api-server/src/main.rs

189 lines
4.9 KiB
Rust
Raw Normal View History

mod dem;
use axum::{
extract::{Extension, Path, Query, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Json, Router,
};
use axum_macros::debug_handler;
use std::env;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize};
use serde_qs::axum::{QsQueryConfig, QsQueryRejection};
use tower_http::trace::TraceLayer;
use tower_http::{
services::ServeDir,
trace::{self},
};
use tracing::{error, info, Level};
use dem::DatasetRepository;
const DEFAULT_DATA_DIR: &str = "/data";
const DEFAULT_PORT: &str = "3000";
#[derive(Deserialize, Debug)]
struct Opts {
#[serde(default, deserialize_with = "empty_string_as_none")]
json: bool,
}
#[derive(Deserialize, Debug)]
struct JsParams {
#[serde(default, deserialize_with = "deserialize_array")]
pts: Vec<(f64, f64)>,
}
#[derive(Serialize, Debug)]
struct JsResult {
#[serde(serialize_with = "serialize_vec_round")]
elevations: Vec<Option<f64>>,
}
fn serialize_vec_round<S>(v: &Vec<Option<f64>>, s: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut sv = s.serialize_seq(Some(v.len()))?;
for e in v {
match e {
None => sv.serialize_element(&e)?,
Some(x) => {
// Round the f64 to 1 decimal place. This is ugly as shit.
let fmt = format!("{:.1}", x);
let xx = fmt.parse::<f64>().unwrap();
sv.serialize_element(&xx)?;
}
};
}
sv.end()
}
fn deserialize_array<'de, D>(deserializer: D) -> Result<Vec<(f64, f64)>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let result: Result<Vec<(f64, f64)>, serde_json::Error> = serde_json::from_str(&s);
match result {
Ok(x) => Ok(x),
Err(e) => Err(serde::de::Error::custom(
"Invalid array: ".to_string() + &e.to_string(),
)),
}
}
fn empty_string_as_none<'de, D>(de: D) -> Result<bool, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(de)?;
match opt.as_deref() {
None => Ok(false),
Some("") => Ok(true),
Some(x) => Ok(x != "false"),
}
}
#[tokio::main]
async fn main() {
// initialize tracing
tracing_subscriber::fmt::init();
let config = load_config().unwrap();
let cache = DatasetRepository::new(config.basedir);
let serve_dir = ServeDir::new("assets");
let app = Router::<DatasetRepository>::new()
.route("/v1/elevation/:lat/:lon", get(get_elevation))
.route("/v1/elevation", get(get_elevation_js))
.nest_service("/", serve_dir)
.with_state(cache)
.layer(
TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::DEBUG))
.on_response(trace::DefaultOnResponse::new().level(Level::DEBUG)),
)
.layer(Extension(QsQueryConfig::new(5, false).error_handler(
|err| {
QsQueryRejection::new(
format!("Get fucked: {}", err),
StatusCode::UNPROCESSABLE_ENTITY,
)
},
)));
let host = format!("[::]:{}", config.port);
info!("Will start server on {host}");
let listener = tokio::net::TcpListener::bind(host).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
#[debug_handler]
async fn get_elevation(
State(dsr): State<DatasetRepository>,
query_opts: Query<Opts>,
Path((lat, lon)): Path<(f64, f64)>,
) -> Response {
println!("lat: {}, lon: {}", lat, lon);
println!("query_opts: {:?}", query_opts);
let ele;
match dem::elevation_from_coordinates(&dsr, lat, lon).await {
Ok(x) => match x {
Some(el) => ele = el,
None => ele = 0.,
},
Err(e) => {
return e.to_string().into_response();
}
}
#[derive(Serialize)]
struct Ele {
elevation: f64,
}
if query_opts.json {
let r = Ele { elevation: ele };
Json(r).into_response()
} else {
format!("{}", ele).into_response()
}
}
#[debug_handler]
async fn get_elevation_js(
State(dsr): State<DatasetRepository>,
Query(params): Query<JsParams>,
) -> Response {
let mut response = JsResult { elevations: vec![] };
for pt in params.pts {
let ele = dem::elevation_from_coordinates(&dsr, pt.0, pt.1).await;
match ele {
Ok(x) => response.elevations.push(x),
Err(e) => {
error!("Error: {e}");
response.elevations.push(None);
}
}
}
Json(response).into_response()
}
fn load_config() -> Result<Config, env::VarError> {
Ok(Config {
basedir: env::var("DEM_LOCATION").unwrap_or_else(|_| DEFAULT_DATA_DIR.to_string()),
port: env::var("HTTP_PORT").unwrap_or_else(|_| DEFAULT_PORT.to_string()),
})
}
struct Config {
basedir: String,
port: String,
}