1use crate::{
2 body::{Body, HttpBody},
3 response::Response,
4 util::AxumMutex,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9 header::{self, CONTENT_LENGTH},
10 HeaderMap, HeaderValue,
11};
12use pin_project_lite::pin_project;
13use std::{
14 convert::Infallible,
15 fmt,
16 future::Future,
17 pin::Pin,
18 task::{Context, Poll},
19};
20use tower::{
21 util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
22 ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
32
33impl<E> Route<E> {
34 pub(crate) fn new<T>(svc: T) -> Self
35 where
36 T: Service<Request, Error = E> + Clone + Send + 'static,
37 T::Response: IntoResponse + 'static,
38 T::Future: Send + 'static,
39 {
40 Self(AxumMutex::new(BoxCloneService::new(
41 svc.map_response(IntoResponse::into_response),
42 )))
43 }
44
45 pub(crate) fn oneshot_inner(
46 &mut self,
47 req: Request,
48 ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
49 self.0.get_mut().unwrap().clone().oneshot(req)
50 }
51
52 pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
53 where
54 L: Layer<Route<E>> + Clone + Send + 'static,
55 L::Service: Service<Request> + Clone + Send + 'static,
56 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
57 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
58 <L::Service as Service<Request>>::Future: Send + 'static,
59 NewError: 'static,
60 {
61 let layer = (
62 MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
63 MapErrLayer::new(Into::into),
64 MapResponseLayer::new(IntoResponse::into_response),
65 layer,
66 );
67
68 Route::new(layer.layer(self))
69 }
70}
71
72impl<E> Clone for Route<E> {
73 #[track_caller]
74 fn clone(&self) -> Self {
75 Self(AxumMutex::new(self.0.lock().unwrap().clone()))
76 }
77}
78
79impl<E> fmt::Debug for Route<E> {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 f.debug_struct("Route").finish()
82 }
83}
84
85impl<B, E> Service<Request<B>> for Route<E>
86where
87 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
88 B::Error: Into<axum_core::BoxError>,
89{
90 type Response = Response;
91 type Error = E;
92 type Future = RouteFuture<E>;
93
94 #[inline]
95 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 Poll::Ready(Ok(()))
97 }
98
99 #[inline]
100 fn call(&mut self, req: Request<B>) -> Self::Future {
101 let req = req.map(Body::new);
102 RouteFuture::from_future(self.oneshot_inner(req))
103 }
104}
105
106pin_project! {
107 pub struct RouteFuture<E> {
109 #[pin]
110 kind: RouteFutureKind<E>,
111 strip_body: bool,
112 allow_header: Option<Bytes>,
113 }
114}
115
116pin_project! {
117 #[project = RouteFutureKindProj]
118 enum RouteFutureKind<E> {
119 Future {
120 #[pin]
121 future: Oneshot<
122 BoxCloneService<Request, Response, E>,
123 Request,
124 >,
125 },
126 Response {
127 response: Option<Response>,
128 }
129 }
130}
131
132impl<E> RouteFuture<E> {
133 pub(crate) fn from_future(
134 future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
135 ) -> Self {
136 Self {
137 kind: RouteFutureKind::Future { future },
138 strip_body: false,
139 allow_header: None,
140 }
141 }
142
143 pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
144 self.strip_body = strip_body;
145 self
146 }
147
148 pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
149 self.allow_header = Some(allow_header);
150 self
151 }
152}
153
154impl<E> Future for RouteFuture<E> {
155 type Output = Result<Response, E>;
156
157 #[inline]
158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 let this = self.project();
160
161 let mut res = match this.kind.project() {
162 RouteFutureKindProj::Future { future } => match future.poll(cx) {
163 Poll::Ready(Ok(res)) => res,
164 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
165 Poll::Pending => return Poll::Pending,
166 },
167 RouteFutureKindProj::Response { response } => {
168 response.take().expect("future polled after completion")
169 }
170 };
171
172 set_allow_header(res.headers_mut(), this.allow_header);
173
174 set_content_length(res.size_hint(), res.headers_mut());
176
177 let res = if *this.strip_body {
178 res.map(|_| Body::empty())
179 } else {
180 res
181 };
182
183 Poll::Ready(Ok(res))
184 }
185}
186
187fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
188 match allow_header.take() {
189 Some(allow_header) if !headers.contains_key(header::ALLOW) => {
190 headers.insert(
191 header::ALLOW,
192 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
193 );
194 }
195 _ => {}
196 }
197}
198
199fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
200 if headers.contains_key(CONTENT_LENGTH) {
201 return;
202 }
203
204 if let Some(size) = size_hint.exact() {
205 let header_value = if size == 0 {
206 #[allow(clippy::declare_interior_mutable_const)]
207 const ZERO: HeaderValue = HeaderValue::from_static("0");
208
209 ZERO
210 } else {
211 let mut buffer = itoa::Buffer::new();
212 HeaderValue::from_str(buffer.format(size)).unwrap()
213 };
214
215 headers.insert(CONTENT_LENGTH, header_value);
216 }
217}
218
219pin_project! {
220 pub struct InfallibleRouteFuture {
222 #[pin]
223 future: RouteFuture<Infallible>,
224 }
225}
226
227impl InfallibleRouteFuture {
228 pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
229 Self { future }
230 }
231}
232
233impl Future for InfallibleRouteFuture {
234 type Output = Response;
235
236 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 match futures_util::ready!(self.project().future.poll(cx)) {
238 Ok(response) => Poll::Ready(response),
239 Err(err) => match err {},
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn traits() {
250 use crate::test_helpers::*;
251 assert_send::<Route<()>>();
252 }
253}