1use std::borrow::Cow;
7use std::fmt;
8
9#[derive(Debug, PartialEq, Eq)]
10enum EscapeErrorKind {
11 ExpectedSurrogatePair(String),
12 InvalidEscapeCharacter(char),
13 InvalidSurrogatePair(u16, u16),
14 InvalidUnicodeEscape(String),
15 InvalidUtf8,
16 UnexpectedEndOfString,
17}
18
19#[derive(Debug)]
20#[cfg_attr(test, derive(PartialEq, Eq))]
21pub struct EscapeError {
22 kind: EscapeErrorKind,
23}
24
25impl std::error::Error for EscapeError {}
26
27impl fmt::Display for EscapeError {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 use EscapeErrorKind::*;
30 match &self.kind {
31 ExpectedSurrogatePair(low) => {
32 write!(
33 f,
34 "expected a UTF-16 surrogate pair, but got {} as the low word",
35 low
36 )
37 }
38 InvalidEscapeCharacter(chr) => write!(f, "invalid JSON escape: \\{}", chr),
39 InvalidSurrogatePair(high, low) => {
40 write!(f, "invalid surrogate pair: \\u{:04X}\\u{:04X}", high, low)
41 }
42 InvalidUnicodeEscape(escape) => write!(f, "invalid JSON Unicode escape: \\u{}", escape),
43 InvalidUtf8 => write!(f, "invalid UTF-8 codepoint in JSON string"),
44 UnexpectedEndOfString => write!(f, "unexpected end of string"),
45 }
46 }
47}
48
49impl From<EscapeErrorKind> for EscapeError {
50 fn from(kind: EscapeErrorKind) -> Self {
51 Self { kind }
52 }
53}
54
55pub(crate) fn escape_string(value: &str) -> Cow<'_, str> {
57 let bytes = value.as_bytes();
58 for (index, byte) in bytes.iter().enumerate() {
59 match byte {
60 0..=0x1F | b'"' | b'\\' => {
61 return Cow::Owned(escape_string_inner(&bytes[0..index], &bytes[index..]))
62 }
63 _ => {}
64 }
65 }
66 Cow::Borrowed(value)
67}
68
69fn escape_string_inner(start: &[u8], rest: &[u8]) -> String {
70 let mut escaped = Vec::with_capacity(start.len() + rest.len() + 1);
71 escaped.extend(start);
72
73 for byte in rest {
74 match byte {
75 b'"' => escaped.extend(b"\\\""),
76 b'\\' => escaped.extend(b"\\\\"),
77 0x08 => escaped.extend(b"\\b"),
78 0x0C => escaped.extend(b"\\f"),
79 b'\n' => escaped.extend(b"\\n"),
80 b'\r' => escaped.extend(b"\\r"),
81 b'\t' => escaped.extend(b"\\t"),
82 0..=0x1F => escaped.extend(format!("\\u{:04x}", byte).bytes()),
83 _ => escaped.push(*byte),
84 }
85 }
86
87 debug_assert!(std::str::from_utf8(&escaped).is_ok());
92 unsafe { String::from_utf8_unchecked(escaped) }
93}
94
95pub(crate) fn unescape_string(value: &str) -> Result<Cow<'_, str>, EscapeError> {
98 let bytes = value.as_bytes();
99 for (index, byte) in bytes.iter().enumerate() {
100 if *byte == b'\\' {
101 return unescape_string_inner(&bytes[0..index], &bytes[index..]).map(Cow::Owned);
102 }
103 }
104 Ok(Cow::Borrowed(value))
105}
106
107fn unescape_string_inner(start: &[u8], rest: &[u8]) -> Result<String, EscapeError> {
108 let mut unescaped = Vec::with_capacity(start.len() + rest.len());
109 unescaped.extend(start);
110
111 let mut index = 0;
112 while index < rest.len() {
113 match rest[index] {
114 b'\\' => {
115 index += 1;
116 if index == rest.len() {
117 return Err(EscapeErrorKind::UnexpectedEndOfString.into());
118 }
119 match rest[index] {
120 b'u' => {
121 index -= 1;
122 index += read_unicode_escapes(&rest[index..], &mut unescaped)?;
123 }
124 byte => {
125 match byte {
126 b'\\' => unescaped.push(b'\\'),
127 b'/' => unescaped.push(b'/'),
128 b'"' => unescaped.push(b'"'),
129 b'b' => unescaped.push(0x08),
130 b'f' => unescaped.push(0x0C),
131 b'n' => unescaped.push(b'\n'),
132 b'r' => unescaped.push(b'\r'),
133 b't' => unescaped.push(b'\t'),
134 _ => {
135 return Err(
136 EscapeErrorKind::InvalidEscapeCharacter(byte.into()).into()
137 )
138 }
139 }
140 index += 1;
141 }
142 }
143 }
144 byte => {
145 unescaped.push(byte);
146 index += 1
147 }
148 }
149 }
150
151 String::from_utf8(unescaped).map_err(|_| EscapeErrorKind::InvalidUtf8.into())
152}
153
154fn is_utf16_low_surrogate(codepoint: u16) -> bool {
155 codepoint & 0xFC00 == 0xDC00
156}
157
158fn is_utf16_high_surrogate(codepoint: u16) -> bool {
159 codepoint & 0xFC00 == 0xD800
160}
161
162fn read_codepoint(rest: &[u8]) -> Result<u16, EscapeError> {
163 if rest.len() < 6 {
164 return Err(EscapeErrorKind::UnexpectedEndOfString.into());
165 }
166 if &rest[0..2] != b"\\u" {
167 return Err(EscapeErrorKind::ExpectedSurrogatePair(
170 String::from_utf8_lossy(&rest[0..6]).into(),
171 )
172 .into());
173 }
174
175 let codepoint_str =
176 std::str::from_utf8(&rest[2..6]).map_err(|_| EscapeErrorKind::InvalidUtf8)?;
177
178 if codepoint_str.bytes().any(|byte| !byte.is_ascii_hexdigit()) {
180 return Err(EscapeErrorKind::InvalidUnicodeEscape(codepoint_str.into()).into());
181 }
182 Ok(u16::from_str_radix(codepoint_str, 16).expect("hex string is valid 16-bit value"))
183}
184
185fn read_unicode_escapes(bytes: &[u8], into: &mut Vec<u8>) -> Result<usize, EscapeError> {
188 let high = read_codepoint(bytes)?;
189 let (bytes_read, chr) = if is_utf16_high_surrogate(high) {
190 let low = read_codepoint(&bytes[6..])?;
191 if !is_utf16_low_surrogate(low) {
192 return Err(EscapeErrorKind::InvalidSurrogatePair(high, low).into());
193 }
194
195 let codepoint =
196 std::char::from_u32(0x10000 + (high - 0xD800) as u32 * 0x400 + (low - 0xDC00) as u32)
197 .ok_or(EscapeErrorKind::InvalidSurrogatePair(high, low))?;
198 (12, codepoint)
199 } else {
200 let codepoint = std::char::from_u32(high as u32).ok_or_else(|| {
201 EscapeErrorKind::InvalidUnicodeEscape(String::from_utf8_lossy(&bytes[0..6]).into())
202 })?;
203 (6, codepoint)
204 };
205
206 match chr.len_utf8() {
207 1 => into.push(chr as u8),
208 _ => into.extend_from_slice(chr.encode_utf8(&mut [0; 4]).as_bytes()),
209 }
210 Ok(bytes_read)
211}
212
213#[cfg(test)]
214mod test {
215 use super::escape_string;
216 use crate::escape::{unescape_string, EscapeErrorKind};
217 use std::borrow::Cow;
218
219 #[test]
220 fn escape() {
221 assert_eq!("", escape_string("").as_ref());
222 assert_eq!("foo", escape_string("foo").as_ref());
223 assert_eq!("foo\\r\\n", escape_string("foo\r\n").as_ref());
224 assert_eq!("foo\\r\\nbar", escape_string("foo\r\nbar").as_ref());
225 assert_eq!(r"foo\\bar", escape_string(r"foo\bar").as_ref());
226 assert_eq!(r"\\foobar", escape_string(r"\foobar").as_ref());
227 assert_eq!(
228 r"\bf\fo\to\r\n",
229 escape_string("\u{08}f\u{0C}o\to\r\n").as_ref()
230 );
231 assert_eq!("\\\"test\\\"", escape_string("\"test\"").as_ref());
232 assert_eq!("\\u0000", escape_string("\u{0}").as_ref());
233 assert_eq!("\\u001f", escape_string("\u{1f}").as_ref());
234 }
235
236 #[test]
237 fn unescape_no_escapes() {
238 let unescaped = unescape_string("test test").unwrap();
239 assert_eq!("test test", unescaped);
240 assert!(matches!(unescaped, Cow::Borrowed(_)));
241 }
242
243 #[test]
244 fn unescape() {
245 assert_eq!(
246 "\x08f\x0Co\to\r\n",
247 unescape_string(r"\bf\fo\to\r\n").unwrap()
248 );
249 assert_eq!("\"test\"", unescape_string(r#"\"test\""#).unwrap());
250 assert_eq!("\x00", unescape_string("\\u0000").unwrap());
251 assert_eq!("\x1f", unescape_string("\\u001f").unwrap());
252 assert_eq!("foo\r\nbar", unescape_string("foo\\r\\nbar").unwrap());
253 assert_eq!("foo\r\n", unescape_string("foo\\r\\n").unwrap());
254 assert_eq!("\r\nbar", unescape_string("\\r\\nbar").unwrap());
255 assert_eq!("\u{10437}", unescape_string("\\uD801\\uDC37").unwrap());
256
257 assert_eq!(
258 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
259 unescape_string("\\")
260 );
261 assert_eq!(
262 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
263 unescape_string("\\u")
264 );
265 assert_eq!(
266 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
267 unescape_string("\\u00")
268 );
269 assert_eq!(
270 Err(EscapeErrorKind::InvalidEscapeCharacter('z').into()),
271 unescape_string("\\z")
272 );
273
274 assert_eq!(
275 Err(EscapeErrorKind::ExpectedSurrogatePair("\\nasdf".into()).into()),
276 unescape_string("\\uD801\\nasdf")
277 );
278 assert_eq!(
279 Err(EscapeErrorKind::UnexpectedEndOfString.into()),
280 unescape_string("\\uD801\\u00")
281 );
282 assert_eq!(
283 Err(EscapeErrorKind::InvalidSurrogatePair(0xD801, 0xC501).into()),
284 unescape_string("\\uD801\\uC501")
285 );
286
287 assert_eq!(
288 Err(EscapeErrorKind::InvalidUnicodeEscape("+04D".into()).into()),
289 unescape_string("\\u+04D")
290 );
291 }
292
293 use proptest::proptest;
294 proptest! {
295 #[test]
296 fn matches_serde_json(s in ".*") {
297 let serde_escaped = serde_json::to_string(&s).unwrap();
298 let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
299 assert_eq!(serde_escaped,escape_string(&s))
300 }
301
302 #[test]
303 fn round_trip(chr in proptest::char::any()) {
304 let mut original = String::new();
305 original.push(chr);
306
307 let escaped = escape_string(&original);
308 let unescaped = unescape_string(&escaped).unwrap();
309 assert_eq!(original, unescaped);
310 }
311
312 #[test]
313 fn unicode_surrogates(chr in proptest::char::range(
314 std::char::from_u32(0x10000).unwrap(),
315 std::char::from_u32(0x10FFFF).unwrap(),
316 )) {
317 let mut codepoints = [0; 2];
318 chr.encode_utf16(&mut codepoints);
319
320 let escaped = format!("\\u{:04X}\\u{:04X}", codepoints[0], codepoints[1]);
321 let unescaped = unescape_string(&escaped).unwrap();
322
323 let expected = format!("{}", chr);
324 assert_eq!(expected, unescaped);
325 }
326 }
327
328 #[test]
329 #[ignore] fn all_codepoints() {
331 for value in 0..u32::MAX {
332 if let Some(chr) = char::from_u32(value) {
333 let string = String::from(chr);
334 let escaped = escape_string(&string);
335 let serde_escaped = serde_json::to_string(&string).unwrap();
336 let serde_escaped = &serde_escaped[1..(serde_escaped.len() - 1)];
337 assert_eq!(&escaped, serde_escaped);
338 }
339 }
340 }
341}