Skip to main content

starnix_rcu/
rcu_cache.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_rcu::RcuReadScope;
6use fuchsia_rcu_collections::rcu_raw_hash_map::{InsertionResult, RcuRawHashMap};
7use starnix_sync::{Mutex, MutexGuard};
8use std::hash::Hash;
9
10pub enum RcuCacheInsertionResult<V> {
11    /// The entry was inserted.
12    Inserted,
13
14    /// The entry was updated.
15    ///
16    /// The old value is returned.
17    Updated(V),
18
19    /// The entry was inserted and caused another entry to be evicted.
20    ///
21    /// The evicted value is returned.
22    Evicted(V),
23}
24
25/// A cache that uses RCU to provide thread-safe access to a hash map.
26///
27/// This is similar to `RcuHashMap`, but it also evicts items when the cache
28/// exceeds a specified capacity.
29///
30/// Entries are evicted in a FIFO manner.
31///
32/// By default, this map uses `rapidhash::RapidBuildHasher`, which provides high performance.
33/// However, if this map holds keys which may be attacker-controlled, consider using
34/// `std::collections::hash_map::RandomState` instead.
35#[derive(Debug)]
36pub struct RcuCache<K, V, S = rapidhash::RapidBuildHasher>
37where
38    K: Eq + Hash + Clone + Send + Sync + 'static,
39    V: Clone + Send + Sync + 'static,
40    S: std::hash::BuildHasher + Send + Sync + 'static,
41{
42    /// The maximum number of entries in the cache.
43    capacity: usize,
44
45    /// The underlying hash map.
46    map: RcuRawHashMap<K, V, S>,
47
48    /// A mutex to provide synchronization for writing to the map.
49    mutex: Mutex<()>,
50}
51
52impl<K, V> RcuCache<K, V, rapidhash::RapidBuildHasher>
53where
54    K: Eq + Hash + Clone + Send + Sync + 'static,
55    V: Clone + Send + Sync + 'static,
56{
57    /// Creates a new `RcuCache` with the specified capacity.
58    pub fn new(capacity: usize) -> Self {
59        Self { capacity, map: RcuRawHashMap::with_capacity(capacity + 1), mutex: Mutex::new(()) }
60    }
61}
62
63impl<K, V, S> RcuCache<K, V, S>
64where
65    K: Eq + Hash + Clone + Send + Sync + 'static,
66    V: Clone + Send + Sync + 'static,
67    S: std::hash::BuildHasher + Send + Sync + 'static,
68{
69    /// Creates a new `RcuCache` with the specified capacity and hasher.
70    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
71        Self {
72            capacity,
73            map: RcuRawHashMap::with_capacity_and_hasher(capacity + 1, hash_builder),
74            mutex: Mutex::new(()),
75        }
76    }
77}
78
79impl<K, V, S> RcuCache<K, V, S>
80where
81    K: Eq + Hash + Clone + Send + Sync + 'static,
82    V: Clone + Send + Sync + 'static,
83    S: std::hash::BuildHasher + Send + Sync + 'static,
84{
85    /// Returns the capacity with which this instance was initialized.
86    pub fn capacity(&self) -> usize {
87        self.capacity
88    }
89
90    /// Returns the number of entries in the cache.
91    pub fn len(&self) -> usize {
92        self.map.len()
93    }
94
95    /// Returns a reference to the value associated with the key.
96    pub fn get<'a>(&self, scope: &'a RcuReadScope, key: &K) -> Option<&'a V> {
97        self.map.get(scope, key)
98    }
99
100    pub fn lock(&self) -> RcuCacheGuard<'_, K, V, S> {
101        let guard = self.mutex.lock();
102        RcuCacheGuard { cache: self, _guard: guard }
103    }
104
105    /// Removes all entries from the cache.
106    pub fn clear(&self) {
107        let _guard = self.mutex.lock();
108        let scope = RcuReadScope::new();
109        let mut cursor = self.map.cursor(&scope);
110        loop {
111            // SAFETY: We have exclusive access to the map because we have exclusive access to the
112            // mutex.
113            let removed = unsafe { cursor.remove() };
114            if removed.is_none() {
115                break;
116            }
117        }
118    }
119}
120
121pub struct RcuCacheGuard<'a, K, V, S = rapidhash::RapidBuildHasher>
122where
123    K: Eq + Hash + Clone + Send + Sync + 'static,
124    V: Clone + Send + Sync + 'static,
125    S: std::hash::BuildHasher + Send + Sync + 'static,
126{
127    cache: &'a RcuCache<K, V, S>,
128    _guard: MutexGuard<'a, ()>,
129}
130
131impl<'a, K, V, S> RcuCacheGuard<'a, K, V, S>
132where
133    K: Eq + Hash + Clone + Send + Sync + 'static,
134    V: Clone + Send + Sync + 'static,
135    S: std::hash::BuildHasher + Send + Sync + 'static,
136{
137    pub fn get<'rcu>(&self, scope: &'rcu RcuReadScope, key: &K) -> Option<&'rcu V> {
138        self.cache.map.get(scope, key)
139    }
140
141    /// Inserts a key-value pair into the cache.
142    ///
143    /// If the cache exceeds its capacity, entries are evicted in a FIFO manner.
144    pub fn insert(&self, scope: &RcuReadScope, key: K, value: V) -> RcuCacheInsertionResult<V> {
145        // SAFETY: We have exclusive access to the map because we have exclusive access to the mutex.
146        match unsafe { self.cache.map.insert(scope, key, value) } {
147            InsertionResult::Inserted(count) => {
148                if count > self.cache.capacity {
149                    // The mutex should prevent any other modifications to the map while the insert
150                    // operation is in progress.
151                    assert!(count == self.cache.capacity + 1);
152                    let mut cursor = self.cache.map.cursor(&scope);
153                    // SAFETY: We have exclusive access to the map because we have exclusive access
154                    // to the mutex.
155                    if let Some(old_value) = unsafe { cursor.remove() } {
156                        RcuCacheInsertionResult::Evicted(old_value)
157                    } else {
158                        unreachable!("cache is full but no entries to evict")
159                    }
160                } else {
161                    RcuCacheInsertionResult::Inserted
162                }
163            }
164            InsertionResult::Updated(old_value) => RcuCacheInsertionResult::Updated(old_value),
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use fuchsia_rcu::rcu_synchronize;
173
174    #[test]
175    fn test_rcu_cache_fifo_eviction() {
176        let capacity = 3;
177        let cache = RcuCache::new(capacity);
178        let guard = cache.lock();
179        let scope = RcuReadScope::new();
180
181        // Insert items up to capacity
182        guard.insert(&scope, 1, 10);
183        guard.insert(&scope, 2, 20);
184        guard.insert(&scope, 3, 30);
185
186        assert_eq!(guard.get(&scope, &1), Some(&10));
187        assert_eq!(guard.get(&scope, &2), Some(&20));
188        assert_eq!(guard.get(&scope, &3), Some(&30));
189
190        // Insert an item beyond capacity, should evict 1
191        guard.insert(&scope, 4, 40);
192
193        assert_eq!(cache.get(&scope, &1), None);
194        assert_eq!(cache.get(&scope, &2), Some(&20));
195        assert_eq!(cache.get(&scope, &3), Some(&30));
196        assert_eq!(cache.get(&scope, &4), Some(&40));
197
198        // Insert another item, should evict 2
199        guard.insert(&scope, 5, 50);
200
201        assert_eq!(cache.get(&scope, &1), None);
202        assert_eq!(cache.get(&scope, &2), None);
203        assert_eq!(cache.get(&scope, &3), Some(&30));
204        assert_eq!(cache.get(&scope, &4), Some(&40));
205        assert_eq!(cache.get(&scope, &5), Some(&50));
206
207        // Update an existing item, should not evict and not change order for eviction
208        guard.insert(&scope, 3, 300);
209
210        assert_eq!(cache.get(&scope, &1), None);
211        assert_eq!(cache.get(&scope, &2), None);
212        assert_eq!(cache.get(&scope, &3), Some(&300));
213        assert_eq!(cache.get(&scope, &4), Some(&40));
214        assert_eq!(cache.get(&scope, &5), Some(&50));
215
216        // Insert another item, should evict 4 (because 3 was updated, not re-inserted)
217        guard.insert(&scope, 6, 60);
218
219        assert_eq!(cache.get(&scope, &1), None);
220        assert_eq!(cache.get(&scope, &2), None);
221        assert_eq!(cache.get(&scope, &3), Some(&300));
222        assert_eq!(cache.get(&scope, &4), None);
223        assert_eq!(cache.get(&scope, &5), Some(&50));
224        assert_eq!(cache.get(&scope, &6), Some(&60));
225
226        std::mem::drop(guard);
227        std::mem::drop(scope);
228        rcu_synchronize();
229    }
230}