axum_core/ext_traits/
request_parts.rs

1use crate::extract::FromRequestParts;
2use http::request::Parts;
3use std::future::Future;
4
5mod sealed {
6    pub trait Sealed {}
7    impl Sealed for http::request::Parts {}
8}
9
10/// Extension trait that adds additional methods to [`Parts`].
11pub trait RequestPartsExt: sealed::Sealed + Sized {
12    /// Apply an extractor to this `Parts`.
13    ///
14    /// This is just a convenience for `E::from_request_parts(parts, &())`.
15    ///
16    /// # Example
17    ///
18    /// ```
19    /// use axum::{
20    ///     extract::{Query, Path, FromRequestParts},
21    ///     response::{Response, IntoResponse},
22    ///     http::request::Parts,
23    ///     RequestPartsExt,
24    /// };
25    /// use std::collections::HashMap;
26    ///
27    /// struct MyExtractor {
28    ///     path_params: HashMap<String, String>,
29    ///     query_params: HashMap<String, String>,
30    /// }
31    ///
32    /// impl<S> FromRequestParts<S> for MyExtractor
33    /// where
34    ///     S: Send + Sync,
35    /// {
36    ///     type Rejection = Response;
37    ///
38    ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
39    ///         let path_params = parts
40    ///             .extract::<Path<HashMap<String, String>>>()
41    ///             .await
42    ///             .map(|Path(path_params)| path_params)
43    ///             .map_err(|err| err.into_response())?;
44    ///
45    ///         let query_params = parts
46    ///             .extract::<Query<HashMap<String, String>>>()
47    ///             .await
48    ///             .map(|Query(params)| params)
49    ///             .map_err(|err| err.into_response())?;
50    ///
51    ///         Ok(MyExtractor { path_params, query_params })
52    ///     }
53    /// }
54    /// ```
55    fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
56    where
57        E: FromRequestParts<()> + 'static;
58
59    /// Apply an extractor that requires some state to this `Parts`.
60    ///
61    /// This is just a convenience for `E::from_request_parts(parts, state)`.
62    ///
63    /// # Example
64    ///
65    /// ```
66    /// use axum::{
67    ///     extract::{FromRef, FromRequestParts},
68    ///     response::{Response, IntoResponse},
69    ///     http::request::Parts,
70    ///     RequestPartsExt,
71    /// };
72    ///
73    /// struct MyExtractor {
74    ///     requires_state: RequiresState,
75    /// }
76    ///
77    /// impl<S> FromRequestParts<S> for MyExtractor
78    /// where
79    ///     String: FromRef<S>,
80    ///     S: Send + Sync,
81    /// {
82    ///     type Rejection = std::convert::Infallible;
83    ///
84    ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
85    ///         let requires_state = parts
86    ///             .extract_with_state::<RequiresState, _>(state)
87    ///             .await?;
88    ///
89    ///         Ok(MyExtractor { requires_state })
90    ///     }
91    /// }
92    ///
93    /// struct RequiresState { /* ... */ }
94    ///
95    /// // some extractor that requires a `String` in the state
96    /// impl<S> FromRequestParts<S> for RequiresState
97    /// where
98    ///     String: FromRef<S>,
99    ///     S: Send + Sync,
100    /// {
101    ///     // ...
102    ///     # type Rejection = std::convert::Infallible;
103    ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
104    ///     #     unimplemented!()
105    ///     # }
106    /// }
107    /// ```
108    fn extract_with_state<'a, E, S>(
109        &'a mut self,
110        state: &'a S,
111    ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
112    where
113        E: FromRequestParts<S> + 'static,
114        S: Send + Sync;
115}
116
117impl RequestPartsExt for Parts {
118    fn extract<E>(&mut self) -> impl Future<Output = Result<E, E::Rejection>> + Send
119    where
120        E: FromRequestParts<()> + 'static,
121    {
122        self.extract_with_state(&())
123    }
124
125    fn extract_with_state<'a, E, S>(
126        &'a mut self,
127        state: &'a S,
128    ) -> impl Future<Output = Result<E, E::Rejection>> + Send + 'a
129    where
130        E: FromRequestParts<S> + 'static,
131        S: Send + Sync,
132    {
133        E::from_request_parts(self, state)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use std::convert::Infallible;
140
141    use super::*;
142    use crate::{
143        ext_traits::tests::{RequiresState, State},
144        extract::FromRef,
145    };
146    use http::{Method, Request};
147
148    #[tokio::test]
149    async fn extract_without_state() {
150        let (mut parts, _) = Request::new(()).into_parts();
151
152        let method: Method = parts.extract().await.unwrap();
153
154        assert_eq!(method, Method::GET);
155    }
156
157    #[tokio::test]
158    async fn extract_with_state() {
159        let (mut parts, _) = Request::new(()).into_parts();
160
161        let state = "state".to_owned();
162
163        let State(extracted_state): State<String> = parts
164            .extract_with_state::<State<String>, String>(&state)
165            .await
166            .unwrap();
167
168        assert_eq!(extracted_state, state);
169    }
170
171    // this stuff just needs to compile
172    #[allow(dead_code)]
173    struct WorksForCustomExtractor {
174        method: Method,
175        from_state: String,
176    }
177
178    impl<S> FromRequestParts<S> for WorksForCustomExtractor
179    where
180        S: Send + Sync,
181        String: FromRef<S>,
182    {
183        type Rejection = Infallible;
184
185        async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
186            let RequiresState(from_state) = parts.extract_with_state(state).await?;
187            let method = parts.extract().await?;
188
189            Ok(Self { method, from_state })
190        }
191    }
192}