axum_core/ext_traits/request.rs
1use crate::body::Body;
2use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4
5mod sealed {
6    pub trait Sealed {}
7    impl Sealed for http::Request<crate::body::Body> {}
8}
9
10/// Extension trait that adds additional methods to [`Request`].
11pub trait RequestExt: sealed::Sealed + Sized {
12    /// Apply an extractor to this `Request`.
13    ///
14    /// This is just a convenience for `E::from_request(req, &())`.
15    ///
16    /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
17    /// the body and don't want to consume the request.
18    ///
19    /// # Example
20    ///
21    /// ```
22    /// use axum::{
23    ///     async_trait,
24    ///     extract::{Request, FromRequest},
25    ///     body::Body,
26    ///     http::{header::CONTENT_TYPE, StatusCode},
27    ///     response::{IntoResponse, Response},
28    ///     Form, Json, RequestExt,
29    /// };
30    ///
31    /// struct FormOrJson<T>(T);
32    ///
33    /// #[async_trait]
34    /// impl<S, T> FromRequest<S> for FormOrJson<T>
35    /// where
36    ///     Json<T>: FromRequest<()>,
37    ///     Form<T>: FromRequest<()>,
38    ///     T: 'static,
39    ///     S: Send + Sync,
40    /// {
41    ///     type Rejection = Response;
42    ///
43    ///     async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
44    ///         let content_type = req
45    ///             .headers()
46    ///             .get(CONTENT_TYPE)
47    ///             .and_then(|value| value.to_str().ok())
48    ///             .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?;
49    ///
50    ///         if content_type.starts_with("application/json") {
51    ///             let Json(payload) = req
52    ///                 .extract::<Json<T>, _>()
53    ///                 .await
54    ///                 .map_err(|err| err.into_response())?;
55    ///
56    ///             Ok(Self(payload))
57    ///         } else if content_type.starts_with("application/x-www-form-urlencoded") {
58    ///             let Form(payload) = req
59    ///                 .extract::<Form<T>, _>()
60    ///                 .await
61    ///                 .map_err(|err| err.into_response())?;
62    ///
63    ///             Ok(Self(payload))
64    ///         } else {
65    ///             Err(StatusCode::BAD_REQUEST.into_response())
66    ///         }
67    ///     }
68    /// }
69    /// ```
70    fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
71    where
72        E: FromRequest<(), M> + 'static,
73        M: 'static;
74
75    /// Apply an extractor that requires some state to this `Request`.
76    ///
77    /// This is just a convenience for `E::from_request(req, state)`.
78    ///
79    /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
80    /// extracting the body and don't want to consume the request.
81    ///
82    /// # Example
83    ///
84    /// ```
85    /// use axum::{
86    ///     async_trait,
87    ///     body::Body,
88    ///     extract::{Request, FromRef, FromRequest},
89    ///     RequestExt,
90    /// };
91    ///
92    /// struct MyExtractor {
93    ///     requires_state: RequiresState,
94    /// }
95    ///
96    /// #[async_trait]
97    /// impl<S> FromRequest<S> for MyExtractor
98    /// where
99    ///     String: FromRef<S>,
100    ///     S: Send + Sync,
101    /// {
102    ///     type Rejection = std::convert::Infallible;
103    ///
104    ///     async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
105    ///         let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?;
106    ///
107    ///         Ok(Self { requires_state })
108    ///     }
109    /// }
110    ///
111    /// // some extractor that consumes the request body and requires state
112    /// struct RequiresState { /* ... */ }
113    ///
114    /// #[async_trait]
115    /// impl<S> FromRequest<S> for RequiresState
116    /// where
117    ///     String: FromRef<S>,
118    ///     S: Send + Sync,
119    /// {
120    ///     // ...
121    ///     # type Rejection = std::convert::Infallible;
122    ///     # async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
123    ///     #     todo!()
124    ///     # }
125    /// }
126    /// ```
127    fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
128    where
129        E: FromRequest<S, M> + 'static,
130        S: Send + Sync;
131
132    /// Apply a parts extractor to this `Request`.
133    ///
134    /// This is just a convenience for `E::from_request_parts(parts, state)`.
135    ///
136    /// # Example
137    ///
138    /// ```
139    /// use axum::{
140    ///     async_trait,
141    ///     extract::{Path, Request, FromRequest},
142    ///     response::{IntoResponse, Response},
143    ///     body::Body,
144    ///     Json, RequestExt,
145    /// };
146    /// use axum_extra::{
147    ///     TypedHeader,
148    ///     headers::{authorization::Bearer, Authorization},
149    /// };
150    /// use std::collections::HashMap;
151    ///
152    /// struct MyExtractor<T> {
153    ///     path_params: HashMap<String, String>,
154    ///     payload: T,
155    /// }
156    ///
157    /// #[async_trait]
158    /// impl<S, T> FromRequest<S> for MyExtractor<T>
159    /// where
160    ///     S: Send + Sync,
161    ///     Json<T>: FromRequest<()>,
162    ///     T: 'static,
163    /// {
164    ///     type Rejection = Response;
165    ///
166    ///     async fn from_request(mut req: Request, _state: &S) -> Result<Self, Self::Rejection> {
167    ///         let path_params = req
168    ///             .extract_parts::<Path<_>>()
169    ///             .await
170    ///             .map(|Path(path_params)| path_params)
171    ///             .map_err(|err| err.into_response())?;
172    ///
173    ///         let Json(payload) = req
174    ///             .extract::<Json<T>, _>()
175    ///             .await
176    ///             .map_err(|err| err.into_response())?;
177    ///
178    ///         Ok(Self { path_params, payload })
179    ///     }
180    /// }
181    /// ```
182    fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
183    where
184        E: FromRequestParts<()> + 'static;
185
186    /// Apply a parts extractor that requires some state to this `Request`.
187    ///
188    /// This is just a convenience for `E::from_request_parts(parts, state)`.
189    ///
190    /// # Example
191    ///
192    /// ```
193    /// use axum::{
194    ///     async_trait,
195    ///     extract::{Request, FromRef, FromRequest, FromRequestParts},
196    ///     http::request::Parts,
197    ///     response::{IntoResponse, Response},
198    ///     body::Body,
199    ///     Json, RequestExt,
200    /// };
201    ///
202    /// struct MyExtractor<T> {
203    ///     requires_state: RequiresState,
204    ///     payload: T,
205    /// }
206    ///
207    /// #[async_trait]
208    /// impl<S, T> FromRequest<S> for MyExtractor<T>
209    /// where
210    ///     String: FromRef<S>,
211    ///     Json<T>: FromRequest<()>,
212    ///     T: 'static,
213    ///     S: Send + Sync,
214    /// {
215    ///     type Rejection = Response;
216    ///
217    ///     async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
218    ///         let requires_state = req
219    ///             .extract_parts_with_state::<RequiresState, _>(state)
220    ///             .await
221    ///             .map_err(|err| err.into_response())?;
222    ///
223    ///         let Json(payload) = req
224    ///             .extract::<Json<T>, _>()
225    ///             .await
226    ///             .map_err(|err| err.into_response())?;
227    ///
228    ///         Ok(Self {
229    ///             requires_state,
230    ///             payload,
231    ///         })
232    ///     }
233    /// }
234    ///
235    /// struct RequiresState {}
236    ///
237    /// #[async_trait]
238    /// impl<S> FromRequestParts<S> for RequiresState
239    /// where
240    ///     String: FromRef<S>,
241    ///     S: Send + Sync,
242    /// {
243    ///     // ...
244    ///     # type Rejection = std::convert::Infallible;
245    ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
246    ///     #     todo!()
247    ///     # }
248    /// }
249    /// ```
250    fn extract_parts_with_state<'a, E, S>(
251        &'a mut self,
252        state: &'a S,
253    ) -> BoxFuture<'a, Result<E, E::Rejection>>
254    where
255        E: FromRequestParts<S> + 'static,
256        S: Send + Sync;
257
258    /// Apply the [default body limit](crate::extract::DefaultBodyLimit).
259    ///
260    /// If it is disabled, the request is returned as-is.
261    fn with_limited_body(self) -> Request;
262
263    /// Consumes the request, returning the body wrapped in [`http_body_util::Limited`] if a
264    /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
265    /// default limit is disabled.
266    fn into_limited_body(self) -> Body;
267}
268
269impl RequestExt for Request {
270    fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
271    where
272        E: FromRequest<(), M> + 'static,
273        M: 'static,
274    {
275        self.extract_with_state(&())
276    }
277
278    fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
279    where
280        E: FromRequest<S, M> + 'static,
281        S: Send + Sync,
282    {
283        E::from_request(self, state)
284    }
285
286    fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
287    where
288        E: FromRequestParts<()> + 'static,
289    {
290        self.extract_parts_with_state(&())
291    }
292
293    fn extract_parts_with_state<'a, E, S>(
294        &'a mut self,
295        state: &'a S,
296    ) -> BoxFuture<'a, Result<E, E::Rejection>>
297    where
298        E: FromRequestParts<S> + 'static,
299        S: Send + Sync,
300    {
301        let mut req = Request::new(());
302        *req.version_mut() = self.version();
303        *req.method_mut() = self.method().clone();
304        *req.uri_mut() = self.uri().clone();
305        *req.headers_mut() = std::mem::take(self.headers_mut());
306        *req.extensions_mut() = std::mem::take(self.extensions_mut());
307        let (mut parts, ()) = req.into_parts();
308
309        Box::pin(async move {
310            let result = E::from_request_parts(&mut parts, state).await;
311
312            *self.version_mut() = parts.version;
313            *self.method_mut() = parts.method.clone();
314            *self.uri_mut() = parts.uri.clone();
315            *self.headers_mut() = std::mem::take(&mut parts.headers);
316            *self.extensions_mut() = std::mem::take(&mut parts.extensions);
317
318            result
319        })
320    }
321
322    fn with_limited_body(self) -> Request {
323        // update docs in `axum-core/src/extract/default_body_limit.rs` and
324        // `axum/src/docs/extract.md` if this changes
325        const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
326
327        match self.extensions().get::<DefaultBodyLimitKind>().copied() {
328            Some(DefaultBodyLimitKind::Disable) => self,
329            Some(DefaultBodyLimitKind::Limit(limit)) => {
330                self.map(|b| Body::new(http_body_util::Limited::new(b, limit)))
331            }
332            None => self.map(|b| Body::new(http_body_util::Limited::new(b, DEFAULT_LIMIT))),
333        }
334    }
335
336    fn into_limited_body(self) -> Body {
337        self.with_limited_body().into_body()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::{
345        ext_traits::tests::{RequiresState, State},
346        extract::FromRef,
347    };
348    use async_trait::async_trait;
349    use http::Method;
350
351    #[tokio::test]
352    async fn extract_without_state() {
353        let req = Request::new(Body::empty());
354
355        let method: Method = req.extract().await.unwrap();
356
357        assert_eq!(method, Method::GET);
358    }
359
360    #[tokio::test]
361    async fn extract_body_without_state() {
362        let req = Request::new(Body::from("foobar"));
363
364        let body: String = req.extract().await.unwrap();
365
366        assert_eq!(body, "foobar");
367    }
368
369    #[tokio::test]
370    async fn extract_with_state() {
371        let req = Request::new(Body::empty());
372
373        let state = "state".to_owned();
374
375        let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();
376
377        assert_eq!(extracted_state, state);
378    }
379
380    #[tokio::test]
381    async fn extract_parts_without_state() {
382        let mut req = Request::builder()
383            .header("x-foo", "foo")
384            .body(Body::empty())
385            .unwrap();
386
387        let method: Method = req.extract_parts().await.unwrap();
388
389        assert_eq!(method, Method::GET);
390        assert_eq!(req.headers()["x-foo"], "foo");
391    }
392
393    #[tokio::test]
394    async fn extract_parts_with_state() {
395        let mut req = Request::builder()
396            .header("x-foo", "foo")
397            .body(Body::empty())
398            .unwrap();
399
400        let state = "state".to_owned();
401
402        let State(extracted_state): State<String> =
403            req.extract_parts_with_state(&state).await.unwrap();
404
405        assert_eq!(extracted_state, state);
406        assert_eq!(req.headers()["x-foo"], "foo");
407    }
408
409    // this stuff just needs to compile
410    #[allow(dead_code)]
411    struct WorksForCustomExtractor {
412        method: Method,
413        from_state: String,
414        body: String,
415    }
416
417    #[async_trait]
418    impl<S> FromRequest<S> for WorksForCustomExtractor
419    where
420        S: Send + Sync,
421        String: FromRef<S> + FromRequest<()>,
422    {
423        type Rejection = <String as FromRequest<()>>::Rejection;
424
425        async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
426            let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
427            let method = req.extract_parts().await.unwrap();
428            let body = req.extract().await?;
429
430            Ok(Self {
431                method,
432                from_state,
433                body,
434            })
435        }
436    }
437}