axum/routing/
route.rs

1use crate::{
2    body::{Body, HttpBody},
3    response::Response,
4    util::AxumMutex,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9    header::{self, CONTENT_LENGTH},
10    HeaderMap, HeaderValue,
11};
12use pin_project_lite::pin_project;
13use std::{
14    convert::Infallible,
15    fmt,
16    future::Future,
17    pin::Pin,
18    task::{Context, Poll},
19};
20use tower::{
21    util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
22    ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27/// How routes are stored inside a [`Router`](super::Router).
28///
29/// You normally shouldn't need to care about this type. It's used in
30/// [`Router::layer`](super::Router::layer).
31pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
32
33impl<E> Route<E> {
34    pub(crate) fn new<T>(svc: T) -> Self
35    where
36        T: Service<Request, Error = E> + Clone + Send + 'static,
37        T::Response: IntoResponse + 'static,
38        T::Future: Send + 'static,
39    {
40        Self(AxumMutex::new(BoxCloneService::new(
41            svc.map_response(IntoResponse::into_response),
42        )))
43    }
44
45    pub(crate) fn oneshot_inner(
46        &mut self,
47        req: Request,
48    ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
49        self.0.get_mut().unwrap().clone().oneshot(req)
50    }
51
52    pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
53    where
54        L: Layer<Route<E>> + Clone + Send + 'static,
55        L::Service: Service<Request> + Clone + Send + 'static,
56        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
57        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
58        <L::Service as Service<Request>>::Future: Send + 'static,
59        NewError: 'static,
60    {
61        let layer = (
62            MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
63            MapErrLayer::new(Into::into),
64            MapResponseLayer::new(IntoResponse::into_response),
65            layer,
66        );
67
68        Route::new(layer.layer(self))
69    }
70}
71
72impl<E> Clone for Route<E> {
73    #[track_caller]
74    fn clone(&self) -> Self {
75        Self(AxumMutex::new(self.0.lock().unwrap().clone()))
76    }
77}
78
79impl<E> fmt::Debug for Route<E> {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_struct("Route").finish()
82    }
83}
84
85impl<B, E> Service<Request<B>> for Route<E>
86where
87    B: HttpBody<Data = bytes::Bytes> + Send + 'static,
88    B::Error: Into<axum_core::BoxError>,
89{
90    type Response = Response;
91    type Error = E;
92    type Future = RouteFuture<E>;
93
94    #[inline]
95    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96        Poll::Ready(Ok(()))
97    }
98
99    #[inline]
100    fn call(&mut self, req: Request<B>) -> Self::Future {
101        let req = req.map(Body::new);
102        RouteFuture::from_future(self.oneshot_inner(req))
103    }
104}
105
106pin_project! {
107    /// Response future for [`Route`].
108    pub struct RouteFuture<E> {
109        #[pin]
110        kind: RouteFutureKind<E>,
111        strip_body: bool,
112        allow_header: Option<Bytes>,
113    }
114}
115
116pin_project! {
117    #[project = RouteFutureKindProj]
118    enum RouteFutureKind<E> {
119        Future {
120            #[pin]
121            future: Oneshot<
122                BoxCloneService<Request, Response, E>,
123                Request,
124            >,
125        },
126        Response {
127            response: Option<Response>,
128        }
129    }
130}
131
132impl<E> RouteFuture<E> {
133    pub(crate) fn from_future(
134        future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
135    ) -> Self {
136        Self {
137            kind: RouteFutureKind::Future { future },
138            strip_body: false,
139            allow_header: None,
140        }
141    }
142
143    pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
144        self.strip_body = strip_body;
145        self
146    }
147
148    pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
149        self.allow_header = Some(allow_header);
150        self
151    }
152}
153
154impl<E> Future for RouteFuture<E> {
155    type Output = Result<Response, E>;
156
157    #[inline]
158    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        let this = self.project();
160
161        let mut res = match this.kind.project() {
162            RouteFutureKindProj::Future { future } => match future.poll(cx) {
163                Poll::Ready(Ok(res)) => res,
164                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
165                Poll::Pending => return Poll::Pending,
166            },
167            RouteFutureKindProj::Response { response } => {
168                response.take().expect("future polled after completion")
169            }
170        };
171
172        set_allow_header(res.headers_mut(), this.allow_header);
173
174        // make sure to set content-length before removing the body
175        set_content_length(res.size_hint(), res.headers_mut());
176
177        let res = if *this.strip_body {
178            res.map(|_| Body::empty())
179        } else {
180            res
181        };
182
183        Poll::Ready(Ok(res))
184    }
185}
186
187fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
188    match allow_header.take() {
189        Some(allow_header) if !headers.contains_key(header::ALLOW) => {
190            headers.insert(
191                header::ALLOW,
192                HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
193            );
194        }
195        _ => {}
196    }
197}
198
199fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
200    if headers.contains_key(CONTENT_LENGTH) {
201        return;
202    }
203
204    if let Some(size) = size_hint.exact() {
205        let header_value = if size == 0 {
206            #[allow(clippy::declare_interior_mutable_const)]
207            const ZERO: HeaderValue = HeaderValue::from_static("0");
208
209            ZERO
210        } else {
211            let mut buffer = itoa::Buffer::new();
212            HeaderValue::from_str(buffer.format(size)).unwrap()
213        };
214
215        headers.insert(CONTENT_LENGTH, header_value);
216    }
217}
218
219pin_project! {
220    /// A [`RouteFuture`] that always yields a [`Response`].
221    pub struct InfallibleRouteFuture {
222        #[pin]
223        future: RouteFuture<Infallible>,
224    }
225}
226
227impl InfallibleRouteFuture {
228    pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
229        Self { future }
230    }
231}
232
233impl Future for InfallibleRouteFuture {
234    type Output = Response;
235
236    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        match futures_util::ready!(self.project().future.poll(cx)) {
238            Ok(response) => Poll::Ready(response),
239            Err(err) => match err {},
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn traits() {
250        use crate::test_helpers::*;
251        assert_send::<Route<()>>();
252    }
253}