1use crate::extract::{nested_path::SetNestedPath, Request};
2use axum_core::response::IntoResponse;
3use matchit::MatchError;
4use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
5use tower_layer::Layer;
6use tower_service::Service;
7
8use super::{
9 future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
10 MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
11};
12
13pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
14 routes: HashMap<RouteId, Endpoint<S>>,
15 node: Arc<Node>,
16 prev_route_id: RouteId,
17}
18
19impl<S> PathRouter<S, true>
20where
21 S: Clone + Send + Sync + 'static,
22{
23 pub(super) fn new_fallback() -> Self {
24 let mut this = Self::default();
25 this.set_fallback(Endpoint::Route(Route::new(NotFound)));
26 this
27 }
28
29 pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
30 self.replace_endpoint("/", endpoint.clone());
31 self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
32 }
33}
34
35impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
36where
37 S: Clone + Send + Sync + 'static,
38{
39 pub(super) fn route(
40 &mut self,
41 path: &str,
42 method_router: MethodRouter<S>,
43 ) -> Result<(), Cow<'static, str>> {
44 fn validate_path(path: &str) -> Result<(), &'static str> {
45 if path.is_empty() {
46 return Err("Paths must start with a `/`. Use \"/\" for root routes");
47 } else if !path.starts_with('/') {
48 return Err("Paths must start with a `/`");
49 }
50
51 Ok(())
52 }
53
54 validate_path(path)?;
55
56 let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
57 .node
58 .path_to_route_id
59 .get(path)
60 .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
61 {
62 let service = Endpoint::MethodRouter(
65 prev_method_router
66 .clone()
67 .merge_for_path(Some(path), method_router),
68 );
69 self.routes.insert(route_id, service);
70 return Ok(());
71 } else {
72 Endpoint::MethodRouter(method_router)
73 };
74
75 let id = self.next_route_id();
76 self.set_node(path, id)?;
77 self.routes.insert(id, endpoint);
78
79 Ok(())
80 }
81
82 pub(super) fn route_service<T>(
83 &mut self,
84 path: &str,
85 service: T,
86 ) -> Result<(), Cow<'static, str>>
87 where
88 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
89 T::Response: IntoResponse,
90 T::Future: Send + 'static,
91 {
92 self.route_endpoint(path, Endpoint::Route(Route::new(service)))
93 }
94
95 pub(super) fn route_endpoint(
96 &mut self,
97 path: &str,
98 endpoint: Endpoint<S>,
99 ) -> Result<(), Cow<'static, str>> {
100 if path.is_empty() {
101 return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
102 } else if !path.starts_with('/') {
103 return Err("Paths must start with a `/`".into());
104 }
105
106 let id = self.next_route_id();
107 self.set_node(path, id)?;
108 self.routes.insert(id, endpoint);
109
110 Ok(())
111 }
112
113 fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
114 let mut node =
115 Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
116 if let Err(err) = node.insert(path, id) {
117 return Err(format!("Invalid route {path:?}: {err}"));
118 }
119 self.node = Arc::new(node);
120 Ok(())
121 }
122
123 pub(super) fn merge(
124 &mut self,
125 other: PathRouter<S, IS_FALLBACK>,
126 ) -> Result<(), Cow<'static, str>> {
127 let PathRouter {
128 routes,
129 node,
130 prev_route_id: _,
131 } = other;
132
133 for (id, route) in routes {
134 let path = node
135 .route_id_to_path
136 .get(&id)
137 .expect("no path for route id. This is a bug in axum. Please file an issue");
138
139 if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
140 self.replace_endpoint(path, route);
152 } else {
153 match route {
154 Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
155 Endpoint::Route(route) => self.route_service(path, route)?,
156 }
157 }
158 }
159
160 Ok(())
161 }
162
163 pub(super) fn nest(
164 &mut self,
165 path_to_nest_at: &str,
166 router: PathRouter<S, IS_FALLBACK>,
167 ) -> Result<(), Cow<'static, str>> {
168 let prefix = validate_nest_path(path_to_nest_at);
169
170 let PathRouter {
171 routes,
172 node,
173 prev_route_id: _,
174 } = router;
175
176 for (id, endpoint) in routes {
177 let inner_path = node
178 .route_id_to_path
179 .get(&id)
180 .expect("no path for route id. This is a bug in axum. Please file an issue");
181
182 let path = path_for_nested_route(prefix, inner_path);
183
184 let layer = (
185 StripPrefix::layer(prefix),
186 SetNestedPath::layer(path_to_nest_at),
187 );
188 match endpoint.layer(layer) {
189 Endpoint::MethodRouter(method_router) => {
190 self.route(&path, method_router)?;
191 }
192 Endpoint::Route(route) => {
193 self.route_endpoint(&path, Endpoint::Route(route))?;
194 }
195 }
196 }
197
198 Ok(())
199 }
200
201 pub(super) fn nest_service<T>(
202 &mut self,
203 path_to_nest_at: &str,
204 svc: T,
205 ) -> Result<(), Cow<'static, str>>
206 where
207 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
208 T::Response: IntoResponse,
209 T::Future: Send + 'static,
210 {
211 let path = validate_nest_path(path_to_nest_at);
212 let prefix = path;
213
214 let path = if path.ends_with('/') {
215 format!("{path}*{NEST_TAIL_PARAM}")
216 } else {
217 format!("{path}/*{NEST_TAIL_PARAM}")
218 };
219
220 let layer = (
221 StripPrefix::layer(prefix),
222 SetNestedPath::layer(path_to_nest_at),
223 );
224 let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
225
226 self.route_endpoint(&path, endpoint.clone())?;
227
228 self.route_endpoint(prefix, endpoint.clone())?;
232 if !prefix.ends_with('/') {
233 self.route_endpoint(&format!("{prefix}/"), endpoint)?;
235 }
236
237 Ok(())
238 }
239
240 pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
241 where
242 L: Layer<Route> + Clone + Send + 'static,
243 L::Service: Service<Request> + Clone + Send + 'static,
244 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
245 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
246 <L::Service as Service<Request>>::Future: Send + 'static,
247 {
248 let routes = self
249 .routes
250 .into_iter()
251 .map(|(id, endpoint)| {
252 let route = endpoint.layer(layer.clone());
253 (id, route)
254 })
255 .collect();
256
257 PathRouter {
258 routes,
259 node: self.node,
260 prev_route_id: self.prev_route_id,
261 }
262 }
263
264 #[track_caller]
265 pub(super) fn route_layer<L>(self, layer: L) -> Self
266 where
267 L: Layer<Route> + Clone + Send + 'static,
268 L::Service: Service<Request> + Clone + Send + 'static,
269 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
270 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
271 <L::Service as Service<Request>>::Future: Send + 'static,
272 {
273 if self.routes.is_empty() {
274 panic!(
275 "Adding a route_layer before any routes is a no-op. \
276 Add the routes you want the layer to apply to first."
277 );
278 }
279
280 let routes = self
281 .routes
282 .into_iter()
283 .map(|(id, endpoint)| {
284 let route = endpoint.layer(layer.clone());
285 (id, route)
286 })
287 .collect();
288
289 PathRouter {
290 routes,
291 node: self.node,
292 prev_route_id: self.prev_route_id,
293 }
294 }
295
296 pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
297 let routes = self
298 .routes
299 .into_iter()
300 .map(|(id, endpoint)| {
301 let endpoint: Endpoint<S2> = match endpoint {
302 Endpoint::MethodRouter(method_router) => {
303 Endpoint::MethodRouter(method_router.with_state(state.clone()))
304 }
305 Endpoint::Route(route) => Endpoint::Route(route),
306 };
307 (id, endpoint)
308 })
309 .collect();
310
311 PathRouter {
312 routes,
313 node: self.node,
314 prev_route_id: self.prev_route_id,
315 }
316 }
317
318 pub(super) fn call_with_state(
319 &self,
320 mut req: Request,
321 state: S,
322 ) -> Result<RouteFuture<Infallible>, (Request, S)> {
323 #[cfg(feature = "original-uri")]
324 {
325 use crate::extract::OriginalUri;
326
327 if req.extensions().get::<OriginalUri>().is_none() {
328 let original_uri = OriginalUri(req.uri().clone());
329 req.extensions_mut().insert(original_uri);
330 }
331 }
332
333 let path = req.uri().path().to_owned();
334
335 match self.node.at(&path) {
336 Ok(match_) => {
337 let id = *match_.value;
338
339 if !IS_FALLBACK {
340 #[cfg(feature = "matched-path")]
341 crate::extract::matched_path::set_matched_path_for_request(
342 id,
343 &self.node.route_id_to_path,
344 req.extensions_mut(),
345 );
346 }
347
348 url_params::insert_url_params(req.extensions_mut(), match_.params);
349
350 let endpoint = self
351 .routes
352 .get(&id)
353 .expect("no route for id. This is a bug in axum. Please file an issue");
354
355 match endpoint {
356 Endpoint::MethodRouter(method_router) => {
357 Ok(method_router.call_with_state(req, state))
358 }
359 Endpoint::Route(route) => Ok(route.clone().call(req)),
360 }
361 }
362 Err(
365 MatchError::NotFound
366 | MatchError::ExtraTrailingSlash
367 | MatchError::MissingTrailingSlash,
368 ) => Err((req, state)),
369 }
370 }
371
372 pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
373 match self.node.at(path) {
374 Ok(match_) => {
375 let id = *match_.value;
376 self.routes.insert(id, endpoint);
377 }
378 Err(_) => self
379 .route_endpoint(path, endpoint)
380 .expect("path wasn't matched so endpoint shouldn't exist"),
381 }
382 }
383
384 fn next_route_id(&mut self) -> RouteId {
385 let next_id = self
386 .prev_route_id
387 .0
388 .checked_add(1)
389 .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
390 self.prev_route_id = RouteId(next_id);
391 self.prev_route_id
392 }
393}
394
395impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
396 fn default() -> Self {
397 Self {
398 routes: Default::default(),
399 node: Default::default(),
400 prev_route_id: RouteId(0),
401 }
402 }
403}
404
405impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 f.debug_struct("PathRouter")
408 .field("routes", &self.routes)
409 .field("node", &self.node)
410 .finish()
411 }
412}
413
414impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
415 fn clone(&self) -> Self {
416 Self {
417 routes: self.routes.clone(),
418 node: self.node.clone(),
419 prev_route_id: self.prev_route_id,
420 }
421 }
422}
423
424#[derive(Clone, Default)]
426struct Node {
427 inner: matchit::Router<RouteId>,
428 route_id_to_path: HashMap<RouteId, Arc<str>>,
429 path_to_route_id: HashMap<Arc<str>, RouteId>,
430}
431
432impl Node {
433 fn insert(
434 &mut self,
435 path: impl Into<String>,
436 val: RouteId,
437 ) -> Result<(), matchit::InsertError> {
438 let path = path.into();
439
440 self.inner.insert(&path, val)?;
441
442 let shared_path: Arc<str> = path.into();
443 self.route_id_to_path.insert(val, shared_path.clone());
444 self.path_to_route_id.insert(shared_path, val);
445
446 Ok(())
447 }
448
449 fn at<'n, 'p>(
450 &'n self,
451 path: &'p str,
452 ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
453 self.inner.at(path)
454 }
455}
456
457impl fmt::Debug for Node {
458 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
459 f.debug_struct("Node")
460 .field("paths", &self.route_id_to_path)
461 .finish()
462 }
463}
464
465#[track_caller]
466fn validate_nest_path(path: &str) -> &str {
467 if path.is_empty() {
468 return "/";
470 }
471
472 if path.contains('*') {
473 panic!("Invalid route: nested routes cannot contain wildcards (*)");
474 }
475
476 path
477}
478
479pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
480 debug_assert!(prefix.starts_with('/'));
481 debug_assert!(path.starts_with('/'));
482
483 if prefix.ends_with('/') {
484 format!("{prefix}{}", path.trim_start_matches('/')).into()
485 } else if path == "/" {
486 prefix.into()
487 } else {
488 format!("{prefix}{path}").into()
489 }
490}