axum/extract/
query.rs

1use super::{rejection::*, FromRequestParts};
2use async_trait::async_trait;
3use http::{request::Parts, Uri};
4use serde::de::DeserializeOwned;
5
6/// Extractor that deserializes query strings into some type.
7///
8/// `T` is expected to implement [`serde::Deserialize`].
9///
10/// # Example
11///
12/// ```rust,no_run
13/// use axum::{
14///     extract::Query,
15///     routing::get,
16///     Router,
17/// };
18/// use serde::Deserialize;
19///
20/// #[derive(Deserialize)]
21/// struct Pagination {
22///     page: usize,
23///     per_page: usize,
24/// }
25///
26/// // This will parse query strings like `?page=2&per_page=30` into `Pagination`
27/// // structs.
28/// async fn list_things(pagination: Query<Pagination>) {
29///     let pagination: Pagination = pagination.0;
30///
31///     // ...
32/// }
33///
34/// let app = Router::new().route("/list_things", get(list_things));
35/// # let _: Router = app;
36/// ```
37///
38/// If the query string cannot be parsed it will reject the request with a `400
39/// Bad Request` response.
40///
41/// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
42/// example.
43///
44/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
45///
46/// For handling multiple values for the same query parameter, in a `?foo=1&foo=2&foo=3`
47/// fashion, use [`axum_extra::extract::Query`] instead.
48///
49/// [`axum_extra::extract::Query`]: https://docs.rs/axum-extra/latest/axum_extra/extract/struct.Query.html
50#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
51#[derive(Debug, Clone, Copy, Default)]
52pub struct Query<T>(pub T);
53
54#[async_trait]
55impl<T, S> FromRequestParts<S> for Query<T>
56where
57    T: DeserializeOwned,
58    S: Send + Sync,
59{
60    type Rejection = QueryRejection;
61
62    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
63        Self::try_from_uri(&parts.uri)
64    }
65}
66
67impl<T> Query<T>
68where
69    T: DeserializeOwned,
70{
71    /// Attempts to construct a [`Query`] from a reference to a [`Uri`].
72    ///
73    /// # Example
74    /// ```
75    /// use axum::extract::Query;
76    /// use http::Uri;
77    /// use serde::Deserialize;
78    ///
79    /// #[derive(Deserialize)]
80    /// struct ExampleParams {
81    ///     foo: String,
82    ///     bar: u32,
83    /// }
84    ///
85    /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
86    /// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap();
87    /// assert_eq!(result.foo, String::from("hello"));
88    /// assert_eq!(result.bar, 42);
89    /// ```
90    pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
91        let query = value.query().unwrap_or_default();
92        let params =
93            serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
94        Ok(Query(params))
95    }
96}
97
98axum_core::__impl_deref!(Query);
99
100#[cfg(test)]
101mod tests {
102    use crate::{routing::get, test_helpers::TestClient, Router};
103
104    use super::*;
105    use axum_core::{body::Body, extract::FromRequest};
106    use http::{Request, StatusCode};
107    use serde::Deserialize;
108    use std::fmt::Debug;
109
110    async fn check<T>(uri: impl AsRef<str>, value: T)
111    where
112        T: DeserializeOwned + PartialEq + Debug,
113    {
114        let req = Request::builder()
115            .uri(uri.as_ref())
116            .body(Body::empty())
117            .unwrap();
118        assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
119    }
120
121    #[crate::test]
122    async fn test_query() {
123        #[derive(Debug, PartialEq, Deserialize)]
124        struct Pagination {
125            size: Option<u64>,
126            page: Option<u64>,
127        }
128
129        check(
130            "http://example.com/test",
131            Pagination {
132                size: None,
133                page: None,
134            },
135        )
136        .await;
137
138        check(
139            "http://example.com/test?size=10",
140            Pagination {
141                size: Some(10),
142                page: None,
143            },
144        )
145        .await;
146
147        check(
148            "http://example.com/test?size=10&page=20",
149            Pagination {
150                size: Some(10),
151                page: Some(20),
152            },
153        )
154        .await;
155    }
156
157    #[crate::test]
158    async fn correct_rejection_status_code() {
159        #[derive(Deserialize)]
160        #[allow(dead_code)]
161        struct Params {
162            n: i32,
163        }
164
165        async fn handler(_: Query<Params>) {}
166
167        let app = Router::new().route("/", get(handler));
168        let client = TestClient::new(app);
169
170        let res = client.get("/?n=hi").await;
171        assert_eq!(res.status(), StatusCode::BAD_REQUEST);
172    }
173
174    #[test]
175    fn test_try_from_uri() {
176        #[derive(Deserialize)]
177        struct TestQueryParams {
178            foo: String,
179            bar: u32,
180        }
181        let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
182        let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
183        assert_eq!(result.foo, String::from("hello"));
184        assert_eq!(result.bar, 42);
185    }
186
187    #[test]
188    fn test_try_from_uri_with_invalid_query() {
189        #[derive(Deserialize)]
190        struct TestQueryParams {
191            _foo: String,
192            _bar: u32,
193        }
194        let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
195            .parse()
196            .unwrap();
197        let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
198
199        assert!(result.is_err());
200    }
201}