1use crate::extract::Request;
2use crate::extract::{rejection::*, FromRequest, RawForm};
3use async_trait::async_trait;
4use axum_core::response::{IntoResponse, Response};
5use axum_core::RequestExt;
6use http::header::CONTENT_TYPE;
7use http::StatusCode;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10
11#[cfg_attr(docsrs, doc(cfg(feature = "form")))]
71#[derive(Debug, Clone, Copy, Default)]
72#[must_use]
73pub struct Form<T>(pub T);
74
75#[async_trait]
76impl<T, S> FromRequest<S> for Form<T>
77where
78 T: DeserializeOwned,
79 S: Send + Sync,
80{
81 type Rejection = FormRejection;
82
83 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
84 let is_get_or_head =
85 req.method() == http::Method::GET || req.method() == http::Method::HEAD;
86
87 match req.extract().await {
88 Ok(RawForm(bytes)) => {
89 let value =
90 serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection {
91 if is_get_or_head {
92 FailedToDeserializeForm::from_err(err).into()
93 } else {
94 FailedToDeserializeFormBody::from_err(err).into()
95 }
96 })?;
97 Ok(Form(value))
98 }
99 Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
100 Err(RawFormRejection::InvalidFormContentType(r)) => {
101 Err(FormRejection::InvalidFormContentType(r))
102 }
103 }
104 }
105}
106
107impl<T> IntoResponse for Form<T>
108where
109 T: Serialize,
110{
111 fn into_response(self) -> Response {
112 match serde_urlencoded::to_string(&self.0) {
113 Ok(body) => (
114 [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
115 body,
116 )
117 .into_response(),
118 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
119 }
120 }
121}
122
123axum_core::__impl_deref!(Form);
124
125#[cfg(test)]
126mod tests {
127 use crate::{
128 routing::{on, MethodFilter},
129 test_helpers::TestClient,
130 Router,
131 };
132
133 use super::*;
134 use axum_core::body::Body;
135 use http::{Method, Request};
136 use mime::APPLICATION_WWW_FORM_URLENCODED;
137 use serde::{Deserialize, Serialize};
138 use std::fmt::Debug;
139
140 #[derive(Debug, PartialEq, Serialize, Deserialize)]
141 struct Pagination {
142 size: Option<u64>,
143 page: Option<u64>,
144 }
145
146 async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
147 let req = Request::builder()
148 .uri(uri.as_ref())
149 .body(Body::empty())
150 .unwrap();
151 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
152 }
153
154 async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
155 let req = Request::builder()
156 .uri("http://example.com/test")
157 .method(Method::POST)
158 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
159 .body(Body::from(serde_urlencoded::to_string(&value).unwrap()))
160 .unwrap();
161 assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
162 }
163
164 #[crate::test]
165 async fn test_form_query() {
166 check_query(
167 "http://example.com/test",
168 Pagination {
169 size: None,
170 page: None,
171 },
172 )
173 .await;
174
175 check_query(
176 "http://example.com/test?size=10",
177 Pagination {
178 size: Some(10),
179 page: None,
180 },
181 )
182 .await;
183
184 check_query(
185 "http://example.com/test?size=10&page=20",
186 Pagination {
187 size: Some(10),
188 page: Some(20),
189 },
190 )
191 .await;
192 }
193
194 #[crate::test]
195 async fn test_form_body() {
196 check_body(Pagination {
197 size: None,
198 page: None,
199 })
200 .await;
201
202 check_body(Pagination {
203 size: Some(10),
204 page: None,
205 })
206 .await;
207
208 check_body(Pagination {
209 size: Some(10),
210 page: Some(20),
211 })
212 .await;
213 }
214
215 #[crate::test]
216 async fn test_incorrect_content_type() {
217 let req = Request::builder()
218 .uri("http://example.com/test")
219 .method(Method::POST)
220 .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
221 .body(Body::from(
222 serde_urlencoded::to_string(&Pagination {
223 size: Some(10),
224 page: None,
225 })
226 .unwrap(),
227 ))
228 .unwrap();
229 assert!(matches!(
230 Form::<Pagination>::from_request(req, &())
231 .await
232 .unwrap_err(),
233 FormRejection::InvalidFormContentType(InvalidFormContentType)
234 ));
235 }
236
237 #[tokio::test]
238 async fn deserialize_error_status_codes() {
239 #[allow(dead_code)]
240 #[derive(Deserialize)]
241 struct Payload {
242 a: i32,
243 }
244
245 let app = Router::new().route(
246 "/",
247 on(
248 MethodFilter::GET.or(MethodFilter::POST),
249 |_: Form<Payload>| async {},
250 ),
251 );
252
253 let client = TestClient::new(app);
254
255 let res = client.get("/?a=false").await;
256 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
257
258 let res = client
259 .post("/")
260 .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
261 .body("a=false")
262 .await;
263 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
264 }
265}