axum/
form.rs

1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use async_trait::async_trait;
4use axum_core::response::{IntoResponse, Response};
5use axum_core::RequestExt;
6use http::header::CONTENT_TYPE;
7use http::StatusCode;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10
11/// URL encoded extractor and response.
12///
13/// # As extractor
14///
15/// If used as an extractor, `Form` will deserialize form data from the request,
16/// specifically:
17///
18/// - If the request has a method of `GET` or `HEAD`, the form data will be read
19///   from the query string (same as with [`Query`])
20/// - If the request has a different method, the form will be read from the body
21///   of the request. It must have a `content-type` of
22///   `application/x-www-form-urlencoded` for this to work. If you want to parse
23///   `multipart/form-data` request bodies, use [`Multipart`] instead.
24///
25/// This matches how HTML forms are sent by browsers by default.
26/// In both cases, the inner type `T` must implement [`serde::Deserialize`].
27///
28/// ⚠️ Since parsing form data might require consuming the request body, the `Form` extractor must be
29/// *last* if there are multiple extractors in a handler. See ["the order of
30/// extractors"][order-of-extractors]
31///
32/// [order-of-extractors]: crate::extract#the-order-of-extractors
33///
34/// ```rust
35/// use axum::Form;
36/// use serde::Deserialize;
37///
38/// #[derive(Deserialize)]
39/// struct SignUp {
40///     username: String,
41///     password: String,
42/// }
43///
44/// async fn accept_form(Form(sign_up): Form<SignUp>) {
45///     // ...
46/// }
47/// ```
48///
49/// # As response
50///
51/// `Form` can also be used to encode any type that implements
52/// [`serde::Serialize`] as `application/x-www-form-urlencoded`
53///
54/// ```rust
55/// use axum::Form;
56/// use serde::Serialize;
57///
58/// #[derive(Serialize)]
59/// struct Payload {
60///     value: String,
61/// }
62///
63/// async fn handler() -> Form<Payload> {
64///     Form(Payload { value: "foo".to_owned() })
65/// }
66/// ```
67///
68/// [`Query`]: crate::extract::Query
69/// [`Multipart`]: crate::extract::Multipart
70#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
71#[derive(Debug, Clone, Copy, Default)]
72#[must_use]
73pub struct Form<T>(pub T);
74
75#[async_trait]
76impl<T, S> FromRequest<S> for Form<T>
77where
78    T: DeserializeOwned,
79    S: Send + Sync,
80{
81    type Rejection = FormRejection;
82
83    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
84        let is_get_or_head =
85            req.method() == http::Method::GET || req.method() == http::Method::HEAD;
86
87        match req.extract().await {
88            Ok(RawForm(bytes)) => {
89                let value =
90                    serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection {
91                        if is_get_or_head {
92                            FailedToDeserializeForm::from_err(err).into()
93                        } else {
94                            FailedToDeserializeFormBody::from_err(err).into()
95                        }
96                    })?;
97                Ok(Form(value))
98            }
99            Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
100            Err(RawFormRejection::InvalidFormContentType(r)) => {
101                Err(FormRejection::InvalidFormContentType(r))
102            }
103        }
104    }
105}
106
107impl<T> IntoResponse for Form<T>
108where
109    T: Serialize,
110{
111    fn into_response(self) -> Response {
112        match serde_urlencoded::to_string(&self.0) {
113            Ok(body) => (
114                [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
115                body,
116            )
117                .into_response(),
118            Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
119        }
120    }
121}
122
123axum_core::__impl_deref!(Form);
124
125#[cfg(test)]
126mod tests {
127    use crate::{
128        routing::{on, MethodFilter},
129        test_helpers::TestClient,
130        Router,
131    };
132
133    use super::*;
134    use axum_core::body::Body;
135    use http::{Method, Request};
136    use mime::APPLICATION_WWW_FORM_URLENCODED;
137    use serde::{Deserialize, Serialize};
138    use std::fmt::Debug;
139
140    #[derive(Debug, PartialEq, Serialize, Deserialize)]
141    struct Pagination {
142        size: Option<u64>,
143        page: Option<u64>,
144    }
145
146    async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
147        let req = Request::builder()
148            .uri(uri.as_ref())
149            .body(Body::empty())
150            .unwrap();
151        assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
152    }
153
154    async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
155        let req = Request::builder()
156            .uri("http://example.com/test")
157            .method(Method::POST)
158            .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
159            .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
160            .unwrap();
161        assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
162    }
163
164    #[crate::test]
165    async fn test_form_query() {
166        check_query(
167            "http://example.com/test",
168            Pagination {
169                size: None,
170                page: None,
171            },
172        )
173        .await;
174
175        check_query(
176            "http://example.com/test?size=10",
177            Pagination {
178                size: Some(10),
179                page: None,
180            },
181        )
182        .await;
183
184        check_query(
185            "http://example.com/test?size=10&page=20",
186            Pagination {
187                size: Some(10),
188                page: Some(20),
189            },
190        )
191        .await;
192    }
193
194    #[crate::test]
195    async fn test_form_body() {
196        check_body(Pagination {
197            size: None,
198            page: None,
199        })
200        .await;
201
202        check_body(Pagination {
203            size: Some(10),
204            page: None,
205        })
206        .await;
207
208        check_body(Pagination {
209            size: Some(10),
210            page: Some(20),
211        })
212        .await;
213    }
214
215    #[crate::test]
216    async fn test_incorrect_content_type() {
217        let req = Request::builder()
218            .uri("http://example.com/test")
219            .method(Method::POST)
220            .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
221            .body(Body::from(
222                serde_urlencoded::to_string(&Pagination {
223                    size: Some(10),
224                    page: None,
225                })
226                .unwrap(),
227            ))
228            .unwrap();
229        assert!(matches!(
230            Form::<Pagination>::from_request(req, &())
231                .await
232                .unwrap_err(),
233            FormRejection::InvalidFormContentType(InvalidFormContentType)
234        ));
235    }
236
237    #[tokio::test]
238    async fn deserialize_error_status_codes() {
239        #[allow(dead_code)]
240        #[derive(Deserialize)]
241        struct Payload {
242            a: i32,
243        }
244
245        let app = Router::new().route(
246            "/",
247            on(
248                MethodFilter::GET.or(MethodFilter::POST),
249                |_: Form<Payload>| async {},
250            ),
251        );
252
253        let client = TestClient::new(app);
254
255        let res = client.get("/?a=false").await;
256        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
257
258        let res = client
259            .post("/")
260            .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
261            .body("a=false")
262            .await;
263        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
264    }
265}