1use crate::{
2 extract::{nested_path::SetNestedPath, Request},
3 handler::Handler,
4};
5use axum_core::response::IntoResponse;
6use matchit::MatchError;
7use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
8use tower_layer::Layer;
9use tower_service::Service;
10
11use super::{
12 future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
13 MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
14};
15
16pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
17 routes: HashMap<RouteId, Endpoint<S>>,
18 node: Arc<Node>,
19 prev_route_id: RouteId,
20 v7_checks: bool,
21}
22
23impl<S> PathRouter<S, true>
24where
25 S: Clone + Send + Sync + 'static,
26{
27 pub(super) fn new_fallback() -> Self {
28 let mut this = Self::default();
29 this.set_fallback(Endpoint::Route(Route::new(NotFound)));
30 this
31 }
32
33 pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
34 self.replace_endpoint("/", endpoint.clone());
35 self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
36 }
37}
38
39fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
40 if path.is_empty() {
41 return Err("Paths must start with a `/`. Use \"/\" for root routes");
42 } else if !path.starts_with('/') {
43 return Err("Paths must start with a `/`");
44 }
45
46 if v7_checks {
47 validate_v07_paths(path)?;
48 }
49
50 Ok(())
51}
52
53fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
54 path.split('/')
55 .find_map(|segment| {
56 if segment.starts_with(':') {
57 Some(Err(
58 "Path segments must not start with `:`. For capture groups, use \
59 `{capture}`. If you meant to literally match a segment starting with \
60 a colon, call `without_v07_checks` on the router.",
61 ))
62 } else if segment.starts_with('*') {
63 Some(Err(
64 "Path segments must not start with `*`. For wildcard capture, use \
65 `{*wildcard}`. If you meant to literally match a segment starting with \
66 an asterisk, call `without_v07_checks` on the router.",
67 ))
68 } else {
69 None
70 }
71 })
72 .unwrap_or(Ok(()))
73}
74
75impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
76where
77 S: Clone + Send + Sync + 'static,
78{
79 pub(super) fn without_v07_checks(&mut self) {
80 self.v7_checks = false;
81 }
82
83 pub(super) fn route(
84 &mut self,
85 path: &str,
86 method_router: MethodRouter<S>,
87 ) -> Result<(), Cow<'static, str>> {
88 validate_path(self.v7_checks, path)?;
89
90 let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
91 .node
92 .path_to_route_id
93 .get(path)
94 .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
95 {
96 let service = Endpoint::MethodRouter(
99 prev_method_router
100 .clone()
101 .merge_for_path(Some(path), method_router),
102 );
103 self.routes.insert(route_id, service);
104 return Ok(());
105 } else {
106 Endpoint::MethodRouter(method_router)
107 };
108
109 let id = self.next_route_id();
110 self.set_node(path, id)?;
111 self.routes.insert(id, endpoint);
112
113 Ok(())
114 }
115
116 pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
117 where
118 H: Handler<T, S>,
119 T: 'static,
120 {
121 for (_, endpoint) in self.routes.iter_mut() {
122 if let Endpoint::MethodRouter(rt) = endpoint {
123 *rt = rt.clone().default_fallback(handler.clone());
124 }
125 }
126 }
127
128 pub(super) fn route_service<T>(
129 &mut self,
130 path: &str,
131 service: T,
132 ) -> Result<(), Cow<'static, str>>
133 where
134 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
135 T::Response: IntoResponse,
136 T::Future: Send + 'static,
137 {
138 self.route_endpoint(path, Endpoint::Route(Route::new(service)))
139 }
140
141 pub(super) fn route_endpoint(
142 &mut self,
143 path: &str,
144 endpoint: Endpoint<S>,
145 ) -> Result<(), Cow<'static, str>> {
146 validate_path(self.v7_checks, path)?;
147
148 let id = self.next_route_id();
149 self.set_node(path, id)?;
150 self.routes.insert(id, endpoint);
151
152 Ok(())
153 }
154
155 fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
156 let node = Arc::make_mut(&mut self.node);
157
158 node.insert(path, id)
159 .map_err(|err| format!("Invalid route {path:?}: {err}"))
160 }
161
162 pub(super) fn merge(
163 &mut self,
164 other: PathRouter<S, IS_FALLBACK>,
165 ) -> Result<(), Cow<'static, str>> {
166 let PathRouter {
167 routes,
168 node,
169 prev_route_id: _,
170 v7_checks,
171 } = other;
172
173 self.v7_checks |= v7_checks;
175
176 for (id, route) in routes {
177 let path = node
178 .route_id_to_path
179 .get(&id)
180 .expect("no path for route id. This is a bug in axum. Please file an issue");
181
182 if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
183 self.replace_endpoint(path, route);
195 } else {
196 match route {
197 Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
198 Endpoint::Route(route) => self.route_service(path, route)?,
199 }
200 }
201 }
202
203 Ok(())
204 }
205
206 pub(super) fn nest(
207 &mut self,
208 path_to_nest_at: &str,
209 router: PathRouter<S, IS_FALLBACK>,
210 ) -> Result<(), Cow<'static, str>> {
211 let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);
212
213 let PathRouter {
214 routes,
215 node,
216 prev_route_id: _,
217 v7_checks: _,
219 } = router;
220
221 for (id, endpoint) in routes {
222 let inner_path = node
223 .route_id_to_path
224 .get(&id)
225 .expect("no path for route id. This is a bug in axum. Please file an issue");
226
227 let path = path_for_nested_route(prefix, inner_path);
228
229 let layer = (
230 StripPrefix::layer(prefix),
231 SetNestedPath::layer(path_to_nest_at),
232 );
233 match endpoint.layer(layer) {
234 Endpoint::MethodRouter(method_router) => {
235 self.route(&path, method_router)?;
236 }
237 Endpoint::Route(route) => {
238 self.route_endpoint(&path, Endpoint::Route(route))?;
239 }
240 }
241 }
242
243 Ok(())
244 }
245
246 pub(super) fn nest_service<T>(
247 &mut self,
248 path_to_nest_at: &str,
249 svc: T,
250 ) -> Result<(), Cow<'static, str>>
251 where
252 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
253 T::Response: IntoResponse,
254 T::Future: Send + 'static,
255 {
256 let path = validate_nest_path(self.v7_checks, path_to_nest_at);
257 let prefix = path;
258
259 let path = if path.ends_with('/') {
260 format!("{path}{{*{NEST_TAIL_PARAM}}}")
261 } else {
262 format!("{path}/{{*{NEST_TAIL_PARAM}}}")
263 };
264
265 let layer = (
266 StripPrefix::layer(prefix),
267 SetNestedPath::layer(path_to_nest_at),
268 );
269 let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
270
271 self.route_endpoint(&path, endpoint.clone())?;
272
273 self.route_endpoint(prefix, endpoint.clone())?;
277 if !prefix.ends_with('/') {
278 self.route_endpoint(&format!("{prefix}/"), endpoint)?;
280 }
281
282 Ok(())
283 }
284
285 pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
286 where
287 L: Layer<Route> + Clone + Send + Sync + 'static,
288 L::Service: Service<Request> + Clone + Send + Sync + 'static,
289 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
290 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
291 <L::Service as Service<Request>>::Future: Send + 'static,
292 {
293 let routes = self
294 .routes
295 .into_iter()
296 .map(|(id, endpoint)| {
297 let route = endpoint.layer(layer.clone());
298 (id, route)
299 })
300 .collect();
301
302 PathRouter {
303 routes,
304 node: self.node,
305 prev_route_id: self.prev_route_id,
306 v7_checks: self.v7_checks,
307 }
308 }
309
310 #[track_caller]
311 pub(super) fn route_layer<L>(self, layer: L) -> Self
312 where
313 L: Layer<Route> + Clone + Send + Sync + 'static,
314 L::Service: Service<Request> + Clone + Send + Sync + 'static,
315 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
316 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
317 <L::Service as Service<Request>>::Future: Send + 'static,
318 {
319 if self.routes.is_empty() {
320 panic!(
321 "Adding a route_layer before any routes is a no-op. \
322 Add the routes you want the layer to apply to first."
323 );
324 }
325
326 let routes = self
327 .routes
328 .into_iter()
329 .map(|(id, endpoint)| {
330 let route = endpoint.layer(layer.clone());
331 (id, route)
332 })
333 .collect();
334
335 PathRouter {
336 routes,
337 node: self.node,
338 prev_route_id: self.prev_route_id,
339 v7_checks: self.v7_checks,
340 }
341 }
342
343 pub(super) fn has_routes(&self) -> bool {
344 !self.routes.is_empty()
345 }
346
347 pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
348 let routes = self
349 .routes
350 .into_iter()
351 .map(|(id, endpoint)| {
352 let endpoint: Endpoint<S2> = match endpoint {
353 Endpoint::MethodRouter(method_router) => {
354 Endpoint::MethodRouter(method_router.with_state(state.clone()))
355 }
356 Endpoint::Route(route) => Endpoint::Route(route),
357 };
358 (id, endpoint)
359 })
360 .collect();
361
362 PathRouter {
363 routes,
364 node: self.node,
365 prev_route_id: self.prev_route_id,
366 v7_checks: self.v7_checks,
367 }
368 }
369
370 pub(super) fn call_with_state(
371 &self,
372 #[cfg_attr(not(feature = "original-uri"), allow(unused_mut))] mut req: Request,
373 state: S,
374 ) -> Result<RouteFuture<Infallible>, (Request, S)> {
375 #[cfg(feature = "original-uri")]
376 {
377 use crate::extract::OriginalUri;
378
379 if req.extensions().get::<OriginalUri>().is_none() {
380 let original_uri = OriginalUri(req.uri().clone());
381 req.extensions_mut().insert(original_uri);
382 }
383 }
384
385 let (mut parts, body) = req.into_parts();
386
387 match self.node.at(parts.uri.path()) {
388 Ok(match_) => {
389 let id = *match_.value;
390
391 if !IS_FALLBACK {
392 #[cfg(feature = "matched-path")]
393 crate::extract::matched_path::set_matched_path_for_request(
394 id,
395 &self.node.route_id_to_path,
396 &mut parts.extensions,
397 );
398 }
399
400 url_params::insert_url_params(&mut parts.extensions, match_.params);
401
402 let endpoint = self
403 .routes
404 .get(&id)
405 .expect("no route for id. This is a bug in axum. Please file an issue");
406
407 let req = Request::from_parts(parts, body);
408 match endpoint {
409 Endpoint::MethodRouter(method_router) => {
410 Ok(method_router.call_with_state(req, state))
411 }
412 Endpoint::Route(route) => Ok(route.clone().call_owned(req)),
413 }
414 }
415 Err(MatchError::NotFound) => Err((Request::from_parts(parts, body), state)),
418 }
419 }
420
421 pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
422 match self.node.at(path) {
423 Ok(match_) => {
424 let id = *match_.value;
425 self.routes.insert(id, endpoint);
426 }
427 Err(_) => self
428 .route_endpoint(path, endpoint)
429 .expect("path wasn't matched so endpoint shouldn't exist"),
430 }
431 }
432
433 fn next_route_id(&mut self) -> RouteId {
434 let next_id = self
435 .prev_route_id
436 .0
437 .checked_add(1)
438 .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
439 self.prev_route_id = RouteId(next_id);
440 self.prev_route_id
441 }
442}
443
444impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
445 fn default() -> Self {
446 Self {
447 routes: Default::default(),
448 node: Default::default(),
449 prev_route_id: RouteId(0),
450 v7_checks: true,
451 }
452 }
453}
454
455impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 f.debug_struct("PathRouter")
458 .field("routes", &self.routes)
459 .field("node", &self.node)
460 .finish()
461 }
462}
463
464impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
465 fn clone(&self) -> Self {
466 Self {
467 routes: self.routes.clone(),
468 node: self.node.clone(),
469 prev_route_id: self.prev_route_id,
470 v7_checks: self.v7_checks,
471 }
472 }
473}
474
475#[derive(Clone, Default)]
477struct Node {
478 inner: matchit::Router<RouteId>,
479 route_id_to_path: HashMap<RouteId, Arc<str>>,
480 path_to_route_id: HashMap<Arc<str>, RouteId>,
481}
482
483impl Node {
484 fn insert(
485 &mut self,
486 path: impl Into<String>,
487 val: RouteId,
488 ) -> Result<(), matchit::InsertError> {
489 let path = path.into();
490
491 self.inner.insert(&path, val)?;
492
493 let shared_path: Arc<str> = path.into();
494 self.route_id_to_path.insert(val, shared_path.clone());
495 self.path_to_route_id.insert(shared_path, val);
496
497 Ok(())
498 }
499
500 fn at<'n, 'p>(
501 &'n self,
502 path: &'p str,
503 ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
504 self.inner.at(path)
505 }
506}
507
508impl fmt::Debug for Node {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 f.debug_struct("Node")
511 .field("paths", &self.route_id_to_path)
512 .finish()
513 }
514}
515
516#[track_caller]
517fn validate_nest_path(v7_checks: bool, path: &str) -> &str {
518 assert!(path.starts_with('/'));
519 assert!(path.len() > 1);
520
521 if path.split('/').any(|segment| {
522 segment.starts_with("{*") && segment.ends_with('}') && !segment.ends_with("}}")
523 }) {
524 panic!("Invalid route: nested routes cannot contain wildcards (*)");
525 }
526
527 if v7_checks {
528 validate_v07_paths(path).unwrap();
529 }
530
531 path
532}
533
534pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
535 debug_assert!(prefix.starts_with('/'));
536 debug_assert!(path.starts_with('/'));
537
538 if prefix.ends_with('/') {
539 format!("{prefix}{}", path.trim_start_matches('/')).into()
540 } else if path == "/" {
541 prefix.into()
542 } else {
543 format!("{prefix}{path}").into()
544 }
545}