aws_smithy_http/
query_writer.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::query::fmt_string as percent_encode_query;
7use http_02x::uri::InvalidUri;
8use http_02x::Uri;
9
10/// Utility for updating the query string in a [`Uri`].
11#[allow(missing_debug_implementations)]
12pub struct QueryWriter {
13    base_uri: Uri,
14    new_path_and_query: String,
15    prefix: Option<char>,
16}
17
18impl QueryWriter {
19    /// Creates a new `QueryWriter` from a string
20    pub fn new_from_string(uri: &str) -> Result<Self, InvalidUri> {
21        Ok(Self::new(&Uri::try_from(uri)?))
22    }
23
24    /// Creates a new `QueryWriter` based off the given `uri`.
25    pub fn new(uri: &Uri) -> Self {
26        let new_path_and_query = uri
27            .path_and_query()
28            .map(|pq| pq.to_string())
29            .unwrap_or_default();
30        let prefix = if uri.query().is_none() {
31            Some('?')
32        } else if !uri.query().unwrap_or_default().is_empty() {
33            Some('&')
34        } else {
35            None
36        };
37        QueryWriter {
38            base_uri: uri.clone(),
39            new_path_and_query,
40            prefix,
41        }
42    }
43
44    /// Clears all query parameters.
45    pub fn clear_params(&mut self) {
46        if let Some(index) = self.new_path_and_query.find('?') {
47            self.new_path_and_query.truncate(index);
48            self.prefix = Some('?');
49        }
50    }
51
52    /// Inserts a new query parameter. The key and value are percent encoded
53    /// by `QueryWriter`. Passing in percent encoded values will result in double encoding.
54    pub fn insert(&mut self, k: &str, v: &str) {
55        if let Some(prefix) = self.prefix {
56            self.new_path_and_query.push(prefix);
57        }
58        self.prefix = Some('&');
59        self.new_path_and_query.push_str(&percent_encode_query(k));
60        self.new_path_and_query.push('=');
61
62        self.new_path_and_query.push_str(&percent_encode_query(v));
63    }
64
65    /// Returns just the built query string.
66    pub fn build_query(self) -> String {
67        self.build_uri().query().unwrap_or_default().to_string()
68    }
69
70    /// Returns a full [`Uri`] with the query string updated.
71    pub fn build_uri(self) -> Uri {
72        let mut parts = self.base_uri.into_parts();
73        parts.path_and_query = Some(
74            self.new_path_and_query
75                .parse()
76                .expect("adding query should not invalidate URI"),
77        );
78        Uri::from_parts(parts).expect("a valid URL in should always produce a valid URL out")
79    }
80}
81
82#[cfg(test)]
83mod test {
84    use super::QueryWriter;
85    use http_02x::Uri;
86
87    #[test]
88    fn empty_uri() {
89        let uri = Uri::from_static("http://www.example.com");
90        let mut query_writer = QueryWriter::new(&uri);
91        query_writer.insert("key", "val%ue");
92        query_writer.insert("another", "value");
93        assert_eq!(
94            query_writer.build_uri(),
95            Uri::from_static("http://www.example.com?key=val%25ue&another=value")
96        );
97    }
98
99    #[test]
100    fn uri_with_path() {
101        let uri = Uri::from_static("http://www.example.com/path");
102        let mut query_writer = QueryWriter::new(&uri);
103        query_writer.insert("key", "val%ue");
104        query_writer.insert("another", "value");
105        assert_eq!(
106            query_writer.build_uri(),
107            Uri::from_static("http://www.example.com/path?key=val%25ue&another=value")
108        );
109    }
110
111    #[test]
112    fn uri_with_path_and_query() {
113        let uri = Uri::from_static("http://www.example.com/path?original=here");
114        let mut query_writer = QueryWriter::new(&uri);
115        query_writer.insert("key", "val%ue");
116        query_writer.insert("another", "value");
117        assert_eq!(
118            query_writer.build_uri(),
119            Uri::from_static(
120                "http://www.example.com/path?original=here&key=val%25ue&another=value"
121            )
122        );
123    }
124
125    #[test]
126    fn build_query() {
127        let uri = Uri::from_static("http://www.example.com");
128        let mut query_writer = QueryWriter::new(&uri);
129        query_writer.insert("key", "val%ue");
130        query_writer.insert("ano%ther", "value");
131        assert_eq!("key=val%25ue&ano%25ther=value", query_writer.build_query());
132    }
133
134    #[test]
135    // This test ensures that the percent encoding applied to queries always produces a valid URI if
136    // the starting URI is valid
137    fn doesnt_panic_when_adding_query_to_valid_uri() {
138        let uri = Uri::from_static("http://www.example.com");
139
140        let mut problematic_chars = Vec::new();
141
142        for byte in u8::MIN..=u8::MAX {
143            match std::str::from_utf8(&[byte]) {
144                // If we can't make a str from the byte then we certainly can't make a URL from it
145                Err(_) => {
146                    continue;
147                }
148                Ok(value) => {
149                    let mut query_writer = QueryWriter::new(&uri);
150                    query_writer.insert("key", value);
151
152                    if std::panic::catch_unwind(|| query_writer.build_uri()).is_err() {
153                        problematic_chars.push(char::from(byte));
154                    };
155                }
156            }
157        }
158
159        if !problematic_chars.is_empty() {
160            panic!("we got some bad bytes here: {:#?}", problematic_chars)
161        }
162    }
163
164    #[test]
165    fn clear_params() {
166        let uri = Uri::from_static("http://www.example.com/path?original=here&foo=1");
167        let mut query_writer = QueryWriter::new(&uri);
168        query_writer.clear_params();
169        query_writer.insert("new", "value");
170        assert_eq!("new=value", query_writer.build_query());
171    }
172}