axum_core/ext_traits/request_parts.rs
1use crate::extract::FromRequestParts;
2use http::request::Parts;
3use std::future::Future;
4
5mod sealed {
6 pub trait Sealed {}
7 impl Sealed for http::request::Parts {}
8}
9
10/// Extension trait that adds additional methods to [`Parts`].
11pub trait RequestPartsExt: sealed::Sealed + Sized {
12 /// Apply an extractor to this `Parts`.
13 ///
14 /// This is just a convenience for `E::from_request_parts(parts, &())`.
15 ///
16 /// # Example
17 ///
18 /// ```
19 /// use axum::{
20 /// extract::{Query, Path, FromRequestParts},
21 /// response::{Response, IntoResponse},
22 /// http::request::Parts,
23 /// RequestPartsExt,
24 /// };
25 /// use std::collections::HashMap;
26 ///
27 /// struct MyExtractor {
28 /// path_params: HashMap<String, String>,
29 /// query_params: HashMap<String, String>,
30 /// }
31 ///
32 /// impl<S> FromRequestParts<S> for MyExtractor
33 /// where
34 /// S: Send + Sync,
35 /// {
36 /// type Rejection = Response;
37 ///
38 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
39 /// let path_params = parts
40 /// .extract::<Path<HashMap<String, String>>>()
41 /// .await
42 /// .map(|Path(path_params)| path_params)
43 /// .map_err(|err| err.into_response())?;
44 ///
45 /// let query_params = parts
46 /// .extract::<Query<HashMap<String, String>>>()
47 /// .await
48 /// .map(|Query(params)| params)
49 /// .map_err(|err| err.into_response())?;
50 ///
51 /// Ok(MyExtractor { path_params, query_params })
52 /// }
53 /// }
54 /// ```
55 fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
56 where
57 E: FromRequestParts<()> + 'static;
58
59 /// Apply an extractor that requires some state to this `Parts`.
60 ///
61 /// This is just a convenience for `E::from_request_parts(parts, state)`.
62 ///
63 /// # Example
64 ///
65 /// ```
66 /// use axum::{
67 /// extract::{FromRef, FromRequestParts},
68 /// response::{Response, IntoResponse},
69 /// http::request::Parts,
70 /// RequestPartsExt,
71 /// };
72 ///
73 /// struct MyExtractor {
74 /// requires_state: RequiresState,
75 /// }
76 ///
77 /// impl<S> FromRequestParts<S> for MyExtractor
78 /// where
79 /// String: FromRef<S>,
80 /// S: Send + Sync,
81 /// {
82 /// type Rejection = std::convert::Infallible;
83 ///
84 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
85 /// let requires_state = parts
86 /// .extract_with_state::<RequiresState, _>(state)
87 /// .await?;
88 ///
89 /// Ok(MyExtractor { requires_state })
90 /// }
91 /// }
92 ///
93 /// struct RequiresState { /* ... */ }
94 ///
95 /// // some extractor that requires a `String` in the state
96 /// impl<S> FromRequestParts<S> for RequiresState
97 /// where
98 /// String: FromRef<S>,
99 /// S: Send + Sync,
100 /// {
101 /// // ...
102 /// # type Rejection = std::convert::Infallible;
103 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
104 /// # unimplemented!()
105 /// # }
106 /// }
107 /// ```
108 fn extract_with_state<'a, E, S>(
109 &'a mut self,
110 state: &'a S,
111 ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
112 where
113 E: FromRequestParts<S> + 'static,
114 S: Send + Sync;
115}
116
117impl RequestPartsExt for Parts {
118 fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
119 where
120 E: FromRequestParts<()> + 'static,
121 {
122 self.extract_with_state(&())
123 }
124
125 fn extract_with_state<'a, E, S>(
126 &'a mut self,
127 state: &'a S,
128 ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
129 where
130 E: FromRequestParts<S> + 'static,
131 S: Send + Sync,
132 {
133 E::from_request_parts(self, state)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::convert::Infallible;
140
141 use super::*;
142 use crate::{
143 ext_traits::tests::{RequiresState, State},
144 extract::FromRef,
145 };
146 use http::{Method, Request};
147
148 #[tokio::test]
149 async fn extract_without_state() {
150 let (mut parts, _) = Request::new(()).into_parts();
151
152 let method: Method = parts.extract().await.unwrap();
153
154 assert_eq!(method, Method::GET);
155 }
156
157 #[tokio::test]
158 async fn extract_with_state() {
159 let (mut parts, _) = Request::new(()).into_parts();
160
161 let state = "state".to_owned();
162
163 let State(extracted_state): State<String> = parts
164 .extract_with_state::<State<String>, String>(&state)
165 .await
166 .unwrap();
167
168 assert_eq!(extracted_state, state);
169 }
170
171 // this stuff just needs to compile
172 #[allow(dead_code)]
173 struct WorksForCustomExtractor {
174 method: Method,
175 from_state: String,
176 }
177
178 impl<S> FromRequestParts<S> for WorksForCustomExtractor
179 where
180 S: Send + Sync,
181 String: FromRef<S>,
182 {
183 type Rejection = Infallible;
184
185 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
186 let RequiresState(from_state) = parts.extract_with_state(state).await?;
187 let method = parts.extract().await?;
188
189 Ok(Self { method, from_state })
190 }
191 }
192}