axum/extract/
nested_path.rs

1use std::{
2    sync::Arc,
3    task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use axum_core::extract::FromRequestParts;
8use http::request::Parts;
9use tower_layer::{layer_fn, Layer};
10use tower_service::Service;
11
12use super::rejection::NestedPathRejection;
13
14/// Access the path the matched the route is nested at.
15///
16/// This can for example be used when doing redirects.
17///
18/// # Example
19///
20/// ```
21/// use axum::{
22///     Router,
23///     extract::NestedPath,
24///     routing::get,
25/// };
26///
27/// let api = Router::new().route(
28///     "/users",
29///     get(|path: NestedPath| async move {
30///         // `path` will be "/api" because that's what this
31///         // router is nested at when we build `app`
32///         let path = path.as_str();
33///     })
34/// );
35///
36/// let app = Router::new().nest("/api", api);
37/// # let _: Router = app;
38/// ```
39#[derive(Debug, Clone)]
40pub struct NestedPath(Arc<str>);
41
42impl NestedPath {
43    /// Returns a `str` representation of the path.
44    pub fn as_str(&self) -> &str {
45        &self.0
46    }
47}
48
49impl<S> FromRequestParts<S> for NestedPath
50where
51    S: Send + Sync,
52{
53    type Rejection = NestedPathRejection;
54
55    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
56        match parts.extensions.get::<Self>() {
57            Some(nested_path) => Ok(nested_path.clone()),
58            None => Err(NestedPathRejection),
59        }
60    }
61}
62
63#[derive(Clone)]
64pub(crate) struct SetNestedPath<S> {
65    inner: S,
66    path: Arc<str>,
67}
68
69impl<S> SetNestedPath<S> {
70    pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
71        let path = Arc::from(path);
72        layer_fn(move |inner| Self {
73            inner,
74            path: Arc::clone(&path),
75        })
76    }
77}
78
79impl<S, B> Service<Request<B>> for SetNestedPath<S>
80where
81    S: Service<Request<B>>,
82{
83    type Response = S::Response;
84    type Error = S::Error;
85    type Future = S::Future;
86
87    #[inline]
88    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89        self.inner.poll_ready(cx)
90    }
91
92    fn call(&mut self, mut req: Request<B>) -> Self::Future {
93        if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
94            let new_path = if prev.as_str() == "/" {
95                Arc::clone(&self.path)
96            } else {
97                format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
98            };
99            prev.0 = new_path;
100        } else {
101            req.extensions_mut()
102                .insert(NestedPath(Arc::clone(&self.path)));
103        };
104
105        self.inner.call(req)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use axum_core::response::Response;
112    use http::StatusCode;
113
114    use crate::{
115        extract::{NestedPath, Request},
116        middleware::{from_fn, Next},
117        routing::get,
118        test_helpers::*,
119        Router,
120    };
121
122    #[crate::test]
123    async fn one_level_of_nesting() {
124        let api = Router::new().route(
125            "/users",
126            get(|nested_path: NestedPath| {
127                assert_eq!(nested_path.as_str(), "/api");
128                async {}
129            }),
130        );
131
132        let app = Router::new().nest("/api", api);
133
134        let client = TestClient::new(app);
135
136        let res = client.get("/api/users").await;
137        assert_eq!(res.status(), StatusCode::OK);
138    }
139
140    #[crate::test]
141    async fn one_level_of_nesting_with_trailing_slash() {
142        let api = Router::new().route(
143            "/users",
144            get(|nested_path: NestedPath| {
145                assert_eq!(nested_path.as_str(), "/api/");
146                async {}
147            }),
148        );
149
150        let app = Router::new().nest("/api/", api);
151
152        let client = TestClient::new(app);
153
154        let res = client.get("/api/users").await;
155        assert_eq!(res.status(), StatusCode::OK);
156    }
157
158    #[crate::test]
159    async fn two_levels_of_nesting() {
160        let api = Router::new().route(
161            "/users",
162            get(|nested_path: NestedPath| {
163                assert_eq!(nested_path.as_str(), "/api/v2");
164                async {}
165            }),
166        );
167
168        let app = Router::new().nest("/api", Router::new().nest("/v2", api));
169
170        let client = TestClient::new(app);
171
172        let res = client.get("/api/v2/users").await;
173        assert_eq!(res.status(), StatusCode::OK);
174    }
175
176    #[crate::test]
177    async fn two_levels_of_nesting_with_trailing_slash() {
178        let api = Router::new().route(
179            "/users",
180            get(|nested_path: NestedPath| {
181                assert_eq!(nested_path.as_str(), "/api/v2");
182                async {}
183            }),
184        );
185
186        let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
187
188        let client = TestClient::new(app);
189
190        let res = client.get("/api/v2/users").await;
191        assert_eq!(res.status(), StatusCode::OK);
192    }
193
194    #[crate::test]
195    async fn in_fallbacks() {
196        let api = Router::new().fallback(get(|nested_path: NestedPath| {
197            assert_eq!(nested_path.as_str(), "/api");
198            async {}
199        }));
200
201        let app = Router::new().nest("/api", api);
202
203        let client = TestClient::new(app);
204
205        let res = client.get("/api/doesnt-exist").await;
206        assert_eq!(res.status(), StatusCode::OK);
207    }
208
209    #[crate::test]
210    async fn in_middleware() {
211        async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
212            assert_eq!(nested_path.as_str(), "/api");
213            next.run(req).await
214        }
215
216        let api = Router::new()
217            .route("/users", get(|| async {}))
218            .layer(from_fn(middleware));
219
220        let app = Router::new().nest("/api", api);
221
222        let client = TestClient::new(app);
223
224        let res = client.get("/api/users").await;
225        assert_eq!(res.status(), StatusCode::OK);
226    }
227}