aho_corasick/packed/teddy/
compile.rs

1// See the README in this directory for an explanation of the Teddy algorithm.
2
3use std::cmp;
4use std::collections::BTreeMap;
5use std::fmt;
6
7use crate::packed::pattern::{PatternID, Patterns};
8use crate::packed::teddy::Teddy;
9
10/// A builder for constructing a Teddy matcher.
11///
12/// The builder primarily permits fine grained configuration of the Teddy
13/// matcher. Most options are made only available for testing/benchmarking
14/// purposes. In reality, options are automatically determined by the nature
15/// and number of patterns given to the builder.
16#[derive(Clone, Debug)]
17pub struct Builder {
18    /// When none, this is automatically determined. Otherwise, `false` means
19    /// slim Teddy is used (8 buckets) and `true` means fat Teddy is used
20    /// (16 buckets). Fat Teddy requires AVX2, so if that CPU feature isn't
21    /// available and Fat Teddy was requested, no matcher will be built.
22    fat: Option<bool>,
23    /// When none, this is automatically determined. Otherwise, `false` means
24    /// that 128-bit vectors will be used (up to SSSE3 instructions) where as
25    /// `true` means that 256-bit vectors will be used. As with `fat`, if
26    /// 256-bit vectors are requested and they aren't available, then a
27    /// searcher will not be built.
28    avx: Option<bool>,
29}
30
31impl Default for Builder {
32    fn default() -> Builder {
33        Builder::new()
34    }
35}
36
37impl Builder {
38    /// Create a new builder for configuring a Teddy matcher.
39    pub fn new() -> Builder {
40        Builder { fat: None, avx: None }
41    }
42
43    /// Build a matcher for the set of patterns given. If a matcher could not
44    /// be built, then `None` is returned.
45    ///
46    /// Generally, a matcher isn't built if the necessary CPU features aren't
47    /// available, an unsupported target or if the searcher is believed to be
48    /// slower than standard techniques (i.e., if there are too many literals).
49    pub fn build(&self, patterns: &Patterns) -> Option<Teddy> {
50        self.build_imp(patterns)
51    }
52
53    /// Require the use of Fat (true) or Slim (false) Teddy. Fat Teddy uses
54    /// 16 buckets where as Slim Teddy uses 8 buckets. More buckets are useful
55    /// for a larger set of literals.
56    ///
57    /// `None` is the default, which results in an automatic selection based
58    /// on the number of literals and available CPU features.
59    pub fn fat(&mut self, yes: Option<bool>) -> &mut Builder {
60        self.fat = yes;
61        self
62    }
63
64    /// Request the use of 256-bit vectors (true) or 128-bit vectors (false).
65    /// Generally, a larger vector size is better since it either permits
66    /// matching more patterns or matching more bytes in the haystack at once.
67    ///
68    /// `None` is the default, which results in an automatic selection based on
69    /// the number of literals and available CPU features.
70    pub fn avx(&mut self, yes: Option<bool>) -> &mut Builder {
71        self.avx = yes;
72        self
73    }
74
75    fn build_imp(&self, patterns: &Patterns) -> Option<Teddy> {
76        use crate::packed::teddy::runtime;
77
78        // Most of the logic here is just about selecting the optimal settings,
79        // or perhaps even rejecting construction altogether. The choices
80        // we have are: fat (avx only) or not, ssse3 or avx2, and how many
81        // patterns we allow ourselves to search. Additionally, for testing
82        // and benchmarking, we permit callers to try to "force" a setting,
83        // and if the setting isn't allowed (e.g., forcing AVX when AVX isn't
84        // available), then we bail and return nothing.
85
86        if patterns.len() > 64 {
87            return None;
88        }
89        let has_ssse3 = is_x86_feature_detected!("ssse3");
90        let has_avx = is_x86_feature_detected!("avx2");
91        let avx = if self.avx == Some(true) {
92            if !has_avx {
93                return None;
94            }
95            true
96        } else if self.avx == Some(false) {
97            if !has_ssse3 {
98                return None;
99            }
100            false
101        } else if !has_ssse3 && !has_avx {
102            return None;
103        } else {
104            has_avx
105        };
106        let fat = match self.fat {
107            None => avx && patterns.len() > 32,
108            Some(false) => false,
109            Some(true) if !avx => return None,
110            Some(true) => true,
111        };
112
113        let mut compiler = Compiler::new(patterns, fat);
114        compiler.compile();
115        let Compiler { buckets, masks, .. } = compiler;
116        // SAFETY: It is required that the builder only produce Teddy matchers
117        // that are allowed to run on the current CPU, since we later assume
118        // that the presence of (for example) TeddySlim1Mask256 means it is
119        // safe to call functions marked with the `avx2` target feature.
120        match (masks.len(), avx, fat) {
121            (1, false, _) => Some(Teddy {
122                buckets,
123                max_pattern_id: patterns.max_pattern_id(),
124                exec: runtime::Exec::TeddySlim1Mask128(
125                    runtime::TeddySlim1Mask128 {
126                        mask1: runtime::Mask128::new(masks[0]),
127                    },
128                ),
129            }),
130            (1, true, false) => Some(Teddy {
131                buckets,
132                max_pattern_id: patterns.max_pattern_id(),
133                exec: runtime::Exec::TeddySlim1Mask256(
134                    runtime::TeddySlim1Mask256 {
135                        mask1: runtime::Mask256::new(masks[0]),
136                    },
137                ),
138            }),
139            (1, true, true) => Some(Teddy {
140                buckets,
141                max_pattern_id: patterns.max_pattern_id(),
142                exec: runtime::Exec::TeddyFat1Mask256(
143                    runtime::TeddyFat1Mask256 {
144                        mask1: runtime::Mask256::new(masks[0]),
145                    },
146                ),
147            }),
148            (2, false, _) => Some(Teddy {
149                buckets,
150                max_pattern_id: patterns.max_pattern_id(),
151                exec: runtime::Exec::TeddySlim2Mask128(
152                    runtime::TeddySlim2Mask128 {
153                        mask1: runtime::Mask128::new(masks[0]),
154                        mask2: runtime::Mask128::new(masks[1]),
155                    },
156                ),
157            }),
158            (2, true, false) => Some(Teddy {
159                buckets,
160                max_pattern_id: patterns.max_pattern_id(),
161                exec: runtime::Exec::TeddySlim2Mask256(
162                    runtime::TeddySlim2Mask256 {
163                        mask1: runtime::Mask256::new(masks[0]),
164                        mask2: runtime::Mask256::new(masks[1]),
165                    },
166                ),
167            }),
168            (2, true, true) => Some(Teddy {
169                buckets,
170                max_pattern_id: patterns.max_pattern_id(),
171                exec: runtime::Exec::TeddyFat2Mask256(
172                    runtime::TeddyFat2Mask256 {
173                        mask1: runtime::Mask256::new(masks[0]),
174                        mask2: runtime::Mask256::new(masks[1]),
175                    },
176                ),
177            }),
178            (3, false, _) => Some(Teddy {
179                buckets,
180                max_pattern_id: patterns.max_pattern_id(),
181                exec: runtime::Exec::TeddySlim3Mask128(
182                    runtime::TeddySlim3Mask128 {
183                        mask1: runtime::Mask128::new(masks[0]),
184                        mask2: runtime::Mask128::new(masks[1]),
185                        mask3: runtime::Mask128::new(masks[2]),
186                    },
187                ),
188            }),
189            (3, true, false) => Some(Teddy {
190                buckets,
191                max_pattern_id: patterns.max_pattern_id(),
192                exec: runtime::Exec::TeddySlim3Mask256(
193                    runtime::TeddySlim3Mask256 {
194                        mask1: runtime::Mask256::new(masks[0]),
195                        mask2: runtime::Mask256::new(masks[1]),
196                        mask3: runtime::Mask256::new(masks[2]),
197                    },
198                ),
199            }),
200            (3, true, true) => Some(Teddy {
201                buckets,
202                max_pattern_id: patterns.max_pattern_id(),
203                exec: runtime::Exec::TeddyFat3Mask256(
204                    runtime::TeddyFat3Mask256 {
205                        mask1: runtime::Mask256::new(masks[0]),
206                        mask2: runtime::Mask256::new(masks[1]),
207                        mask3: runtime::Mask256::new(masks[2]),
208                    },
209                ),
210            }),
211            _ => unreachable!(),
212        }
213    }
214}
215
216/// A compiler is in charge of allocating patterns into buckets and generating
217/// the masks necessary for searching.
218#[derive(Clone)]
219struct Compiler<'p> {
220    patterns: &'p Patterns,
221    buckets: Vec<Vec<PatternID>>,
222    masks: Vec<Mask>,
223}
224
225impl<'p> Compiler<'p> {
226    /// Create a new Teddy compiler for the given patterns. If `fat` is true,
227    /// then 16 buckets will be used instead of 8.
228    ///
229    /// This panics if any of the patterns given are empty.
230    fn new(patterns: &'p Patterns, fat: bool) -> Compiler<'p> {
231        let mask_len = cmp::min(3, patterns.minimum_len());
232        assert!(1 <= mask_len && mask_len <= 3);
233
234        Compiler {
235            patterns,
236            buckets: vec![vec![]; if fat { 16 } else { 8 }],
237            masks: vec![Mask::default(); mask_len],
238        }
239    }
240
241    /// Compile the patterns in this compiler into buckets and masks.
242    fn compile(&mut self) {
243        let mut lonibble_to_bucket: BTreeMap<Vec<u8>, usize> = BTreeMap::new();
244        for (id, pattern) in self.patterns.iter() {
245            // We try to be slightly clever in how we assign patterns into
246            // buckets. Generally speaking, we want patterns with the same
247            // prefix to be in the same bucket, since it minimizes the amount
248            // of time we spend churning through buckets in the verification
249            // step.
250            //
251            // So we could assign patterns with the same N-prefix (where N
252            // is the size of the mask, which is one of {1, 2, 3}) to the
253            // same bucket. However, case insensitive searches are fairly
254            // common, so we'd for example, ideally want to treat `abc` and
255            // `ABC` as if they shared the same prefix. ASCII has the nice
256            // property that the lower 4 bits of A and a are the same, so we
257            // therefore group patterns with the same low-nybbe-N-prefix into
258            // the same bucket.
259            //
260            // MOREOVER, this is actually necessary for correctness! In
261            // particular, by grouping patterns with the same prefix into the
262            // same bucket, we ensure that we preserve correct leftmost-first
263            // and leftmost-longest match semantics. In addition to the fact
264            // that `patterns.iter()` iterates in the correct order, this
265            // guarantees that all possible ambiguous matches will occur in
266            // the same bucket. The verification routine could be adjusted to
267            // support correct leftmost match semantics regardless of bucket
268            // allocation, but that results in a performance hit. It's much
269            // nicer to be able to just stop as soon as a match is found.
270            let lonybs = pattern.low_nybbles(self.masks.len());
271            if let Some(&bucket) = lonibble_to_bucket.get(&lonybs) {
272                self.buckets[bucket].push(id);
273            } else {
274                // N.B. We assign buckets in reverse because it shouldn't have
275                // any influence on performance, but it does make it harder to
276                // get leftmost match semantics accidentally correct.
277                let bucket = (self.buckets.len() - 1)
278                    - (id as usize % self.buckets.len());
279                self.buckets[bucket].push(id);
280                lonibble_to_bucket.insert(lonybs, bucket);
281            }
282        }
283        for (bucket_index, bucket) in self.buckets.iter().enumerate() {
284            for &pat_id in bucket {
285                let pat = self.patterns.get(pat_id);
286                for (i, mask) in self.masks.iter_mut().enumerate() {
287                    if self.buckets.len() == 8 {
288                        mask.add_slim(bucket_index as u8, pat.bytes()[i]);
289                    } else {
290                        mask.add_fat(bucket_index as u8, pat.bytes()[i]);
291                    }
292                }
293            }
294        }
295    }
296}
297
298impl<'p> fmt::Debug for Compiler<'p> {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        let mut buckets = vec![vec![]; self.buckets.len()];
301        for (i, bucket) in self.buckets.iter().enumerate() {
302            for &patid in bucket {
303                buckets[i].push(self.patterns.get(patid));
304            }
305        }
306        f.debug_struct("Compiler")
307            .field("buckets", &buckets)
308            .field("masks", &self.masks)
309            .finish()
310    }
311}
312
313/// Mask represents the low and high nybble masks that will be used during
314/// search. Each mask is 32 bytes wide, although only the first 16 bytes are
315/// used for the SSSE3 runtime.
316///
317/// Each byte in the mask corresponds to a 8-bit bitset, where bit `i` is set
318/// if and only if the corresponding nybble is in the ith bucket. The index of
319/// the byte (0-15, inclusive) corresponds to the nybble.
320///
321/// Each mask is used as the target of a shuffle, where the indices for the
322/// shuffle are taken from the haystack. AND'ing the shuffles for both the
323/// low and high masks together also results in 8-bit bitsets, but where bit
324/// `i` is set if and only if the correspond *byte* is in the ith bucket.
325///
326/// During compilation, masks are just arrays. But during search, these masks
327/// are represented as 128-bit or 256-bit vectors.
328///
329/// (See the README is this directory for more details.)
330#[derive(Clone, Copy, Default)]
331pub struct Mask {
332    lo: [u8; 32],
333    hi: [u8; 32],
334}
335
336impl Mask {
337    /// Update this mask by adding the given byte to the given bucket. The
338    /// given bucket must be in the range 0-7.
339    ///
340    /// This is for "slim" Teddy, where there are only 8 buckets.
341    fn add_slim(&mut self, bucket: u8, byte: u8) {
342        assert!(bucket < 8);
343
344        let byte_lo = (byte & 0xF) as usize;
345        let byte_hi = ((byte >> 4) & 0xF) as usize;
346        // When using 256-bit vectors, we need to set this bucket assignment in
347        // the low and high 128-bit portions of the mask. This allows us to
348        // process 32 bytes at a time. Namely, AVX2 shuffles operate on each
349        // of the 128-bit lanes, rather than the full 256-bit vector at once.
350        self.lo[byte_lo] |= 1 << bucket;
351        self.lo[byte_lo + 16] |= 1 << bucket;
352        self.hi[byte_hi] |= 1 << bucket;
353        self.hi[byte_hi + 16] |= 1 << bucket;
354    }
355
356    /// Update this mask by adding the given byte to the given bucket. The
357    /// given bucket must be in the range 0-15.
358    ///
359    /// This is for "fat" Teddy, where there are 16 buckets.
360    fn add_fat(&mut self, bucket: u8, byte: u8) {
361        assert!(bucket < 16);
362
363        let byte_lo = (byte & 0xF) as usize;
364        let byte_hi = ((byte >> 4) & 0xF) as usize;
365        // Unlike slim teddy, fat teddy only works with AVX2. For fat teddy,
366        // the high 128 bits of our mask correspond to buckets 8-15, while the
367        // low 128 bits correspond to buckets 0-7.
368        if bucket < 8 {
369            self.lo[byte_lo] |= 1 << bucket;
370            self.hi[byte_hi] |= 1 << bucket;
371        } else {
372            self.lo[byte_lo + 16] |= 1 << (bucket % 8);
373            self.hi[byte_hi + 16] |= 1 << (bucket % 8);
374        }
375    }
376
377    /// Return the low 128 bits of the low-nybble mask.
378    pub fn lo128(&self) -> [u8; 16] {
379        let mut tmp = [0; 16];
380        tmp.copy_from_slice(&self.lo[..16]);
381        tmp
382    }
383
384    /// Return the full low-nybble mask.
385    pub fn lo256(&self) -> [u8; 32] {
386        self.lo
387    }
388
389    /// Return the low 128 bits of the high-nybble mask.
390    pub fn hi128(&self) -> [u8; 16] {
391        let mut tmp = [0; 16];
392        tmp.copy_from_slice(&self.hi[..16]);
393        tmp
394    }
395
396    /// Return the full high-nybble mask.
397    pub fn hi256(&self) -> [u8; 32] {
398        self.hi
399    }
400}
401
402impl fmt::Debug for Mask {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        let (mut parts_lo, mut parts_hi) = (vec![], vec![]);
405        for i in 0..32 {
406            parts_lo.push(format!("{:02}: {:08b}", i, self.lo[i]));
407            parts_hi.push(format!("{:02}: {:08b}", i, self.hi[i]));
408        }
409        f.debug_struct("Mask")
410            .field("lo", &parts_lo)
411            .field("hi", &parts_hi)
412            .finish()
413    }
414}