axum_core/ext_traits/
request.rs

1use crate::body::Body;
2use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
3use std::future::Future;
4
5mod sealed {
6    pub trait Sealed {}
7    impl Sealed for http::Request<crate::body::Body> {}
8}
9
10/// Extension trait that adds additional methods to [`Request`].
11pub trait RequestExt: sealed::Sealed + Sized {
12    /// Apply an extractor to this `Request`.
13    ///
14    /// This is just a convenience for `E::from_request(req, &())`.
15    ///
16    /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
17    /// the body and don't want to consume the request.
18    ///
19    /// # Example
20    ///
21    /// ```
22    /// use axum::{
23    ///     extract::{Request, FromRequest},
24    ///     body::Body,
25    ///     http::{header::CONTENT_TYPE, StatusCode},
26    ///     response::{IntoResponse, Response},
27    ///     Form, Json, RequestExt,
28    /// };
29    ///
30    /// struct FormOrJson<T>(T);
31    ///
32    /// impl<S, T> FromRequest<S> for FormOrJson<T>
33    /// where
34    ///     Json<T>: FromRequest<()>,
35    ///     Form<T>: FromRequest<()>,
36    ///     T: 'static,
37    ///     S: Send + Sync,
38    /// {
39    ///     type Rejection = Response;
40    ///
41    ///     async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
42    ///         let content_type = req
43    ///             .headers()
44    ///             .get(CONTENT_TYPE)
45    ///             .and_then(|value| value.to_str().ok())
46    ///             .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?;
47    ///
48    ///         if content_type.starts_with("application/json") {
49    ///             let Json(payload) = req
50    ///                 .extract::<Json<T>, _>()
51    ///                 .await
52    ///                 .map_err(|err| err.into_response())?;
53    ///
54    ///             Ok(Self(payload))
55    ///         } else if content_type.starts_with("application/x-www-form-urlencoded") {
56    ///             let Form(payload) = req
57    ///                 .extract::<Form<T>, _>()
58    ///                 .await
59    ///                 .map_err(|err| err.into_response())?;
60    ///
61    ///             Ok(Self(payload))
62    ///         } else {
63    ///             Err(StatusCode::BAD_REQUEST.into_response())
64    ///         }
65    ///     }
66    /// }
67    /// ```
68    fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
69    where
70        E: FromRequest<(), M> + 'static,
71        M: 'static;
72
73    /// Apply an extractor that requires some state to this `Request`.
74    ///
75    /// This is just a convenience for `E::from_request(req, state)`.
76    ///
77    /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
78    /// extracting the body and don't want to consume the request.
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// use axum::{
84    ///     body::Body,
85    ///     extract::{Request, FromRef, FromRequest},
86    ///     RequestExt,
87    /// };
88    ///
89    /// struct MyExtractor {
90    ///     requires_state: RequiresState,
91    /// }
92    ///
93    /// impl<S> FromRequest<S> for MyExtractor
94    /// where
95    ///     String: FromRef<S>,
96    ///     S: Send + Sync,
97    /// {
98    ///     type Rejection = std::convert::Infallible;
99    ///
100    ///     async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
101    ///         let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?;
102    ///
103    ///         Ok(Self { requires_state })
104    ///     }
105    /// }
106    ///
107    /// // some extractor that consumes the request body and requires state
108    /// struct RequiresState { /* ... */ }
109    ///
110    /// impl<S> FromRequest<S> for RequiresState
111    /// where
112    ///     String: FromRef<S>,
113    ///     S: Send + Sync,
114    /// {
115    ///     // ...
116    ///     # type Rejection = std::convert::Infallible;
117    ///     # async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
118    ///     #     todo!()
119    ///     # }
120    /// }
121    /// ```
122    fn extract_with_state<E, S, M>(
123        self,
124        state: &S,
125    ) -> impl Future<Output = Result<E, E::Rejection>> + Send
126    where
127        E: FromRequest<S, M> + 'static,
128        S: Send + Sync;
129
130    /// Apply a parts extractor to this `Request`.
131    ///
132    /// This is just a convenience for `E::from_request_parts(parts, state)`.
133    ///
134    /// # Example
135    ///
136    /// ```
137    /// use axum::{
138    ///     extract::{Path, Request, FromRequest},
139    ///     response::{IntoResponse, Response},
140    ///     body::Body,
141    ///     Json, RequestExt,
142    /// };
143    /// use axum_extra::{
144    ///     TypedHeader,
145    ///     headers::{authorization::Bearer, Authorization},
146    /// };
147    /// use std::collections::HashMap;
148    ///
149    /// struct MyExtractor<T> {
150    ///     path_params: HashMap<String, String>,
151    ///     payload: T,
152    /// }
153    ///
154    /// impl<S, T> FromRequest<S> for MyExtractor<T>
155    /// where
156    ///     S: Send + Sync,
157    ///     Json<T>: FromRequest<()>,
158    ///     T: 'static,
159    /// {
160    ///     type Rejection = Response;
161    ///
162    ///     async fn from_request(mut req: Request, _state: &S) -> Result<Self, Self::Rejection> {
163    ///         let path_params = req
164    ///             .extract_parts::<Path<_>>()
165    ///             .await
166    ///             .map(|Path(path_params)| path_params)
167    ///             .map_err(|err| err.into_response())?;
168    ///
169    ///         let Json(payload) = req
170    ///             .extract::<Json<T>, _>()
171    ///             .await
172    ///             .map_err(|err| err.into_response())?;
173    ///
174    ///         Ok(Self { path_params, payload })
175    ///     }
176    /// }
177    /// ```
178    fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
179    where
180        E: FromRequestParts<()> + 'static;
181
182    /// Apply a parts extractor that requires some state to this `Request`.
183    ///
184    /// This is just a convenience for `E::from_request_parts(parts, state)`.
185    ///
186    /// # Example
187    ///
188    /// ```
189    /// use axum::{
190    ///     extract::{Request, FromRef, FromRequest, FromRequestParts},
191    ///     http::request::Parts,
192    ///     response::{IntoResponse, Response},
193    ///     body::Body,
194    ///     Json, RequestExt,
195    /// };
196    ///
197    /// struct MyExtractor<T> {
198    ///     requires_state: RequiresState,
199    ///     payload: T,
200    /// }
201    ///
202    /// impl<S, T> FromRequest<S> for MyExtractor<T>
203    /// where
204    ///     String: FromRef<S>,
205    ///     Json<T>: FromRequest<()>,
206    ///     T: 'static,
207    ///     S: Send + Sync,
208    /// {
209    ///     type Rejection = Response;
210    ///
211    ///     async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
212    ///         let requires_state = req
213    ///             .extract_parts_with_state::<RequiresState, _>(state)
214    ///             .await
215    ///             .map_err(|err| err.into_response())?;
216    ///
217    ///         let Json(payload) = req
218    ///             .extract::<Json<T>, _>()
219    ///             .await
220    ///             .map_err(|err| err.into_response())?;
221    ///
222    ///         Ok(Self {
223    ///             requires_state,
224    ///             payload,
225    ///         })
226    ///     }
227    /// }
228    ///
229    /// struct RequiresState {}
230    ///
231    /// impl<S> FromRequestParts<S> for RequiresState
232    /// where
233    ///     String: FromRef<S>,
234    ///     S: Send + Sync,
235    /// {
236    ///     // ...
237    ///     # type Rejection = std::convert::Infallible;
238    ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
239    ///     #     todo!()
240    ///     # }
241    /// }
242    /// ```
243    fn extract_parts_with_state<'a, E, S>(
244        &'a mut self,
245        state: &'a S,
246    ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
247    where
248        E: FromRequestParts<S> + 'static,
249        S: Send + Sync;
250
251    /// Apply the [default body limit](crate::extract::DefaultBodyLimit).
252    ///
253    /// If it is disabled, the request is returned as-is.
254    fn with_limited_body(self) -> Request;
255
256    /// Consumes the request, returning the body wrapped in [`http_body_util::Limited`] if a
257    /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
258    /// default limit is disabled.
259    fn into_limited_body(self) -> Body;
260}
261
262impl RequestExt for Request {
263    fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
264    where
265        E: FromRequest<(), M> + 'static,
266        M: 'static,
267    {
268        self.extract_with_state(&())
269    }
270
271    fn extract_with_state<E, S, M>(
272        self,
273        state: &S,
274    ) -> impl Future<Output = Result<E, E::Rejection>> + Send
275    where
276        E: FromRequest<S, M> + 'static,
277        S: Send + Sync,
278    {
279        E::from_request(self, state)
280    }
281
282    fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
283    where
284        E: FromRequestParts<()> + 'static,
285    {
286        self.extract_parts_with_state(&())
287    }
288
289    async fn extract_parts_with_state<'a, E, S>(
290        &'a mut self,
291        state: &'a S,
292    ) -> Result<E, E::Rejection>
293    where
294        E: FromRequestParts<S> + 'static,
295        S: Send + Sync,
296    {
297        let mut req = Request::new(());
298        *req.version_mut() = self.version();
299        *req.method_mut() = self.method().clone();
300        *req.uri_mut() = self.uri().clone();
301        *req.headers_mut() = std::mem::take(self.headers_mut());
302        *req.extensions_mut() = std::mem::take(self.extensions_mut());
303        let (mut parts, ()) = req.into_parts();
304
305        let result = E::from_request_parts(&mut parts, state).await;
306
307        *self.version_mut() = parts.version;
308        *self.method_mut() = parts.method.clone();
309        *self.uri_mut() = parts.uri.clone();
310        *self.headers_mut() = std::mem::take(&mut parts.headers);
311        *self.extensions_mut() = std::mem::take(&mut parts.extensions);
312
313        result
314    }
315
316    fn with_limited_body(self) -> Request {
317        // update docs in `axum-core/src/extract/default_body_limit.rs` and
318        // `axum/src/docs/extract.md` if this changes
319        const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
320
321        match self.extensions().get::<DefaultBodyLimitKind>().copied() {
322            Some(DefaultBodyLimitKind::Disable) => self,
323            Some(DefaultBodyLimitKind::Limit(limit)) => {
324                self.map(|b| Body::new(http_body_util::Limited::new(b, limit)))
325            }
326            None => self.map(|b| Body::new(http_body_util::Limited::new(b, DEFAULT_LIMIT))),
327        }
328    }
329
330    fn into_limited_body(self) -> Body {
331        self.with_limited_body().into_body()
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::{
339        ext_traits::tests::{RequiresState, State},
340        extract::FromRef,
341    };
342    use http::Method;
343
344    #[tokio::test]
345    async fn extract_without_state() {
346        let req = Request::new(Body::empty());
347
348        let method: Method = req.extract().await.unwrap();
349
350        assert_eq!(method, Method::GET);
351    }
352
353    #[tokio::test]
354    async fn extract_body_without_state() {
355        let req = Request::new(Body::from("foobar"));
356
357        let body: String = req.extract().await.unwrap();
358
359        assert_eq!(body, "foobar");
360    }
361
362    #[tokio::test]
363    async fn extract_with_state() {
364        let req = Request::new(Body::empty());
365
366        let state = "state".to_owned();
367
368        let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();
369
370        assert_eq!(extracted_state, state);
371    }
372
373    #[tokio::test]
374    async fn extract_parts_without_state() {
375        let mut req = Request::builder()
376            .header("x-foo", "foo")
377            .body(Body::empty())
378            .unwrap();
379
380        let method: Method = req.extract_parts().await.unwrap();
381
382        assert_eq!(method, Method::GET);
383        assert_eq!(req.headers()["x-foo"], "foo");
384    }
385
386    #[tokio::test]
387    async fn extract_parts_with_state() {
388        let mut req = Request::builder()
389            .header("x-foo", "foo")
390            .body(Body::empty())
391            .unwrap();
392
393        let state = "state".to_owned();
394
395        let State(extracted_state): State<String> =
396            req.extract_parts_with_state(&state).await.unwrap();
397
398        assert_eq!(extracted_state, state);
399        assert_eq!(req.headers()["x-foo"], "foo");
400    }
401
402    // this stuff just needs to compile
403    #[allow(dead_code)]
404    struct WorksForCustomExtractor {
405        method: Method,
406        from_state: String,
407        body: String,
408    }
409
410    impl<S> FromRequest<S> for WorksForCustomExtractor
411    where
412        S: Send + Sync,
413        String: FromRef<S> + FromRequest<()>,
414    {
415        type Rejection = <String as FromRequest<()>>::Rejection;
416
417        async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
418            let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
419            let method = req.extract_parts().await.unwrap();
420            let body = req.extract().await?;
421
422            Ok(Self {
423                method,
424                from_state,
425                body,
426            })
427        }
428    }
429}