axum_extra/routing/
mod.rs

1//! Additional types for defining routes.
2
3use axum::{
4    extract::Request,
5    response::{IntoResponse, Redirect, Response},
6    routing::{any, MethodRouter},
7    Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28/// Extension trait that adds additional methods to [`Router`].
29pub trait RouterExt<S>: sealed::Sealed {
30    /// Add a typed `GET` route to the router.
31    ///
32    /// The path will be inferred from the first argument to the handler function which must
33    /// implement [`TypedPath`].
34    ///
35    /// See [`TypedPath`] for more details and examples.
36    #[cfg(feature = "typed-routing")]
37    fn typed_get<H, T, P>(self, handler: H) -> Self
38    where
39        H: axum::handler::Handler<T, S>,
40        T: SecondElementIs<P> + 'static,
41        P: TypedPath;
42
43    /// Add a typed `DELETE` route to the router.
44    ///
45    /// The path will be inferred from the first argument to the handler function which must
46    /// implement [`TypedPath`].
47    ///
48    /// See [`TypedPath`] for more details and examples.
49    #[cfg(feature = "typed-routing")]
50    fn typed_delete<H, T, P>(self, handler: H) -> Self
51    where
52        H: axum::handler::Handler<T, S>,
53        T: SecondElementIs<P> + 'static,
54        P: TypedPath;
55
56    /// Add a typed `HEAD` route to the router.
57    ///
58    /// The path will be inferred from the first argument to the handler function which must
59    /// implement [`TypedPath`].
60    ///
61    /// See [`TypedPath`] for more details and examples.
62    #[cfg(feature = "typed-routing")]
63    fn typed_head<H, T, P>(self, handler: H) -> Self
64    where
65        H: axum::handler::Handler<T, S>,
66        T: SecondElementIs<P> + 'static,
67        P: TypedPath;
68
69    /// Add a typed `OPTIONS` route to the router.
70    ///
71    /// The path will be inferred from the first argument to the handler function which must
72    /// implement [`TypedPath`].
73    ///
74    /// See [`TypedPath`] for more details and examples.
75    #[cfg(feature = "typed-routing")]
76    fn typed_options<H, T, P>(self, handler: H) -> Self
77    where
78        H: axum::handler::Handler<T, S>,
79        T: SecondElementIs<P> + 'static,
80        P: TypedPath;
81
82    /// Add a typed `PATCH` route to the router.
83    ///
84    /// The path will be inferred from the first argument to the handler function which must
85    /// implement [`TypedPath`].
86    ///
87    /// See [`TypedPath`] for more details and examples.
88    #[cfg(feature = "typed-routing")]
89    fn typed_patch<H, T, P>(self, handler: H) -> Self
90    where
91        H: axum::handler::Handler<T, S>,
92        T: SecondElementIs<P> + 'static,
93        P: TypedPath;
94
95    /// Add a typed `POST` route to the router.
96    ///
97    /// The path will be inferred from the first argument to the handler function which must
98    /// implement [`TypedPath`].
99    ///
100    /// See [`TypedPath`] for more details and examples.
101    #[cfg(feature = "typed-routing")]
102    fn typed_post<H, T, P>(self, handler: H) -> Self
103    where
104        H: axum::handler::Handler<T, S>,
105        T: SecondElementIs<P> + 'static,
106        P: TypedPath;
107
108    /// Add a typed `PUT` route to the router.
109    ///
110    /// The path will be inferred from the first argument to the handler function which must
111    /// implement [`TypedPath`].
112    ///
113    /// See [`TypedPath`] for more details and examples.
114    #[cfg(feature = "typed-routing")]
115    fn typed_put<H, T, P>(self, handler: H) -> Self
116    where
117        H: axum::handler::Handler<T, S>,
118        T: SecondElementIs<P> + 'static,
119        P: TypedPath;
120
121    /// Add a typed `TRACE` route to the router.
122    ///
123    /// The path will be inferred from the first argument to the handler function which must
124    /// implement [`TypedPath`].
125    ///
126    /// See [`TypedPath`] for more details and examples.
127    #[cfg(feature = "typed-routing")]
128    fn typed_trace<H, T, P>(self, handler: H) -> Self
129    where
130        H: axum::handler::Handler<T, S>,
131        T: SecondElementIs<P> + 'static,
132        P: TypedPath;
133
134    /// Add another route to the router with an additional "trailing slash redirect" route.
135    ///
136    /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
137    /// route for `/foo/` that redirects to `/foo`.
138    ///
139    /// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a
140    /// route for `/bar` that redirects to `/bar/`.
141    ///
142    /// This is similar to what axum 0.5.x did by default, except this explicitly adds another
143    /// route, so trying to add a `/foo/` route after calling `.route_with_tsr("/foo", /* ... */)`
144    /// will result in a panic due to route overlap.
145    ///
146    /// # Example
147    ///
148    /// ```
149    /// use axum::{Router, routing::get};
150    /// use axum_extra::routing::RouterExt;
151    ///
152    /// let app = Router::new()
153    ///     // `/foo/` will redirect to `/foo`
154    ///     .route_with_tsr("/foo", get(|| async {}))
155    ///     // `/bar` will redirect to `/bar/`
156    ///     .route_with_tsr("/bar/", get(|| async {}));
157    /// # let _: Router = app;
158    /// ```
159    fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
160    where
161        Self: Sized;
162
163    /// Add another route to the router with an additional "trailing slash redirect" route.
164    ///
165    /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
166    fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
167    where
168        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
169        T::Response: IntoResponse,
170        T::Future: Send + 'static,
171        Self: Sized;
172}
173
174impl<S> RouterExt<S> for Router<S>
175where
176    S: Clone + Send + Sync + 'static,
177{
178    #[cfg(feature = "typed-routing")]
179    fn typed_get<H, T, P>(self, handler: H) -> Self
180    where
181        H: axum::handler::Handler<T, S>,
182        T: SecondElementIs<P> + 'static,
183        P: TypedPath,
184    {
185        self.route(P::PATH, axum::routing::get(handler))
186    }
187
188    #[cfg(feature = "typed-routing")]
189    fn typed_delete<H, T, P>(self, handler: H) -> Self
190    where
191        H: axum::handler::Handler<T, S>,
192        T: SecondElementIs<P> + 'static,
193        P: TypedPath,
194    {
195        self.route(P::PATH, axum::routing::delete(handler))
196    }
197
198    #[cfg(feature = "typed-routing")]
199    fn typed_head<H, T, P>(self, handler: H) -> Self
200    where
201        H: axum::handler::Handler<T, S>,
202        T: SecondElementIs<P> + 'static,
203        P: TypedPath,
204    {
205        self.route(P::PATH, axum::routing::head(handler))
206    }
207
208    #[cfg(feature = "typed-routing")]
209    fn typed_options<H, T, P>(self, handler: H) -> Self
210    where
211        H: axum::handler::Handler<T, S>,
212        T: SecondElementIs<P> + 'static,
213        P: TypedPath,
214    {
215        self.route(P::PATH, axum::routing::options(handler))
216    }
217
218    #[cfg(feature = "typed-routing")]
219    fn typed_patch<H, T, P>(self, handler: H) -> Self
220    where
221        H: axum::handler::Handler<T, S>,
222        T: SecondElementIs<P> + 'static,
223        P: TypedPath,
224    {
225        self.route(P::PATH, axum::routing::patch(handler))
226    }
227
228    #[cfg(feature = "typed-routing")]
229    fn typed_post<H, T, P>(self, handler: H) -> Self
230    where
231        H: axum::handler::Handler<T, S>,
232        T: SecondElementIs<P> + 'static,
233        P: TypedPath,
234    {
235        self.route(P::PATH, axum::routing::post(handler))
236    }
237
238    #[cfg(feature = "typed-routing")]
239    fn typed_put<H, T, P>(self, handler: H) -> Self
240    where
241        H: axum::handler::Handler<T, S>,
242        T: SecondElementIs<P> + 'static,
243        P: TypedPath,
244    {
245        self.route(P::PATH, axum::routing::put(handler))
246    }
247
248    #[cfg(feature = "typed-routing")]
249    fn typed_trace<H, T, P>(self, handler: H) -> Self
250    where
251        H: axum::handler::Handler<T, S>,
252        T: SecondElementIs<P> + 'static,
253        P: TypedPath,
254    {
255        self.route(P::PATH, axum::routing::trace(handler))
256    }
257
258    #[track_caller]
259    fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
260    where
261        Self: Sized,
262    {
263        validate_tsr_path(path);
264        self = self.route(path, method_router);
265        add_tsr_redirect_route(self, path)
266    }
267
268    #[track_caller]
269    fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
270    where
271        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
272        T::Response: IntoResponse,
273        T::Future: Send + 'static,
274        Self: Sized,
275    {
276        validate_tsr_path(path);
277        self = self.route_service(path, service);
278        add_tsr_redirect_route(self, path)
279    }
280}
281
282#[track_caller]
283fn validate_tsr_path(path: &str) {
284    if path == "/" {
285        panic!("Cannot add a trailing slash redirect route for `/`")
286    }
287}
288
289fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
290where
291    S: Clone + Send + Sync + 'static,
292{
293    async fn redirect_handler(uri: Uri) -> Response {
294        let new_uri = map_path(uri, |path| {
295            path.strip_suffix('/')
296                .map(Cow::Borrowed)
297                .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
298        });
299
300        if let Some(new_uri) = new_uri {
301            Redirect::permanent(&new_uri.to_string()).into_response()
302        } else {
303            StatusCode::BAD_REQUEST.into_response()
304        }
305    }
306
307    if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
308        router.route(path_without_trailing_slash, any(redirect_handler))
309    } else {
310        router.route(&format!("{path}/"), any(redirect_handler))
311    }
312}
313
314/// Map the path of a `Uri`.
315///
316/// Returns `None` if the `Uri` cannot be put back together with the new path.
317fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
318where
319    F: FnOnce(&str) -> Cow<'_, str>,
320{
321    let mut parts = original_uri.into_parts();
322    let path_and_query = parts.path_and_query.as_ref()?;
323
324    let new_path = f(path_and_query.path());
325
326    let new_path_and_query = if let Some(query) = &path_and_query.query() {
327        format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
328    } else {
329        new_path.parse::<PathAndQuery>().ok()?
330    };
331    parts.path_and_query = Some(new_path_and_query);
332
333    Uri::from_parts(parts).ok()
334}
335
336mod sealed {
337    pub trait Sealed {}
338    impl<S> Sealed for axum::Router<S> {}
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::test_helpers::*;
345    use axum::{extract::Path, routing::get};
346
347    #[tokio::test]
348    async fn test_tsr() {
349        let app = Router::new()
350            .route_with_tsr("/foo", get(|| async {}))
351            .route_with_tsr("/bar/", get(|| async {}));
352
353        let client = TestClient::new(app);
354
355        let res = client.get("/foo").await;
356        assert_eq!(res.status(), StatusCode::OK);
357
358        let res = client.get("/foo/").await;
359        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
360        assert_eq!(res.headers()["location"], "/foo");
361
362        let res = client.get("/bar/").await;
363        assert_eq!(res.status(), StatusCode::OK);
364
365        let res = client.get("/bar").await;
366        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
367        assert_eq!(res.headers()["location"], "/bar/");
368    }
369
370    #[tokio::test]
371    async fn tsr_with_params() {
372        let app = Router::new()
373            .route_with_tsr(
374                "/a/:a",
375                get(|Path(param): Path<String>| async move { param }),
376            )
377            .route_with_tsr(
378                "/b/:b/",
379                get(|Path(param): Path<String>| async move { param }),
380            );
381
382        let client = TestClient::new(app);
383
384        let res = client.get("/a/foo").await;
385        assert_eq!(res.status(), StatusCode::OK);
386        assert_eq!(res.text().await, "foo");
387
388        let res = client.get("/a/foo/").await;
389        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
390        assert_eq!(res.headers()["location"], "/a/foo");
391
392        let res = client.get("/b/foo/").await;
393        assert_eq!(res.status(), StatusCode::OK);
394        assert_eq!(res.text().await, "foo");
395
396        let res = client.get("/b/foo").await;
397        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
398        assert_eq!(res.headers()["location"], "/b/foo/");
399    }
400
401    #[tokio::test]
402    async fn tsr_maintains_query_params() {
403        let app = Router::new().route_with_tsr("/foo", get(|| async {}));
404
405        let client = TestClient::new(app);
406
407        let res = client.get("/foo/?a=a").await;
408        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
409        assert_eq!(res.headers()["location"], "/foo?a=a");
410    }
411
412    #[test]
413    #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
414    fn tsr_at_root() {
415        let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
416    }
417}