base64ct/
encoding.rs
1use crate::{
4 alphabet::Alphabet,
5 errors::{Error, InvalidEncodingError, InvalidLengthError},
6};
7use core::str;
8
9#[cfg(feature = "alloc")]
10use alloc::{string::String, vec::Vec};
11
12#[cfg(doc)]
13use crate::{Base64, Base64Bcrypt, Base64Crypt, Base64Unpadded, Base64Url, Base64UrlUnpadded};
14
15const PAD: u8 = b'=';
17
18pub trait Encoding: Alphabet {
32 fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error>;
34
35 fn decode_in_place(buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError>;
40
41 #[cfg(feature = "alloc")]
43 #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
44 fn decode_vec(input: &str) -> Result<Vec<u8>, Error>;
45
46 fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError>;
51
52 #[cfg(feature = "alloc")]
57 #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
58 fn encode_string(input: &[u8]) -> String;
59
60 fn encoded_len(bytes: &[u8]) -> usize;
64}
65
66impl<T: Alphabet> Encoding for T {
67 fn decode(src: impl AsRef<[u8]>, dst: &mut [u8]) -> Result<&[u8], Error> {
68 let (src_unpadded, mut err) = if T::PADDED {
69 let (unpadded_len, e) = decode_padding(src.as_ref())?;
70 (&src.as_ref()[..unpadded_len], e)
71 } else {
72 (src.as_ref(), 0)
73 };
74
75 let dlen = decoded_len(src_unpadded.len());
76
77 if dlen > dst.len() {
78 return Err(Error::InvalidLength);
79 }
80
81 let dst = &mut dst[..dlen];
82
83 let mut src_chunks = src_unpadded.chunks_exact(4);
84 let mut dst_chunks = dst.chunks_exact_mut(3);
85 for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
86 err |= Self::decode_3bytes(s, d);
87 }
88 let src_rem = src_chunks.remainder();
89 let dst_rem = dst_chunks.into_remainder();
90
91 err |= !(src_rem.is_empty() || src_rem.len() >= 2) as i16;
92 let mut tmp_out = [0u8; 3];
93 let mut tmp_in = [b'A'; 4];
94 tmp_in[..src_rem.len()].copy_from_slice(src_rem);
95 err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
96 dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
97
98 if err == 0 {
99 validate_last_block::<T>(src.as_ref(), dst)?;
100 Ok(dst)
101 } else {
102 Err(Error::InvalidEncoding)
103 }
104 }
105
106 #[allow(clippy::integer_arithmetic)]
108 fn decode_in_place(mut buf: &mut [u8]) -> Result<&[u8], InvalidEncodingError> {
109 let mut err = if T::PADDED {
112 let (unpadded_len, e) = decode_padding(buf)?;
113 buf = &mut buf[..unpadded_len];
114 e
115 } else {
116 0
117 };
118
119 let dlen = decoded_len(buf.len());
120 let full_chunks = buf.len() / 4;
121
122 for chunk in 0..full_chunks {
123 #[allow(unsafe_code)]
127 unsafe {
128 debug_assert!(3 * chunk + 3 <= buf.len());
129 debug_assert!(4 * chunk + 4 <= buf.len());
130
131 let p3 = buf.as_mut_ptr().add(3 * chunk) as *mut [u8; 3];
132 let p4 = buf.as_ptr().add(4 * chunk) as *const [u8; 4];
133
134 let mut tmp_out = [0u8; 3];
135 err |= Self::decode_3bytes(&*p4, &mut tmp_out);
136 *p3 = tmp_out;
137 }
138 }
139
140 let src_rem_pos = 4 * full_chunks;
141 let src_rem_len = buf.len() - src_rem_pos;
142 let dst_rem_pos = 3 * full_chunks;
143 let dst_rem_len = dlen - dst_rem_pos;
144
145 err |= !(src_rem_len == 0 || src_rem_len >= 2) as i16;
146 let mut tmp_in = [b'A'; 4];
147 tmp_in[..src_rem_len].copy_from_slice(&buf[src_rem_pos..]);
148 let mut tmp_out = [0u8; 3];
149
150 err |= Self::decode_3bytes(&tmp_in, &mut tmp_out);
151
152 if err == 0 {
153 #[allow(unsafe_code)]
157 unsafe {
158 debug_assert!(dst_rem_pos + dst_rem_len <= buf.len());
159 debug_assert!(dst_rem_len <= tmp_out.len());
160 debug_assert!(dlen <= buf.len());
161
162 core::ptr::copy_nonoverlapping(
163 tmp_out.as_ptr(),
164 buf.as_mut_ptr().add(dst_rem_pos),
165 dst_rem_len,
166 );
167 Ok(buf.get_unchecked(..dlen))
168 }
169 } else {
170 Err(InvalidEncodingError)
171 }
172 }
173
174 #[cfg(feature = "alloc")]
175 fn decode_vec(input: &str) -> Result<Vec<u8>, Error> {
176 let mut output = vec![0u8; decoded_len(input.len())];
177 let len = Self::decode(input, &mut output)?.len();
178
179 if len <= output.len() {
180 output.truncate(len);
181 Ok(output)
182 } else {
183 Err(Error::InvalidLength)
184 }
185 }
186
187 fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, InvalidLengthError> {
188 let elen = match encoded_len_inner(src.len(), T::PADDED) {
189 Some(v) => v,
190 None => return Err(InvalidLengthError),
191 };
192
193 if elen > dst.len() {
194 return Err(InvalidLengthError);
195 }
196
197 let dst = &mut dst[..elen];
198
199 let mut src_chunks = src.chunks_exact(3);
200 let mut dst_chunks = dst.chunks_exact_mut(4);
201
202 for (s, d) in (&mut src_chunks).zip(&mut dst_chunks) {
203 Self::encode_3bytes(s, d);
204 }
205
206 let src_rem = src_chunks.remainder();
207
208 if T::PADDED {
209 if let Some(dst_rem) = dst_chunks.next() {
210 let mut tmp = [0u8; 3];
211 tmp[..src_rem.len()].copy_from_slice(src_rem);
212 Self::encode_3bytes(&tmp, dst_rem);
213
214 let flag = src_rem.len() == 1;
215 let mask = (flag as u8).wrapping_sub(1);
216 dst_rem[2] = (dst_rem[2] & mask) | (PAD & !mask);
217 dst_rem[3] = PAD;
218 }
219 } else {
220 let dst_rem = dst_chunks.into_remainder();
221
222 let mut tmp_in = [0u8; 3];
223 let mut tmp_out = [0u8; 4];
224 tmp_in[..src_rem.len()].copy_from_slice(src_rem);
225 Self::encode_3bytes(&tmp_in, &mut tmp_out);
226 dst_rem.copy_from_slice(&tmp_out[..dst_rem.len()]);
227 }
228
229 debug_assert!(str::from_utf8(dst).is_ok());
230
231 #[allow(unsafe_code)]
233 Ok(unsafe { str::from_utf8_unchecked(dst) })
234 }
235
236 #[cfg(feature = "alloc")]
237 fn encode_string(input: &[u8]) -> String {
238 let elen = encoded_len_inner(input.len(), T::PADDED).expect("input is too big");
239 let mut dst = vec![0u8; elen];
240 let res = Self::encode(input, &mut dst).expect("encoding error");
241
242 debug_assert_eq!(elen, res.len());
243 debug_assert!(str::from_utf8(&dst).is_ok());
244
245 #[allow(unsafe_code)]
247 unsafe {
248 String::from_utf8_unchecked(dst)
249 }
250 }
251
252 fn encoded_len(bytes: &[u8]) -> usize {
253 encoded_len_inner(bytes.len(), T::PADDED).unwrap_or(0)
254 }
255}
256
257#[inline(always)]
267pub(crate) fn decode_padding(input: &[u8]) -> Result<(usize, i16), InvalidEncodingError> {
268 if input.len() % 4 != 0 {
269 return Err(InvalidEncodingError);
270 }
271
272 let unpadded_len = match *input {
273 [.., b0, b1] => is_pad_ct(b0)
274 .checked_add(is_pad_ct(b1))
275 .and_then(|len| len.try_into().ok())
276 .and_then(|len| input.len().checked_sub(len))
277 .ok_or(InvalidEncodingError)?,
278 _ => input.len(),
279 };
280
281 let padding_len = input
282 .len()
283 .checked_sub(unpadded_len)
284 .ok_or(InvalidEncodingError)?;
285
286 let err = match *input {
287 [.., b0] if padding_len == 1 => is_pad_ct(b0) ^ 1,
288 [.., b0, b1] if padding_len == 2 => (is_pad_ct(b0) & is_pad_ct(b1)) ^ 1,
289 _ => {
290 if padding_len == 0 {
291 0
292 } else {
293 return Err(InvalidEncodingError);
294 }
295 }
296 };
297
298 Ok((unpadded_len, err))
299}
300
301fn validate_last_block<T: Alphabet>(encoded: &[u8], decoded: &[u8]) -> Result<(), Error> {
304 if encoded.is_empty() && decoded.is_empty() {
305 return Ok(());
306 }
307
308 #[allow(clippy::integer_arithmetic)]
310 fn last_block_start(bytes: &[u8], block_size: usize) -> usize {
311 (bytes.len().saturating_sub(1) / block_size) * block_size
312 }
313
314 let enc_block = encoded
315 .get(last_block_start(encoded, 4)..)
316 .ok_or(Error::InvalidEncoding)?;
317
318 let dec_block = decoded
319 .get(last_block_start(decoded, 3)..)
320 .ok_or(Error::InvalidEncoding)?;
321
322 let mut buf = [0u8; 4];
324 let block = T::encode(dec_block, &mut buf)?;
325
326 if block
329 .as_bytes()
330 .iter()
331 .zip(enc_block.iter())
332 .fold(0, |acc, (a, b)| acc | (a ^ b))
333 == 0
334 {
335 Ok(())
336 } else {
337 Err(Error::InvalidEncoding)
338 }
339}
340
341#[allow(clippy::integer_arithmetic)]
348#[inline(always)]
349pub(crate) fn decoded_len(input_len: usize) -> usize {
350 let k = input_len / 4;
352 let l = input_len - 4 * k;
353 3 * k + (3 * l) / 4
354}
355
356#[allow(clippy::integer_arithmetic)]
359#[inline(always)]
360fn is_pad_ct(input: u8) -> i16 {
361 ((((PAD as i16 - 1) - input as i16) & (input as i16 - (PAD as i16 + 1))) >> 8) & 1
362}
363
364#[allow(clippy::integer_arithmetic)]
366#[inline(always)]
367const fn encoded_len_inner(n: usize, padded: bool) -> Option<usize> {
368 match n.checked_mul(4) {
369 Some(q) => {
370 if padded {
371 Some(((q / 3) + 3) & !3)
372 } else {
373 Some((q / 3) + (q % 3 != 0) as usize)
374 }
375 }
376 None => None,
377 }
378}