Skip to main content

starnix_rcu/
rcu_hash_map.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_rcu::RcuReadScope;
6use fuchsia_rcu_collections::rcu_raw_hash_map::{InsertionResult, RcuRawHashMap};
7use starnix_sync::Mutex;
8use std::borrow::Borrow;
9use std::hash::{BuildHasher, Hash};
10
11/// A concurrent hash map that uses RCU for read synchronization and a mutex for write synchronization.
12///
13/// This map allows concurrent readers to access entries without blocking, while writers are
14/// synchronized via a `Mutex`.
15///
16/// By default, this map uses `rapidhash::RapidBuildHasher`, which provides high performance.
17/// However, if this map holds keys which may be attacker-controlled, consider using
18/// `std::collections::hash_map::RandomState` instead.
19pub struct RcuHashMap<K, V, S = rapidhash::RapidBuildHasher>
20where
21    K: Eq + Hash + Clone + Send + Sync + 'static,
22    V: Clone + Send + Sync + 'static,
23    S: BuildHasher + Send + Sync + 'static,
24{
25    map: RcuRawHashMap<K, V, S>,
26    mutex: Mutex<()>,
27}
28
29impl<K, V> Default for RcuHashMap<K, V, rapidhash::RapidBuildHasher>
30where
31    K: Eq + Hash + Clone + Send + Sync + 'static,
32    V: Clone + Send + Sync + 'static,
33{
34    fn default() -> Self {
35        Self { map: Default::default(), mutex: Mutex::new(()) }
36    }
37}
38
39impl<K, V, S> RcuHashMap<K, V, S>
40where
41    K: Eq + Hash + Clone + Send + Sync + 'static,
42    V: Clone + Send + Sync + 'static,
43    S: BuildHasher + Send + Sync + 'static,
44{
45    /// Creates a new hash map with the given capacity and hasher.
46    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
47        Self {
48            map: RcuRawHashMap::with_capacity_and_hasher(capacity, hash_builder),
49            mutex: Mutex::new(()),
50        }
51    }
52
53    /// Creates a new hash map with the given hasher.
54    pub fn with_hasher(hash_builder: S) -> Self {
55        Self { map: RcuRawHashMap::with_hasher(hash_builder), mutex: Mutex::new(()) }
56    }
57
58    /// Returns a reference to the value associated with the given key, if it exists.
59    ///
60    /// The returned reference is bound to the lifetime of the `RcuReadScope`.
61    pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
62    where
63        K: Borrow<Q>,
64        Q: ?Sized + Hash + Eq,
65    {
66        self.map.get(scope, key)
67    }
68
69    /// Locks the map for exclusive access, returning a guard that allows mutation.
70    pub fn lock(&self) -> RcuHashMapGuard<'_, K, V, S> {
71        RcuHashMapGuard { map: &self.map, _guard: self.mutex.lock() }
72    }
73
74    /// Inserts a key-value pair into the map, returning the old value if the key was already present.
75    pub fn insert(&self, key: K, value: V) -> Option<V> {
76        self.lock().insert(key, value)
77    }
78
79    /// Removes a key from the map, returning the value if the key was present.
80    pub fn remove<Q>(&self, key: &Q) -> Option<V>
81    where
82        K: Borrow<Q>,
83        Q: ?Sized + Hash + Eq,
84    {
85        self.lock().remove(key)
86    }
87
88    /// Returns an iterator over the map's entries.
89    pub fn iter<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = (&'a K, &'a V)> {
90        let mut cursor = self.map.cursor(scope);
91        std::iter::from_fn(move || {
92            let current = cursor.current();
93            if current.is_some() {
94                cursor.advance();
95            }
96            current
97        })
98    }
99
100    /// Returns an iterator over the map's keys.
101    pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
102        self.iter(scope).map(|(k, _)| k)
103    }
104
105    /// Returns the number of entries in the map.
106    pub fn len(&self) -> usize {
107        self.map.len()
108    }
109}
110
111// TODO(https://fxbug.dev/482462174): switch back to #[derive(Debug)]
112impl<K, V, S> std::fmt::Debug for RcuHashMap<K, V, S>
113where
114    K: Eq + Hash + std::fmt::Debug + Clone + Send + Sync + 'static,
115    V: std::fmt::Debug + Clone + Send + Sync + 'static,
116    S: BuildHasher + Send + Sync + 'static,
117{
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("RcuHashMap").field("map", &self.map).finish()
120    }
121}
122
123/// A guard that provides exclusive access to the `RcuHashMap`.
124pub struct RcuHashMapGuard<'a, K, V, S = rapidhash::RapidBuildHasher>
125where
126    K: Eq + Hash + Clone + Send + Sync + 'static,
127    V: Clone + Send + Sync + 'static,
128    S: BuildHasher + Send + Sync + 'static,
129{
130    map: &'a RcuRawHashMap<K, V, S>,
131    _guard: starnix_sync::MutexGuard<'a, ()>,
132}
133
134impl<'a, K, V, S> RcuHashMapGuard<'a, K, V, S>
135where
136    K: Eq + Hash + Clone + Send + Sync + 'static,
137    V: Clone + Send + Sync + 'static,
138    S: BuildHasher + Send + Sync + 'static,
139{
140    /// Returns a copy (clone) of the value associated with the given key, if it exists.
141    pub fn get<Q>(&self, key: &Q) -> Option<V>
142    where
143        K: Borrow<Q>,
144        Q: ?Sized + Hash + Eq,
145    {
146        let scope = RcuReadScope::new();
147        self.map.get(&scope, key).cloned()
148    }
149
150    /// Inserts a key-value pair into the map.
151    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
152        let scope = RcuReadScope::new();
153        // SAFETY: We have exclusive access to the map because we have exclusive access to the mutex.
154        match unsafe { self.map.insert(&scope, key, value) } {
155            InsertionResult::Inserted(_) => None,
156            InsertionResult::Updated(old_value) => Some(old_value),
157        }
158    }
159
160    /// Removes a key from the map.
161    pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
162    where
163        K: Borrow<Q>,
164        Q: ?Sized + Hash + Eq,
165    {
166        // SAFETY: We have exclusive access to the map because we have exclusive access to the mutex.
167        unsafe { self.map.remove(key) }
168    }
169
170    /// Removes all values from the map and returns them.
171    pub fn drain<'b>(&'b mut self) -> impl Iterator<Item = (K, V)> + 'b {
172        let scope = RcuReadScope::new();
173        // We collect the keys first because we cannot iterate and modify the map at the same time.
174        #[allow(clippy::needless_collect)]
175        let keys: Vec<_> = self.map.keys(&scope).map(Clone::clone).collect();
176        keys.into_iter().filter_map(move |k| self.remove(&k).map(|v| (k, v)))
177    }
178
179    /// Returns true if the map contains a value for the specified key.
180    pub fn contains_key<Q>(&self, key: &Q) -> bool
181    where
182        K: Borrow<Q>,
183        Q: ?Sized + Hash + Eq,
184    {
185        self.get(key).is_some()
186    }
187
188    /// Gets the given key's corresponding entry in the map for in-place manipulation.
189    pub fn entry<'b>(&'b mut self, key: K) -> Entry<'b, 'a, K, V, S> {
190        if self.get(&key).is_some() {
191            Entry::Occupied(OccupiedEntry { guard: self, key })
192        } else {
193            Entry::Vacant(VacantEntry { guard: self, key })
194        }
195    }
196}
197
198/// A view into a single entry in the map, which may either be vacant or occupied.
199pub enum Entry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
200where
201    K: Eq + Hash + Clone + Send + Sync + 'static,
202    V: Clone + Send + Sync + 'static,
203    S: BuildHasher + Send + Sync + 'static,
204{
205    /// An occupied entry.
206    Occupied(OccupiedEntry<'b, 'a, K, V, S>),
207    /// A vacant entry.
208    Vacant(VacantEntry<'b, 'a, K, V, S>),
209}
210
211impl<'b, 'a, K, V, S> Entry<'b, 'a, K, V, S>
212where
213    K: Eq + Hash + Clone + Send + Sync + 'static,
214    V: Clone + Send + Sync + 'static,
215    S: BuildHasher + Send + Sync + 'static,
216{
217    /// Ensures a value is in the entry by inserting the result of the default function if empty,
218    /// and returns an occupied entry.
219    pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> OccupiedEntry<'b, 'a, K, V, S> {
220        match self {
221            Entry::Occupied(entry) => entry,
222            Entry::Vacant(entry) => entry.insert_entry(default()),
223        }
224    }
225}
226
227/// A view into an occupied entry in a `RcuHashMap`.
228pub struct OccupiedEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
229where
230    K: Eq + Hash + Clone + Send + Sync + 'static,
231    V: Clone + Send + Sync + 'static,
232    S: BuildHasher + Send + Sync + 'static,
233{
234    guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
235    key: K,
236}
237
238impl<K, V, S> OccupiedEntry<'_, '_, K, V, S>
239where
240    K: Eq + Hash + Clone + Send + Sync + 'static,
241    V: Clone + Send + Sync + 'static,
242    S: BuildHasher + Send + Sync + 'static,
243{
244    /// Gets a copy (clone) of the value in the entry.
245    pub fn get(&self) -> V {
246        self.guard.get(&self.key).unwrap()
247    }
248
249    /// Sets the value of the entry, returning the old value.
250    pub fn insert(&mut self, value: V) -> V {
251        self.guard.insert(self.key.clone(), value).unwrap()
252    }
253
254    /// Removes the entry from the map, returning the value.
255    pub fn remove(self) -> V {
256        self.guard.remove(&self.key).unwrap()
257    }
258}
259
260/// A view into a vacant entry in a `RcuHashMap`.
261pub struct VacantEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
262where
263    K: Eq + Hash + Clone + Send + Sync + 'static,
264    V: Clone + Send + Sync + 'static,
265    S: BuildHasher + Send + Sync + 'static,
266{
267    guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
268    key: K,
269}
270
271impl<'b, 'a, K, V, S> VacantEntry<'b, 'a, K, V, S>
272where
273    K: Eq + Hash + Clone + Send + Sync + 'static,
274    V: Clone + Send + Sync + 'static,
275    S: BuildHasher + Send + Sync + 'static,
276{
277    /// Sets the value of the entry with the VacantEntry's key.
278    pub fn insert(self, value: V) {
279        self.guard.insert(self.key, value);
280    }
281
282    /// Sets the value of the entry with the VacantEntry's key, and returns an occupied entry.
283    pub fn insert_entry(self, value: V) -> OccupiedEntry<'b, 'a, K, V, S> {
284        self.guard.insert(self.key.clone(), value);
285        OccupiedEntry { guard: self.guard, key: self.key }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use fuchsia_rcu::rcu_synchronize;
293
294    #[test]
295    fn test_rcu_hash_map_custom_hasher() {
296        use std::collections::hash_map::DefaultHasher;
297        use std::hash::BuildHasherDefault;
298        let hasher = BuildHasherDefault::<DefaultHasher>::default();
299        let map = RcuHashMap::with_capacity_and_hasher(10, hasher);
300        let mut guard = map.lock();
301        guard.insert(1, 10);
302        assert_eq!(guard.get(&1), Some(10));
303    }
304
305    #[test]
306    fn test_rcu_hash_map_insert_and_get() {
307        let map = RcuHashMap::<i32, i32>::default();
308        let mut guard = map.lock();
309        let scope = RcuReadScope::new();
310
311        guard.insert(1, 10);
312        guard.insert(2, 20);
313
314        assert_eq!(guard.get(&1), Some(10));
315        assert_eq!(guard.get(&2), Some(20));
316        assert_eq!(guard.get(&3), None);
317
318        // Verify we can read without the lock too
319        drop(guard);
320        assert_eq!(map.get(&scope, &1), Some(&10));
321        assert_eq!(map.get(&scope, &2), Some(&20));
322
323        drop(scope);
324        rcu_synchronize();
325    }
326
327    #[test]
328    fn test_rcu_hash_map_update() {
329        let map = RcuHashMap::<i32, i32>::default();
330        let mut guard = map.lock();
331        let scope = RcuReadScope::new();
332
333        guard.insert(1, 10);
334        assert_eq!(guard.get(&1), Some(10));
335
336        guard.insert(1, 20);
337        assert_eq!(guard.get(&1), Some(20));
338
339        drop(guard);
340        assert_eq!(map.get(&scope, &1), Some(&20));
341
342        drop(scope);
343        rcu_synchronize();
344    }
345
346    #[test]
347    fn test_rcu_hash_map_remove() {
348        let map = RcuHashMap::<i32, i32>::default();
349        let mut guard = map.lock();
350        let scope = RcuReadScope::new();
351
352        guard.insert(1, 10);
353        assert_eq!(guard.get(&1), Some(10));
354
355        guard.remove(&1);
356        assert_eq!(guard.get(&1), None);
357
358        drop(guard);
359        assert_eq!(map.get(&scope, &1), None);
360
361        drop(scope);
362        rcu_synchronize();
363    }
364
365    #[test]
366    fn test_rcu_hash_map_entry_api() {
367        let map = RcuHashMap::<i32, i32>::default();
368        let mut guard = map.lock();
369
370        // Vacant entry
371        match guard.entry(1) {
372            Entry::Vacant(e) => e.insert(10),
373            Entry::Occupied(_) => panic!("Should be vacant"),
374        }
375        assert_eq!(guard.get(&1), Some(10));
376
377        // Occupied entry
378        match guard.entry(1) {
379            Entry::Occupied(mut e) => {
380                assert_eq!(e.get(), 10);
381                e.insert(20);
382            }
383            Entry::Vacant(_) => panic!("Should be occupied"),
384        }
385        assert_eq!(guard.get(&1), Some(20));
386
387        drop(guard);
388        rcu_synchronize();
389    }
390
391    #[test]
392    fn test_rcu_hash_map_iter() {
393        let map = RcuHashMap::<i32, i32>::default();
394        let scope = RcuReadScope::new();
395        map.insert(1, 10);
396        map.insert(2, 20);
397        map.insert(3, 30);
398
399        let mut items: Vec<_> = map.iter(&scope).collect();
400        items.sort_by_key(|(k, _)| **k);
401        assert_eq!(items, vec![(&1, &10), (&2, &20), (&3, &30)]);
402    }
403
404    #[test]
405    fn test_rcu_hash_map_keys() {
406        let map = RcuHashMap::<i32, i32>::default();
407        let scope = RcuReadScope::new();
408        map.insert(1, 10);
409        map.insert(2, 20);
410        map.insert(3, 30);
411
412        let mut keys: Vec<_> = map.keys(&scope).collect();
413        keys.sort();
414        assert_eq!(keys, vec![&1, &2, &3]);
415    }
416
417    #[test]
418    fn test_rcu_hash_map_len() {
419        let map = RcuHashMap::<i32, i32>::default();
420        map.insert(1, 10);
421        map.insert(2, 20);
422        map.insert(3, 30);
423
424        assert_eq!(map.len(), 3);
425    }
426
427    #[test]
428    fn test_rcu_hash_map_or_insert_with() {
429        let map = RcuHashMap::<i32, i32>::default();
430        let mut guard = map.lock();
431
432        // test or_insert_with
433        guard.entry(1).or_insert_with(|| 10);
434        assert!(guard.contains_key(&1));
435        assert_eq!(guard.get(&1), Some(10));
436
437        // test or_insert_with existing
438        guard.entry(1).or_insert_with(|| 20);
439        assert_eq!(guard.get(&1), Some(10));
440
441        // test OccupiedEntry::remove
442        match guard.entry(1) {
443            Entry::Occupied(e) => {
444                assert_eq!(e.remove(), 10);
445            }
446            Entry::Vacant(_) => panic!("Should be occupied"),
447        }
448        assert!(!guard.contains_key(&1));
449    }
450
451    #[test]
452    fn test_rcu_hash_map_drain() {
453        let map = RcuHashMap::<i32, i32>::default();
454        let mut guard = map.lock();
455
456        guard.insert(1, 10);
457        guard.insert(2, 20);
458        guard.insert(3, 30);
459
460        let mut items: Vec<_> = guard.drain().collect();
461        items.sort_by_key(|(k, _)| *k);
462        assert_eq!(items, vec![(1, 10), (2, 20), (3, 30)]);
463
464        assert!(!guard.contains_key(&1));
465        assert!(!guard.contains_key(&2));
466        assert!(!guard.contains_key(&3));
467    }
468
469    #[test]
470    fn test_rcu_hash_map_capacity_zero() {
471        use std::collections::hash_map::RandomState;
472        let map =
473            RcuHashMap::<i32, i32, RandomState>::with_capacity_and_hasher(0, RandomState::new());
474        let mut guard = map.lock();
475
476        assert_eq!(guard.get(&1), None);
477
478        guard.insert(1, 10);
479        assert_eq!(guard.get(&1), Some(10));
480
481        guard.remove(&1);
482        assert_eq!(guard.get(&1), None);
483    }
484}