base64ct/
encoder.rs

1//! Buffered Base64 encoder.
2
3use crate::{
4    Encoding,
5    Error::{self, InvalidLength},
6    LineEnding, MIN_LINE_WIDTH,
7};
8use core::{cmp, marker::PhantomData, str};
9
10#[cfg(feature = "std")]
11use std::io;
12
13#[cfg(doc)]
14use crate::{Base64, Base64Unpadded};
15
16/// Stateful Base64 encoder with support for buffered, incremental encoding.
17///
18/// The `E` type parameter can be any type which impls [`Encoding`] such as
19/// [`Base64`] or [`Base64Unpadded`].
20pub struct Encoder<'o, E: Encoding> {
21    /// Output buffer.
22    output: &'o mut [u8],
23
24    /// Cursor within the output buffer.
25    position: usize,
26
27    /// Block buffer used for non-block-aligned data.
28    block_buffer: BlockBuffer,
29
30    /// Configuration and state for line-wrapping the output at a specified
31    /// column.
32    line_wrapper: Option<LineWrapper>,
33
34    /// Phantom parameter for the Base64 encoding in use.
35    encoding: PhantomData<E>,
36}
37
38impl<'o, E: Encoding> Encoder<'o, E> {
39    /// Create a new encoder which writes output to the given byte slice.
40    ///
41    /// Output constructed using this method is not line-wrapped.
42    pub fn new(output: &'o mut [u8]) -> Result<Self, Error> {
43        if output.is_empty() {
44            return Err(InvalidLength);
45        }
46
47        Ok(Self {
48            output,
49            position: 0,
50            block_buffer: BlockBuffer::default(),
51            line_wrapper: None,
52            encoding: PhantomData,
53        })
54    }
55
56    /// Create a new encoder which writes line-wrapped output to the given byte
57    /// slice.
58    ///
59    /// Output will be wrapped at the specified interval, using the provided
60    /// line ending. Use [`LineEnding::default()`] to use the conventional line
61    /// ending for the target OS.
62    ///
63    /// Minimum allowed line width is 4.
64    pub fn new_wrapped(
65        output: &'o mut [u8],
66        width: usize,
67        ending: LineEnding,
68    ) -> Result<Self, Error> {
69        let mut encoder = Self::new(output)?;
70        encoder.line_wrapper = Some(LineWrapper::new(width, ending)?);
71        Ok(encoder)
72    }
73
74    /// Encode the provided buffer as Base64, writing it to the output buffer.
75    ///
76    /// # Returns
77    /// - `Ok(bytes)` if the expected amount of data was read
78    /// - `Err(Error::InvalidLength)` if there is insufficient space in the output buffer
79    pub fn encode(&mut self, mut input: &[u8]) -> Result<(), Error> {
80        // If there's data in the block buffer, fill it
81        if !self.block_buffer.is_empty() {
82            self.process_buffer(&mut input)?;
83        }
84
85        while !input.is_empty() {
86            // Attempt to encode a stride of block-aligned data
87            let in_blocks = input.len() / 3;
88            let out_blocks = self.remaining().len() / 4;
89            let mut blocks = cmp::min(in_blocks, out_blocks);
90
91            // When line wrapping, cap the block-aligned stride at near/at line length
92            if let Some(line_wrapper) = &self.line_wrapper {
93                line_wrapper.wrap_blocks(&mut blocks)?;
94            }
95
96            if blocks > 0 {
97                let len = blocks.checked_mul(3).ok_or(InvalidLength)?;
98                let (in_aligned, in_rem) = input.split_at(len);
99                input = in_rem;
100                self.perform_encode(in_aligned)?;
101            }
102
103            // If there's remaining non-aligned data, fill the block buffer
104            if !input.is_empty() {
105                self.process_buffer(&mut input)?;
106            }
107        }
108
109        Ok(())
110    }
111
112    /// Get the position inside of the output buffer where the write cursor
113    /// is currently located.
114    pub fn position(&self) -> usize {
115        self.position
116    }
117
118    /// Finish encoding data, returning the resulting Base64 as a `str`.
119    pub fn finish(self) -> Result<&'o str, Error> {
120        self.finish_with_remaining().map(|(base64, _)| base64)
121    }
122
123    /// Finish encoding data, returning the resulting Base64 as a `str`
124    /// along with the remaining space in the output buffer.
125    pub fn finish_with_remaining(mut self) -> Result<(&'o str, &'o mut [u8]), Error> {
126        if !self.block_buffer.is_empty() {
127            let buffer_len = self.block_buffer.position;
128            let block = self.block_buffer.bytes;
129            self.perform_encode(&block[..buffer_len])?;
130        }
131
132        let (base64, remaining) = self.output.split_at_mut(self.position);
133        Ok((str::from_utf8(base64)?, remaining))
134    }
135
136    /// Borrow the remaining data in the buffer.
137    fn remaining(&mut self) -> &mut [u8] {
138        &mut self.output[self.position..]
139    }
140
141    /// Fill the block buffer with data, consuming and encoding it when the
142    /// buffer is full.
143    fn process_buffer(&mut self, input: &mut &[u8]) -> Result<(), Error> {
144        self.block_buffer.fill(input)?;
145
146        if self.block_buffer.is_full() {
147            let block = self.block_buffer.take();
148            self.perform_encode(&block)?;
149        }
150
151        Ok(())
152    }
153
154    /// Perform Base64 encoding operation.
155    fn perform_encode(&mut self, input: &[u8]) -> Result<usize, Error> {
156        let mut len = E::encode(input, self.remaining())?.as_bytes().len();
157
158        // Insert newline characters into the output as needed
159        if let Some(line_wrapper) = &mut self.line_wrapper {
160            line_wrapper.insert_newlines(&mut self.output[self.position..], &mut len)?;
161        }
162
163        self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
164        Ok(len)
165    }
166}
167
168#[cfg(feature = "std")]
169#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
170impl<'o, E: Encoding> io::Write for Encoder<'o, E> {
171    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172        self.encode(buf)?;
173        Ok(buf.len())
174    }
175
176    fn flush(&mut self) -> io::Result<()> {
177        // TODO(tarcieri): return an error if there's still data remaining in the buffer?
178        Ok(())
179    }
180}
181
182/// Base64 encode buffer for a 1-block output.
183///
184/// This handles a partial block of data, i.e. data which hasn't been
185#[derive(Clone, Default, Debug)]
186struct BlockBuffer {
187    /// 3 decoded bytes to be encoded to a 4-byte Base64-encoded input.
188    bytes: [u8; Self::SIZE],
189
190    /// Position within the buffer.
191    position: usize,
192}
193
194impl BlockBuffer {
195    /// Size of the buffer in bytes: 3-bytes of unencoded input which
196    /// Base64 encode to 4-bytes of output.
197    const SIZE: usize = 3;
198
199    /// Fill the remaining space in the buffer with the input data.
200    fn fill(&mut self, input: &mut &[u8]) -> Result<(), Error> {
201        let remaining = Self::SIZE.checked_sub(self.position).ok_or(InvalidLength)?;
202        let len = cmp::min(input.len(), remaining);
203        self.bytes[self.position..][..len].copy_from_slice(&input[..len]);
204        self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
205        *input = &input[len..];
206        Ok(())
207    }
208
209    /// Take the output buffer, resetting the position to 0.
210    fn take(&mut self) -> [u8; Self::SIZE] {
211        debug_assert!(self.is_full());
212        let result = self.bytes;
213        *self = Default::default();
214        result
215    }
216
217    /// Is the buffer empty?
218    fn is_empty(&self) -> bool {
219        self.position == 0
220    }
221
222    /// Is the buffer full?
223    fn is_full(&self) -> bool {
224        self.position == Self::SIZE
225    }
226}
227
228/// Helper for wrapping Base64 at a given line width.
229#[derive(Debug)]
230struct LineWrapper {
231    /// Number of bytes remaining in the current line.
232    remaining: usize,
233
234    /// Column at which Base64 should be wrapped.
235    width: usize,
236
237    /// Newline characters to use at the end of each line.
238    ending: LineEnding,
239}
240
241impl LineWrapper {
242    /// Create a new linewrapper.
243    fn new(width: usize, ending: LineEnding) -> Result<Self, Error> {
244        if width < MIN_LINE_WIDTH {
245            return Err(InvalidLength);
246        }
247
248        Ok(Self {
249            remaining: width,
250            width,
251            ending,
252        })
253    }
254
255    /// Wrap the number of blocks to encode near/at EOL.
256    fn wrap_blocks(&self, blocks: &mut usize) -> Result<(), Error> {
257        if blocks.checked_mul(4).ok_or(InvalidLength)? >= self.remaining {
258            *blocks = self.remaining / 4;
259        }
260
261        Ok(())
262    }
263
264    /// Insert newlines into the output buffer as needed.
265    fn insert_newlines(&mut self, mut buffer: &mut [u8], len: &mut usize) -> Result<(), Error> {
266        let mut buffer_len = *len;
267
268        if buffer_len <= self.remaining {
269            self.remaining = self
270                .remaining
271                .checked_sub(buffer_len)
272                .ok_or(InvalidLength)?;
273
274            return Ok(());
275        }
276
277        buffer = &mut buffer[self.remaining..];
278        buffer_len = buffer_len
279            .checked_sub(self.remaining)
280            .ok_or(InvalidLength)?;
281
282        // The `wrap_blocks` function should ensure the buffer is no larger than a Base64 block
283        debug_assert!(buffer_len <= 4, "buffer too long: {}", buffer_len);
284
285        // Ensure space in buffer to add newlines
286        let buffer_end = buffer_len
287            .checked_add(self.ending.len())
288            .ok_or(InvalidLength)?;
289
290        if buffer_end >= buffer.len() {
291            return Err(InvalidLength);
292        }
293
294        // Shift the buffer contents to make space for the line ending
295        for i in (0..buffer_len).rev() {
296            buffer[i.checked_add(self.ending.len()).ok_or(InvalidLength)?] = buffer[i];
297        }
298
299        buffer[..self.ending.len()].copy_from_slice(self.ending.as_bytes());
300        *len = (*len).checked_add(self.ending.len()).ok_or(InvalidLength)?;
301        self.remaining = self.width.checked_sub(buffer_len).ok_or(InvalidLength)?;
302
303        Ok(())
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Encoder, LineEnding};
310
311    #[test]
312    fn encode_padded() {
313        encode_test::<Base64>(PADDED_BIN, PADDED_BASE64, None);
314    }
315
316    #[test]
317    fn encode_unpadded() {
318        encode_test::<Base64Unpadded>(UNPADDED_BIN, UNPADDED_BASE64, None);
319    }
320
321    #[test]
322    fn encode_multiline_padded() {
323        encode_test::<Base64>(MULTILINE_PADDED_BIN, MULTILINE_PADDED_BASE64, Some(70));
324    }
325
326    #[test]
327    fn encode_multiline_unpadded() {
328        encode_test::<Base64Unpadded>(MULTILINE_UNPADDED_BIN, MULTILINE_UNPADDED_BASE64, Some(70));
329    }
330
331    #[test]
332    fn no_trailing_newline_when_aligned() {
333        let mut buffer = [0u8; 64];
334        let mut encoder = Encoder::<Base64>::new_wrapped(&mut buffer, 64, LineEnding::LF).unwrap();
335        encoder.encode(&[0u8; 48]).unwrap();
336
337        // Ensure no newline character is present in this case
338        assert_eq!(
339            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
340            encoder.finish().unwrap()
341        );
342    }
343
344    /// Core functionality of an encoding test.
345    fn encode_test<V: Alphabet>(input: &[u8], expected: &str, wrapped: Option<usize>) {
346        let mut buffer = [0u8; 1024];
347
348        for chunk_size in 1..input.len() {
349            let mut encoder = match wrapped {
350                Some(line_width) => {
351                    Encoder::<V>::new_wrapped(&mut buffer, line_width, LineEnding::LF)
352                }
353                None => Encoder::<V>::new(&mut buffer),
354            }
355            .unwrap();
356
357            for chunk in input.chunks(chunk_size) {
358                encoder.encode(chunk).unwrap();
359            }
360
361            assert_eq!(expected, encoder.finish().unwrap());
362        }
363    }
364}