base64ct/
encoding.rs

1//! Base64 encodings
2
3use crate::{
4    alphabet::Alphabet,
5    errors::{Error, InvalidEncodingError, InvalidLengthError},
6};
7use core::str;
8
9#[cfg(feature = "alloc")]
10use alloc::{string::String, vec::Vec};
11
12#[cfg(doc)]
13use crate::{Base64, Base64Bcrypt, Base64Crypt, Base64Unpadded, Base64Url, Base64UrlUnpadded};
14
15/// Padding character
16const PAD: u8 = b'=';
17
18/// Base64 encoding trait.
19///
20/// This trait must be imported to make use of any Base64 alphabet defined
21/// in this crate.
22///
23/// The following encoding types impl this trait:
24///
25/// - [`Base64`]: standard Base64 encoding with `=` padding.
26/// - [`Base64Bcrypt`]: bcrypt Base64 encoding.
27/// - [`Base64Crypt`]: `crypt(3)` Base64 encoding.
28/// - [`Base64Unpadded`]: standard Base64 encoding *without* padding.
29/// - [`Base64Url`]: URL-safe Base64 encoding with `=` padding.
30/// - [`Base64UrlUnpadded`]: URL-safe Base64 encoding *without* padding.
31pub trait Encoding: Alphabet {
32    /// Decode a Base64 string into the provided destination buffer.
33    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error>;
34
35    /// Decode a Base64 string in-place.
36    ///
37    /// NOTE: this method does not (yet) validate that padding is well-formed,
38    /// if the given Base64 encoding is padded.
39    fn decode_in_place(buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError>;
40
41    /// Decode a Base64 string into a byte vector.
42    #[cfg(feature = "alloc")]
43    #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
44    fn decode_vec(input: &str) -> Result<Vec<u8>, Error>;
45
46    /// Encode the input byte slice as Base64.
47    ///
48    /// Writes the result into the provided destination slice, returning an
49    /// ASCII-encoded Base64 string value.
50    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError>;
51
52    /// Encode input byte slice into a [`String`] containing Base64.
53    ///
54    /// # Panics
55    /// If `input` length is greater than `usize::MAX/4`.
56    #[cfg(feature = "alloc")]
57    #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
58    fn encode_string(input: &[u8]) -> String;
59
60    /// Get the length of Base64 produced by encoding the given bytes.
61    ///
62    /// WARNING: this function will return `0` for lengths greater than `usize::MAX/4`!
63    fn encoded_len(bytes: &[u8]) -> usize;
64}
65
66impl<T: Alphabet> Encoding for T {
67    fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error> {
68        let (src_unpadded, mut err) = if T::PADDED {
69            let (unpadded_len, e) = decode_padding(src.as_ref())?;
70            (&src.as_ref()[..unpadded_len], e)
71        } else {
72            (src.as_ref(), 0)
73        };
74
75        let dlen = decoded_len(src_unpadded.len());
76
77        if dlen > dst.len() {
78            return Err(Error::InvalidLength);
79        }
80
81        let dst = &mut dst[..dlen];
82
83        let mut src_chunks = src_unpadded.chunks_exact(4);
84        let mut dst_chunks = dst.chunks_exact_mut(3);
85        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
86            err |= Self::decode_3bytes(s, d);
87        }
88        let src_rem = src_chunks.remainder();
89        let dst_rem = dst_chunks.into_remainder();
90
91        err |= !(src_rem.is_empty() || src_rem.len() >= 2) as i16;
92        let mut tmp_out = [0u8; 3];
93        let mut tmp_in = [b'A'; 4];
94        tmp_in[..src_rem.len()].copy_from_slice(src_rem);
95        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
96        dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
97
98        if err == 0 {
99            validate_last_block::<T>(src.as_ref(), dst)?;
100            Ok(dst)
101        } else {
102            Err(Error::InvalidEncoding)
103        }
104    }
105
106    // TODO(tarcieri): explicitly checked/wrapped arithmetic
107    #[allow(clippy::integer_arithmetic)]
108    fn decode_in_place(mut buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError> {
109        // TODO: eliminate unsafe code when LLVM12 is stable
110        // See: https://github.com/rust-lang/rust/issues/80963
111        let mut err = if T::PADDED {
112            let (unpadded_len, e) = decode_padding(buf)?;
113            buf = &mut buf[..unpadded_len];
114            e
115        } else {
116            0
117        };
118
119        let dlen = decoded_len(buf.len());
120        let full_chunks = buf.len() / 4;
121
122        for chunk in 0..full_chunks {
123            // SAFETY: `p3` and `p4` point inside `buf`, while they may overlap,
124            // read and write are clearly separated from each other and done via
125            // raw pointers.
126            #[allow(unsafe_code)]
127            unsafe {
128                debug_assert!(3 * chunk + 3 <= buf.len());
129                debug_assert!(4 * chunk + 4 <= buf.len());
130
131                let p3 = buf.as_mut_ptr().add(3 * chunk) as *mut [u8; 3];
132                let p4 = buf.as_ptr().add(4 * chunk) as *const [u8; 4];
133
134                let mut tmp_out = [0u8; 3];
135                err |= Self::decode_3bytes(&*p4, &mut tmp_out);
136                *p3 = tmp_out;
137            }
138        }
139
140        let src_rem_pos = 4 * full_chunks;
141        let src_rem_len = buf.len() - src_rem_pos;
142        let dst_rem_pos = 3 * full_chunks;
143        let dst_rem_len = dlen - dst_rem_pos;
144
145        err |= !(src_rem_len == 0 || src_rem_len >= 2) as i16;
146        let mut tmp_in = [b'A'; 4];
147        tmp_in[..src_rem_len].copy_from_slice(&buf[src_rem_pos..]);
148        let mut tmp_out = [0u8; 3];
149
150        err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
151
152        if err == 0 {
153            // SAFETY: `dst_rem_len` is always smaller than 4, so we don't
154            // read outside of `tmp_out`, write and the final slicing never go
155            // outside of `buf`.
156            #[allow(unsafe_code)]
157            unsafe {
158                debug_assert!(dst_rem_pos + dst_rem_len <= buf.len());
159                debug_assert!(dst_rem_len <= tmp_out.len());
160                debug_assert!(dlen <= buf.len());
161
162                core::ptr::copy_nonoverlapping(
163                    tmp_out.as_ptr(),
164                    buf.as_mut_ptr().add(dst_rem_pos),
165                    dst_rem_len,
166                );
167                Ok(buf.get_unchecked(..dlen))
168            }
169        } else {
170            Err(InvalidEncodingError)
171        }
172    }
173
174    #[cfg(feature = "alloc")]
175    fn decode_vec(input: &str) -> Result<Vec<u8>, Error> {
176        let mut output = vec![0u8; decoded_len(input.len())];
177        let len = Self::decode(input, &mut output)?.len();
178
179        if len <= output.len() {
180            output.truncate(len);
181            Ok(output)
182        } else {
183            Err(Error::InvalidLength)
184        }
185    }
186
187    fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError> {
188        let elen = match encoded_len_inner(src.len(), T::PADDED) {
189            Some(v) => v,
190            None => return Err(InvalidLengthError),
191        };
192
193        if elen > dst.len() {
194            return Err(InvalidLengthError);
195        }
196
197        let dst = &mut dst[..elen];
198
199        let mut src_chunks = src.chunks_exact(3);
200        let mut dst_chunks = dst.chunks_exact_mut(4);
201
202        for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
203            Self::encode_3bytes(s, d);
204        }
205
206        let src_rem = src_chunks.remainder();
207
208        if T::PADDED {
209            if let Some(dst_rem) = dst_chunks.next() {
210                let mut tmp = [0u8; 3];
211                tmp[..src_rem.len()].copy_from_slice(src_rem);
212                Self::encode_3bytes(&tmp, dst_rem);
213
214                let flag = src_rem.len() == 1;
215                let mask = (flag as u8).wrapping_sub(1);
216                dst_rem[2] = (dst_rem[2] & mask) | (PAD & !mask);
217                dst_rem[3] = PAD;
218            }
219        } else {
220            let dst_rem = dst_chunks.into_remainder();
221
222            let mut tmp_in = [0u8; 3];
223            let mut tmp_out = [0u8; 4];
224            tmp_in[..src_rem.len()].copy_from_slice(src_rem);
225            Self::encode_3bytes(&tmp_in, &mut tmp_out);
226            dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
227        }
228
229        debug_assert!(str::from_utf8(dst).is_ok());
230
231        // SAFETY: values written by `encode_3bytes` are valid one-byte UTF-8 chars
232        #[allow(unsafe_code)]
233        Ok(unsafe { str::from_utf8_unchecked(dst) })
234    }
235
236    #[cfg(feature = "alloc")]
237    fn encode_string(input: &[u8]) -> String {
238        let elen = encoded_len_inner(input.len(), T::PADDED).expect("input is too big");
239        let mut dst = vec![0u8; elen];
240        let res = Self::encode(input, &mut dst).expect("encoding error");
241
242        debug_assert_eq!(elen, res.len());
243        debug_assert!(str::from_utf8(&dst).is_ok());
244
245        // SAFETY: `dst` is fully written and contains only valid one-byte UTF-8 chars
246        #[allow(unsafe_code)]
247        unsafe {
248            String::from_utf8_unchecked(dst)
249        }
250    }
251
252    fn encoded_len(bytes: &[u8]) -> usize {
253        encoded_len_inner(bytes.len(), T::PADDED).unwrap_or(0)
254    }
255}
256
257/// Validate padding is of the expected length compute unpadded length.
258///
259/// Note that this method does not explicitly check that the padded data
260/// is valid in and of itself: that is performed by `validate_last_block` as a
261/// final step.
262///
263/// Returns length-related errors eagerly as a [`Result`], and data-dependent
264/// errors (i.e. malformed padding bytes) as `i16` to be combined with other
265/// encoding-related errors prior to branching.
266#[inline(always)]
267pub(crate) fn decode_padding(input: &[u8]) -> Result<(usize, i16), InvalidEncodingError> {
268    if input.len() % 4 != 0 {
269        return Err(InvalidEncodingError);
270    }
271
272    let unpadded_len = match *input {
273        [.., b0, b1] => is_pad_ct(b0)
274            .checked_add(is_pad_ct(b1))
275            .and_then(|len| len.try_into().ok())
276            .and_then(|len| input.len().checked_sub(len))
277            .ok_or(InvalidEncodingError)?,
278        _ => input.len(),
279    };
280
281    let padding_len = input
282        .len()
283        .checked_sub(unpadded_len)
284        .ok_or(InvalidEncodingError)?;
285
286    let err = match *input {
287        [.., b0] if padding_len == 1 => is_pad_ct(b0) ^ 1,
288        [.., b0, b1] if padding_len == 2 => (is_pad_ct(b0) & is_pad_ct(b1)) ^ 1,
289        _ => {
290            if padding_len == 0 {
291                0
292            } else {
293                return Err(InvalidEncodingError);
294            }
295        }
296    };
297
298    Ok((unpadded_len, err))
299}
300
301/// Validate that the last block of the decoded data round-trips back to the
302/// encoded data.
303fn validate_last_block<T: Alphabet>(encoded: &[u8], decoded: &[u8]) -> Result<(), Error> {
304    if encoded.is_empty() && decoded.is_empty() {
305        return Ok(());
306    }
307
308    // TODO(tarcieri): explicitly checked/wrapped arithmetic
309    #[allow(clippy::integer_arithmetic)]
310    fn last_block_start(bytes: &[u8], block_size: usize) -> usize {
311        (bytes.len().saturating_sub(1) / block_size) * block_size
312    }
313
314    let enc_block = encoded
315        .get(last_block_start(encoded, 4)..)
316        .ok_or(Error::InvalidEncoding)?;
317
318    let dec_block = decoded
319        .get(last_block_start(decoded, 3)..)
320        .ok_or(Error::InvalidEncoding)?;
321
322    // Round-trip encode the decoded block
323    let mut buf = [0u8; 4];
324    let block = T::encode(dec_block, &mut buf)?;
325
326    // Non-short-circuiting comparison of padding
327    // TODO(tarcieri): better constant-time mechanisms (e.g. `subtle`)?
328    if block
329        .as_bytes()
330        .iter()
331        .zip(enc_block.iter())
332        .fold(0, |acc, (a, b)| acc | (a ^ b))
333        == 0
334    {
335        Ok(())
336    } else {
337        Err(Error::InvalidEncoding)
338    }
339}
340
341/// Get the length of the output from decoding the provided *unpadded*
342/// Base64-encoded input.
343///
344/// Note that this function does not fully validate the Base64 is well-formed
345/// and may return incorrect results for malformed Base64.
346// TODO(tarcieri): explicitly checked/wrapped arithmetic
347#[allow(clippy::integer_arithmetic)]
348#[inline(always)]
349pub(crate) fn decoded_len(input_len: usize) -> usize {
350    // overflow-proof computation of `(3*n)/4`
351    let k = input_len / 4;
352    let l = input_len - 4 * k;
353    3 * k + (3 * l) / 4
354}
355
356/// Branchless match that a given byte is the `PAD` character
357// TODO(tarcieri): explicitly checked/wrapped arithmetic
358#[allow(clippy::integer_arithmetic)]
359#[inline(always)]
360fn is_pad_ct(input: u8) -> i16 {
361    ((((PAD as i16 - 1) - input as i16) & (input as i16 - (PAD as i16 + 1))) >> 8) & 1
362}
363
364// TODO(tarcieri): explicitly checked/wrapped arithmetic
365#[allow(clippy::integer_arithmetic)]
366#[inline(always)]
367const fn encoded_len_inner(n: usize, padded: bool) -> Option<usize> {
368    match n.checked_mul(4) {
369        Some(q) => {
370            if padded {
371                Some(((q / 3) + 3) & !3)
372            } else {
373                Some((q / 3) + (q % 3 != 0) as usize)
374            }
375        }
376        None => None,
377    }
378}