axum_core/ext_traits/request.rs
1use crate::body::Body;
2use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
3use std::future::Future;
4
5mod sealed {
6 pub trait Sealed {}
7 impl Sealed for http::Request<crate::body::Body> {}
8}
9
10/// Extension trait that adds additional methods to [`Request`].
11pub trait RequestExt: sealed::Sealed + Sized {
12 /// Apply an extractor to this `Request`.
13 ///
14 /// This is just a convenience for `E::from_request(req, &())`.
15 ///
16 /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
17 /// the body and don't want to consume the request.
18 ///
19 /// # Example
20 ///
21 /// ```
22 /// use axum::{
23 /// extract::{Request, FromRequest},
24 /// body::Body,
25 /// http::{header::CONTENT_TYPE, StatusCode},
26 /// response::{IntoResponse, Response},
27 /// Form, Json, RequestExt,
28 /// };
29 ///
30 /// struct FormOrJson<T>(T);
31 ///
32 /// impl<S, T> FromRequest<S> for FormOrJson<T>
33 /// where
34 /// Json<T>: FromRequest<()>,
35 /// Form<T>: FromRequest<()>,
36 /// T: 'static,
37 /// S: Send + Sync,
38 /// {
39 /// type Rejection = Response;
40 ///
41 /// async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
42 /// let content_type = req
43 /// .headers()
44 /// .get(CONTENT_TYPE)
45 /// .and_then(|value| value.to_str().ok())
46 /// .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?;
47 ///
48 /// if content_type.starts_with("application/json") {
49 /// let Json(payload) = req
50 /// .extract::<Json<T>, _>()
51 /// .await
52 /// .map_err(|err| err.into_response())?;
53 ///
54 /// Ok(Self(payload))
55 /// } else if content_type.starts_with("application/x-www-form-urlencoded") {
56 /// let Form(payload) = req
57 /// .extract::<Form<T>, _>()
58 /// .await
59 /// .map_err(|err| err.into_response())?;
60 ///
61 /// Ok(Self(payload))
62 /// } else {
63 /// Err(StatusCode::BAD_REQUEST.into_response())
64 /// }
65 /// }
66 /// }
67 /// ```
68 fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
69 where
70 E: FromRequest<(), M> + 'static,
71 M: 'static;
72
73 /// Apply an extractor that requires some state to this `Request`.
74 ///
75 /// This is just a convenience for `E::from_request(req, state)`.
76 ///
77 /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
78 /// extracting the body and don't want to consume the request.
79 ///
80 /// # Example
81 ///
82 /// ```
83 /// use axum::{
84 /// body::Body,
85 /// extract::{Request, FromRef, FromRequest},
86 /// RequestExt,
87 /// };
88 ///
89 /// struct MyExtractor {
90 /// requires_state: RequiresState,
91 /// }
92 ///
93 /// impl<S> FromRequest<S> for MyExtractor
94 /// where
95 /// String: FromRef<S>,
96 /// S: Send + Sync,
97 /// {
98 /// type Rejection = std::convert::Infallible;
99 ///
100 /// async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
101 /// let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?;
102 ///
103 /// Ok(Self { requires_state })
104 /// }
105 /// }
106 ///
107 /// // some extractor that consumes the request body and requires state
108 /// struct RequiresState { /* ... */ }
109 ///
110 /// impl<S> FromRequest<S> for RequiresState
111 /// where
112 /// String: FromRef<S>,
113 /// S: Send + Sync,
114 /// {
115 /// // ...
116 /// # type Rejection = std::convert::Infallible;
117 /// # async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
118 /// # todo!()
119 /// # }
120 /// }
121 /// ```
122 fn extract_with_state<E, S, M>(
123 self,
124 state: &S,
125 ) -> impl Future<Output = Result<E, E::Rejection>> + Send
126 where
127 E: FromRequest<S, M> + 'static,
128 S: Send + Sync;
129
130 /// Apply a parts extractor to this `Request`.
131 ///
132 /// This is just a convenience for `E::from_request_parts(parts, state)`.
133 ///
134 /// # Example
135 ///
136 /// ```
137 /// use axum::{
138 /// extract::{Path, Request, FromRequest},
139 /// response::{IntoResponse, Response},
140 /// body::Body,
141 /// Json, RequestExt,
142 /// };
143 /// use axum_extra::{
144 /// TypedHeader,
145 /// headers::{authorization::Bearer, Authorization},
146 /// };
147 /// use std::collections::HashMap;
148 ///
149 /// struct MyExtractor<T> {
150 /// path_params: HashMap<String, String>,
151 /// payload: T,
152 /// }
153 ///
154 /// impl<S, T> FromRequest<S> for MyExtractor<T>
155 /// where
156 /// S: Send + Sync,
157 /// Json<T>: FromRequest<()>,
158 /// T: 'static,
159 /// {
160 /// type Rejection = Response;
161 ///
162 /// async fn from_request(mut req: Request, _state: &S) -> Result<Self, Self::Rejection> {
163 /// let path_params = req
164 /// .extract_parts::<Path<_>>()
165 /// .await
166 /// .map(|Path(path_params)| path_params)
167 /// .map_err(|err| err.into_response())?;
168 ///
169 /// let Json(payload) = req
170 /// .extract::<Json<T>, _>()
171 /// .await
172 /// .map_err(|err| err.into_response())?;
173 ///
174 /// Ok(Self { path_params, payload })
175 /// }
176 /// }
177 /// ```
178 fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
179 where
180 E: FromRequestParts<()> + 'static;
181
182 /// Apply a parts extractor that requires some state to this `Request`.
183 ///
184 /// This is just a convenience for `E::from_request_parts(parts, state)`.
185 ///
186 /// # Example
187 ///
188 /// ```
189 /// use axum::{
190 /// extract::{Request, FromRef, FromRequest, FromRequestParts},
191 /// http::request::Parts,
192 /// response::{IntoResponse, Response},
193 /// body::Body,
194 /// Json, RequestExt,
195 /// };
196 ///
197 /// struct MyExtractor<T> {
198 /// requires_state: RequiresState,
199 /// payload: T,
200 /// }
201 ///
202 /// impl<S, T> FromRequest<S> for MyExtractor<T>
203 /// where
204 /// String: FromRef<S>,
205 /// Json<T>: FromRequest<()>,
206 /// T: 'static,
207 /// S: Send + Sync,
208 /// {
209 /// type Rejection = Response;
210 ///
211 /// async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
212 /// let requires_state = req
213 /// .extract_parts_with_state::<RequiresState, _>(state)
214 /// .await
215 /// .map_err(|err| err.into_response())?;
216 ///
217 /// let Json(payload) = req
218 /// .extract::<Json<T>, _>()
219 /// .await
220 /// .map_err(|err| err.into_response())?;
221 ///
222 /// Ok(Self {
223 /// requires_state,
224 /// payload,
225 /// })
226 /// }
227 /// }
228 ///
229 /// struct RequiresState {}
230 ///
231 /// impl<S> FromRequestParts<S> for RequiresState
232 /// where
233 /// String: FromRef<S>,
234 /// S: Send + Sync,
235 /// {
236 /// // ...
237 /// # type Rejection = std::convert::Infallible;
238 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
239 /// # todo!()
240 /// # }
241 /// }
242 /// ```
243 fn extract_parts_with_state<'a, E, S>(
244 &'a mut self,
245 state: &'a S,
246 ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
247 where
248 E: FromRequestParts<S> + 'static,
249 S: Send + Sync;
250
251 /// Apply the [default body limit](crate::extract::DefaultBodyLimit).
252 ///
253 /// If it is disabled, the request is returned as-is.
254 fn with_limited_body(self) -> Request;
255
256 /// Consumes the request, returning the body wrapped in [`http_body_util::Limited`] if a
257 /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
258 /// default limit is disabled.
259 fn into_limited_body(self) -> Body;
260}
261
262impl RequestExt for Request {
263 fn extract<E, M>(self) -> impl Future<Output = Result<E, E::Rejection>> + Send
264 where
265 E: FromRequest<(), M> + 'static,
266 M: 'static,
267 {
268 self.extract_with_state(&())
269 }
270
271 fn extract_with_state<E, S, M>(
272 self,
273 state: &S,
274 ) -> impl Future<Output = Result<E, E::Rejection>> + Send
275 where
276 E: FromRequest<S, M> + 'static,
277 S: Send + Sync,
278 {
279 E::from_request(self, state)
280 }
281
282 fn extract_parts<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
283 where
284 E: FromRequestParts<()> + 'static,
285 {
286 self.extract_parts_with_state(&())
287 }
288
289 async fn extract_parts_with_state<'a, E, S>(
290 &'a mut self,
291 state: &'a S,
292 ) -> Result<E, E::Rejection>
293 where
294 E: FromRequestParts<S> + 'static,
295 S: Send + Sync,
296 {
297 let mut req = Request::new(());
298 *req.version_mut() = self.version();
299 *req.method_mut() = self.method().clone();
300 *req.uri_mut() = self.uri().clone();
301 *req.headers_mut() = std::mem::take(self.headers_mut());
302 *req.extensions_mut() = std::mem::take(self.extensions_mut());
303 let (mut parts, ()) = req.into_parts();
304
305 let result = E::from_request_parts(&mut parts, state).await;
306
307 *self.version_mut() = parts.version;
308 *self.method_mut() = parts.method.clone();
309 *self.uri_mut() = parts.uri.clone();
310 *self.headers_mut() = std::mem::take(&mut parts.headers);
311 *self.extensions_mut() = std::mem::take(&mut parts.extensions);
312
313 result
314 }
315
316 fn with_limited_body(self) -> Request {
317 // update docs in `axum-core/src/extract/default_body_limit.rs` and
318 // `axum/src/docs/extract.md` if this changes
319 const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
320
321 match self.extensions().get::<DefaultBodyLimitKind>().copied() {
322 Some(DefaultBodyLimitKind::Disable) => self,
323 Some(DefaultBodyLimitKind::Limit(limit)) => {
324 self.map(|b| Body::new(http_body_util::Limited::new(b, limit)))
325 }
326 None => self.map(|b| Body::new(http_body_util::Limited::new(b, DEFAULT_LIMIT))),
327 }
328 }
329
330 fn into_limited_body(self) -> Body {
331 self.with_limited_body().into_body()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::{
339 ext_traits::tests::{RequiresState, State},
340 extract::FromRef,
341 };
342 use http::Method;
343
344 #[tokio::test]
345 async fn extract_without_state() {
346 let req = Request::new(Body::empty());
347
348 let method: Method = req.extract().await.unwrap();
349
350 assert_eq!(method, Method::GET);
351 }
352
353 #[tokio::test]
354 async fn extract_body_without_state() {
355 let req = Request::new(Body::from("foobar"));
356
357 let body: String = req.extract().await.unwrap();
358
359 assert_eq!(body, "foobar");
360 }
361
362 #[tokio::test]
363 async fn extract_with_state() {
364 let req = Request::new(Body::empty());
365
366 let state = "state".to_owned();
367
368 let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();
369
370 assert_eq!(extracted_state, state);
371 }
372
373 #[tokio::test]
374 async fn extract_parts_without_state() {
375 let mut req = Request::builder()
376 .header("x-foo", "foo")
377 .body(Body::empty())
378 .unwrap();
379
380 let method: Method = req.extract_parts().await.unwrap();
381
382 assert_eq!(method, Method::GET);
383 assert_eq!(req.headers()["x-foo"], "foo");
384 }
385
386 #[tokio::test]
387 async fn extract_parts_with_state() {
388 let mut req = Request::builder()
389 .header("x-foo", "foo")
390 .body(Body::empty())
391 .unwrap();
392
393 let state = "state".to_owned();
394
395 let State(extracted_state): State<String> =
396 req.extract_parts_with_state(&state).await.unwrap();
397
398 assert_eq!(extracted_state, state);
399 assert_eq!(req.headers()["x-foo"], "foo");
400 }
401
402 // this stuff just needs to compile
403 #[allow(dead_code)]
404 struct WorksForCustomExtractor {
405 method: Method,
406 from_state: String,
407 body: String,
408 }
409
410 impl<S> FromRequest<S> for WorksForCustomExtractor
411 where
412 S: Send + Sync,
413 String: FromRef<S> + FromRequest<()>,
414 {
415 type Rejection = <String as FromRequest<()>>::Rejection;
416
417 async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
418 let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
419 let method = req.extract_parts().await.unwrap();
420 let body = req.extract().await?;
421
422 Ok(Self {
423 method,
424 from_state,
425 body,
426 })
427 }
428 }
429}