askama_escape/
lib.rs

1#![cfg_attr(not(any(feature = "json", test)), no_std)]
2#![deny(elided_lifetimes_in_paths)]
3#![deny(unreachable_pub)]
4
5use core::fmt::{self, Display, Formatter, Write};
6use core::str;
7
8#[derive(Debug)]
9pub struct MarkupDisplay<E, T>
10where
11    E: Escaper,
12    T: Display,
13{
14    value: DisplayValue<T>,
15    escaper: E,
16}
17
18impl<E, T> MarkupDisplay<E, T>
19where
20    E: Escaper,
21    T: Display,
22{
23    pub fn new_unsafe(value: T, escaper: E) -> Self {
24        Self {
25            value: DisplayValue::Unsafe(value),
26            escaper,
27        }
28    }
29
30    pub fn new_safe(value: T, escaper: E) -> Self {
31        Self {
32            value: DisplayValue::Safe(value),
33            escaper,
34        }
35    }
36
37    #[must_use]
38    pub fn mark_safe(mut self) -> MarkupDisplay<E, T> {
39        self.value = match self.value {
40            DisplayValue::Unsafe(t) => DisplayValue::Safe(t),
41            _ => self.value,
42        };
43        self
44    }
45}
46
47impl<E, T> Display for MarkupDisplay<E, T>
48where
49    E: Escaper,
50    T: Display,
51{
52    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
53        match self.value {
54            DisplayValue::Unsafe(ref t) => write!(
55                EscapeWriter {
56                    fmt,
57                    escaper: &self.escaper
58                },
59                "{}",
60                t
61            ),
62            DisplayValue::Safe(ref t) => t.fmt(fmt),
63        }
64    }
65}
66
67#[derive(Debug)]
68pub struct EscapeWriter<'a, E, W> {
69    fmt: W,
70    escaper: &'a E,
71}
72
73impl<E, W> Write for EscapeWriter<'_, E, W>
74where
75    W: Write,
76    E: Escaper,
77{
78    fn write_str(&mut self, s: &str) -> fmt::Result {
79        self.escaper.write_escaped(&mut self.fmt, s)
80    }
81}
82
83pub fn escape<E>(string: &str, escaper: E) -> Escaped<'_, E>
84where
85    E: Escaper,
86{
87    Escaped { string, escaper }
88}
89
90#[derive(Debug)]
91pub struct Escaped<'a, E>
92where
93    E: Escaper,
94{
95    string: &'a str,
96    escaper: E,
97}
98
99impl<E> Display for Escaped<'_, E>
100where
101    E: Escaper,
102{
103    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
104        self.escaper.write_escaped(fmt, self.string)
105    }
106}
107
108pub struct Html;
109
110macro_rules! escaping_body {
111    ($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{
112        if $start < $i {
113            $fmt.write_str(unsafe { str::from_utf8_unchecked(&$bytes[$start..$i]) })?;
114        }
115        $fmt.write_str($quote)?;
116        $start = $i + 1;
117    }};
118}
119
120impl Escaper for Html {
121    fn write_escaped<W>(&self, mut fmt: W, string: &str) -> fmt::Result
122    where
123        W: Write,
124    {
125        let bytes = string.as_bytes();
126        let mut start = 0;
127        for (i, b) in bytes.iter().enumerate() {
128            if b.wrapping_sub(b'"') <= FLAG {
129                match *b {
130                    b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
131                    b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
132                    b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
133                    b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
134                    b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
135                    _ => (),
136                }
137            }
138        }
139        if start < bytes.len() {
140            fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })
141        } else {
142            Ok(())
143        }
144    }
145}
146
147pub struct Text;
148
149impl Escaper for Text {
150    fn write_escaped<W>(&self, mut fmt: W, string: &str) -> fmt::Result
151    where
152        W: Write,
153    {
154        fmt.write_str(string)
155    }
156}
157
158#[derive(Debug, PartialEq)]
159enum DisplayValue<T>
160where
161    T: Display,
162{
163    Safe(T),
164    Unsafe(T),
165}
166
167pub trait Escaper {
168    fn write_escaped<W>(&self, fmt: W, string: &str) -> fmt::Result
169    where
170        W: Write;
171}
172
173const FLAG: u8 = b'>' - b'"';
174
175/// Escape chevrons, ampersand and apostrophes for use in JSON
176#[cfg(feature = "json")]
177#[derive(Debug, Clone, Default)]
178pub struct JsonEscapeBuffer(Vec<u8>);
179
180#[cfg(feature = "json")]
181impl JsonEscapeBuffer {
182    pub fn new() -> Self {
183        Self(Vec::new())
184    }
185
186    pub fn finish(self) -> String {
187        unsafe { String::from_utf8_unchecked(self.0) }
188    }
189}
190
191#[cfg(feature = "json")]
192impl std::io::Write for JsonEscapeBuffer {
193    fn write(&mut self, bytes: &[u8]) -> std::io::Result<usize> {
194        macro_rules! push_esc_sequence {
195            ($start:ident, $i:ident, $self:ident, $bytes:ident, $quote:expr) => {{
196                if $start < $i {
197                    $self.0.extend_from_slice(&$bytes[$start..$i]);
198                }
199                $self.0.extend_from_slice($quote);
200                $start = $i + 1;
201            }};
202        }
203
204        self.0.reserve(bytes.len());
205        let mut start = 0;
206        for (i, b) in bytes.iter().enumerate() {
207            match *b {
208                b'&' => push_esc_sequence!(start, i, self, bytes, br#"\u0026"#),
209                b'\'' => push_esc_sequence!(start, i, self, bytes, br#"\u0027"#),
210                b'<' => push_esc_sequence!(start, i, self, bytes, br#"\u003c"#),
211                b'>' => push_esc_sequence!(start, i, self, bytes, br#"\u003e"#),
212                _ => (),
213            }
214        }
215        if start < bytes.len() {
216            self.0.extend_from_slice(&bytes[start..]);
217        }
218        Ok(bytes.len())
219    }
220
221    fn flush(&mut self) -> std::io::Result<()> {
222        Ok(())
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use std::string::ToString;
230
231    #[test]
232    fn test_escape() {
233        assert_eq!(escape("", Html).to_string(), "");
234        assert_eq!(escape("<&>", Html).to_string(), "&lt;&amp;&gt;");
235        assert_eq!(escape("bla&", Html).to_string(), "bla&amp;");
236        assert_eq!(escape("<foo", Html).to_string(), "&lt;foo");
237        assert_eq!(escape("bla&h", Html).to_string(), "bla&amp;h");
238    }
239}