base64_simd/
decode.rs

1use crate::alsw::{STANDARD_ALSW_CHECK_X2, URL_SAFE_ALSW_CHECK_X2};
2use crate::alsw::{STANDARD_ALSW_DECODE_X2, URL_SAFE_ALSW_DECODE_X2};
3use crate::{Config, Error, Extra, Kind};
4use crate::{STANDARD_CHARSET, URL_SAFE_CHARSET};
5
6use vsimd::alsw::AlswLut;
7use vsimd::isa::{NEON, SSSE3, WASM128};
8use vsimd::mask::u8x32_highbit_any;
9use vsimd::matches_isa;
10use vsimd::tools::{read, write};
11use vsimd::vector::V256;
12use vsimd::SIMD256;
13
14use core::ops::Not;
15
16const fn decode_table(charset: &'static [u8; 64]) -> [u8; 256] {
17    let mut table = [0xff; 256];
18    let mut i = 0;
19    while i < charset.len() {
20        table[charset[i] as usize] = i as u8;
21        i += 1;
22    }
23    table
24}
25
26pub const STANDARD_DECODE_TABLE: &[u8; 256] = &decode_table(STANDARD_CHARSET);
27pub const URL_SAFE_DECODE_TABLE: &[u8; 256] = &decode_table(URL_SAFE_CHARSET);
28
29#[inline(always)]
30pub(crate) fn decoded_length(src: &[u8], config: Config) -> Result<(usize, usize), Error> {
31    if src.is_empty() {
32        return Ok((0, 0));
33    }
34
35    let n = unsafe {
36        let len = src.len();
37
38        let count_pad = || {
39            let last1 = *src.get_unchecked(len - 1);
40            let last2 = *src.get_unchecked(len - 2);
41            if last1 == b'=' {
42                if last2 == b'=' {
43                    2
44                } else {
45                    1
46                }
47            } else {
48                0
49            }
50        };
51
52        match config.extra {
53            Extra::Pad => {
54                ensure!(len % 4 == 0);
55                len - count_pad()
56            }
57            Extra::NoPad => len,
58            Extra::Forgiving => {
59                if len % 4 == 0 {
60                    len - count_pad()
61                } else {
62                    len
63                }
64            }
65        }
66    };
67
68    let m = match n % 4 {
69        0 => n / 4 * 3,
70        1 => return Err(Error::new()),
71        2 => n / 4 * 3 + 1,
72        3 => n / 4 * 3 + 2,
73        _ => unsafe { core::hint::unreachable_unchecked() },
74    };
75
76    Ok((n, m))
77}
78
79#[inline(always)]
80pub unsafe fn decode_ascii8<const WRITE: bool>(src: *const u8, dst: *mut u8, table: *const u8) -> Result<(), Error> {
81    let mut y: u64 = 0;
82    let mut flag = 0;
83
84    let mut i = 0;
85    while i < 8 {
86        let x = read(src, i);
87        let bits = read(table, x as usize);
88        flag |= bits;
89
90        if WRITE {
91            y |= (bits as u64) << (58 - i * 6);
92        }
93
94        i += 1;
95    }
96
97    if WRITE {
98        dst.cast::<u64>().write_unaligned(y.to_be());
99    }
100
101    ensure!(flag != 0xff);
102    Ok(())
103}
104
105#[inline(always)]
106pub unsafe fn decode_ascii4<const WRITE: bool>(src: *const u8, dst: *mut u8, table: *const u8) -> Result<(), Error> {
107    let mut y: u32 = 0;
108    let mut flag = 0;
109
110    let mut i = 0;
111    while i < 4 {
112        let x = read(src, i);
113        let bits = read(table, x as usize);
114        flag |= bits;
115
116        if WRITE {
117            y |= (bits as u32) << (18 - i * 6);
118        }
119
120        i += 1;
121    }
122
123    if WRITE {
124        let y = y.to_be_bytes();
125        write(dst, 0, y[1]);
126        write(dst, 1, y[2]);
127        write(dst, 2, y[3]);
128    }
129
130    ensure!(flag != 0xff);
131    Ok(())
132}
133
134#[inline(always)]
135pub unsafe fn decode_extra<const WRITE: bool>(
136    extra: usize,
137    src: *const u8,
138    dst: *mut u8,
139    table: *const u8,
140    forgiving: bool,
141) -> Result<(), Error> {
142    match extra {
143        0 => {}
144        1 => core::hint::unreachable_unchecked(),
145        2 => {
146            let [x1, x2] = src.cast::<[u8; 2]>().read();
147
148            let y1 = read(table, x1 as usize);
149            let y2 = read(table, x2 as usize);
150            ensure!((y1 | y2) != 0xff && (forgiving || (y2 & 0x0f) == 0));
151
152            if WRITE {
153                write(dst, 0, (y1 << 2) | (y2 >> 4));
154            }
155        }
156        3 => {
157            let [x1, x2, x3] = src.cast::<[u8; 3]>().read();
158
159            let y1 = read(table, x1 as usize);
160            let y2 = read(table, x2 as usize);
161            let y3 = read(table, x3 as usize);
162            ensure!((y1 | y2 | y3) != 0xff && (forgiving || (y3 & 0x03) == 0));
163
164            if WRITE {
165                write(dst, 0, (y1 << 2) | (y2 >> 4));
166                write(dst, 1, (y2 << 4) | (y3 >> 2));
167            }
168        }
169        _ => core::hint::unreachable_unchecked(),
170    }
171    Ok(())
172}
173
174#[inline]
175pub(crate) unsafe fn decode_fallback(
176    mut src: *const u8,
177    mut dst: *mut u8,
178    mut n: usize,
179    config: Config,
180) -> Result<(), Error> {
181    let kind = config.kind;
182    let forgiving = config.extra.forgiving();
183
184    let table = match kind {
185        Kind::Standard => STANDARD_DECODE_TABLE.as_ptr(),
186        Kind::UrlSafe => URL_SAFE_DECODE_TABLE.as_ptr(),
187    };
188
189    // n*3/4 >= 6+2
190    while n >= 11 {
191        decode_ascii8::<true>(src, dst, table)?;
192        src = src.add(8);
193        dst = dst.add(6);
194        n -= 8;
195    }
196
197    let end = src.add(n / 4 * 4);
198    while src < end {
199        decode_ascii4::<true>(src, dst, table)?;
200        src = src.add(4);
201        dst = dst.add(3);
202    }
203    n %= 4;
204
205    decode_extra::<true>(n, src, dst, table, forgiving)
206}
207
208#[inline(always)]
209pub(crate) unsafe fn decode_simd<S: SIMD256>(
210    s: S,
211    mut src: *const u8,
212    mut dst: *mut u8,
213    mut n: usize,
214    config: Config,
215) -> Result<(), Error> {
216    let kind = config.kind;
217
218    let (check_lut, decode_lut) = match kind {
219        Kind::Standard => (STANDARD_ALSW_CHECK_X2, STANDARD_ALSW_DECODE_X2),
220        Kind::UrlSafe => (URL_SAFE_ALSW_CHECK_X2, URL_SAFE_ALSW_DECODE_X2),
221    };
222
223    // n*3/4 >= 24+4
224    while n >= 38 {
225        let x = s.v256_load_unaligned(src);
226        let y = try_!(decode_ascii32(s, x, check_lut, decode_lut));
227
228        let (y1, y2) = y.to_v128x2();
229        s.v128_store_unaligned(dst, y1);
230        s.v128_store_unaligned(dst.add(12), y2);
231
232        src = src.add(32);
233        dst = dst.add(24);
234        n -= 32;
235    }
236
237    decode_fallback(src, dst, n, config)
238}
239
240#[inline(always)]
241fn merge_bits_x2<S: SIMD256>(s: S, x: V256) -> V256 {
242    // x : {00aaaaaa|00bbbbbb|00cccccc|00dddddd} x8
243
244    let y = if matches_isa!(S, SSSE3) {
245        let m1 = s.u16x16_splat(u16::from_le_bytes([0x40, 0x01]));
246        let x1 = s.i16x16_maddubs(x, m1);
247        // x1: {aabbbbbb|0000aaaa|ccdddddd|0000cccc} x8
248
249        let m2 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x10, 0x01, 0x00]));
250        s.i16x16_madd(x1, m2)
251        // {ccdddddd|bbbbcccc|aaaaaabb|00000000} x8
252    } else if matches_isa!(S, NEON | WASM128) {
253        let m1 = s.u32x8_splat(u32::from_le_bytes([0x3f, 0x00, 0x3f, 0x00]));
254        let x1 = s.v256_and(x, m1);
255        // x1: {00aaaaaa|00000000|00cccccc|00000000} x8
256
257        let m2 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x3f, 0x00, 0x3f]));
258        let x2 = s.v256_and(x, m2);
259        // x2: {00000000|00bbbbbb|00000000|00dddddd} x8
260
261        let x3 = s.v256_or(s.u32x8_shl::<18>(x1), s.u32x8_shr::<10>(x1));
262        // x3: {cc000000|0000cccc|aaaaaa00|00000000} x8
263
264        let x4 = s.v256_or(s.u32x8_shl::<4>(x2), s.u32x8_shr::<24>(x2));
265        // x4: {00dddddd|bbbb0000|000000bb|dddd0000}
266
267        let mask = s.u32x8_splat(u32::from_le_bytes([0xff, 0xff, 0xff, 0x00]));
268        s.v256_and(s.v256_or(x3, x4), mask)
269        // {ccdddddd|bbbbcccc|aaaaaabb|00000000} x8
270    } else {
271        unreachable!()
272    };
273
274    const SHUFFLE: V256 = V256::double_bytes([
275        0x02, 0x01, 0x00, 0x06, 0x05, 0x04, 0x0a, 0x09, //
276        0x08, 0x0e, 0x0d, 0x0c, 0x80, 0x80, 0x80, 0x80, //
277    ]);
278    s.u8x16x2_swizzle(y, SHUFFLE)
279    // {AAAB|BBCC|CDDD|0000|EEEF|FFGG|GHHH|0000}
280}
281
282#[inline(always)]
283fn decode_ascii32<S: SIMD256>(s: S, x: V256, check: AlswLut<V256>, decode: AlswLut<V256>) -> Result<V256, Error> {
284    let (c1, c2) = vsimd::alsw::decode_ascii_xn(s, x, check, decode);
285    let y = merge_bits_x2(s, c2);
286    ensure!(u8x32_highbit_any(s, c1).not());
287    Ok(y)
288}