axum/routing/
path_router.rs

1use crate::{
2    extract::{nested_path::SetNestedPath, Request},
3    handler::Handler,
4};
5use axum_core::response::IntoResponse;
6use matchit::MatchError;
7use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
8use tower_layer::Layer;
9use tower_service::Service;
10
11use super::{
12    future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
13    MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
14};
15
16pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
17    routes: HashMap<RouteId, Endpoint<S>>,
18    node: Arc<Node>,
19    prev_route_id: RouteId,
20    v7_checks: bool,
21}
22
23impl<S> PathRouter<S, true>
24where
25    S: Clone + Send + Sync + 'static,
26{
27    pub(super) fn new_fallback() -> Self {
28        let mut this = Self::default();
29        this.set_fallback(Endpoint::Route(Route::new(NotFound)));
30        this
31    }
32
33    pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
34        self.replace_endpoint("/", endpoint.clone());
35        self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
36    }
37}
38
39fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
40    if path.is_empty() {
41        return Err("Paths must start with a `/`. Use \"/\" for root routes");
42    } else if !path.starts_with('/') {
43        return Err("Paths must start with a `/`");
44    }
45
46    if v7_checks {
47        validate_v07_paths(path)?;
48    }
49
50    Ok(())
51}
52
53fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
54    path.split('/')
55        .find_map(|segment| {
56            if segment.starts_with(':') {
57                Some(Err(
58                    "Path segments must not start with `:`. For capture groups, use \
59                `{capture}`. If you meant to literally match a segment starting with \
60                a colon, call `without_v07_checks` on the router.",
61                ))
62            } else if segment.starts_with('*') {
63                Some(Err(
64                    "Path segments must not start with `*`. For wildcard capture, use \
65                `{*wildcard}`. If you meant to literally match a segment starting with \
66                an asterisk, call `without_v07_checks` on the router.",
67                ))
68            } else {
69                None
70            }
71        })
72        .unwrap_or(Ok(()))
73}
74
75impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
76where
77    S: Clone + Send + Sync + 'static,
78{
79    pub(super) fn without_v07_checks(&mut self) {
80        self.v7_checks = false;
81    }
82
83    pub(super) fn route(
84        &mut self,
85        path: &str,
86        method_router: MethodRouter<S>,
87    ) -> Result<(), Cow<'static, str>> {
88        validate_path(self.v7_checks, path)?;
89
90        let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
91            .node
92            .path_to_route_id
93            .get(path)
94            .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
95        {
96            // if we're adding a new `MethodRouter` to a route that already has one just
97            // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
98            let service = Endpoint::MethodRouter(
99                prev_method_router
100                    .clone()
101                    .merge_for_path(Some(path), method_router),
102            );
103            self.routes.insert(route_id, service);
104            return Ok(());
105        } else {
106            Endpoint::MethodRouter(method_router)
107        };
108
109        let id = self.next_route_id();
110        self.set_node(path, id)?;
111        self.routes.insert(id, endpoint);
112
113        Ok(())
114    }
115
116    pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
117    where
118        H: Handler<T, S>,
119        T: 'static,
120    {
121        for (_, endpoint) in self.routes.iter_mut() {
122            if let Endpoint::MethodRouter(rt) = endpoint {
123                *rt = rt.clone().default_fallback(handler.clone());
124            }
125        }
126    }
127
128    pub(super) fn route_service<T>(
129        &mut self,
130        path: &str,
131        service: T,
132    ) -> Result<(), Cow<'static, str>>
133    where
134        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
135        T::Response: IntoResponse,
136        T::Future: Send + 'static,
137    {
138        self.route_endpoint(path, Endpoint::Route(Route::new(service)))
139    }
140
141    pub(super) fn route_endpoint(
142        &mut self,
143        path: &str,
144        endpoint: Endpoint<S>,
145    ) -> Result<(), Cow<'static, str>> {
146        validate_path(self.v7_checks, path)?;
147
148        let id = self.next_route_id();
149        self.set_node(path, id)?;
150        self.routes.insert(id, endpoint);
151
152        Ok(())
153    }
154
155    fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
156        let node = Arc::make_mut(&mut self.node);
157
158        node.insert(path, id)
159            .map_err(|err| format!("Invalid route {path:?}: {err}"))
160    }
161
162    pub(super) fn merge(
163        &mut self,
164        other: PathRouter<S, IS_FALLBACK>,
165    ) -> Result<(), Cow<'static, str>> {
166        let PathRouter {
167            routes,
168            node,
169            prev_route_id: _,
170            v7_checks,
171        } = other;
172
173        // If either of the two did not allow paths starting with `:` or `*`, do not allow them for the merged router either.
174        self.v7_checks |= v7_checks;
175
176        for (id, route) in routes {
177            let path = node
178                .route_id_to_path
179                .get(&id)
180                .expect("no path for route id. This is a bug in axum. Please file an issue");
181
182            if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
183                // when merging two routers it doesn't matter if you do `a.merge(b)` or
184                // `b.merge(a)`. This must also be true for fallbacks.
185                //
186                // However all fallback routers will have routes for `/` and `/*` so when merging
187                // we have to ignore the top level fallbacks on one side otherwise we get
188                // conflicts.
189                //
190                // `Router::merge` makes sure that when merging fallbacks `other` always has the
191                // fallback we want to keep. It panics if both routers have a custom fallback. Thus
192                // it is always okay to ignore one fallback and `Router::merge` also makes sure the
193                // one we can ignore is that of `self`.
194                self.replace_endpoint(path, route);
195            } else {
196                match route {
197                    Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
198                    Endpoint::Route(route) => self.route_service(path, route)?,
199                }
200            }
201        }
202
203        Ok(())
204    }
205
206    pub(super) fn nest(
207        &mut self,
208        path_to_nest_at: &str,
209        router: PathRouter<S, IS_FALLBACK>,
210    ) -> Result<(), Cow<'static, str>> {
211        let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);
212
213        let PathRouter {
214            routes,
215            node,
216            prev_route_id: _,
217            // Ignore the configuration of the nested router
218            v7_checks: _,
219        } = router;
220
221        for (id, endpoint) in routes {
222            let inner_path = node
223                .route_id_to_path
224                .get(&id)
225                .expect("no path for route id. This is a bug in axum. Please file an issue");
226
227            let path = path_for_nested_route(prefix, inner_path);
228
229            let layer = (
230                StripPrefix::layer(prefix),
231                SetNestedPath::layer(path_to_nest_at),
232            );
233            match endpoint.layer(layer) {
234                Endpoint::MethodRouter(method_router) => {
235                    self.route(&path, method_router)?;
236                }
237                Endpoint::Route(route) => {
238                    self.route_endpoint(&path, Endpoint::Route(route))?;
239                }
240            }
241        }
242
243        Ok(())
244    }
245
246    pub(super) fn nest_service<T>(
247        &mut self,
248        path_to_nest_at: &str,
249        svc: T,
250    ) -> Result<(), Cow<'static, str>>
251    where
252        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
253        T::Response: IntoResponse,
254        T::Future: Send + 'static,
255    {
256        let path = validate_nest_path(self.v7_checks, path_to_nest_at);
257        let prefix = path;
258
259        let path = if path.ends_with('/') {
260            format!("{path}{{*{NEST_TAIL_PARAM}}}")
261        } else {
262            format!("{path}/{{*{NEST_TAIL_PARAM}}}")
263        };
264
265        let layer = (
266            StripPrefix::layer(prefix),
267            SetNestedPath::layer(path_to_nest_at),
268        );
269        let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
270
271        self.route_endpoint(&path, endpoint.clone())?;
272
273        // `/{*rest}` is not matched by `/` so we need to also register a router at the
274        // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
275        // wouldn't match, which it should
276        self.route_endpoint(prefix, endpoint.clone())?;
277        if !prefix.ends_with('/') {
278            // same goes for `/foo/`, that should also match
279            self.route_endpoint(&format!("{prefix}/"), endpoint)?;
280        }
281
282        Ok(())
283    }
284
285    pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
286    where
287        L: Layer<Route> + Clone + Send + Sync + 'static,
288        L::Service: Service<Request> + Clone + Send + Sync + 'static,
289        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
290        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
291        <L::Service as Service<Request>>::Future: Send + 'static,
292    {
293        let routes = self
294            .routes
295            .into_iter()
296            .map(|(id, endpoint)| {
297                let route = endpoint.layer(layer.clone());
298                (id, route)
299            })
300            .collect();
301
302        PathRouter {
303            routes,
304            node: self.node,
305            prev_route_id: self.prev_route_id,
306            v7_checks: self.v7_checks,
307        }
308    }
309
310    #[track_caller]
311    pub(super) fn route_layer<L>(self, layer: L) -> Self
312    where
313        L: Layer<Route> + Clone + Send + Sync + 'static,
314        L::Service: Service<Request> + Clone + Send + Sync + 'static,
315        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
316        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
317        <L::Service as Service<Request>>::Future: Send + 'static,
318    {
319        if self.routes.is_empty() {
320            panic!(
321                "Adding a route_layer before any routes is a no-op. \
322                 Add the routes you want the layer to apply to first."
323            );
324        }
325
326        let routes = self
327            .routes
328            .into_iter()
329            .map(|(id, endpoint)| {
330                let route = endpoint.layer(layer.clone());
331                (id, route)
332            })
333            .collect();
334
335        PathRouter {
336            routes,
337            node: self.node,
338            prev_route_id: self.prev_route_id,
339            v7_checks: self.v7_checks,
340        }
341    }
342
343    pub(super) fn has_routes(&self) -> bool {
344        !self.routes.is_empty()
345    }
346
347    pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
348        let routes = self
349            .routes
350            .into_iter()
351            .map(|(id, endpoint)| {
352                let endpoint: Endpoint<S2> = match endpoint {
353                    Endpoint::MethodRouter(method_router) => {
354                        Endpoint::MethodRouter(method_router.with_state(state.clone()))
355                    }
356                    Endpoint::Route(route) => Endpoint::Route(route),
357                };
358                (id, endpoint)
359            })
360            .collect();
361
362        PathRouter {
363            routes,
364            node: self.node,
365            prev_route_id: self.prev_route_id,
366            v7_checks: self.v7_checks,
367        }
368    }
369
370    pub(super) fn call_with_state(
371        &self,
372        #[cfg_attr(not(feature = "original-uri"), allow(unused_mut))] mut req: Request,
373        state: S,
374    ) -> Result<RouteFuture<Infallible>, (Request, S)> {
375        #[cfg(feature = "original-uri")]
376        {
377            use crate::extract::OriginalUri;
378
379            if req.extensions().get::<OriginalUri>().is_none() {
380                let original_uri = OriginalUri(req.uri().clone());
381                req.extensions_mut().insert(original_uri);
382            }
383        }
384
385        let (mut parts, body) = req.into_parts();
386
387        match self.node.at(parts.uri.path()) {
388            Ok(match_) => {
389                let id = *match_.value;
390
391                if !IS_FALLBACK {
392                    #[cfg(feature = "matched-path")]
393                    crate::extract::matched_path::set_matched_path_for_request(
394                        id,
395                        &self.node.route_id_to_path,
396                        &mut parts.extensions,
397                    );
398                }
399
400                url_params::insert_url_params(&mut parts.extensions, match_.params);
401
402                let endpoint = self
403                    .routes
404                    .get(&id)
405                    .expect("no route for id. This is a bug in axum. Please file an issue");
406
407                let req = Request::from_parts(parts, body);
408                match endpoint {
409                    Endpoint::MethodRouter(method_router) => {
410                        Ok(method_router.call_with_state(req, state))
411                    }
412                    Endpoint::Route(route) => Ok(route.clone().call_owned(req)),
413                }
414            }
415            // explicitly handle all variants in case matchit adds
416            // new ones we need to handle differently
417            Err(MatchError::NotFound) => Err((Request::from_parts(parts, body), state)),
418        }
419    }
420
421    pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
422        match self.node.at(path) {
423            Ok(match_) => {
424                let id = *match_.value;
425                self.routes.insert(id, endpoint);
426            }
427            Err(_) => self
428                .route_endpoint(path, endpoint)
429                .expect("path wasn't matched so endpoint shouldn't exist"),
430        }
431    }
432
433    fn next_route_id(&mut self) -> RouteId {
434        let next_id = self
435            .prev_route_id
436            .0
437            .checked_add(1)
438            .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
439        self.prev_route_id = RouteId(next_id);
440        self.prev_route_id
441    }
442}
443
444impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
445    fn default() -> Self {
446        Self {
447            routes: Default::default(),
448            node: Default::default(),
449            prev_route_id: RouteId(0),
450            v7_checks: true,
451        }
452    }
453}
454
455impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
456    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457        f.debug_struct("PathRouter")
458            .field("routes", &self.routes)
459            .field("node", &self.node)
460            .finish()
461    }
462}
463
464impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
465    fn clone(&self) -> Self {
466        Self {
467            routes: self.routes.clone(),
468            node: self.node.clone(),
469            prev_route_id: self.prev_route_id,
470            v7_checks: self.v7_checks,
471        }
472    }
473}
474
475/// Wrapper around `matchit::Router` that supports merging two `Router`s.
476#[derive(Clone, Default)]
477struct Node {
478    inner: matchit::Router<RouteId>,
479    route_id_to_path: HashMap<RouteId, Arc<str>>,
480    path_to_route_id: HashMap<Arc<str>, RouteId>,
481}
482
483impl Node {
484    fn insert(
485        &mut self,
486        path: impl Into<String>,
487        val: RouteId,
488    ) -> Result<(), matchit::InsertError> {
489        let path = path.into();
490
491        self.inner.insert(&path, val)?;
492
493        let shared_path: Arc<str> = path.into();
494        self.route_id_to_path.insert(val, shared_path.clone());
495        self.path_to_route_id.insert(shared_path, val);
496
497        Ok(())
498    }
499
500    fn at<'n, 'p>(
501        &'n self,
502        path: &'p str,
503    ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
504        self.inner.at(path)
505    }
506}
507
508impl fmt::Debug for Node {
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        f.debug_struct("Node")
511            .field("paths", &self.route_id_to_path)
512            .finish()
513    }
514}
515
516#[track_caller]
517fn validate_nest_path(v7_checks: bool, path: &str) -> &str {
518    assert!(path.starts_with('/'));
519    assert!(path.len() > 1);
520
521    if path.split('/').any(|segment| {
522        segment.starts_with("{*") && segment.ends_with('}') && !segment.ends_with("}}")
523    }) {
524        panic!("Invalid route: nested routes cannot contain wildcards (*)");
525    }
526
527    if v7_checks {
528        validate_v07_paths(path).unwrap();
529    }
530
531    path
532}
533
534pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
535    debug_assert!(prefix.starts_with('/'));
536    debug_assert!(path.starts_with('/'));
537
538    if prefix.ends_with('/') {
539        format!("{prefix}{}", path.trim_start_matches('/')).into()
540    } else if path == "/" {
541        prefix.into()
542    } else {
543        format!("{prefix}{path}").into()
544    }
545}