bstr/
ascii.rs

1// The following ~400 lines of code exists for exactly one purpose, which is
2// to optimize this code:
3//
4//     byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len())
5//
6// Yes... Overengineered is a word that comes to mind, but this is effectively
7// a very similar problem to memchr, and virtually nobody has been able to
8// resist optimizing the crap out of that (except for perhaps the BSD and MUSL
9// folks). In particular, this routine makes a very common case (ASCII) very
10// fast, which seems worth it. We do stop short of adding AVX variants of the
11// code below in order to retain our sanity and also to avoid needing to deal
12// with runtime target feature detection. RESIST!
13//
14// In order to understand the SIMD version below, it would be good to read this
15// comment describing how my memchr routine works:
16// https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106
17//
18// The primary difference with memchr is that for ASCII, we can do a bit less
19// work. In particular, we don't need to detect the presence of a specific
20// byte, but rather, whether any byte has its most significant bit set. That
21// means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to
22// _mm_movemask_epi8.
23
24#[cfg(any(test, miri, not(target_arch = "x86_64")))]
25const USIZE_BYTES: usize = core::mem::size_of::<usize>();
26#[cfg(any(test, miri, not(target_arch = "x86_64")))]
27const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES;
28
29// This is a mask where the most significant bit of each byte in the usize
30// is set. We test this bit to determine whether a character is ASCII or not.
31// Namely, a single byte is regarded as an ASCII codepoint if and only if it's
32// most significant bit is not set.
33#[cfg(any(test, miri, not(target_arch = "x86_64")))]
34const ASCII_MASK_U64: u64 = 0x8080808080808080;
35#[cfg(any(test, miri, not(target_arch = "x86_64")))]
36const ASCII_MASK: usize = ASCII_MASK_U64 as usize;
37
38/// Returns the index of the first non ASCII byte in the given slice.
39///
40/// If slice only contains ASCII bytes, then the length of the slice is
41/// returned.
42pub fn first_non_ascii_byte(slice: &[u8]) -> usize {
43    #[cfg(any(miri, not(target_arch = "x86_64")))]
44    {
45        first_non_ascii_byte_fallback(slice)
46    }
47
48    #[cfg(all(not(miri), target_arch = "x86_64"))]
49    {
50        first_non_ascii_byte_sse2(slice)
51    }
52}
53
54#[cfg(any(test, miri, not(target_arch = "x86_64")))]
55fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize {
56    let align = USIZE_BYTES - 1;
57    let start_ptr = slice.as_ptr();
58    let end_ptr = slice[slice.len()..].as_ptr();
59    let mut ptr = start_ptr;
60
61    unsafe {
62        if slice.len() < USIZE_BYTES {
63            return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
64        }
65
66        let chunk = read_unaligned_usize(ptr);
67        let mask = chunk & ASCII_MASK;
68        if mask != 0 {
69            return first_non_ascii_byte_mask(mask);
70        }
71
72        ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align));
73        debug_assert!(ptr > start_ptr);
74        debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr);
75        if slice.len() >= FALLBACK_LOOP_SIZE {
76            while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) {
77                debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
78
79                let a = *(ptr as *const usize);
80                let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
81                if (a | b) & ASCII_MASK != 0 {
82                    // What a kludge. We wrap the position finding code into
83                    // a non-inlineable function, which makes the codegen in
84                    // the tight loop above a bit better by avoiding a
85                    // couple extra movs. We pay for it by two additional
86                    // stores, but only in the case of finding a non-ASCII
87                    // byte.
88                    #[inline(never)]
89                    unsafe fn findpos(
90                        start_ptr: *const u8,
91                        ptr: *const u8,
92                    ) -> usize {
93                        let a = *(ptr as *const usize);
94                        let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize);
95
96                        let mut at = sub(ptr, start_ptr);
97                        let maska = a & ASCII_MASK;
98                        if maska != 0 {
99                            return at + first_non_ascii_byte_mask(maska);
100                        }
101
102                        at += USIZE_BYTES;
103                        let maskb = b & ASCII_MASK;
104                        debug_assert!(maskb != 0);
105                        return at + first_non_ascii_byte_mask(maskb);
106                    }
107                    return findpos(start_ptr, ptr);
108                }
109                ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE);
110            }
111        }
112        first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
113    }
114}
115
116#[cfg(all(not(miri), target_arch = "x86_64"))]
117fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize {
118    use core::arch::x86_64::*;
119
120    const VECTOR_SIZE: usize = core::mem::size_of::<__m128i>();
121    const VECTOR_ALIGN: usize = VECTOR_SIZE - 1;
122    const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE;
123
124    let start_ptr = slice.as_ptr();
125    let end_ptr = slice[slice.len()..].as_ptr();
126    let mut ptr = start_ptr;
127
128    unsafe {
129        if slice.len() < VECTOR_SIZE {
130            return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr);
131        }
132
133        let chunk = _mm_loadu_si128(ptr as *const __m128i);
134        let mask = _mm_movemask_epi8(chunk);
135        if mask != 0 {
136            return mask.trailing_zeros() as usize;
137        }
138
139        ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN));
140        debug_assert!(ptr > start_ptr);
141        debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr);
142        if slice.len() >= VECTOR_LOOP_SIZE {
143            while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) {
144                debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE);
145
146                let a = _mm_load_si128(ptr as *const __m128i);
147                let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i);
148                let c =
149                    _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i);
150                let d =
151                    _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i);
152
153                let or1 = _mm_or_si128(a, b);
154                let or2 = _mm_or_si128(c, d);
155                let or3 = _mm_or_si128(or1, or2);
156                if _mm_movemask_epi8(or3) != 0 {
157                    let mut at = sub(ptr, start_ptr);
158                    let mask = _mm_movemask_epi8(a);
159                    if mask != 0 {
160                        return at + mask.trailing_zeros() as usize;
161                    }
162
163                    at += VECTOR_SIZE;
164                    let mask = _mm_movemask_epi8(b);
165                    if mask != 0 {
166                        return at + mask.trailing_zeros() as usize;
167                    }
168
169                    at += VECTOR_SIZE;
170                    let mask = _mm_movemask_epi8(c);
171                    if mask != 0 {
172                        return at + mask.trailing_zeros() as usize;
173                    }
174
175                    at += VECTOR_SIZE;
176                    let mask = _mm_movemask_epi8(d);
177                    debug_assert!(mask != 0);
178                    return at + mask.trailing_zeros() as usize;
179                }
180                ptr = ptr_add(ptr, VECTOR_LOOP_SIZE);
181            }
182        }
183        while ptr <= end_ptr.sub(VECTOR_SIZE) {
184            debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE);
185
186            let chunk = _mm_loadu_si128(ptr as *const __m128i);
187            let mask = _mm_movemask_epi8(chunk);
188            if mask != 0 {
189                return sub(ptr, start_ptr) + mask.trailing_zeros() as usize;
190            }
191            ptr = ptr.add(VECTOR_SIZE);
192        }
193        first_non_ascii_byte_slow(start_ptr, end_ptr, ptr)
194    }
195}
196
197#[inline(always)]
198unsafe fn first_non_ascii_byte_slow(
199    start_ptr: *const u8,
200    end_ptr: *const u8,
201    mut ptr: *const u8,
202) -> usize {
203    debug_assert!(start_ptr <= ptr);
204    debug_assert!(ptr <= end_ptr);
205
206    while ptr < end_ptr {
207        if *ptr > 0x7F {
208            return sub(ptr, start_ptr);
209        }
210        ptr = ptr.offset(1);
211    }
212    sub(end_ptr, start_ptr)
213}
214
215/// Compute the position of the first ASCII byte in the given mask.
216///
217/// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is
218/// 8 contiguous bytes of the slice being checked where *at least* one of those
219/// bytes is not an ASCII byte.
220///
221/// The position returned is always in the inclusive range [0, 7].
222#[cfg(any(test, miri, not(target_arch = "x86_64")))]
223fn first_non_ascii_byte_mask(mask: usize) -> usize {
224    #[cfg(target_endian = "little")]
225    {
226        mask.trailing_zeros() as usize / 8
227    }
228    #[cfg(target_endian = "big")]
229    {
230        mask.leading_zeros() as usize / 8
231    }
232}
233
234/// Increment the given pointer by the given amount.
235unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 {
236    debug_assert!(amt < ::core::isize::MAX as usize);
237    ptr.offset(amt as isize)
238}
239
240/// Decrement the given pointer by the given amount.
241unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 {
242    debug_assert!(amt < ::core::isize::MAX as usize);
243    ptr.offset((amt as isize).wrapping_neg())
244}
245
246#[cfg(any(test, miri, not(target_arch = "x86_64")))]
247unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
248    use core::ptr;
249
250    let mut n: usize = 0;
251    ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES);
252    n
253}
254
255/// Subtract `b` from `a` and return the difference. `a` should be greater than
256/// or equal to `b`.
257fn sub(a: *const u8, b: *const u8) -> usize {
258    debug_assert!(a >= b);
259    (a as usize) - (b as usize)
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    // Our testing approach here is to try and exhaustively test every case.
267    // This includes the position at which a non-ASCII byte occurs in addition
268    // to the alignment of the slice that we're searching.
269
270    #[test]
271    fn positive_fallback_forward() {
272        for i in 0..517 {
273            let s = "a".repeat(i);
274            assert_eq!(
275                i,
276                first_non_ascii_byte_fallback(s.as_bytes()),
277                "i: {:?}, len: {:?}, s: {:?}",
278                i,
279                s.len(),
280                s
281            );
282        }
283    }
284
285    #[test]
286    #[cfg(target_arch = "x86_64")]
287    #[cfg(not(miri))]
288    fn positive_sse2_forward() {
289        for i in 0..517 {
290            let b = "a".repeat(i).into_bytes();
291            assert_eq!(b.len(), first_non_ascii_byte_sse2(&b));
292        }
293    }
294
295    #[test]
296    #[cfg(not(miri))]
297    fn negative_fallback_forward() {
298        for i in 0..517 {
299            for align in 0..65 {
300                let mut s = "a".repeat(i);
301                s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
302                let s = s.get(align..).unwrap_or("");
303                assert_eq!(
304                    i.saturating_sub(align),
305                    first_non_ascii_byte_fallback(s.as_bytes()),
306                    "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
307                    i,
308                    align,
309                    s.len(),
310                    s
311                );
312            }
313        }
314    }
315
316    #[test]
317    #[cfg(target_arch = "x86_64")]
318    #[cfg(not(miri))]
319    fn negative_sse2_forward() {
320        for i in 0..517 {
321            for align in 0..65 {
322                let mut s = "a".repeat(i);
323                s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃");
324                let s = s.get(align..).unwrap_or("");
325                assert_eq!(
326                    i.saturating_sub(align),
327                    first_non_ascii_byte_sse2(s.as_bytes()),
328                    "i: {:?}, align: {:?}, len: {:?}, s: {:?}",
329                    i,
330                    align,
331                    s.len(),
332                    s
333                );
334            }
335        }
336    }
337}