bstr/byteset/
scalar.rs

1// This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2// the 'inverse' query of memchr, e.g. finding the first byte not in the
3// provided set. This is simple for the 1-byte case.
4
5use core::{cmp, usize};
6
7const USIZE_BYTES: usize = core::mem::size_of::<usize>();
8
9// The number of bytes to loop at in one iteration of memchr/memrchr.
10const LOOP_SIZE: usize = 2 * USIZE_BYTES;
11
12/// Repeat the given byte into a word size number. That is, every 8 bits
13/// is equivalent to the given byte. For example, if `b` is `\x4E` or
14/// `01001110` in binary, then the returned value on a 32-bit system would be:
15/// `01001110_01001110_01001110_01001110`.
16#[inline(always)]
17fn repeat_byte(b: u8) -> usize {
18    (b as usize) * (usize::MAX / 255)
19}
20
21pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
22    let vn1 = repeat_byte(n1);
23    let confirm = |byte| byte != n1;
24    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
25    let align = USIZE_BYTES - 1;
26    let start_ptr = haystack.as_ptr();
27
28    unsafe {
29        let end_ptr = haystack.as_ptr().add(haystack.len());
30        let mut ptr = start_ptr;
31
32        if haystack.len() < USIZE_BYTES {
33            return forward_search(start_ptr, end_ptr, ptr, confirm);
34        }
35
36        let chunk = read_unaligned_usize(ptr);
37        if (chunk ^ vn1) != 0 {
38            return forward_search(start_ptr, end_ptr, ptr, confirm);
39        }
40
41        ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
42        debug_assert!(ptr > start_ptr);
43        debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
44        while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
45            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
46
47            let a = *(ptr as *const usize);
48            let b = *(ptr.add(USIZE_BYTES) as *const usize);
49            let eqa = (a ^ vn1) != 0;
50            let eqb = (b ^ vn1) != 0;
51            if eqa || eqb {
52                break;
53            }
54            ptr = ptr.add(LOOP_SIZE);
55        }
56        forward_search(start_ptr, end_ptr, ptr, confirm)
57    }
58}
59
60/// Return the last index not matching the byte `x` in `text`.
61pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
62    let vn1 = repeat_byte(n1);
63    let confirm = |byte| byte != n1;
64    let loop_size = cmp::min(LOOP_SIZE, haystack.len());
65    let align = USIZE_BYTES - 1;
66    let start_ptr = haystack.as_ptr();
67
68    unsafe {
69        let end_ptr = haystack.as_ptr().add(haystack.len());
70        let mut ptr = end_ptr;
71
72        if haystack.len() < USIZE_BYTES {
73            return reverse_search(start_ptr, end_ptr, ptr, confirm);
74        }
75
76        let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
77        if (chunk ^ vn1) != 0 {
78            return reverse_search(start_ptr, end_ptr, ptr, confirm);
79        }
80
81        ptr = ptr.sub(end_ptr as usize & align);
82        debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
83        while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
84            debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
85
86            let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
87            let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
88            let eqa = (a ^ vn1) != 0;
89            let eqb = (b ^ vn1) != 0;
90            if eqa || eqb {
91                break;
92            }
93            ptr = ptr.sub(loop_size);
94        }
95        reverse_search(start_ptr, end_ptr, ptr, confirm)
96    }
97}
98
99#[inline(always)]
100unsafe fn forward_search<F: Fn(u8) -> bool>(
101    start_ptr: *const u8,
102    end_ptr: *const u8,
103    mut ptr: *const u8,
104    confirm: F,
105) -> Option<usize> {
106    debug_assert!(start_ptr <= ptr);
107    debug_assert!(ptr <= end_ptr);
108
109    while ptr < end_ptr {
110        if confirm(*ptr) {
111            return Some(sub(ptr, start_ptr));
112        }
113        ptr = ptr.offset(1);
114    }
115    None
116}
117
118#[inline(always)]
119unsafe fn reverse_search<F: Fn(u8) -> bool>(
120    start_ptr: *const u8,
121    end_ptr: *const u8,
122    mut ptr: *const u8,
123    confirm: F,
124) -> Option<usize> {
125    debug_assert!(start_ptr <= ptr);
126    debug_assert!(ptr <= end_ptr);
127
128    while ptr > start_ptr {
129        ptr = ptr.offset(-1);
130        if confirm(*ptr) {
131            return Some(sub(ptr, start_ptr));
132        }
133    }
134    None
135}
136
137unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
138    (ptr as *const usize).read_unaligned()
139}
140
141/// Subtract `b` from `a` and return the difference. `a` should be greater than
142/// or equal to `b`.
143fn sub(a: *const u8, b: *const u8) -> usize {
144    debug_assert!(a >= b);
145    (a as usize) - (b as usize)
146}
147
148/// Safe wrapper around `forward_search`
149#[inline]
150pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
151    s: &[u8],
152    confirm: F,
153) -> Option<usize> {
154    unsafe {
155        let start = s.as_ptr();
156        let end = start.add(s.len());
157        forward_search(start, end, start, confirm)
158    }
159}
160
161/// Safe wrapper around `reverse_search`
162#[inline]
163pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
164    s: &[u8],
165    confirm: F,
166) -> Option<usize> {
167    unsafe {
168        let start = s.as_ptr();
169        let end = start.add(s.len());
170        reverse_search(start, end, end, confirm)
171    }
172}
173
174#[cfg(all(test, feature = "std"))]
175mod tests {
176    use super::{inv_memchr, inv_memrchr};
177
178    // search string, search byte, inv_memchr result, inv_memrchr result.
179    // these are expanded into a much larger set of tests in build_tests
180    const TESTS: &[(&[u8], u8, usize, usize)] = &[
181        (b"z", b'a', 0, 0),
182        (b"zz", b'a', 0, 1),
183        (b"aza", b'a', 1, 1),
184        (b"zaz", b'a', 0, 2),
185        (b"zza", b'a', 0, 1),
186        (b"zaa", b'a', 0, 0),
187        (b"zzz", b'a', 0, 2),
188    ];
189
190    type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
191
192    fn build_tests() -> Vec<TestCase> {
193        #[cfg(not(miri))]
194        const MAX_PER: usize = 515;
195        #[cfg(miri)]
196        const MAX_PER: usize = 10;
197
198        let mut result = vec![];
199        for &(search, byte, fwd_pos, rev_pos) in TESTS {
200            result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
201            for i in 1..MAX_PER {
202                // add a bunch of copies of the search byte to the end.
203                let mut suffixed: Vec<u8> = search.into();
204                suffixed.extend(std::iter::repeat(byte).take(i));
205                result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
206
207                // add a bunch of copies of the search byte to the start.
208                let mut prefixed: Vec<u8> =
209                    std::iter::repeat(byte).take(i).collect();
210                prefixed.extend(search);
211                result.push((
212                    prefixed,
213                    byte,
214                    Some((fwd_pos + i, rev_pos + i)),
215                ));
216
217                // add a bunch of copies of the search byte to both ends.
218                let mut surrounded: Vec<u8> =
219                    std::iter::repeat(byte).take(i).collect();
220                surrounded.extend(search);
221                surrounded.extend(std::iter::repeat(byte).take(i));
222                result.push((
223                    surrounded,
224                    byte,
225                    Some((fwd_pos + i, rev_pos + i)),
226                ));
227            }
228        }
229
230        // build non-matching tests for several sizes
231        for i in 0..MAX_PER {
232            result.push((
233                std::iter::repeat(b'\0').take(i).collect(),
234                b'\0',
235                None,
236            ));
237        }
238
239        result
240    }
241
242    #[test]
243    fn test_inv_memchr() {
244        use crate::{ByteSlice, B};
245
246        #[cfg(not(miri))]
247        const MAX_OFFSET: usize = 130;
248        #[cfg(miri)]
249        const MAX_OFFSET: usize = 13;
250
251        for (search, byte, matching) in build_tests() {
252            assert_eq!(
253                inv_memchr(byte, &search),
254                matching.map(|m| m.0),
255                "inv_memchr when searching for {:?} in {:?}",
256                byte as char,
257                // better printing
258                B(&search).as_bstr(),
259            );
260            assert_eq!(
261                inv_memrchr(byte, &search),
262                matching.map(|m| m.1),
263                "inv_memrchr when searching for {:?} in {:?}",
264                byte as char,
265                // better printing
266                B(&search).as_bstr(),
267            );
268            // Test a rather large number off offsets for potential alignment
269            // issues.
270            for offset in 1..MAX_OFFSET {
271                if offset >= search.len() {
272                    break;
273                }
274                // If this would cause us to shift the results off the end,
275                // skip it so that we don't have to recompute them.
276                if let Some((f, r)) = matching {
277                    if offset > f || offset > r {
278                        break;
279                    }
280                }
281                let realigned = &search[offset..];
282
283                let forward_pos = matching.map(|m| m.0 - offset);
284                let reverse_pos = matching.map(|m| m.1 - offset);
285
286                assert_eq!(
287                    inv_memchr(byte, &realigned),
288                    forward_pos,
289                    "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
290                    offset,
291                    byte as char,
292                    realigned.as_bstr(),
293                );
294                assert_eq!(
295                    inv_memrchr(byte, &realigned),
296                    reverse_pos,
297                    "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
298                    offset,
299                    byte as char,
300                    realigned.as_bstr(),
301                );
302            }
303        }
304    }
305}