domain/utils/
base64.rs

1//! Decoding and encoding of Base 64.
2//!
3//! The Base 64 encoding is defined in [RFC 4648]. There are two variants
4//! defined in the RFC, dubbed *base64* and *base64url* which are
5//! differenciated by the last two characters in the alphabet. The DNS uses
6//! only the original *base64* variant, so this is what is implemented by the
7//! module for now.
8//!
9//! The module defines the type [`Decoder`] which keeps the state necessary
10//! for decoding. The various functions offered use such a decoder to decode
11//! and encode octets in various forms.
12//!
13//! [RFC 4648]: https://tools.ietf.org/html/rfc4648
14
15use crate::base::scan::{ConvertSymbols, EntrySymbol, ScannerError};
16use core::fmt;
17use octseq::builder::{
18    EmptyBuilder, FreezeBuilder, FromBuilder, OctetsBuilder, ShortBuf,
19};
20#[cfg(feature = "std")]
21use std::string::String;
22
23//------------ Convenience Functions -----------------------------------------
24
25/// Decodes a string with *base64* encoded data.
26///
27/// The function attempts to decode the entire string and returns the result
28/// as a `Bytes` value.
29pub fn decode<Octets>(s: &str) -> Result<Octets, DecodeError>
30where
31    Octets: FromBuilder,
32    <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
33{
34    let mut decoder = Decoder::<<Octets as FromBuilder>::Builder>::new();
35    for ch in s.chars() {
36        decoder.push(ch)?;
37    }
38    decoder.finalize()
39}
40
41/// Encodes binary data in *base64* and writes it into a format stream.
42///
43/// This function is intended to be used in implementations of formatting
44/// traits:
45///
46/// ```
47/// use core::fmt;
48/// use domain::utils::base64;
49///
50/// struct Foo<'a>(&'a [u8]);
51///
52/// impl<'a> fmt::Display for Foo<'a> {
53///     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54///         base64::display(&self.0, f)
55///     }
56/// }
57/// ```
58pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
59where
60    B: AsRef<[u8]> + ?Sized,
61    W: fmt::Write,
62{
63    fn ch(i: u8) -> char {
64        ENCODE_ALPHABET[i as usize]
65    }
66
67    for chunk in bytes.as_ref().chunks(3) {
68        match chunk.len() {
69            1 => {
70                f.write_char(ch(chunk[0] >> 2))?;
71                f.write_char(ch((chunk[0] & 0x03) << 4))?;
72                f.write_char('=')?;
73                f.write_char('=')?;
74            }
75            2 => {
76                f.write_char(ch(chunk[0] >> 2))?;
77                f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
78                f.write_char(ch((chunk[1] & 0x0F) << 2))?;
79                f.write_char('=')?;
80            }
81            3 => {
82                f.write_char(ch(chunk[0] >> 2))?;
83                f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
84                f.write_char(ch((chunk[1] & 0x0F) << 2 | chunk[2] >> 6))?;
85                f.write_char(ch(chunk[2] & 0x3F))?;
86            }
87            _ => unreachable!(),
88        }
89    }
90    Ok(())
91}
92
93/// Encodes binary data in *base64* and returns the encoded data as a string.
94#[cfg(feature = "std")]
95pub fn encode_string<B: AsRef<[u8]> + ?Sized>(bytes: &B) -> String {
96    let mut res = String::with_capacity((bytes.as_ref().len() / 3 + 1) * 4);
97    display(bytes, &mut res).unwrap();
98    res
99}
100
101/// Returns a placeholder value that implements `Display` for encoded data.
102pub fn encode_display<Octets: AsRef<[u8]>>(
103    octets: &Octets,
104) -> impl fmt::Display + '_ {
105    struct Display<'a>(&'a [u8]);
106
107    impl<'a> fmt::Display for Display<'a> {
108        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109            display(self.0, f)
110        }
111    }
112
113    Display(octets.as_ref())
114}
115
116/// Serialize and deserialize octets Base64 encoded or binary.
117///
118/// This module can be used with Serde’s `with` attribute. It will serialize
119/// an octets sequence as a Base64 encoded string with human readable
120/// serializers or as a raw octets sequence for compact serializers.
121#[cfg(feature = "serde")]
122pub mod serde {
123    use super::encode_display;
124    use core::fmt;
125    use octseq::builder::{EmptyBuilder, FromBuilder, OctetsBuilder};
126    use octseq::serde::{DeserializeOctets, SerializeOctets};
127
128    pub fn serialize<Octets, S>(
129        octets: &Octets,
130        serializer: S,
131    ) -> Result<S::Ok, S::Error>
132    where
133        Octets: AsRef<[u8]> + SerializeOctets,
134        S: serde::Serializer,
135    {
136        if serializer.is_human_readable() {
137            serializer.collect_str(&encode_display(octets))
138        } else {
139            octets.serialize_octets(serializer)
140        }
141    }
142
143    pub fn deserialize<'de, Octets, D: serde::Deserializer<'de>>(
144        deserializer: D,
145    ) -> Result<Octets, D::Error>
146    where
147        Octets: FromBuilder + DeserializeOctets<'de>,
148        <Octets as FromBuilder>::Builder: EmptyBuilder,
149    {
150        struct Visitor<'de, Octets: DeserializeOctets<'de>>(Octets::Visitor);
151
152        impl<'de, Octets> serde::de::Visitor<'de> for Visitor<'de, Octets>
153        where
154            Octets: FromBuilder + DeserializeOctets<'de>,
155            <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
156        {
157            type Value = Octets;
158
159            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
160                f.write_str("an Base64-encoded string")
161            }
162
163            fn visit_str<E: serde::de::Error>(
164                self,
165                v: &str,
166            ) -> Result<Self::Value, E> {
167                super::decode(v).map_err(E::custom)
168            }
169
170            fn visit_borrowed_bytes<E: serde::de::Error>(
171                self,
172                value: &'de [u8],
173            ) -> Result<Self::Value, E> {
174                self.0.visit_borrowed_bytes(value)
175            }
176
177            #[cfg(feature = "std")]
178            fn visit_byte_buf<E: serde::de::Error>(
179                self,
180                value: std::vec::Vec<u8>,
181            ) -> Result<Self::Value, E> {
182                self.0.visit_byte_buf(value)
183            }
184        }
185
186        if deserializer.is_human_readable() {
187            deserializer.deserialize_str(Visitor(Octets::visitor()))
188        } else {
189            Octets::deserialize_with_visitor(
190                deserializer,
191                Visitor(Octets::visitor()),
192            )
193        }
194    }
195}
196
197//------------ Decoder -------------------------------------------------------
198
199/// A base 64 decoder.
200///
201/// This type keeps all the state for decoding a sequence of characters
202/// representing data encoded in base 32. Upon success, the decoder returns
203/// the decoded data in a `bytes::Bytes` value.
204pub struct Decoder<Builder> {
205    /// A buffer for up to four characters.
206    ///
207    /// We only keep `u8`s here because only ASCII characters are used by
208    /// Base64.
209    buf: [u8; 4],
210
211    /// The index in `buf` where we place the next character.
212    ///
213    /// We also abuse this to mark when we are done (because there was
214    /// padding, in which case we set it to 0xF0).
215    next: usize,
216
217    /// The target or an error if something went wrong.
218    target: Result<Builder, DecodeError>,
219}
220
221impl<Builder: EmptyBuilder> Decoder<Builder> {
222    /// Creates a new empty decoder.
223    #[must_use]
224    pub fn new() -> Self {
225        Decoder {
226            buf: [0; 4],
227            next: 0,
228            target: Ok(Builder::empty()),
229        }
230    }
231}
232
233impl<Builder: OctetsBuilder> Decoder<Builder> {
234    /// Finalizes decoding and returns the decoded data.
235    pub fn finalize(self) -> Result<Builder::Octets, DecodeError>
236    where
237        Builder: FreezeBuilder,
238    {
239        let (target, next) = (self.target, self.next);
240        target.and_then(|bytes| {
241            // next is either 0 or 0xF0 for a completed group.
242            if next & 0x0F != 0 {
243                Err(DecodeError::ShortInput)
244            } else {
245                Ok(bytes.freeze())
246            }
247        })
248    }
249
250    /// Decodes one more character of data.
251    ///
252    /// Returns an error as soon as the encoded data is determined to be
253    /// illegal. It is okay to push more data after the first error. The
254    /// method will just keep returned errors.
255    pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
256        if self.next == 0xF0 {
257            self.target = Err(DecodeError::TrailingInput);
258            return Err(DecodeError::TrailingInput);
259        }
260
261        let val = if ch == PAD {
262            // Only up to two padding characters possible.
263            if self.next < 2 {
264                return Err(DecodeError::IllegalChar(ch));
265            }
266            0x80 // Acts as a marker later on.
267        } else {
268            if ch > (127 as char) {
269                return Err(DecodeError::IllegalChar(ch));
270            }
271            let val = DECODE_ALPHABET[ch as usize];
272            if val == 0xFF {
273                return Err(DecodeError::IllegalChar(ch));
274            }
275            val
276        };
277        self.buf[self.next] = val;
278        self.next += 1;
279
280        if self.next == 4 {
281            let target = self.target.as_mut().unwrap(); // Err covered above.
282            target
283                .append_slice(&[self.buf[0] << 2 | self.buf[1] >> 4])
284                .map_err(Into::into)?;
285            if self.buf[2] != 0x80 {
286                target
287                    .append_slice(&[self.buf[1] << 4 | self.buf[2] >> 2])
288                    .map_err(Into::into)?;
289            }
290            if self.buf[3] != 0x80 {
291                if self.buf[2] == 0x80 {
292                    return Err(DecodeError::TrailingInput);
293                }
294                target
295                    .append_slice(&[(self.buf[2] << 6) | self.buf[3]])
296                    .map_err(Into::into)?;
297                self.next = 0
298            } else {
299                self.next = 0xF0
300            }
301        }
302
303        Ok(())
304    }
305}
306
307//--- Default
308
309#[cfg(feature = "bytes")]
310impl<Builder: EmptyBuilder> Default for Decoder<Builder> {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316//------------ SymbolConverter -----------------------------------------------
317
318/// A Base 64 decoder that can be used as a converter with a scanner.
319#[derive(Clone, Debug, Default)]
320pub struct SymbolConverter {
321    /// A buffer for up to four input characters.
322    ///
323    /// We only keep `u8`s here because only ASCII characters are used by
324    /// Base64.
325    input: [u8; 4],
326
327    /// The index in `input` where we place the next character.
328    ///
329    /// We also abuse this to mark when we are done (because there was
330    /// padding, in which case we set it to 0xF0).
331    next: usize,
332
333    /// A buffer to return a slice for the output.
334    output: [u8; 3],
335}
336
337impl SymbolConverter {
338    /// Creates a new symbol converter.
339    #[must_use]
340    pub fn new() -> Self {
341        Default::default()
342    }
343
344    fn process_char<Error: ScannerError>(
345        &mut self,
346        ch: char,
347    ) -> Result<Option<&[u8]>, Error> {
348        if self.next == EOF_MARKER {
349            return Err(Error::custom("trailing Base 64 data"));
350        }
351
352        let val = if ch == PAD {
353            // Only up to two padding characters possible.
354            if self.next < 2 {
355                return Err(Error::custom("illegal Base 64 data"));
356            }
357            PAD_MARKER // Acts as a marker later on.
358        } else {
359            if ch > (127 as char) {
360                return Err(Error::custom("illegal Base 64 data"));
361            }
362            let val = DECODE_ALPHABET[ch as usize];
363            if val == 0xFF {
364                return Err(Error::custom("illegal Base 64 data"));
365            }
366            val
367        };
368        self.input[self.next] = val;
369        self.next += 1;
370
371        if self.next == 4 {
372            self.output[0] = self.input[0] << 2 | self.input[1] >> 4;
373
374            if self.input[2] == PAD_MARKER {
375                // The second to last character is padding. The last one
376                // needs to be, too.
377                if self.input[3] == PAD_MARKER {
378                    self.next = EOF_MARKER;
379                    Ok(Some(&self.output[..1]))
380                } else {
381                    Err(Error::custom("illegal Base 64 data"))
382                }
383            } else {
384                self.output[1] = self.input[1] << 4 | self.input[2] >> 2;
385
386                if self.input[3] == PAD_MARKER {
387                    // The last characters is padding.
388                    self.next = EOF_MARKER;
389                    Ok(Some(&self.output[..2]))
390                } else {
391                    self.output[2] = (self.input[2] << 6) | self.input[3];
392                    self.next = 0;
393                    Ok(Some(&self.output))
394                }
395            }
396        } else {
397            Ok(None)
398        }
399    }
400}
401
402impl<Sym, Error> ConvertSymbols<Sym, Error> for SymbolConverter
403where
404    Sym: Into<EntrySymbol>,
405    Error: ScannerError,
406{
407    fn process_symbol(
408        &mut self,
409        symbol: Sym,
410    ) -> Result<Option<&[u8]>, Error> {
411        match symbol.into() {
412            EntrySymbol::Symbol(symbol) => self.process_char(
413                symbol
414                    .into_char()
415                    .map_err(|_| Error::custom("illegal Base 64 data"))?,
416            ),
417            EntrySymbol::EndOfToken => Ok(None),
418        }
419    }
420
421    fn process_tail(&mut self) -> Result<Option<&[u8]>, Error> {
422        // next is either 0 or 0xF0 for a completed group.
423        if self.next & 0x0F != 0 {
424            Err(Error::custom("incomplete Base 64 data"))
425        } else {
426            Ok(None)
427        }
428    }
429}
430
431//============ Error Types ===================================================
432
433//------------ DecodeError ---------------------------------------------------
434
435/// An error happened while decoding a base 64 or base 32 encoded string.
436#[derive(Clone, Copy, Debug, Eq, PartialEq)]
437pub enum DecodeError {
438    /// A character was pushed that isn’t allowed in the encoding.
439    IllegalChar(char),
440
441    /// There was trailing data after a padding sequence.
442    TrailingInput,
443
444    /// The input ended with an incomplete sequence.
445    ShortInput,
446
447    /// The buffer to decode into is too short.
448    ShortBuf,
449}
450
451impl From<ShortBuf> for DecodeError {
452    fn from(_: ShortBuf) -> Self {
453        DecodeError::ShortBuf
454    }
455}
456
457//--- Display and Error
458
459impl fmt::Display for DecodeError {
460    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
461        match *self {
462            DecodeError::TrailingInput => f.write_str("trailing input"),
463            DecodeError::IllegalChar(ch) => {
464                write!(f, "illegal character '{}'", ch)
465            }
466            DecodeError::ShortInput => f.write_str("incomplete input"),
467            DecodeError::ShortBuf => ShortBuf.fmt(f),
468        }
469    }
470}
471
472#[cfg(feature = "std")]
473impl std::error::Error for DecodeError {}
474
475//============ Constants =====================================================
476
477/// The alphabet used by the decoder.
478///
479/// This maps encoding characters into their values. A value of 0xFF stands in
480/// for illegal characters. We only provide the first 128 characters since the
481/// alphabet will only use ASCII characters.
482const DECODE_ALPHABET: [u8; 128] = [
483    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00 .. 0x07
484    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x08 .. 0x0F
485    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10 .. 0x17
486    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x18 .. 0x1F
487    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x20 .. 0x27
488    0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, // 0x28 .. 0x2F
489    0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, // 0x30 .. 0x37
490    0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x38 .. 0x3F
491    0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, // 0x40 .. 0x47
492    0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, // 0x48 .. 0x4F
493    0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // 0x50 .. 0x57
494    0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x58 .. 0x5F
495    0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, // 0x60 .. 0x67
496    0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, // 0x68 .. 0x6F
497    0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, // 0x70 .. 0x77
498    0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x78 .. 0x7F
499];
500
501const ENCODE_ALPHABET: [char; 64] = [
502    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', // 0x00 .. 0x07
503    'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', // 0x08 .. 0x0F
504    'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', // 0x10 .. 0x17
505    'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', // 0x18 .. 0x1F
506    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', // 0x20 .. 0x27
507    'o', 'p', 'q', 'r', 's', 't', 'u', 'v', // 0x28 .. 0x2F
508    'w', 'x', 'y', 'z', '0', '1', '2', '3', // 0x30 .. 0x37
509    '4', '5', '6', '7', '8', '9', '+', '/', // 0x38 .. 0x3F
510];
511
512/// The padding character
513const PAD: char = '=';
514
515/// The marker for padding.
516const PAD_MARKER: u8 = 0x80;
517
518/// The marker for complete data.
519const EOF_MARKER: usize = 0xF0;
520
521//============ Test ==========================================================
522
523#[cfg(test)]
524mod test {
525    use super::*;
526
527    #[allow(dead_code)]
528    const HAPPY_CASES: &[(&[u8], &str)] = &[
529        (b"", ""),
530        (b"f", "Zg=="),
531        (b"fo", "Zm8="),
532        (b"foo", "Zm9v"),
533        (b"foob", "Zm9vYg=="),
534        (b"fooba", "Zm9vYmE="),
535        (b"foobar", "Zm9vYmFy"),
536    ];
537
538    #[cfg(feature = "std")]
539    #[test]
540    fn decode_str() {
541        fn decode(s: &str) -> Result<std::vec::Vec<u8>, DecodeError> {
542            super::decode(s)
543        }
544
545        for (bin, text) in HAPPY_CASES {
546            assert_eq!(&decode(text).unwrap(), bin, "decode {}", text)
547        }
548
549        assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput);
550        assert_eq!(
551            decode("FPucA=").unwrap_err(),
552            DecodeError::IllegalChar('=')
553        );
554        assert_eq!(decode("FPucAw=").unwrap_err(), DecodeError::ShortInput);
555        assert_eq!(
556            decode("FPucAw=a").unwrap_err(),
557            DecodeError::TrailingInput
558        );
559        assert_eq!(
560            decode("FPucAw==a").unwrap_err(),
561            DecodeError::TrailingInput
562        );
563    }
564
565    #[cfg(feature = "std")]
566    #[test]
567    fn symbol_converter() {
568        use crate::base::scan::Symbols;
569        use std::vec::Vec;
570
571        fn decode(s: &str) -> Result<Vec<u8>, std::io::Error> {
572            let mut convert = SymbolConverter::new();
573            let convert: &mut dyn ConvertSymbols<_, std::io::Error> =
574                &mut convert;
575            let mut res = Vec::new();
576            for sym in Symbols::new(s.chars()) {
577                if let Some(octs) = convert.process_symbol(sym)? {
578                    res.extend_from_slice(octs);
579                }
580            }
581            if let Some(octs) = convert.process_tail()? {
582                res.extend_from_slice(octs);
583            }
584            Ok(res)
585        }
586
587        for (bin, text) in HAPPY_CASES {
588            assert_eq!(&decode(text).unwrap(), bin, "convert {}", text)
589        }
590    }
591
592    #[test]
593    #[cfg(feature = "std")]
594    fn display_bytes() {
595        use super::*;
596
597        fn fmt(s: &[u8]) -> String {
598            let mut out = String::new();
599            display(s, &mut out).unwrap();
600            out
601        }
602
603        for (bin, text) in HAPPY_CASES {
604            assert_eq!(&fmt(bin), text, "fmt {}", text);
605        }
606    }
607}