axum/middleware/
from_fn.rs

1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5    any::type_name,
6    convert::Infallible,
7    fmt,
8    future::Future,
9    marker::PhantomData,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tower::{util::BoxCloneService, ServiceBuilder};
14use tower_layer::Layer;
15use tower_service::Service;
16
17/// Create a middleware from an async function.
18///
19/// `from_fn` requires the function given to
20///
21/// 1. Be an `async fn`.
22/// 2. Take one or more [extractors] as the first arguments.
23/// 3. Take [`Next`](Next) as the final argument.
24/// 4. Return something that implements [`IntoResponse`].
25///
26/// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
27///
28/// # Example
29///
30/// ```rust
31/// use axum::{
32///     Router,
33///     http,
34///     routing::get,
35///     response::Response,
36///     middleware::{self, Next},
37///     extract::Request,
38/// };
39///
40/// async fn my_middleware(
41///     request: Request,
42///     next: Next,
43/// ) -> Response {
44///     // do something with `request`...
45///
46///     let response = next.run(request).await;
47///
48///     // do something with `response`...
49///
50///     response
51/// }
52///
53/// let app = Router::new()
54///     .route("/", get(|| async { /* ... */ }))
55///     .layer(middleware::from_fn(my_middleware));
56/// # let app: Router = app;
57/// ```
58///
59/// # Running extractors
60///
61/// ```rust
62/// use axum::{
63///     Router,
64///     extract::Request,
65///     http::{StatusCode, HeaderMap},
66///     middleware::{self, Next},
67///     response::Response,
68///     routing::get,
69/// };
70///
71/// async fn auth(
72///     // run the `HeaderMap` extractor
73///     headers: HeaderMap,
74///     // you can also add more extractors here but the last
75///     // extractor must implement `FromRequest` which
76///     // `Request` does
77///     request: Request,
78///     next: Next,
79/// ) -> Result<Response, StatusCode> {
80///     match get_token(&headers) {
81///         Some(token) if token_is_valid(token) => {
82///             let response = next.run(request).await;
83///             Ok(response)
84///         }
85///         _ => {
86///             Err(StatusCode::UNAUTHORIZED)
87///         }
88///     }
89/// }
90///
91/// fn get_token(headers: &HeaderMap) -> Option<&str> {
92///     // ...
93///     # None
94/// }
95///
96/// fn token_is_valid(token: &str) -> bool {
97///     // ...
98///     # false
99/// }
100///
101/// let app = Router::new()
102///     .route("/", get(|| async { /* ... */ }))
103///     .route_layer(middleware::from_fn(auth));
104/// # let app: Router = app;
105/// ```
106///
107/// [extractors]: crate::extract::FromRequest
108/// [`State`]: crate::extract::State
109pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
110    from_fn_with_state((), f)
111}
112
113/// Create a middleware from an async function with the given state.
114///
115/// See [`State`](crate::extract::State) for more details about accessing state.
116///
117/// # Example
118///
119/// ```rust
120/// use axum::{
121///     Router,
122///     http::StatusCode,
123///     routing::get,
124///     response::{IntoResponse, Response},
125///     middleware::{self, Next},
126///     extract::{Request, State},
127/// };
128///
129/// #[derive(Clone)]
130/// struct AppState { /* ... */ }
131///
132/// async fn my_middleware(
133///     State(state): State<AppState>,
134///     // you can add more extractors here but the last
135///     // extractor must implement `FromRequest` which
136///     // `Request` does
137///     request: Request,
138///     next: Next,
139/// ) -> Response {
140///     // do something with `request`...
141///
142///     let response = next.run(request).await;
143///
144///     // do something with `response`...
145///
146///     response
147/// }
148///
149/// let state = AppState { /* ... */ };
150///
151/// let app = Router::new()
152///     .route("/", get(|| async { /* ... */ }))
153///     .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
154///     .with_state(state);
155/// # let _: axum::Router = app;
156/// ```
157pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
158    FromFnLayer {
159        f,
160        state,
161        _extractor: PhantomData,
162    }
163}
164
165/// A [`tower::Layer`] from an async function.
166///
167/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
168///
169/// Created with [`from_fn`]. See that function for more details.
170#[must_use]
171pub struct FromFnLayer<F, S, T> {
172    f: F,
173    state: S,
174    _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for FromFnLayer<F, S, T>
178where
179    F: Clone,
180    S: Clone,
181{
182    fn clone(&self) -> Self {
183        Self {
184            f: self.f.clone(),
185            state: self.state.clone(),
186            _extractor: self._extractor,
187        }
188    }
189}
190
191impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
192where
193    F: Clone,
194    S: Clone,
195{
196    type Service = FromFn<F, S, I, T>;
197
198    fn layer(&self, inner: I) -> Self::Service {
199        FromFn {
200            f: self.f.clone(),
201            state: self.state.clone(),
202            inner,
203            _extractor: PhantomData,
204        }
205    }
206}
207
208impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
209where
210    S: fmt::Debug,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        f.debug_struct("FromFnLayer")
214            // Write out the type name, without quoting it as `&type_name::<F>()` would
215            .field("f", &format_args!("{}", type_name::<F>()))
216            .field("state", &self.state)
217            .finish()
218    }
219}
220
221/// A middleware created from an async function.
222///
223/// Created with [`from_fn`]. See that function for more details.
224pub struct FromFn<F, S, I, T> {
225    f: F,
226    inner: I,
227    state: S,
228    _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for FromFn<F, S, I, T>
232where
233    F: Clone,
234    I: Clone,
235    S: Clone,
236{
237    fn clone(&self) -> Self {
238        Self {
239            f: self.f.clone(),
240            inner: self.inner.clone(),
241            state: self.state.clone(),
242            _extractor: self._extractor,
243        }
244    }
245}
246
247macro_rules! impl_service {
248    (
249        [$($ty:ident),*], $last:ident
250    ) => {
251        #[allow(non_snake_case, unused_mut)]
252        impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
253        where
254            F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
255            $( $ty: FromRequestParts<S> + Send, )*
256            $last: FromRequest<S> + Send,
257            Fut: Future<Output = Out> + Send + 'static,
258            Out: IntoResponse + 'static,
259            I: Service<Request, Error = Infallible>
260                + Clone
261                + Send
262                + 'static,
263            I::Response: IntoResponse,
264            I::Future: Send + 'static,
265            S: Clone + Send + Sync + 'static,
266        {
267            type Response = Response;
268            type Error = Infallible;
269            type Future = ResponseFuture;
270
271            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272                self.inner.poll_ready(cx)
273            }
274
275            fn call(&mut self, req: Request) -> Self::Future {
276                let not_ready_inner = self.inner.clone();
277                let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
278
279                let mut f = self.f.clone();
280                let state = self.state.clone();
281
282                let future = Box::pin(async move {
283                    let (mut parts, body) = req.into_parts();
284
285                    $(
286                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
287                            Ok(value) => value,
288                            Err(rejection) => return rejection.into_response(),
289                        };
290                    )*
291
292                    let req = Request::from_parts(parts, body);
293
294                    let $last = match $last::from_request(req, &state).await {
295                        Ok(value) => value,
296                        Err(rejection) => return rejection.into_response(),
297                    };
298
299                    let inner = ServiceBuilder::new()
300                        .boxed_clone()
301                        .map_response(IntoResponse::into_response)
302                        .service(ready_inner);
303                    let next = Next { inner };
304
305                    f($($ty,)* $last, next).await.into_response()
306                });
307
308                ResponseFuture {
309                    inner: future
310                }
311            }
312        }
313    };
314}
315
316all_the_tuples!(impl_service);
317
318impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
319where
320    S: fmt::Debug,
321    I: fmt::Debug,
322{
323    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324        f.debug_struct("FromFnLayer")
325            .field("f", &format_args!("{}", type_name::<F>()))
326            .field("inner", &self.inner)
327            .field("state", &self.state)
328            .finish()
329    }
330}
331
332/// The remainder of a middleware stack, including the handler.
333#[derive(Debug, Clone)]
334pub struct Next {
335    inner: BoxCloneService<Request, Response, Infallible>,
336}
337
338impl Next {
339    /// Execute the remaining middleware stack.
340    pub async fn run(mut self, req: Request) -> Response {
341        match self.inner.call(req).await {
342            Ok(res) => res,
343            Err(err) => match err {},
344        }
345    }
346}
347
348impl Service<Request> for Next {
349    type Response = Response;
350    type Error = Infallible;
351    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
352
353    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
354        self.inner.poll_ready(cx)
355    }
356
357    fn call(&mut self, req: Request) -> Self::Future {
358        self.inner.call(req)
359    }
360}
361
362/// Response future for [`FromFn`].
363pub struct ResponseFuture {
364    inner: BoxFuture<'static, Response>,
365}
366
367impl Future for ResponseFuture {
368    type Output = Result<Response, Infallible>;
369
370    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
371        self.inner.as_mut().poll(cx).map(Ok)
372    }
373}
374
375impl fmt::Debug for ResponseFuture {
376    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377        f.debug_struct("ResponseFuture").finish()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::{body::Body, routing::get, Router};
385    use http::{HeaderMap, StatusCode};
386    use http_body_util::BodyExt;
387    use tower::ServiceExt;
388
389    #[crate::test]
390    async fn basic() {
391        async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
392            req.headers_mut()
393                .insert("x-axum-test", "ok".parse().unwrap());
394
395            next.run(req).await
396        }
397
398        async fn handle(headers: HeaderMap) -> String {
399            headers["x-axum-test"].to_str().unwrap().to_owned()
400        }
401
402        let app = Router::new()
403            .route("/", get(handle))
404            .layer(from_fn(insert_header));
405
406        let res = app
407            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
408            .await
409            .unwrap();
410        assert_eq!(res.status(), StatusCode::OK);
411        let body = res.collect().await.unwrap().to_bytes();
412        assert_eq!(&body[..], b"ok");
413    }
414}