ucd_trie/
owned.rs

1use std::borrow::Borrow;
2use std::collections::HashMap;
3use std::error;
4use std::fmt;
5use std::io;
6use std::result;
7
8use super::{TrieSetSlice, CHUNK_SIZE};
9
10// This implementation was pretty much cribbed from raphlinus' contribution
11// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
12//
13// The fundamental principle guiding this implementation is to take advantage
14// of the fact that similar Unicode codepoints are often grouped together, and
15// that most boolean Unicode properties are quite sparse over the entire space
16// of Unicode codepoints.
17//
18// To do this, we represent sets using something like a trie (which gives us
19// prefix compression). The "final" states of the trie are embedded in leaves
20// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
21// integer corresponds to whether a particular codepoint is in the set or not.
22// These chunks are not just a compact representation of the final states of
23// the trie, but are also a form of suffix compression. In particular, if
24// multiple ranges of 64 contiguous codepoints map have the same set membership
25// ordering, then they all map to the exact same chunk in the trie.
26//
27// We organize this structure by partitioning the space of Unicode codepoints
28// into three disjoint sets. The first set corresponds to codepoints
29// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
30// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
31// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
32// codepoints, respectively.
33//
34// Each partition has its own tree with its own root. The first partition is
35// the simplest, since the tree is completely flat. In particular, to determine
36// the set membership of a Unicode codepoint (that is less than `0x800`), we
37// do the following (where `cp` is the codepoint we're testing):
38//
39//     let chunk_address = cp >> 6;
40//     let chunk_bit = cp & 0b111111;
41//     let chunk = tree1[cp >> 6];
42//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
43//
44// We do something similar for the second partition:
45//
46//     // we subtract 0x20 since (0x800 >> 6) == 0x20.
47//     let child_address = (cp >> 6) - 0x20;
48//     let chunk_address = tree2_level1[child_address];
49//     let chunk_bit = cp & 0b111111;
50//     let chunk = tree2_level2[chunk_address];
51//     let is_member = 1 == ((chunk >> chunk_bit) & 1);
52//
53// And so on for the third partition.
54//
55// Note that as a special case, if the second or third partitions are empty,
56// then the trie will store empty slices for those levels. The `contains`
57// check knows to return `false` in those cases.
58
59const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
60
61/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
62pub type Result<T> = result::Result<T, Error>;
63
64/// An error that can occur during construction of a trie.
65#[derive(Clone, Debug)]
66pub enum Error {
67    /// This error is returned when an invalid codepoint is given to
68    /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
69    /// is greater than `0x10FFFF`.
70    InvalidCodepoint(u32),
71    /// This error is returned when a set of Unicode codepoints could not be
72    /// sufficiently compressed into the trie provided by this crate. There is
73    /// no work-around for this error at this time.
74    GaveUp,
75}
76
77impl error::Error for Error {}
78
79impl fmt::Display for Error {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        match *self {
82            Error::InvalidCodepoint(cp) => write!(
83                f,
84                "could not construct trie set containing an \
85                 invalid Unicode codepoint: 0x{:X}",
86                cp
87            ),
88            Error::GaveUp => {
89                write!(f, "could not compress codepoint set into a trie")
90            }
91        }
92    }
93}
94
95impl From<Error> for io::Error {
96    fn from(err: Error) -> io::Error {
97        io::Error::new(io::ErrorKind::Other, err)
98    }
99}
100
101/// An owned trie set.
102#[derive(Clone)]
103pub struct TrieSetOwned {
104    tree1_level1: Vec<u64>,
105    tree2_level1: Vec<u8>,
106    tree2_level2: Vec<u64>,
107    tree3_level1: Vec<u8>,
108    tree3_level2: Vec<u8>,
109    tree3_level3: Vec<u64>,
110}
111
112impl fmt::Debug for TrieSetOwned {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "TrieSetOwned(...)")
115    }
116}
117
118impl TrieSetOwned {
119    fn new(all: &[bool]) -> Result<TrieSetOwned> {
120        let mut bitvectors = Vec::with_capacity(CHUNKS);
121        for i in 0..CHUNKS {
122            let mut bitvector = 0u64;
123            for j in 0..CHUNK_SIZE {
124                if all[i * CHUNK_SIZE + j] {
125                    bitvector |= 1 << j;
126                }
127            }
128            bitvectors.push(bitvector);
129        }
130
131        let tree1_level1 =
132            bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
133
134        let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
135            &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
136        )?;
137        if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
138            tree2_level1.clear();
139            tree2_level2.clear();
140        }
141
142        let (mid, mut tree3_level3) = compress_postfix_leaves(
143            &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
144        )?;
145        let (mut tree3_level1, mut tree3_level2) =
146            compress_postfix_mid(&mid, 64)?;
147        if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
148            tree3_level1.clear();
149            tree3_level2.clear();
150            tree3_level3.clear();
151        }
152
153        Ok(TrieSetOwned {
154            tree1_level1,
155            tree2_level1,
156            tree2_level2,
157            tree3_level1,
158            tree3_level2,
159            tree3_level3,
160        })
161    }
162
163    /// Create a new trie set from a set of Unicode scalar values.
164    ///
165    /// This returns an error if a set could not be sufficiently compressed to
166    /// fit into a trie.
167    pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
168    where
169        I: IntoIterator<Item = C>,
170        C: Borrow<char>,
171    {
172        let mut all = vec![false; 0x110000];
173        for s in scalars {
174            all[*s.borrow() as usize] = true;
175        }
176        TrieSetOwned::new(&all)
177    }
178
179    /// Create a new trie set from a set of Unicode scalar values.
180    ///
181    /// This returns an error if a set could not be sufficiently compressed to
182    /// fit into a trie. This also returns an error if any of the given
183    /// codepoints are greater than `0x10FFFF`.
184    pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
185    where
186        I: IntoIterator<Item = C>,
187        C: Borrow<u32>,
188    {
189        let mut all = vec![false; 0x110000];
190        for cp in codepoints {
191            let cp = *cp.borrow();
192            if cp > 0x10FFFF {
193                return Err(Error::InvalidCodepoint(cp));
194            }
195            all[cp as usize] = true;
196        }
197        TrieSetOwned::new(&all)
198    }
199
200    /// Return this set as a slice.
201    #[inline(always)]
202    pub fn as_slice(&self) -> TrieSetSlice<'_> {
203        TrieSetSlice {
204            tree1_level1: &self.tree1_level1,
205            tree2_level1: &self.tree2_level1,
206            tree2_level2: &self.tree2_level2,
207            tree3_level1: &self.tree3_level1,
208            tree3_level2: &self.tree3_level2,
209            tree3_level3: &self.tree3_level3,
210        }
211    }
212
213    /// Returns true if and only if the given Unicode scalar value is in this
214    /// set.
215    pub fn contains_char(&self, c: char) -> bool {
216        self.as_slice().contains_char(c)
217    }
218
219    /// Returns true if and only if the given codepoint is in this set.
220    ///
221    /// If the given value exceeds the codepoint range (i.e., it's greater
222    /// than `0x10FFFF`), then this returns false.
223    pub fn contains_u32(&self, cp: u32) -> bool {
224        self.as_slice().contains_u32(cp)
225    }
226}
227
228fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
229    let mut root = vec![];
230    let mut children = vec![];
231    let mut bychild = HashMap::new();
232    for &chunk in chunks {
233        if !bychild.contains_key(&chunk) {
234            let start = bychild.len();
235            if start > ::std::u8::MAX as usize {
236                return Err(Error::GaveUp);
237            }
238            bychild.insert(chunk, start as u8);
239            children.push(chunk);
240        }
241        root.push(bychild[&chunk]);
242    }
243    Ok((root, children))
244}
245
246fn compress_postfix_mid(
247    chunks: &[u8],
248    chunk_size: usize,
249) -> Result<(Vec<u8>, Vec<u8>)> {
250    let mut root = vec![];
251    let mut children = vec![];
252    let mut bychild = HashMap::new();
253    for i in 0..(chunks.len() / chunk_size) {
254        let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
255        if !bychild.contains_key(chunk) {
256            let start = bychild.len();
257            if start > ::std::u8::MAX as usize {
258                return Err(Error::GaveUp);
259            }
260            bychild.insert(chunk, start as u8);
261            children.extend(chunk);
262        }
263        root.push(bychild[chunk]);
264    }
265    Ok((root, children))
266}
267
268#[cfg(test)]
269mod tests {
270    use super::TrieSetOwned;
271    use crate::general_category;
272    use std::collections::HashSet;
273
274    fn mk(scalars: &[char]) -> TrieSetOwned {
275        TrieSetOwned::from_scalars(scalars).unwrap()
276    }
277
278    fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
279        let mut set = vec![];
280        for &(start, end) in ranges {
281            for cp in start..end + 1 {
282                set.push(cp);
283            }
284        }
285        set
286    }
287
288    #[test]
289    fn set1() {
290        let set = mk(&['a']);
291        assert!(set.contains_char('a'));
292        assert!(!set.contains_char('b'));
293        assert!(!set.contains_char('β'));
294        assert!(!set.contains_char('☃'));
295        assert!(!set.contains_char('😼'));
296    }
297
298    #[test]
299    fn set_combined() {
300        let set = mk(&['a', 'b', 'β', '☃', '😼']);
301        assert!(set.contains_char('a'));
302        assert!(set.contains_char('b'));
303        assert!(set.contains_char('β'));
304        assert!(set.contains_char('☃'));
305        assert!(set.contains_char('😼'));
306
307        assert!(!set.contains_char('c'));
308        assert!(!set.contains_char('θ'));
309        assert!(!set.contains_char('⛇'));
310        assert!(!set.contains_char('🐲'));
311    }
312
313    // Basic tests on all of the general category sets. We check that
314    // membership is correct on every Unicode codepoint... because we can.
315
316    macro_rules! category_test {
317        ($name:ident, $ranges:ident) => {
318            #[test]
319            fn $name() {
320                let set = ranges_to_set(general_category::$ranges);
321                let hashset: HashSet<u32> = set.iter().cloned().collect();
322                let trie = TrieSetOwned::from_codepoints(&set).unwrap();
323                for cp in 0..0x110000 {
324                    assert!(trie.contains_u32(cp) == hashset.contains(&cp));
325                }
326                // Test that an invalid codepoint is treated correctly.
327                assert!(!trie.contains_u32(0x110000));
328                assert!(!hashset.contains(&0x110000));
329            }
330        };
331    }
332
333    category_test!(gencat_cased_letter, CASED_LETTER);
334    category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
335    category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
336    category_test!(gencat_control, CONTROL);
337    category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
338    category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
339    category_test!(gencat_decimal_number, DECIMAL_NUMBER);
340    category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
341    category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
342    category_test!(gencat_format, FORMAT);
343    category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
344    category_test!(gencat_letter, LETTER);
345    category_test!(gencat_letter_number, LETTER_NUMBER);
346    category_test!(gencat_line_separator, LINE_SEPARATOR);
347    category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
348    category_test!(gencat_math_symbol, MATH_SYMBOL);
349    category_test!(gencat_mark, MARK);
350    category_test!(gencat_modifier_letter, MODIFIER_LETTER);
351    category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
352    category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
353    category_test!(gencat_number, NUMBER);
354    category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
355    category_test!(gencat_other, OTHER);
356    category_test!(gencat_other_letter, OTHER_LETTER);
357    category_test!(gencat_other_number, OTHER_NUMBER);
358    category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
359    category_test!(gencat_other_symbol, OTHER_SYMBOL);
360    category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
361    category_test!(gencat_private_use, PRIVATE_USE);
362    category_test!(gencat_punctuation, PUNCTUATION);
363    category_test!(gencat_separator, SEPARATOR);
364    category_test!(gencat_space_separator, SPACE_SEPARATOR);
365    category_test!(gencat_spacing_mark, SPACING_MARK);
366    category_test!(gencat_surrogate, SURROGATE);
367    category_test!(gencat_symbol, SYMBOL);
368    category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
369    category_test!(gencat_unassigned, UNASSIGNED);
370    category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
371}