Skip to main content

fuchsia_rcu_collections/
rcu_raw_hash_map.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![warn(unsafe_op_in_unsafe_fn)]
6
7use crate::rcu_array::RcuArray;
8use crate::rcu_intrusive_list::{
9    Link, RcuIntrusiveList, RcuIntrusiveListCursor, RcuListAdapter, rcu_list_adapter,
10};
11use crate::rcu_list::RcuList;
12use fuchsia_rcu::RcuReadScope;
13use std::borrow::Borrow;
14use std::hash::{BuildHasher, Hash, Hasher};
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17/// The initial capacity of the hash map.
18const INITIAL_CAPACITY: usize = 128;
19
20/// An entry in the hash table.
21#[derive(Debug)]
22struct Entry<K, V> {
23    /// The key for this entry.
24    key: K,
25
26    /// The value for this entry.
27    value: V,
28
29    /// The link to the next node in the collision chain for this bucket.
30    collision_chain: Link,
31
32    /// The link to the next node in the insertion chain for this bucket.
33    insertion_chain: Link,
34}
35
36impl<K, V> Entry<K, V> {
37    /// Create a new hash table entry.
38    fn new(key: K, value: V) -> Self {
39        Self {
40            key,
41            value,
42            collision_chain: Default::default(),
43            insertion_chain: Default::default(),
44        }
45    }
46}
47
48/// An RcuListAdapter for the collision chain.
49#[derive(Debug)]
50struct CollisionAdapter;
51
52impl<K, V> RcuListAdapter<Entry<K, V>> for CollisionAdapter {
53    rcu_list_adapter!(Entry<K, V>, collision_chain);
54}
55
56/// An RcuListAdapter for the insertion chain.
57#[derive(Debug)]
58struct InsertionAdapter;
59
60impl<K, V> RcuListAdapter<Entry<K, V>> for InsertionAdapter {
61    rcu_list_adapter!(Entry<K, V>, insertion_chain);
62}
63
64/// The result of inserting an entry into the map.
65pub enum InsertionResult<V> {
66    /// The entry was inserted.
67    ///
68    /// The number of entries in the map is returned.
69    Inserted(usize),
70
71    /// The entry was updated.
72    ///
73    /// The old value is returned.
74    Updated(V),
75}
76
77/// The bucket in the hash table.
78///
79/// Each bucket is a linked list to hold the collision chain.
80type Bucket<K, V> = RcuList<Entry<K, V>, CollisionAdapter>;
81
82/// A hash map that uses read-copy-update (RCU) to manage concurrent accesses.
83///
84/// By default, this map uses `rapidhash::RapidBuildHasher`, which provides high performance.
85/// However, if this map holds keys which may be attacker-controlled, consider using
86/// `std::collections::hash_map::RandomState` instead.
87pub struct RcuRawHashMap<K, V, S = rapidhash::RapidBuildHasher>
88where
89    K: Eq + Hash + Clone + Send + Sync + 'static,
90    V: Clone + Send + Sync + 'static,
91    S: BuildHasher + Send + Sync + 'static,
92{
93    /// The table of buckets.
94    table: RcuArray<Bucket<K, V>>,
95
96    /// The number of entries in the map.
97    num_entries: AtomicUsize,
98
99    /// The entries in this map in the order they were inserted.
100    insertion_chain: RcuIntrusiveList<Entry<K, V>, InsertionAdapter>,
101
102    /// The build hasher.
103    hash_builder: S,
104}
105
106impl<K, V> Default for RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
107where
108    K: Eq + Hash + Clone + Send + Sync + 'static,
109    V: Clone + Send + Sync + 'static,
110{
111    fn default() -> Self {
112        Self::with_capacity_and_hasher(INITIAL_CAPACITY, rapidhash::RapidBuildHasher::default())
113    }
114}
115
116impl<K, V> RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
117where
118    K: Eq + Hash + Clone + Send + Sync + 'static,
119    V: Clone + Send + Sync + 'static,
120{
121    /// Creates a new hash map with the given capacity.
122    pub fn with_capacity(capacity: usize) -> Self {
123        Self::with_capacity_and_hasher(capacity, rapidhash::RapidBuildHasher::default())
124    }
125}
126
127impl<K, V, S> RcuRawHashMap<K, V, S>
128where
129    K: Eq + Hash + Clone + Send + Sync + 'static,
130    V: Clone + Send + Sync + 'static,
131    S: BuildHasher + Send + Sync + 'static,
132{
133    /// Creates a new hash map with the given capacity and hasher.
134    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
135        let mut table = Vec::new();
136        table.resize_with((capacity + 1) / 2, Default::default);
137        Self {
138            table: RcuArray::from(table),
139            num_entries: AtomicUsize::new(0),
140            insertion_chain: Default::default(),
141            hash_builder,
142        }
143    }
144
145    /// Creates a new hash map with the given hasher.
146    pub fn with_hasher(hash_builder: S) -> Self {
147        Self::with_capacity_and_hasher(INITIAL_CAPACITY, hash_builder)
148    }
149
150    /// Returns the hash of the key as a u64.
151    fn hash_key<Q>(&self, key: &Q) -> u64
152    where
153        Q: ?Sized + Hash,
154    {
155        let mut hasher = self.hash_builder.build_hasher();
156        key.hash(&mut hasher);
157        hasher.finish()
158    }
159
160    /// Returns the bucket for the given key in the given table.
161    fn get_bucket<'a, Q>(&self, table: &'a [Bucket<K, V>], key: &Q) -> &'a Bucket<K, V>
162    where
163        K: Borrow<Q>,
164        Q: ?Sized + Hash,
165    {
166        let hash = self.hash_key(key);
167        let index = hash as usize % table.len();
168        &table[index]
169    }
170
171    /// Returns a reference to the bucket for the given key.
172    fn read_bucket<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> &'a Bucket<K, V>
173    where
174        K: Borrow<Q>,
175        Q: ?Sized + Hash,
176    {
177        let table = self.table.as_slice(scope);
178        self.get_bucket(table, key)
179    }
180
181    /// Returns a reference to the value corresponding to the key.
182    ///
183    /// Another thread running concurrently might see a different value for the object.
184    pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
185    where
186        K: Borrow<Q>,
187        Q: ?Sized + Hash + Eq,
188    {
189        let bucket = self.read_bucket(scope, key);
190        bucket.iter(scope).find(|entry| entry.key.borrow() == key).map(|entry| &entry.value)
191    }
192
193    /// Returns the number of entries in the map.
194    ///
195    /// The length can change concurrently with this call.
196    pub fn len(&self) -> usize {
197        self.num_entries.load(Ordering::Relaxed)
198    }
199
200    /// Inserts a key-value pair into the map.
201    ///
202    /// If the map did not have this key present, `None` is returned.
203    ///
204    /// If the map did have this key present, the value is updated, and the old
205    /// value is returned.
206    ///
207    /// Concurrent readers might not see the inserted value until the RCU state machine has made
208    /// sufficient progress to ensure that no concurrent readers are holding read guards.
209    ///
210    /// # Safety
211    ///
212    /// Requires external synchronization to exclude concurrent writers.
213    pub unsafe fn insert(&self, scope: &RcuReadScope, key: K, value: V) -> InsertionResult<V> {
214        let mut table = self.table.as_slice(scope);
215        if self.needs_to_grow(table) {
216            // SAFETY: Our caller is required to use external synchronization to exclude concurrent
217            // writers.
218            table = unsafe { self.grow(&scope, table) };
219        }
220        let bucket = self.get_bucket(table, &key);
221        let mut cursor = bucket.cursor(&scope);
222        while let Some(entry) = cursor.current() {
223            if entry.key == key {
224                let old_value = entry.value.clone();
225                // SAFETY: We have exclusive access to the bucket because we have exclusive access
226                // to the table.
227                unsafe {
228                    let removed_entry = cursor.remove();
229                    self.insertion_chain.remove(&scope, removed_entry);
230                    let entry = bucket.push_front(&scope, Entry::new(key, value));
231                    self.insertion_chain.push_back(&scope, entry);
232                };
233                return InsertionResult::Updated(old_value);
234            }
235            cursor.advance();
236        }
237
238        // SAFETY: We have exclusive access to the bucket because we have exclusive access to the
239        // table.
240        unsafe {
241            let entry = bucket.push_front(&scope, Entry::new(key, value));
242            self.insertion_chain.push_back(&scope, entry);
243        }
244        let count = self.num_entries.fetch_add(1, Ordering::Relaxed);
245        InsertionResult::Inserted(count + 1)
246    }
247
248    /// Removes a key from the map, returning the value at the key if the key
249    /// was previously in the map.
250    ///
251    /// Concurrent readers might see the removed value until the RCU state machine has made
252    /// sufficient progress to ensure that no concurrent readers are holding read guards.
253    ///
254    /// # Safety
255    ///
256    /// Requires external synchronization to exclude concurrent writers.
257    pub unsafe fn remove<Q>(&self, key: &Q) -> Option<V>
258    where
259        K: Borrow<Q>,
260        Q: ?Sized + Hash + Eq,
261    {
262        let scope = RcuReadScope::new();
263        let bucket = self.read_bucket(&scope, key);
264        let mut cursor = bucket.cursor(&scope);
265        while let Some(entry) = cursor.current() {
266            if entry.key.borrow() == key {
267                let old_value = entry.value.clone();
268                // SAFETY: We have exclusive access to the bucket because we have exclusive access
269                // to the table.
270                unsafe {
271                    let removed_entry = cursor.remove();
272                    self.insertion_chain.remove(&scope, removed_entry);
273                };
274                self.num_entries.fetch_sub(1, Ordering::Relaxed);
275                return Some(old_value);
276            }
277            cursor.advance();
278        }
279        None
280    }
281
282    /// Whether the given table needs to grow to reduce the number of collisions.
283    fn needs_to_grow(&self, table: &[Bucket<K, V>]) -> bool {
284        self.num_entries.load(Ordering::Relaxed) > table.len() * 2
285    }
286
287    /// Grows the table to reduce the number of collisions.
288    ///
289    /// Returns a reference to the new table. Callers should be sure to update the table reference
290    /// they are using to the returned value.
291    ///
292    /// # Safety
293    ///
294    /// Requires external synchronization to exclude concurrent writers.
295    #[must_use]
296    unsafe fn grow<'a>(
297        &self,
298        scope: &'a RcuReadScope,
299        old_table: &[Bucket<K, V>],
300    ) -> &'a [Bucket<K, V>] {
301        let new_size = old_table.len() * 2;
302        let mut new_table = Vec::new();
303        let new_insertion_chain = RcuIntrusiveList::default();
304        new_table.resize_with(new_size, Default::default);
305
306        for entry in self.insertion_chain.iter(scope) {
307            let bucket = self.get_bucket(&new_table, &entry.key);
308            let key = entry.key.clone();
309            let value = entry.value.clone();
310            // SAFETY: We have exclusive access to new_table_vec because we just created it.
311            unsafe {
312                let entry = bucket.push_front(&scope, Entry::new(key, value));
313                new_insertion_chain.push_back(&scope, entry);
314            };
315        }
316
317        self.table.update(new_table);
318        // SAFETY: Our caller promises to exclude concurrent writers.
319        unsafe {
320            self.insertion_chain.update(&scope, new_insertion_chain);
321        }
322        self.table.as_slice(scope)
323    }
324
325    /// Returns a cursor that can be used to traverse and modify the map.
326    ///
327    /// The cursor iterates through the map in insertion order.
328    pub fn cursor<'a>(&'a self, scope: &'a RcuReadScope) -> RcuRawHashMapCursor<'a, K, V, S> {
329        RcuRawHashMapCursor { inner: self.insertion_chain.cursor(scope), map: self }
330    }
331
332    /// Returns an iterator over the keys in the map.
333    pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
334        self.insertion_chain.iter(scope).map(|entry| &entry.key)
335    }
336}
337
338// TODO(https://fxbug.dev/482462174): switch back to #[derive(Debug)]
339impl<K, V, S> std::fmt::Debug for RcuRawHashMap<K, V, S>
340where
341    K: Eq + Hash + Clone + Send + Sync + 'static + std::fmt::Debug,
342    V: Clone + Send + Sync + 'static + std::fmt::Debug,
343    S: std::hash::BuildHasher + Send + Sync + 'static,
344{
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        f.debug_struct("RcuRawHashMap")
347            .field("table", &self.table)
348            .field("num_entries", &self.num_entries)
349            .field("insertion_chain", &self.insertion_chain)
350            .field("hash_builder", &std::any::type_name::<S>())
351            .finish_non_exhaustive()
352    }
353}
354
355/// A cursor for traversing and modifying an `RcuRawHashMap`.
356///
357/// See `RcuRawHashMap::cursor` for more information.
358pub struct RcuRawHashMapCursor<'a, K, V, S = rapidhash::RapidBuildHasher>
359where
360    K: Eq + Hash + Clone + Send + Sync + 'static,
361    V: Clone + Send + Sync + 'static,
362    S: BuildHasher + Send + Sync + 'static,
363{
364    inner: RcuIntrusiveListCursor<'a, Entry<K, V>, InsertionAdapter>,
365    map: &'a RcuRawHashMap<K, V, S>,
366}
367
368impl<'a, K, V, S> RcuRawHashMapCursor<'a, K, V, S>
369where
370    K: Eq + Hash + Clone + Send + Sync + 'static,
371    V: Clone + Send + Sync + 'static,
372    S: BuildHasher + Send + Sync + 'static,
373{
374    /// Returns the element at the current cursor position.
375    pub fn current(&self) -> Option<(&'a K, &'a V)> {
376        self.inner.current().map(|entry| (&entry.key, &entry.value))
377    }
378
379    /// Advances the cursor to the next element in the list.
380    pub fn advance(&mut self) {
381        self.inner.advance()
382    }
383
384    /// Removes the element at the current cursor position.
385    ///
386    /// After calling `remove`, the cursor will be positioned at the next element in the list.
387    ///
388    /// Concurrent readers may continue to see this entry in the list until the RCU state machine
389    /// has made sufficient progress to ensure that no concurrent readers are holding read guards.
390    ///
391    /// # Safety
392    ///
393    /// Requires external synchronization to exclude concurrent writers.
394    pub unsafe fn remove(&mut self) -> Option<V> {
395        if let Some((key, _)) = self.current() {
396            self.advance();
397            // SAFETY: The caller promises to exclude concurrent writers.
398            unsafe { self.map.remove(key) }
399        } else {
400            None
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use fuchsia_rcu::rcu_synchronize;
409
410    #[test]
411    fn test_rcu_hash_map_custom_hasher() {
412        use std::collections::hash_map::DefaultHasher;
413        use std::hash::BuildHasherDefault;
414        let hasher = BuildHasherDefault::<DefaultHasher>::default();
415        let map = RcuRawHashMap::with_capacity_and_hasher(10, hasher);
416        let scope = RcuReadScope::new();
417        unsafe {
418            map.insert(&scope, 1, 10);
419        }
420        assert_eq!(map.get(&scope, &1), Some(&10));
421    }
422
423    #[test]
424    fn test_rcu_hash_map_insert_and_get() {
425        let map = RcuRawHashMap::default();
426        let scope = RcuReadScope::new();
427        unsafe {
428            map.insert(&scope, 1, 10);
429            map.insert(&scope, 2, 20);
430        }
431
432        assert_eq!(map.get(&scope, &1), Some(&10));
433        assert_eq!(map.get(&scope, &2), Some(&20));
434        assert_eq!(map.get(&scope, &3), None);
435
436        std::mem::drop(scope);
437        rcu_synchronize();
438    }
439
440    #[test]
441    fn test_rcu_hash_map_remove() {
442        let map = RcuRawHashMap::default();
443        let scope = RcuReadScope::new();
444        unsafe {
445            map.insert(&scope, 1, 10);
446            map.insert(&scope, 2, 20);
447        }
448
449        assert_eq!(map.get(&scope, &1), Some(&10));
450
451        unsafe {
452            assert_eq!(map.remove(&1), Some(10));
453        }
454
455        assert_eq!(map.get(&scope, &1), None);
456        assert_eq!(map.get(&scope, &2), Some(&20));
457
458        std::mem::drop(scope);
459        rcu_synchronize();
460    }
461
462    #[test]
463    fn test_rcu_hash_map_insert_update() {
464        let map = RcuRawHashMap::default();
465        let scope = RcuReadScope::new();
466        unsafe {
467            map.insert(&scope, 1, 10);
468        }
469
470        assert_eq!(map.get(&scope, &1), Some(&10));
471
472        let result = unsafe { map.insert(&scope, 1, 100) };
473        assert!(matches!(result, InsertionResult::Updated(10)));
474
475        assert_eq!(map.get(&scope, &1), Some(&100));
476
477        std::mem::drop(scope);
478        rcu_synchronize();
479    }
480
481    #[test]
482    fn test_rcu_hash_map_cursor() {
483        let map = RcuRawHashMap::default();
484        let scope = RcuReadScope::new();
485        unsafe {
486            map.insert(&scope, 1, 10);
487            map.insert(&scope, 2, 20);
488            map.insert(&scope, 3, 30);
489        }
490
491        let mut cursor = map.cursor(&scope);
492
493        assert_eq!(cursor.current(), Some((&1, &10)));
494        cursor.advance();
495        assert_eq!(cursor.current(), Some((&2, &20)));
496
497        unsafe {
498            cursor.remove();
499        }
500
501        assert_eq!(cursor.current(), Some((&3, &30)));
502        assert_eq!(map.get(&scope, &2), None);
503
504        cursor.advance();
505        assert_eq!(cursor.current(), None);
506
507        std::mem::drop(scope);
508        rcu_synchronize();
509    }
510
511    #[test]
512    fn test_rcu_hash_map_grow_maintains_order() {
513        let map = RcuRawHashMap::default();
514        let scope = RcuReadScope::new();
515        let num_elements = INITIAL_CAPACITY * 3;
516        let mut expected_order = Vec::new();
517
518        for i in 0..num_elements {
519            unsafe {
520                map.insert(&scope, i, i * 10);
521            }
522            expected_order.push((i, i * 10));
523        }
524
525        let mut cursor = map.cursor(&scope);
526        let mut actual_order = Vec::new();
527
528        while let Some((key, value)) = cursor.current() {
529            actual_order.push((*key, *value));
530            cursor.advance();
531        }
532
533        assert_eq!(actual_order, expected_order);
534
535        std::mem::drop(scope);
536        rcu_synchronize();
537    }
538    #[test]
539    fn test_rcu_hash_map_grow_overwrites_maintain_order() {
540        let map = RcuRawHashMap::default();
541        let scope = RcuReadScope::new();
542        let num_elements = INITIAL_CAPACITY * 3;
543        let mut expected_order = Vec::new();
544
545        for i in 0..num_elements {
546            unsafe {
547                map.insert(&scope, i, i * 10);
548            }
549            expected_order.push((i, i * 10));
550        }
551
552        // Overwrite some existing entries and add new ones
553        unsafe {
554            map.insert(&scope, 5, 500);
555            map.insert(&scope, INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10); // New entry
556        }
557        expected_order.retain(|(k, _)| *k != 5);
558        expected_order.push((5, 500));
559        expected_order.push((INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10));
560
561        let mut cursor = map.cursor(&scope);
562        let mut actual_order = Vec::new();
563
564        while let Some((key, value)) = cursor.current() {
565            actual_order.push((*key, *value));
566            cursor.advance();
567        }
568
569        assert_eq!(actual_order, expected_order);
570
571        std::mem::drop(scope);
572        rcu_synchronize();
573    }
574
575    #[test]
576    fn test_rcu_hash_map_grow() {
577        let map = RcuRawHashMap::default();
578        let scope = RcuReadScope::new();
579        for i in 0..(INITIAL_CAPACITY * 3) {
580            unsafe {
581                map.insert(&scope, i, i * 10);
582            }
583        }
584
585        for i in 0..(INITIAL_CAPACITY * 3) {
586            assert_eq!(map.get(&scope, &i), Some(&(i * 10)));
587        }
588
589        std::mem::drop(scope);
590        rcu_synchronize();
591    }
592}