aws_smithy_json/
escape.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::borrow::Cow;
7use std::fmt;
8
9#[derive(Debug, PartialEq, Eq)]
10enum EscapeErrorKind {
11    ExpectedSurrogatePair(String),
12    InvalidEscapeCharacter(char),
13    InvalidSurrogatePair(u16, u16),
14    InvalidUnicodeEscape(String),
15    InvalidUtf8,
16    UnexpectedEndOfString,
17}
18
19#[derive(Debug)]
20#[cfg_attr(test, derive(PartialEq, Eq))]
21pub struct EscapeError {
22    kind: EscapeErrorKind,
23}
24
25impl std::error::Error for EscapeError {}
26
27impl fmt::Display for EscapeError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        use EscapeErrorKind::*;
30        match &self.kind {
31            ExpectedSurrogatePair(low) => {
32                write!(
33                    f,
34                    "expected a UTF-16 surrogate pair, but got {} as the low word",
35                    low
36                )
37            }
38            InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{}", chr),
39            InvalidSurrogatePair(high, low) => {
40                write!(f, "invalid surrogate pair: \\u{:04X}\\u{:04X}", high, low)
41            }
42            InvalidUnicodeEscape(escape) => write!(f, "invalid JSON Unicode escape: \\u{}", escape),
43            InvalidUtf8 => write!(f, "invalid UTF-8 codepoint in JSON string"),
44            UnexpectedEndOfString => write!(f, "unexpected end of string"),
45        }
46    }
47}
48
49impl From<EscapeErrorKind> for EscapeError {
50    fn from(kind: EscapeErrorKind) -> Self {
51        Self { kind }
52    }
53}
54
55/// Escapes a string for embedding in a JSON string value.
56pub(crate) fn escape_string(value: &str) -> Cow<'_, str> {
57    let bytes = value.as_bytes();
58    for (index, byte) in bytes.iter().enumerate() {
59        match byte {
60            0..=0x1F | b'"' | b'\\' => {
61                return Cow::Owned(escape_string_inner(&bytes[0..index], &bytes[index..]))
62            }
63            _ => {}
64        }
65    }
66    Cow::Borrowed(value)
67}
68
69fn escape_string_inner(start: &[u8], rest: &[u8]) -> String {
70    let mut escaped = Vec::with_capacity(start.len() + rest.len() + 1);
71    escaped.extend(start);
72
73    for byte in rest {
74        match byte {
75            b'"' => escaped.extend(b"\\\""),
76            b'\\' => escaped.extend(b"\\\\"),
77            0x08 => escaped.extend(b"\\b"),
78            0x0C => escaped.extend(b"\\f"),
79            b'\n' => escaped.extend(b"\\n"),
80            b'\r' => escaped.extend(b"\\r"),
81            b'\t' => escaped.extend(b"\\t"),
82            0..=0x1F => escaped.extend(format!("\\u{:04x}", byte).bytes()),
83            _ => escaped.push(*byte),
84        }
85    }
86
87    // This is safe because:
88    // - The original input was valid UTF-8 since it came in as a `&str`
89    // - Only single-byte code points were escaped
90    // - The escape sequences are valid UTF-8
91    debug_assert!(std::str::from_utf8(&escaped).is_ok());
92    unsafe { String::from_utf8_unchecked(escaped) }
93}
94
95/// Unescapes a JSON-escaped string.
96/// If there are no escape sequences, it directly returns the reference.
97pub(crate) fn unescape_string(value: &str) -> Result<Cow<'_, str>, EscapeError> {
98    let bytes = value.as_bytes();
99    for (index, byte) in bytes.iter().enumerate() {
100        if *byte == b'\\' {
101            return unescape_string_inner(&bytes[0..index], &bytes[index..]).map(Cow::Owned);
102        }
103    }
104    Ok(Cow::Borrowed(value))
105}
106
107fn unescape_string_inner(start: &[u8], rest: &[u8]) -> Result<String, EscapeError> {
108    let mut unescaped = Vec::with_capacity(start.len() + rest.len());
109    unescaped.extend(start);
110
111    let mut index = 0;
112    while index < rest.len() {
113        match rest[index] {
114            b'\\' => {
115                index += 1;
116                if index == rest.len() {
117                    return Err(EscapeErrorKind::UnexpectedEndOfString.into());
118                }
119                match rest[index] {
120                    b'u' => {
121                        index -= 1;
122                        index += read_unicode_escapes(&rest[index..], &mut unescaped)?;
123                    }
124                    byte => {
125                        match byte {
126                            b'\\' => unescaped.push(b'\\'),
127                            b'/' => unescaped.push(b'/'),
128                            b'"' => unescaped.push(b'"'),
129                            b'b' => unescaped.push(0x08),
130                            b'f' => unescaped.push(0x0C),
131                            b'n' => unescaped.push(b'\n'),
132                            b'r' => unescaped.push(b'\r'),
133                            b't' => unescaped.push(b'\t'),
134                            _ => {
135                                return Err(
136                                    EscapeErrorKind::InvalidEscapeCharacter(byte.into()).into()
137                                )
138                            }
139                        }
140                        index += 1;
141                    }
142                }
143            }
144            byte => {
145                unescaped.push(byte);
146                index += 1
147            }
148        }
149    }
150
151    String::from_utf8(unescaped).map_err(|_| EscapeErrorKind::InvalidUtf8.into())
152}
153
154fn is_utf16_low_surrogate(codepoint: u16) -> bool {
155    codepoint & 0xFC00 == 0xDC00
156}
157
158fn is_utf16_high_surrogate(codepoint: u16) -> bool {
159    codepoint & 0xFC00 == 0xD800
160}
161
162fn read_codepoint(rest: &[u8]) -> Result<u16, EscapeError> {
163    if rest.len() < 6 {
164        return Err(EscapeErrorKind::UnexpectedEndOfString.into());
165    }
166    if &rest[0..2] != b"\\u" {
167        // The first codepoint is always prefixed with "\u" since unescape_string_inner does
168        // that check, so this error will always be for the low word of a surrogate pair.
169        return Err(EscapeErrorKind::ExpectedSurrogatePair(
170            String::from_utf8_lossy(&rest[0..6]).into(),
171        )
172        .into());
173    }
174
175    let codepoint_str =
176        std::str::from_utf8(&rest[2..6]).map_err(|_| EscapeErrorKind::InvalidUtf8)?;
177
178    // Error on characters `u16::from_str_radix` would otherwise accept, such as `+`
179    if codepoint_str.bytes().any(|byte| !byte.is_ascii_hexdigit()) {
180        return Err(EscapeErrorKind::InvalidUnicodeEscape(codepoint_str.into()).into());
181    }
182    Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value"))
183}
184
185/// Reads JSON Unicode escape sequences (i.e., "\u1234"). Will also read
186/// an additional codepoint if the first codepoint is the start of a surrogate pair.
187fn read_unicode_escapes(bytes: &[u8], into: &mut Vec<u8>) -> Result<usize, EscapeError> {
188    let high = read_codepoint(bytes)?;
189    let (bytes_read, chr) = if is_utf16_high_surrogate(high) {
190        let low = read_codepoint(&bytes[6..])?;
191        if !is_utf16_low_surrogate(low) {
192            return Err(EscapeErrorKind::InvalidSurrogatePair(high, low).into());
193        }
194
195        let codepoint =
196            std::char::from_u32(0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32)
197                .ok_or(EscapeErrorKind::InvalidSurrogatePair(high, low))?;
198        (12, codepoint)
199    } else {
200        let codepoint = std::char::from_u32(high as u32).ok_or_else(|| {
201            EscapeErrorKind::InvalidUnicodeEscape(String::from_utf8_lossy(&bytes[0..6]).into())
202        })?;
203        (6, codepoint)
204    };
205
206    match chr.len_utf8() {
207        1 => into.push(chr as u8),
208        _ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()),
209    }
210    Ok(bytes_read)
211}
212
213#[cfg(test)]
214mod test {
215    use super::escape_string;
216    use crate::escape::{unescape_string, EscapeErrorKind};
217    use std::borrow::Cow;
218
219    #[test]
220    fn escape() {
221        assert_eq!("", escape_string("").as_ref());
222        assert_eq!("foo", escape_string("foo").as_ref());
223        assert_eq!("foo\\r\\n", escape_string("foo\r\n").as_ref());
224        assert_eq!("foo\\r\\nbar", escape_string("foo\r\nbar").as_ref());
225        assert_eq!(r"foo\\bar", escape_string(r"foo\bar").as_ref());
226        assert_eq!(r"\\foobar", escape_string(r"\foobar").as_ref());
227        assert_eq!(
228            r"\bf\fo\to\r\n",
229            escape_string("\u{08}f\u{0C}o\to\r\n").as_ref()
230        );
231        assert_eq!("\\\"test\\\"", escape_string("\"test\"").as_ref());
232        assert_eq!("\\u0000", escape_string("\u{0}").as_ref());
233        assert_eq!("\\u001f", escape_string("\u{1f}").as_ref());
234    }
235
236    #[test]
237    fn unescape_no_escapes() {
238        let unescaped = unescape_string("test test").unwrap();
239        assert_eq!("test test", unescaped);
240        assert!(matches!(unescaped, Cow::Borrowed(_)));
241    }
242
243    #[test]
244    fn unescape() {
245        assert_eq!(
246            "\x08f\x0Co\to\r\n",
247            unescape_string(r"\bf\fo\to\r\n").unwrap()
248        );
249        assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap());
250        assert_eq!("\x00", unescape_string("\\u0000").unwrap());
251        assert_eq!("\x1f", unescape_string("\\u001f").unwrap());
252        assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap());
253        assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap());
254        assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap());
255        assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap());
256
257        assert_eq!(
258            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
259            unescape_string("\\")
260        );
261        assert_eq!(
262            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
263            unescape_string("\\u")
264        );
265        assert_eq!(
266            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
267            unescape_string("\\u00")
268        );
269        assert_eq!(
270            Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()),
271            unescape_string("\\z")
272        );
273
274        assert_eq!(
275            Err(EscapeErrorKind::ExpectedSurrogatePair("\\nasdf".into()).into()),
276            unescape_string("\\uD801\\nasdf")
277        );
278        assert_eq!(
279            Err(EscapeErrorKind::UnexpectedEndOfString.into()),
280            unescape_string("\\uD801\\u00")
281        );
282        assert_eq!(
283            Err(EscapeErrorKind::InvalidSurrogatePair(0xD801, 0xC501).into()),
284            unescape_string("\\uD801\\uC501")
285        );
286
287        assert_eq!(
288            Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()),
289            unescape_string("\\u+04D")
290        );
291    }
292
293    use proptest::proptest;
294    proptest! {
295        #[test]
296        fn matches_serde_json(s in ".*") {
297            let serde_escaped = serde_json::to_string(&s).unwrap();
298            let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
299            assert_eq!(serde_escaped,escape_string(&s))
300        }
301
302        #[test]
303        fn round_trip(chr in proptest::char::any()) {
304            let mut original = String::new();
305            original.push(chr);
306
307            let escaped = escape_string(&original);
308            let unescaped = unescape_string(&escaped).unwrap();
309            assert_eq!(original, unescaped);
310        }
311
312        #[test]
313        fn unicode_surrogates(chr in proptest::char::range(
314            std::char::from_u32(0x10000).unwrap(),
315            std::char::from_u32(0x10FFFF).unwrap(),
316        )) {
317            let mut codepoints = [0; 2];
318            chr.encode_utf16(&mut codepoints);
319
320            let escaped = format!("\\u{:04X}\\u{:04X}", codepoints[0], codepoints[1]);
321            let unescaped = unescape_string(&escaped).unwrap();
322
323            let expected = format!("{}", chr);
324            assert_eq!(expected, unescaped);
325        }
326    }
327
328    #[test]
329    #[ignore] // This tests escaping of all codepoints, but can take a long time in debug builds
330    fn all_codepoints() {
331        for value in 0..u32::MAX {
332            if let Some(chr) = char::from_u32(value) {
333                let string = String::from(chr);
334                let escaped = escape_string(&string);
335                let serde_escaped = serde_json::to_string(&string).unwrap();
336                let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
337                assert_eq!(&escaped, serde_escaped);
338            }
339        }
340    }
341}