axum/routing/
path_router.rs

1use crate::extract::{nested_path::SetNestedPath, Request};
2use axum_core::response::IntoResponse;
3use matchit::MatchError;
4use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
5use tower_layer::Layer;
6use tower_service::Service;
7
8use super::{
9    future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
10    MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
11};
12
13pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
14    routes: HashMap<RouteId, Endpoint<S>>,
15    node: Arc<Node>,
16    prev_route_id: RouteId,
17}
18
19impl<S> PathRouter<S, true>
20where
21    S: Clone + Send + Sync + 'static,
22{
23    pub(super) fn new_fallback() -> Self {
24        let mut this = Self::default();
25        this.set_fallback(Endpoint::Route(Route::new(NotFound)));
26        this
27    }
28
29    pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
30        self.replace_endpoint("/", endpoint.clone());
31        self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
32    }
33}
34
35impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
36where
37    S: Clone + Send + Sync + 'static,
38{
39    pub(super) fn route(
40        &mut self,
41        path: &str,
42        method_router: MethodRouter<S>,
43    ) -> Result<(), Cow<'static, str>> {
44        fn validate_path(path: &str) -> Result<(), &'static str> {
45            if path.is_empty() {
46                return Err("Paths must start with a `/`. Use \"/\" for root routes");
47            } else if !path.starts_with('/') {
48                return Err("Paths must start with a `/`");
49            }
50
51            Ok(())
52        }
53
54        validate_path(path)?;
55
56        let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
57            .node
58            .path_to_route_id
59            .get(path)
60            .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
61        {
62            // if we're adding a new `MethodRouter` to a route that already has one just
63            // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
64            let service = Endpoint::MethodRouter(
65                prev_method_router
66                    .clone()
67                    .merge_for_path(Some(path), method_router),
68            );
69            self.routes.insert(route_id, service);
70            return Ok(());
71        } else {
72            Endpoint::MethodRouter(method_router)
73        };
74
75        let id = self.next_route_id();
76        self.set_node(path, id)?;
77        self.routes.insert(id, endpoint);
78
79        Ok(())
80    }
81
82    pub(super) fn route_service<T>(
83        &mut self,
84        path: &str,
85        service: T,
86    ) -> Result<(), Cow<'static, str>>
87    where
88        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
89        T::Response: IntoResponse,
90        T::Future: Send + 'static,
91    {
92        self.route_endpoint(path, Endpoint::Route(Route::new(service)))
93    }
94
95    pub(super) fn route_endpoint(
96        &mut self,
97        path: &str,
98        endpoint: Endpoint<S>,
99    ) -> Result<(), Cow<'static, str>> {
100        if path.is_empty() {
101            return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
102        } else if !path.starts_with('/') {
103            return Err("Paths must start with a `/`".into());
104        }
105
106        let id = self.next_route_id();
107        self.set_node(path, id)?;
108        self.routes.insert(id, endpoint);
109
110        Ok(())
111    }
112
113    fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
114        let mut node =
115            Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
116        if let Err(err) = node.insert(path, id) {
117            return Err(format!("Invalid route {path:?}: {err}"));
118        }
119        self.node = Arc::new(node);
120        Ok(())
121    }
122
123    pub(super) fn merge(
124        &mut self,
125        other: PathRouter<S, IS_FALLBACK>,
126    ) -> Result<(), Cow<'static, str>> {
127        let PathRouter {
128            routes,
129            node,
130            prev_route_id: _,
131        } = other;
132
133        for (id, route) in routes {
134            let path = node
135                .route_id_to_path
136                .get(&id)
137                .expect("no path for route id. This is a bug in axum. Please file an issue");
138
139            if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
140                // when merging two routers it doesn't matter if you do `a.merge(b)` or
141                // `b.merge(a)`. This must also be true for fallbacks.
142                //
143                // However all fallback routers will have routes for `/` and `/*` so when merging
144                // we have to ignore the top level fallbacks on one side otherwise we get
145                // conflicts.
146                //
147                // `Router::merge` makes sure that when merging fallbacks `other` always has the
148                // fallback we want to keep. It panics if both routers have a custom fallback. Thus
149                // it is always okay to ignore one fallback and `Router::merge` also makes sure the
150                // one we can ignore is that of `self`.
151                self.replace_endpoint(path, route);
152            } else {
153                match route {
154                    Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
155                    Endpoint::Route(route) => self.route_service(path, route)?,
156                }
157            }
158        }
159
160        Ok(())
161    }
162
163    pub(super) fn nest(
164        &mut self,
165        path_to_nest_at: &str,
166        router: PathRouter<S, IS_FALLBACK>,
167    ) -> Result<(), Cow<'static, str>> {
168        let prefix = validate_nest_path(path_to_nest_at);
169
170        let PathRouter {
171            routes,
172            node,
173            prev_route_id: _,
174        } = router;
175
176        for (id, endpoint) in routes {
177            let inner_path = node
178                .route_id_to_path
179                .get(&id)
180                .expect("no path for route id. This is a bug in axum. Please file an issue");
181
182            let path = path_for_nested_route(prefix, inner_path);
183
184            let layer = (
185                StripPrefix::layer(prefix),
186                SetNestedPath::layer(path_to_nest_at),
187            );
188            match endpoint.layer(layer) {
189                Endpoint::MethodRouter(method_router) => {
190                    self.route(&path, method_router)?;
191                }
192                Endpoint::Route(route) => {
193                    self.route_endpoint(&path, Endpoint::Route(route))?;
194                }
195            }
196        }
197
198        Ok(())
199    }
200
201    pub(super) fn nest_service<T>(
202        &mut self,
203        path_to_nest_at: &str,
204        svc: T,
205    ) -> Result<(), Cow<'static, str>>
206    where
207        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
208        T::Response: IntoResponse,
209        T::Future: Send + 'static,
210    {
211        let path = validate_nest_path(path_to_nest_at);
212        let prefix = path;
213
214        let path = if path.ends_with('/') {
215            format!("{path}*{NEST_TAIL_PARAM}")
216        } else {
217            format!("{path}/*{NEST_TAIL_PARAM}")
218        };
219
220        let layer = (
221            StripPrefix::layer(prefix),
222            SetNestedPath::layer(path_to_nest_at),
223        );
224        let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
225
226        self.route_endpoint(&path, endpoint.clone())?;
227
228        // `/*rest` is not matched by `/` so we need to also register a router at the
229        // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
230        // wouldn't match, which it should
231        self.route_endpoint(prefix, endpoint.clone())?;
232        if !prefix.ends_with('/') {
233            // same goes for `/foo/`, that should also match
234            self.route_endpoint(&format!("{prefix}/"), endpoint)?;
235        }
236
237        Ok(())
238    }
239
240    pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
241    where
242        L: Layer<Route> + Clone + Send + 'static,
243        L::Service: Service<Request> + Clone + Send + 'static,
244        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
245        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
246        <L::Service as Service<Request>>::Future: Send + 'static,
247    {
248        let routes = self
249            .routes
250            .into_iter()
251            .map(|(id, endpoint)| {
252                let route = endpoint.layer(layer.clone());
253                (id, route)
254            })
255            .collect();
256
257        PathRouter {
258            routes,
259            node: self.node,
260            prev_route_id: self.prev_route_id,
261        }
262    }
263
264    #[track_caller]
265    pub(super) fn route_layer<L>(self, layer: L) -> Self
266    where
267        L: Layer<Route> + Clone + Send + 'static,
268        L::Service: Service<Request> + Clone + Send + 'static,
269        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
270        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
271        <L::Service as Service<Request>>::Future: Send + 'static,
272    {
273        if self.routes.is_empty() {
274            panic!(
275                "Adding a route_layer before any routes is a no-op. \
276                 Add the routes you want the layer to apply to first."
277            );
278        }
279
280        let routes = self
281            .routes
282            .into_iter()
283            .map(|(id, endpoint)| {
284                let route = endpoint.layer(layer.clone());
285                (id, route)
286            })
287            .collect();
288
289        PathRouter {
290            routes,
291            node: self.node,
292            prev_route_id: self.prev_route_id,
293        }
294    }
295
296    pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
297        let routes = self
298            .routes
299            .into_iter()
300            .map(|(id, endpoint)| {
301                let endpoint: Endpoint<S2> = match endpoint {
302                    Endpoint::MethodRouter(method_router) => {
303                        Endpoint::MethodRouter(method_router.with_state(state.clone()))
304                    }
305                    Endpoint::Route(route) => Endpoint::Route(route),
306                };
307                (id, endpoint)
308            })
309            .collect();
310
311        PathRouter {
312            routes,
313            node: self.node,
314            prev_route_id: self.prev_route_id,
315        }
316    }
317
318    pub(super) fn call_with_state(
319        &self,
320        mut req: Request,
321        state: S,
322    ) -> Result<RouteFuture<Infallible>, (Request, S)> {
323        #[cfg(feature = "original-uri")]
324        {
325            use crate::extract::OriginalUri;
326
327            if req.extensions().get::<OriginalUri>().is_none() {
328                let original_uri = OriginalUri(req.uri().clone());
329                req.extensions_mut().insert(original_uri);
330            }
331        }
332
333        let path = req.uri().path().to_owned();
334
335        match self.node.at(&path) {
336            Ok(match_) => {
337                let id = *match_.value;
338
339                if !IS_FALLBACK {
340                    #[cfg(feature = "matched-path")]
341                    crate::extract::matched_path::set_matched_path_for_request(
342                        id,
343                        &self.node.route_id_to_path,
344                        req.extensions_mut(),
345                    );
346                }
347
348                url_params::insert_url_params(req.extensions_mut(), match_.params);
349
350                let endpoint = self
351                    .routes
352                    .get(&id)
353                    .expect("no route for id. This is a bug in axum. Please file an issue");
354
355                match endpoint {
356                    Endpoint::MethodRouter(method_router) => {
357                        Ok(method_router.call_with_state(req, state))
358                    }
359                    Endpoint::Route(route) => Ok(route.clone().call(req)),
360                }
361            }
362            // explicitly handle all variants in case matchit adds
363            // new ones we need to handle differently
364            Err(
365                MatchError::NotFound
366                | MatchError::ExtraTrailingSlash
367                | MatchError::MissingTrailingSlash,
368            ) => Err((req, state)),
369        }
370    }
371
372    pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
373        match self.node.at(path) {
374            Ok(match_) => {
375                let id = *match_.value;
376                self.routes.insert(id, endpoint);
377            }
378            Err(_) => self
379                .route_endpoint(path, endpoint)
380                .expect("path wasn't matched so endpoint shouldn't exist"),
381        }
382    }
383
384    fn next_route_id(&mut self) -> RouteId {
385        let next_id = self
386            .prev_route_id
387            .0
388            .checked_add(1)
389            .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
390        self.prev_route_id = RouteId(next_id);
391        self.prev_route_id
392    }
393}
394
395impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
396    fn default() -> Self {
397        Self {
398            routes: Default::default(),
399            node: Default::default(),
400            prev_route_id: RouteId(0),
401        }
402    }
403}
404
405impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        f.debug_struct("PathRouter")
408            .field("routes", &self.routes)
409            .field("node", &self.node)
410            .finish()
411    }
412}
413
414impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
415    fn clone(&self) -> Self {
416        Self {
417            routes: self.routes.clone(),
418            node: self.node.clone(),
419            prev_route_id: self.prev_route_id,
420        }
421    }
422}
423
424/// Wrapper around `matchit::Router` that supports merging two `Router`s.
425#[derive(Clone, Default)]
426struct Node {
427    inner: matchit::Router<RouteId>,
428    route_id_to_path: HashMap<RouteId, Arc<str>>,
429    path_to_route_id: HashMap<Arc<str>, RouteId>,
430}
431
432impl Node {
433    fn insert(
434        &mut self,
435        path: impl Into<String>,
436        val: RouteId,
437    ) -> Result<(), matchit::InsertError> {
438        let path = path.into();
439
440        self.inner.insert(&path, val)?;
441
442        let shared_path: Arc<str> = path.into();
443        self.route_id_to_path.insert(val, shared_path.clone());
444        self.path_to_route_id.insert(shared_path, val);
445
446        Ok(())
447    }
448
449    fn at<'n, 'p>(
450        &'n self,
451        path: &'p str,
452    ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
453        self.inner.at(path)
454    }
455}
456
457impl fmt::Debug for Node {
458    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
459        f.debug_struct("Node")
460            .field("paths", &self.route_id_to_path)
461            .finish()
462    }
463}
464
465#[track_caller]
466fn validate_nest_path(path: &str) -> &str {
467    if path.is_empty() {
468        // nesting at `""` and `"/"` should mean the same thing
469        return "/";
470    }
471
472    if path.contains('*') {
473        panic!("Invalid route: nested routes cannot contain wildcards (*)");
474    }
475
476    path
477}
478
479pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
480    debug_assert!(prefix.starts_with('/'));
481    debug_assert!(path.starts_with('/'));
482
483    if prefix.ends_with('/') {
484        format!("{prefix}{}", path.trim_start_matches('/')).into()
485    } else if path == "/" {
486        prefix.into()
487    } else {
488        format!("{prefix}{path}").into()
489    }
490}