axum/middleware/
map_request.rs

1use crate::body::{Body, Bytes, HttpBody};
2use crate::response::{IntoResponse, Response};
3use crate::BoxError;
4use axum_core::extract::{FromRequest, FromRequestParts};
5use futures_util::future::BoxFuture;
6use http::Request;
7use std::{
8    any::type_name,
9    convert::Infallible,
10    fmt,
11    future::Future,
12    marker::PhantomData,
13    pin::Pin,
14    task::{Context, Poll},
15};
16use tower_layer::Layer;
17use tower_service::Service;
18
19/// Create a middleware from an async function that transforms a request.
20///
21/// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific
22/// extractors.
23///
24/// # Example
25///
26/// ```
27/// use axum::{
28///     Router,
29///     routing::get,
30///     middleware::map_request,
31///     http::Request,
32/// };
33///
34/// async fn set_header<B>(mut request: Request<B>) -> Request<B> {
35///     request.headers_mut().insert("x-foo", "foo".parse().unwrap());
36///     request
37/// }
38///
39/// async fn handler<B>(request: Request<B>) {
40///     // `request` will have an `x-foo` header
41/// }
42///
43/// let app = Router::new()
44///     .route("/", get(handler))
45///     .layer(map_request(set_header));
46/// # let _: Router = app;
47/// ```
48///
49/// # Rejecting the request
50///
51/// The function given to `map_request` is allowed to also return a `Result` which can be used to
52/// reject the request and return a response immediately, without calling the remaining
53/// middleware.
54///
55/// Specifically the valid return types are:
56///
57/// - `Request<B>`
58/// - `Result<Request<B>, E> where E:  IntoResponse`
59///
60/// ```
61/// use axum::{
62///     Router,
63///     http::{Request, StatusCode},
64///     routing::get,
65///     middleware::map_request,
66/// };
67///
68/// async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
69///     let auth_header = request.headers()
70///         .get(http::header::AUTHORIZATION)
71///         .and_then(|header| header.to_str().ok());
72///
73///     match auth_header {
74///         Some(auth_header) if token_is_valid(auth_header) => Ok(request),
75///         _ => Err(StatusCode::UNAUTHORIZED),
76///     }
77/// }
78///
79/// fn token_is_valid(token: &str) -> bool {
80///     // ...
81///     # false
82/// }
83///
84/// let app = Router::new()
85///     .route("/", get(|| async { /* ... */ }))
86///     .route_layer(map_request(auth));
87/// # let app: Router = app;
88/// ```
89///
90/// # Running extractors
91///
92/// ```
93/// use axum::{
94///     Router,
95///     routing::get,
96///     middleware::map_request,
97///     extract::Path,
98///     http::Request,
99/// };
100/// use std::collections::HashMap;
101///
102/// async fn log_path_params<B>(
103///     Path(path_params): Path<HashMap<String, String>>,
104///     request: Request<B>,
105/// ) -> Request<B> {
106///     tracing::debug!(?path_params);
107///     request
108/// }
109///
110/// let app = Router::new()
111///     .route("/", get(|| async { /* ... */ }))
112///     .layer(map_request(log_path_params));
113/// # let _: Router = app;
114/// ```
115///
116/// Note that to access state you must use either [`map_request_with_state`].
117pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
118    map_request_with_state((), f)
119}
120
121/// Create a middleware from an async function that transforms a request, with the given state.
122///
123/// See [`State`](crate::extract::State) for more details about accessing state.
124///
125/// # Example
126///
127/// ```rust
128/// use axum::{
129///     Router,
130///     http::{Request, StatusCode},
131///     routing::get,
132///     response::IntoResponse,
133///     middleware::map_request_with_state,
134///     extract::State,
135/// };
136///
137/// #[derive(Clone)]
138/// struct AppState { /* ... */ }
139///
140/// async fn my_middleware<B>(
141///     State(state): State<AppState>,
142///     // you can add more extractors here but the last
143///     // extractor must implement `FromRequest` which
144///     // `Request` does
145///     request: Request<B>,
146/// ) -> Request<B> {
147///     // do something with `state` and `request`...
148///     request
149/// }
150///
151/// let state = AppState { /* ... */ };
152///
153/// let app = Router::new()
154///     .route("/", get(|| async { /* ... */ }))
155///     .route_layer(map_request_with_state(state.clone(), my_middleware))
156///     .with_state(state);
157/// # let _: axum::Router = app;
158/// ```
159pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
160    MapRequestLayer {
161        f,
162        state,
163        _extractor: PhantomData,
164    }
165}
166
167/// A [`tower::Layer`] from an async function that transforms a request.
168///
169/// Created with [`map_request`]. See that function for more details.
170#[must_use]
171pub struct MapRequestLayer<F, S, T> {
172    f: F,
173    state: S,
174    _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for MapRequestLayer<F, S, T>
178where
179    F: Clone,
180    S: Clone,
181{
182    fn clone(&self) -> Self {
183        Self {
184            f: self.f.clone(),
185            state: self.state.clone(),
186            _extractor: self._extractor,
187        }
188    }
189}
190
191impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
192where
193    F: Clone,
194    S: Clone,
195{
196    type Service = MapRequest<F, S, I, T>;
197
198    fn layer(&self, inner: I) -> Self::Service {
199        MapRequest {
200            f: self.f.clone(),
201            state: self.state.clone(),
202            inner,
203            _extractor: PhantomData,
204        }
205    }
206}
207
208impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
209where
210    S: fmt::Debug,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        f.debug_struct("MapRequestLayer")
214            // Write out the type name, without quoting it as `&type_name::<F>()` would
215            .field("f", &format_args!("{}", type_name::<F>()))
216            .field("state", &self.state)
217            .finish()
218    }
219}
220
221/// A middleware created from an async function that transforms a request.
222///
223/// Created with [`map_request`]. See that function for more details.
224pub struct MapRequest<F, S, I, T> {
225    f: F,
226    inner: I,
227    state: S,
228    _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
232where
233    F: Clone,
234    I: Clone,
235    S: Clone,
236{
237    fn clone(&self) -> Self {
238        Self {
239            f: self.f.clone(),
240            inner: self.inner.clone(),
241            state: self.state.clone(),
242            _extractor: self._extractor,
243        }
244    }
245}
246
247macro_rules! impl_service {
248    (
249        [$($ty:ident),*], $last:ident
250    ) => {
251        #[allow(non_snake_case, unused_mut)]
252        impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
253        where
254            F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
255            $( $ty: FromRequestParts<S> + Send, )*
256            $last: FromRequest<S> + Send,
257            Fut: Future + Send + 'static,
258            Fut::Output: IntoMapRequestResult<B> + Send + 'static,
259            I: Service<Request<B>, Error = Infallible>
260                + Clone
261                + Send
262                + 'static,
263            I::Response: IntoResponse,
264            I::Future: Send + 'static,
265            B: HttpBody<Data = Bytes> + Send + 'static,
266            B::Error: Into<BoxError>,
267            S: Clone + Send + Sync + 'static,
268        {
269            type Response = Response;
270            type Error = Infallible;
271            type Future = ResponseFuture;
272
273            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274                self.inner.poll_ready(cx)
275            }
276
277            fn call(&mut self, req: Request<B>) -> Self::Future {
278                let req = req.map(Body::new);
279
280                let not_ready_inner = self.inner.clone();
281                let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
282
283                let mut f = self.f.clone();
284                let state = self.state.clone();
285                let (mut parts, body) = req.into_parts();
286
287                let future = Box::pin(async move {
288                    $(
289                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
290                            Ok(value) => value,
291                            Err(rejection) => return rejection.into_response(),
292                        };
293                    )*
294
295                    let req = Request::from_parts(parts, body);
296
297                    let $last = match $last::from_request(req, &state).await {
298                        Ok(value) => value,
299                        Err(rejection) => return rejection.into_response(),
300                    };
301
302                    match f($($ty,)* $last).await.into_map_request_result() {
303                        Ok(req) => {
304                            ready_inner.call(req).await.into_response()
305                        }
306                        Err(res) => {
307                            res
308                        }
309                    }
310                });
311
312                ResponseFuture {
313                    inner: future
314                }
315            }
316        }
317    };
318}
319
320all_the_tuples!(impl_service);
321
322impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
323where
324    S: fmt::Debug,
325    I: fmt::Debug,
326{
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.debug_struct("MapRequest")
329            .field("f", &format_args!("{}", type_name::<F>()))
330            .field("inner", &self.inner)
331            .field("state", &self.state)
332            .finish()
333    }
334}
335
336/// Response future for [`MapRequest`].
337pub struct ResponseFuture {
338    inner: BoxFuture<'static, Response>,
339}
340
341impl Future for ResponseFuture {
342    type Output = Result<Response, Infallible>;
343
344    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
345        self.inner.as_mut().poll(cx).map(Ok)
346    }
347}
348
349impl fmt::Debug for ResponseFuture {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        f.debug_struct("ResponseFuture").finish()
352    }
353}
354
355mod private {
356    use crate::{http::Request, response::IntoResponse};
357
358    pub trait Sealed<B> {}
359    impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
360    impl<B> Sealed<B> for Request<B> {}
361}
362
363/// Trait implemented by types that can be returned from [`map_request`],
364/// [`map_request_with_state`].
365///
366/// This trait is sealed such that it cannot be implemented outside this crate.
367pub trait IntoMapRequestResult<B>: private::Sealed<B> {
368    /// Perform the conversion.
369    fn into_map_request_result(self) -> Result<Request<B>, Response>;
370}
371
372impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
373where
374    E: IntoResponse,
375{
376    fn into_map_request_result(self) -> Result<Request<B>, Response> {
377        self.map_err(IntoResponse::into_response)
378    }
379}
380
381impl<B> IntoMapRequestResult<B> for Request<B> {
382    fn into_map_request_result(self) -> Result<Request<B>, Response> {
383        Ok(self)
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::{routing::get, test_helpers::TestClient, Router};
391    use http::{HeaderMap, StatusCode};
392
393    #[crate::test]
394    async fn works() {
395        async fn add_header<B>(mut req: Request<B>) -> Request<B> {
396            req.headers_mut().insert("x-foo", "foo".parse().unwrap());
397            req
398        }
399
400        async fn handler(headers: HeaderMap) -> Response {
401            headers["x-foo"]
402                .to_str()
403                .unwrap()
404                .to_owned()
405                .into_response()
406        }
407
408        let app = Router::new()
409            .route("/", get(handler))
410            .layer(map_request(add_header));
411        let client = TestClient::new(app);
412
413        let res = client.get("/").await;
414
415        assert_eq!(res.text().await, "foo");
416    }
417
418    #[crate::test]
419    async fn works_for_short_circutting() {
420        async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
421            Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
422        }
423
424        async fn handler(_headers: HeaderMap) -> Response {
425            unreachable!()
426        }
427
428        let app = Router::new()
429            .route("/", get(handler))
430            .layer(map_request(add_header));
431        let client = TestClient::new(app);
432
433        let res = client.get("/").await;
434
435        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
436        assert_eq!(res.text().await, "something went wrong");
437    }
438}