1use crate::{
4 grammar, Base64Encoder, Error, LineEnding, Result, BASE64_WRAP_WIDTH,
5 ENCAPSULATION_BOUNDARY_DELIMITER, POST_ENCAPSULATION_BOUNDARY, PRE_ENCAPSULATION_BOUNDARY,
6};
7use base64ct::{Base64, Encoding};
8use core::str;
9
10#[cfg(feature = "alloc")]
11use alloc::string::String;
12
13#[cfg(feature = "std")]
14use std::io;
15
16pub fn encapsulated_len(label: &str, line_ending: LineEnding, input_len: usize) -> Result<usize> {
34 encapsulated_len_wrapped(label, BASE64_WRAP_WIDTH, line_ending, input_len)
35}
36
37pub fn encapsulated_len_wrapped(
52 label: &str,
53 line_width: usize,
54 line_ending: LineEnding,
55 input_len: usize,
56) -> Result<usize> {
57 if line_width < 4 {
58 return Err(Error::Length);
59 }
60
61 let base64_len = input_len
62 .checked_mul(4)
63 .and_then(|n| n.checked_div(3))
64 .and_then(|n| n.checked_add(3))
65 .ok_or(Error::Length)?
66 & !3;
67
68 let base64_len_wrapped = base64_len_wrapped(base64_len, line_width, line_ending)?;
69 encapsulated_len_inner(label, line_ending, base64_len_wrapped)
70}
71
72pub fn encoded_len(label: &str, line_ending: LineEnding, input: &[u8]) -> Result<usize> {
81 let base64_len = Base64::encoded_len(input);
82 let base64_len_wrapped = base64_len_wrapped(base64_len, BASE64_WRAP_WIDTH, line_ending)?;
83 encapsulated_len_inner(label, line_ending, base64_len_wrapped)
84}
85
86pub fn encode<'o>(
88 type_label: &str,
89 line_ending: LineEnding,
90 input: &[u8],
91 buf: &'o mut [u8],
92) -> Result<&'o str> {
93 let mut encoder = Encoder::new(type_label, line_ending, buf)?;
94 encoder.encode(input)?;
95 let encoded_len = encoder.finish()?;
96 let output = &buf[..encoded_len];
97
98 debug_assert!(str::from_utf8(output).is_ok());
100
101 if output.iter().fold(0u8, |acc, &byte| acc | (byte & 0x80)) == 0 {
103 #[allow(unsafe_code)]
113 Ok(unsafe { str::from_utf8_unchecked(output) })
114 } else {
115 Err(Error::CharacterEncoding)
116 }
117}
118
119#[cfg(feature = "alloc")]
122#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
123pub fn encode_string(label: &str, line_ending: LineEnding, input: &[u8]) -> Result<String> {
124 let expected_len = encoded_len(label, line_ending, input)?;
125 let mut buf = vec![0u8; expected_len];
126 let actual_len = encode(label, line_ending, input, &mut buf)?.len();
127 debug_assert_eq!(expected_len, actual_len);
128 String::from_utf8(buf).map_err(|_| Error::CharacterEncoding)
129}
130
131fn encapsulated_len_inner(
133 label: &str,
134 line_ending: LineEnding,
135 base64_len: usize,
136) -> Result<usize> {
137 [
138 PRE_ENCAPSULATION_BOUNDARY.len(),
139 label.as_bytes().len(),
140 ENCAPSULATION_BOUNDARY_DELIMITER.len(),
141 line_ending.len(),
142 base64_len,
143 line_ending.len(),
144 POST_ENCAPSULATION_BOUNDARY.len(),
145 label.as_bytes().len(),
146 ENCAPSULATION_BOUNDARY_DELIMITER.len(),
147 line_ending.len(),
148 ]
149 .into_iter()
150 .try_fold(0usize, |acc, len| acc.checked_add(len))
151 .ok_or(Error::Length)
152}
153
154fn base64_len_wrapped(
157 base64_len: usize,
158 line_width: usize,
159 line_ending: LineEnding,
160) -> Result<usize> {
161 base64_len
162 .saturating_sub(1)
163 .checked_div(line_width)
164 .and_then(|lines| lines.checked_mul(line_ending.len()))
165 .and_then(|len| len.checked_add(base64_len))
166 .ok_or(Error::Length)
167}
168
169pub struct Encoder<'l, 'o> {
174 type_label: &'l str,
176
177 line_ending: LineEnding,
179
180 base64: Base64Encoder<'o>,
182}
183
184impl<'l, 'o> Encoder<'l, 'o> {
185 pub fn new(type_label: &'l str, line_ending: LineEnding, out: &'o mut [u8]) -> Result<Self> {
190 Self::new_wrapped(type_label, BASE64_WRAP_WIDTH, line_ending, out)
191 }
192
193 pub fn new_wrapped(
208 type_label: &'l str,
209 line_width: usize,
210 line_ending: LineEnding,
211 mut out: &'o mut [u8],
212 ) -> Result<Self> {
213 grammar::validate_label(type_label.as_bytes())?;
214
215 for boundary_part in [
216 PRE_ENCAPSULATION_BOUNDARY,
217 type_label.as_bytes(),
218 ENCAPSULATION_BOUNDARY_DELIMITER,
219 line_ending.as_bytes(),
220 ] {
221 if out.len() < boundary_part.len() {
222 return Err(Error::Length);
223 }
224
225 let (part, rest) = out.split_at_mut(boundary_part.len());
226 out = rest;
227
228 part.copy_from_slice(boundary_part);
229 }
230
231 let base64 = Base64Encoder::new_wrapped(out, line_width, line_ending)?;
232
233 Ok(Self {
234 type_label,
235 line_ending,
236 base64,
237 })
238 }
239
240 pub fn type_label(&self) -> &'l str {
242 self.type_label
243 }
244
245 pub fn encode(&mut self, input: &[u8]) -> Result<()> {
252 self.base64.encode(input)?;
253 Ok(())
254 }
255
256 pub fn base64_encoder(&mut self) -> &mut Base64Encoder<'o> {
258 &mut self.base64
259 }
260
261 pub fn finish(self) -> Result<usize> {
266 let (base64, mut out) = self.base64.finish_with_remaining()?;
267
268 for boundary_part in [
269 self.line_ending.as_bytes(),
270 POST_ENCAPSULATION_BOUNDARY,
271 self.type_label.as_bytes(),
272 ENCAPSULATION_BOUNDARY_DELIMITER,
273 self.line_ending.as_bytes(),
274 ] {
275 if out.len() < boundary_part.len() {
276 return Err(Error::Length);
277 }
278
279 let (part, rest) = out.split_at_mut(boundary_part.len());
280 out = rest;
281
282 part.copy_from_slice(boundary_part);
283 }
284
285 encapsulated_len_inner(self.type_label, self.line_ending, base64.len())
286 }
287}
288
289#[cfg(feature = "std")]
290#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
291impl<'l, 'o> io::Write for Encoder<'l, 'o> {
292 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
293 self.encode(buf)?;
294 Ok(buf.len())
295 }
296
297 fn flush(&mut self) -> io::Result<()> {
298 Ok(())
300 }
301}