use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);
#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
S: Send + Sync,
{
type Rejection = PathRejection;
async fn from_request_parts(
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use axum::{routing::get, Router};
use super::OptionalPath;
use crate::test_helpers::TestClient;
#[crate::test]
async fn supports_128_bit_numbers() {
async fn handle(OptionalPath(param): OptionalPath<NonZeroU32>) -> String {
let num = param.map_or(0, |p| p.get());
format!("Success: {num}")
}
let app = Router::new()
.route("/", get(handle))
.route("/:num", get(handle));
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.text().await, "Success: 0");
let res = client.get("/1").await;
assert_eq!(res.text().await, "Success: 1");
let res = client.get("/0").await;
assert_eq!(
res.text().await,
"Invalid URL: invalid value: integer `0`, expected a nonzero u32"
);
let res = client.get("/NaN").await;
assert_eq!(
res.text().await,
"Invalid URL: Cannot parse `\"NaN\"` to a `u32`"
);
}
}