1use crate::{
4 encoding,
5 line_ending::{CHAR_CR, CHAR_LF},
6 Encoding,
7 Error::{self, InvalidLength},
8 MIN_LINE_WIDTH,
9};
10use core::{cmp, marker::PhantomData};
11
12#[cfg(feature = "alloc")]
13use {alloc::vec::Vec, core::iter};
14
15#[cfg(feature = "std")]
16use std::io;
17
18#[cfg(doc)]
19use crate::{Base64, Base64Unpadded};
20
21#[derive(Clone)]
26pub struct Decoder<'i, E: Encoding> {
27 line: Line<'i>,
29
30 line_reader: LineReader<'i>,
32
33 remaining_len: usize,
35
36 block_buffer: BlockBuffer,
38
39 encoding: PhantomData<E>,
41}
42
43impl<'i, E: Encoding> Decoder<'i, E> {
44 pub fn new(input: &'i [u8]) -> Result<Self, Error> {
51 let line_reader = LineReader::new_unwrapped(input)?;
52 let remaining_len = line_reader.decoded_len::<E>()?;
53
54 Ok(Self {
55 line: Line::default(),
56 line_reader,
57 remaining_len,
58 block_buffer: BlockBuffer::default(),
59 encoding: PhantomData,
60 })
61 }
62
63 pub fn new_wrapped(input: &'i [u8], line_width: usize) -> Result<Self, Error> {
88 let line_reader = LineReader::new_wrapped(input, line_width)?;
89 let remaining_len = line_reader.decoded_len::<E>()?;
90
91 Ok(Self {
92 line: Line::default(),
93 line_reader,
94 remaining_len,
95 block_buffer: BlockBuffer::default(),
96 encoding: PhantomData,
97 })
98 }
99
100 pub fn decode<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8], Error> {
108 if self.is_finished() {
109 return Err(InvalidLength);
110 }
111
112 let mut out_pos = 0;
113
114 while out_pos < out.len() {
115 if !self.block_buffer.is_empty() {
117 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
118 let bytes = self.block_buffer.take(out_rem)?;
119 out[out_pos..][..bytes.len()].copy_from_slice(bytes);
120 out_pos = out_pos.checked_add(bytes.len()).ok_or(InvalidLength)?;
121 }
122
123 if self.line.is_empty() && !self.line_reader.is_empty() {
125 self.advance_line()?;
126 }
127
128 let in_blocks = self.line.len() / 4;
130 let out_rem = out.len().checked_sub(out_pos).ok_or(InvalidLength)?;
131 let out_blocks = out_rem / 3;
132 let blocks = cmp::min(in_blocks, out_blocks);
133 let in_aligned = self.line.take(blocks.checked_mul(4).ok_or(InvalidLength)?);
134
135 if !in_aligned.is_empty() {
136 let out_buf = &mut out[out_pos..][..blocks.checked_mul(3).ok_or(InvalidLength)?];
137 let decoded_len = self.perform_decode(in_aligned, out_buf)?.len();
138 out_pos = out_pos.checked_add(decoded_len).ok_or(InvalidLength)?;
139 }
140
141 if out_pos < out.len() {
142 if self.is_finished() {
143 return Err(InvalidLength);
146 } else {
147 self.fill_block_buffer()?;
152 }
153 }
154 }
155
156 self.remaining_len = self
157 .remaining_len
158 .checked_sub(out.len())
159 .ok_or(InvalidLength)?;
160
161 Ok(out)
162 }
163
164 #[cfg(feature = "alloc")]
169 #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
170 pub fn decode_to_end<'o>(&mut self, buf: &'o mut Vec<u8>) -> Result<&'o [u8], Error> {
171 let start_len = buf.len();
172 let remaining_len = self.remaining_len();
173 let total_len = start_len.checked_add(remaining_len).ok_or(InvalidLength)?;
174
175 if total_len > buf.capacity() {
176 buf.reserve(total_len.checked_sub(buf.capacity()).ok_or(InvalidLength)?);
177 }
178
179 buf.extend(iter::repeat(0).take(remaining_len));
181 self.decode(&mut buf[start_len..])?;
182 Ok(&buf[start_len..])
183 }
184
185 pub fn remaining_len(&self) -> usize {
189 self.remaining_len
190 }
191
192 pub fn is_finished(&self) -> bool {
194 self.line.is_empty() && self.line_reader.is_empty() && self.block_buffer.is_empty()
195 }
196
197 fn fill_block_buffer(&mut self) -> Result<(), Error> {
199 let mut buf = [0u8; BlockBuffer::SIZE];
200
201 let decoded = if self.line.len() < 4 && !self.line_reader.is_empty() {
202 let mut tmp = [0u8; 4];
204
205 let line_end = self.line.take(4);
207 tmp[..line_end.len()].copy_from_slice(line_end);
208
209 self.advance_line()?;
211 let len = 4usize.checked_sub(line_end.len()).ok_or(InvalidLength)?;
212 let line_begin = self.line.take(len);
213 tmp[line_end.len()..][..line_begin.len()].copy_from_slice(line_begin);
214
215 let tmp_len = line_begin
216 .len()
217 .checked_add(line_end.len())
218 .ok_or(InvalidLength)?;
219
220 self.perform_decode(&tmp[..tmp_len], &mut buf)
221 } else {
222 let block = self.line.take(4);
223 self.perform_decode(block, &mut buf)
224 }?;
225
226 self.block_buffer.fill(decoded)
227 }
228
229 fn advance_line(&mut self) -> Result<(), Error> {
231 debug_assert!(self.line.is_empty(), "expected line buffer to be empty");
232
233 if let Some(line) = self.line_reader.next().transpose()? {
234 self.line = line;
235 Ok(())
236 } else {
237 Err(InvalidLength)
238 }
239 }
240
241 fn perform_decode<'o>(&self, src: &[u8], dst: &'o mut [u8]) -> Result<&'o [u8], Error> {
243 if self.is_finished() {
244 E::decode(src, dst)
245 } else {
246 E::Unpadded::decode(src, dst)
247 }
248 }
249}
250
251#[cfg(feature = "std")]
252#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
253impl<'i, E: Encoding> io::Read for Decoder<'i, E> {
254 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
255 let slice = match buf.get_mut(..self.remaining_len()) {
256 Some(bytes) => bytes,
257 None => buf,
258 };
259
260 self.decode(slice)?;
261 Ok(slice.len())
262 }
263
264 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
265 Ok(self.decode_to_end(buf)?.len())
266 }
267
268 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
269 self.decode(buf)?;
270 Ok(())
271 }
272}
273
274#[derive(Clone, Default, Debug)]
279struct BlockBuffer {
280 decoded: [u8; Self::SIZE],
282
283 length: usize,
285
286 position: usize,
288}
289
290impl BlockBuffer {
291 const SIZE: usize = 3;
293
294 fn fill(&mut self, decoded_input: &[u8]) -> Result<(), Error> {
296 debug_assert!(self.is_empty());
297
298 if decoded_input.len() > Self::SIZE {
299 return Err(InvalidLength);
300 }
301
302 self.position = 0;
303 self.length = decoded_input.len();
304 self.decoded[..decoded_input.len()].copy_from_slice(decoded_input);
305 Ok(())
306 }
307
308 fn take(&mut self, mut nbytes: usize) -> Result<&[u8], Error> {
313 debug_assert!(self.position <= self.length);
314 let start_pos = self.position;
315 let remaining_len = self.length.checked_sub(start_pos).ok_or(InvalidLength)?;
316
317 if nbytes > remaining_len {
318 nbytes = remaining_len;
319 }
320
321 self.position = self.position.checked_add(nbytes).ok_or(InvalidLength)?;
322 Ok(&self.decoded[start_pos..][..nbytes])
323 }
324
325 fn is_empty(&self) -> bool {
327 self.position == self.length
328 }
329}
330
331#[derive(Clone, Debug)]
333pub struct Line<'i> {
334 remaining: &'i [u8],
336}
337
338impl<'i> Default for Line<'i> {
339 fn default() -> Self {
340 Self::new(&[])
341 }
342}
343
344impl<'i> Line<'i> {
345 fn new(bytes: &'i [u8]) -> Self {
347 Self { remaining: bytes }
348 }
349
350 fn take(&mut self, nbytes: usize) -> &'i [u8] {
352 let (bytes, rest) = if nbytes < self.remaining.len() {
353 self.remaining.split_at(nbytes)
354 } else {
355 (self.remaining, [].as_ref())
356 };
357
358 self.remaining = rest;
359 bytes
360 }
361
362 fn slice_tail(&self, nbytes: usize) -> Result<&'i [u8], Error> {
364 let offset = self.len().checked_sub(nbytes).ok_or(InvalidLength)?;
365 self.remaining.get(offset..).ok_or(InvalidLength)
366 }
367
368 fn len(&self) -> usize {
370 self.remaining.len()
371 }
372
373 fn is_empty(&self) -> bool {
375 self.len() == 0
376 }
377
378 fn trim_end(&self) -> Self {
380 Line::new(match self.remaining {
381 [line @ .., CHAR_CR, CHAR_LF] => line,
382 [line @ .., CHAR_CR] => line,
383 [line @ .., CHAR_LF] => line,
384 line => line,
385 })
386 }
387}
388
389#[derive(Clone)]
391struct LineReader<'i> {
392 remaining: &'i [u8],
394
395 line_width: Option<usize>,
397}
398
399impl<'i> LineReader<'i> {
400 fn new_unwrapped(bytes: &'i [u8]) -> Result<Self, Error> {
402 if bytes.is_empty() {
403 Err(InvalidLength)
404 } else {
405 Ok(Self {
406 remaining: bytes,
407 line_width: None,
408 })
409 }
410 }
411
412 fn new_wrapped(bytes: &'i [u8], line_width: usize) -> Result<Self, Error> {
414 if line_width < MIN_LINE_WIDTH {
415 return Err(InvalidLength);
416 }
417
418 let mut reader = Self::new_unwrapped(bytes)?;
419 reader.line_width = Some(line_width);
420 Ok(reader)
421 }
422
423 fn is_empty(&self) -> bool {
425 self.remaining.is_empty()
426 }
427
428 fn decoded_len<E: Encoding>(&self) -> Result<usize, Error> {
430 let mut buffer = [0u8; 4];
431 let mut lines = self.clone();
432 let mut line = match lines.next().transpose()? {
433 Some(l) => l,
434 None => return Ok(0),
435 };
436 let mut base64_len = 0usize;
437
438 loop {
439 base64_len = base64_len.checked_add(line.len()).ok_or(InvalidLength)?;
440
441 match lines.next().transpose()? {
442 Some(l) => {
443 buffer.copy_from_slice(line.slice_tail(4)?);
446
447 line = l
448 }
449
450 None => {
455 let base64_last_block_len = match base64_len % 4 {
457 0 => 4,
458 n => n,
459 };
460
461 let decoded_len = encoding::decoded_len(
463 base64_len
464 .checked_sub(base64_last_block_len)
465 .ok_or(InvalidLength)?,
466 );
467
468 let mut out = [0u8; 3];
470 let last_block_len = if line.len() < base64_last_block_len {
471 let buffered_part_len = base64_last_block_len
472 .checked_sub(line.len())
473 .ok_or(InvalidLength)?;
474
475 let offset = 4usize.checked_sub(buffered_part_len).ok_or(InvalidLength)?;
476
477 for i in 0..buffered_part_len {
478 buffer[i] = buffer[offset.checked_add(i).ok_or(InvalidLength)?];
479 }
480
481 buffer[buffered_part_len..][..line.len()].copy_from_slice(line.remaining);
482 let buffer_len = buffered_part_len
483 .checked_add(line.len())
484 .ok_or(InvalidLength)?;
485
486 E::decode(&buffer[..buffer_len], &mut out)?.len()
487 } else {
488 let last_block = line.slice_tail(base64_last_block_len)?;
489 E::decode(last_block, &mut out)?.len()
490 };
491
492 return decoded_len.checked_add(last_block_len).ok_or(InvalidLength);
493 }
494 }
495 }
496 }
497}
498
499impl<'i> Iterator for LineReader<'i> {
500 type Item = Result<Line<'i>, Error>;
501
502 fn next(&mut self) -> Option<Result<Line<'i>, Error>> {
503 if let Some(line_width) = self.line_width {
504 let rest = match self.remaining.get(line_width..) {
505 None | Some([]) => {
506 if self.remaining.is_empty() {
507 return None;
508 } else {
509 let line = Line::new(self.remaining).trim_end();
510 self.remaining = &[];
511 return Some(Ok(line));
512 }
513 }
514 Some([CHAR_CR, CHAR_LF, rest @ ..]) => rest,
515 Some([CHAR_CR, rest @ ..]) => rest,
516 Some([CHAR_LF, rest @ ..]) => rest,
517 _ => {
518 return Some(Err(Error::InvalidEncoding));
520 }
521 };
522
523 let line = Line::new(&self.remaining[..line_width]);
524 self.remaining = rest;
525 Some(Ok(line))
526 } else if !self.remaining.is_empty() {
527 let line = Line::new(self.remaining).trim_end();
528 self.remaining = b"";
529
530 if line.is_empty() {
531 None
532 } else {
533 Some(Ok(line))
534 }
535 } else {
536 None
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use crate::{alphabet::Alphabet, test_vectors::*, Base64, Base64Unpadded, Decoder};
544
545 #[cfg(feature = "std")]
546 use {alloc::vec::Vec, std::io::Read};
547
548 #[test]
549 fn decode_padded() {
550 decode_test(PADDED_BIN, || {
551 Decoder::<Base64>::new(PADDED_BASE64.as_bytes()).unwrap()
552 })
553 }
554
555 #[test]
556 fn decode_unpadded() {
557 decode_test(UNPADDED_BIN, || {
558 Decoder::<Base64Unpadded>::new(UNPADDED_BASE64.as_bytes()).unwrap()
559 })
560 }
561
562 #[test]
563 fn decode_multiline_padded() {
564 decode_test(MULTILINE_PADDED_BIN, || {
565 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap()
566 })
567 }
568
569 #[test]
570 fn decode_multiline_unpadded() {
571 decode_test(MULTILINE_UNPADDED_BIN, || {
572 Decoder::<Base64Unpadded>::new_wrapped(MULTILINE_UNPADDED_BASE64.as_bytes(), 70)
573 .unwrap()
574 })
575 }
576
577 #[cfg(feature = "std")]
578 #[test]
579 fn read_multiline_padded() {
580 let mut decoder =
581 Decoder::<Base64>::new_wrapped(MULTILINE_PADDED_BASE64.as_bytes(), 70).unwrap();
582
583 let mut buf = Vec::new();
584 let len = decoder.read_to_end(&mut buf).unwrap();
585
586 assert_eq!(len, MULTILINE_PADDED_BIN.len());
587 assert_eq!(buf.as_slice(), MULTILINE_PADDED_BIN);
588 }
589
590 fn decode_test<'a, F, V>(expected: &[u8], f: F)
592 where
593 F: Fn() -> Decoder<'a, V>,
594 V: Alphabet,
595 {
596 for chunk_size in 1..expected.len() {
597 let mut decoder = f();
598 let mut remaining_len = decoder.remaining_len();
599 let mut buffer = [0u8; 1024];
600
601 for chunk in expected.chunks(chunk_size) {
602 assert!(!decoder.is_finished());
603 let decoded = decoder.decode(&mut buffer[..chunk.len()]).unwrap();
604 assert_eq!(chunk, decoded);
605
606 remaining_len -= decoded.len();
607 assert_eq!(remaining_len, decoder.remaining_len());
608 }
609
610 assert!(decoder.is_finished());
611 assert_eq!(decoder.remaining_len(), 0);
612 }
613 }
614}