mz_expr/scalar/func/
encoding.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Encoding and decoding support for various formats that represent binary data
11//! as text data.
12
13use mz_ore::fmt::FormatBuffer;
14use mz_repr::strconv;
15use uncased::UncasedStr;
16
17use crate::EvalError;
18
19/// An encoding format.
20pub trait Format {
21    /// Encodes a byte slice into its string representation according to this
22    /// format.
23    fn encode(&self, bytes: &[u8]) -> String;
24
25    /// Decodes a byte slice from its string representation according to this
26    /// format.
27    fn decode(&self, s: &str) -> Result<Vec<u8>, EvalError>;
28}
29
30/// PostgreSQL-style Base64 encoding.
31///
32/// PostgreSQL follows RFC 2045, which requires that lines are broken after 76
33/// characters when encoding and that all whitespace characters are ignored when
34/// decoding. See <http://materialize.com/docs/sql/functions/encode> for
35/// details.
36struct Base64Format;
37
38impl Base64Format {
39    const CHARSET: &'static [u8] =
40        b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
41
42    fn encode_sextet(v: u8) -> char {
43        char::from(Self::CHARSET[usize::from(v)])
44    }
45
46    fn decode_sextet(b: u8) -> Result<u8, EvalError> {
47        match b {
48            b'A'..=b'Z' => Ok(b - b'A'),
49            b'a'..=b'z' => Ok(b - b'a' + 26),
50            b'0'..=b'9' => Ok(b + 4),
51            b'+' => Ok(62),
52            b'/' => Ok(63),
53            _ => Err(EvalError::InvalidBase64Symbol(char::from(b))),
54        }
55    }
56}
57
58impl Format for Base64Format {
59    // Support for PostgreSQL-style (which is really MIME-style) Base64 encoding
60    // was, frustratingly, removed from Rust's `base64` crate. So we roll our
61    // own Base64 encoder and decoder here.
62
63    fn encode(&self, bytes: &[u8]) -> String {
64        // Process input in chunks of three octets. Each chunk is converted to
65        // four sextets. Each sextet is encoded as a printable ASCII character
66        // via `CHARSET`.
67        //
68        // When the input length is not divisible by three, the last chunk is
69        // partial. Sextets that are entirely determined by missing octets are
70        // encoded as `=`. Sextets that are partially determined by a missing
71        // octect assume the octet was zero.
72        //
73        // Line breaks are emitted after every 76 characters.
74
75        let mut buf = String::new();
76        for chunk in bytes.chunks(3) {
77            match chunk {
78                [o1, o2, o3] => {
79                    let s1 = (o1 & 0b11111100) >> 2;
80                    let s2 = (o1 & 0b00000011) << 4 | (o2 & 0b11110000) >> 4;
81                    let s3 = (o2 & 0b00001111) << 2 | (o3 & 0b11000000) >> 6;
82                    let s4 = o3 & 0b00111111;
83                    buf.push(Self::encode_sextet(s1));
84                    buf.push(Self::encode_sextet(s2));
85                    buf.push(Self::encode_sextet(s3));
86                    buf.push(Self::encode_sextet(s4));
87                }
88                [o1, o2] => {
89                    let s1 = (o1 & 0b11111100) >> 2;
90                    let s2 = (o1 & 0b00000011) << 4 | (o2 & 0b11110000) >> 4;
91                    let s3 = (o2 & 0b00001111) << 2;
92                    buf.push(Self::encode_sextet(s1));
93                    buf.push(Self::encode_sextet(s2));
94                    buf.push(Self::encode_sextet(s3));
95                    buf.push('=');
96                }
97                [o1] => {
98                    let s1 = (o1 & 0b11111100) >> 2;
99                    let s2 = (o1 & 0b00000011) << 4;
100                    buf.push(Self::encode_sextet(s1));
101                    buf.push(Self::encode_sextet(s2));
102                    buf.push('=');
103                    buf.push('=');
104                }
105                _ => unreachable!(),
106            }
107            if buf.len() % 77 == 76 {
108                buf.push('\n');
109            }
110        }
111        buf
112    }
113
114    fn decode(&self, s: &str) -> Result<Vec<u8>, EvalError> {
115        // Process input in chunks of four bytes, after filtering out any bytes
116        // that represent whitespace. Each byte is decoded into a sextet
117        // according to the reverse charset mapping maintained in
118        // `Self::decode_sextet`. The four sextets are converted to three octets
119        // and emitted.
120        //
121        // When the last character in a chunk is `=` or the last two characters
122        // in a chunk are both `=`, the chunk is missing its last one or two
123        // sextets, respectively. Octets that are entirely determined by missing
124        // sextets are elided. Octets that are partially determined by a missing
125        // sextet assume the sextet was zero.
126        //
127        // It is an error for a `=` character to appear in another position in
128        // a chunk. It is also an error if a chunk is incomplete.
129
130        let mut buf = vec![];
131        let mut bytes = s
132            .as_bytes()
133            .iter()
134            .copied()
135            .filter(|ch| !matches!(ch, b' ' | b'\t' | b'\n' | b'\r'));
136        loop {
137            match (bytes.next(), bytes.next(), bytes.next(), bytes.next()) {
138                (Some(c1), Some(c2), Some(b'='), Some(b'=')) => {
139                    let s1 = Self::decode_sextet(c1)?;
140                    let s2 = Self::decode_sextet(c2)?;
141                    buf.push(s1 << 2 | (s2 & 0b110000) >> 4);
142                }
143                (Some(c1), Some(c2), Some(c3), Some(b'=')) => {
144                    let s1 = Self::decode_sextet(c1)?;
145                    let s2 = Self::decode_sextet(c2)?;
146                    let s3 = Self::decode_sextet(c3)?;
147                    buf.push(s1 << 2 | (s2 & 0b110000) >> 4);
148                    buf.push((s2 & 0b001111) << 4 | (s3 & 0b111100) >> 2);
149                }
150                (Some(b'='), _, _, _) | (_, Some(b'='), _, _) | (_, _, Some(b'='), _) => {
151                    return Err(EvalError::InvalidBase64Equals);
152                }
153                (Some(c1), Some(c2), Some(c3), Some(c4)) => {
154                    let s1 = Self::decode_sextet(c1)?;
155                    let s2 = Self::decode_sextet(c2)?;
156                    let s3 = Self::decode_sextet(c3)?;
157                    let s4 = Self::decode_sextet(c4)?;
158                    buf.push(s1 << 2 | (s2 & 0b110000) >> 4);
159                    buf.push((s2 & 0b001111) << 4 | (s3 & 0b111100) >> 2);
160                    buf.push((s3 & 0b000011) << 6 | s4);
161                }
162                (None, None, None, None) => return Ok(buf),
163                _ => return Err(EvalError::InvalidBase64EndSequence),
164            }
165        }
166    }
167}
168
169struct EscapeFormat;
170
171impl Format for EscapeFormat {
172    fn encode(&self, bytes: &[u8]) -> String {
173        let mut buf = String::new();
174        for b in bytes {
175            match b {
176                b'\0' | (b'\x80'..=b'\xff') => {
177                    buf.push('\\');
178                    write!(&mut buf, "{:03o}", b);
179                }
180                b'\\' => buf.push_str("\\\\"),
181                _ => buf.push(char::from(*b)),
182            }
183        }
184        buf
185    }
186
187    fn decode(&self, s: &str) -> Result<Vec<u8>, EvalError> {
188        Ok(strconv::parse_bytes_traditional(s)?)
189    }
190}
191
192struct HexFormat;
193
194impl Format for HexFormat {
195    fn encode(&self, bytes: &[u8]) -> String {
196        hex::encode(bytes)
197    }
198
199    fn decode(&self, s: &str) -> Result<Vec<u8>, EvalError> {
200        // Can't use `hex::decode` here, as it doesn't tolerate whitespace
201        // between encoded bytes.
202        Ok(strconv::parse_bytes_hex(s)?)
203    }
204}
205
206pub fn lookup_format(s: &str) -> Result<&'static dyn Format, EvalError> {
207    let s = UncasedStr::new(s);
208    if s == "base64" {
209        Ok(&Base64Format)
210    } else if s == "escape" {
211        Ok(&EscapeFormat)
212    } else if s == "hex" {
213        Ok(&HexFormat)
214    } else {
215        Err(EvalError::InvalidEncodingName(s.as_str().into()))
216    }
217}