1use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18 convert::Infallible,
19 fmt,
20 task::{Context, Poll},
21};
22use tower::service_fn;
23use tower_layer::Layer;
24use tower_service::Service;
25
26macro_rules! top_level_service_fn {
27 (
28 $name:ident, GET
29 ) => {
30 top_level_service_fn!(
31 $name,
58 GET
59 );
60 };
61
62 (
63 $name:ident, CONNECT
64 ) => {
65 top_level_service_fn!(
66 $name,
71 CONNECT
72 );
73 };
74
75 (
76 $name:ident, $method:ident
77 ) => {
78 top_level_service_fn!(
79 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
80 $name,
83 $method
84 );
85 };
86
87 (
88 $(#[$m:meta])+
89 $name:ident, $method:ident
90 ) => {
91 $(#[$m])+
92 pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
93 where
94 T: Service<Request> + Clone + Send + Sync + 'static,
95 T::Response: IntoResponse + 'static,
96 T::Future: Send + 'static,
97 S: Clone,
98 {
99 on_service(MethodFilter::$method, svc)
100 }
101 };
102}
103
104macro_rules! top_level_handler_fn {
105 (
106 $name:ident, GET
107 ) => {
108 top_level_handler_fn!(
109 $name,
130 GET
131 );
132 };
133
134 (
135 $name:ident, CONNECT
136 ) => {
137 top_level_handler_fn!(
138 $name,
143 CONNECT
144 );
145 };
146
147 (
148 $name:ident, $method:ident
149 ) => {
150 top_level_handler_fn!(
151 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
152 $name,
155 $method
156 );
157 };
158
159 (
160 $(#[$m:meta])+
161 $name:ident, $method:ident
162 ) => {
163 $(#[$m])+
164 pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
165 where
166 H: Handler<T, S>,
167 T: 'static,
168 S: Clone + Send + Sync + 'static,
169 {
170 on(MethodFilter::$method, handler)
171 }
172 };
173}
174
175macro_rules! chained_service_fn {
176 (
177 $name:ident, GET
178 ) => {
179 chained_service_fn!(
180 $name,
212 GET
213 );
214 };
215
216 (
217 $name:ident, CONNECT
218 ) => {
219 chained_service_fn!(
220 $name,
225 CONNECT
226 );
227 };
228
229 (
230 $name:ident, $method:ident
231 ) => {
232 chained_service_fn!(
233 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
234 $name,
237 $method
238 );
239 };
240
241 (
242 $(#[$m:meta])+
243 $name:ident, $method:ident
244 ) => {
245 $(#[$m])+
246 #[track_caller]
247 pub fn $name<T>(self, svc: T) -> Self
248 where
249 T: Service<Request, Error = E>
250 + Clone
251 + Send
252 + Sync
253 + 'static,
254 T::Response: IntoResponse + 'static,
255 T::Future: Send + 'static,
256 {
257 self.on_service(MethodFilter::$method, svc)
258 }
259 };
260}
261
262macro_rules! chained_handler_fn {
263 (
264 $name:ident, GET
265 ) => {
266 chained_handler_fn!(
267 $name,
288 GET
289 );
290 };
291
292 (
293 $name:ident, CONNECT
294 ) => {
295 chained_handler_fn!(
296 $name,
301 CONNECT
302 );
303 };
304
305 (
306 $name:ident, $method:ident
307 ) => {
308 chained_handler_fn!(
309 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
310 $name,
313 $method
314 );
315 };
316
317 (
318 $(#[$m:meta])+
319 $name:ident, $method:ident
320 ) => {
321 $(#[$m])+
322 #[track_caller]
323 pub fn $name<H, T>(self, handler: H) -> Self
324 where
325 H: Handler<T, S>,
326 T: 'static,
327 S: Send + Sync + 'static,
328 {
329 self.on(MethodFilter::$method, handler)
330 }
331 };
332}
333
334top_level_service_fn!(connect_service, CONNECT);
335top_level_service_fn!(delete_service, DELETE);
336top_level_service_fn!(get_service, GET);
337top_level_service_fn!(head_service, HEAD);
338top_level_service_fn!(options_service, OPTIONS);
339top_level_service_fn!(patch_service, PATCH);
340top_level_service_fn!(post_service, POST);
341top_level_service_fn!(put_service, PUT);
342top_level_service_fn!(trace_service, TRACE);
343
344pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
368where
369 T: Service<Request> + Clone + Send + Sync + 'static,
370 T::Response: IntoResponse + 'static,
371 T::Future: Send + 'static,
372 S: Clone,
373{
374 MethodRouter::new().on_service(filter, svc)
375}
376
377pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
427where
428 T: Service<Request> + Clone + Send + Sync + 'static,
429 T::Response: IntoResponse + 'static,
430 T::Future: Send + 'static,
431 S: Clone,
432{
433 MethodRouter::new()
434 .fallback_service(svc)
435 .skip_allow_header()
436}
437
438top_level_handler_fn!(connect, CONNECT);
439top_level_handler_fn!(delete, DELETE);
440top_level_handler_fn!(get, GET);
441top_level_handler_fn!(head, HEAD);
442top_level_handler_fn!(options, OPTIONS);
443top_level_handler_fn!(patch, PATCH);
444top_level_handler_fn!(post, POST);
445top_level_handler_fn!(put, PUT);
446top_level_handler_fn!(trace, TRACE);
447
448pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
466where
467 H: Handler<T, S>,
468 T: 'static,
469 S: Clone + Send + Sync + 'static,
470{
471 MethodRouter::new().on(filter, handler)
472}
473
474pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
508where
509 H: Handler<T, S>,
510 T: 'static,
511 S: Clone + Send + Sync + 'static,
512{
513 MethodRouter::new().fallback(handler).skip_allow_header()
514}
515
516#[must_use]
546pub struct MethodRouter<S = (), E = Infallible> {
547 get: MethodEndpoint<S, E>,
548 head: MethodEndpoint<S, E>,
549 delete: MethodEndpoint<S, E>,
550 options: MethodEndpoint<S, E>,
551 patch: MethodEndpoint<S, E>,
552 post: MethodEndpoint<S, E>,
553 put: MethodEndpoint<S, E>,
554 trace: MethodEndpoint<S, E>,
555 connect: MethodEndpoint<S, E>,
556 fallback: Fallback<S, E>,
557 allow_header: AllowHeader,
558}
559
560#[derive(Clone, Debug)]
561enum AllowHeader {
562 None,
564 Skip,
566 Bytes(BytesMut),
568}
569
570impl AllowHeader {
571 fn merge(self, other: Self) -> Self {
572 match (self, other) {
573 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
574 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
575 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
576 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
577 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
578 a.extend_from_slice(b",");
579 a.extend_from_slice(&b);
580 AllowHeader::Bytes(a)
581 }
582 }
583 }
584}
585
586impl<S, E> fmt::Debug for MethodRouter<S, E> {
587 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588 f.debug_struct("MethodRouter")
589 .field("get", &self.get)
590 .field("head", &self.head)
591 .field("delete", &self.delete)
592 .field("options", &self.options)
593 .field("patch", &self.patch)
594 .field("post", &self.post)
595 .field("put", &self.put)
596 .field("trace", &self.trace)
597 .field("connect", &self.connect)
598 .field("fallback", &self.fallback)
599 .field("allow_header", &self.allow_header)
600 .finish()
601 }
602}
603
604impl<S> MethodRouter<S, Infallible>
605where
606 S: Clone,
607{
608 #[track_caller]
630 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
631 where
632 H: Handler<T, S>,
633 T: 'static,
634 S: Send + Sync + 'static,
635 {
636 self.on_endpoint(
637 filter,
638 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
639 )
640 }
641
642 chained_handler_fn!(connect, CONNECT);
643 chained_handler_fn!(delete, DELETE);
644 chained_handler_fn!(get, GET);
645 chained_handler_fn!(head, HEAD);
646 chained_handler_fn!(options, OPTIONS);
647 chained_handler_fn!(patch, PATCH);
648 chained_handler_fn!(post, POST);
649 chained_handler_fn!(put, PUT);
650 chained_handler_fn!(trace, TRACE);
651
652 pub fn fallback<H, T>(mut self, handler: H) -> Self
654 where
655 H: Handler<T, S>,
656 T: 'static,
657 S: Send + Sync + 'static,
658 {
659 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
660 self
661 }
662
663 pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
665 where
666 H: Handler<T, S>,
667 T: 'static,
668 S: Send + Sync + 'static,
669 {
670 match self.fallback {
671 Fallback::Default(_) => self.fallback(handler),
672 _ => self,
673 }
674 }
675}
676
677impl MethodRouter<(), Infallible> {
678 pub fn into_make_service(self) -> IntoMakeService<Self> {
706 IntoMakeService::new(self.with_state(()))
707 }
708
709 #[cfg(feature = "tokio")]
738 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
739 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
740 }
741}
742
743impl<S, E> MethodRouter<S, E>
744where
745 S: Clone,
746{
747 pub fn new() -> Self {
750 let fallback = Route::new(service_fn(|_: Request| async {
751 Ok(StatusCode::METHOD_NOT_ALLOWED)
752 }));
753
754 Self {
755 get: MethodEndpoint::None,
756 head: MethodEndpoint::None,
757 delete: MethodEndpoint::None,
758 options: MethodEndpoint::None,
759 patch: MethodEndpoint::None,
760 post: MethodEndpoint::None,
761 put: MethodEndpoint::None,
762 trace: MethodEndpoint::None,
763 connect: MethodEndpoint::None,
764 allow_header: AllowHeader::None,
765 fallback: Fallback::Default(fallback),
766 }
767 }
768
769 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
771 MethodRouter {
772 get: self.get.with_state(&state),
773 head: self.head.with_state(&state),
774 delete: self.delete.with_state(&state),
775 options: self.options.with_state(&state),
776 patch: self.patch.with_state(&state),
777 post: self.post.with_state(&state),
778 put: self.put.with_state(&state),
779 trace: self.trace.with_state(&state),
780 connect: self.connect.with_state(&state),
781 allow_header: self.allow_header,
782 fallback: self.fallback.with_state(state),
783 }
784 }
785
786 #[track_caller]
810 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
811 where
812 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
813 T::Response: IntoResponse + 'static,
814 T::Future: Send + 'static,
815 {
816 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
817 }
818
819 #[track_caller]
820 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
821 #[track_caller]
823 fn set_endpoint<S, E>(
824 method_name: &str,
825 out: &mut MethodEndpoint<S, E>,
826 endpoint: &MethodEndpoint<S, E>,
827 endpoint_filter: MethodFilter,
828 filter: MethodFilter,
829 allow_header: &mut AllowHeader,
830 methods: &[&'static str],
831 ) where
832 MethodEndpoint<S, E>: Clone,
833 S: Clone,
834 {
835 if endpoint_filter.contains(filter) {
836 if out.is_some() {
837 panic!(
838 "Overlapping method route. Cannot add two method routes that both handle \
839 `{method_name}`",
840 )
841 }
842 *out = endpoint.clone();
843 for method in methods {
844 append_allow_header(allow_header, method);
845 }
846 }
847 }
848
849 set_endpoint(
850 "GET",
851 &mut self.get,
852 &endpoint,
853 filter,
854 MethodFilter::GET,
855 &mut self.allow_header,
856 &["GET", "HEAD"],
857 );
858
859 set_endpoint(
860 "HEAD",
861 &mut self.head,
862 &endpoint,
863 filter,
864 MethodFilter::HEAD,
865 &mut self.allow_header,
866 &["HEAD"],
867 );
868
869 set_endpoint(
870 "TRACE",
871 &mut self.trace,
872 &endpoint,
873 filter,
874 MethodFilter::TRACE,
875 &mut self.allow_header,
876 &["TRACE"],
877 );
878
879 set_endpoint(
880 "PUT",
881 &mut self.put,
882 &endpoint,
883 filter,
884 MethodFilter::PUT,
885 &mut self.allow_header,
886 &["PUT"],
887 );
888
889 set_endpoint(
890 "POST",
891 &mut self.post,
892 &endpoint,
893 filter,
894 MethodFilter::POST,
895 &mut self.allow_header,
896 &["POST"],
897 );
898
899 set_endpoint(
900 "PATCH",
901 &mut self.patch,
902 &endpoint,
903 filter,
904 MethodFilter::PATCH,
905 &mut self.allow_header,
906 &["PATCH"],
907 );
908
909 set_endpoint(
910 "OPTIONS",
911 &mut self.options,
912 &endpoint,
913 filter,
914 MethodFilter::OPTIONS,
915 &mut self.allow_header,
916 &["OPTIONS"],
917 );
918
919 set_endpoint(
920 "DELETE",
921 &mut self.delete,
922 &endpoint,
923 filter,
924 MethodFilter::DELETE,
925 &mut self.allow_header,
926 &["DELETE"],
927 );
928
929 set_endpoint(
930 "CONNECT",
931 &mut self.options,
932 &endpoint,
933 filter,
934 MethodFilter::CONNECT,
935 &mut self.allow_header,
936 &["CONNECT"],
937 );
938
939 self
940 }
941
942 chained_service_fn!(connect_service, CONNECT);
943 chained_service_fn!(delete_service, DELETE);
944 chained_service_fn!(get_service, GET);
945 chained_service_fn!(head_service, HEAD);
946 chained_service_fn!(options_service, OPTIONS);
947 chained_service_fn!(patch_service, PATCH);
948 chained_service_fn!(post_service, POST);
949 chained_service_fn!(put_service, PUT);
950 chained_service_fn!(trace_service, TRACE);
951
952 #[doc = include_str!("../docs/method_routing/fallback.md")]
953 pub fn fallback_service<T>(mut self, svc: T) -> Self
954 where
955 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
956 T::Response: IntoResponse + 'static,
957 T::Future: Send + 'static,
958 {
959 self.fallback = Fallback::Service(Route::new(svc));
960 self
961 }
962
963 #[doc = include_str!("../docs/method_routing/layer.md")]
964 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
965 where
966 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
967 L::Service: Service<Request> + Clone + Send + Sync + 'static,
968 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
969 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
970 <L::Service as Service<Request>>::Future: Send + 'static,
971 E: 'static,
972 S: 'static,
973 NewError: 'static,
974 {
975 let layer_fn = move |route: Route<E>| route.layer(layer.clone());
976
977 MethodRouter {
978 get: self.get.map(layer_fn.clone()),
979 head: self.head.map(layer_fn.clone()),
980 delete: self.delete.map(layer_fn.clone()),
981 options: self.options.map(layer_fn.clone()),
982 patch: self.patch.map(layer_fn.clone()),
983 post: self.post.map(layer_fn.clone()),
984 put: self.put.map(layer_fn.clone()),
985 trace: self.trace.map(layer_fn.clone()),
986 connect: self.connect.map(layer_fn.clone()),
987 fallback: self.fallback.map(layer_fn),
988 allow_header: self.allow_header,
989 }
990 }
991
992 #[doc = include_str!("../docs/method_routing/route_layer.md")]
993 #[track_caller]
994 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
995 where
996 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
997 L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
998 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
999 <L::Service as Service<Request>>::Future: Send + 'static,
1000 E: 'static,
1001 S: 'static,
1002 {
1003 if self.get.is_none()
1004 && self.head.is_none()
1005 && self.delete.is_none()
1006 && self.options.is_none()
1007 && self.patch.is_none()
1008 && self.post.is_none()
1009 && self.put.is_none()
1010 && self.trace.is_none()
1011 && self.connect.is_none()
1012 {
1013 panic!(
1014 "Adding a route_layer before any routes is a no-op. \
1015 Add the routes you want the layer to apply to first."
1016 );
1017 }
1018
1019 let layer_fn = move |svc| Route::new(layer.layer(svc));
1020
1021 self.get = self.get.map(layer_fn.clone());
1022 self.head = self.head.map(layer_fn.clone());
1023 self.delete = self.delete.map(layer_fn.clone());
1024 self.options = self.options.map(layer_fn.clone());
1025 self.patch = self.patch.map(layer_fn.clone());
1026 self.post = self.post.map(layer_fn.clone());
1027 self.put = self.put.map(layer_fn.clone());
1028 self.trace = self.trace.map(layer_fn.clone());
1029 self.connect = self.connect.map(layer_fn);
1030
1031 self
1032 }
1033
1034 #[track_caller]
1035 pub(crate) fn merge_for_path(mut self, path: Option<&str>, other: MethodRouter<S, E>) -> Self {
1036 #[track_caller]
1038 fn merge_inner<S, E>(
1039 path: Option<&str>,
1040 name: &str,
1041 first: MethodEndpoint<S, E>,
1042 second: MethodEndpoint<S, E>,
1043 ) -> MethodEndpoint<S, E> {
1044 match (first, second) {
1045 (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1046 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1047 _ => {
1048 if let Some(path) = path {
1049 panic!(
1050 "Overlapping method route. Handler for `{name} {path}` already exists"
1051 );
1052 } else {
1053 panic!(
1054 "Overlapping method route. Cannot merge two method routes that both \
1055 define `{name}`"
1056 );
1057 }
1058 }
1059 }
1060 }
1061
1062 self.get = merge_inner(path, "GET", self.get, other.get);
1063 self.head = merge_inner(path, "HEAD", self.head, other.head);
1064 self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1065 self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1066 self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1067 self.post = merge_inner(path, "POST", self.post, other.post);
1068 self.put = merge_inner(path, "PUT", self.put, other.put);
1069 self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1070 self.connect = merge_inner(path, "CONNECT", self.connect, other.connect);
1071
1072 self.fallback = self
1073 .fallback
1074 .merge(other.fallback)
1075 .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1076
1077 self.allow_header = self.allow_header.merge(other.allow_header);
1078
1079 self
1080 }
1081
1082 #[doc = include_str!("../docs/method_routing/merge.md")]
1083 #[track_caller]
1084 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1085 self.merge_for_path(None, other)
1086 }
1087
1088 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1092 where
1093 F: Clone + Send + Sync + 'static,
1094 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1095 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1096 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1097 T: 'static,
1098 E: 'static,
1099 S: 'static,
1100 {
1101 self.layer(HandleErrorLayer::new(f))
1102 }
1103
1104 fn skip_allow_header(mut self) -> Self {
1105 self.allow_header = AllowHeader::Skip;
1106 self
1107 }
1108
1109 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1110 macro_rules! call {
1111 (
1112 $req:expr,
1113 $method_variant:ident,
1114 $svc:expr
1115 ) => {
1116 if *req.method() == Method::$method_variant {
1117 match $svc {
1118 MethodEndpoint::None => {}
1119 MethodEndpoint::Route(route) => {
1120 return route.clone().oneshot_inner_owned($req);
1121 }
1122 MethodEndpoint::BoxedHandler(handler) => {
1123 let route = handler.clone().into_route(state);
1124 return route.oneshot_inner_owned($req);
1125 }
1126 }
1127 }
1128 };
1129 }
1130
1131 let Self {
1133 get,
1134 head,
1135 delete,
1136 options,
1137 patch,
1138 post,
1139 put,
1140 trace,
1141 connect,
1142 fallback,
1143 allow_header,
1144 } = self;
1145
1146 call!(req, HEAD, head);
1147 call!(req, HEAD, get);
1148 call!(req, GET, get);
1149 call!(req, POST, post);
1150 call!(req, OPTIONS, options);
1151 call!(req, PATCH, patch);
1152 call!(req, PUT, put);
1153 call!(req, DELETE, delete);
1154 call!(req, TRACE, trace);
1155 call!(req, CONNECT, connect);
1156
1157 let future = fallback.clone().call_with_state(req, state);
1158
1159 match allow_header {
1160 AllowHeader::None => future.allow_header(Bytes::new()),
1161 AllowHeader::Skip => future,
1162 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1163 }
1164 }
1165}
1166
1167fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1168 match allow_header {
1169 AllowHeader::None => {
1170 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1171 }
1172 AllowHeader::Skip => {}
1173 AllowHeader::Bytes(allow_header) => {
1174 if let Ok(s) = std::str::from_utf8(allow_header) {
1175 if !s.contains(method) {
1176 allow_header.extend_from_slice(b",");
1177 allow_header.extend_from_slice(method.as_bytes());
1178 }
1179 } else {
1180 #[cfg(debug_assertions)]
1181 panic!("`allow_header` contained invalid uft-8. This should never happen")
1182 }
1183 }
1184 }
1185}
1186
1187impl<S, E> Clone for MethodRouter<S, E> {
1188 fn clone(&self) -> Self {
1189 Self {
1190 get: self.get.clone(),
1191 head: self.head.clone(),
1192 delete: self.delete.clone(),
1193 options: self.options.clone(),
1194 patch: self.patch.clone(),
1195 post: self.post.clone(),
1196 put: self.put.clone(),
1197 trace: self.trace.clone(),
1198 connect: self.connect.clone(),
1199 fallback: self.fallback.clone(),
1200 allow_header: self.allow_header.clone(),
1201 }
1202 }
1203}
1204
1205impl<S, E> Default for MethodRouter<S, E>
1206where
1207 S: Clone,
1208{
1209 fn default() -> Self {
1210 Self::new()
1211 }
1212}
1213
1214enum MethodEndpoint<S, E> {
1215 None,
1216 Route(Route<E>),
1217 BoxedHandler(BoxedIntoRoute<S, E>),
1218}
1219
1220impl<S, E> MethodEndpoint<S, E>
1221where
1222 S: Clone,
1223{
1224 fn is_some(&self) -> bool {
1225 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1226 }
1227
1228 fn is_none(&self) -> bool {
1229 matches!(self, Self::None)
1230 }
1231
1232 fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1233 where
1234 S: 'static,
1235 E: 'static,
1236 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1237 E2: 'static,
1238 {
1239 match self {
1240 Self::None => MethodEndpoint::None,
1241 Self::Route(route) => MethodEndpoint::Route(f(route)),
1242 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1243 }
1244 }
1245
1246 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1247 match self {
1248 MethodEndpoint::None => MethodEndpoint::None,
1249 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1250 MethodEndpoint::BoxedHandler(handler) => {
1251 MethodEndpoint::Route(handler.into_route(state.clone()))
1252 }
1253 }
1254 }
1255}
1256
1257impl<S, E> Clone for MethodEndpoint<S, E> {
1258 fn clone(&self) -> Self {
1259 match self {
1260 Self::None => Self::None,
1261 Self::Route(inner) => Self::Route(inner.clone()),
1262 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1263 }
1264 }
1265}
1266
1267impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1269 match self {
1270 Self::None => f.debug_tuple("None").finish(),
1271 Self::Route(inner) => inner.fmt(f),
1272 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1273 }
1274 }
1275}
1276
1277impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1278where
1279 B: HttpBody<Data = Bytes> + Send + 'static,
1280 B::Error: Into<BoxError>,
1281{
1282 type Response = Response;
1283 type Error = E;
1284 type Future = RouteFuture<E>;
1285
1286 #[inline]
1287 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1288 Poll::Ready(Ok(()))
1289 }
1290
1291 #[inline]
1292 fn call(&mut self, req: Request<B>) -> Self::Future {
1293 let req = req.map(Body::new);
1294 self.call_with_state(req, ())
1295 }
1296}
1297
1298impl<S> Handler<(), S> for MethodRouter<S>
1299where
1300 S: Clone + 'static,
1301{
1302 type Future = InfallibleRouteFuture;
1303
1304 fn call(self, req: Request, state: S) -> Self::Future {
1305 InfallibleRouteFuture::new(self.call_with_state(req, state))
1306 }
1307}
1308
1309#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1311const _: () = {
1312 use crate::serve;
1313
1314 impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1315 where
1316 L: serve::Listener,
1317 {
1318 type Response = Self;
1319 type Error = Infallible;
1320 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1321
1322 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1323 Poll::Ready(Ok(()))
1324 }
1325
1326 fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1327 std::future::ready(Ok(self.clone().with_state(())))
1328 }
1329 }
1330};
1331
1332#[cfg(test)]
1333mod tests {
1334 use super::*;
1335 use crate::{extract::State, handler::HandlerWithoutStateExt};
1336 use http::{header::ALLOW, HeaderMap};
1337 use http_body_util::BodyExt;
1338 use std::time::Duration;
1339 use tower::ServiceExt;
1340 use tower_http::{
1341 services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1342 };
1343
1344 #[crate::test]
1345 async fn method_not_allowed_by_default() {
1346 let mut svc = MethodRouter::new();
1347 let (status, _, body) = call(Method::GET, &mut svc).await;
1348 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1349 assert!(body.is_empty());
1350 }
1351
1352 #[crate::test]
1353 async fn get_service_fn() {
1354 async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1355 Ok(Response::new(Body::from("ok")))
1356 }
1357
1358 let mut svc = get_service(service_fn(handle));
1359
1360 let (status, _, body) = call(Method::GET, &mut svc).await;
1361 assert_eq!(status, StatusCode::OK);
1362 assert_eq!(body, "ok");
1363 }
1364
1365 #[crate::test]
1366 async fn get_handler() {
1367 let mut svc = MethodRouter::new().get(ok);
1368 let (status, _, body) = call(Method::GET, &mut svc).await;
1369 assert_eq!(status, StatusCode::OK);
1370 assert_eq!(body, "ok");
1371 }
1372
1373 #[crate::test]
1374 async fn get_accepts_head() {
1375 let mut svc = MethodRouter::new().get(ok);
1376 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1377 assert_eq!(status, StatusCode::OK);
1378 assert!(body.is_empty());
1379 }
1380
1381 #[crate::test]
1382 async fn head_takes_precedence_over_get() {
1383 let mut svc = MethodRouter::new().head(created).get(ok);
1384 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1385 assert_eq!(status, StatusCode::CREATED);
1386 assert!(body.is_empty());
1387 }
1388
1389 #[crate::test]
1390 async fn merge() {
1391 let mut svc = get(ok).merge(post(ok));
1392
1393 let (status, _, _) = call(Method::GET, &mut svc).await;
1394 assert_eq!(status, StatusCode::OK);
1395
1396 let (status, _, _) = call(Method::POST, &mut svc).await;
1397 assert_eq!(status, StatusCode::OK);
1398 }
1399
1400 #[crate::test]
1401 async fn layer() {
1402 let mut svc = MethodRouter::new()
1403 .get(|| async { std::future::pending::<()>().await })
1404 .layer(ValidateRequestHeaderLayer::bearer("password"));
1405
1406 let (status, _, _) = call(Method::GET, &mut svc).await;
1408 assert_eq!(status, StatusCode::UNAUTHORIZED);
1409
1410 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1412 assert_eq!(status, StatusCode::UNAUTHORIZED);
1413 }
1414
1415 #[crate::test]
1416 async fn route_layer() {
1417 let mut svc = MethodRouter::new()
1418 .get(|| async { std::future::pending::<()>().await })
1419 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1420
1421 let (status, _, _) = call(Method::GET, &mut svc).await;
1423 assert_eq!(status, StatusCode::UNAUTHORIZED);
1424
1425 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1427 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1428 }
1429
1430 #[allow(dead_code)]
1431 async fn building_complex_router() {
1432 let app = crate::Router::new().route(
1433 "/",
1434 get(ok)
1436 .post(ok)
1437 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1438 .merge(delete_service(ServeDir::new(".")))
1439 .fallback(|| async { StatusCode::NOT_FOUND })
1440 .put(ok)
1441 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1442 );
1443
1444 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1445 crate::serve(listener, app).await.unwrap();
1446 }
1447
1448 #[crate::test]
1449 async fn sets_allow_header() {
1450 let mut svc = MethodRouter::new().put(ok).patch(ok);
1451 let (status, headers, _) = call(Method::GET, &mut svc).await;
1452 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1453 assert_eq!(headers[ALLOW], "PUT,PATCH");
1454 }
1455
1456 #[crate::test]
1457 async fn sets_allow_header_get_head() {
1458 let mut svc = MethodRouter::new().get(ok).head(ok);
1459 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1460 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1461 assert_eq!(headers[ALLOW], "GET,HEAD");
1462 }
1463
1464 #[crate::test]
1465 async fn empty_allow_header_by_default() {
1466 let mut svc = MethodRouter::new();
1467 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1468 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1469 assert_eq!(headers[ALLOW], "");
1470 }
1471
1472 #[crate::test]
1473 async fn allow_header_when_merging() {
1474 let a = put(ok).patch(ok);
1475 let b = get(ok).head(ok);
1476 let mut svc = a.merge(b);
1477
1478 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1479 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1480 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1481 }
1482
1483 #[crate::test]
1484 async fn allow_header_any() {
1485 let mut svc = any(ok);
1486
1487 let (status, headers, _) = call(Method::GET, &mut svc).await;
1488 assert_eq!(status, StatusCode::OK);
1489 assert!(!headers.contains_key(ALLOW));
1490 }
1491
1492 #[crate::test]
1493 async fn allow_header_with_fallback() {
1494 let mut svc = MethodRouter::new()
1495 .get(ok)
1496 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1497
1498 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1499 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1500 assert_eq!(headers[ALLOW], "GET,HEAD");
1501 }
1502
1503 #[crate::test]
1504 async fn allow_header_with_fallback_that_sets_allow() {
1505 async fn fallback(method: Method) -> Response {
1506 if method == Method::POST {
1507 "OK".into_response()
1508 } else {
1509 (
1510 StatusCode::METHOD_NOT_ALLOWED,
1511 [(ALLOW, "GET,POST")],
1512 "Method not allowed",
1513 )
1514 .into_response()
1515 }
1516 }
1517
1518 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1519
1520 let (status, _, _) = call(Method::GET, &mut svc).await;
1521 assert_eq!(status, StatusCode::OK);
1522
1523 let (status, _, _) = call(Method::POST, &mut svc).await;
1524 assert_eq!(status, StatusCode::OK);
1525
1526 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1527 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1528 assert_eq!(headers[ALLOW], "GET,POST");
1529 }
1530
1531 #[crate::test]
1532 async fn allow_header_noop_middleware() {
1533 let mut svc = MethodRouter::new()
1534 .get(ok)
1535 .layer(tower::layer::util::Identity::new());
1536
1537 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1538 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1539 assert_eq!(headers[ALLOW], "GET,HEAD");
1540 }
1541
1542 #[crate::test]
1543 #[should_panic(
1544 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1545 )]
1546 async fn handler_overlaps() {
1547 let _: MethodRouter<()> = get(ok).get(ok);
1548 }
1549
1550 #[crate::test]
1551 #[should_panic(
1552 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1553 )]
1554 async fn service_overlaps() {
1555 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1556 }
1557
1558 #[crate::test]
1559 async fn get_head_does_not_overlap() {
1560 let _: MethodRouter<()> = get(ok).head(ok);
1561 }
1562
1563 #[crate::test]
1564 async fn head_get_does_not_overlap() {
1565 let _: MethodRouter<()> = head(ok).get(ok);
1566 }
1567
1568 #[crate::test]
1569 async fn accessing_state() {
1570 let mut svc = MethodRouter::new()
1571 .get(|State(state): State<&'static str>| async move { state })
1572 .with_state("state");
1573
1574 let (status, _, text) = call(Method::GET, &mut svc).await;
1575
1576 assert_eq!(status, StatusCode::OK);
1577 assert_eq!(text, "state");
1578 }
1579
1580 #[crate::test]
1581 async fn fallback_accessing_state() {
1582 let mut svc = MethodRouter::new()
1583 .fallback(|State(state): State<&'static str>| async move { state })
1584 .with_state("state");
1585
1586 let (status, _, text) = call(Method::GET, &mut svc).await;
1587
1588 assert_eq!(status, StatusCode::OK);
1589 assert_eq!(text, "state");
1590 }
1591
1592 #[crate::test]
1593 async fn merge_accessing_state() {
1594 let one = get(|State(state): State<&'static str>| async move { state });
1595 let two = post(|State(state): State<&'static str>| async move { state });
1596
1597 let mut svc = one.merge(two).with_state("state");
1598
1599 let (status, _, text) = call(Method::GET, &mut svc).await;
1600 assert_eq!(status, StatusCode::OK);
1601 assert_eq!(text, "state");
1602
1603 let (status, _, _) = call(Method::POST, &mut svc).await;
1604 assert_eq!(status, StatusCode::OK);
1605 assert_eq!(text, "state");
1606 }
1607
1608 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1609 where
1610 S: Service<Request, Error = Infallible>,
1611 S::Response: IntoResponse,
1612 {
1613 let request = Request::builder()
1614 .uri("/")
1615 .method(method)
1616 .body(Body::empty())
1617 .unwrap();
1618 let response = svc
1619 .ready()
1620 .await
1621 .unwrap()
1622 .call(request)
1623 .await
1624 .unwrap()
1625 .into_response();
1626 let (parts, body) = response.into_parts();
1627 let body =
1628 String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1629 (parts.status, parts.headers, body)
1630 }
1631
1632 async fn ok() -> (StatusCode, &'static str) {
1633 (StatusCode::OK, "ok")
1634 }
1635
1636 async fn created() -> (StatusCode, &'static str) {
1637 (StatusCode::CREATED, "created")
1638 }
1639}