bstr/byteset/
scalar.rs

1// This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2// the 'inverse' query of memchr, e.g. finding the first byte not in the
3// provided set. This is simple for the 1-byte case.
4
5use core::{cmp, usize};
6
7const USIZE_BYTES: usize = core::mem::size_of::<usize>();
8
9// The number of bytes to loop at in one iteration of memchr/memrchr.
10const LOOP_SIZE: usize = 2 * USIZE_BYTES;
11
12/// Repeat the given byte into a word size number. That is, every 8 bits
13/// is equivalent to the given byte. For example, if `b` is `\x4E` or
14/// `01001110` in binary, then the returned value on a 32-bit system would be:
15/// `01001110_01001110_01001110_01001110`.
16#[inline(always)]
17fn repeat_byte(b: u8) -> usize {
18    (b as usize) * (usize::MAX / 255)
19}
20
21pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
22    let vn1 = repeat_byte(n1);
23    let confirm = |byte| byte != n1;
24    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
25    let align = USIZE_BYTES - 1;
26    let start_ptr = haystack.as_ptr();
27
28    unsafe {
29        let end_ptr = haystack.as_ptr().add(haystack.len());
30        let mut ptr = start_ptr;
31
32        if haystack.len() < USIZE_BYTES {
33            return forward_search(start_ptr, end_ptr, ptr, confirm);
34        }
35
36        let chunk = read_unaligned_usize(ptr);
37        if (chunk ^ vn1) != 0 {
38            return forward_search(start_ptr, end_ptr, ptr, confirm);
39        }
40
41        ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
42        debug_assert!(ptr > start_ptr);
43        debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
44        while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
45            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
46
47            let a = *(ptr as *const usize);
48            let b = *(ptr.add(USIZE_BYTES) as *const usize);
49            let eqa = (a ^ vn1) != 0;
50            let eqb = (b ^ vn1) != 0;
51            if eqa || eqb {
52                break;
53            }
54            ptr = ptr.add(LOOP_SIZE);
55        }
56        forward_search(start_ptr, end_ptr, ptr, confirm)
57    }
58}
59
60/// Return the last index not matching the byte `x` in `text`.
61pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
62    let vn1 = repeat_byte(n1);
63    let confirm = |byte| byte != n1;
64    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
65    let align = USIZE_BYTES - 1;
66    let start_ptr = haystack.as_ptr();
67
68    unsafe {
69        let end_ptr = haystack.as_ptr().add(haystack.len());
70        let mut ptr = end_ptr;
71
72        if haystack.len() < USIZE_BYTES {
73            return reverse_search(start_ptr, end_ptr, ptr, confirm);
74        }
75
76        let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
77        if (chunk ^ vn1) != 0 {
78            return reverse_search(start_ptr, end_ptr, ptr, confirm);
79        }
80
81        ptr = ptr.sub(end_ptr as usize & align);
82        debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
83        while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
84            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
85
86            let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
87            let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
88            let eqa = (a ^ vn1) != 0;
89            let eqb = (b ^ vn1) != 0;
90            if eqa || eqb {
91                break;
92            }
93            ptr = ptr.sub(loop_size);
94        }
95        reverse_search(start_ptr, end_ptr, ptr, confirm)
96    }
97}
98
99#[inline(always)]
100unsafe fn forward_search<F: Fn(u8) -> bool>(
101    start_ptr: *const u8,
102    end_ptr: *const u8,
103    mut ptr: *const u8,
104    confirm: F,
105) -> Option<usize> {
106    debug_assert!(start_ptr <= ptr);
107    debug_assert!(ptr <= end_ptr);
108
109    while ptr < end_ptr {
110        if confirm(*ptr) {
111            return Some(sub(ptr, start_ptr));
112        }
113        ptr = ptr.offset(1);
114    }
115    None
116}
117
118#[inline(always)]
119unsafe fn reverse_search<F: Fn(u8) -> bool>(
120    start_ptr: *const u8,
121    end_ptr: *const u8,
122    mut ptr: *const u8,
123    confirm: F,
124) -> Option<usize> {
125    debug_assert!(start_ptr <= ptr);
126    debug_assert!(ptr <= end_ptr);
127
128    while ptr > start_ptr {
129        ptr = ptr.offset(-1);
130        if confirm(*ptr) {
131            return Some(sub(ptr, start_ptr));
132        }
133    }
134    None
135}
136
137unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
138    (ptr as *const usize).read_unaligned()
139}
140
141/// Subtract `b` from `a` and return the difference. `a` should be greater than
142/// or equal to `b`.
143fn sub(a: *const u8, b: *const u8) -> usize {
144    debug_assert!(a >= b);
145    (a as usize) - (b as usize)
146}
147
148/// Safe wrapper around `forward_search`
149#[inline]
150pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
151    s: &[u8],
152    confirm: F,
153) -> Option<usize> {
154    unsafe {
155        let start = s.as_ptr();
156        let end = start.add(s.len());
157        forward_search(start, end, start, confirm)
158    }
159}
160
161/// Safe wrapper around `reverse_search`
162#[inline]
163pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
164    s: &[u8],
165    confirm: F,
166) -> Option<usize> {
167    unsafe {
168        let start = s.as_ptr();
169        let end = start.add(s.len());
170        reverse_search(start, end, end, confirm)
171    }
172}
173
174#[cfg(all(test, feature = "std"))]
175mod tests {
176    use alloc::{vec, vec::Vec};
177
178    use super::{inv_memchr, inv_memrchr};
179
180    // search string, search byte, inv_memchr result, inv_memrchr result.
181    // these are expanded into a much larger set of tests in build_tests
182    const TESTS: &[(&[u8], u8, usize, usize)] = &[
183        (b"z", b'a', 0, 0),
184        (b"zz", b'a', 0, 1),
185        (b"aza", b'a', 1, 1),
186        (b"zaz", b'a', 0, 2),
187        (b"zza", b'a', 0, 1),
188        (b"zaa", b'a', 0, 0),
189        (b"zzz", b'a', 0, 2),
190    ];
191
192    type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
193
194    fn build_tests() -> Vec<TestCase> {
195        #[cfg(not(miri))]
196        const MAX_PER: usize = 515;
197        #[cfg(miri)]
198        const MAX_PER: usize = 10;
199
200        let mut result = vec![];
201        for &(search, byte, fwd_pos, rev_pos) in TESTS {
202            result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
203            for i in 1..MAX_PER {
204                // add a bunch of copies of the search byte to the end.
205                let mut suffixed: Vec<u8> = search.into();
206                suffixed.extend(std::iter::repeat(byte).take(i));
207                result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
208
209                // add a bunch of copies of the search byte to the start.
210                let mut prefixed: Vec<u8> =
211                    std::iter::repeat(byte).take(i).collect();
212                prefixed.extend(search);
213                result.push((
214                    prefixed,
215                    byte,
216                    Some((fwd_pos + i, rev_pos + i)),
217                ));
218
219                // add a bunch of copies of the search byte to both ends.
220                let mut surrounded: Vec<u8> =
221                    std::iter::repeat(byte).take(i).collect();
222                surrounded.extend(search);
223                surrounded.extend(std::iter::repeat(byte).take(i));
224                result.push((
225                    surrounded,
226                    byte,
227                    Some((fwd_pos + i, rev_pos + i)),
228                ));
229            }
230        }
231
232        // build non-matching tests for several sizes
233        for i in 0..MAX_PER {
234            result.push((
235                std::iter::repeat(b'\0').take(i).collect(),
236                b'\0',
237                None,
238            ));
239        }
240
241        result
242    }
243
244    #[test]
245    fn test_inv_memchr() {
246        use crate::{ByteSlice, B};
247
248        #[cfg(not(miri))]
249        const MAX_OFFSET: usize = 130;
250        #[cfg(miri)]
251        const MAX_OFFSET: usize = 13;
252
253        for (search, byte, matching) in build_tests() {
254            assert_eq!(
255                inv_memchr(byte, &search),
256                matching.map(|m| m.0),
257                "inv_memchr when searching for {:?} in {:?}",
258                byte as char,
259                // better printing
260                B(&search).as_bstr(),
261            );
262            assert_eq!(
263                inv_memrchr(byte, &search),
264                matching.map(|m| m.1),
265                "inv_memrchr when searching for {:?} in {:?}",
266                byte as char,
267                // better printing
268                B(&search).as_bstr(),
269            );
270            // Test a rather large number off offsets for potential alignment
271            // issues.
272            for offset in 1..MAX_OFFSET {
273                if offset >= search.len() {
274                    break;
275                }
276                // If this would cause us to shift the results off the end,
277                // skip it so that we don't have to recompute them.
278                if let Some((f, r)) = matching {
279                    if offset > f || offset > r {
280                        break;
281                    }
282                }
283                let realigned = &search[offset..];
284
285                let forward_pos = matching.map(|m| m.0 - offset);
286                let reverse_pos = matching.map(|m| m.1 - offset);
287
288                assert_eq!(
289                    inv_memchr(byte, &realigned),
290                    forward_pos,
291                    "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
292                    offset,
293                    byte as char,
294                    realigned.as_bstr(),
295                );
296                assert_eq!(
297                    inv_memrchr(byte, &realigned),
298                    reverse_pos,
299                    "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
300                    offset,
301                    byte as char,
302                    realigned.as_bstr(),
303                );
304            }
305        }
306    }
307}