1use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, HttpBody},
8 boxed::BoxedIntoRoute,
9 handler::Handler,
10 util::try_downcast,
11};
12use axum_core::{
13 extract::Request,
14 response::{IntoResponse, Response},
15};
16use std::{
17 convert::Infallible,
18 fmt,
19 marker::PhantomData,
20 sync::Arc,
21 task::{Context, Poll},
22};
23use tower_layer::Layer;
24use tower_service::Service;
25
26pub mod future;
27pub mod method_routing;
28
29mod into_make_service;
30mod method_filter;
31mod not_found;
32pub(crate) mod path_router;
33mod route;
34mod strip_prefix;
35pub(crate) mod url_params;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
41
42pub use self::method_routing::{
43 any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
44 options, options_service, patch, patch_service, post, post_service, put, put_service, trace,
45 trace_service, MethodRouter,
46};
47
48macro_rules! panic_on_err {
49 ($expr:expr) => {
50 match $expr {
51 Ok(x) => x,
52 Err(err) => panic!("{err}"),
53 }
54 };
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
58pub(crate) struct RouteId(u32);
59
60#[must_use]
62pub struct Router<S = ()> {
63 inner: Arc<RouterInner<S>>,
64}
65
66impl<S> Clone for Router<S> {
67 fn clone(&self) -> Self {
68 Self {
69 inner: Arc::clone(&self.inner),
70 }
71 }
72}
73
74struct RouterInner<S> {
75 path_router: PathRouter<S, false>,
76 fallback_router: PathRouter<S, true>,
77 default_fallback: bool,
78 catch_all_fallback: Fallback<S>,
79}
80
81impl<S> Default for Router<S>
82where
83 S: Clone + Send + Sync + 'static,
84{
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl<S> fmt::Debug for Router<S> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("Router")
93 .field("path_router", &self.inner.path_router)
94 .field("fallback_router", &self.inner.fallback_router)
95 .field("default_fallback", &self.inner.default_fallback)
96 .field("catch_all_fallback", &self.inner.catch_all_fallback)
97 .finish()
98 }
99}
100
101pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
102pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
103pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
104pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";
105
106impl<S> Router<S>
107where
108 S: Clone + Send + Sync + 'static,
109{
110 pub fn new() -> Self {
115 Self {
116 inner: Arc::new(RouterInner {
117 path_router: Default::default(),
118 fallback_router: PathRouter::new_fallback(),
119 default_fallback: true,
120 catch_all_fallback: Fallback::Default(Route::new(NotFound)),
121 }),
122 }
123 }
124
125 fn map_inner<F, S2>(self, f: F) -> Router<S2>
126 where
127 F: FnOnce(RouterInner<S>) -> RouterInner<S2>,
128 {
129 Router {
130 inner: Arc::new(f(self.into_inner())),
131 }
132 }
133
134 fn tap_inner_mut<F>(self, f: F) -> Self
135 where
136 F: FnOnce(&mut RouterInner<S>),
137 {
138 let mut inner = self.into_inner();
139 f(&mut inner);
140 Router {
141 inner: Arc::new(inner),
142 }
143 }
144
145 fn into_inner(self) -> RouterInner<S> {
146 match Arc::try_unwrap(self.inner) {
147 Ok(inner) => inner,
148 Err(arc) => RouterInner {
149 path_router: arc.path_router.clone(),
150 fallback_router: arc.fallback_router.clone(),
151 default_fallback: arc.default_fallback,
152 catch_all_fallback: arc.catch_all_fallback.clone(),
153 },
154 }
155 }
156
157 #[doc = include_str!("../docs/routing/route.md")]
158 #[track_caller]
159 pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
160 self.tap_inner_mut(|this| {
161 panic_on_err!(this.path_router.route(path, method_router));
162 })
163 }
164
165 #[doc = include_str!("../docs/routing/route_service.md")]
166 pub fn route_service<T>(self, path: &str, service: T) -> Self
167 where
168 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
169 T::Response: IntoResponse,
170 T::Future: Send + 'static,
171 {
172 let service = match try_downcast::<Router<S>, _>(service) {
173 Ok(_) => {
174 panic!(
175 "Invalid route: `Router::route_service` cannot be used with `Router`s. \
176 Use `Router::nest` instead"
177 );
178 }
179 Err(service) => service,
180 };
181
182 self.tap_inner_mut(|this| {
183 panic_on_err!(this.path_router.route_service(path, service));
184 })
185 }
186
187 #[doc = include_str!("../docs/routing/nest.md")]
188 #[track_caller]
189 pub fn nest(self, path: &str, router: Router<S>) -> Self {
190 let RouterInner {
191 path_router,
192 fallback_router,
193 default_fallback,
194 catch_all_fallback: _,
198 } = router.into_inner();
199
200 self.tap_inner_mut(|this| {
201 panic_on_err!(this.path_router.nest(path, path_router));
202
203 if !default_fallback {
204 panic_on_err!(this.fallback_router.nest(path, fallback_router));
205 }
206 })
207 }
208
209 #[track_caller]
211 pub fn nest_service<T>(self, path: &str, service: T) -> Self
212 where
213 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
214 T::Response: IntoResponse,
215 T::Future: Send + 'static,
216 {
217 self.tap_inner_mut(|this| {
218 panic_on_err!(this.path_router.nest_service(path, service));
219 })
220 }
221
222 #[doc = include_str!("../docs/routing/merge.md")]
223 #[track_caller]
224 pub fn merge<R>(self, other: R) -> Self
225 where
226 R: Into<Router<S>>,
227 {
228 const PANIC_MSG: &str =
229 "Failed to merge fallbacks. This is a bug in axum. Please file an issue";
230
231 let other: Router<S> = other.into();
232 let RouterInner {
233 path_router,
234 fallback_router: mut other_fallback,
235 default_fallback,
236 catch_all_fallback,
237 } = other.into_inner();
238
239 self.map_inner(|mut this| {
240 panic_on_err!(this.path_router.merge(path_router));
241
242 match (this.default_fallback, default_fallback) {
243 (true, true) => {
246 this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
247 }
248 (true, false) => {
250 this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
251 this.default_fallback = false;
252 }
253 (false, true) => {
255 let fallback_router = std::mem::take(&mut this.fallback_router);
256 other_fallback.merge(fallback_router).expect(PANIC_MSG);
257 this.fallback_router = other_fallback;
258 }
259 (false, false) => {
261 panic!("Cannot merge two `Router`s that both have a fallback")
262 }
263 };
264
265 this.catch_all_fallback = this
266 .catch_all_fallback
267 .merge(catch_all_fallback)
268 .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
269
270 this
271 })
272 }
273
274 #[doc = include_str!("../docs/routing/layer.md")]
275 pub fn layer<L>(self, layer: L) -> Router<S>
276 where
277 L: Layer<Route> + Clone + Send + 'static,
278 L::Service: Service<Request> + Clone + Send + 'static,
279 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
280 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
281 <L::Service as Service<Request>>::Future: Send + 'static,
282 {
283 self.map_inner(|this| RouterInner {
284 path_router: this.path_router.layer(layer.clone()),
285 fallback_router: this.fallback_router.layer(layer.clone()),
286 default_fallback: this.default_fallback,
287 catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
288 })
289 }
290
291 #[doc = include_str!("../docs/routing/route_layer.md")]
292 #[track_caller]
293 pub fn route_layer<L>(self, layer: L) -> Self
294 where
295 L: Layer<Route> + Clone + Send + 'static,
296 L::Service: Service<Request> + Clone + Send + 'static,
297 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
298 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
299 <L::Service as Service<Request>>::Future: Send + 'static,
300 {
301 self.map_inner(|this| RouterInner {
302 path_router: this.path_router.route_layer(layer),
303 fallback_router: this.fallback_router,
304 default_fallback: this.default_fallback,
305 catch_all_fallback: this.catch_all_fallback,
306 })
307 }
308
309 #[track_caller]
310 #[doc = include_str!("../docs/routing/fallback.md")]
311 pub fn fallback<H, T>(self, handler: H) -> Self
312 where
313 H: Handler<T, S>,
314 T: 'static,
315 {
316 self.tap_inner_mut(|this| {
317 this.catch_all_fallback =
318 Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
319 })
320 .fallback_endpoint(Endpoint::MethodRouter(any(handler)))
321 }
322
323 pub fn fallback_service<T>(self, service: T) -> Self
327 where
328 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
329 T::Response: IntoResponse,
330 T::Future: Send + 'static,
331 {
332 let route = Route::new(service);
333 self.tap_inner_mut(|this| {
334 this.catch_all_fallback = Fallback::Service(route.clone());
335 })
336 .fallback_endpoint(Endpoint::Route(route))
337 }
338
339 fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
340 self.tap_inner_mut(|this| {
341 this.fallback_router.set_fallback(endpoint);
342 this.default_fallback = false;
343 })
344 }
345
346 #[doc = include_str!("../docs/routing/with_state.md")]
347 pub fn with_state<S2>(self, state: S) -> Router<S2> {
348 self.map_inner(|this| RouterInner {
349 path_router: this.path_router.with_state(state.clone()),
350 fallback_router: this.fallback_router.with_state(state.clone()),
351 default_fallback: this.default_fallback,
352 catch_all_fallback: this.catch_all_fallback.with_state(state),
353 })
354 }
355
356 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<Infallible> {
357 let (req, state) = match self.inner.path_router.call_with_state(req, state) {
358 Ok(future) => return future,
359 Err((req, state)) => (req, state),
360 };
361
362 let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
363 Ok(future) => return future,
364 Err((req, state)) => (req, state),
365 };
366
367 self.inner
368 .catch_all_fallback
369 .clone()
370 .call_with_state(req, state)
371 }
372
373 pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S> {
427 RouterAsService {
428 router: self,
429 _marker: PhantomData,
430 }
431 }
432
433 pub fn into_service<B>(self) -> RouterIntoService<B, S> {
439 RouterIntoService {
440 router: self,
441 _marker: PhantomData,
442 }
443 }
444}
445
446impl Router {
447 pub fn into_make_service(self) -> IntoMakeService<Self> {
466 IntoMakeService::new(self.with_state(()))
469 }
470
471 #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
472 #[cfg(feature = "tokio")]
473 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
474 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
477 }
478}
479
480#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
482const _: () = {
483 use crate::serve::IncomingStream;
484
485 impl Service<IncomingStream<'_>> for Router<()> {
486 type Response = Self;
487 type Error = Infallible;
488 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
489
490 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
491 Poll::Ready(Ok(()))
492 }
493
494 fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
495 std::future::ready(Ok(self.clone().with_state(())))
498 }
499 }
500};
501
502impl<B> Service<Request<B>> for Router<()>
503where
504 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
505 B::Error: Into<axum_core::BoxError>,
506{
507 type Response = Response;
508 type Error = Infallible;
509 type Future = RouteFuture<Infallible>;
510
511 #[inline]
512 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
513 Poll::Ready(Ok(()))
514 }
515
516 #[inline]
517 fn call(&mut self, req: Request<B>) -> Self::Future {
518 let req = req.map(Body::new);
519 self.call_with_state(req, ())
520 }
521}
522
523pub struct RouterAsService<'a, B, S = ()> {
527 router: &'a mut Router<S>,
528 _marker: PhantomData<B>,
529}
530
531impl<'a, B> Service<Request<B>> for RouterAsService<'a, B, ()>
532where
533 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
534 B::Error: Into<axum_core::BoxError>,
535{
536 type Response = Response;
537 type Error = Infallible;
538 type Future = RouteFuture<Infallible>;
539
540 #[inline]
541 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
542 <Router as Service<Request<B>>>::poll_ready(self.router, cx)
543 }
544
545 #[inline]
546 fn call(&mut self, req: Request<B>) -> Self::Future {
547 self.router.call(req)
548 }
549}
550
551impl<'a, B, S> fmt::Debug for RouterAsService<'a, B, S>
552where
553 S: fmt::Debug,
554{
555 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
556 f.debug_struct("RouterAsService")
557 .field("router", &self.router)
558 .finish()
559 }
560}
561
562pub struct RouterIntoService<B, S = ()> {
566 router: Router<S>,
567 _marker: PhantomData<B>,
568}
569
570impl<B, S> Clone for RouterIntoService<B, S>
571where
572 Router<S>: Clone,
573{
574 fn clone(&self) -> Self {
575 Self {
576 router: self.router.clone(),
577 _marker: PhantomData,
578 }
579 }
580}
581
582impl<B> Service<Request<B>> for RouterIntoService<B, ()>
583where
584 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
585 B::Error: Into<axum_core::BoxError>,
586{
587 type Response = Response;
588 type Error = Infallible;
589 type Future = RouteFuture<Infallible>;
590
591 #[inline]
592 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
593 <Router as Service<Request<B>>>::poll_ready(&mut self.router, cx)
594 }
595
596 #[inline]
597 fn call(&mut self, req: Request<B>) -> Self::Future {
598 self.router.call(req)
599 }
600}
601
602impl<B, S> fmt::Debug for RouterIntoService<B, S>
603where
604 S: fmt::Debug,
605{
606 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
607 f.debug_struct("RouterIntoService")
608 .field("router", &self.router)
609 .finish()
610 }
611}
612
613enum Fallback<S, E = Infallible> {
614 Default(Route<E>),
615 Service(Route<E>),
616 BoxedHandler(BoxedIntoRoute<S, E>),
617}
618
619impl<S, E> Fallback<S, E>
620where
621 S: Clone,
622{
623 fn merge(self, other: Self) -> Option<Self> {
624 match (self, other) {
625 (Self::Default(_), pick @ Self::Default(_)) => Some(pick),
626 (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick),
627 _ => None,
628 }
629 }
630
631 fn map<F, E2>(self, f: F) -> Fallback<S, E2>
632 where
633 S: 'static,
634 E: 'static,
635 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
636 E2: 'static,
637 {
638 match self {
639 Self::Default(route) => Fallback::Default(f(route)),
640 Self::Service(route) => Fallback::Service(f(route)),
641 Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
642 }
643 }
644
645 fn with_state<S2>(self, state: S) -> Fallback<S2, E> {
646 match self {
647 Fallback::Default(route) => Fallback::Default(route),
648 Fallback::Service(route) => Fallback::Service(route),
649 Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
650 }
651 }
652
653 fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
654 match self {
655 Fallback::Default(route) | Fallback::Service(route) => {
656 RouteFuture::from_future(route.oneshot_inner(req))
657 }
658 Fallback::BoxedHandler(handler) => {
659 let mut route = handler.clone().into_route(state);
660 RouteFuture::from_future(route.oneshot_inner(req))
661 }
662 }
663 }
664}
665
666impl<S, E> Clone for Fallback<S, E> {
667 fn clone(&self) -> Self {
668 match self {
669 Self::Default(inner) => Self::Default(inner.clone()),
670 Self::Service(inner) => Self::Service(inner.clone()),
671 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
672 }
673 }
674}
675
676impl<S, E> fmt::Debug for Fallback<S, E> {
677 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
678 match self {
679 Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
680 Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
681 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
682 }
683 }
684}
685
686#[allow(clippy::large_enum_variant)]
687enum Endpoint<S> {
688 MethodRouter(MethodRouter<S>),
689 Route(Route),
690}
691
692impl<S> Endpoint<S>
693where
694 S: Clone + Send + Sync + 'static,
695{
696 fn layer<L>(self, layer: L) -> Endpoint<S>
697 where
698 L: Layer<Route> + Clone + Send + 'static,
699 L::Service: Service<Request> + Clone + Send + 'static,
700 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
701 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
702 <L::Service as Service<Request>>::Future: Send + 'static,
703 {
704 match self {
705 Endpoint::MethodRouter(method_router) => {
706 Endpoint::MethodRouter(method_router.layer(layer))
707 }
708 Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
709 }
710 }
711}
712
713impl<S> Clone for Endpoint<S> {
714 fn clone(&self) -> Self {
715 match self {
716 Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
717 Self::Route(inner) => Self::Route(inner.clone()),
718 }
719 }
720}
721
722impl<S> fmt::Debug for Endpoint<S> {
723 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
724 match self {
725 Self::MethodRouter(method_router) => {
726 f.debug_tuple("MethodRouter").field(method_router).finish()
727 }
728 Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
729 }
730 }
731}
732
733#[test]
734fn traits() {
735 use crate::test_helpers::*;
736 assert_send::<Router<()>>();
737 assert_sync::<Router<()>>();
738}