1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use axum::{
    async_trait,
    extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
    RequestPartsExt,
};
use serde::de::DeserializeOwned;

/// Extractor that extracts path arguments the same way as [`Path`], except if there aren't any.
///
/// This extractor can be used in place of `Path` when you have two routes that you want to handle
/// in mostly the same way, where one has a path parameter and the other one doesn't.
///
/// # Example
///
/// ```
/// use std::num::NonZeroU32;
/// use axum::{
///     response::IntoResponse,
///     routing::get,
///     Router,
/// };
/// use axum_extra::extract::OptionalPath;
///
/// async fn render_blog(OptionalPath(page): OptionalPath<NonZeroU32>) -> impl IntoResponse {
///     // Convert to u32, default to page 1 if not specified
///     let page = page.map_or(1, |param| param.get());
///     // ...
/// }
///
/// let app = Router::new()
///     .route("/blog", get(render_blog))
///     .route("/blog/:page", get(render_blog));
/// # let app: Router = app;
/// ```
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);

#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
    T: DeserializeOwned + Send + 'static,
    S: Send + Sync,
{
    type Rejection = PathRejection;

    async fn from_request_parts(
        parts: &mut http::request::Parts,
        _: &S,
    ) -> Result<Self, Self::Rejection> {
        match parts.extract::<Path<T>>().await {
            Ok(Path(params)) => Ok(Self(Some(params))),
            Err(PathRejection::FailedToDeserializePathParams(e))
                if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
            {
                Ok(Self(None))
            }
            Err(e) => Err(e),
        }
    }
}

#[cfg(test)]
mod tests {
    use std::num::NonZeroU32;

    use axum::{routing::get, Router};

    use super::OptionalPath;
    use crate::test_helpers::TestClient;

    #[crate::test]
    async fn supports_128_bit_numbers() {
        async fn handle(OptionalPath(param): OptionalPath<NonZeroU32>) -> String {
            let num = param.map_or(0, |p| p.get());
            format!("Success: {num}")
        }

        let app = Router::new()
            .route("/", get(handle))
            .route("/:num", get(handle));

        let client = TestClient::new(app);

        let res = client.get("/").await;
        assert_eq!(res.text().await, "Success: 0");

        let res = client.get("/1").await;
        assert_eq!(res.text().await, "Success: 1");

        let res = client.get("/0").await;
        assert_eq!(
            res.text().await,
            "Invalid URL: invalid value: integer `0`, expected a nonzero u32"
        );

        let res = client.get("/NaN").await;
        assert_eq!(
            res.text().await,
            "Invalid URL: Cannot parse `\"NaN\"` to a `u32`"
        );
    }
}