axum/middleware/
from_extractor.rs

1use crate::{
2    extract::FromRequestParts,
3    response::{IntoResponse, Response},
4};
5use futures_util::future::BoxFuture;
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9    fmt,
10    future::Future,
11    marker::PhantomData,
12    pin::Pin,
13    task::{ready, Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18/// Create a middleware from an extractor.
19///
20/// If the extractor succeeds the value will be discarded and the inner service
21/// will be called. If the extractor fails the rejection will be returned and
22/// the inner service will _not_ be called.
23///
24/// This can be used to perform validation of requests if the validation doesn't
25/// produce any useful output, and run the extractor for several handlers
26/// without repeating it in the function signature.
27///
28/// Note that if the extractor consumes the request body, as `String` or
29/// [`Bytes`] does, an empty body will be left in its place. Thus won't be
30/// accessible to subsequent extractors or handlers.
31///
32/// # Example
33///
34/// ```rust
35/// use axum::{
36///     extract::FromRequestParts,
37///     middleware::from_extractor,
38///     routing::{get, post},
39///     Router,
40///     http::{header, StatusCode, request::Parts},
41/// };
42///
43/// // An extractor that performs authorization.
44/// struct RequireAuth;
45///
46/// impl<S> FromRequestParts<S> for RequireAuth
47/// where
48///     S: Send + Sync,
49/// {
50///     type Rejection = StatusCode;
51///
52///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
53///         let auth_header = parts
54///             .headers
55///             .get(header::AUTHORIZATION)
56///             .and_then(|value| value.to_str().ok());
57///
58///         match auth_header {
59///             Some(auth_header) if token_is_valid(auth_header) => {
60///                 Ok(Self)
61///             }
62///             _ => Err(StatusCode::UNAUTHORIZED),
63///         }
64///     }
65/// }
66///
67/// fn token_is_valid(token: &str) -> bool {
68///     // ...
69///     # false
70/// }
71///
72/// async fn handler() {
73///     // If we get here the request has been authorized
74/// }
75///
76/// async fn other_handler() {
77///     // If we get here the request has been authorized
78/// }
79///
80/// let app = Router::new()
81///     .route("/", get(handler))
82///     .route("/foo", post(other_handler))
83///     // The extractor will run before all routes
84///     .route_layer(from_extractor::<RequireAuth>());
85/// # let _: Router = app;
86/// ```
87///
88/// [`Bytes`]: bytes::Bytes
89pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
90    from_extractor_with_state(())
91}
92
93/// Create a middleware from an extractor with the given state.
94///
95/// See [`State`](crate::extract::State) for more details about accessing state.
96pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
97    FromExtractorLayer {
98        state,
99        _marker: PhantomData,
100    }
101}
102
103/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
104/// discards the value.
105///
106/// See [`from_extractor`] for more details.
107///
108/// [`Layer`]: tower::Layer
109#[must_use]
110pub struct FromExtractorLayer<E, S> {
111    state: S,
112    _marker: PhantomData<fn() -> E>,
113}
114
115impl<E, S> Clone for FromExtractorLayer<E, S>
116where
117    S: Clone,
118{
119    fn clone(&self) -> Self {
120        Self {
121            state: self.state.clone(),
122            _marker: PhantomData,
123        }
124    }
125}
126
127impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
128where
129    S: fmt::Debug,
130{
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("FromExtractorLayer")
133            .field("state", &self.state)
134            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
135            .finish()
136    }
137}
138
139impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
140where
141    S: Clone,
142{
143    type Service = FromExtractor<T, E, S>;
144
145    fn layer(&self, inner: T) -> Self::Service {
146        FromExtractor {
147            inner,
148            state: self.state.clone(),
149            _extractor: PhantomData,
150        }
151    }
152}
153
154/// Middleware that runs an extractor and discards the value.
155///
156/// See [`from_extractor`] for more details.
157pub struct FromExtractor<T, E, S> {
158    inner: T,
159    state: S,
160    _extractor: PhantomData<fn() -> E>,
161}
162
163#[test]
164fn traits() {
165    use crate::test_helpers::*;
166    assert_send::<FromExtractor<(), NotSendSync, ()>>();
167    assert_sync::<FromExtractor<(), NotSendSync, ()>>();
168}
169
170impl<T, E, S> Clone for FromExtractor<T, E, S>
171where
172    T: Clone,
173    S: Clone,
174{
175    fn clone(&self) -> Self {
176        Self {
177            inner: self.inner.clone(),
178            state: self.state.clone(),
179            _extractor: PhantomData,
180        }
181    }
182}
183
184impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
185where
186    T: fmt::Debug,
187    S: fmt::Debug,
188{
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        f.debug_struct("FromExtractor")
191            .field("inner", &self.inner)
192            .field("state", &self.state)
193            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
194            .finish()
195    }
196}
197
198impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
199where
200    E: FromRequestParts<S> + 'static,
201    B: Send + 'static,
202    T: Service<Request<B>> + Clone,
203    T::Response: IntoResponse,
204    S: Clone + Send + Sync + 'static,
205{
206    type Response = Response;
207    type Error = T::Error;
208    type Future = ResponseFuture<B, T, E, S>;
209
210    #[inline]
211    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        self.inner.poll_ready(cx)
213    }
214
215    fn call(&mut self, req: Request<B>) -> Self::Future {
216        let state = self.state.clone();
217        let (mut parts, body) = req.into_parts();
218
219        let extract_future = Box::pin(async move {
220            let extracted = E::from_request_parts(&mut parts, &state).await;
221            let req = Request::from_parts(parts, body);
222            (req, extracted)
223        });
224
225        ResponseFuture {
226            state: State::Extracting {
227                future: extract_future,
228            },
229            svc: Some(self.inner.clone()),
230        }
231    }
232}
233
234pin_project! {
235    /// Response future for [`FromExtractor`].
236    #[allow(missing_debug_implementations)]
237    pub struct ResponseFuture<B, T, E, S>
238    where
239        E: FromRequestParts<S>,
240        T: Service<Request<B>>,
241    {
242        #[pin]
243        state: State<B, T, E, S>,
244        svc: Option<T>,
245    }
246}
247
248pin_project! {
249    #[project = StateProj]
250    enum State<B, T, E, S>
251    where
252        E: FromRequestParts<S>,
253        T: Service<Request<B>>,
254    {
255        Extracting {
256            future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
257        },
258        Call { #[pin] future: T::Future },
259    }
260}
261
262impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
263where
264    E: FromRequestParts<S>,
265    T: Service<Request<B>>,
266    T::Response: IntoResponse,
267{
268    type Output = Result<Response, T::Error>;
269
270    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
271        loop {
272            let mut this = self.as_mut().project();
273
274            let new_state = match this.state.as_mut().project() {
275                StateProj::Extracting { future } => {
276                    let (req, extracted) = ready!(future.as_mut().poll(cx));
277
278                    match extracted {
279                        Ok(_) => {
280                            let mut svc = this.svc.take().expect("future polled after completion");
281                            let future = svc.call(req);
282                            State::Call { future }
283                        }
284                        Err(err) => {
285                            let res = err.into_response();
286                            return Poll::Ready(Ok(res));
287                        }
288                    }
289                }
290                StateProj::Call { future } => {
291                    return future
292                        .poll(cx)
293                        .map(|result| result.map(IntoResponse::into_response));
294                }
295            };
296
297            this.state.set(new_state);
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::{handler::Handler, routing::get, test_helpers::*, Router};
306    use axum_core::extract::FromRef;
307    use http::{header, request::Parts, StatusCode};
308    use tower_http::limit::RequestBodyLimitLayer;
309
310    #[crate::test]
311    async fn test_from_extractor() {
312        #[derive(Clone)]
313        struct Secret(&'static str);
314
315        struct RequireAuth;
316
317        impl<S> FromRequestParts<S> for RequireAuth
318        where
319            S: Send + Sync,
320            Secret: FromRef<S>,
321        {
322            type Rejection = StatusCode;
323
324            async fn from_request_parts(
325                parts: &mut Parts,
326                state: &S,
327            ) -> Result<Self, Self::Rejection> {
328                let Secret(secret) = Secret::from_ref(state);
329                if let Some(auth) = parts
330                    .headers
331                    .get(header::AUTHORIZATION)
332                    .and_then(|v| v.to_str().ok())
333                {
334                    if auth == secret {
335                        return Ok(Self);
336                    }
337                }
338
339                Err(StatusCode::UNAUTHORIZED)
340            }
341        }
342
343        async fn handler() {}
344
345        let state = Secret("secret");
346        let app = Router::new().route(
347            "/",
348            get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
349        );
350
351        let client = TestClient::new(app);
352
353        let res = client.get("/").await;
354        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
355
356        let res = client
357            .get("/")
358            .header(http::header::AUTHORIZATION, "secret")
359            .await;
360        assert_eq!(res.status(), StatusCode::OK);
361    }
362
363    // just needs to compile
364    #[allow(dead_code)]
365    fn works_with_request_body_limit() {
366        struct MyExtractor;
367
368        impl<S> FromRequestParts<S> for MyExtractor
369        where
370            S: Send + Sync,
371        {
372            type Rejection = std::convert::Infallible;
373
374            async fn from_request_parts(
375                _parts: &mut Parts,
376                _state: &S,
377            ) -> Result<Self, Self::Rejection> {
378                unimplemented!()
379            }
380        }
381
382        let _: Router = Router::new()
383            .layer(from_extractor::<MyExtractor>())
384            .layer(RequestBodyLimitLayer::new(1));
385    }
386}