domain/utils/
base64.rs

1//! Decoding and encoding of Base 64.
2//!
3//! The Base 64 encoding is defined in [RFC 4648]. There are two variants
4//! defined in the RFC, dubbed *base64* and *base64url* which are
5//! differenciated by the last two characters in the alphabet. The DNS uses
6//! only the original *base64* variant, so this is what is implemented by the
7//! module for now.
8//!
9//! The module defines the type [`Decoder`] which keeps the state necessary
10//! for decoding. The various functions offered use such a decoder to decode
11//! and encode octets in various forms.
12//!
13//! [RFC 4648]: https://tools.ietf.org/html/rfc4648
14
15use crate::base::scan::{ConvertSymbols, EntrySymbol, ScannerError};
16use core::fmt;
17use octseq::builder::{
18    EmptyBuilder, FreezeBuilder, FromBuilder, OctetsBuilder, ShortBuf,
19};
20#[cfg(feature = "std")]
21use std::string::String;
22
23//------------ Convenience Functions -----------------------------------------
24
25/// Decodes a string with *base64* encoded data.
26///
27/// The function attempts to decode the entire string and returns the result
28/// as a `Octets` value.
29pub fn decode<Octets>(s: &str) -> Result<Octets, DecodeError>
30where
31    Octets: FromBuilder,
32    <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
33{
34    let mut decoder = Decoder::<<Octets as FromBuilder>::Builder>::new();
35    for ch in s.chars() {
36        decoder.push(ch)?;
37    }
38    decoder.finalize()
39}
40
41/// Encodes binary data in *base64* and writes it into a format stream.
42///
43/// This function is intended to be used in implementations of formatting
44/// traits:
45///
46/// ```
47/// use core::fmt;
48/// use domain::utils::base64;
49///
50/// struct Foo<'a>(&'a [u8]);
51///
52/// impl<'a> fmt::Display for Foo<'a> {
53///     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54///         base64::display(&self.0, f)
55///     }
56/// }
57/// ```
58pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
59where
60    B: AsRef<[u8]> + ?Sized,
61    W: fmt::Write,
62{
63    fn ch(i: u8) -> char {
64        ENCODE_ALPHABET[i as usize]
65    }
66
67    for chunk in bytes.as_ref().chunks(3) {
68        match chunk.len() {
69            1 => {
70                f.write_char(ch(chunk[0] >> 2))?;
71                f.write_char(ch((chunk[0] & 0x03) << 4))?;
72                f.write_char('=')?;
73                f.write_char('=')?;
74            }
75            2 => {
76                f.write_char(ch(chunk[0] >> 2))?;
77                f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
78                f.write_char(ch((chunk[1] & 0x0F) << 2))?;
79                f.write_char('=')?;
80            }
81            3 => {
82                f.write_char(ch(chunk[0] >> 2))?;
83                f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
84                f.write_char(ch(((chunk[1] & 0x0F) << 2) | (chunk[2] >> 6)))?;
85                f.write_char(ch(chunk[2] & 0x3F))?;
86            }
87            _ => unreachable!(),
88        }
89    }
90    Ok(())
91}
92
93/// Encodes binary data in *base64* and returns the encoded data as a string.
94#[cfg(feature = "std")]
95pub fn encode_string<B: AsRef<[u8]> + ?Sized>(bytes: &B) -> String {
96    let mut res = String::with_capacity((bytes.as_ref().len() / 3 + 1) * 4);
97    display(bytes, &mut res).unwrap();
98    res
99}
100
101/// Returns a placeholder value that implements `Display` for encoded data.
102pub fn encode_display<Octets: AsRef<[u8]>>(
103    octets: &Octets,
104) -> impl fmt::Display + '_ {
105    struct Display<'a>(&'a [u8]);
106
107    impl fmt::Display for Display<'_> {
108        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109            display(self.0, f)
110        }
111    }
112
113    Display(octets.as_ref())
114}
115
116/// Serialize and deserialize octets Base64 encoded or binary.
117///
118/// This module can be used with Serde’s `with` attribute. It will serialize
119/// an octets sequence as a Base64 encoded string with human readable
120/// serializers or as a raw octets sequence for compact serializers.
121#[cfg(feature = "serde")]
122pub mod serde {
123    use super::encode_display;
124    use core::fmt;
125    use octseq::builder::{EmptyBuilder, FromBuilder, OctetsBuilder};
126    use octseq::serde::{DeserializeOctets, SerializeOctets};
127
128    pub fn serialize<Octets, S>(
129        octets: &Octets,
130        serializer: S,
131    ) -> Result<S::Ok, S::Error>
132    where
133        Octets: AsRef<[u8]> + SerializeOctets,
134        S: serde::Serializer,
135    {
136        if serializer.is_human_readable() {
137            serializer.collect_str(&encode_display(octets))
138        } else {
139            octets.serialize_octets(serializer)
140        }
141    }
142
143    pub fn deserialize<'de, Octets, D: serde::Deserializer<'de>>(
144        deserializer: D,
145    ) -> Result<Octets, D::Error>
146    where
147        Octets: FromBuilder + DeserializeOctets<'de>,
148        <Octets as FromBuilder>::Builder: EmptyBuilder,
149    {
150        struct Visitor<'de, Octets: DeserializeOctets<'de>>(Octets::Visitor);
151
152        impl<'de, Octets> serde::de::Visitor<'de> for Visitor<'de, Octets>
153        where
154            Octets: FromBuilder + DeserializeOctets<'de>,
155            <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
156        {
157            type Value = Octets;
158
159            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
160                f.write_str("an Base64-encoded string")
161            }
162
163            fn visit_str<E: serde::de::Error>(
164                self,
165                v: &str,
166            ) -> Result<Self::Value, E> {
167                super::decode(v).map_err(E::custom)
168            }
169
170            fn visit_borrowed_bytes<E: serde::de::Error>(
171                self,
172                value: &'de [u8],
173            ) -> Result<Self::Value, E> {
174                self.0.visit_borrowed_bytes(value)
175            }
176
177            #[cfg(feature = "std")]
178            fn visit_byte_buf<E: serde::de::Error>(
179                self,
180                value: std::vec::Vec<u8>,
181            ) -> Result<Self::Value, E> {
182                self.0.visit_byte_buf(value)
183            }
184        }
185
186        if deserializer.is_human_readable() {
187            deserializer.deserialize_str(Visitor(Octets::visitor()))
188        } else {
189            Octets::deserialize_with_visitor(
190                deserializer,
191                Visitor(Octets::visitor()),
192            )
193        }
194    }
195}
196
197//------------ Decoder -------------------------------------------------------
198
199/// A base 64 decoder.
200///
201/// This type keeps all the state for decoding a sequence of characters
202/// representing data encoded in base 32. Upon success, the decoder returns
203/// the decoded data in a `bytes::Bytes` value.
204pub struct Decoder<Builder> {
205    /// A buffer for up to four characters.
206    ///
207    /// We only keep `u8`s here because only ASCII characters are used by
208    /// Base64.
209    buf: [u8; 4],
210
211    /// The index in `buf` where we place the next character.
212    ///
213    /// We also abuse this to mark when we are done (because there was
214    /// padding, in which case we set it to 0xF0).
215    next: usize,
216
217    /// The target or an error if something went wrong.
218    target: Result<Builder, DecodeError>,
219}
220
221impl<Builder: EmptyBuilder> Decoder<Builder> {
222    /// Creates a new empty decoder.
223    #[must_use]
224    pub fn new() -> Self {
225        Decoder {
226            buf: [0; 4],
227            next: 0,
228            target: Ok(Builder::empty()),
229        }
230    }
231}
232
233impl<Builder: OctetsBuilder> Decoder<Builder> {
234    /// Finalizes decoding and returns the decoded data.
235    pub fn finalize(self) -> Result<Builder::Octets, DecodeError>
236    where
237        Builder: FreezeBuilder,
238    {
239        let (target, next) = (self.target, self.next);
240        target.and_then(|bytes| {
241            // next is either 0 or 0xF0 for a completed group.
242            if next & 0x0F != 0 {
243                Err(DecodeError::ShortInput)
244            } else {
245                Ok(bytes.freeze())
246            }
247        })
248    }
249
250    /// Decodes one more character of data.
251    ///
252    /// Returns an error as soon as the encoded data is determined to be
253    /// illegal. It is okay to push more data after the first error. The
254    /// method will just keep returned errors.
255    pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
256        if self.next == 0xF0 {
257            self.target = Err(DecodeError::TrailingInput);
258            return Err(DecodeError::TrailingInput);
259        }
260
261        let val = if ch == PAD {
262            // Only up to two padding characters possible.
263            if self.next < 2 {
264                return Err(DecodeError::IllegalChar(ch));
265            }
266            0x80 // Acts as a marker later on.
267        } else {
268            if ch > (127 as char) {
269                return Err(DecodeError::IllegalChar(ch));
270            }
271            let val = DECODE_ALPHABET[ch as usize];
272            if val == 0xFF {
273                return Err(DecodeError::IllegalChar(ch));
274            }
275            val
276        };
277        self.buf[self.next] = val;
278        self.next += 1;
279
280        if self.next == 4 {
281            let target = self.target.as_mut().unwrap(); // Err covered above.
282            target
283                .append_slice(&[(self.buf[0] << 2) | (self.buf[1] >> 4)])
284                .map_err(Into::into)?;
285            if self.buf[2] != 0x80 {
286                target
287                    .append_slice(&[(self.buf[1] << 4) | (self.buf[2] >> 2)])
288                    .map_err(Into::into)?;
289            }
290            if self.buf[3] != 0x80 {
291                if self.buf[2] == 0x80 {
292                    return Err(DecodeError::TrailingInput);
293                }
294                target
295                    .append_slice(&[(self.buf[2] << 6) | self.buf[3]])
296                    .map_err(Into::into)?;
297                self.next = 0
298            } else {
299                self.next = 0xF0
300            }
301        }
302
303        Ok(())
304    }
305}
306
307//--- Default
308
309impl<Builder: EmptyBuilder> Default for Decoder<Builder> {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315//------------ SymbolConverter -----------------------------------------------
316
317/// A Base 64 decoder that can be used as a converter with a scanner.
318#[derive(Clone, Debug, Default)]
319pub struct SymbolConverter {
320    /// A buffer for up to four input characters.
321    ///
322    /// We only keep `u8`s here because only ASCII characters are used by
323    /// Base64.
324    input: [u8; 4],
325
326    /// The index in `input` where we place the next character.
327    ///
328    /// We also abuse this to mark when we are done (because there was
329    /// padding, in which case we set it to 0xF0).
330    next: usize,
331
332    /// A buffer to return a slice for the output.
333    output: [u8; 3],
334}
335
336impl SymbolConverter {
337    /// Creates a new symbol converter.
338    #[must_use]
339    pub fn new() -> Self {
340        Default::default()
341    }
342
343    fn process_char<Error: ScannerError>(
344        &mut self,
345        ch: char,
346    ) -> Result<Option<&[u8]>, Error> {
347        if self.next == EOF_MARKER {
348            return Err(Error::custom("trailing Base 64 data"));
349        }
350
351        let val = if ch == PAD {
352            // Only up to two padding characters possible.
353            if self.next < 2 {
354                return Err(Error::custom("illegal Base 64 data"));
355            }
356            PAD_MARKER // Acts as a marker later on.
357        } else {
358            if ch > (127 as char) {
359                return Err(Error::custom("illegal Base 64 data"));
360            }
361            let val = DECODE_ALPHABET[ch as usize];
362            if val == 0xFF {
363                return Err(Error::custom("illegal Base 64 data"));
364            }
365            val
366        };
367        self.input[self.next] = val;
368        self.next += 1;
369
370        if self.next == 4 {
371            self.output[0] = (self.input[0] << 2) | (self.input[1] >> 4);
372
373            if self.input[2] == PAD_MARKER {
374                // The second to last character is padding. The last one
375                // needs to be, too.
376                if self.input[3] == PAD_MARKER {
377                    self.next = EOF_MARKER;
378                    Ok(Some(&self.output[..1]))
379                } else {
380                    Err(Error::custom("illegal Base 64 data"))
381                }
382            } else {
383                self.output[1] = (self.input[1] << 4) | (self.input[2] >> 2);
384
385                if self.input[3] == PAD_MARKER {
386                    // The last characters is padding.
387                    self.next = EOF_MARKER;
388                    Ok(Some(&self.output[..2]))
389                } else {
390                    self.output[2] = (self.input[2] << 6) | self.input[3];
391                    self.next = 0;
392                    Ok(Some(&self.output))
393                }
394            }
395        } else {
396            Ok(None)
397        }
398    }
399}
400
401impl<Sym, Error> ConvertSymbols<Sym, Error> for SymbolConverter
402where
403    Sym: Into<EntrySymbol>,
404    Error: ScannerError,
405{
406    fn process_symbol(
407        &mut self,
408        symbol: Sym,
409    ) -> Result<Option<&[u8]>, Error> {
410        match symbol.into() {
411            EntrySymbol::Symbol(symbol) => self.process_char(
412                symbol
413                    .into_char()
414                    .map_err(|_| Error::custom("illegal Base 64 data"))?,
415            ),
416            EntrySymbol::EndOfToken => Ok(None),
417        }
418    }
419
420    fn process_tail(&mut self) -> Result<Option<&[u8]>, Error> {
421        // next is either 0 or 0xF0 for a completed group.
422        if self.next & 0x0F != 0 {
423            Err(Error::custom("incomplete Base 64 data"))
424        } else {
425            Ok(None)
426        }
427    }
428}
429
430//============ Error Types ===================================================
431
432//------------ DecodeError ---------------------------------------------------
433
434/// An error happened while decoding a base 64 or base 32 encoded string.
435#[derive(Clone, Copy, Debug, Eq, PartialEq)]
436pub enum DecodeError {
437    /// A character was pushed that isn’t allowed in the encoding.
438    IllegalChar(char),
439
440    /// There was trailing data after a padding sequence.
441    TrailingInput,
442
443    /// The input ended with an incomplete sequence.
444    ShortInput,
445
446    /// The buffer to decode into is too short.
447    ShortBuf,
448}
449
450impl From<ShortBuf> for DecodeError {
451    fn from(_: ShortBuf) -> Self {
452        DecodeError::ShortBuf
453    }
454}
455
456//--- Display and Error
457
458impl fmt::Display for DecodeError {
459    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
460        match *self {
461            DecodeError::TrailingInput => f.write_str("trailing input"),
462            DecodeError::IllegalChar(ch) => {
463                write!(f, "illegal character '{}'", ch)
464            }
465            DecodeError::ShortInput => f.write_str("incomplete input"),
466            DecodeError::ShortBuf => ShortBuf.fmt(f),
467        }
468    }
469}
470
471#[cfg(feature = "std")]
472impl std::error::Error for DecodeError {}
473
474//============ Constants =====================================================
475
476/// The alphabet used by the decoder.
477///
478/// This maps encoding characters into their values. A value of 0xFF stands in
479/// for illegal characters. We only provide the first 128 characters since the
480/// alphabet will only use ASCII characters.
481const DECODE_ALPHABET: [u8; 128] = [
482    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00 .. 0x07
483    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x08 .. 0x0F
484    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10 .. 0x17
485    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x18 .. 0x1F
486    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x20 .. 0x27
487    0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, // 0x28 .. 0x2F
488    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, // 0x30 .. 0x37
489    0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x38 .. 0x3F
490    0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, // 0x40 .. 0x47
491    0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, // 0x48 .. 0x4F
492    0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // 0x50 .. 0x57
493    0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x58 .. 0x5F
494    0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, // 0x60 .. 0x67
495    0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, // 0x68 .. 0x6F
496    0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, // 0x70 .. 0x77
497    0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x78 .. 0x7F
498];
499
500const ENCODE_ALPHABET: [char; 64] = [
501    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', // 0x00 .. 0x07
502    'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', // 0x08 .. 0x0F
503    'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', // 0x10 .. 0x17
504    'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', // 0x18 .. 0x1F
505    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', // 0x20 .. 0x27
506    'o', 'p', 'q', 'r', 's', 't', 'u', 'v', // 0x28 .. 0x2F
507    'w', 'x', 'y', 'z', '0', '1', '2', '3', // 0x30 .. 0x37
508    '4', '5', '6', '7', '8', '9', '+', '/', // 0x38 .. 0x3F
509];
510
511/// The padding character
512const PAD: char = '=';
513
514/// The marker for padding.
515const PAD_MARKER: u8 = 0x80;
516
517/// The marker for complete data.
518const EOF_MARKER: usize = 0xF0;
519
520//============ Test ==========================================================
521
522#[cfg(all(test, feature = "std"))]
523mod test {
524    use super::*;
525
526    const HAPPY_CASES: &[(&[u8], &str)] = &[
527        (b"", ""),
528        (b"f", "Zg=="),
529        (b"fo", "Zm8="),
530        (b"foo", "Zm9v"),
531        (b"foob", "Zm9vYg=="),
532        (b"fooba", "Zm9vYmE="),
533        (b"foobar", "Zm9vYmFy"),
534    ];
535
536    #[test]
537    fn decode_str() {
538        fn decode(s: &str) -> Result<std::vec::Vec<u8>, DecodeError> {
539            super::decode(s)
540        }
541
542        for (bin, text) in HAPPY_CASES {
543            assert_eq!(&decode(text).unwrap(), bin, "decode {}", text)
544        }
545
546        assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput);
547        assert_eq!(
548            decode("FPucA=").unwrap_err(),
549            DecodeError::IllegalChar('=')
550        );
551        assert_eq!(decode("FPucAw=").unwrap_err(), DecodeError::ShortInput);
552        assert_eq!(
553            decode("FPucAw=a").unwrap_err(),
554            DecodeError::TrailingInput
555        );
556        assert_eq!(
557            decode("FPucAw==a").unwrap_err(),
558            DecodeError::TrailingInput
559        );
560    }
561
562    #[test]
563    fn symbol_converter() {
564        use crate::base::scan::Symbols;
565        use std::vec::Vec;
566
567        fn decode(s: &str) -> Result<Vec<u8>, std::io::Error> {
568            let mut convert = SymbolConverter::new();
569            let convert: &mut dyn ConvertSymbols<_, std::io::Error> =
570                &mut convert;
571            let mut res = Vec::new();
572            for sym in Symbols::new(s.chars()) {
573                if let Some(octs) = convert.process_symbol(sym)? {
574                    res.extend_from_slice(octs);
575                }
576            }
577            if let Some(octs) = convert.process_tail()? {
578                res.extend_from_slice(octs);
579            }
580            Ok(res)
581        }
582
583        for (bin, text) in HAPPY_CASES {
584            assert_eq!(&decode(text).unwrap(), bin, "convert {}", text)
585        }
586    }
587
588    #[test]
589    fn display_bytes() {
590        use super::*;
591
592        fn fmt(s: &[u8]) -> String {
593            let mut out = String::new();
594            display(s, &mut out).unwrap();
595            out
596        }
597
598        for (bin, text) in HAPPY_CASES {
599            assert_eq!(&fmt(bin), text, "fmt {}", text);
600        }
601    }
602}