1use crate::alsw::{STANDARD_ALSW_CHECK_X2, URL_SAFE_ALSW_CHECK_X2};
2use crate::alsw::{STANDARD_ALSW_DECODE_X2, URL_SAFE_ALSW_DECODE_X2};
3use crate::{Config, Error, Extra, Kind};
4use crate::{STANDARD_CHARSET, URL_SAFE_CHARSET};
5
6use vsimd::alsw::AlswLut;
7use vsimd::isa::{NEON, SSSE3, WASM128};
8use vsimd::mask::u8x32_highbit_any;
9use vsimd::matches_isa;
10use vsimd::tools::{read, write};
11use vsimd::vector::V256;
12use vsimd::SIMD256;
13
14use core::ops::Not;
15
16const fn decode_table(charset: &'static [u8; 64]) -> [u8; 256] {
17 let mut table = [0xff; 256];
18 let mut i = 0;
19 while i < charset.len() {
20 table[charset[i] as usize] = i as u8;
21 i += 1;
22 }
23 table
24}
25
26pub const STANDARD_DECODE_TABLE: &[u8; 256] = &decode_table(STANDARD_CHARSET);
27pub const URL_SAFE_DECODE_TABLE: &[u8; 256] = &decode_table(URL_SAFE_CHARSET);
28
29#[inline(always)]
30pub(crate) fn decoded_length(src: &[u8], config: Config) -> Result<(usize, usize), Error> {
31 if src.is_empty() {
32 return Ok((0, 0));
33 }
34
35 let n = unsafe {
36 let len = src.len();
37
38 let count_pad = || {
39 let last1 = *src.get_unchecked(len - 1);
40 let last2 = *src.get_unchecked(len - 2);
41 if last1 == b'=' {
42 if last2 == b'=' {
43 2
44 } else {
45 1
46 }
47 } else {
48 0
49 }
50 };
51
52 match config.extra {
53 Extra::Pad => {
54 ensure!(len % 4 == 0);
55 len - count_pad()
56 }
57 Extra::NoPad => len,
58 Extra::Forgiving => {
59 if len % 4 == 0 {
60 len - count_pad()
61 } else {
62 len
63 }
64 }
65 }
66 };
67
68 let m = match n % 4 {
69 0 => n / 4 * 3,
70 1 => return Err(Error::new()),
71 2 => n / 4 * 3 + 1,
72 3 => n / 4 * 3 + 2,
73 _ => unsafe { core::hint::unreachable_unchecked() },
74 };
75
76 Ok((n, m))
77}
78
79#[inline(always)]
80pub unsafe fn decode_ascii8<const WRITE: bool>(src: *const u8, dst: *mut u8, table: *const u8) -> Result<(), Error> {
81 let mut y: u64 = 0;
82 let mut flag = 0;
83
84 let mut i = 0;
85 while i < 8 {
86 let x = read(src, i);
87 let bits = read(table, x as usize);
88 flag |= bits;
89
90 if WRITE {
91 y |= (bits as u64) << (58 - i * 6);
92 }
93
94 i += 1;
95 }
96
97 if WRITE {
98 dst.cast::<u64>().write_unaligned(y.to_be());
99 }
100
101 ensure!(flag != 0xff);
102 Ok(())
103}
104
105#[inline(always)]
106pub unsafe fn decode_ascii4<const WRITE: bool>(src: *const u8, dst: *mut u8, table: *const u8) -> Result<(), Error> {
107 let mut y: u32 = 0;
108 let mut flag = 0;
109
110 let mut i = 0;
111 while i < 4 {
112 let x = read(src, i);
113 let bits = read(table, x as usize);
114 flag |= bits;
115
116 if WRITE {
117 y |= (bits as u32) << (18 - i * 6);
118 }
119
120 i += 1;
121 }
122
123 if WRITE {
124 let y = y.to_be_bytes();
125 write(dst, 0, y[1]);
126 write(dst, 1, y[2]);
127 write(dst, 2, y[3]);
128 }
129
130 ensure!(flag != 0xff);
131 Ok(())
132}
133
134#[inline(always)]
135pub unsafe fn decode_extra<const WRITE: bool>(
136 extra: usize,
137 src: *const u8,
138 dst: *mut u8,
139 table: *const u8,
140 forgiving: bool,
141) -> Result<(), Error> {
142 match extra {
143 0 => {}
144 1 => core::hint::unreachable_unchecked(),
145 2 => {
146 let [x1, x2] = src.cast::<[u8; 2]>().read();
147
148 let y1 = read(table, x1 as usize);
149 let y2 = read(table, x2 as usize);
150 ensure!((y1 | y2) != 0xff && (forgiving || (y2 & 0x0f) == 0));
151
152 if WRITE {
153 write(dst, 0, (y1 << 2) | (y2 >> 4));
154 }
155 }
156 3 => {
157 let [x1, x2, x3] = src.cast::<[u8; 3]>().read();
158
159 let y1 = read(table, x1 as usize);
160 let y2 = read(table, x2 as usize);
161 let y3 = read(table, x3 as usize);
162 ensure!((y1 | y2 | y3) != 0xff && (forgiving || (y3 & 0x03) == 0));
163
164 if WRITE {
165 write(dst, 0, (y1 << 2) | (y2 >> 4));
166 write(dst, 1, (y2 << 4) | (y3 >> 2));
167 }
168 }
169 _ => core::hint::unreachable_unchecked(),
170 }
171 Ok(())
172}
173
174#[inline]
175pub(crate) unsafe fn decode_fallback(
176 mut src: *const u8,
177 mut dst: *mut u8,
178 mut n: usize,
179 config: Config,
180) -> Result<(), Error> {
181 let kind = config.kind;
182 let forgiving = config.extra.forgiving();
183
184 let table = match kind {
185 Kind::Standard => STANDARD_DECODE_TABLE.as_ptr(),
186 Kind::UrlSafe => URL_SAFE_DECODE_TABLE.as_ptr(),
187 };
188
189 while n >= 11 {
191 decode_ascii8::<true>(src, dst, table)?;
192 src = src.add(8);
193 dst = dst.add(6);
194 n -= 8;
195 }
196
197 let end = src.add(n / 4 * 4);
198 while src < end {
199 decode_ascii4::<true>(src, dst, table)?;
200 src = src.add(4);
201 dst = dst.add(3);
202 }
203 n %= 4;
204
205 decode_extra::<true>(n, src, dst, table, forgiving)
206}
207
208#[inline(always)]
209pub(crate) unsafe fn decode_simd<S: SIMD256>(
210 s: S,
211 mut src: *const u8,
212 mut dst: *mut u8,
213 mut n: usize,
214 config: Config,
215) -> Result<(), Error> {
216 let kind = config.kind;
217
218 let (check_lut, decode_lut) = match kind {
219 Kind::Standard => (STANDARD_ALSW_CHECK_X2, STANDARD_ALSW_DECODE_X2),
220 Kind::UrlSafe => (URL_SAFE_ALSW_CHECK_X2, URL_SAFE_ALSW_DECODE_X2),
221 };
222
223 while n >= 38 {
225 let x = s.v256_load_unaligned(src);
226 let y = try_!(decode_ascii32(s, x, check_lut, decode_lut));
227
228 let (y1, y2) = y.to_v128x2();
229 s.v128_store_unaligned(dst, y1);
230 s.v128_store_unaligned(dst.add(12), y2);
231
232 src = src.add(32);
233 dst = dst.add(24);
234 n -= 32;
235 }
236
237 decode_fallback(src, dst, n, config)
238}
239
240#[inline(always)]
241fn merge_bits_x2<S: SIMD256>(s: S, x: V256) -> V256 {
242 let y = if matches_isa!(S, SSSE3) {
245 let m1 = s.u16x16_splat(u16::from_le_bytes([0x40, 0x01]));
246 let x1 = s.i16x16_maddubs(x, m1);
247 let m2 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x10, 0x01, 0x00]));
250 s.i16x16_madd(x1, m2)
251 } else if matches_isa!(S, NEON | WASM128) {
253 let m1 = s.u32x8_splat(u32::from_le_bytes([0x3f, 0x00, 0x3f, 0x00]));
254 let x1 = s.v256_and(x, m1);
255 let m2 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x3f, 0x00, 0x3f]));
258 let x2 = s.v256_and(x, m2);
259 let x3 = s.v256_or(s.u32x8_shl::<18>(x1), s.u32x8_shr::<10>(x1));
262 let x4 = s.v256_or(s.u32x8_shl::<4>(x2), s.u32x8_shr::<24>(x2));
265 let mask = s.u32x8_splat(u32::from_le_bytes([0xff, 0xff, 0xff, 0x00]));
268 s.v256_and(s.v256_or(x3, x4), mask)
269 } else {
271 unreachable!()
272 };
273
274 const SHUFFLE: V256 = V256::double_bytes([
275 0x02, 0x01, 0x00, 0x06, 0x05, 0x04, 0x0a, 0x09, 0x08, 0x0e, 0x0d, 0x0c, 0x80, 0x80, 0x80, 0x80, ]);
278 s.u8x16x2_swizzle(y, SHUFFLE)
279 }
281
282#[inline(always)]
283fn decode_ascii32<S: SIMD256>(s: S, x: V256, check: AlswLut<V256>, decode: AlswLut<V256>) -> Result<V256, Error> {
284 let (c1, c2) = vsimd::alsw::decode_ascii_xn(s, x, check, decode);
285 let y = merge_bits_x2(s, c2);
286 ensure!(u8x32_highbit_any(s, c1).not());
287 Ok(y)
288}