1use crate::base::scan::{ConvertSymbols, EntrySymbol, ScannerError};
16use core::fmt;
17use octseq::builder::{
18 EmptyBuilder, FreezeBuilder, FromBuilder, OctetsBuilder, ShortBuf,
19};
20#[cfg(feature = "std")]
21use std::string::String;
22
23pub fn decode<Octets>(s: &str) -> Result<Octets, DecodeError>
30where
31 Octets: FromBuilder,
32 <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
33{
34 let mut decoder = Decoder::<<Octets as FromBuilder>::Builder>::new();
35 for ch in s.chars() {
36 decoder.push(ch)?;
37 }
38 decoder.finalize()
39}
40
41pub fn display<B, W>(bytes: &B, f: &mut W) -> fmt::Result
59where
60 B: AsRef<[u8]> + ?Sized,
61 W: fmt::Write,
62{
63 fn ch(i: u8) -> char {
64 ENCODE_ALPHABET[i as usize]
65 }
66
67 for chunk in bytes.as_ref().chunks(3) {
68 match chunk.len() {
69 1 => {
70 f.write_char(ch(chunk[0] >> 2))?;
71 f.write_char(ch((chunk[0] & 0x03) << 4))?;
72 f.write_char('=')?;
73 f.write_char('=')?;
74 }
75 2 => {
76 f.write_char(ch(chunk[0] >> 2))?;
77 f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
78 f.write_char(ch((chunk[1] & 0x0F) << 2))?;
79 f.write_char('=')?;
80 }
81 3 => {
82 f.write_char(ch(chunk[0] >> 2))?;
83 f.write_char(ch(((chunk[0] & 0x03) << 4) | (chunk[1] >> 4)))?;
84 f.write_char(ch(((chunk[1] & 0x0F) << 2) | (chunk[2] >> 6)))?;
85 f.write_char(ch(chunk[2] & 0x3F))?;
86 }
87 _ => unreachable!(),
88 }
89 }
90 Ok(())
91}
92
93#[cfg(feature = "std")]
95pub fn encode_string<B: AsRef<[u8]> + ?Sized>(bytes: &B) -> String {
96 let mut res = String::with_capacity((bytes.as_ref().len() / 3 + 1) * 4);
97 display(bytes, &mut res).unwrap();
98 res
99}
100
101pub fn encode_display<Octets: AsRef<[u8]>>(
103 octets: &Octets,
104) -> impl fmt::Display + '_ {
105 struct Display<'a>(&'a [u8]);
106
107 impl fmt::Display for Display<'_> {
108 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
109 display(self.0, f)
110 }
111 }
112
113 Display(octets.as_ref())
114}
115
116#[cfg(feature = "serde")]
122pub mod serde {
123 use super::encode_display;
124 use core::fmt;
125 use octseq::builder::{EmptyBuilder, FromBuilder, OctetsBuilder};
126 use octseq::serde::{DeserializeOctets, SerializeOctets};
127
128 pub fn serialize<Octets, S>(
129 octets: &Octets,
130 serializer: S,
131 ) -> Result<S::Ok, S::Error>
132 where
133 Octets: AsRef<[u8]> + SerializeOctets,
134 S: serde::Serializer,
135 {
136 if serializer.is_human_readable() {
137 serializer.collect_str(&encode_display(octets))
138 } else {
139 octets.serialize_octets(serializer)
140 }
141 }
142
143 pub fn deserialize<'de, Octets, D: serde::Deserializer<'de>>(
144 deserializer: D,
145 ) -> Result<Octets, D::Error>
146 where
147 Octets: FromBuilder + DeserializeOctets<'de>,
148 <Octets as FromBuilder>::Builder: EmptyBuilder,
149 {
150 struct Visitor<'de, Octets: DeserializeOctets<'de>>(Octets::Visitor);
151
152 impl<'de, Octets> serde::de::Visitor<'de> for Visitor<'de, Octets>
153 where
154 Octets: FromBuilder + DeserializeOctets<'de>,
155 <Octets as FromBuilder>::Builder: OctetsBuilder + EmptyBuilder,
156 {
157 type Value = Octets;
158
159 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
160 f.write_str("an Base64-encoded string")
161 }
162
163 fn visit_str<E: serde::de::Error>(
164 self,
165 v: &str,
166 ) -> Result<Self::Value, E> {
167 super::decode(v).map_err(E::custom)
168 }
169
170 fn visit_borrowed_bytes<E: serde::de::Error>(
171 self,
172 value: &'de [u8],
173 ) -> Result<Self::Value, E> {
174 self.0.visit_borrowed_bytes(value)
175 }
176
177 #[cfg(feature = "std")]
178 fn visit_byte_buf<E: serde::de::Error>(
179 self,
180 value: std::vec::Vec<u8>,
181 ) -> Result<Self::Value, E> {
182 self.0.visit_byte_buf(value)
183 }
184 }
185
186 if deserializer.is_human_readable() {
187 deserializer.deserialize_str(Visitor(Octets::visitor()))
188 } else {
189 Octets::deserialize_with_visitor(
190 deserializer,
191 Visitor(Octets::visitor()),
192 )
193 }
194 }
195}
196
197pub struct Decoder<Builder> {
205 buf: [u8; 4],
210
211 next: usize,
216
217 target: Result<Builder, DecodeError>,
219}
220
221impl<Builder: EmptyBuilder> Decoder<Builder> {
222 #[must_use]
224 pub fn new() -> Self {
225 Decoder {
226 buf: [0; 4],
227 next: 0,
228 target: Ok(Builder::empty()),
229 }
230 }
231}
232
233impl<Builder: OctetsBuilder> Decoder<Builder> {
234 pub fn finalize(self) -> Result<Builder::Octets, DecodeError>
236 where
237 Builder: FreezeBuilder,
238 {
239 let (target, next) = (self.target, self.next);
240 target.and_then(|bytes| {
241 if next & 0x0F != 0 {
243 Err(DecodeError::ShortInput)
244 } else {
245 Ok(bytes.freeze())
246 }
247 })
248 }
249
250 pub fn push(&mut self, ch: char) -> Result<(), DecodeError> {
256 if self.next == 0xF0 {
257 self.target = Err(DecodeError::TrailingInput);
258 return Err(DecodeError::TrailingInput);
259 }
260
261 let val = if ch == PAD {
262 if self.next < 2 {
264 return Err(DecodeError::IllegalChar(ch));
265 }
266 0x80 } else {
268 if ch > (127 as char) {
269 return Err(DecodeError::IllegalChar(ch));
270 }
271 let val = DECODE_ALPHABET[ch as usize];
272 if val == 0xFF {
273 return Err(DecodeError::IllegalChar(ch));
274 }
275 val
276 };
277 self.buf[self.next] = val;
278 self.next += 1;
279
280 if self.next == 4 {
281 let target = self.target.as_mut().unwrap(); target
283 .append_slice(&[(self.buf[0] << 2) | (self.buf[1] >> 4)])
284 .map_err(Into::into)?;
285 if self.buf[2] != 0x80 {
286 target
287 .append_slice(&[(self.buf[1] << 4) | (self.buf[2] >> 2)])
288 .map_err(Into::into)?;
289 }
290 if self.buf[3] != 0x80 {
291 if self.buf[2] == 0x80 {
292 return Err(DecodeError::TrailingInput);
293 }
294 target
295 .append_slice(&[(self.buf[2] << 6) | self.buf[3]])
296 .map_err(Into::into)?;
297 self.next = 0
298 } else {
299 self.next = 0xF0
300 }
301 }
302
303 Ok(())
304 }
305}
306
307impl<Builder: EmptyBuilder> Default for Decoder<Builder> {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[derive(Clone, Debug, Default)]
319pub struct SymbolConverter {
320 input: [u8; 4],
325
326 next: usize,
331
332 output: [u8; 3],
334}
335
336impl SymbolConverter {
337 #[must_use]
339 pub fn new() -> Self {
340 Default::default()
341 }
342
343 fn process_char<Error: ScannerError>(
344 &mut self,
345 ch: char,
346 ) -> Result<Option<&[u8]>, Error> {
347 if self.next == EOF_MARKER {
348 return Err(Error::custom("trailing Base 64 data"));
349 }
350
351 let val = if ch == PAD {
352 if self.next < 2 {
354 return Err(Error::custom("illegal Base 64 data"));
355 }
356 PAD_MARKER } else {
358 if ch > (127 as char) {
359 return Err(Error::custom("illegal Base 64 data"));
360 }
361 let val = DECODE_ALPHABET[ch as usize];
362 if val == 0xFF {
363 return Err(Error::custom("illegal Base 64 data"));
364 }
365 val
366 };
367 self.input[self.next] = val;
368 self.next += 1;
369
370 if self.next == 4 {
371 self.output[0] = (self.input[0] << 2) | (self.input[1] >> 4);
372
373 if self.input[2] == PAD_MARKER {
374 if self.input[3] == PAD_MARKER {
377 self.next = EOF_MARKER;
378 Ok(Some(&self.output[..1]))
379 } else {
380 Err(Error::custom("illegal Base 64 data"))
381 }
382 } else {
383 self.output[1] = (self.input[1] << 4) | (self.input[2] >> 2);
384
385 if self.input[3] == PAD_MARKER {
386 self.next = EOF_MARKER;
388 Ok(Some(&self.output[..2]))
389 } else {
390 self.output[2] = (self.input[2] << 6) | self.input[3];
391 self.next = 0;
392 Ok(Some(&self.output))
393 }
394 }
395 } else {
396 Ok(None)
397 }
398 }
399}
400
401impl<Sym, Error> ConvertSymbols<Sym, Error> for SymbolConverter
402where
403 Sym: Into<EntrySymbol>,
404 Error: ScannerError,
405{
406 fn process_symbol(
407 &mut self,
408 symbol: Sym,
409 ) -> Result<Option<&[u8]>, Error> {
410 match symbol.into() {
411 EntrySymbol::Symbol(symbol) => self.process_char(
412 symbol
413 .into_char()
414 .map_err(|_| Error::custom("illegal Base 64 data"))?,
415 ),
416 EntrySymbol::EndOfToken => Ok(None),
417 }
418 }
419
420 fn process_tail(&mut self) -> Result<Option<&[u8]>, Error> {
421 if self.next & 0x0F != 0 {
423 Err(Error::custom("incomplete Base 64 data"))
424 } else {
425 Ok(None)
426 }
427 }
428}
429
430#[derive(Clone, Copy, Debug, Eq, PartialEq)]
436pub enum DecodeError {
437 IllegalChar(char),
439
440 TrailingInput,
442
443 ShortInput,
445
446 ShortBuf,
448}
449
450impl From<ShortBuf> for DecodeError {
451 fn from(_: ShortBuf) -> Self {
452 DecodeError::ShortBuf
453 }
454}
455
456impl fmt::Display for DecodeError {
459 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
460 match *self {
461 DecodeError::TrailingInput => f.write_str("trailing input"),
462 DecodeError::IllegalChar(ch) => {
463 write!(f, "illegal character '{}'", ch)
464 }
465 DecodeError::ShortInput => f.write_str("incomplete input"),
466 DecodeError::ShortBuf => ShortBuf.fmt(f),
467 }
468 }
469}
470
471#[cfg(feature = "std")]
472impl std::error::Error for DecodeError {}
473
474const DECODE_ALPHABET: [u8; 128] = [
482 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, ];
499
500const ENCODE_ALPHABET: [char; 64] = [
501 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', ];
510
511const PAD: char = '=';
513
514const PAD_MARKER: u8 = 0x80;
516
517const EOF_MARKER: usize = 0xF0;
519
520#[cfg(all(test, feature = "std"))]
523mod test {
524 use super::*;
525
526 const HAPPY_CASES: &[(&[u8], &str)] = &[
527 (b"", ""),
528 (b"f", "Zg=="),
529 (b"fo", "Zm8="),
530 (b"foo", "Zm9v"),
531 (b"foob", "Zm9vYg=="),
532 (b"fooba", "Zm9vYmE="),
533 (b"foobar", "Zm9vYmFy"),
534 ];
535
536 #[test]
537 fn decode_str() {
538 fn decode(s: &str) -> Result<std::vec::Vec<u8>, DecodeError> {
539 super::decode(s)
540 }
541
542 for (bin, text) in HAPPY_CASES {
543 assert_eq!(&decode(text).unwrap(), bin, "decode {}", text)
544 }
545
546 assert_eq!(decode("FPucA").unwrap_err(), DecodeError::ShortInput);
547 assert_eq!(
548 decode("FPucA=").unwrap_err(),
549 DecodeError::IllegalChar('=')
550 );
551 assert_eq!(decode("FPucAw=").unwrap_err(), DecodeError::ShortInput);
552 assert_eq!(
553 decode("FPucAw=a").unwrap_err(),
554 DecodeError::TrailingInput
555 );
556 assert_eq!(
557 decode("FPucAw==a").unwrap_err(),
558 DecodeError::TrailingInput
559 );
560 }
561
562 #[test]
563 fn symbol_converter() {
564 use crate::base::scan::Symbols;
565 use std::vec::Vec;
566
567 fn decode(s: &str) -> Result<Vec<u8>, std::io::Error> {
568 let mut convert = SymbolConverter::new();
569 let convert: &mut dyn ConvertSymbols<_, std::io::Error> =
570 &mut convert;
571 let mut res = Vec::new();
572 for sym in Symbols::new(s.chars()) {
573 if let Some(octs) = convert.process_symbol(sym)? {
574 res.extend_from_slice(octs);
575 }
576 }
577 if let Some(octs) = convert.process_tail()? {
578 res.extend_from_slice(octs);
579 }
580 Ok(res)
581 }
582
583 for (bin, text) in HAPPY_CASES {
584 assert_eq!(&decode(text).unwrap(), bin, "convert {}", text)
585 }
586 }
587
588 #[test]
589 fn display_bytes() {
590 use super::*;
591
592 fn fmt(s: &[u8]) -> String {
593 let mut out = String::new();
594 display(s, &mut out).unwrap();
595 out
596 }
597
598 for (bin, text) in HAPPY_CASES {
599 assert_eq!(&fmt(bin), text, "fmt {}", text);
600 }
601 }
602}