1use axum::{
4 extract::Request,
5 response::{IntoResponse, Redirect, Response},
6 routing::{any, MethodRouter},
7 Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28pub trait RouterExt<S>: sealed::Sealed {
30 #[cfg(feature = "typed-routing")]
37 fn typed_get<H, T, P>(self, handler: H) -> Self
38 where
39 H: axum::handler::Handler<T, S>,
40 T: SecondElementIs<P> + 'static,
41 P: TypedPath;
42
43 #[cfg(feature = "typed-routing")]
50 fn typed_delete<H, T, P>(self, handler: H) -> Self
51 where
52 H: axum::handler::Handler<T, S>,
53 T: SecondElementIs<P> + 'static,
54 P: TypedPath;
55
56 #[cfg(feature = "typed-routing")]
63 fn typed_head<H, T, P>(self, handler: H) -> Self
64 where
65 H: axum::handler::Handler<T, S>,
66 T: SecondElementIs<P> + 'static,
67 P: TypedPath;
68
69 #[cfg(feature = "typed-routing")]
76 fn typed_options<H, T, P>(self, handler: H) -> Self
77 where
78 H: axum::handler::Handler<T, S>,
79 T: SecondElementIs<P> + 'static,
80 P: TypedPath;
81
82 #[cfg(feature = "typed-routing")]
89 fn typed_patch<H, T, P>(self, handler: H) -> Self
90 where
91 H: axum::handler::Handler<T, S>,
92 T: SecondElementIs<P> + 'static,
93 P: TypedPath;
94
95 #[cfg(feature = "typed-routing")]
102 fn typed_post<H, T, P>(self, handler: H) -> Self
103 where
104 H: axum::handler::Handler<T, S>,
105 T: SecondElementIs<P> + 'static,
106 P: TypedPath;
107
108 #[cfg(feature = "typed-routing")]
115 fn typed_put<H, T, P>(self, handler: H) -> Self
116 where
117 H: axum::handler::Handler<T, S>,
118 T: SecondElementIs<P> + 'static,
119 P: TypedPath;
120
121 #[cfg(feature = "typed-routing")]
128 fn typed_trace<H, T, P>(self, handler: H) -> Self
129 where
130 H: axum::handler::Handler<T, S>,
131 T: SecondElementIs<P> + 'static,
132 P: TypedPath;
133
134 fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
160 where
161 Self: Sized;
162
163 fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
167 where
168 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
169 T::Response: IntoResponse,
170 T::Future: Send + 'static,
171 Self: Sized;
172}
173
174impl<S> RouterExt<S> for Router<S>
175where
176 S: Clone + Send + Sync + 'static,
177{
178 #[cfg(feature = "typed-routing")]
179 fn typed_get<H, T, P>(self, handler: H) -> Self
180 where
181 H: axum::handler::Handler<T, S>,
182 T: SecondElementIs<P> + 'static,
183 P: TypedPath,
184 {
185 self.route(P::PATH, axum::routing::get(handler))
186 }
187
188 #[cfg(feature = "typed-routing")]
189 fn typed_delete<H, T, P>(self, handler: H) -> Self
190 where
191 H: axum::handler::Handler<T, S>,
192 T: SecondElementIs<P> + 'static,
193 P: TypedPath,
194 {
195 self.route(P::PATH, axum::routing::delete(handler))
196 }
197
198 #[cfg(feature = "typed-routing")]
199 fn typed_head<H, T, P>(self, handler: H) -> Self
200 where
201 H: axum::handler::Handler<T, S>,
202 T: SecondElementIs<P> + 'static,
203 P: TypedPath,
204 {
205 self.route(P::PATH, axum::routing::head(handler))
206 }
207
208 #[cfg(feature = "typed-routing")]
209 fn typed_options<H, T, P>(self, handler: H) -> Self
210 where
211 H: axum::handler::Handler<T, S>,
212 T: SecondElementIs<P> + 'static,
213 P: TypedPath,
214 {
215 self.route(P::PATH, axum::routing::options(handler))
216 }
217
218 #[cfg(feature = "typed-routing")]
219 fn typed_patch<H, T, P>(self, handler: H) -> Self
220 where
221 H: axum::handler::Handler<T, S>,
222 T: SecondElementIs<P> + 'static,
223 P: TypedPath,
224 {
225 self.route(P::PATH, axum::routing::patch(handler))
226 }
227
228 #[cfg(feature = "typed-routing")]
229 fn typed_post<H, T, P>(self, handler: H) -> Self
230 where
231 H: axum::handler::Handler<T, S>,
232 T: SecondElementIs<P> + 'static,
233 P: TypedPath,
234 {
235 self.route(P::PATH, axum::routing::post(handler))
236 }
237
238 #[cfg(feature = "typed-routing")]
239 fn typed_put<H, T, P>(self, handler: H) -> Self
240 where
241 H: axum::handler::Handler<T, S>,
242 T: SecondElementIs<P> + 'static,
243 P: TypedPath,
244 {
245 self.route(P::PATH, axum::routing::put(handler))
246 }
247
248 #[cfg(feature = "typed-routing")]
249 fn typed_trace<H, T, P>(self, handler: H) -> Self
250 where
251 H: axum::handler::Handler<T, S>,
252 T: SecondElementIs<P> + 'static,
253 P: TypedPath,
254 {
255 self.route(P::PATH, axum::routing::trace(handler))
256 }
257
258 #[track_caller]
259 fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
260 where
261 Self: Sized,
262 {
263 validate_tsr_path(path);
264 self = self.route(path, method_router);
265 add_tsr_redirect_route(self, path)
266 }
267
268 #[track_caller]
269 fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
270 where
271 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
272 T::Response: IntoResponse,
273 T::Future: Send + 'static,
274 Self: Sized,
275 {
276 validate_tsr_path(path);
277 self = self.route_service(path, service);
278 add_tsr_redirect_route(self, path)
279 }
280}
281
282#[track_caller]
283fn validate_tsr_path(path: &str) {
284 if path == "/" {
285 panic!("Cannot add a trailing slash redirect route for `/`")
286 }
287}
288
289fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
290where
291 S: Clone + Send + Sync + 'static,
292{
293 async fn redirect_handler(uri: Uri) -> Response {
294 let new_uri = map_path(uri, |path| {
295 path.strip_suffix('/')
296 .map(Cow::Borrowed)
297 .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
298 });
299
300 if let Some(new_uri) = new_uri {
301 Redirect::permanent(&new_uri.to_string()).into_response()
302 } else {
303 StatusCode::BAD_REQUEST.into_response()
304 }
305 }
306
307 if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
308 router.route(path_without_trailing_slash, any(redirect_handler))
309 } else {
310 router.route(&format!("{path}/"), any(redirect_handler))
311 }
312}
313
314fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
318where
319 F: FnOnce(&str) -> Cow<'_, str>,
320{
321 let mut parts = original_uri.into_parts();
322 let path_and_query = parts.path_and_query.as_ref()?;
323
324 let new_path = f(path_and_query.path());
325
326 let new_path_and_query = if let Some(query) = &path_and_query.query() {
327 format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
328 } else {
329 new_path.parse::<PathAndQuery>().ok()?
330 };
331 parts.path_and_query = Some(new_path_and_query);
332
333 Uri::from_parts(parts).ok()
334}
335
336mod sealed {
337 pub trait Sealed {}
338 impl<S> Sealed for axum::Router<S> {}
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::test_helpers::*;
345 use axum::{extract::Path, routing::get};
346
347 #[tokio::test]
348 async fn test_tsr() {
349 let app = Router::new()
350 .route_with_tsr("/foo", get(|| async {}))
351 .route_with_tsr("/bar/", get(|| async {}));
352
353 let client = TestClient::new(app);
354
355 let res = client.get("/foo").await;
356 assert_eq!(res.status(), StatusCode::OK);
357
358 let res = client.get("/foo/").await;
359 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
360 assert_eq!(res.headers()["location"], "/foo");
361
362 let res = client.get("/bar/").await;
363 assert_eq!(res.status(), StatusCode::OK);
364
365 let res = client.get("/bar").await;
366 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
367 assert_eq!(res.headers()["location"], "/bar/");
368 }
369
370 #[tokio::test]
371 async fn tsr_with_params() {
372 let app = Router::new()
373 .route_with_tsr(
374 "/a/:a",
375 get(|Path(param): Path<String>| async move { param }),
376 )
377 .route_with_tsr(
378 "/b/:b/",
379 get(|Path(param): Path<String>| async move { param }),
380 );
381
382 let client = TestClient::new(app);
383
384 let res = client.get("/a/foo").await;
385 assert_eq!(res.status(), StatusCode::OK);
386 assert_eq!(res.text().await, "foo");
387
388 let res = client.get("/a/foo/").await;
389 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
390 assert_eq!(res.headers()["location"], "/a/foo");
391
392 let res = client.get("/b/foo/").await;
393 assert_eq!(res.status(), StatusCode::OK);
394 assert_eq!(res.text().await, "foo");
395
396 let res = client.get("/b/foo").await;
397 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
398 assert_eq!(res.headers()["location"], "/b/foo/");
399 }
400
401 #[tokio::test]
402 async fn tsr_maintains_query_params() {
403 let app = Router::new().route_with_tsr("/foo", get(|| async {}));
404
405 let client = TestClient::new(app);
406
407 let res = client.get("/foo/?a=a").await;
408 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
409 assert_eq!(res.headers()["location"], "/foo?a=a");
410 }
411
412 #[test]
413 #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
414 fn tsr_at_root() {
415 let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
416 }
417}