1use crate::{
4 Encoding,
5 Error::{self, InvalidLength},
6 LineEnding, MIN_LINE_WIDTH,
7};
8use core::{cmp, marker::PhantomData, str};
9
10#[cfg(feature = "std")]
11use std::io;
12
13#[cfg(doc)]
14use crate::{Base64, Base64Unpadded};
15
16pub struct Encoder<'o, E: Encoding> {
21 output: &'o mut [u8],
23
24 position: usize,
26
27 block_buffer: BlockBuffer,
29
30 line_wrapper: Option<LineWrapper>,
33
34 encoding: PhantomData<E>,
36}
37
38impl<'o, E: Encoding> Encoder<'o, E> {
39 pub fn new(output: &'o mut [u8]) -> Result<Self, Error> {
43 if output.is_empty() {
44 return Err(InvalidLength);
45 }
46
47 Ok(Self {
48 output,
49 position: 0,
50 block_buffer: BlockBuffer::default(),
51 line_wrapper: None,
52 encoding: PhantomData,
53 })
54 }
55
56 pub fn new_wrapped(
65 output: &'o mut [u8],
66 width: usize,
67 ending: LineEnding,
68 ) -> Result<Self, Error> {
69 let mut encoder = Self::new(output)?;
70 encoder.line_wrapper = Some(LineWrapper::new(width, ending)?);
71 Ok(encoder)
72 }
73
74 pub fn encode(&mut self, mut input: &[u8]) -> Result<(), Error> {
80 if !self.block_buffer.is_empty() {
82 self.process_buffer(&mut input)?;
83 }
84
85 while !input.is_empty() {
86 let in_blocks = input.len() / 3;
88 let out_blocks = self.remaining().len() / 4;
89 let mut blocks = cmp::min(in_blocks, out_blocks);
90
91 if let Some(line_wrapper) = &self.line_wrapper {
93 line_wrapper.wrap_blocks(&mut blocks)?;
94 }
95
96 if blocks > 0 {
97 let len = blocks.checked_mul(3).ok_or(InvalidLength)?;
98 let (in_aligned, in_rem) = input.split_at(len);
99 input = in_rem;
100 self.perform_encode(in_aligned)?;
101 }
102
103 if !input.is_empty() {
105 self.process_buffer(&mut input)?;
106 }
107 }
108
109 Ok(())
110 }
111
112 pub fn position(&self) -> usize {
115 self.position
116 }
117
118 pub fn finish(self) -> Result<&'o str, Error> {
120 self.finish_with_remaining().map(|(base64, _)| base64)
121 }
122
123 pub fn finish_with_remaining(mut self) -> Result<(&'o str, &'o mut [u8]), Error> {
126 if !self.block_buffer.is_empty() {
127 let buffer_len = self.block_buffer.position;
128 let block = self.block_buffer.bytes;
129 self.perform_encode(&block[..buffer_len])?;
130 }
131
132 let (base64, remaining) = self.output.split_at_mut(self.position);
133 Ok((str::from_utf8(base64)?, remaining))
134 }
135
136 fn remaining(&mut self) -> &mut [u8] {
138 &mut self.output[self.position..]
139 }
140
141 fn process_buffer(&mut self, input: &mut &[u8]) -> Result<(), Error> {
144 self.block_buffer.fill(input)?;
145
146 if self.block_buffer.is_full() {
147 let block = self.block_buffer.take();
148 self.perform_encode(&block)?;
149 }
150
151 Ok(())
152 }
153
154 fn perform_encode(&mut self, input: &[u8]) -> Result<usize, Error> {
156 let mut len = E::encode(input, self.remaining())?.as_bytes().len();
157
158 if let Some(line_wrapper) = &mut self.line_wrapper {
160 line_wrapper.insert_newlines(&mut self.output[self.position..], &mut len)?;
161 }
162
163 self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
164 Ok(len)
165 }
166}
167
168#[cfg(feature = "std")]
169#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
170impl<'o, E: Encoding> io::Write for Encoder<'o, E> {
171 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
172 self.encode(buf)?;
173 Ok(buf.len())
174 }
175
176 fn flush(&mut self) -> io::Result<()> {
177 Ok(())
179 }
180}
181
182#[derive(Clone, Default, Debug)]
186struct BlockBuffer {
187 bytes: [u8; Self::SIZE],
189
190 position: usize,
192}
193
194impl BlockBuffer {
195 const SIZE: usize = 3;
198
199 fn fill(&mut self, input: &mut &[u8]) -> Result<(), Error> {
201 let remaining = Self::SIZE.checked_sub(self.position).ok_or(InvalidLength)?;
202 let len = cmp::min(input.len(), remaining);
203 self.bytes[self.position..][..len].copy_from_slice(&input[..len]);
204 self.position = self.position.checked_add(len).ok_or(InvalidLength)?;
205 *input = &input[len..];
206 Ok(())
207 }
208
209 fn take(&mut self) -> [u8; Self::SIZE] {
211 debug_assert!(self.is_full());
212 let result = self.bytes;
213 *self = Default::default();
214 result
215 }
216
217 fn is_empty(&self) -> bool {
219 self.position == 0
220 }
221
222 fn is_full(&self) -> bool {
224 self.position == Self::SIZE
225 }
226}
227
228#[derive(Debug)]
230struct LineWrapper {
231 remaining: usize,
233
234 width: usize,
236
237 ending: LineEnding,
239}
240
241impl LineWrapper {
242 fn new(width: usize, ending: LineEnding) -> Result<Self, Error> {
244 if width < MIN_LINE_WIDTH {
245 return Err(InvalidLength);
246 }
247
248 Ok(Self {
249 remaining: width,
250 width,
251 ending,
252 })
253 }
254
255 fn wrap_blocks(&self, blocks: &mut usize) -> Result<(), Error> {
257 if blocks.checked_mul(4).ok_or(InvalidLength)? >= self.remaining {
258 *blocks = self.remaining / 4;
259 }
260
261 Ok(())
262 }
263
264 fn insert_newlines(&mut self, mut buffer: &mut [u8], len: &mut usize) -> Result<(), Error> {
266 let mut buffer_len = *len;
267
268 if buffer_len <= self.remaining {
269 self.remaining = self
270 .remaining
271 .checked_sub(buffer_len)
272 .ok_or(InvalidLength)?;
273
274 return Ok(());
275 }
276
277 buffer = &mut buffer[self.remaining..];
278 buffer_len = buffer_len
279 .checked_sub(self.remaining)
280 .ok_or(InvalidLength)?;
281
282 debug_assert!(buffer_len <= 4, "buffer too long: {}", buffer_len);
284
285 let buffer_end = buffer_len
287 .checked_add(self.ending.len())
288 .ok_or(InvalidLength)?;
289
290 if buffer_end >= buffer.len() {
291 return Err(InvalidLength);
292 }
293
294 for i in (0..buffer_len).rev() {
296 buffer[i.checked_add(self.ending.len()).ok_or(InvalidLength)?] = buffer[i];
297 }
298
299 buffer[..self.ending.len()].copy_from_slice(self.ending.as_bytes());
300 *len = (*len).checked_add(self.ending.len()).ok_or(InvalidLength)?;
301 self.remaining = self.width.checked_sub(buffer_len).ok_or(InvalidLength)?;
302
303 Ok(())
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Encoder, LineEnding};
310
311 #[test]
312 fn encode_padded() {
313 encode_test::<Base64>(PADDED_BIN, PADDED_BASE64, None);
314 }
315
316 #[test]
317 fn encode_unpadded() {
318 encode_test::<Base64Unpadded>(UNPADDED_BIN, UNPADDED_BASE64, None);
319 }
320
321 #[test]
322 fn encode_multiline_padded() {
323 encode_test::<Base64>(MULTILINE_PADDED_BIN, MULTILINE_PADDED_BASE64, Some(70));
324 }
325
326 #[test]
327 fn encode_multiline_unpadded() {
328 encode_test::<Base64Unpadded>(MULTILINE_UNPADDED_BIN, MULTILINE_UNPADDED_BASE64, Some(70));
329 }
330
331 #[test]
332 fn no_trailing_newline_when_aligned() {
333 let mut buffer = [0u8; 64];
334 let mut encoder = Encoder::<Base64>::new_wrapped(&mut buffer, 64, LineEnding::LF).unwrap();
335 encoder.encode(&[0u8; 48]).unwrap();
336
337 assert_eq!(
339 "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
340 encoder.finish().unwrap()
341 );
342 }
343
344 fn encode_test<V: Alphabet>(input: &[u8], expected: &str, wrapped: Option<usize>) {
346 let mut buffer = [0u8; 1024];
347
348 for chunk_size in 1..input.len() {
349 let mut encoder = match wrapped {
350 Some(line_width) => {
351 Encoder::<V>::new_wrapped(&mut buffer, line_width, LineEnding::LF)
352 }
353 None => Encoder::<V>::new(&mut buffer),
354 }
355 .unwrap();
356
357 for chunk in input.chunks(chunk_size) {
358 encoder.encode(chunk).unwrap();
359 }
360
361 assert_eq!(expected, encoder.finish().unwrap());
362 }
363 }
364}