axum/extract/
nested_path.rs
1use std::{
2 sync::Arc,
3 task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use axum_core::extract::FromRequestParts;
8use http::request::Parts;
9use tower_layer::{layer_fn, Layer};
10use tower_service::Service;
11
12use super::rejection::NestedPathRejection;
13
14#[derive(Debug, Clone)]
40pub struct NestedPath(Arc<str>);
41
42impl NestedPath {
43 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47}
48
49impl<S> FromRequestParts<S> for NestedPath
50where
51 S: Send + Sync,
52{
53 type Rejection = NestedPathRejection;
54
55 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
56 match parts.extensions.get::<Self>() {
57 Some(nested_path) => Ok(nested_path.clone()),
58 None => Err(NestedPathRejection),
59 }
60 }
61}
62
63#[derive(Clone)]
64pub(crate) struct SetNestedPath<S> {
65 inner: S,
66 path: Arc<str>,
67}
68
69impl<S> SetNestedPath<S> {
70 pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
71 let path = Arc::from(path);
72 layer_fn(move |inner| Self {
73 inner,
74 path: Arc::clone(&path),
75 })
76 }
77}
78
79impl<S, B> Service<Request<B>> for SetNestedPath<S>
80where
81 S: Service<Request<B>>,
82{
83 type Response = S::Response;
84 type Error = S::Error;
85 type Future = S::Future;
86
87 #[inline]
88 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
89 self.inner.poll_ready(cx)
90 }
91
92 fn call(&mut self, mut req: Request<B>) -> Self::Future {
93 if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
94 let new_path = if prev.as_str() == "/" {
95 Arc::clone(&self.path)
96 } else {
97 format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
98 };
99 prev.0 = new_path;
100 } else {
101 req.extensions_mut()
102 .insert(NestedPath(Arc::clone(&self.path)));
103 };
104
105 self.inner.call(req)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use axum_core::response::Response;
112 use http::StatusCode;
113
114 use crate::{
115 extract::{NestedPath, Request},
116 middleware::{from_fn, Next},
117 routing::get,
118 test_helpers::*,
119 Router,
120 };
121
122 #[crate::test]
123 async fn one_level_of_nesting() {
124 let api = Router::new().route(
125 "/users",
126 get(|nested_path: NestedPath| {
127 assert_eq!(nested_path.as_str(), "/api");
128 async {}
129 }),
130 );
131
132 let app = Router::new().nest("/api", api);
133
134 let client = TestClient::new(app);
135
136 let res = client.get("/api/users").await;
137 assert_eq!(res.status(), StatusCode::OK);
138 }
139
140 #[crate::test]
141 async fn one_level_of_nesting_with_trailing_slash() {
142 let api = Router::new().route(
143 "/users",
144 get(|nested_path: NestedPath| {
145 assert_eq!(nested_path.as_str(), "/api/");
146 async {}
147 }),
148 );
149
150 let app = Router::new().nest("/api/", api);
151
152 let client = TestClient::new(app);
153
154 let res = client.get("/api/users").await;
155 assert_eq!(res.status(), StatusCode::OK);
156 }
157
158 #[crate::test]
159 async fn two_levels_of_nesting() {
160 let api = Router::new().route(
161 "/users",
162 get(|nested_path: NestedPath| {
163 assert_eq!(nested_path.as_str(), "/api/v2");
164 async {}
165 }),
166 );
167
168 let app = Router::new().nest("/api", Router::new().nest("/v2", api));
169
170 let client = TestClient::new(app);
171
172 let res = client.get("/api/v2/users").await;
173 assert_eq!(res.status(), StatusCode::OK);
174 }
175
176 #[crate::test]
177 async fn two_levels_of_nesting_with_trailing_slash() {
178 let api = Router::new().route(
179 "/users",
180 get(|nested_path: NestedPath| {
181 assert_eq!(nested_path.as_str(), "/api/v2");
182 async {}
183 }),
184 );
185
186 let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
187
188 let client = TestClient::new(app);
189
190 let res = client.get("/api/v2/users").await;
191 assert_eq!(res.status(), StatusCode::OK);
192 }
193
194 #[crate::test]
195 async fn in_fallbacks() {
196 let api = Router::new().fallback(get(|nested_path: NestedPath| {
197 assert_eq!(nested_path.as_str(), "/api");
198 async {}
199 }));
200
201 let app = Router::new().nest("/api", api);
202
203 let client = TestClient::new(app);
204
205 let res = client.get("/api/doesnt-exist").await;
206 assert_eq!(res.status(), StatusCode::OK);
207 }
208
209 #[crate::test]
210 async fn in_middleware() {
211 async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
212 assert_eq!(nested_path.as_str(), "/api");
213 next.run(req).await
214 }
215
216 let api = Router::new()
217 .route("/users", get(|| async {}))
218 .layer(from_fn(middleware));
219
220 let app = Router::new().nest("/api", api);
221
222 let client = TestClient::new(app);
223
224 let res = client.get("/api/users").await;
225 assert_eq!(res.status(), StatusCode::OK);
226 }
227}