base64_simd/
encode.rs

1use crate::{Config, Kind};
2use crate::{STANDARD_CHARSET, URL_SAFE_CHARSET};
3
4use vsimd::isa::{NEON, SSE2, WASM128};
5use vsimd::tools::{read, write};
6use vsimd::vector::{V128, V256};
7use vsimd::{matches_isa, POD};
8use vsimd::{Scalable, SIMD128, SIMD256};
9
10#[inline(always)]
11pub(crate) const fn encoded_length_unchecked(len: usize, config: Config) -> usize {
12    let extra = len % 3;
13    if extra == 0 {
14        len / 3 * 4
15    } else if config.extra.padding() {
16        len / 3 * 4 + 4
17    } else {
18        len / 3 * 4 + extra + 1
19    }
20}
21
22#[inline(always)]
23unsafe fn encode_bits24(src: *const u8, dst: *mut u8, charset: *const u8) {
24    let x = u32::from_be_bytes([0, read(src, 0), read(src, 1), read(src, 2)]);
25    let mut i = 0;
26    while i < 4 {
27        let bits = (x >> (18 - i * 6)) & 0x3f;
28        let y = read(charset, bits as usize);
29        write(dst, i, y);
30        i += 1;
31    }
32}
33
34#[inline(always)]
35unsafe fn encode_bits48(src: *const u8, dst: *mut u8, charset: *const u8) {
36    let x = u64::from_be_bytes(src.cast::<[u8; 8]>().read());
37    let mut i = 0;
38    while i < 8 {
39        let bits = (x >> (58 - i * 6)) & 0x3f;
40        let y = read(charset, bits as usize);
41        write(dst, i, y);
42        i += 1;
43    }
44}
45
46#[inline(always)]
47unsafe fn encode_extra(extra: usize, src: *const u8, dst: *mut u8, charset: *const u8, padding: bool) {
48    match extra {
49        0 => {}
50        1 => {
51            let x = read(src, 0);
52            let y1 = read(charset, (x >> 2) as usize);
53            let y2 = read(charset, ((x << 6) >> 2) as usize);
54            write(dst, 0, y1);
55            write(dst, 1, y2);
56            if padding {
57                write(dst, 2, b'=');
58                write(dst, 3, b'=');
59            }
60        }
61        2 => {
62            let x1 = read(src, 0);
63            let x2 = read(src, 1);
64            let y1 = read(charset, (x1 >> 2) as usize);
65            let y2 = read(charset, (((x1 << 6) >> 2) | (x2 >> 4)) as usize);
66            let y3 = read(charset, ((x2 << 4) >> 2) as usize);
67            write(dst, 0, y1);
68            write(dst, 1, y2);
69            write(dst, 2, y3);
70            if padding {
71                write(dst, 3, b'=');
72            }
73        }
74        _ => core::hint::unreachable_unchecked(),
75    }
76}
77
78#[inline]
79pub(crate) unsafe fn encode_fallback(mut src: *const u8, mut len: usize, mut dst: *mut u8, config: Config) {
80    let kind = config.kind;
81    let padding = config.extra.padding();
82
83    let charset = match kind {
84        Kind::Standard => STANDARD_CHARSET.as_ptr(),
85        Kind::UrlSafe => URL_SAFE_CHARSET.as_ptr(),
86    };
87
88    const L: usize = 4;
89    while len >= L * 6 + 2 {
90        let mut i = 0;
91        while i < L {
92            encode_bits48(src, dst, charset);
93            src = src.add(6);
94            dst = dst.add(8);
95            i += 1;
96        }
97        len -= L * 6;
98    }
99
100    while len >= 6 + 2 {
101        encode_bits48(src, dst, charset);
102        src = src.add(6);
103        dst = dst.add(8);
104        len -= 6;
105    }
106
107    let end = src.add(len / 3 * 3);
108    while src < end {
109        encode_bits24(src, dst, charset);
110        src = src.add(3);
111        dst = dst.add(4);
112    }
113    len %= 3;
114
115    encode_extra(len, src, dst, charset, padding);
116}
117
118#[inline(always)]
119pub(crate) unsafe fn encode_simd<S: SIMD256>(
120    s: S,
121    mut src: *const u8,
122    mut len: usize,
123    mut dst: *mut u8,
124    config: Config,
125) {
126    let kind = config.kind;
127
128    if len >= (6 + 24 + 4) {
129        let (charset, shift_lut) = match kind {
130            Kind::Standard => (STANDARD_CHARSET.as_ptr(), STANDARD_ENCODING_SHIFT_X2),
131            Kind::UrlSafe => (URL_SAFE_CHARSET.as_ptr(), URL_SAFE_ENCODING_SHIFT_X2),
132        };
133
134        for _ in 0..2 {
135            encode_bits24(src, dst, charset);
136            src = src.add(3);
137            dst = dst.add(4);
138            len -= 3;
139        }
140
141        while len >= (24 + 4) {
142            let x = s.v256_load_unaligned(src.sub(4));
143            let y = encode_bytes24(s, x, shift_lut);
144            s.v256_store_unaligned(dst, y);
145            src = src.add(24);
146            dst = dst.add(32);
147            len -= 24;
148        }
149    }
150
151    if len >= 12 + 4 {
152        let shift_lut = match kind {
153            Kind::Standard => STANDARD_ENCODING_SHIFT,
154            Kind::UrlSafe => URL_SAFE_ENCODING_SHIFT,
155        };
156
157        let x = s.v128_load_unaligned(src);
158        let y = encode_bytes12(s, x, shift_lut);
159        s.v128_store_unaligned(dst, y);
160        src = src.add(12);
161        dst = dst.add(16);
162        len -= 12;
163    }
164
165    encode_fallback(src, len, dst, config);
166}
167
168const SPLIT_SHUFFLE: V256 = V256::from_bytes([
169    0x05, 0x04, 0x06, 0x05, 0x08, 0x07, 0x09, 0x08, //
170    0x0b, 0x0a, 0x0c, 0x0b, 0x0e, 0x0d, 0x0f, 0x0e, //
171    0x01, 0x00, 0x02, 0x01, 0x04, 0x03, 0x05, 0x04, //
172    0x07, 0x06, 0x08, 0x07, 0x0a, 0x09, 0x0b, 0x0a, //
173]);
174
175#[inline(always)]
176fn split_bits_x2<S: SIMD256>(s: S, x: V256) -> V256 {
177    // x: {????|AAAB|BBCC|CDDD|EEEF|FFGG|GHHH|????}
178
179    let x0 = s.u8x16x2_swizzle(x, SPLIT_SHUFFLE);
180    // x0: {bbbbcccc|aaaaaabb|ccdddddd|bbbbcccc} x8 (1021)
181
182    if matches_isa!(S, SSE2) {
183        let m1 = s.u32x8_splat(u32::from_le_bytes([0x00, 0xfc, 0xc0, 0x0f]));
184        let x1 = s.v256_and(x0, m1);
185        // x1: {00000000|aaaaaa00|cc000000|0000cccc} x8
186
187        let m2 = s.u32x8_splat(u32::from_le_bytes([0xf0, 0x03, 0x3f, 0x00]));
188        let x2 = s.v256_and(x0, m2);
189        // x2: {bbbb0000|000000bb|00dddddd|00000000} x8
190
191        let m3 = s.u32x8_splat(u32::from_le_bytes([0x40, 0x00, 0x00, 0x04]));
192        let x3 = s.u16x16_mul_hi(x1, m3);
193        // x3: {00aaaaaa|00000000|00cccccc|00000000} x8
194
195        let m4 = s.u32x8_splat(u32::from_le_bytes([0x10, 0x00, 0x00, 0x01]));
196        let x4 = s.i16x16_mul_lo(x2, m4);
197        // x4: {00000000|00bbbbbb|00000000|00dddddd} x8
198
199        return s.v256_or(x3, x4);
200        // {00aaaaaa|00bbbbbb|00cccccc|00dddddd} x8
201    }
202
203    if matches_isa!(S, NEON | WASM128) {
204        let m1 = s.u32x8_splat(u32::from_le_bytes([0x00, 0xfc, 0x00, 0x00]));
205        let x1 = s.u16x16_shr::<10>(s.v256_and(x0, m1));
206        // x1: {00aaaaaa|000000000|00000000|00000000} x8
207
208        let m2 = s.u32x8_splat(u32::from_le_bytes([0xf0, 0x03, 0x00, 0x00]));
209        let x2 = s.u16x16_shl::<4>(s.v256_and(x0, m2));
210        // x2: {00000000|00bbbbbb|00000000|00000000} x8
211
212        let m3 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x00, 0xc0, 0x0f]));
213        let x3 = s.u16x16_shr::<6>(s.v256_and(x0, m3));
214        // x3: {00000000|00000000|00cccccc|00000000} x8
215
216        let m4 = s.u32x8_splat(u32::from_le_bytes([0x00, 0x00, 0x3f, 0x00]));
217        let x4 = s.u16x16_shl::<8>(s.v256_and(x0, m4));
218        // x4: {00000000|00000000|00000000|00dddddd} x8
219
220        return s.v256_or(s.v256_or(x1, x2), s.v256_or(x3, x4));
221        // {00aaaaaa|00bbbbbb|00cccccc|00dddddd} x8
222    }
223
224    unreachable!()
225}
226
227#[inline(always)]
228fn split_bits_x1<S: SIMD128>(s: S, x: V128) -> V128 {
229    // x: {AAAB|BBCC|CDDD|????}
230
231    const SHUFFLE: V128 = SPLIT_SHUFFLE.to_v128x2().1;
232    let x0 = s.u8x16_swizzle(x, SHUFFLE);
233    // x0: {bbbbcccc|aaaaaabb|ccdddddd|bbbbcccc} x8 (1021)
234
235    if matches_isa!(S, SSE2) {
236        let m1 = s.u32x4_splat(u32::from_le_bytes([0x00, 0xfc, 0xc0, 0x0f]));
237        let x1 = s.v128_and(x0, m1);
238
239        let m2 = s.u32x4_splat(u32::from_le_bytes([0xf0, 0x03, 0x3f, 0x00]));
240        let x2 = s.v128_and(x0, m2);
241
242        let m3 = s.u32x4_splat(u32::from_le_bytes([0x40, 0x00, 0x00, 0x04]));
243        let x3 = s.u16x8_mul_hi(x1, m3);
244
245        let m4 = s.u32x4_splat(u32::from_le_bytes([0x10, 0x00, 0x00, 0x01]));
246        let x4 = s.i16x8_mul_lo(x2, m4);
247
248        return s.v128_or(x3, x4);
249    }
250
251    if matches_isa!(S, NEON | WASM128) {
252        let m1 = s.u32x4_splat(u32::from_le_bytes([0x00, 0xfc, 0x00, 0x00]));
253        let x1 = s.u16x8_shr::<10>(s.v128_and(x0, m1));
254
255        let m2 = s.u32x4_splat(u32::from_le_bytes([0xf0, 0x03, 0x00, 0x00]));
256        let x2 = s.u16x8_shl::<4>(s.v128_and(x0, m2));
257
258        let m3 = s.u32x4_splat(u32::from_le_bytes([0x00, 0x00, 0xc0, 0x0f]));
259        let x3 = s.u16x8_shr::<6>(s.v128_and(x0, m3));
260
261        let m4 = s.u32x4_splat(u32::from_le_bytes([0x00, 0x00, 0x3f, 0x00]));
262        let x4 = s.u16x8_shl::<8>(s.v128_and(x0, m4));
263
264        return s.v128_or(s.v128_or(x1, x2), s.v128_or(x3, x4));
265    }
266
267    unreachable!()
268}
269
270#[inline]
271const fn encoding_shift(charset: &'static [u8; 64]) -> V128 {
272    // 0~25     'A'   [13]
273    // 26~51    'a'   [0]
274    // 52~61    '0'   [1~10]
275    // 62       c62   [11]
276    // 63       c63   [12]
277
278    let mut lut = [0x80; 16];
279    lut[13] = b'A';
280    lut[0] = b'a' - 26;
281    let mut i = 1;
282    while i <= 10 {
283        lut[i] = b'0'.wrapping_sub(52);
284        i += 1;
285    }
286    lut[11] = charset[62].wrapping_sub(62);
287    lut[12] = charset[63].wrapping_sub(63);
288    V128::from_bytes(lut)
289}
290
291const STANDARD_ENCODING_SHIFT: V128 = encoding_shift(STANDARD_CHARSET);
292const URL_SAFE_ENCODING_SHIFT: V128 = encoding_shift(URL_SAFE_CHARSET);
293
294const STANDARD_ENCODING_SHIFT_X2: V256 = STANDARD_ENCODING_SHIFT.x2();
295const URL_SAFE_ENCODING_SHIFT_X2: V256 = URL_SAFE_ENCODING_SHIFT.x2();
296
297#[inline(always)]
298fn encode_values<S: Scalable<V>, V: POD>(s: S, x: V, shift_lut: V) -> V {
299    // x: {00aaaaaa|00bbbbbb|00cccccc|00dddddd} xn
300
301    let x1 = s.u8xn_sub_sat(x, s.u8xn_splat(51));
302    // 0~25    => 0
303    // 26~51   => 0
304    // 52~61   => 1~10
305    // 62      => 11
306    // 63      => 12
307
308    let x2 = s.i8xn_lt(x, s.u8xn_splat(26));
309    let x3 = s.and(x2, s.u8xn_splat(13));
310    let x4 = s.or(x1, x3);
311    // 0~25    => 0xff  => 13
312    // 26~51   => 0     => 0
313    // 52~61   => 0     => 1~10
314    // 62      => 0     => 11
315    // 63      => 0     => 12
316
317    let shift = s.u8x16xn_swizzle(shift_lut, x4);
318    s.u8xn_add(x, shift)
319    // {{ascii}} xn
320}
321
322#[inline(always)]
323fn encode_bytes24<S: SIMD256>(s: S, x: V256, shift_lut: V256) -> V256 {
324    // x: {????|AAAB|BBCC|CDDD|EEEF|FFGG|GHHH|????}
325
326    let values = split_bits_x2(s, x);
327    // values: {00aaaaaa|00bbbbbb|00cccccc|00dddddd} x8
328
329    encode_values(s, values, shift_lut)
330    // {{ascii}} x32
331}
332
333#[inline(always)]
334fn encode_bytes12<S: SIMD256>(s: S, x: V128, shift_lut: V128) -> V128 {
335    // x: {AAAB|BBCC|CDDD|????}
336
337    let values = split_bits_x1(s, x);
338    // values: {00aaaaaa|00bbbbbb|00cccccc|00dddddd} x4
339
340    encode_values(s, values, shift_lut)
341    // {{ascii}} x16
342}