aho_corasick/
dfa.rs

1use std::mem::size_of;
2
3use crate::ahocorasick::MatchKind;
4use crate::automaton::Automaton;
5use crate::classes::ByteClasses;
6use crate::error::Result;
7use crate::nfa::{PatternID, PatternLength, NFA};
8use crate::prefilter::{Prefilter, PrefilterObj, PrefilterState};
9use crate::state_id::{dead_id, fail_id, premultiply_overflow_error, StateID};
10use crate::Match;
11
12#[derive(Clone, Debug)]
13pub enum DFA<S> {
14    Standard(Standard<S>),
15    ByteClass(ByteClass<S>),
16    Premultiplied(Premultiplied<S>),
17    PremultipliedByteClass(PremultipliedByteClass<S>),
18}
19
20impl<S: StateID> DFA<S> {
21    fn repr(&self) -> &Repr<S> {
22        match *self {
23            DFA::Standard(ref dfa) => dfa.repr(),
24            DFA::ByteClass(ref dfa) => dfa.repr(),
25            DFA::Premultiplied(ref dfa) => dfa.repr(),
26            DFA::PremultipliedByteClass(ref dfa) => dfa.repr(),
27        }
28    }
29
30    pub fn match_kind(&self) -> &MatchKind {
31        &self.repr().match_kind
32    }
33
34    pub fn heap_bytes(&self) -> usize {
35        self.repr().heap_bytes
36    }
37
38    pub fn max_pattern_len(&self) -> usize {
39        self.repr().max_pattern_len
40    }
41
42    pub fn pattern_count(&self) -> usize {
43        self.repr().pattern_count
44    }
45
46    pub fn prefilter(&self) -> Option<&dyn Prefilter> {
47        self.repr().prefilter.as_ref().map(|p| p.as_ref())
48    }
49
50    pub fn start_state(&self) -> S {
51        self.repr().start_id
52    }
53
54    #[inline(always)]
55    pub fn overlapping_find_at(
56        &self,
57        prestate: &mut PrefilterState,
58        haystack: &[u8],
59        at: usize,
60        state_id: &mut S,
61        match_index: &mut usize,
62    ) -> Option<Match> {
63        match *self {
64            DFA::Standard(ref dfa) => dfa.overlapping_find_at(
65                prestate,
66                haystack,
67                at,
68                state_id,
69                match_index,
70            ),
71            DFA::ByteClass(ref dfa) => dfa.overlapping_find_at(
72                prestate,
73                haystack,
74                at,
75                state_id,
76                match_index,
77            ),
78            DFA::Premultiplied(ref dfa) => dfa.overlapping_find_at(
79                prestate,
80                haystack,
81                at,
82                state_id,
83                match_index,
84            ),
85            DFA::PremultipliedByteClass(ref dfa) => dfa.overlapping_find_at(
86                prestate,
87                haystack,
88                at,
89                state_id,
90                match_index,
91            ),
92        }
93    }
94
95    #[inline(always)]
96    pub fn earliest_find_at(
97        &self,
98        prestate: &mut PrefilterState,
99        haystack: &[u8],
100        at: usize,
101        state_id: &mut S,
102    ) -> Option<Match> {
103        match *self {
104            DFA::Standard(ref dfa) => {
105                dfa.earliest_find_at(prestate, haystack, at, state_id)
106            }
107            DFA::ByteClass(ref dfa) => {
108                dfa.earliest_find_at(prestate, haystack, at, state_id)
109            }
110            DFA::Premultiplied(ref dfa) => {
111                dfa.earliest_find_at(prestate, haystack, at, state_id)
112            }
113            DFA::PremultipliedByteClass(ref dfa) => {
114                dfa.earliest_find_at(prestate, haystack, at, state_id)
115            }
116        }
117    }
118
119    #[inline(always)]
120    pub fn find_at_no_state(
121        &self,
122        prestate: &mut PrefilterState,
123        haystack: &[u8],
124        at: usize,
125    ) -> Option<Match> {
126        match *self {
127            DFA::Standard(ref dfa) => {
128                dfa.find_at_no_state(prestate, haystack, at)
129            }
130            DFA::ByteClass(ref dfa) => {
131                dfa.find_at_no_state(prestate, haystack, at)
132            }
133            DFA::Premultiplied(ref dfa) => {
134                dfa.find_at_no_state(prestate, haystack, at)
135            }
136            DFA::PremultipliedByteClass(ref dfa) => {
137                dfa.find_at_no_state(prestate, haystack, at)
138            }
139        }
140    }
141}
142
143#[derive(Clone, Debug)]
144pub struct Standard<S>(Repr<S>);
145
146impl<S: StateID> Standard<S> {
147    fn repr(&self) -> &Repr<S> {
148        &self.0
149    }
150}
151
152impl<S: StateID> Automaton for Standard<S> {
153    type ID = S;
154
155    fn match_kind(&self) -> &MatchKind {
156        &self.repr().match_kind
157    }
158
159    fn anchored(&self) -> bool {
160        self.repr().anchored
161    }
162
163    fn prefilter(&self) -> Option<&dyn Prefilter> {
164        self.repr().prefilter.as_ref().map(|p| p.as_ref())
165    }
166
167    fn start_state(&self) -> S {
168        self.repr().start_id
169    }
170
171    fn is_valid(&self, id: S) -> bool {
172        id.to_usize() < self.repr().state_count
173    }
174
175    fn is_match_state(&self, id: S) -> bool {
176        self.repr().is_match_state(id)
177    }
178
179    fn is_match_or_dead_state(&self, id: S) -> bool {
180        self.repr().is_match_or_dead_state(id)
181    }
182
183    fn get_match(
184        &self,
185        id: S,
186        match_index: usize,
187        end: usize,
188    ) -> Option<Match> {
189        self.repr().get_match(id, match_index, end)
190    }
191
192    fn match_count(&self, id: S) -> usize {
193        self.repr().match_count(id)
194    }
195
196    fn next_state(&self, current: S, input: u8) -> S {
197        let o = current.to_usize() * 256 + input as usize;
198        self.repr().trans[o]
199    }
200}
201
202#[derive(Clone, Debug)]
203pub struct ByteClass<S>(Repr<S>);
204
205impl<S: StateID> ByteClass<S> {
206    fn repr(&self) -> &Repr<S> {
207        &self.0
208    }
209}
210
211impl<S: StateID> Automaton for ByteClass<S> {
212    type ID = S;
213
214    fn match_kind(&self) -> &MatchKind {
215        &self.repr().match_kind
216    }
217
218    fn anchored(&self) -> bool {
219        self.repr().anchored
220    }
221
222    fn prefilter(&self) -> Option<&dyn Prefilter> {
223        self.repr().prefilter.as_ref().map(|p| p.as_ref())
224    }
225
226    fn start_state(&self) -> S {
227        self.repr().start_id
228    }
229
230    fn is_valid(&self, id: S) -> bool {
231        id.to_usize() < self.repr().state_count
232    }
233
234    fn is_match_state(&self, id: S) -> bool {
235        self.repr().is_match_state(id)
236    }
237
238    fn is_match_or_dead_state(&self, id: S) -> bool {
239        self.repr().is_match_or_dead_state(id)
240    }
241
242    fn get_match(
243        &self,
244        id: S,
245        match_index: usize,
246        end: usize,
247    ) -> Option<Match> {
248        self.repr().get_match(id, match_index, end)
249    }
250
251    fn match_count(&self, id: S) -> usize {
252        self.repr().match_count(id)
253    }
254
255    fn next_state(&self, current: S, input: u8) -> S {
256        let alphabet_len = self.repr().byte_classes.alphabet_len();
257        let input = self.repr().byte_classes.get(input);
258        let o = current.to_usize() * alphabet_len + input as usize;
259        self.repr().trans[o]
260    }
261}
262
263#[derive(Clone, Debug)]
264pub struct Premultiplied<S>(Repr<S>);
265
266impl<S: StateID> Premultiplied<S> {
267    fn repr(&self) -> &Repr<S> {
268        &self.0
269    }
270}
271
272impl<S: StateID> Automaton for Premultiplied<S> {
273    type ID = S;
274
275    fn match_kind(&self) -> &MatchKind {
276        &self.repr().match_kind
277    }
278
279    fn anchored(&self) -> bool {
280        self.repr().anchored
281    }
282
283    fn prefilter(&self) -> Option<&dyn Prefilter> {
284        self.repr().prefilter.as_ref().map(|p| p.as_ref())
285    }
286
287    fn start_state(&self) -> S {
288        self.repr().start_id
289    }
290
291    fn is_valid(&self, id: S) -> bool {
292        (id.to_usize() / 256) < self.repr().state_count
293    }
294
295    fn is_match_state(&self, id: S) -> bool {
296        self.repr().is_match_state(id)
297    }
298
299    fn is_match_or_dead_state(&self, id: S) -> bool {
300        self.repr().is_match_or_dead_state(id)
301    }
302
303    fn get_match(
304        &self,
305        id: S,
306        match_index: usize,
307        end: usize,
308    ) -> Option<Match> {
309        if id > self.repr().max_match {
310            return None;
311        }
312        self.repr()
313            .matches
314            .get(id.to_usize() / 256)
315            .and_then(|m| m.get(match_index))
316            .map(|&(id, len)| Match { pattern: id, len, end })
317    }
318
319    fn match_count(&self, id: S) -> usize {
320        let o = id.to_usize() / 256;
321        self.repr().matches[o].len()
322    }
323
324    fn next_state(&self, current: S, input: u8) -> S {
325        let o = current.to_usize() + input as usize;
326        self.repr().trans[o]
327    }
328}
329
330#[derive(Clone, Debug)]
331pub struct PremultipliedByteClass<S>(Repr<S>);
332
333impl<S: StateID> PremultipliedByteClass<S> {
334    fn repr(&self) -> &Repr<S> {
335        &self.0
336    }
337}
338
339impl<S: StateID> Automaton for PremultipliedByteClass<S> {
340    type ID = S;
341
342    fn match_kind(&self) -> &MatchKind {
343        &self.repr().match_kind
344    }
345
346    fn anchored(&self) -> bool {
347        self.repr().anchored
348    }
349
350    fn prefilter(&self) -> Option<&dyn Prefilter> {
351        self.repr().prefilter.as_ref().map(|p| p.as_ref())
352    }
353
354    fn start_state(&self) -> S {
355        self.repr().start_id
356    }
357
358    fn is_valid(&self, id: S) -> bool {
359        (id.to_usize() / self.repr().alphabet_len()) < self.repr().state_count
360    }
361
362    fn is_match_state(&self, id: S) -> bool {
363        self.repr().is_match_state(id)
364    }
365
366    fn is_match_or_dead_state(&self, id: S) -> bool {
367        self.repr().is_match_or_dead_state(id)
368    }
369
370    fn get_match(
371        &self,
372        id: S,
373        match_index: usize,
374        end: usize,
375    ) -> Option<Match> {
376        if id > self.repr().max_match {
377            return None;
378        }
379        self.repr()
380            .matches
381            .get(id.to_usize() / self.repr().alphabet_len())
382            .and_then(|m| m.get(match_index))
383            .map(|&(id, len)| Match { pattern: id, len, end })
384    }
385
386    fn match_count(&self, id: S) -> usize {
387        let o = id.to_usize() / self.repr().alphabet_len();
388        self.repr().matches[o].len()
389    }
390
391    fn next_state(&self, current: S, input: u8) -> S {
392        let input = self.repr().byte_classes.get(input);
393        let o = current.to_usize() + input as usize;
394        self.repr().trans[o]
395    }
396}
397
398#[derive(Clone, Debug)]
399pub struct Repr<S> {
400    match_kind: MatchKind,
401    anchored: bool,
402    premultiplied: bool,
403    start_id: S,
404    /// The length, in bytes, of the longest pattern in this automaton. This
405    /// information is useful for keeping correct buffer sizes when searching
406    /// on streams.
407    max_pattern_len: usize,
408    /// The total number of patterns added to this automaton. This includes
409    /// patterns that may never match.
410    pattern_count: usize,
411    state_count: usize,
412    max_match: S,
413    /// The number of bytes of heap used by this NFA's transition table.
414    heap_bytes: usize,
415    /// A prefilter for quickly detecting candidate matchs, if pertinent.
416    prefilter: Option<PrefilterObj>,
417    byte_classes: ByteClasses,
418    trans: Vec<S>,
419    matches: Vec<Vec<(PatternID, PatternLength)>>,
420}
421
422impl<S: StateID> Repr<S> {
423    /// Returns the total alphabet size for this DFA.
424    ///
425    /// If byte classes are enabled, then this corresponds to the number of
426    /// equivalence classes. If they are disabled, then this is always 256.
427    fn alphabet_len(&self) -> usize {
428        self.byte_classes.alphabet_len()
429    }
430
431    /// Returns true only if the given state is a match state.
432    fn is_match_state(&self, id: S) -> bool {
433        id <= self.max_match && id > dead_id()
434    }
435
436    /// Returns true only if the given state is either a dead state or a match
437    /// state.
438    fn is_match_or_dead_state(&self, id: S) -> bool {
439        id <= self.max_match
440    }
441
442    /// Get the ith match for the given state, where the end position of a
443    /// match was found at `end`.
444    ///
445    /// # Panics
446    ///
447    /// The caller must ensure that the given state identifier is valid,
448    /// otherwise this may panic. The `match_index` need not be valid. That is,
449    /// if the given state has no matches then this returns `None`.
450    fn get_match(
451        &self,
452        id: S,
453        match_index: usize,
454        end: usize,
455    ) -> Option<Match> {
456        if id > self.max_match {
457            return None;
458        }
459        self.matches
460            .get(id.to_usize())
461            .and_then(|m| m.get(match_index))
462            .map(|&(id, len)| Match { pattern: id, len, end })
463    }
464
465    /// Return the total number of matches for the given state.
466    ///
467    /// # Panics
468    ///
469    /// The caller must ensure that the given identifier is valid, or else
470    /// this panics.
471    fn match_count(&self, id: S) -> usize {
472        self.matches[id.to_usize()].len()
473    }
474
475    /// Get the next state given `from` as the current state and `byte` as the
476    /// current input byte.
477    fn next_state(&self, from: S, byte: u8) -> S {
478        let alphabet_len = self.alphabet_len();
479        let byte = self.byte_classes.get(byte);
480        self.trans[from.to_usize() * alphabet_len + byte as usize]
481    }
482
483    /// Set the `byte` transition for the `from` state to point to `to`.
484    fn set_next_state(&mut self, from: S, byte: u8, to: S) {
485        let alphabet_len = self.alphabet_len();
486        let byte = self.byte_classes.get(byte);
487        self.trans[from.to_usize() * alphabet_len + byte as usize] = to;
488    }
489
490    /// Swap the given states in place.
491    fn swap_states(&mut self, id1: S, id2: S) {
492        assert!(!self.premultiplied, "can't swap states in premultiplied DFA");
493
494        let o1 = id1.to_usize() * self.alphabet_len();
495        let o2 = id2.to_usize() * self.alphabet_len();
496        for b in 0..self.alphabet_len() {
497            self.trans.swap(o1 + b, o2 + b);
498        }
499        self.matches.swap(id1.to_usize(), id2.to_usize());
500    }
501
502    /// This routine shuffles all match states in this DFA to the beginning
503    /// of the DFA such that every non-match state appears after every match
504    /// state. (With one exception: the special fail and dead states remain as
505    /// the first two states.)
506    ///
507    /// The purpose of doing this shuffling is to avoid an extra conditional
508    /// in the search loop, and in particular, detecting whether a state is a
509    /// match or not does not need to access any memory.
510    ///
511    /// This updates `self.max_match` to point to the last matching state as
512    /// well as `self.start` if the starting state was moved.
513    fn shuffle_match_states(&mut self) {
514        assert!(
515            !self.premultiplied,
516            "cannot shuffle match states of premultiplied DFA"
517        );
518
519        if self.state_count <= 1 {
520            return;
521        }
522
523        let mut first_non_match = self.start_id.to_usize();
524        while first_non_match < self.state_count
525            && self.matches[first_non_match].len() > 0
526        {
527            first_non_match += 1;
528        }
529
530        let mut swaps: Vec<S> = vec![fail_id(); self.state_count];
531        let mut cur = self.state_count - 1;
532        while cur > first_non_match {
533            if self.matches[cur].len() > 0 {
534                self.swap_states(
535                    S::from_usize(cur),
536                    S::from_usize(first_non_match),
537                );
538                swaps[cur] = S::from_usize(first_non_match);
539                swaps[first_non_match] = S::from_usize(cur);
540
541                first_non_match += 1;
542                while first_non_match < cur
543                    && self.matches[first_non_match].len() > 0
544                {
545                    first_non_match += 1;
546                }
547            }
548            cur -= 1;
549        }
550        for id in (0..self.state_count).map(S::from_usize) {
551            let alphabet_len = self.alphabet_len();
552            let offset = id.to_usize() * alphabet_len;
553            for next in &mut self.trans[offset..offset + alphabet_len] {
554                if swaps[next.to_usize()] != fail_id() {
555                    *next = swaps[next.to_usize()];
556                }
557            }
558        }
559        if swaps[self.start_id.to_usize()] != fail_id() {
560            self.start_id = swaps[self.start_id.to_usize()];
561        }
562        self.max_match = S::from_usize(first_non_match - 1);
563    }
564
565    fn premultiply(&mut self) -> Result<()> {
566        if self.premultiplied || self.state_count <= 1 {
567            return Ok(());
568        }
569
570        let alpha_len = self.alphabet_len();
571        premultiply_overflow_error(
572            S::from_usize(self.state_count - 1),
573            alpha_len,
574        )?;
575
576        for id in (2..self.state_count).map(S::from_usize) {
577            let offset = id.to_usize() * alpha_len;
578            for next in &mut self.trans[offset..offset + alpha_len] {
579                if *next == dead_id() {
580                    continue;
581                }
582                *next = S::from_usize(next.to_usize() * alpha_len);
583            }
584        }
585        self.premultiplied = true;
586        self.start_id = S::from_usize(self.start_id.to_usize() * alpha_len);
587        self.max_match = S::from_usize(self.max_match.to_usize() * alpha_len);
588        Ok(())
589    }
590
591    /// Computes the total amount of heap used by this NFA in bytes.
592    fn calculate_size(&mut self) {
593        let mut size = (self.trans.len() * size_of::<S>())
594            + (self.matches.len()
595                * size_of::<Vec<(PatternID, PatternLength)>>());
596        for state_matches in &self.matches {
597            size +=
598                state_matches.len() * size_of::<(PatternID, PatternLength)>();
599        }
600        size += self.prefilter.as_ref().map_or(0, |p| p.as_ref().heap_bytes());
601        self.heap_bytes = size;
602    }
603}
604
605/// A builder for configuring the determinization of an NFA into a DFA.
606#[derive(Clone, Debug)]
607pub struct Builder {
608    premultiply: bool,
609    byte_classes: bool,
610}
611
612impl Builder {
613    /// Create a new builder for a DFA.
614    pub fn new() -> Builder {
615        Builder { premultiply: true, byte_classes: true }
616    }
617
618    /// Build a DFA from the given NFA.
619    ///
620    /// This returns an error if the state identifiers exceed their
621    /// representation size. This can only happen when state ids are
622    /// premultiplied (which is enabled by default).
623    pub fn build<S: StateID>(&self, nfa: &NFA<S>) -> Result<DFA<S>> {
624        let byte_classes = if self.byte_classes {
625            nfa.byte_classes().clone()
626        } else {
627            ByteClasses::singletons()
628        };
629        let alphabet_len = byte_classes.alphabet_len();
630        let trans = vec![fail_id(); alphabet_len * nfa.state_len()];
631        let matches = vec![vec![]; nfa.state_len()];
632        let mut repr = Repr {
633            match_kind: nfa.match_kind().clone(),
634            anchored: nfa.anchored(),
635            premultiplied: false,
636            start_id: nfa.start_state(),
637            max_pattern_len: nfa.max_pattern_len(),
638            pattern_count: nfa.pattern_count(),
639            state_count: nfa.state_len(),
640            max_match: fail_id(),
641            heap_bytes: 0,
642            prefilter: nfa.prefilter_obj().map(|p| p.clone()),
643            byte_classes: byte_classes.clone(),
644            trans,
645            matches,
646        };
647        for id in (0..nfa.state_len()).map(S::from_usize) {
648            repr.matches[id.to_usize()].extend_from_slice(nfa.matches(id));
649
650            let fail = nfa.failure_transition(id);
651            nfa.iter_all_transitions(&byte_classes, id, |b, mut next| {
652                if next == fail_id() {
653                    next = nfa_next_state_memoized(nfa, &repr, id, fail, b);
654                }
655                repr.set_next_state(id, b, next);
656            });
657        }
658        repr.shuffle_match_states();
659        repr.calculate_size();
660        if self.premultiply {
661            repr.premultiply()?;
662            if byte_classes.is_singleton() {
663                Ok(DFA::Premultiplied(Premultiplied(repr)))
664            } else {
665                Ok(DFA::PremultipliedByteClass(PremultipliedByteClass(repr)))
666            }
667        } else {
668            if byte_classes.is_singleton() {
669                Ok(DFA::Standard(Standard(repr)))
670            } else {
671                Ok(DFA::ByteClass(ByteClass(repr)))
672            }
673        }
674    }
675
676    /// Whether to use byte classes or in the DFA.
677    pub fn byte_classes(&mut self, yes: bool) -> &mut Builder {
678        self.byte_classes = yes;
679        self
680    }
681
682    /// Whether to premultiply state identifier in the DFA.
683    pub fn premultiply(&mut self, yes: bool) -> &mut Builder {
684        self.premultiply = yes;
685        self
686    }
687}
688
689/// This returns the next NFA transition (including resolving failure
690/// transitions), except once it sees a state id less than the id of the DFA
691/// state that is currently being populated, then we no longer need to follow
692/// failure transitions and can instead query the pre-computed state id from
693/// the DFA itself.
694///
695/// In general, this should only be called when a failure transition is seen.
696fn nfa_next_state_memoized<S: StateID>(
697    nfa: &NFA<S>,
698    dfa: &Repr<S>,
699    populating: S,
700    mut current: S,
701    input: u8,
702) -> S {
703    loop {
704        if current < populating {
705            return dfa.next_state(current, input);
706        }
707        let next = nfa.next_state(current, input);
708        if next != fail_id() {
709            return next;
710        }
711        current = nfa.failure_transition(current);
712    }
713}