Skip to main content

fuchsia_rcu_collections/
rcu_raw_hash_map.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![warn(unsafe_op_in_unsafe_fn)]
6
7use crate::rcu_array::RcuArray;
8use crate::rcu_intrusive_list::{
9    Link, RcuIntrusiveList, RcuIntrusiveListCursor, RcuListAdapter, rcu_list_adapter,
10};
11use crate::rcu_list::RcuList;
12use fuchsia_rcu::RcuReadScope;
13use std::borrow::Borrow;
14use std::hash::{BuildHasher, Hash, Hasher};
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17/// The initial capacity of the hash map.
18const INITIAL_CAPACITY: usize = 16;
19
20/// An entry in the hash table.
21#[derive(Debug)]
22struct Entry<K, V> {
23    /// The key for this entry.
24    key: K,
25
26    /// The value for this entry.
27    value: V,
28
29    /// The link to the next node in the collision chain for this bucket.
30    collision_chain: Link,
31
32    /// The link to the next node in the insertion chain for this bucket.
33    insertion_chain: Link,
34}
35
36impl<K, V> Entry<K, V> {
37    /// Create a new hash table entry.
38    fn new(key: K, value: V) -> Self {
39        Self {
40            key,
41            value,
42            collision_chain: Default::default(),
43            insertion_chain: Default::default(),
44        }
45    }
46}
47
48/// An RcuListAdapter for the collision chain.
49#[derive(Debug)]
50struct CollisionAdapter;
51
52impl<K, V> RcuListAdapter<Entry<K, V>> for CollisionAdapter {
53    rcu_list_adapter!(Entry<K, V>, collision_chain);
54}
55
56/// An RcuListAdapter for the insertion chain.
57#[derive(Debug)]
58struct InsertionAdapter;
59
60impl<K, V> RcuListAdapter<Entry<K, V>> for InsertionAdapter {
61    rcu_list_adapter!(Entry<K, V>, insertion_chain);
62}
63
64/// The result of inserting an entry into the map.
65pub enum InsertionResult<V> {
66    /// The entry was inserted.
67    ///
68    /// The number of entries in the map is returned.
69    Inserted(usize),
70
71    /// The entry was updated.
72    ///
73    /// The old value is returned.
74    Updated(V),
75}
76
77/// The bucket in the hash table.
78///
79/// Each bucket is a linked list to hold the collision chain.
80type Bucket<K, V> = RcuList<Entry<K, V>, CollisionAdapter>;
81
82/// A hash map that uses read-copy-update (RCU) to manage concurrent accesses.
83///
84/// By default, this map uses `rapidhash::RapidBuildHasher`, which provides high performance.
85/// However, if this map holds keys which may be attacker-controlled, consider using
86/// `std::collections::hash_map::RandomState` instead.
87pub struct RcuRawHashMap<K, V, S = rapidhash::RapidBuildHasher>
88where
89    K: Eq + Hash + Clone + Send + Sync + 'static,
90    V: Clone + Send + Sync + 'static,
91    S: BuildHasher + Send + Sync + 'static,
92{
93    /// The table of buckets.
94    table: RcuArray<Bucket<K, V>>,
95
96    /// The number of entries in the map.
97    num_entries: AtomicUsize,
98
99    /// The entries in this map in the order they were inserted.
100    insertion_chain: RcuIntrusiveList<Entry<K, V>, InsertionAdapter>,
101
102    /// The build hasher.
103    hash_builder: S,
104}
105
106impl<K, V> Default for RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
107where
108    K: Eq + Hash + Clone + Send + Sync + 'static,
109    V: Clone + Send + Sync + 'static,
110{
111    fn default() -> Self {
112        Self::with_capacity_and_hasher(0, rapidhash::RapidBuildHasher::default())
113    }
114}
115
116impl<K, V> RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
117where
118    K: Eq + Hash + Clone + Send + Sync + 'static,
119    V: Clone + Send + Sync + 'static,
120{
121    /// Creates a new hash map with the given capacity.
122    pub fn with_capacity(capacity: usize) -> Self {
123        Self::with_capacity_and_hasher(capacity, rapidhash::RapidBuildHasher::default())
124    }
125}
126
127impl<K, V, S> RcuRawHashMap<K, V, S>
128where
129    K: Eq + Hash + Clone + Send + Sync + 'static,
130    V: Clone + Send + Sync + 'static,
131    S: BuildHasher + Send + Sync + 'static,
132{
133    /// Creates a new hash map with the given capacity and hasher.
134    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
135        let mut table = Vec::new();
136        table.resize_with((capacity + 1) / 2, Default::default);
137        Self {
138            table: RcuArray::from(table),
139            num_entries: AtomicUsize::new(0),
140            insertion_chain: Default::default(),
141            hash_builder,
142        }
143    }
144
145    /// Creates a new hash map with the given hasher.
146    pub fn with_hasher(hash_builder: S) -> Self {
147        Self::with_capacity_and_hasher(0, hash_builder)
148    }
149
150    /// Returns the hash of the key as a u64.
151    fn hash_key<Q>(&self, key: &Q) -> u64
152    where
153        Q: ?Sized + Hash,
154    {
155        let mut hasher = self.hash_builder.build_hasher();
156        key.hash(&mut hasher);
157        hasher.finish()
158    }
159
160    /// Returns the bucket for the given key in the given table.
161    fn get_bucket<'a, Q>(&self, table: &'a [Bucket<K, V>], key: &Q) -> &'a Bucket<K, V>
162    where
163        K: Borrow<Q>,
164        Q: ?Sized + Hash,
165    {
166        let hash = self.hash_key(key);
167        let index = hash as usize % table.len();
168        &table[index]
169    }
170
171    /// Returns a reference to the bucket for the given key.
172    fn read_bucket<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a Bucket<K, V>>
173    where
174        K: Borrow<Q>,
175        Q: ?Sized + Hash,
176    {
177        let table = self.table.as_slice(scope);
178        if table.is_empty() {
179            return None;
180        }
181        Some(self.get_bucket(table, key))
182    }
183
184    /// Returns a reference to the value corresponding to the key.
185    ///
186    /// Another thread running concurrently might see a different value for the object.
187    pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
188    where
189        K: Borrow<Q>,
190        Q: ?Sized + Hash + Eq,
191    {
192        let bucket = self.read_bucket(scope, key)?;
193        bucket.iter(scope).find(|entry| entry.key.borrow() == key).map(|entry| &entry.value)
194    }
195
196    /// Returns the number of entries in the map.
197    ///
198    /// The length can change concurrently with this call.
199    pub fn len(&self) -> usize {
200        self.num_entries.load(Ordering::Relaxed)
201    }
202
203    /// Inserts a key-value pair into the map.
204    ///
205    /// If the map did not have this key present, `None` is returned.
206    ///
207    /// If the map did have this key present, the value is updated, and the old
208    /// value is returned.
209    ///
210    /// Concurrent readers might not see the inserted value until the RCU state machine has made
211    /// sufficient progress to ensure that no concurrent readers are holding read guards.
212    ///
213    /// # Safety
214    ///
215    /// Requires external synchronization to exclude concurrent writers.
216    pub unsafe fn insert(&self, scope: &RcuReadScope, key: K, value: V) -> InsertionResult<V> {
217        let mut table = self.table.as_slice(scope);
218        if self.needs_to_grow(table) {
219            // SAFETY: Our caller is required to use external synchronization to exclude concurrent
220            // writers.
221            table = unsafe { self.grow(&scope, table) };
222        }
223        let bucket = self.get_bucket(table, &key);
224        let mut cursor = bucket.cursor(&scope);
225        while let Some(entry) = cursor.current() {
226            if entry.key == key {
227                let old_value = entry.value.clone();
228                // SAFETY: We have exclusive access to the bucket because we have exclusive access
229                // to the table.
230                unsafe {
231                    let removed_entry = cursor.remove();
232                    self.insertion_chain.remove(&scope, removed_entry);
233                    let entry = bucket.push_front(&scope, Entry::new(key, value));
234                    self.insertion_chain.push_back(&scope, entry);
235                };
236                return InsertionResult::Updated(old_value);
237            }
238            cursor.advance();
239        }
240
241        // SAFETY: We have exclusive access to the bucket because we have exclusive access to the
242        // table.
243        unsafe {
244            let entry = bucket.push_front(&scope, Entry::new(key, value));
245            self.insertion_chain.push_back(&scope, entry);
246        }
247        let count = self.num_entries.fetch_add(1, Ordering::Relaxed);
248        InsertionResult::Inserted(count + 1)
249    }
250
251    /// Removes a key from the map, returning the value at the key if the key
252    /// was previously in the map.
253    ///
254    /// Concurrent readers might see the removed value until the RCU state machine has made
255    /// sufficient progress to ensure that no concurrent readers are holding read guards.
256    ///
257    /// # Safety
258    ///
259    /// Requires external synchronization to exclude concurrent writers.
260    pub unsafe fn remove<Q>(&self, key: &Q) -> Option<V>
261    where
262        K: Borrow<Q>,
263        Q: ?Sized + Hash + Eq,
264    {
265        let scope = RcuReadScope::new();
266        let bucket = self.read_bucket(&scope, key)?;
267        let mut cursor = bucket.cursor(&scope);
268        while let Some(entry) = cursor.current() {
269            if entry.key.borrow() == key {
270                let old_value = entry.value.clone();
271                // SAFETY: We have exclusive access to the bucket because we have exclusive access
272                // to the table.
273                unsafe {
274                    let removed_entry = cursor.remove();
275                    self.insertion_chain.remove(&scope, removed_entry);
276                };
277                self.num_entries.fetch_sub(1, Ordering::Relaxed);
278                return Some(old_value);
279            }
280            cursor.advance();
281        }
282        None
283    }
284
285    /// Whether the given table needs to grow to reduce the number of collisions.
286    fn needs_to_grow(&self, table: &[Bucket<K, V>]) -> bool {
287        table.is_empty() || self.num_entries.load(Ordering::Relaxed) > table.len() * 2
288    }
289
290    /// Grows the table to reduce the number of collisions.
291    ///
292    /// Returns a reference to the new table. Callers should be sure to update the table reference
293    /// they are using to the returned value.
294    ///
295    /// # Safety
296    ///
297    /// Requires external synchronization to exclude concurrent writers.
298    #[must_use]
299    unsafe fn grow<'a>(
300        &self,
301        scope: &'a RcuReadScope,
302        old_table: &[Bucket<K, V>],
303    ) -> &'a [Bucket<K, V>] {
304        let new_size = if old_table.is_empty() { INITIAL_CAPACITY } else { old_table.len() * 2 };
305        let mut new_table = Vec::new();
306        let new_insertion_chain = RcuIntrusiveList::default();
307        new_table.resize_with(new_size, Default::default);
308
309        for entry in self.insertion_chain.iter(scope) {
310            let bucket = self.get_bucket(&new_table, &entry.key);
311            let key = entry.key.clone();
312            let value = entry.value.clone();
313            // SAFETY: We have exclusive access to new_table_vec because we just created it.
314            unsafe {
315                let entry = bucket.push_front(&scope, Entry::new(key, value));
316                new_insertion_chain.push_back(&scope, entry);
317            };
318        }
319
320        self.table.update(new_table);
321        // SAFETY: Our caller promises to exclude concurrent writers.
322        unsafe {
323            self.insertion_chain.update(&scope, new_insertion_chain);
324        }
325        self.table.as_slice(scope)
326    }
327
328    /// Returns a cursor that can be used to traverse and modify the map.
329    ///
330    /// The cursor iterates through the map in insertion order.
331    pub fn cursor<'a>(&'a self, scope: &'a RcuReadScope) -> RcuRawHashMapCursor<'a, K, V, S> {
332        RcuRawHashMapCursor { inner: self.insertion_chain.cursor(scope), map: self }
333    }
334
335    /// Returns an iterator over the keys in the map.
336    pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
337        self.insertion_chain.iter(scope).map(|entry| &entry.key)
338    }
339}
340
341// TODO(https://fxbug.dev/482462174): switch back to #[derive(Debug)]
342impl<K, V, S> std::fmt::Debug for RcuRawHashMap<K, V, S>
343where
344    K: Eq + Hash + Clone + Send + Sync + 'static + std::fmt::Debug,
345    V: Clone + Send + Sync + 'static + std::fmt::Debug,
346    S: std::hash::BuildHasher + Send + Sync + 'static,
347{
348    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349        f.debug_struct("RcuRawHashMap")
350            .field("table", &self.table)
351            .field("num_entries", &self.num_entries)
352            .field("insertion_chain", &self.insertion_chain)
353            .field("hash_builder", &std::any::type_name::<S>())
354            .finish_non_exhaustive()
355    }
356}
357
358/// A cursor for traversing and modifying an `RcuRawHashMap`.
359///
360/// See `RcuRawHashMap::cursor` for more information.
361pub struct RcuRawHashMapCursor<'a, K, V, S = rapidhash::RapidBuildHasher>
362where
363    K: Eq + Hash + Clone + Send + Sync + 'static,
364    V: Clone + Send + Sync + 'static,
365    S: BuildHasher + Send + Sync + 'static,
366{
367    inner: RcuIntrusiveListCursor<'a, Entry<K, V>, InsertionAdapter>,
368    map: &'a RcuRawHashMap<K, V, S>,
369}
370
371impl<'a, K, V, S> RcuRawHashMapCursor<'a, K, V, S>
372where
373    K: Eq + Hash + Clone + Send + Sync + 'static,
374    V: Clone + Send + Sync + 'static,
375    S: BuildHasher + Send + Sync + 'static,
376{
377    /// Returns the element at the current cursor position.
378    pub fn current(&self) -> Option<(&'a K, &'a V)> {
379        self.inner.current().map(|entry| (&entry.key, &entry.value))
380    }
381
382    /// Advances the cursor to the next element in the list.
383    pub fn advance(&mut self) {
384        self.inner.advance()
385    }
386
387    /// Removes the element at the current cursor position.
388    ///
389    /// After calling `remove`, the cursor will be positioned at the next element in the list.
390    ///
391    /// Concurrent readers may continue to see this entry in the list until the RCU state machine
392    /// has made sufficient progress to ensure that no concurrent readers are holding read guards.
393    ///
394    /// # Safety
395    ///
396    /// Requires external synchronization to exclude concurrent writers.
397    pub unsafe fn remove(&mut self) -> Option<V> {
398        if let Some((key, _)) = self.current() {
399            self.advance();
400            // SAFETY: The caller promises to exclude concurrent writers.
401            unsafe { self.map.remove(key) }
402        } else {
403            None
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use fuchsia_rcu::rcu_synchronize;
412
413    #[test]
414    fn test_rcu_hash_map_custom_hasher() {
415        use std::collections::hash_map::DefaultHasher;
416        use std::hash::BuildHasherDefault;
417        let hasher = BuildHasherDefault::<DefaultHasher>::default();
418        let map = RcuRawHashMap::with_capacity_and_hasher(10, hasher);
419        let scope = RcuReadScope::new();
420        unsafe {
421            map.insert(&scope, 1, 10);
422        }
423        assert_eq!(map.get(&scope, &1), Some(&10));
424    }
425
426    #[test]
427    fn test_rcu_hash_map_insert_and_get() {
428        let map = RcuRawHashMap::default();
429        let scope = RcuReadScope::new();
430        unsafe {
431            map.insert(&scope, 1, 10);
432            map.insert(&scope, 2, 20);
433        }
434
435        assert_eq!(map.get(&scope, &1), Some(&10));
436        assert_eq!(map.get(&scope, &2), Some(&20));
437        assert_eq!(map.get(&scope, &3), None);
438
439        std::mem::drop(scope);
440        rcu_synchronize();
441    }
442
443    #[test]
444    fn test_rcu_hash_map_remove() {
445        let map = RcuRawHashMap::default();
446        let scope = RcuReadScope::new();
447        unsafe {
448            map.insert(&scope, 1, 10);
449            map.insert(&scope, 2, 20);
450        }
451
452        assert_eq!(map.get(&scope, &1), Some(&10));
453
454        unsafe {
455            assert_eq!(map.remove(&1), Some(10));
456        }
457
458        assert_eq!(map.get(&scope, &1), None);
459        assert_eq!(map.get(&scope, &2), Some(&20));
460
461        std::mem::drop(scope);
462        rcu_synchronize();
463    }
464
465    #[test]
466    fn test_rcu_hash_map_insert_update() {
467        let map = RcuRawHashMap::default();
468        let scope = RcuReadScope::new();
469        unsafe {
470            map.insert(&scope, 1, 10);
471        }
472
473        assert_eq!(map.get(&scope, &1), Some(&10));
474
475        let result = unsafe { map.insert(&scope, 1, 100) };
476        assert!(matches!(result, InsertionResult::Updated(10)));
477
478        assert_eq!(map.get(&scope, &1), Some(&100));
479
480        std::mem::drop(scope);
481        rcu_synchronize();
482    }
483
484    #[test]
485    fn test_rcu_hash_map_cursor() {
486        let map = RcuRawHashMap::default();
487        let scope = RcuReadScope::new();
488        unsafe {
489            map.insert(&scope, 1, 10);
490            map.insert(&scope, 2, 20);
491            map.insert(&scope, 3, 30);
492        }
493
494        let mut cursor = map.cursor(&scope);
495
496        assert_eq!(cursor.current(), Some((&1, &10)));
497        cursor.advance();
498        assert_eq!(cursor.current(), Some((&2, &20)));
499
500        unsafe {
501            cursor.remove();
502        }
503
504        assert_eq!(cursor.current(), Some((&3, &30)));
505        assert_eq!(map.get(&scope, &2), None);
506
507        cursor.advance();
508        assert_eq!(cursor.current(), None);
509
510        std::mem::drop(scope);
511        rcu_synchronize();
512    }
513
514    #[test]
515    fn test_rcu_hash_map_grow_maintains_order() {
516        let map = RcuRawHashMap::default();
517        let scope = RcuReadScope::new();
518        let num_elements = INITIAL_CAPACITY * 3;
519        let mut expected_order = Vec::new();
520
521        for i in 0..num_elements {
522            unsafe {
523                map.insert(&scope, i, i * 10);
524            }
525            expected_order.push((i, i * 10));
526        }
527
528        let mut cursor = map.cursor(&scope);
529        let mut actual_order = Vec::new();
530
531        while let Some((key, value)) = cursor.current() {
532            actual_order.push((*key, *value));
533            cursor.advance();
534        }
535
536        assert_eq!(actual_order, expected_order);
537
538        std::mem::drop(scope);
539        rcu_synchronize();
540    }
541    #[test]
542    fn test_rcu_hash_map_grow_overwrites_maintain_order() {
543        let map = RcuRawHashMap::default();
544        let scope = RcuReadScope::new();
545        let num_elements = INITIAL_CAPACITY * 3;
546        let mut expected_order = Vec::new();
547
548        for i in 0..num_elements {
549            unsafe {
550                map.insert(&scope, i, i * 10);
551            }
552            expected_order.push((i, i * 10));
553        }
554
555        // Overwrite some existing entries and add new ones
556        unsafe {
557            map.insert(&scope, 5, 500);
558            map.insert(&scope, INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10); // New entry
559        }
560        expected_order.retain(|(k, _)| *k != 5);
561        expected_order.push((5, 500));
562        expected_order.push((INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10));
563
564        let mut cursor = map.cursor(&scope);
565        let mut actual_order = Vec::new();
566
567        while let Some((key, value)) = cursor.current() {
568            actual_order.push((*key, *value));
569            cursor.advance();
570        }
571
572        assert_eq!(actual_order, expected_order);
573
574        std::mem::drop(scope);
575        rcu_synchronize();
576    }
577
578    #[test]
579    fn test_rcu_hash_map_grow() {
580        let map = RcuRawHashMap::default();
581        let scope = RcuReadScope::new();
582        for i in 0..(INITIAL_CAPACITY * 3) {
583            unsafe {
584                map.insert(&scope, i, i * 10);
585            }
586        }
587
588        for i in 0..(INITIAL_CAPACITY * 3) {
589            assert_eq!(map.get(&scope, &i), Some(&(i * 10)));
590        }
591
592        std::mem::drop(scope);
593        rcu_synchronize();
594    }
595
596    #[test]
597    fn test_rcu_hash_map_capacity_zero() {
598        let map = RcuRawHashMap::with_capacity(0);
599        let scope = RcuReadScope::new();
600
601        assert_eq!(map.get(&scope, &1), None);
602
603        unsafe {
604            map.insert(&scope, 1, 10);
605        }
606        assert_eq!(map.get(&scope, &1), Some(&10));
607
608        unsafe {
609            assert_eq!(map.remove(&1), Some(10));
610        }
611        assert_eq!(map.get(&scope, &1), None);
612
613        std::mem::drop(scope);
614        rcu_synchronize();
615    }
616}