Skip to main content

starnix_rcu/
rcu_hash_map.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_rcu::RcuReadScope;
6use fuchsia_rcu_collections::rcu_raw_hash_map::{InsertionResult, RcuRawHashMap};
7use starnix_sync::Mutex;
8use std::borrow::Borrow;
9use std::hash::{BuildHasher, Hash};
10
11/// A concurrent hash map that uses RCU for read synchronization and a mutex for write synchronization.
12///
13/// This map allows concurrent readers to access entries without blocking, while writers are
14/// synchronized via a `Mutex`.
15///
16/// By default, this map uses `rapidhash::RapidBuildHasher`, which provides high performance.
17/// However, if this map holds keys which may be attacker-controlled, consider using
18/// `std::collections::hash_map::RandomState` instead.
19pub struct RcuHashMap<K, V, S = rapidhash::RapidBuildHasher>
20where
21    K: Eq + Hash + Clone + Send + Sync + 'static,
22    V: Clone + Send + Sync + 'static,
23    S: BuildHasher + Send + Sync + 'static,
24{
25    map: RcuRawHashMap<K, V, S>,
26    mutex: Mutex<()>,
27}
28
29impl<K, V> Default for RcuHashMap<K, V, rapidhash::RapidBuildHasher>
30where
31    K: Eq + Hash + Clone + Send + Sync + 'static,
32    V: Clone + Send + Sync + 'static,
33{
34    fn default() -> Self {
35        Self { map: Default::default(), mutex: Mutex::new(()) }
36    }
37}
38
39impl<K, V, S> RcuHashMap<K, V, S>
40where
41    K: Eq + Hash + Clone + Send + Sync + 'static,
42    V: Clone + Send + Sync + 'static,
43    S: BuildHasher + Send + Sync + 'static,
44{
45    /// Creates a new hash map with the given capacity and hasher.
46    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
47        Self {
48            map: RcuRawHashMap::with_capacity_and_hasher(capacity, hash_builder),
49            mutex: Mutex::new(()),
50        }
51    }
52
53    /// Creates a new hash map with the given hasher.
54    pub fn with_hasher(hash_builder: S) -> Self {
55        Self { map: RcuRawHashMap::with_hasher(hash_builder), mutex: Mutex::new(()) }
56    }
57
58    /// Returns a reference to the value associated with the given key, if it exists.
59    ///
60    /// The returned reference is bound to the lifetime of the `RcuReadScope`.
61    pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
62    where
63        K: Borrow<Q>,
64        Q: ?Sized + Hash + Eq,
65    {
66        self.map.get(scope, key)
67    }
68
69    /// Locks the map for exclusive access, returning a guard that allows mutation.
70    pub fn lock(&self) -> RcuHashMapGuard<'_, K, V, S> {
71        RcuHashMapGuard { map: &self.map, _guard: self.mutex.lock() }
72    }
73
74    /// Inserts a key-value pair into the map, returning the old value if the key was already present.
75    pub fn insert(&self, key: K, value: V) -> Option<V> {
76        self.lock().insert(key, value)
77    }
78
79    /// Removes a key from the map, returning the value if the key was present.
80    pub fn remove<Q>(&self, key: &Q) -> Option<V>
81    where
82        K: Borrow<Q>,
83        Q: ?Sized + Hash + Eq,
84    {
85        self.lock().remove(key)
86    }
87
88    /// Returns an iterator over the map's entries.
89    pub fn iter<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = (&'a K, &'a V)> {
90        let mut cursor = self.map.cursor(scope);
91        std::iter::from_fn(move || {
92            let current = cursor.current();
93            if current.is_some() {
94                cursor.advance();
95            }
96            current
97        })
98    }
99
100    /// Returns an iterator over the map's keys.
101    pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
102        self.iter(scope).map(|(k, _)| k)
103    }
104}
105
106// TODO(https://fxbug.dev/482462174): switch back to #[derive(Debug)]
107impl<K, V, S> std::fmt::Debug for RcuHashMap<K, V, S>
108where
109    K: Eq + Hash + std::fmt::Debug + Clone + Send + Sync + 'static,
110    V: std::fmt::Debug + Clone + Send + Sync + 'static,
111    S: BuildHasher + Send + Sync + 'static,
112{
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("RcuHashMap").field("map", &self.map).finish()
115    }
116}
117
118/// A guard that provides exclusive access to the `RcuHashMap`.
119pub struct RcuHashMapGuard<'a, K, V, S = rapidhash::RapidBuildHasher>
120where
121    K: Eq + Hash + Clone + Send + Sync + 'static,
122    V: Clone + Send + Sync + 'static,
123    S: BuildHasher + Send + Sync + 'static,
124{
125    map: &'a RcuRawHashMap<K, V, S>,
126    _guard: starnix_sync::MutexGuard<'a, ()>,
127}
128
129impl<'a, K, V, S> RcuHashMapGuard<'a, K, V, S>
130where
131    K: Eq + Hash + Clone + Send + Sync + 'static,
132    V: Clone + Send + Sync + 'static,
133    S: BuildHasher + Send + Sync + 'static,
134{
135    /// Returns a copy (clone) of the value associated with the given key, if it exists.
136    pub fn get<Q>(&self, key: &Q) -> Option<V>
137    where
138        K: Borrow<Q>,
139        Q: ?Sized + Hash + Eq,
140    {
141        let scope = RcuReadScope::new();
142        self.map.get(&scope, key).cloned()
143    }
144
145    /// Inserts a key-value pair into the map.
146    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
147        let scope = RcuReadScope::new();
148        // SAFETY: We have exclusive access to the map because we have exclusive access to the mutex.
149        match unsafe { self.map.insert(&scope, key, value) } {
150            InsertionResult::Inserted(_) => None,
151            InsertionResult::Updated(old_value) => Some(old_value),
152        }
153    }
154
155    /// Removes a key from the map.
156    pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
157    where
158        K: Borrow<Q>,
159        Q: ?Sized + Hash + Eq,
160    {
161        // SAFETY: We have exclusive access to the map because we have exclusive access to the mutex.
162        unsafe { self.map.remove(key) }
163    }
164
165    /// Removes all values from the map and returns them.
166    pub fn drain<'b>(&'b mut self) -> impl Iterator<Item = (K, V)> + 'b {
167        let scope = RcuReadScope::new();
168        // We collect the keys first because we cannot iterate and modify the map at the same time.
169        #[allow(clippy::needless_collect)]
170        let keys: Vec<_> = self.map.keys(&scope).map(Clone::clone).collect();
171        keys.into_iter().filter_map(move |k| self.remove(&k).map(|v| (k, v)))
172    }
173
174    /// Returns true if the map contains a value for the specified key.
175    pub fn contains_key<Q>(&self, key: &Q) -> bool
176    where
177        K: Borrow<Q>,
178        Q: ?Sized + Hash + Eq,
179    {
180        self.get(key).is_some()
181    }
182
183    /// Gets the given key's corresponding entry in the map for in-place manipulation.
184    pub fn entry<'b>(&'b mut self, key: K) -> Entry<'b, 'a, K, V, S> {
185        if self.get(&key).is_some() {
186            Entry::Occupied(OccupiedEntry { guard: self, key })
187        } else {
188            Entry::Vacant(VacantEntry { guard: self, key })
189        }
190    }
191}
192
193/// A view into a single entry in the map, which may either be vacant or occupied.
194pub enum Entry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
195where
196    K: Eq + Hash + Clone + Send + Sync + 'static,
197    V: Clone + Send + Sync + 'static,
198    S: BuildHasher + Send + Sync + 'static,
199{
200    /// An occupied entry.
201    Occupied(OccupiedEntry<'b, 'a, K, V, S>),
202    /// A vacant entry.
203    Vacant(VacantEntry<'b, 'a, K, V, S>),
204}
205
206impl<'b, 'a, K, V, S> Entry<'b, 'a, K, V, S>
207where
208    K: Eq + Hash + Clone + Send + Sync + 'static,
209    V: Clone + Send + Sync + 'static,
210    S: BuildHasher + Send + Sync + 'static,
211{
212    /// Ensures a value is in the entry by inserting the result of the default function if empty,
213    /// and returns an occupied entry.
214    pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> OccupiedEntry<'b, 'a, K, V, S> {
215        match self {
216            Entry::Occupied(entry) => entry,
217            Entry::Vacant(entry) => entry.insert_entry(default()),
218        }
219    }
220}
221
222/// A view into an occupied entry in a `RcuHashMap`.
223pub struct OccupiedEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
224where
225    K: Eq + Hash + Clone + Send + Sync + 'static,
226    V: Clone + Send + Sync + 'static,
227    S: BuildHasher + Send + Sync + 'static,
228{
229    guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
230    key: K,
231}
232
233impl<K, V, S> OccupiedEntry<'_, '_, K, V, S>
234where
235    K: Eq + Hash + Clone + Send + Sync + 'static,
236    V: Clone + Send + Sync + 'static,
237    S: BuildHasher + Send + Sync + 'static,
238{
239    /// Gets a copy (clone) of the value in the entry.
240    pub fn get(&self) -> V {
241        self.guard.get(&self.key).unwrap()
242    }
243
244    /// Sets the value of the entry, returning the old value.
245    pub fn insert(&mut self, value: V) -> V {
246        self.guard.insert(self.key.clone(), value).unwrap()
247    }
248
249    /// Removes the entry from the map, returning the value.
250    pub fn remove(self) -> V {
251        self.guard.remove(&self.key).unwrap()
252    }
253}
254
255/// A view into a vacant entry in a `RcuHashMap`.
256pub struct VacantEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
257where
258    K: Eq + Hash + Clone + Send + Sync + 'static,
259    V: Clone + Send + Sync + 'static,
260    S: BuildHasher + Send + Sync + 'static,
261{
262    guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
263    key: K,
264}
265
266impl<'b, 'a, K, V, S> VacantEntry<'b, 'a, K, V, S>
267where
268    K: Eq + Hash + Clone + Send + Sync + 'static,
269    V: Clone + Send + Sync + 'static,
270    S: BuildHasher + Send + Sync + 'static,
271{
272    /// Sets the value of the entry with the VacantEntry's key.
273    pub fn insert(self, value: V) {
274        self.guard.insert(self.key, value);
275    }
276
277    /// Sets the value of the entry with the VacantEntry's key, and returns an occupied entry.
278    pub fn insert_entry(self, value: V) -> OccupiedEntry<'b, 'a, K, V, S> {
279        self.guard.insert(self.key.clone(), value);
280        OccupiedEntry { guard: self.guard, key: self.key }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use fuchsia_rcu::rcu_synchronize;
288
289    #[test]
290    fn test_rcu_hash_map_custom_hasher() {
291        use std::collections::hash_map::DefaultHasher;
292        use std::hash::BuildHasherDefault;
293        let hasher = BuildHasherDefault::<DefaultHasher>::default();
294        let map = RcuHashMap::with_capacity_and_hasher(10, hasher);
295        let mut guard = map.lock();
296        guard.insert(1, 10);
297        assert_eq!(guard.get(&1), Some(10));
298    }
299
300    #[test]
301    fn test_rcu_hash_map_insert_and_get() {
302        let map = RcuHashMap::<i32, i32>::default();
303        let mut guard = map.lock();
304        let scope = RcuReadScope::new();
305
306        guard.insert(1, 10);
307        guard.insert(2, 20);
308
309        assert_eq!(guard.get(&1), Some(10));
310        assert_eq!(guard.get(&2), Some(20));
311        assert_eq!(guard.get(&3), None);
312
313        // Verify we can read without the lock too
314        drop(guard);
315        assert_eq!(map.get(&scope, &1), Some(&10));
316        assert_eq!(map.get(&scope, &2), Some(&20));
317
318        drop(scope);
319        rcu_synchronize();
320    }
321
322    #[test]
323    fn test_rcu_hash_map_update() {
324        let map = RcuHashMap::<i32, i32>::default();
325        let mut guard = map.lock();
326        let scope = RcuReadScope::new();
327
328        guard.insert(1, 10);
329        assert_eq!(guard.get(&1), Some(10));
330
331        guard.insert(1, 20);
332        assert_eq!(guard.get(&1), Some(20));
333
334        drop(guard);
335        assert_eq!(map.get(&scope, &1), Some(&20));
336
337        drop(scope);
338        rcu_synchronize();
339    }
340
341    #[test]
342    fn test_rcu_hash_map_remove() {
343        let map = RcuHashMap::<i32, i32>::default();
344        let mut guard = map.lock();
345        let scope = RcuReadScope::new();
346
347        guard.insert(1, 10);
348        assert_eq!(guard.get(&1), Some(10));
349
350        guard.remove(&1);
351        assert_eq!(guard.get(&1), None);
352
353        drop(guard);
354        assert_eq!(map.get(&scope, &1), None);
355
356        drop(scope);
357        rcu_synchronize();
358    }
359
360    #[test]
361    fn test_rcu_hash_map_entry_api() {
362        let map = RcuHashMap::<i32, i32>::default();
363        let mut guard = map.lock();
364
365        // Vacant entry
366        match guard.entry(1) {
367            Entry::Vacant(e) => e.insert(10),
368            Entry::Occupied(_) => panic!("Should be vacant"),
369        }
370        assert_eq!(guard.get(&1), Some(10));
371
372        // Occupied entry
373        match guard.entry(1) {
374            Entry::Occupied(mut e) => {
375                assert_eq!(e.get(), 10);
376                e.insert(20);
377            }
378            Entry::Vacant(_) => panic!("Should be occupied"),
379        }
380        assert_eq!(guard.get(&1), Some(20));
381
382        drop(guard);
383        rcu_synchronize();
384    }
385
386    #[test]
387    fn test_rcu_hash_map_iter() {
388        let map = RcuHashMap::<i32, i32>::default();
389        let scope = RcuReadScope::new();
390        map.insert(1, 10);
391        map.insert(2, 20);
392        map.insert(3, 30);
393
394        let mut items: Vec<_> = map.iter(&scope).collect();
395        items.sort_by_key(|(k, _)| **k);
396        assert_eq!(items, vec![(&1, &10), (&2, &20), (&3, &30)]);
397    }
398
399    #[test]
400    fn test_rcu_hash_map_keys() {
401        let map = RcuHashMap::<i32, i32>::default();
402        let scope = RcuReadScope::new();
403        map.insert(1, 10);
404        map.insert(2, 20);
405        map.insert(3, 30);
406
407        let mut keys: Vec<_> = map.keys(&scope).collect();
408        keys.sort();
409        assert_eq!(keys, vec![&1, &2, &3]);
410    }
411
412    #[test]
413    fn test_rcu_hash_map_or_insert_with() {
414        let map = RcuHashMap::<i32, i32>::default();
415        let mut guard = map.lock();
416
417        // test or_insert_with
418        guard.entry(1).or_insert_with(|| 10);
419        assert!(guard.contains_key(&1));
420        assert_eq!(guard.get(&1), Some(10));
421
422        // test or_insert_with existing
423        guard.entry(1).or_insert_with(|| 20);
424        assert_eq!(guard.get(&1), Some(10));
425
426        // test OccupiedEntry::remove
427        match guard.entry(1) {
428            Entry::Occupied(e) => {
429                assert_eq!(e.remove(), 10);
430            }
431            Entry::Vacant(_) => panic!("Should be occupied"),
432        }
433        assert!(!guard.contains_key(&1));
434    }
435
436    #[test]
437    fn test_rcu_hash_map_drain() {
438        let map = RcuHashMap::<i32, i32>::default();
439        let mut guard = map.lock();
440
441        guard.insert(1, 10);
442        guard.insert(2, 20);
443        guard.insert(3, 30);
444
445        let mut items: Vec<_> = guard.drain().collect();
446        items.sort_by_key(|(k, _)| *k);
447        assert_eq!(items, vec![(1, 10), (2, 20), (3, 30)]);
448
449        assert!(!guard.contains_key(&1));
450        assert!(!guard.contains_key(&2));
451        assert!(!guard.contains_key(&3));
452    }
453}