1use crate::base::scan::{ConvertSymbols, EntrySymbol, ScannerError};
16use core::fmt;
17use octseq::builder::{
18 EmptyBuilder, FreezeBuilder, FromBuilder, OctetsBuilder, ShortBuf,
19};
20#[cfg(feature = "std")]
21use std::string::String;
22
23pub fn decode<Octets>(s: &str) -> Result<Octets, DecodeError>
30where
31 Octets: FromBuilder,
32 <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
33{
34 let mut decoder = Decoder::<<Octets as FromBuilder>::Builder>::new();
35 for ch in s.chars() {
36 decoder.push(ch)?;
37 }
38 decoder.finalize()
39}
40
41pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
59where
60 B: AsRef<[u8]> + ?Sized,
61 W: fmt::Write,
62{
63 fn ch(i: u8) -> char {
64 ENCODE_ALPHABET[i as usize]
65 }
66
67 for chunk in bytes.as_ref().chunks(3) {
68 match chunk.len() {
69 1 => {
70 f.write_char(ch(chunk[0] >> 2))?;
71 f.write_char(ch((chunk[0] & 0x03) << 4))?;
72 f.write_char('=')?;
73 f.write_char('=')?;
74 }
75 2 => {
76 f.write_char(ch(chunk[0] >> 2))?;
77 f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
78 f.write_char(ch((chunk[1] & 0x0F) << 2))?;
79 f.write_char('=')?;
80 }
81 3 => {
82 f.write_char(ch(chunk[0] >> 2))?;
83 f.write_char(ch((chunk[0] & 0x03) << 4 | chunk[1] >> 4))?;
84 f.write_char(ch((chunk[1] & 0x0F) << 2 | chunk[2] >> 6))?;
85 f.write_char(ch(chunk[2] & 0x3F))?;
86 }
87 _ => unreachable!(),
88 }
89 }
90 Ok(())
91}
92
93#[cfg(feature = "std")]
95pub fn encode_string<B: AsRef<[u8]> + ?Sized>(bytes: &B) -> String {
96 let mut res = String::with_capacity((bytes.as_ref().len() / 3 + 1) * 4);
97 display(bytes, &mut res).unwrap();
98 res
99}
100
101pub fn encode_display<Octets: AsRef<[u8]>>(
103 octets: &Octets,
104) -> impl fmt::Display + '_ {
105 struct Display<'a>(&'a [u8]);
106
107 impl<'a> fmt::Display for Display<'a> {
108 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109 display(self.0, f)
110 }
111 }
112
113 Display(octets.as_ref())
114}
115
116#[cfg(feature = "serde")]
122pub mod serde {
123 use super::encode_display;
124 use core::fmt;
125 use octseq::builder::{EmptyBuilder, FromBuilder, OctetsBuilder};
126 use octseq::serde::{DeserializeOctets, SerializeOctets};
127
128 pub fn serialize<Octets, S>(
129 octets: &Octets,
130 serializer: S,
131 ) -> Result<S::Ok, S::Error>
132 where
133 Octets: AsRef<[u8]> + SerializeOctets,
134 S: serde::Serializer,
135 {
136 if serializer.is_human_readable() {
137 serializer.collect_str(&encode_display(octets))
138 } else {
139 octets.serialize_octets(serializer)
140 }
141 }
142
143 pub fn deserialize<'de, Octets, D: serde::Deserializer<'de>>(
144 deserializer: D,
145 ) -> Result<Octets, D::Error>
146 where
147 Octets: FromBuilder + DeserializeOctets<'de>,
148 <Octets as FromBuilder>::Builder: EmptyBuilder,
149 {
150 struct Visitor<'de, Octets: DeserializeOctets<'de>>(Octets::Visitor);
151
152 impl<'de, Octets> serde::de::Visitor<'de> for Visitor<'de, Octets>
153 where
154 Octets: FromBuilder + DeserializeOctets<'de>,
155 <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
156 {
157 type Value = Octets;
158
159 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 f.write_str("an Base64-encoded string")
161 }
162
163 fn visit_str<E: serde::de::Error>(
164 self,
165 v: &str,
166 ) -> Result<Self::Value, E> {
167 super::decode(v).map_err(E::custom)
168 }
169
170 fn visit_borrowed_bytes<E: serde::de::Error>(
171 self,
172 value: &'de [u8],
173 ) -> Result<Self::Value, E> {
174 self.0.visit_borrowed_bytes(value)
175 }
176
177 #[cfg(feature = "std")]
178 fn visit_byte_buf<E: serde::de::Error>(
179 self,
180 value: std::vec::Vec<u8>,
181 ) -> Result<Self::Value, E> {
182 self.0.visit_byte_buf(value)
183 }
184 }
185
186 if deserializer.is_human_readable() {
187 deserializer.deserialize_str(Visitor(Octets::visitor()))
188 } else {
189 Octets::deserialize_with_visitor(
190 deserializer,
191 Visitor(Octets::visitor()),
192 )
193 }
194 }
195}
196
197pub struct Decoder<Builder> {
205 buf: [u8; 4],
210
211 next: usize,
216
217 target: Result<Builder, DecodeError>,
219}
220
221impl<Builder: EmptyBuilder> Decoder<Builder> {
222 #[must_use]
224 pub fn new() -> Self {
225 Decoder {
226 buf: [0; 4],
227 next: 0,
228 target: Ok(Builder::empty()),
229 }
230 }
231}
232
233impl<Builder: OctetsBuilder> Decoder<Builder> {
234 pub fn finalize(self) -> Result<Builder::Octets, DecodeError>
236 where
237 Builder: FreezeBuilder,
238 {
239 let (target, next) = (self.target, self.next);
240 target.and_then(|bytes| {
241 if next & 0x0F != 0 {
243 Err(DecodeError::ShortInput)
244 } else {
245 Ok(bytes.freeze())
246 }
247 })
248 }
249
250 pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
256 if self.next == 0xF0 {
257 self.target = Err(DecodeError::TrailingInput);
258 return Err(DecodeError::TrailingInput);
259 }
260
261 let val = if ch == PAD {
262 if self.next < 2 {
264 return Err(DecodeError::IllegalChar(ch));
265 }
266 0x80 } else {
268 if ch > (127 as char) {
269 return Err(DecodeError::IllegalChar(ch));
270 }
271 let val = DECODE_ALPHABET[ch as usize];
272 if val == 0xFF {
273 return Err(DecodeError::IllegalChar(ch));
274 }
275 val
276 };
277 self.buf[self.next] = val;
278 self.next += 1;
279
280 if self.next == 4 {
281 let target = self.target.as_mut().unwrap(); target
283 .append_slice(&[self.buf[0] << 2 | self.buf[1] >> 4])
284 .map_err(Into::into)?;
285 if self.buf[2] != 0x80 {
286 target
287 .append_slice(&[self.buf[1] << 4 | self.buf[2] >> 2])
288 .map_err(Into::into)?;
289 }
290 if self.buf[3] != 0x80 {
291 if self.buf[2] == 0x80 {
292 return Err(DecodeError::TrailingInput);
293 }
294 target
295 .append_slice(&[(self.buf[2] << 6) | self.buf[3]])
296 .map_err(Into::into)?;
297 self.next = 0
298 } else {
299 self.next = 0xF0
300 }
301 }
302
303 Ok(())
304 }
305}
306
307#[cfg(feature = "bytes")]
310impl<Builder: EmptyBuilder> Default for Decoder<Builder> {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[derive(Clone, Debug, Default)]
320pub struct SymbolConverter {
321 input: [u8; 4],
326
327 next: usize,
332
333 output: [u8; 3],
335}
336
337impl SymbolConverter {
338 #[must_use]
340 pub fn new() -> Self {
341 Default::default()
342 }
343
344 fn process_char<Error: ScannerError>(
345 &mut self,
346 ch: char,
347 ) -> Result<Option<&[u8]>, Error> {
348 if self.next == EOF_MARKER {
349 return Err(Error::custom("trailing Base 64 data"));
350 }
351
352 let val = if ch == PAD {
353 if self.next < 2 {
355 return Err(Error::custom("illegal Base 64 data"));
356 }
357 PAD_MARKER } else {
359 if ch > (127 as char) {
360 return Err(Error::custom("illegal Base 64 data"));
361 }
362 let val = DECODE_ALPHABET[ch as usize];
363 if val == 0xFF {
364 return Err(Error::custom("illegal Base 64 data"));
365 }
366 val
367 };
368 self.input[self.next] = val;
369 self.next += 1;
370
371 if self.next == 4 {
372 self.output[0] = self.input[0] << 2 | self.input[1] >> 4;
373
374 if self.input[2] == PAD_MARKER {
375 if self.input[3] == PAD_MARKER {
378 self.next = EOF_MARKER;
379 Ok(Some(&self.output[..1]))
380 } else {
381 Err(Error::custom("illegal Base 64 data"))
382 }
383 } else {
384 self.output[1] = self.input[1] << 4 | self.input[2] >> 2;
385
386 if self.input[3] == PAD_MARKER {
387 self.next = EOF_MARKER;
389 Ok(Some(&self.output[..2]))
390 } else {
391 self.output[2] = (self.input[2] << 6) | self.input[3];
392 self.next = 0;
393 Ok(Some(&self.output))
394 }
395 }
396 } else {
397 Ok(None)
398 }
399 }
400}
401
402impl<Sym, Error> ConvertSymbols<Sym, Error> for SymbolConverter
403where
404 Sym: Into<EntrySymbol>,
405 Error: ScannerError,
406{
407 fn process_symbol(
408 &mut self,
409 symbol: Sym,
410 ) -> Result<Option<&[u8]>, Error> {
411 match symbol.into() {
412 EntrySymbol::Symbol(symbol) => self.process_char(
413 symbol
414 .into_char()
415 .map_err(|_| Error::custom("illegal Base 64 data"))?,
416 ),
417 EntrySymbol::EndOfToken => Ok(None),
418 }
419 }
420
421 fn process_tail(&mut self) -> Result<Option<&[u8]>, Error> {
422 if self.next & 0x0F != 0 {
424 Err(Error::custom("incomplete Base 64 data"))
425 } else {
426 Ok(None)
427 }
428 }
429}
430
431#[derive(Clone, Copy, Debug, Eq, PartialEq)]
437pub enum DecodeError {
438 IllegalChar(char),
440
441 TrailingInput,
443
444 ShortInput,
446
447 ShortBuf,
449}
450
451impl From<ShortBuf> for DecodeError {
452 fn from(_: ShortBuf) -> Self {
453 DecodeError::ShortBuf
454 }
455}
456
457impl fmt::Display for DecodeError {
460 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
461 match *self {
462 DecodeError::TrailingInput => f.write_str("trailing input"),
463 DecodeError::IllegalChar(ch) => {
464 write!(f, "illegal character '{}'", ch)
465 }
466 DecodeError::ShortInput => f.write_str("incomplete input"),
467 DecodeError::ShortBuf => ShortBuf.fmt(f),
468 }
469 }
470}
471
472#[cfg(feature = "std")]
473impl std::error::Error for DecodeError {}
474
475const DECODE_ALPHABET: [u8; 128] = [
483 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ];
500
501const ENCODE_ALPHABET: [char; 64] = [
502 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', ];
511
512const PAD: char = '=';
514
515const PAD_MARKER: u8 = 0x80;
517
518const EOF_MARKER: usize = 0xF0;
520
521#[cfg(test)]
524mod test {
525 use super::*;
526
527 #[allow(dead_code)]
528 const HAPPY_CASES: &[(&[u8], &str)] = &[
529 (b"", ""),
530 (b"f", "Zg=="),
531 (b"fo", "Zm8="),
532 (b"foo", "Zm9v"),
533 (b"foob", "Zm9vYg=="),
534 (b"fooba", "Zm9vYmE="),
535 (b"foobar", "Zm9vYmFy"),
536 ];
537
538 #[cfg(feature = "std")]
539 #[test]
540 fn decode_str() {
541 fn decode(s: &str) -> Result<std::vec::Vec<u8>, DecodeError> {
542 super::decode(s)
543 }
544
545 for (bin, text) in HAPPY_CASES {
546 assert_eq!(&decode(text).unwrap(), bin, "decode {}", text)
547 }
548
549 assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput);
550 assert_eq!(
551 decode("FPucA=").unwrap_err(),
552 DecodeError::IllegalChar('=')
553 );
554 assert_eq!(decode("FPucAw=").unwrap_err(), DecodeError::ShortInput);
555 assert_eq!(
556 decode("FPucAw=a").unwrap_err(),
557 DecodeError::TrailingInput
558 );
559 assert_eq!(
560 decode("FPucAw==a").unwrap_err(),
561 DecodeError::TrailingInput
562 );
563 }
564
565 #[cfg(feature = "std")]
566 #[test]
567 fn symbol_converter() {
568 use crate::base::scan::Symbols;
569 use std::vec::Vec;
570
571 fn decode(s: &str) -> Result<Vec<u8>, std::io::Error> {
572 let mut convert = SymbolConverter::new();
573 let convert: &mut dyn ConvertSymbols<_, std::io::Error> =
574 &mut convert;
575 let mut res = Vec::new();
576 for sym in Symbols::new(s.chars()) {
577 if let Some(octs) = convert.process_symbol(sym)? {
578 res.extend_from_slice(octs);
579 }
580 }
581 if let Some(octs) = convert.process_tail()? {
582 res.extend_from_slice(octs);
583 }
584 Ok(res)
585 }
586
587 for (bin, text) in HAPPY_CASES {
588 assert_eq!(&decode(text).unwrap(), bin, "convert {}", text)
589 }
590 }
591
592 #[test]
593 #[cfg(feature = "std")]
594 fn display_bytes() {
595 use super::*;
596
597 fn fmt(s: &[u8]) -> String {
598 let mut out = String::new();
599 display(s, &mut out).unwrap();
600 out
601 }
602
603 for (bin, text) in HAPPY_CASES {
604 assert_eq!(&fmt(bin), text, "fmt {}", text);
605 }
606 }
607}