1use axum::{
4 extract::{OriginalUri, Request},
5 response::{IntoResponse, Redirect, Response},
6 routing::{any, MethodRouter},
7 Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28#[rustversion::since(1.80)]
30#[doc(hidden)]
31pub const fn __private_validate_static_path(path: &'static str) -> &'static str {
32 if path.is_empty() {
33 panic!("Paths must start with a `/`. Use \"/\" for root routes")
34 }
35 if path.as_bytes()[0] != b'/' {
36 panic!("Paths must start with /");
37 }
38 path
39}
40
41#[rustversion::since(1.80)]
71#[macro_export]
72macro_rules! vpath {
73 ($e:expr) => {
74 const { $crate::routing::__private_validate_static_path($e) }
75 };
76}
77
78pub trait RouterExt<S>: sealed::Sealed {
80 #[cfg(feature = "typed-routing")]
87 fn typed_get<H, T, P>(self, handler: H) -> Self
88 where
89 H: axum::handler::Handler<T, S>,
90 T: SecondElementIs<P> + 'static,
91 P: TypedPath;
92
93 #[cfg(feature = "typed-routing")]
100 fn typed_delete<H, T, P>(self, handler: H) -> Self
101 where
102 H: axum::handler::Handler<T, S>,
103 T: SecondElementIs<P> + 'static,
104 P: TypedPath;
105
106 #[cfg(feature = "typed-routing")]
113 fn typed_head<H, T, P>(self, handler: H) -> Self
114 where
115 H: axum::handler::Handler<T, S>,
116 T: SecondElementIs<P> + 'static,
117 P: TypedPath;
118
119 #[cfg(feature = "typed-routing")]
126 fn typed_options<H, T, P>(self, handler: H) -> Self
127 where
128 H: axum::handler::Handler<T, S>,
129 T: SecondElementIs<P> + 'static,
130 P: TypedPath;
131
132 #[cfg(feature = "typed-routing")]
139 fn typed_patch<H, T, P>(self, handler: H) -> Self
140 where
141 H: axum::handler::Handler<T, S>,
142 T: SecondElementIs<P> + 'static,
143 P: TypedPath;
144
145 #[cfg(feature = "typed-routing")]
152 fn typed_post<H, T, P>(self, handler: H) -> Self
153 where
154 H: axum::handler::Handler<T, S>,
155 T: SecondElementIs<P> + 'static,
156 P: TypedPath;
157
158 #[cfg(feature = "typed-routing")]
165 fn typed_put<H, T, P>(self, handler: H) -> Self
166 where
167 H: axum::handler::Handler<T, S>,
168 T: SecondElementIs<P> + 'static,
169 P: TypedPath;
170
171 #[cfg(feature = "typed-routing")]
178 fn typed_trace<H, T, P>(self, handler: H) -> Self
179 where
180 H: axum::handler::Handler<T, S>,
181 T: SecondElementIs<P> + 'static,
182 P: TypedPath;
183
184 #[cfg(feature = "typed-routing")]
191 fn typed_connect<H, T, P>(self, handler: H) -> Self
192 where
193 H: axum::handler::Handler<T, S>,
194 T: SecondElementIs<P> + 'static,
195 P: TypedPath;
196
197 fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
223 where
224 Self: Sized;
225
226 fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
230 where
231 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
232 T::Response: IntoResponse,
233 T::Future: Send + 'static,
234 Self: Sized;
235}
236
237impl<S> RouterExt<S> for Router<S>
238where
239 S: Clone + Send + Sync + 'static,
240{
241 #[cfg(feature = "typed-routing")]
242 fn typed_get<H, T, P>(self, handler: H) -> Self
243 where
244 H: axum::handler::Handler<T, S>,
245 T: SecondElementIs<P> + 'static,
246 P: TypedPath,
247 {
248 self.route(P::PATH, axum::routing::get(handler))
249 }
250
251 #[cfg(feature = "typed-routing")]
252 fn typed_delete<H, T, P>(self, handler: H) -> Self
253 where
254 H: axum::handler::Handler<T, S>,
255 T: SecondElementIs<P> + 'static,
256 P: TypedPath,
257 {
258 self.route(P::PATH, axum::routing::delete(handler))
259 }
260
261 #[cfg(feature = "typed-routing")]
262 fn typed_head<H, T, P>(self, handler: H) -> Self
263 where
264 H: axum::handler::Handler<T, S>,
265 T: SecondElementIs<P> + 'static,
266 P: TypedPath,
267 {
268 self.route(P::PATH, axum::routing::head(handler))
269 }
270
271 #[cfg(feature = "typed-routing")]
272 fn typed_options<H, T, P>(self, handler: H) -> Self
273 where
274 H: axum::handler::Handler<T, S>,
275 T: SecondElementIs<P> + 'static,
276 P: TypedPath,
277 {
278 self.route(P::PATH, axum::routing::options(handler))
279 }
280
281 #[cfg(feature = "typed-routing")]
282 fn typed_patch<H, T, P>(self, handler: H) -> Self
283 where
284 H: axum::handler::Handler<T, S>,
285 T: SecondElementIs<P> + 'static,
286 P: TypedPath,
287 {
288 self.route(P::PATH, axum::routing::patch(handler))
289 }
290
291 #[cfg(feature = "typed-routing")]
292 fn typed_post<H, T, P>(self, handler: H) -> Self
293 where
294 H: axum::handler::Handler<T, S>,
295 T: SecondElementIs<P> + 'static,
296 P: TypedPath,
297 {
298 self.route(P::PATH, axum::routing::post(handler))
299 }
300
301 #[cfg(feature = "typed-routing")]
302 fn typed_put<H, T, P>(self, handler: H) -> Self
303 where
304 H: axum::handler::Handler<T, S>,
305 T: SecondElementIs<P> + 'static,
306 P: TypedPath,
307 {
308 self.route(P::PATH, axum::routing::put(handler))
309 }
310
311 #[cfg(feature = "typed-routing")]
312 fn typed_trace<H, T, P>(self, handler: H) -> Self
313 where
314 H: axum::handler::Handler<T, S>,
315 T: SecondElementIs<P> + 'static,
316 P: TypedPath,
317 {
318 self.route(P::PATH, axum::routing::trace(handler))
319 }
320
321 #[cfg(feature = "typed-routing")]
322 fn typed_connect<H, T, P>(self, handler: H) -> Self
323 where
324 H: axum::handler::Handler<T, S>,
325 T: SecondElementIs<P> + 'static,
326 P: TypedPath,
327 {
328 self.route(P::PATH, axum::routing::connect(handler))
329 }
330
331 #[track_caller]
332 fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
333 where
334 Self: Sized,
335 {
336 validate_tsr_path(path);
337 self = self.route(path, method_router);
338 add_tsr_redirect_route(self, path)
339 }
340
341 #[track_caller]
342 fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
343 where
344 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
345 T::Response: IntoResponse,
346 T::Future: Send + 'static,
347 Self: Sized,
348 {
349 validate_tsr_path(path);
350 self = self.route_service(path, service);
351 add_tsr_redirect_route(self, path)
352 }
353}
354
355#[track_caller]
356fn validate_tsr_path(path: &str) {
357 if path == "/" {
358 panic!("Cannot add a trailing slash redirect route for `/`")
359 }
360}
361
362fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
363where
364 S: Clone + Send + Sync + 'static,
365{
366 async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
367 let new_uri = map_path(uri, |path| {
368 path.strip_suffix('/')
369 .map(Cow::Borrowed)
370 .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
371 });
372
373 if let Some(new_uri) = new_uri {
374 Redirect::permanent(&new_uri.to_string()).into_response()
375 } else {
376 StatusCode::BAD_REQUEST.into_response()
377 }
378 }
379
380 if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
381 router.route(path_without_trailing_slash, any(redirect_handler))
382 } else {
383 router.route(&format!("{path}/"), any(redirect_handler))
384 }
385}
386
387fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
391where
392 F: FnOnce(&str) -> Cow<'_, str>,
393{
394 let mut parts = original_uri.into_parts();
395 let path_and_query = parts.path_and_query.as_ref()?;
396
397 let new_path = f(path_and_query.path());
398
399 let new_path_and_query = if let Some(query) = &path_and_query.query() {
400 format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
401 } else {
402 new_path.parse::<PathAndQuery>().ok()?
403 };
404 parts.path_and_query = Some(new_path_and_query);
405
406 Uri::from_parts(parts).ok()
407}
408
409mod sealed {
410 pub trait Sealed {}
411 impl<S> Sealed for axum::Router<S> {}
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::test_helpers::*;
418 use axum::{extract::Path, routing::get};
419
420 #[tokio::test]
421 async fn test_tsr() {
422 let app = Router::new()
423 .route_with_tsr("/foo", get(|| async {}))
424 .route_with_tsr("/bar/", get(|| async {}));
425
426 let client = TestClient::new(app);
427
428 let res = client.get("/foo").await;
429 assert_eq!(res.status(), StatusCode::OK);
430
431 let res = client.get("/foo/").await;
432 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
433 assert_eq!(res.headers()["location"], "/foo");
434
435 let res = client.get("/bar/").await;
436 assert_eq!(res.status(), StatusCode::OK);
437
438 let res = client.get("/bar").await;
439 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
440 assert_eq!(res.headers()["location"], "/bar/");
441 }
442
443 #[tokio::test]
444 async fn tsr_with_params() {
445 let app = Router::new()
446 .route_with_tsr(
447 "/a/{a}",
448 get(|Path(param): Path<String>| async move { param }),
449 )
450 .route_with_tsr(
451 "/b/{b}/",
452 get(|Path(param): Path<String>| async move { param }),
453 );
454
455 let client = TestClient::new(app);
456
457 let res = client.get("/a/foo").await;
458 assert_eq!(res.status(), StatusCode::OK);
459 assert_eq!(res.text().await, "foo");
460
461 let res = client.get("/a/foo/").await;
462 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
463 assert_eq!(res.headers()["location"], "/a/foo");
464
465 let res = client.get("/b/foo/").await;
466 assert_eq!(res.status(), StatusCode::OK);
467 assert_eq!(res.text().await, "foo");
468
469 let res = client.get("/b/foo").await;
470 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
471 assert_eq!(res.headers()["location"], "/b/foo/");
472 }
473
474 #[tokio::test]
475 async fn tsr_maintains_query_params() {
476 let app = Router::new().route_with_tsr("/foo", get(|| async {}));
477
478 let client = TestClient::new(app);
479
480 let res = client.get("/foo/?a=a").await;
481 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
482 assert_eq!(res.headers()["location"], "/foo?a=a");
483 }
484
485 #[tokio::test]
486 async fn tsr_works_in_nested_router() {
487 let app = Router::new().nest(
488 "/neko",
489 Router::new().route_with_tsr("/nyan/", get(|| async {})),
490 );
491
492 let client = TestClient::new(app);
493 let res = client.get("/neko/nyan/").await;
494 assert_eq!(res.status(), StatusCode::OK);
495
496 let res = client.get("/neko/nyan").await;
497 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
498 assert_eq!(res.headers()["location"], "/neko/nyan/");
499 }
500
501 #[test]
502 #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
503 fn tsr_at_root() {
504 let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
505 }
506}