axum/extract/
query.rs
1use super::{rejection::*, FromRequestParts};
2use async_trait::async_trait;
3use http::{request::Parts, Uri};
4use serde::de::DeserializeOwned;
5
6#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
51#[derive(Debug, Clone, Copy, Default)]
52pub struct Query<T>(pub T);
53
54#[async_trait]
55impl<T, S> FromRequestParts<S> for Query<T>
56where
57 T: DeserializeOwned,
58 S: Send + Sync,
59{
60 type Rejection = QueryRejection;
61
62 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
63 Self::try_from_uri(&parts.uri)
64 }
65}
66
67impl<T> Query<T>
68where
69 T: DeserializeOwned,
70{
71 pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
91 let query = value.query().unwrap_or_default();
92 let params =
93 serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
94 Ok(Query(params))
95 }
96}
97
98axum_core::__impl_deref!(Query);
99
100#[cfg(test)]
101mod tests {
102 use crate::{routing::get, test_helpers::TestClient, Router};
103
104 use super::*;
105 use axum_core::{body::Body, extract::FromRequest};
106 use http::{Request, StatusCode};
107 use serde::Deserialize;
108 use std::fmt::Debug;
109
110 async fn check<T>(uri: impl AsRef<str>, value: T)
111 where
112 T: DeserializeOwned + PartialEq + Debug,
113 {
114 let req = Request::builder()
115 .uri(uri.as_ref())
116 .body(Body::empty())
117 .unwrap();
118 assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
119 }
120
121 #[crate::test]
122 async fn test_query() {
123 #[derive(Debug, PartialEq, Deserialize)]
124 struct Pagination {
125 size: Option<u64>,
126 page: Option<u64>,
127 }
128
129 check(
130 "http://example.com/test",
131 Pagination {
132 size: None,
133 page: None,
134 },
135 )
136 .await;
137
138 check(
139 "http://example.com/test?size=10",
140 Pagination {
141 size: Some(10),
142 page: None,
143 },
144 )
145 .await;
146
147 check(
148 "http://example.com/test?size=10&page=20",
149 Pagination {
150 size: Some(10),
151 page: Some(20),
152 },
153 )
154 .await;
155 }
156
157 #[crate::test]
158 async fn correct_rejection_status_code() {
159 #[derive(Deserialize)]
160 #[allow(dead_code)]
161 struct Params {
162 n: i32,
163 }
164
165 async fn handler(_: Query<Params>) {}
166
167 let app = Router::new().route("/", get(handler));
168 let client = TestClient::new(app);
169
170 let res = client.get("/?n=hi").await;
171 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172 }
173
174 #[test]
175 fn test_try_from_uri() {
176 #[derive(Deserialize)]
177 struct TestQueryParams {
178 foo: String,
179 bar: u32,
180 }
181 let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
182 let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
183 assert_eq!(result.foo, String::from("hello"));
184 assert_eq!(result.bar, 42);
185 }
186
187 #[test]
188 fn test_try_from_uri_with_invalid_query() {
189 #[derive(Deserialize)]
190 struct TestQueryParams {
191 _foo: String,
192 _bar: u32,
193 }
194 let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
195 .parse()
196 .unwrap();
197 let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
198
199 assert!(result.is_err());
200 }
201}