str_indices/
byte_chunk.rs

1#[cfg(target_arch = "x86_64")]
2use core::arch::x86_64;
3
4// Which type to actually use at build time.
5#[cfg(all(feature = "simd", target_arch = "x86_64"))]
6pub(crate) type Chunk = x86_64::__m128i;
7#[cfg(any(not(feature = "simd"), not(any(target_arch = "x86_64"))))]
8pub(crate) type Chunk = usize;
9
10/// Interface for working with chunks of bytes at a time, providing the
11/// operations needed for the functionality in str_utils.
12pub(crate) trait ByteChunk: Copy + Clone {
13    /// Size of the chunk in bytes.
14    const SIZE: usize;
15
16    /// Maximum number of iterations the chunk can accumulate
17    /// before sum_bytes() becomes inaccurate.
18    const MAX_ACC: usize;
19
20    /// Creates a new chunk with all bytes set to zero.
21    fn zero() -> Self;
22
23    /// Creates a new chunk with all bytes set to n.
24    fn splat(n: u8) -> Self;
25
26    /// Returns whether all bytes are zero or not.
27    fn is_zero(&self) -> bool;
28
29    /// Shifts bytes back lexographically by n bytes.
30    fn shift_back_lex(&self, n: usize) -> Self;
31
32    /// Shifts bits to the right by n bits.
33    fn shr(&self, n: usize) -> Self;
34
35    /// Compares bytes for equality with the given byte.
36    ///
37    /// Bytes that are equal are set to 1, bytes that are not
38    /// are set to 0.
39    fn cmp_eq_byte(&self, byte: u8) -> Self;
40
41    /// Compares bytes to see if they're in the non-inclusive range (a, b),
42    /// where a < b <= 127.
43    ///
44    /// Bytes in the range are set to 1, bytes not in the range are set to 0.
45    fn bytes_between_127(&self, a: u8, b: u8) -> Self;
46
47    /// Performs a bitwise and on two chunks.
48    fn bitand(&self, other: Self) -> Self;
49
50    /// Adds the bytes of two chunks together.
51    fn add(&self, other: Self) -> Self;
52
53    /// Subtracts other's bytes from this chunk.
54    fn sub(&self, other: Self) -> Self;
55
56    /// Increments the nth-from-last lexographic byte by 1.
57    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self;
58
59    /// Decrements the last lexographic byte by 1.
60    fn dec_last_lex_byte(&self) -> Self;
61
62    /// Returns the sum of all bytes in the chunk.
63    fn sum_bytes(&self) -> usize;
64}
65
66impl ByteChunk for usize {
67    const SIZE: usize = core::mem::size_of::<usize>();
68    const MAX_ACC: usize = (256 / core::mem::size_of::<usize>()) - 1;
69
70    #[inline(always)]
71    fn zero() -> Self {
72        0
73    }
74
75    #[inline(always)]
76    fn splat(n: u8) -> Self {
77        const ONES: usize = core::usize::MAX / 0xFF;
78        ONES * n as usize
79    }
80
81    #[inline(always)]
82    fn is_zero(&self) -> bool {
83        *self == 0
84    }
85
86    #[inline(always)]
87    fn shift_back_lex(&self, n: usize) -> Self {
88        if cfg!(target_endian = "little") {
89            *self >> (n * 8)
90        } else {
91            *self << (n * 8)
92        }
93    }
94
95    #[inline(always)]
96    fn shr(&self, n: usize) -> Self {
97        *self >> n
98    }
99
100    #[inline(always)]
101    fn cmp_eq_byte(&self, byte: u8) -> Self {
102        const ONES: usize = core::usize::MAX / 0xFF;
103        const ONES_HIGH: usize = ONES << 7;
104        let word = *self ^ (byte as usize * ONES);
105        (!(((word & !ONES_HIGH) + !ONES_HIGH) | word) & ONES_HIGH) >> 7
106    }
107
108    #[inline(always)]
109    fn bytes_between_127(&self, a: u8, b: u8) -> Self {
110        const ONES: usize = core::usize::MAX / 0xFF;
111        const ONES_HIGH: usize = ONES << 7;
112        let tmp = *self & (ONES * 127);
113        (((ONES * (127 + b as usize) - tmp) & !*self & (tmp + (ONES * (127 - a as usize))))
114            & ONES_HIGH)
115            >> 7
116    }
117
118    #[inline(always)]
119    fn bitand(&self, other: Self) -> Self {
120        *self & other
121    }
122
123    #[inline(always)]
124    fn add(&self, other: Self) -> Self {
125        *self + other
126    }
127
128    #[inline(always)]
129    fn sub(&self, other: Self) -> Self {
130        *self - other
131    }
132
133    #[inline(always)]
134    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
135        if cfg!(target_endian = "little") {
136            *self + (1 << ((Self::SIZE - 1 - n) * 8))
137        } else {
138            *self + (1 << (n * 8))
139        }
140    }
141
142    #[inline(always)]
143    fn dec_last_lex_byte(&self) -> Self {
144        if cfg!(target_endian = "little") {
145            *self - (1 << ((Self::SIZE - 1) * 8))
146        } else {
147            *self - 1
148        }
149    }
150
151    #[inline(always)]
152    fn sum_bytes(&self) -> usize {
153        const ONES: usize = core::usize::MAX / 0xFF;
154        self.wrapping_mul(ONES) >> ((Self::SIZE - 1) * 8)
155    }
156}
157
158// Note: use only SSE2 and older instructions, since these are
159// guaranteed on all x86_64 platforms.
160#[cfg(target_arch = "x86_64")]
161impl ByteChunk for x86_64::__m128i {
162    const SIZE: usize = core::mem::size_of::<x86_64::__m128i>();
163    const MAX_ACC: usize = 255;
164
165    #[inline(always)]
166    fn zero() -> Self {
167        unsafe { x86_64::_mm_setzero_si128() }
168    }
169
170    #[inline(always)]
171    fn splat(n: u8) -> Self {
172        unsafe { x86_64::_mm_set1_epi8(n as i8) }
173    }
174
175    #[inline(always)]
176    fn is_zero(&self) -> bool {
177        let tmp = unsafe { core::mem::transmute::<Self, (u64, u64)>(*self) };
178        tmp.0 == 0 && tmp.1 == 0
179    }
180
181    #[inline(always)]
182    fn shift_back_lex(&self, n: usize) -> Self {
183        match n {
184            0 => *self,
185            1 => unsafe { x86_64::_mm_srli_si128(*self, 1) },
186            2 => unsafe { x86_64::_mm_srli_si128(*self, 2) },
187            3 => unsafe { x86_64::_mm_srli_si128(*self, 3) },
188            4 => unsafe { x86_64::_mm_srli_si128(*self, 4) },
189            _ => unreachable!(),
190        }
191    }
192
193    #[inline(always)]
194    fn shr(&self, n: usize) -> Self {
195        match n {
196            0 => *self,
197            1 => unsafe { x86_64::_mm_srli_epi64(*self, 1) },
198            2 => unsafe { x86_64::_mm_srli_epi64(*self, 2) },
199            3 => unsafe { x86_64::_mm_srli_epi64(*self, 3) },
200            4 => unsafe { x86_64::_mm_srli_epi64(*self, 4) },
201            _ => unreachable!(),
202        }
203    }
204
205    #[inline(always)]
206    fn cmp_eq_byte(&self, byte: u8) -> Self {
207        let tmp = unsafe { x86_64::_mm_cmpeq_epi8(*self, Self::splat(byte)) };
208        unsafe { x86_64::_mm_and_si128(tmp, Self::splat(1)) }
209    }
210
211    #[inline(always)]
212    fn bytes_between_127(&self, a: u8, b: u8) -> Self {
213        let tmp1 = unsafe { x86_64::_mm_cmpgt_epi8(*self, Self::splat(a)) };
214        let tmp2 = unsafe { x86_64::_mm_cmplt_epi8(*self, Self::splat(b)) };
215        let tmp3 = unsafe { x86_64::_mm_and_si128(tmp1, tmp2) };
216        unsafe { x86_64::_mm_and_si128(tmp3, Self::splat(1)) }
217    }
218
219    #[inline(always)]
220    fn bitand(&self, other: Self) -> Self {
221        unsafe { x86_64::_mm_and_si128(*self, other) }
222    }
223
224    #[inline(always)]
225    fn add(&self, other: Self) -> Self {
226        unsafe { x86_64::_mm_add_epi8(*self, other) }
227    }
228
229    #[inline(always)]
230    fn sub(&self, other: Self) -> Self {
231        unsafe { x86_64::_mm_sub_epi8(*self, other) }
232    }
233
234    #[inline(always)]
235    fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
236        let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
237        tmp[15 - n] += 1;
238        unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
239    }
240
241    #[inline(always)]
242    fn dec_last_lex_byte(&self) -> Self {
243        let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
244        tmp[15] -= 1;
245        unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
246    }
247
248    #[inline(always)]
249    fn sum_bytes(&self) -> usize {
250        let half_sum = unsafe { x86_64::_mm_sad_epu8(*self, x86_64::_mm_setzero_si128()) };
251        let (low, high) = unsafe { core::mem::transmute::<Self, (u64, u64)>(half_sum) };
252        (low + high) as usize
253    }
254}
255
256//=============================================================
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn usize_flag_bytes_01() {
264        let v: usize = 0xE2_09_08_A6_E2_A6_E2_09;
265        assert_eq!(0x00_00_00_00_00_00_00_00, v.cmp_eq_byte(0x07));
266        assert_eq!(0x00_00_01_00_00_00_00_00, v.cmp_eq_byte(0x08));
267        assert_eq!(0x00_01_00_00_00_00_00_01, v.cmp_eq_byte(0x09));
268        assert_eq!(0x00_00_00_01_00_01_00_00, v.cmp_eq_byte(0xA6));
269        assert_eq!(0x01_00_00_00_01_00_01_00, v.cmp_eq_byte(0xE2));
270    }
271
272    #[test]
273    fn usize_bytes_between_127_01() {
274        let v: usize = 0x7E_09_00_A6_FF_7F_08_07;
275        assert_eq!(0x01_01_00_00_00_00_01_01, v.bytes_between_127(0x00, 0x7F));
276        assert_eq!(0x00_01_00_00_00_00_01_00, v.bytes_between_127(0x07, 0x7E));
277        assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
278    }
279
280    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
281    #[test]
282    fn sum_bytes_x86_64() {
283        use core::arch::x86_64::__m128i as T;
284
285        let ones = T::splat(1);
286        let mut acc = T::zero();
287        for _ in 0..T::MAX_ACC {
288            acc = acc.add(ones);
289        }
290
291        assert_eq!(acc.sum_bytes(), T::SIZE * T::MAX_ACC);
292    }
293}