1use fuchsia_rcu::RcuReadScope;
6use fuchsia_rcu_collections::rcu_raw_hash_map::{InsertionResult, RcuRawHashMap};
7use starnix_sync::Mutex;
8use std::borrow::Borrow;
9use std::hash::{BuildHasher, Hash};
10
11pub struct RcuHashMap<K, V, S = rapidhash::RapidBuildHasher>
20where
21 K: Eq + Hash + Clone + Send + Sync + 'static,
22 V: Clone + Send + Sync + 'static,
23 S: BuildHasher + Send + Sync + 'static,
24{
25 map: RcuRawHashMap<K, V, S>,
26 mutex: Mutex<()>,
27}
28
29impl<K, V> Default for RcuHashMap<K, V, rapidhash::RapidBuildHasher>
30where
31 K: Eq + Hash + Clone + Send + Sync + 'static,
32 V: Clone + Send + Sync + 'static,
33{
34 fn default() -> Self {
35 Self { map: Default::default(), mutex: Mutex::new(()) }
36 }
37}
38
39impl<K, V, S> RcuHashMap<K, V, S>
40where
41 K: Eq + Hash + Clone + Send + Sync + 'static,
42 V: Clone + Send + Sync + 'static,
43 S: BuildHasher + Send + Sync + 'static,
44{
45 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
47 Self {
48 map: RcuRawHashMap::with_capacity_and_hasher(capacity, hash_builder),
49 mutex: Mutex::new(()),
50 }
51 }
52
53 pub fn with_hasher(hash_builder: S) -> Self {
55 Self { map: RcuRawHashMap::with_hasher(hash_builder), mutex: Mutex::new(()) }
56 }
57
58 pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
62 where
63 K: Borrow<Q>,
64 Q: ?Sized + Hash + Eq,
65 {
66 self.map.get(scope, key)
67 }
68
69 pub fn lock(&self) -> RcuHashMapGuard<'_, K, V, S> {
71 RcuHashMapGuard { map: &self.map, _guard: self.mutex.lock() }
72 }
73
74 pub fn insert(&self, key: K, value: V) -> Option<V> {
76 self.lock().insert(key, value)
77 }
78
79 pub fn remove<Q>(&self, key: &Q) -> Option<V>
81 where
82 K: Borrow<Q>,
83 Q: ?Sized + Hash + Eq,
84 {
85 self.lock().remove(key)
86 }
87
88 pub fn iter<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = (&'a K, &'a V)> {
90 let mut cursor = self.map.cursor(scope);
91 std::iter::from_fn(move || {
92 let current = cursor.current();
93 if current.is_some() {
94 cursor.advance();
95 }
96 current
97 })
98 }
99
100 pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
102 self.iter(scope).map(|(k, _)| k)
103 }
104
105 pub fn len(&self) -> usize {
107 self.map.len()
108 }
109}
110
111impl<K, V, S> std::fmt::Debug for RcuHashMap<K, V, S>
113where
114 K: Eq + Hash + std::fmt::Debug + Clone + Send + Sync + 'static,
115 V: std::fmt::Debug + Clone + Send + Sync + 'static,
116 S: BuildHasher + Send + Sync + 'static,
117{
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("RcuHashMap").field("map", &self.map).finish()
120 }
121}
122
123pub struct RcuHashMapGuard<'a, K, V, S = rapidhash::RapidBuildHasher>
125where
126 K: Eq + Hash + Clone + Send + Sync + 'static,
127 V: Clone + Send + Sync + 'static,
128 S: BuildHasher + Send + Sync + 'static,
129{
130 map: &'a RcuRawHashMap<K, V, S>,
131 _guard: starnix_sync::MutexGuard<'a, ()>,
132}
133
134impl<'a, K, V, S> RcuHashMapGuard<'a, K, V, S>
135where
136 K: Eq + Hash + Clone + Send + Sync + 'static,
137 V: Clone + Send + Sync + 'static,
138 S: BuildHasher + Send + Sync + 'static,
139{
140 pub fn get<Q>(&self, key: &Q) -> Option<V>
142 where
143 K: Borrow<Q>,
144 Q: ?Sized + Hash + Eq,
145 {
146 let scope = RcuReadScope::new();
147 self.map.get(&scope, key).cloned()
148 }
149
150 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
152 let scope = RcuReadScope::new();
153 match unsafe { self.map.insert(&scope, key, value) } {
155 InsertionResult::Inserted(_) => None,
156 InsertionResult::Updated(old_value) => Some(old_value),
157 }
158 }
159
160 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
162 where
163 K: Borrow<Q>,
164 Q: ?Sized + Hash + Eq,
165 {
166 unsafe { self.map.remove(key) }
168 }
169
170 pub fn drain<'b>(&'b mut self) -> impl Iterator<Item = (K, V)> + 'b {
172 let scope = RcuReadScope::new();
173 #[allow(clippy::needless_collect)]
175 let keys: Vec<_> = self.map.keys(&scope).map(Clone::clone).collect();
176 keys.into_iter().filter_map(move |k| self.remove(&k).map(|v| (k, v)))
177 }
178
179 pub fn contains_key<Q>(&self, key: &Q) -> bool
181 where
182 K: Borrow<Q>,
183 Q: ?Sized + Hash + Eq,
184 {
185 self.get(key).is_some()
186 }
187
188 pub fn entry<'b>(&'b mut self, key: K) -> Entry<'b, 'a, K, V, S> {
190 if self.get(&key).is_some() {
191 Entry::Occupied(OccupiedEntry { guard: self, key })
192 } else {
193 Entry::Vacant(VacantEntry { guard: self, key })
194 }
195 }
196}
197
198pub enum Entry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
200where
201 K: Eq + Hash + Clone + Send + Sync + 'static,
202 V: Clone + Send + Sync + 'static,
203 S: BuildHasher + Send + Sync + 'static,
204{
205 Occupied(OccupiedEntry<'b, 'a, K, V, S>),
207 Vacant(VacantEntry<'b, 'a, K, V, S>),
209}
210
211impl<'b, 'a, K, V, S> Entry<'b, 'a, K, V, S>
212where
213 K: Eq + Hash + Clone + Send + Sync + 'static,
214 V: Clone + Send + Sync + 'static,
215 S: BuildHasher + Send + Sync + 'static,
216{
217 pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> OccupiedEntry<'b, 'a, K, V, S> {
220 match self {
221 Entry::Occupied(entry) => entry,
222 Entry::Vacant(entry) => entry.insert_entry(default()),
223 }
224 }
225}
226
227pub struct OccupiedEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
229where
230 K: Eq + Hash + Clone + Send + Sync + 'static,
231 V: Clone + Send + Sync + 'static,
232 S: BuildHasher + Send + Sync + 'static,
233{
234 guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
235 key: K,
236}
237
238impl<K, V, S> OccupiedEntry<'_, '_, K, V, S>
239where
240 K: Eq + Hash + Clone + Send + Sync + 'static,
241 V: Clone + Send + Sync + 'static,
242 S: BuildHasher + Send + Sync + 'static,
243{
244 pub fn get(&self) -> V {
246 self.guard.get(&self.key).unwrap()
247 }
248
249 pub fn insert(&mut self, value: V) -> V {
251 self.guard.insert(self.key.clone(), value).unwrap()
252 }
253
254 pub fn remove(self) -> V {
256 self.guard.remove(&self.key).unwrap()
257 }
258}
259
260pub struct VacantEntry<'b, 'a, K, V, S = rapidhash::RapidBuildHasher>
262where
263 K: Eq + Hash + Clone + Send + Sync + 'static,
264 V: Clone + Send + Sync + 'static,
265 S: BuildHasher + Send + Sync + 'static,
266{
267 guard: &'b mut RcuHashMapGuard<'a, K, V, S>,
268 key: K,
269}
270
271impl<'b, 'a, K, V, S> VacantEntry<'b, 'a, K, V, S>
272where
273 K: Eq + Hash + Clone + Send + Sync + 'static,
274 V: Clone + Send + Sync + 'static,
275 S: BuildHasher + Send + Sync + 'static,
276{
277 pub fn insert(self, value: V) {
279 self.guard.insert(self.key, value);
280 }
281
282 pub fn insert_entry(self, value: V) -> OccupiedEntry<'b, 'a, K, V, S> {
284 self.guard.insert(self.key.clone(), value);
285 OccupiedEntry { guard: self.guard, key: self.key }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use fuchsia_rcu::rcu_synchronize;
293
294 #[test]
295 fn test_rcu_hash_map_custom_hasher() {
296 use std::collections::hash_map::DefaultHasher;
297 use std::hash::BuildHasherDefault;
298 let hasher = BuildHasherDefault::<DefaultHasher>::default();
299 let map = RcuHashMap::with_capacity_and_hasher(10, hasher);
300 let mut guard = map.lock();
301 guard.insert(1, 10);
302 assert_eq!(guard.get(&1), Some(10));
303 }
304
305 #[test]
306 fn test_rcu_hash_map_insert_and_get() {
307 let map = RcuHashMap::<i32, i32>::default();
308 let mut guard = map.lock();
309 let scope = RcuReadScope::new();
310
311 guard.insert(1, 10);
312 guard.insert(2, 20);
313
314 assert_eq!(guard.get(&1), Some(10));
315 assert_eq!(guard.get(&2), Some(20));
316 assert_eq!(guard.get(&3), None);
317
318 drop(guard);
320 assert_eq!(map.get(&scope, &1), Some(&10));
321 assert_eq!(map.get(&scope, &2), Some(&20));
322
323 drop(scope);
324 rcu_synchronize();
325 }
326
327 #[test]
328 fn test_rcu_hash_map_update() {
329 let map = RcuHashMap::<i32, i32>::default();
330 let mut guard = map.lock();
331 let scope = RcuReadScope::new();
332
333 guard.insert(1, 10);
334 assert_eq!(guard.get(&1), Some(10));
335
336 guard.insert(1, 20);
337 assert_eq!(guard.get(&1), Some(20));
338
339 drop(guard);
340 assert_eq!(map.get(&scope, &1), Some(&20));
341
342 drop(scope);
343 rcu_synchronize();
344 }
345
346 #[test]
347 fn test_rcu_hash_map_remove() {
348 let map = RcuHashMap::<i32, i32>::default();
349 let mut guard = map.lock();
350 let scope = RcuReadScope::new();
351
352 guard.insert(1, 10);
353 assert_eq!(guard.get(&1), Some(10));
354
355 guard.remove(&1);
356 assert_eq!(guard.get(&1), None);
357
358 drop(guard);
359 assert_eq!(map.get(&scope, &1), None);
360
361 drop(scope);
362 rcu_synchronize();
363 }
364
365 #[test]
366 fn test_rcu_hash_map_entry_api() {
367 let map = RcuHashMap::<i32, i32>::default();
368 let mut guard = map.lock();
369
370 match guard.entry(1) {
372 Entry::Vacant(e) => e.insert(10),
373 Entry::Occupied(_) => panic!("Should be vacant"),
374 }
375 assert_eq!(guard.get(&1), Some(10));
376
377 match guard.entry(1) {
379 Entry::Occupied(mut e) => {
380 assert_eq!(e.get(), 10);
381 e.insert(20);
382 }
383 Entry::Vacant(_) => panic!("Should be occupied"),
384 }
385 assert_eq!(guard.get(&1), Some(20));
386
387 drop(guard);
388 rcu_synchronize();
389 }
390
391 #[test]
392 fn test_rcu_hash_map_iter() {
393 let map = RcuHashMap::<i32, i32>::default();
394 let scope = RcuReadScope::new();
395 map.insert(1, 10);
396 map.insert(2, 20);
397 map.insert(3, 30);
398
399 let mut items: Vec<_> = map.iter(&scope).collect();
400 items.sort_by_key(|(k, _)| **k);
401 assert_eq!(items, vec![(&1, &10), (&2, &20), (&3, &30)]);
402 }
403
404 #[test]
405 fn test_rcu_hash_map_keys() {
406 let map = RcuHashMap::<i32, i32>::default();
407 let scope = RcuReadScope::new();
408 map.insert(1, 10);
409 map.insert(2, 20);
410 map.insert(3, 30);
411
412 let mut keys: Vec<_> = map.keys(&scope).collect();
413 keys.sort();
414 assert_eq!(keys, vec![&1, &2, &3]);
415 }
416
417 #[test]
418 fn test_rcu_hash_map_len() {
419 let map = RcuHashMap::<i32, i32>::default();
420 map.insert(1, 10);
421 map.insert(2, 20);
422 map.insert(3, 30);
423
424 assert_eq!(map.len(), 3);
425 }
426
427 #[test]
428 fn test_rcu_hash_map_or_insert_with() {
429 let map = RcuHashMap::<i32, i32>::default();
430 let mut guard = map.lock();
431
432 guard.entry(1).or_insert_with(|| 10);
434 assert!(guard.contains_key(&1));
435 assert_eq!(guard.get(&1), Some(10));
436
437 guard.entry(1).or_insert_with(|| 20);
439 assert_eq!(guard.get(&1), Some(10));
440
441 match guard.entry(1) {
443 Entry::Occupied(e) => {
444 assert_eq!(e.remove(), 10);
445 }
446 Entry::Vacant(_) => panic!("Should be occupied"),
447 }
448 assert!(!guard.contains_key(&1));
449 }
450
451 #[test]
452 fn test_rcu_hash_map_drain() {
453 let map = RcuHashMap::<i32, i32>::default();
454 let mut guard = map.lock();
455
456 guard.insert(1, 10);
457 guard.insert(2, 20);
458 guard.insert(3, 30);
459
460 let mut items: Vec<_> = guard.drain().collect();
461 items.sort_by_key(|(k, _)| *k);
462 assert_eq!(items, vec![(1, 10), (2, 20), (3, 30)]);
463
464 assert!(!guard.contains_key(&1));
465 assert!(!guard.contains_key(&2));
466 assert!(!guard.contains_key(&3));
467 }
468
469 #[test]
470 fn test_rcu_hash_map_capacity_zero() {
471 use std::collections::hash_map::RandomState;
472 let map =
473 RcuHashMap::<i32, i32, RandomState>::with_capacity_and_hasher(0, RandomState::new());
474 let mut guard = map.lock();
475
476 assert_eq!(guard.get(&1), None);
477
478 guard.insert(1, 10);
479 assert_eq!(guard.get(&1), Some(10));
480
481 guard.remove(&1);
482 assert_eq!(guard.get(&1), None);
483 }
484}