axum_extra/routing/
mod.rs

1//! Additional types for defining routes.
2
3use axum::{
4    extract::{OriginalUri, Request},
5    response::{IntoResponse, Redirect, Response},
6    routing::{any, MethodRouter},
7    Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28// Validates a path at compile time, used with the vpath macro.
29#[rustversion::since(1.80)]
30#[doc(hidden)]
31pub const fn __private_validate_static_path(path: &'static str) -> &'static str {
32    if path.is_empty() {
33        panic!("Paths must start with a `/`. Use \"/\" for root routes")
34    }
35    if path.as_bytes()[0] != b'/' {
36        panic!("Paths must start with /");
37    }
38    path
39}
40
41/// This macro aborts compilation if the path is invalid.
42///
43/// This example will fail to compile:
44///
45/// ```compile_fail
46/// use axum::routing::{Router, get};
47/// use axum_extra::vpath;
48///
49/// let router = axum::Router::<()>::new()
50///     .route(vpath!("invalid_path"), get(root))
51///     .to_owned();
52///
53/// async fn root() {}
54/// ```
55///
56/// This one will compile without problems:
57///
58/// ```no_run
59/// use axum::routing::{Router, get};
60/// use axum_extra::vpath;
61///
62/// let router = axum::Router::<()>::new()
63///     .route(vpath!("/valid_path"), get(root))
64///     .to_owned();
65///
66/// async fn root() {}
67/// ```
68///
69/// This macro is available only on rust versions 1.80 and above.
70#[rustversion::since(1.80)]
71#[macro_export]
72macro_rules! vpath {
73    ($e:expr) => {
74        const { $crate::routing::__private_validate_static_path($e) }
75    };
76}
77
78/// Extension trait that adds additional methods to [`Router`].
79pub trait RouterExt<S>: sealed::Sealed {
80    /// Add a typed `GET` route to the router.
81    ///
82    /// The path will be inferred from the first argument to the handler function which must
83    /// implement [`TypedPath`].
84    ///
85    /// See [`TypedPath`] for more details and examples.
86    #[cfg(feature = "typed-routing")]
87    fn typed_get<H, T, P>(self, handler: H) -> Self
88    where
89        H: axum::handler::Handler<T, S>,
90        T: SecondElementIs<P> + 'static,
91        P: TypedPath;
92
93    /// Add a typed `DELETE` route to the router.
94    ///
95    /// The path will be inferred from the first argument to the handler function which must
96    /// implement [`TypedPath`].
97    ///
98    /// See [`TypedPath`] for more details and examples.
99    #[cfg(feature = "typed-routing")]
100    fn typed_delete<H, T, P>(self, handler: H) -> Self
101    where
102        H: axum::handler::Handler<T, S>,
103        T: SecondElementIs<P> + 'static,
104        P: TypedPath;
105
106    /// Add a typed `HEAD` route to the router.
107    ///
108    /// The path will be inferred from the first argument to the handler function which must
109    /// implement [`TypedPath`].
110    ///
111    /// See [`TypedPath`] for more details and examples.
112    #[cfg(feature = "typed-routing")]
113    fn typed_head<H, T, P>(self, handler: H) -> Self
114    where
115        H: axum::handler::Handler<T, S>,
116        T: SecondElementIs<P> + 'static,
117        P: TypedPath;
118
119    /// Add a typed `OPTIONS` route to the router.
120    ///
121    /// The path will be inferred from the first argument to the handler function which must
122    /// implement [`TypedPath`].
123    ///
124    /// See [`TypedPath`] for more details and examples.
125    #[cfg(feature = "typed-routing")]
126    fn typed_options<H, T, P>(self, handler: H) -> Self
127    where
128        H: axum::handler::Handler<T, S>,
129        T: SecondElementIs<P> + 'static,
130        P: TypedPath;
131
132    /// Add a typed `PATCH` route to the router.
133    ///
134    /// The path will be inferred from the first argument to the handler function which must
135    /// implement [`TypedPath`].
136    ///
137    /// See [`TypedPath`] for more details and examples.
138    #[cfg(feature = "typed-routing")]
139    fn typed_patch<H, T, P>(self, handler: H) -> Self
140    where
141        H: axum::handler::Handler<T, S>,
142        T: SecondElementIs<P> + 'static,
143        P: TypedPath;
144
145    /// Add a typed `POST` route to the router.
146    ///
147    /// The path will be inferred from the first argument to the handler function which must
148    /// implement [`TypedPath`].
149    ///
150    /// See [`TypedPath`] for more details and examples.
151    #[cfg(feature = "typed-routing")]
152    fn typed_post<H, T, P>(self, handler: H) -> Self
153    where
154        H: axum::handler::Handler<T, S>,
155        T: SecondElementIs<P> + 'static,
156        P: TypedPath;
157
158    /// Add a typed `PUT` route to the router.
159    ///
160    /// The path will be inferred from the first argument to the handler function which must
161    /// implement [`TypedPath`].
162    ///
163    /// See [`TypedPath`] for more details and examples.
164    #[cfg(feature = "typed-routing")]
165    fn typed_put<H, T, P>(self, handler: H) -> Self
166    where
167        H: axum::handler::Handler<T, S>,
168        T: SecondElementIs<P> + 'static,
169        P: TypedPath;
170
171    /// Add a typed `TRACE` route to the router.
172    ///
173    /// The path will be inferred from the first argument to the handler function which must
174    /// implement [`TypedPath`].
175    ///
176    /// See [`TypedPath`] for more details and examples.
177    #[cfg(feature = "typed-routing")]
178    fn typed_trace<H, T, P>(self, handler: H) -> Self
179    where
180        H: axum::handler::Handler<T, S>,
181        T: SecondElementIs<P> + 'static,
182        P: TypedPath;
183
184    /// Add a typed `CONNECT` route to the router.
185    ///
186    /// The path will be inferred from the first argument to the handler function which must
187    /// implement [`TypedPath`].
188    ///
189    /// See [`TypedPath`] for more details and examples.
190    #[cfg(feature = "typed-routing")]
191    fn typed_connect<H, T, P>(self, handler: H) -> Self
192    where
193        H: axum::handler::Handler<T, S>,
194        T: SecondElementIs<P> + 'static,
195        P: TypedPath;
196
197    /// Add another route to the router with an additional "trailing slash redirect" route.
198    ///
199    /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
200    /// route for `/foo/` that redirects to `/foo`.
201    ///
202    /// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a
203    /// route for `/bar` that redirects to `/bar/`.
204    ///
205    /// This is similar to what axum 0.5.x did by default, except this explicitly adds another
206    /// route, so trying to add a `/foo/` route after calling `.route_with_tsr("/foo", /* ... */)`
207    /// will result in a panic due to route overlap.
208    ///
209    /// # Example
210    ///
211    /// ```
212    /// use axum::{Router, routing::get};
213    /// use axum_extra::routing::RouterExt;
214    ///
215    /// let app = Router::new()
216    ///     // `/foo/` will redirect to `/foo`
217    ///     .route_with_tsr("/foo", get(|| async {}))
218    ///     // `/bar` will redirect to `/bar/`
219    ///     .route_with_tsr("/bar/", get(|| async {}));
220    /// # let _: Router = app;
221    /// ```
222    fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
223    where
224        Self: Sized;
225
226    /// Add another route to the router with an additional "trailing slash redirect" route.
227    ///
228    /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
229    fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
230    where
231        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
232        T::Response: IntoResponse,
233        T::Future: Send + 'static,
234        Self: Sized;
235}
236
237impl<S> RouterExt<S> for Router<S>
238where
239    S: Clone + Send + Sync + 'static,
240{
241    #[cfg(feature = "typed-routing")]
242    fn typed_get<H, T, P>(self, handler: H) -> Self
243    where
244        H: axum::handler::Handler<T, S>,
245        T: SecondElementIs<P> + 'static,
246        P: TypedPath,
247    {
248        self.route(P::PATH, axum::routing::get(handler))
249    }
250
251    #[cfg(feature = "typed-routing")]
252    fn typed_delete<H, T, P>(self, handler: H) -> Self
253    where
254        H: axum::handler::Handler<T, S>,
255        T: SecondElementIs<P> + 'static,
256        P: TypedPath,
257    {
258        self.route(P::PATH, axum::routing::delete(handler))
259    }
260
261    #[cfg(feature = "typed-routing")]
262    fn typed_head<H, T, P>(self, handler: H) -> Self
263    where
264        H: axum::handler::Handler<T, S>,
265        T: SecondElementIs<P> + 'static,
266        P: TypedPath,
267    {
268        self.route(P::PATH, axum::routing::head(handler))
269    }
270
271    #[cfg(feature = "typed-routing")]
272    fn typed_options<H, T, P>(self, handler: H) -> Self
273    where
274        H: axum::handler::Handler<T, S>,
275        T: SecondElementIs<P> + 'static,
276        P: TypedPath,
277    {
278        self.route(P::PATH, axum::routing::options(handler))
279    }
280
281    #[cfg(feature = "typed-routing")]
282    fn typed_patch<H, T, P>(self, handler: H) -> Self
283    where
284        H: axum::handler::Handler<T, S>,
285        T: SecondElementIs<P> + 'static,
286        P: TypedPath,
287    {
288        self.route(P::PATH, axum::routing::patch(handler))
289    }
290
291    #[cfg(feature = "typed-routing")]
292    fn typed_post<H, T, P>(self, handler: H) -> Self
293    where
294        H: axum::handler::Handler<T, S>,
295        T: SecondElementIs<P> + 'static,
296        P: TypedPath,
297    {
298        self.route(P::PATH, axum::routing::post(handler))
299    }
300
301    #[cfg(feature = "typed-routing")]
302    fn typed_put<H, T, P>(self, handler: H) -> Self
303    where
304        H: axum::handler::Handler<T, S>,
305        T: SecondElementIs<P> + 'static,
306        P: TypedPath,
307    {
308        self.route(P::PATH, axum::routing::put(handler))
309    }
310
311    #[cfg(feature = "typed-routing")]
312    fn typed_trace<H, T, P>(self, handler: H) -> Self
313    where
314        H: axum::handler::Handler<T, S>,
315        T: SecondElementIs<P> + 'static,
316        P: TypedPath,
317    {
318        self.route(P::PATH, axum::routing::trace(handler))
319    }
320
321    #[cfg(feature = "typed-routing")]
322    fn typed_connect<H, T, P>(self, handler: H) -> Self
323    where
324        H: axum::handler::Handler<T, S>,
325        T: SecondElementIs<P> + 'static,
326        P: TypedPath,
327    {
328        self.route(P::PATH, axum::routing::connect(handler))
329    }
330
331    #[track_caller]
332    fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
333    where
334        Self: Sized,
335    {
336        validate_tsr_path(path);
337        self = self.route(path, method_router);
338        add_tsr_redirect_route(self, path)
339    }
340
341    #[track_caller]
342    fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
343    where
344        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
345        T::Response: IntoResponse,
346        T::Future: Send + 'static,
347        Self: Sized,
348    {
349        validate_tsr_path(path);
350        self = self.route_service(path, service);
351        add_tsr_redirect_route(self, path)
352    }
353}
354
355#[track_caller]
356fn validate_tsr_path(path: &str) {
357    if path == "/" {
358        panic!("Cannot add a trailing slash redirect route for `/`")
359    }
360}
361
362fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
363where
364    S: Clone + Send + Sync + 'static,
365{
366    async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
367        let new_uri = map_path(uri, |path| {
368            path.strip_suffix('/')
369                .map(Cow::Borrowed)
370                .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
371        });
372
373        if let Some(new_uri) = new_uri {
374            Redirect::permanent(&new_uri.to_string()).into_response()
375        } else {
376            StatusCode::BAD_REQUEST.into_response()
377        }
378    }
379
380    if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
381        router.route(path_without_trailing_slash, any(redirect_handler))
382    } else {
383        router.route(&format!("{path}/"), any(redirect_handler))
384    }
385}
386
387/// Map the path of a `Uri`.
388///
389/// Returns `None` if the `Uri` cannot be put back together with the new path.
390fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
391where
392    F: FnOnce(&str) -> Cow<'_, str>,
393{
394    let mut parts = original_uri.into_parts();
395    let path_and_query = parts.path_and_query.as_ref()?;
396
397    let new_path = f(path_and_query.path());
398
399    let new_path_and_query = if let Some(query) = &path_and_query.query() {
400        format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
401    } else {
402        new_path.parse::<PathAndQuery>().ok()?
403    };
404    parts.path_and_query = Some(new_path_and_query);
405
406    Uri::from_parts(parts).ok()
407}
408
409mod sealed {
410    pub trait Sealed {}
411    impl<S> Sealed for axum::Router<S> {}
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::test_helpers::*;
418    use axum::{extract::Path, routing::get};
419
420    #[tokio::test]
421    async fn test_tsr() {
422        let app = Router::new()
423            .route_with_tsr("/foo", get(|| async {}))
424            .route_with_tsr("/bar/", get(|| async {}));
425
426        let client = TestClient::new(app);
427
428        let res = client.get("/foo").await;
429        assert_eq!(res.status(), StatusCode::OK);
430
431        let res = client.get("/foo/").await;
432        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
433        assert_eq!(res.headers()["location"], "/foo");
434
435        let res = client.get("/bar/").await;
436        assert_eq!(res.status(), StatusCode::OK);
437
438        let res = client.get("/bar").await;
439        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
440        assert_eq!(res.headers()["location"], "/bar/");
441    }
442
443    #[tokio::test]
444    async fn tsr_with_params() {
445        let app = Router::new()
446            .route_with_tsr(
447                "/a/{a}",
448                get(|Path(param): Path<String>| async move { param }),
449            )
450            .route_with_tsr(
451                "/b/{b}/",
452                get(|Path(param): Path<String>| async move { param }),
453            );
454
455        let client = TestClient::new(app);
456
457        let res = client.get("/a/foo").await;
458        assert_eq!(res.status(), StatusCode::OK);
459        assert_eq!(res.text().await, "foo");
460
461        let res = client.get("/a/foo/").await;
462        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
463        assert_eq!(res.headers()["location"], "/a/foo");
464
465        let res = client.get("/b/foo/").await;
466        assert_eq!(res.status(), StatusCode::OK);
467        assert_eq!(res.text().await, "foo");
468
469        let res = client.get("/b/foo").await;
470        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
471        assert_eq!(res.headers()["location"], "/b/foo/");
472    }
473
474    #[tokio::test]
475    async fn tsr_maintains_query_params() {
476        let app = Router::new().route_with_tsr("/foo", get(|| async {}));
477
478        let client = TestClient::new(app);
479
480        let res = client.get("/foo/?a=a").await;
481        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
482        assert_eq!(res.headers()["location"], "/foo?a=a");
483    }
484
485    #[tokio::test]
486    async fn tsr_works_in_nested_router() {
487        let app = Router::new().nest(
488            "/neko",
489            Router::new().route_with_tsr("/nyan/", get(|| async {})),
490        );
491
492        let client = TestClient::new(app);
493        let res = client.get("/neko/nyan/").await;
494        assert_eq!(res.status(), StatusCode::OK);
495
496        let res = client.get("/neko/nyan").await;
497        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
498        assert_eq!(res.headers()["location"], "/neko/nyan/");
499    }
500
501    #[test]
502    #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
503    fn tsr_at_root() {
504        let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
505    }
506}