1use fuchsia_rcu::RcuReadScope;
6use fuchsia_rcu_collections::rcu_raw_hash_map::{InsertionResult, RcuRawHashMap};
7use starnix_sync::Mutex;
8use std::borrow::Borrow;
9use std::hash::{BuildHasher, Hash};
10
11pub struct RcuHashMap<K, V, S = rapidhash::RapidBuildHasher>
20where
21 K: Eq + Hash + Clone + Send + Sync + 'static,
22 V: Clone + Send + Sync + 'static,
23 S: BuildHasher + Send + Sync + 'static,
24{
25 map: RcuRawHashMap<K, V, S>,
26 mutex: Mutex<()>,
27}
28
29impl<K, V> Default for RcuHashMap<K, V, rapidhash::RapidBuildHasher>
30where
31 K: Eq + Hash + Clone + Send + Sync + 'static,
32 V: Clone + Send + Sync + 'static,
33{
34 fn default() -> Self {
35 Self { map: Default::default(), mutex: Mutex::new(()) }
36 }
37}
38
39impl<K, V, S> RcuHashMap<K, V, S>
40where
41 K: Eq + Hash + Clone + Send + Sync + 'static,
42 V: Clone + Send + Sync + 'static,
43 S: BuildHasher + Send + Sync + 'static,
44{
45 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
47 Self {
48 map: RcuRawHashMap::with_capacity_and_hasher(capacity, hash_builder),
49 mutex: Mutex::new(()),
50 }
51 }
52
53 pub fn with_hasher(hash_builder: S) -> Self {
55 Self { map: RcuRawHashMap::with_hasher(hash_builder), mutex: Mutex::new(()) }
56 }
57
58 pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
62 where
63 K: Borrow<Q>,
64 Q: ?Sized + Hash + Eq,
65 {
66 self.map.get(scope, key)
67 }
68
69 pub fn lock(&self) -> RcuHashMapGuard<'_, K, V, S> {
71 RcuHashMapGuard { map: &self.map, _guard: self.mutex.lock() }
72 }
73
74 pub fn insert(&self, key: K, value: V) -> Option<V> {
76 self.lock().insert(key, value)
77 }
78
79 pub fn remove<Q>(&self, key: &Q) -> Option<V>
81 where
82 K: Borrow<Q>,
83 Q: ?Sized + Hash + Eq,
84 {
85 self.lock().remove(key)
86 }
87
88 pub fn iter<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = (&'a K, &'a V)> {
90 let mut cursor = self.map.cursor(scope);
91 std::iter::from_fn(move || {
92 let current = cursor.current();
93 if current.is_some() {
94 cursor.advance();
95 }
96 current
97 })
98 }
99
100 pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
102 self.iter(scope).map(|(k, _)| k)
103 }
104}
105
106impl<K, V, S> std::fmt::Debug for RcuHashMap<K, V, S>
108where
109 K: Eq + Hash + std::fmt::Debug + Clone + Send + Sync + 'static,
110 V: std::fmt::Debug + Clone + Send + Sync + 'static,
111 S: BuildHasher + Send + Sync + 'static,
112{
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("RcuHashMap").field("map", &self.map).finish()
115 }
116}
117
118pub struct RcuHashMapGuard<'a, K, V, S = rapidhash::RapidBuildHasher>
120where
121 K: Eq + Hash + Clone + Send + Sync + 'static,
122 V: Clone + Send + Sync + 'static,
123 S: BuildHasher + Send + Sync + 'static,
124{
125 map: &'a RcuRawHashMap<K, V, S>,
126 _guard: starnix_sync::MutexGuard<'a, ()>,
127}
128
129impl<'a, K, V, S> RcuHashMapGuard<'a, K, V, S>
130where
131 K: Eq + Hash + Clone + Send + Sync + 'static,
132 V: Clone + Send + Sync + 'static,
133 S: BuildHasher + Send + Sync + 'static,
134{
135 pub fn get<Q>(&self, key: &Q) -> Option<V>
137 where
138 K: Borrow<Q>,
139 Q: ?Sized + Hash + Eq,
140 {
141 let scope = RcuReadScope::new();
142 self.map.get(&scope, key).cloned()
143 }
144
145 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
147 let scope = RcuReadScope::new();
148 match unsafe { self.map.insert(&scope, key, value) } {
150 InsertionResult::Inserted(_) => None,
151 InsertionResult::Updated(old_value) => Some(old_value),
152 }
153 }
154
155 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
157 where
158 K: Borrow<Q>,
159 Q: ?Sized + Hash + Eq,
160 {
161 unsafe { self.map.remove(key) }
163 }
164
165 pub fn drain<'b>(&'b mut self) -> impl Iterator<Item = (K, V)> + 'b {
167 let scope = RcuReadScope::new();
168 #[allow(clippy::needless_collect)]
170 let keys: Vec<_> = self.map.keys(&scope).map(Clone::clone).collect();
171 keys.into_iter().filter_map(move |k| self.remove(&k).map(|v| (k, v)))
172 }
173
174 pub fn contains_key<Q>(&self, key: &Q) -> bool
176 where
177 K: Borrow<Q>,
178 Q: ?Sized + Hash + Eq,
179 {
180 self.get(key).is_some()
181 }
182
183 pub fn entry<'b>(&'b mut self, key: K) -> Entry<'b, 'a, K, V, S> {
185 if self.get(&key).is_some() {
186 Entry::Occupied(OccupiedEntry { guard: self, key })
187 } else {
188 Entry::Vacant(VacantEntry { guard: self, key })
189 }
190 }
191}
192
193pub enum Entry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
195where
196 K: Eq + Hash + Clone + Send + Sync + 'static,
197 V: Clone + Send + Sync + 'static,
198 S: BuildHasher + Send + Sync + 'static,
199{
200 Occupied(OccupiedEntry<'b, 'a, K, V, S>),
202 Vacant(VacantEntry<'b, 'a, K, V, S>),
204}
205
206impl<'b, 'a, K, V, S> Entry<'b, 'a, K, V, S>
207where
208 K: Eq + Hash + Clone + Send + Sync + 'static,
209 V: Clone + Send + Sync + 'static,
210 S: BuildHasher + Send + Sync + 'static,
211{
212 pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> OccupiedEntry<'b, 'a, K, V, S> {
215 match self {
216 Entry::Occupied(entry) => entry,
217 Entry::Vacant(entry) => entry.insert_entry(default()),
218 }
219 }
220}
221
222pub struct OccupiedEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
224where
225 K: Eq + Hash + Clone + Send + Sync + 'static,
226 V: Clone + Send + Sync + 'static,
227 S: BuildHasher + Send + Sync + 'static,
228{
229 guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
230 key: K,
231}
232
233impl<K, V, S> OccupiedEntry<'_, '_, K, V, S>
234where
235 K: Eq + Hash + Clone + Send + Sync + 'static,
236 V: Clone + Send + Sync + 'static,
237 S: BuildHasher + Send + Sync + 'static,
238{
239 pub fn get(&self) -> V {
241 self.guard.get(&self.key).unwrap()
242 }
243
244 pub fn insert(&mut self, value: V) -> V {
246 self.guard.insert(self.key.clone(), value).unwrap()
247 }
248
249 pub fn remove(self) -> V {
251 self.guard.remove(&self.key).unwrap()
252 }
253}
254
255pub struct VacantEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
257where
258 K: Eq + Hash + Clone + Send + Sync + 'static,
259 V: Clone + Send + Sync + 'static,
260 S: BuildHasher + Send + Sync + 'static,
261{
262 guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
263 key: K,
264}
265
266impl<'b, 'a, K, V, S> VacantEntry<'b, 'a, K, V, S>
267where
268 K: Eq + Hash + Clone + Send + Sync + 'static,
269 V: Clone + Send + Sync + 'static,
270 S: BuildHasher + Send + Sync + 'static,
271{
272 pub fn insert(self, value: V) {
274 self.guard.insert(self.key, value);
275 }
276
277 pub fn insert_entry(self, value: V) -> OccupiedEntry<'b, 'a, K, V, S> {
279 self.guard.insert(self.key.clone(), value);
280 OccupiedEntry { guard: self.guard, key: self.key }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use fuchsia_rcu::rcu_synchronize;
288
289 #[test]
290 fn test_rcu_hash_map_custom_hasher() {
291 use std::collections::hash_map::DefaultHasher;
292 use std::hash::BuildHasherDefault;
293 let hasher = BuildHasherDefault::<DefaultHasher>::default();
294 let map = RcuHashMap::with_capacity_and_hasher(10, hasher);
295 let mut guard = map.lock();
296 guard.insert(1, 10);
297 assert_eq!(guard.get(&1), Some(10));
298 }
299
300 #[test]
301 fn test_rcu_hash_map_insert_and_get() {
302 let map = RcuHashMap::<i32, i32>::default();
303 let mut guard = map.lock();
304 let scope = RcuReadScope::new();
305
306 guard.insert(1, 10);
307 guard.insert(2, 20);
308
309 assert_eq!(guard.get(&1), Some(10));
310 assert_eq!(guard.get(&2), Some(20));
311 assert_eq!(guard.get(&3), None);
312
313 drop(guard);
315 assert_eq!(map.get(&scope, &1), Some(&10));
316 assert_eq!(map.get(&scope, &2), Some(&20));
317
318 drop(scope);
319 rcu_synchronize();
320 }
321
322 #[test]
323 fn test_rcu_hash_map_update() {
324 let map = RcuHashMap::<i32, i32>::default();
325 let mut guard = map.lock();
326 let scope = RcuReadScope::new();
327
328 guard.insert(1, 10);
329 assert_eq!(guard.get(&1), Some(10));
330
331 guard.insert(1, 20);
332 assert_eq!(guard.get(&1), Some(20));
333
334 drop(guard);
335 assert_eq!(map.get(&scope, &1), Some(&20));
336
337 drop(scope);
338 rcu_synchronize();
339 }
340
341 #[test]
342 fn test_rcu_hash_map_remove() {
343 let map = RcuHashMap::<i32, i32>::default();
344 let mut guard = map.lock();
345 let scope = RcuReadScope::new();
346
347 guard.insert(1, 10);
348 assert_eq!(guard.get(&1), Some(10));
349
350 guard.remove(&1);
351 assert_eq!(guard.get(&1), None);
352
353 drop(guard);
354 assert_eq!(map.get(&scope, &1), None);
355
356 drop(scope);
357 rcu_synchronize();
358 }
359
360 #[test]
361 fn test_rcu_hash_map_entry_api() {
362 let map = RcuHashMap::<i32, i32>::default();
363 let mut guard = map.lock();
364
365 match guard.entry(1) {
367 Entry::Vacant(e) => e.insert(10),
368 Entry::Occupied(_) => panic!("Should be vacant"),
369 }
370 assert_eq!(guard.get(&1), Some(10));
371
372 match guard.entry(1) {
374 Entry::Occupied(mut e) => {
375 assert_eq!(e.get(), 10);
376 e.insert(20);
377 }
378 Entry::Vacant(_) => panic!("Should be occupied"),
379 }
380 assert_eq!(guard.get(&1), Some(20));
381
382 drop(guard);
383 rcu_synchronize();
384 }
385
386 #[test]
387 fn test_rcu_hash_map_iter() {
388 let map = RcuHashMap::<i32, i32>::default();
389 let scope = RcuReadScope::new();
390 map.insert(1, 10);
391 map.insert(2, 20);
392 map.insert(3, 30);
393
394 let mut items: Vec<_> = map.iter(&scope).collect();
395 items.sort_by_key(|(k, _)| **k);
396 assert_eq!(items, vec![(&1, &10), (&2, &20), (&3, &30)]);
397 }
398
399 #[test]
400 fn test_rcu_hash_map_keys() {
401 let map = RcuHashMap::<i32, i32>::default();
402 let scope = RcuReadScope::new();
403 map.insert(1, 10);
404 map.insert(2, 20);
405 map.insert(3, 30);
406
407 let mut keys: Vec<_> = map.keys(&scope).collect();
408 keys.sort();
409 assert_eq!(keys, vec![&1, &2, &3]);
410 }
411
412 #[test]
413 fn test_rcu_hash_map_or_insert_with() {
414 let map = RcuHashMap::<i32, i32>::default();
415 let mut guard = map.lock();
416
417 guard.entry(1).or_insert_with(|| 10);
419 assert!(guard.contains_key(&1));
420 assert_eq!(guard.get(&1), Some(10));
421
422 guard.entry(1).or_insert_with(|| 20);
424 assert_eq!(guard.get(&1), Some(10));
425
426 match guard.entry(1) {
428 Entry::Occupied(e) => {
429 assert_eq!(e.remove(), 10);
430 }
431 Entry::Vacant(_) => panic!("Should be occupied"),
432 }
433 assert!(!guard.contains_key(&1));
434 }
435
436 #[test]
437 fn test_rcu_hash_map_drain() {
438 let map = RcuHashMap::<i32, i32>::default();
439 let mut guard = map.lock();
440
441 guard.insert(1, 10);
442 guard.insert(2, 20);
443 guard.insert(3, 30);
444
445 let mut items: Vec<_> = guard.drain().collect();
446 items.sort_by_key(|(k, _)| *k);
447 assert_eq!(items, vec![(1, 10), (2, 20), (3, 30)]);
448
449 assert!(!guard.contains_key(&1));
450 assert!(!guard.contains_key(&2));
451 assert!(!guard.contains_key(&3));
452 }
453}