Skip to main content

rkyv/collections/swiss_table/
table.rs

1//! An archived hash table implementation based on Google's high-performance
2//! SwissTable hash map.
3//!
4//! Notable differences from other implementations:
5//!
6//! - The number of control bytes is rounded up to a maximum group width (16)
7//!   instead of the next power of two. This reduces the number of empty buckets
8//!   on the wire. Since this collection is immutable after writing, we'll never
9//!   benefit from having more buckets than we need.
10//! - Because the bucket count is not a power of two, the triangular probing
11//!   sequence simply skips any indices larger than the actual size of the
12//!   buckets array.
13//! - Instead of the final control bytes always being marked EMPTY, the last
14//!   control bytes repeat the first few. This helps reduce the number of
15//!   lookups when probing at the end of the control bytes.
16//! - Because the available SIMD group width may be less than the maximum group
17//!   width, each probe reads N groups before striding where N is the maximum
18//!   group width divided by the SIMD group width.
19
20use core::{
21    alloc::Layout,
22    borrow::Borrow,
23    error::Error,
24    fmt,
25    marker::PhantomData,
26    mem::{size_of, MaybeUninit},
27    ptr::{self, null, NonNull},
28    slice::from_raw_parts,
29};
30
31use munge::munge;
32use rancor::{fail, Fallible, ResultExt as _, Source};
33
34use crate::{
35    collections::util::IteratorLengthMismatch,
36    primitive::{ArchivedUsize, FixedUsize},
37    seal::Seal,
38    ser::{Allocator, Writer, WriterExt},
39    simd::{Bitmask, Group, MAX_GROUP_WIDTH},
40    util::SerVec,
41    Archive as _, Place, Portable, RawRelPtr, Serialize,
42};
43
44/// A low-level archived SwissTable hash table with explicit hashing.
45#[derive(Portable)]
46#[cfg_attr(
47    feature = "bytecheck",
48    derive(bytecheck::CheckBytes),
49    bytecheck(verify)
50)]
51#[rkyv(crate)]
52#[repr(C)]
53pub struct ArchivedHashTable<T> {
54    ptr: RawRelPtr,
55    len: ArchivedUsize,
56    cap: ArchivedUsize,
57    _phantom: PhantomData<T>,
58}
59
60#[inline]
61fn h1(hash: u64) -> u64 {
62    hash
63}
64
65#[inline]
66fn h2(hash: u64) -> u8 {
67    (hash >> 57) as u8
68}
69
70struct ProbeSeq {
71    pos: usize,
72    stride: usize,
73}
74
75impl ProbeSeq {
76    #[inline]
77    fn move_next(&mut self, bucket_mask: usize) {
78        self.stride += MAX_GROUP_WIDTH;
79        self.pos += self.stride;
80        self.pos &= bucket_mask;
81    }
82}
83
84impl<T> ArchivedHashTable<T> {
85    fn probe_seq(hash: u64, capacity: usize) -> ProbeSeq {
86        ProbeSeq {
87            pos: (h1(hash) % capacity as u64) as usize,
88            stride: 0,
89        }
90    }
91
92    /// # Safety
93    ///
94    /// - `this` must point to a valid `ArchivedHashTable`
95    /// - `index` must be less than `len()`
96    unsafe fn control_raw(this: *mut Self, index: usize) -> *const u8 {
97        debug_assert!(unsafe { !(*this).is_empty() });
98
99        // SAFETY: As an invariant of `ArchivedHashTable`, if `self` is not
100        // empty then `self.ptr` is a valid relative pointer. Since `index` is
101        // at least 0 and strictly less than `len()`, this table must not be
102        // empty.
103        let ptr =
104            unsafe { RawRelPtr::as_ptr_raw(ptr::addr_of_mut!((*this).ptr)) };
105        // SAFETY: The caller has guaranteed that `index` is less than `len()`,
106        // and the first `len()` bytes following `ptr` are the control bytes of
107        // the hash table.
108        unsafe { ptr.cast::<u8>().add(index) }
109    }
110
111    /// # Safety
112    ///
113    /// - `this` must point to a valid `ArchivedHashTable`
114    /// - `index` must be less than `len()`
115    unsafe fn bucket_raw(this: *mut Self, index: usize) -> NonNull<T> {
116        unsafe {
117            NonNull::new_unchecked(
118                RawRelPtr::as_ptr_raw(ptr::addr_of_mut!((*this).ptr))
119                    .cast::<T>()
120                    .sub(index + 1),
121            )
122        }
123    }
124
125    fn bucket_mask(capacity: usize) -> usize {
126        capacity.checked_next_power_of_two().unwrap() - 1
127    }
128
129    /// # Safety
130    ///
131    /// `this` must point to a valid `ArchivedHashTable`
132    unsafe fn get_entry_raw<C>(
133        this: *mut Self,
134        hash: u64,
135        cmp: C,
136    ) -> Option<NonNull<T>>
137    where
138        C: Fn(&T) -> bool,
139    {
140        let is_empty = unsafe { (*this).is_empty() };
141        if is_empty {
142            return None;
143        }
144
145        let capacity = unsafe { (*this).capacity() };
146        let probe_cap = Self::probe_cap(capacity);
147        let control_count = Self::control_count(probe_cap);
148
149        let h2_hash = h2(hash);
150        let mut probe_seq = Self::probe_seq(hash, capacity);
151
152        let bucket_mask = Self::bucket_mask(control_count);
153
154        loop {
155            let mut any_empty = false;
156
157            for i in 0..MAX_GROUP_WIDTH / Group::WIDTH {
158                let pos = probe_seq.pos + i * Group::WIDTH;
159
160                let group =
161                    unsafe { Group::read(Self::control_raw(this, pos)) };
162
163                for bit in group.match_byte(h2_hash) {
164                    let index = (pos + bit) % capacity;
165                    let bucket_ptr = unsafe { Self::bucket_raw(this, index) };
166                    let bucket = unsafe { bucket_ptr.as_ref() };
167
168                    // Opt: These can be marked as likely true on nightly.
169                    if cmp(bucket) {
170                        return Some(bucket_ptr);
171                    }
172                }
173
174                // Opt: These can be marked as likely true on nightly.
175                any_empty = any_empty || group.match_empty().any_bit_set();
176            }
177
178            if any_empty {
179                return None;
180            }
181
182            loop {
183                probe_seq.move_next(bucket_mask);
184                if probe_seq.pos < probe_cap {
185                    break;
186                }
187            }
188        }
189    }
190
191    /// Returns the key-value pair corresponding to the supplied key.
192    pub fn get_with<C>(&self, hash: u64, cmp: C) -> Option<&T>
193    where
194        C: Fn(&T) -> bool,
195    {
196        let this = (self as *const Self).cast_mut();
197        let ptr = unsafe { Self::get_entry_raw(this, hash, |e| cmp(e))? };
198        Some(unsafe { ptr.as_ref() })
199    }
200
201    /// Returns the mutable key-value pair corresponding to the supplied key.
202    pub fn get_seal_with<C>(
203        this: Seal<'_, Self>,
204        hash: u64,
205        cmp: C,
206    ) -> Option<Seal<'_, T>>
207    where
208        C: Fn(&T) -> bool,
209    {
210        let mut ptr = unsafe {
211            Self::get_entry_raw(this.unseal_unchecked(), hash, |e| cmp(e))?
212        };
213        Some(Seal::new(unsafe { ptr.as_mut() }))
214    }
215
216    /// Returns whether the hash table is empty.
217    pub const fn is_empty(&self) -> bool {
218        self.len.to_native() == 0
219    }
220
221    /// Returns the number of elements in the hash table.
222    pub const fn len(&self) -> usize {
223        self.len.to_native() as usize
224    }
225
226    /// Returns the total capacity of the hash table.
227    pub fn capacity(&self) -> usize {
228        self.cap.to_native() as usize
229    }
230
231    /// # Safety
232    ///
233    /// This hash table must not be empty.
234    unsafe fn control_iter(this: *mut Self) -> ControlIter {
235        ControlIter {
236            current_mask: unsafe {
237                Group::read(Self::control_raw(this, 0)).match_full()
238            },
239            next_group: unsafe { Self::control_raw(this, Group::WIDTH) },
240        }
241    }
242
243    /// Returns an iterator over the entry pointers in the hash table.
244    pub fn raw_iter(&self) -> RawIter<T> {
245        if self.is_empty() {
246            RawIter::empty()
247        } else {
248            let this = (self as *const Self).cast_mut();
249            RawIter {
250                // SAFETY: We have checked that `self` is not empty.
251                controls: unsafe { Self::control_iter(this) },
252                entries: unsafe {
253                    NonNull::new_unchecked(self.ptr.as_ptr().cast_mut().cast())
254                },
255                items_left: self.len(),
256            }
257        }
258    }
259
260    /// Returns a sealed iterator over the entry pointers in the hash table.
261    pub fn raw_iter_seal(mut this: Seal<'_, Self>) -> RawIter<T> {
262        if this.is_empty() {
263            RawIter::empty()
264        } else {
265            // SAFETY: We have checked that `this` is not empty.
266            let controls =
267                unsafe { Self::control_iter(this.as_mut().unseal_unchecked()) };
268            let items_left = this.len();
269            munge!(let Self { ptr, .. } = this);
270            RawIter {
271                controls,
272                entries: unsafe {
273                    NonNull::new_unchecked(RawRelPtr::as_mut_ptr(ptr).cast())
274                },
275                items_left,
276            }
277        }
278    }
279
280    fn capacity_from_len(len: usize, load_factor: (usize, usize)) -> usize {
281        if len == 0 {
282            0
283        } else {
284            usize::max(len * load_factor.1 / load_factor.0, len + 1)
285        }
286    }
287
288    fn probe_cap(capacity: usize) -> usize {
289        capacity.next_multiple_of(MAX_GROUP_WIDTH)
290    }
291
292    fn control_count(probe_cap: usize) -> usize {
293        probe_cap + MAX_GROUP_WIDTH - 1
294    }
295
296    #[allow(dead_code)]
297    fn memory_layout<E: Source>(
298        capacity: usize,
299        control_count: usize,
300    ) -> Result<(Layout, usize), E> {
301        let buckets_layout = Layout::array::<T>(capacity).into_error()?;
302        let control_layout = Layout::array::<u8>(control_count).into_error()?;
303        buckets_layout.extend(control_layout).into_error()
304    }
305
306    /// Serializes an iterator of items as a hash table.
307    pub fn serialize_from_iter<I, U, H, S>(
308        items: I,
309        hashes: H,
310        load_factor: (usize, usize),
311        serializer: &mut S,
312    ) -> Result<HashTableResolver, S::Error>
313    where
314        I: Clone + ExactSizeIterator,
315        I::Item: Borrow<U>,
316        U: Serialize<S, Archived = T>,
317        H: ExactSizeIterator<Item = u64>,
318        S: Fallible + Writer + Allocator + ?Sized,
319        S::Error: Source,
320    {
321        #[derive(Debug)]
322        struct InvalidLoadFactor {
323            numerator: usize,
324            denominator: usize,
325        }
326
327        impl fmt::Display for InvalidLoadFactor {
328            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329                write!(
330                    f,
331                    "invalid load factor {} / {}, load factor must be a \
332                     fraction in the range (0, 1]",
333                    self.numerator, self.denominator
334                )
335            }
336        }
337
338        impl Error for InvalidLoadFactor {}
339
340        if load_factor.0 == 0
341            || load_factor.1 == 0
342            || load_factor.0 > load_factor.1
343        {
344            fail!(InvalidLoadFactor {
345                numerator: load_factor.0,
346                denominator: load_factor.1,
347            });
348        }
349
350        let len = items.len();
351
352        if len == 0 {
353            let count = items.count();
354            if count != 0 {
355                fail!(IteratorLengthMismatch {
356                    expected: 0,
357                    actual: count,
358                });
359            }
360
361            return Ok(HashTableResolver { pos: 0 });
362        }
363
364        let capacity = Self::capacity_from_len(len, load_factor);
365        let probe_cap = Self::probe_cap(capacity);
366        let control_count = Self::control_count(probe_cap);
367
368        // Determine hash locations for all items
369        SerVec::with_capacity(
370            serializer,
371            capacity,
372            |ordered_items, serializer| {
373                for _ in 0..capacity {
374                    unsafe {
375                        ordered_items.push_unchecked(None);
376                    }
377                }
378
379                SerVec::<u8>::with_capacity(
380                    serializer,
381                    control_count,
382                    |control_bytes, serializer| {
383                        // Initialize all control bytes to EMPTY (0xFF)
384                        unsafe {
385                            control_bytes
386                                .as_mut_ptr()
387                                .write_bytes(0xff, control_bytes.capacity());
388                            control_bytes.set_len(control_bytes.capacity());
389                        }
390
391                        let bucket_mask = Self::bucket_mask(control_count);
392
393                        for (item, hash) in items.zip(hashes) {
394                            let h2_hash = h2(hash);
395                            let mut probe_seq = Self::probe_seq(hash, capacity);
396
397                            'insert: loop {
398                                for i in 0..MAX_GROUP_WIDTH / Group::WIDTH {
399                                    let pos = probe_seq.pos + i * Group::WIDTH;
400                                    let group = unsafe {
401                                        Group::read(
402                                            control_bytes.as_ptr().add(pos),
403                                        )
404                                    };
405
406                                    if let Some(bit) =
407                                        group.match_empty().lowest_set_bit()
408                                    {
409                                        let index = (pos + bit) % capacity;
410
411                                        // Update control byte
412                                        control_bytes[index] = h2_hash;
413                                        // If it's near the beginning of the
414                                        // control bytes,
415                                        // update the wraparound control byte
416                                        if index < (control_count - capacity) {
417                                            control_bytes[capacity + index] =
418                                                h2_hash;
419                                        }
420
421                                        ordered_items[index] = Some(item);
422                                        break 'insert;
423                                    }
424                                }
425
426                                loop {
427                                    probe_seq.move_next(bucket_mask);
428                                    if probe_seq.pos < probe_cap {
429                                        break;
430                                    }
431                                }
432                            }
433                        }
434
435                        let mut zeros = MaybeUninit::<T>::uninit();
436                        unsafe {
437                            zeros.as_mut_ptr().write_bytes(0, 1);
438                        }
439                        let zeros = unsafe {
440                            from_raw_parts(
441                                zeros.as_ptr().cast::<u8>(),
442                                size_of::<T>(),
443                            )
444                        };
445                        SerVec::with_capacity(
446                            serializer,
447                            len,
448                            |resolvers, serializer| {
449                                for item in ordered_items
450                                    .iter()
451                                    .filter_map(|x| x.as_ref())
452                                {
453                                    resolvers.push(
454                                        item.borrow().serialize(serializer)?,
455                                    );
456                                }
457
458                                serializer.align_for::<T>()?;
459
460                                let mut resolvers = resolvers.drain().rev();
461                                for item in ordered_items.iter().rev() {
462                                    if let Some(item) = item {
463                                        unsafe {
464                                            serializer.resolve_aligned(
465                                                item.borrow(),
466                                                resolvers.next().unwrap(),
467                                            )?;
468                                        }
469                                    } else {
470                                        serializer.write(zeros)?;
471                                    }
472                                }
473
474                                let pos = serializer.pos();
475                                serializer.write(control_bytes)?;
476
477                                Ok(HashTableResolver {
478                                    pos: pos as FixedUsize,
479                                })
480                            },
481                        )?
482                    },
483                )?
484            },
485        )?
486    }
487
488    /// Resolves an archived hash table from a given length and parameters.
489    pub fn resolve_from_len(
490        len: usize,
491        load_factor: (usize, usize),
492        resolver: HashTableResolver,
493        out: Place<Self>,
494    ) {
495        munge!(let Self { ptr, len: out_len, cap, _phantom: _ } = out);
496
497        if len == 0 {
498            RawRelPtr::emplace_invalid(ptr);
499        } else {
500            RawRelPtr::emplace(resolver.pos as usize, ptr);
501        }
502
503        len.resolve((), out_len);
504
505        let capacity = Self::capacity_from_len(len, load_factor);
506        capacity.resolve((), cap);
507
508        // PhantomData doesn't need to be initialized
509    }
510}
511
512/// The resolver for [`ArchivedHashTable`].
513pub struct HashTableResolver {
514    pos: FixedUsize,
515}
516
517struct ControlIter {
518    current_mask: Bitmask,
519    next_group: *const u8,
520}
521
522unsafe impl Send for ControlIter {}
523unsafe impl Sync for ControlIter {}
524
525impl ControlIter {
526    fn none() -> Self {
527        Self {
528            current_mask: Bitmask::EMPTY,
529            next_group: null(),
530        }
531    }
532
533    #[inline]
534    fn next_full(&mut self) -> Option<usize> {
535        let bit = self.current_mask.lowest_set_bit()?;
536        self.current_mask = self.current_mask.remove_lowest_bit();
537        Some(bit)
538    }
539
540    #[inline]
541    fn move_next(&mut self) {
542        self.current_mask =
543            unsafe { Group::read(self.next_group).match_full() };
544        self.next_group = unsafe { self.next_group.add(Group::WIDTH) };
545    }
546}
547
548/// An iterator over the entry pointers of an [`ArchivedHashTable`].
549pub struct RawIter<T> {
550    controls: ControlIter,
551    entries: NonNull<T>,
552    items_left: usize,
553}
554
555impl<T> RawIter<T> {
556    /// Returns a raw iterator which yields no elements.
557    pub fn empty() -> Self {
558        Self {
559            controls: ControlIter::none(),
560            entries: NonNull::dangling(),
561            items_left: 0,
562        }
563    }
564}
565
566impl<T> Iterator for RawIter<T> {
567    type Item = NonNull<T>;
568
569    fn next(&mut self) -> Option<Self::Item> {
570        if self.items_left == 0 {
571            None
572        } else {
573            let bit = loop {
574                if let Some(bit) = self.controls.next_full() {
575                    break bit;
576                }
577                self.controls.move_next();
578                self.entries = unsafe {
579                    NonNull::new_unchecked(
580                        self.entries.as_ptr().sub(Group::WIDTH),
581                    )
582                };
583            };
584            self.items_left -= 1;
585            let entry = unsafe {
586                NonNull::new_unchecked(self.entries.as_ptr().sub(bit + 1))
587            };
588            Some(entry)
589        }
590    }
591}
592
593impl<T> ExactSizeIterator for RawIter<T> {
594    fn len(&self) -> usize {
595        self.items_left
596    }
597}
598
599#[cfg(feature = "bytecheck")]
600mod verify {
601    use core::{error::Error, fmt};
602
603    use bytecheck::{CheckBytes, Verify};
604    use rancor::{fail, Fallible, Source};
605
606    use super::ArchivedHashTable;
607    use crate::{
608        simd::Group,
609        validation::{ArchiveContext, ArchiveContextExt as _},
610    };
611
612    #[derive(Debug)]
613    struct InvalidLength {
614        len: usize,
615        cap: usize,
616    }
617
618    impl fmt::Display for InvalidLength {
619        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
620            write!(
621                f,
622                "hash table length must be strictly less than its capacity \
623                 (length: {}, capacity: {})",
624                self.len, self.cap,
625            )
626        }
627    }
628
629    impl Error for InvalidLength {}
630
631    #[derive(Debug)]
632    struct UnwrappedControlByte {
633        index: usize,
634    }
635
636    impl fmt::Display for UnwrappedControlByte {
637        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
638            write!(f, "unwrapped control byte at index {}", self.index,)
639        }
640    }
641
642    impl Error for UnwrappedControlByte {}
643
644    unsafe impl<C, T> Verify<C> for ArchivedHashTable<T>
645    where
646        C: Fallible + ArchiveContext + ?Sized,
647        C::Error: Source,
648        T: CheckBytes<C>,
649    {
650        fn verify(&self, context: &mut C) -> Result<(), C::Error> {
651            let len = self.len();
652            let cap = self.capacity();
653
654            if len == 0 && cap == 0 {
655                return Ok(());
656            }
657
658            if len >= cap {
659                fail!(InvalidLength { len, cap });
660            }
661
662            // Check memory allocation
663            let probe_cap = Self::probe_cap(cap);
664            let control_count = Self::control_count(probe_cap);
665            let (layout, control_offset) =
666                Self::memory_layout(cap, control_count)?;
667            let ptr = self
668                .ptr
669                .as_ptr_wrapping()
670                .cast::<u8>()
671                .wrapping_sub(control_offset);
672
673            context.in_subtree_raw(ptr, layout, |context| {
674                // Check each non-empty bucket
675
676                let this = (self as *const Self).cast_mut();
677                // SAFETY: We have checked that `self` is not empty.
678                let mut controls = unsafe { Self::control_iter(this) };
679                let mut base_index = 0;
680                'outer: while base_index < cap {
681                    while let Some(bit) = controls.next_full() {
682                        let index = base_index + bit;
683                        if index >= cap {
684                            break 'outer;
685                        }
686
687                        unsafe {
688                            T::check_bytes(
689                                Self::bucket_raw(this, index).as_ptr(),
690                                context,
691                            )?;
692                        }
693                    }
694
695                    controls.move_next();
696                    base_index += Group::WIDTH;
697                }
698
699                // Verify that wrapped bytes are set correctly
700                for i in cap..usize::min(2 * cap, control_count - cap) {
701                    let byte = unsafe { *Self::control_raw(this, i) };
702                    let wrapped = unsafe { *Self::control_raw(this, i % cap) };
703                    if wrapped != byte {
704                        fail!(UnwrappedControlByte { index: i })
705                    }
706                }
707
708                Ok(())
709            })
710        }
711    }
712}