1#![warn(unsafe_op_in_unsafe_fn)]
6
7use crate::rcu_array::RcuArray;
8use crate::rcu_intrusive_list::{
9 Link, RcuIntrusiveList, RcuIntrusiveListCursor, RcuListAdapter, rcu_list_adapter,
10};
11use crate::rcu_list::RcuList;
12use fuchsia_rcu::RcuReadScope;
13use std::borrow::Borrow;
14use std::hash::{BuildHasher, Hash, Hasher};
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17const INITIAL_CAPACITY: usize = 16;
19
20#[derive(Debug)]
22struct Entry<K, V> {
23 key: K,
25
26 value: V,
28
29 collision_chain: Link,
31
32 insertion_chain: Link,
34}
35
36impl<K, V> Entry<K, V> {
37 fn new(key: K, value: V) -> Self {
39 Self {
40 key,
41 value,
42 collision_chain: Default::default(),
43 insertion_chain: Default::default(),
44 }
45 }
46}
47
48#[derive(Debug)]
50struct CollisionAdapter;
51
52impl<K, V> RcuListAdapter<Entry<K, V>> for CollisionAdapter {
53 rcu_list_adapter!(Entry<K, V>, collision_chain);
54}
55
56#[derive(Debug)]
58struct InsertionAdapter;
59
60impl<K, V> RcuListAdapter<Entry<K, V>> for InsertionAdapter {
61 rcu_list_adapter!(Entry<K, V>, insertion_chain);
62}
63
64pub enum InsertionResult<V> {
66 Inserted(usize),
70
71 Updated(V),
75}
76
77type Bucket<K, V> = RcuList<Entry<K, V>, CollisionAdapter>;
81
82pub struct RcuRawHashMap<K, V, S = rapidhash::RapidBuildHasher>
88where
89 K: Eq + Hash + Clone + Send + Sync + 'static,
90 V: Clone + Send + Sync + 'static,
91 S: BuildHasher + Send + Sync + 'static,
92{
93 table: RcuArray<Bucket<K, V>>,
95
96 num_entries: AtomicUsize,
98
99 insertion_chain: RcuIntrusiveList<Entry<K, V>, InsertionAdapter>,
101
102 hash_builder: S,
104}
105
106impl<K, V> Default for RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
107where
108 K: Eq + Hash + Clone + Send + Sync + 'static,
109 V: Clone + Send + Sync + 'static,
110{
111 fn default() -> Self {
112 Self::with_capacity_and_hasher(0, rapidhash::RapidBuildHasher::default())
113 }
114}
115
116impl<K, V> RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
117where
118 K: Eq + Hash + Clone + Send + Sync + 'static,
119 V: Clone + Send + Sync + 'static,
120{
121 pub fn with_capacity(capacity: usize) -> Self {
123 Self::with_capacity_and_hasher(capacity, rapidhash::RapidBuildHasher::default())
124 }
125}
126
127impl<K, V, S> RcuRawHashMap<K, V, S>
128where
129 K: Eq + Hash + Clone + Send + Sync + 'static,
130 V: Clone + Send + Sync + 'static,
131 S: BuildHasher + Send + Sync + 'static,
132{
133 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
135 let mut table = Vec::new();
136 table.resize_with((capacity + 1) / 2, Default::default);
137 Self {
138 table: RcuArray::from(table),
139 num_entries: AtomicUsize::new(0),
140 insertion_chain: Default::default(),
141 hash_builder,
142 }
143 }
144
145 pub fn with_hasher(hash_builder: S) -> Self {
147 Self::with_capacity_and_hasher(0, hash_builder)
148 }
149
150 fn hash_key<Q>(&self, key: &Q) -> u64
152 where
153 Q: ?Sized + Hash,
154 {
155 let mut hasher = self.hash_builder.build_hasher();
156 key.hash(&mut hasher);
157 hasher.finish()
158 }
159
160 fn get_bucket<'a, Q>(&self, table: &'a [Bucket<K, V>], key: &Q) -> &'a Bucket<K, V>
162 where
163 K: Borrow<Q>,
164 Q: ?Sized + Hash,
165 {
166 let hash = self.hash_key(key);
167 let index = hash as usize % table.len();
168 &table[index]
169 }
170
171 fn read_bucket<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a Bucket<K, V>>
173 where
174 K: Borrow<Q>,
175 Q: ?Sized + Hash,
176 {
177 let table = self.table.as_slice(scope);
178 if table.is_empty() {
179 return None;
180 }
181 Some(self.get_bucket(table, key))
182 }
183
184 pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
188 where
189 K: Borrow<Q>,
190 Q: ?Sized + Hash + Eq,
191 {
192 let bucket = self.read_bucket(scope, key)?;
193 bucket.iter(scope).find(|entry| entry.key.borrow() == key).map(|entry| &entry.value)
194 }
195
196 pub fn len(&self) -> usize {
200 self.num_entries.load(Ordering::Relaxed)
201 }
202
203 pub unsafe fn insert(&self, scope: &RcuReadScope, key: K, value: V) -> InsertionResult<V> {
217 let mut table = self.table.as_slice(scope);
218 if self.needs_to_grow(table) {
219 table = unsafe { self.grow(&scope, table) };
222 }
223 let bucket = self.get_bucket(table, &key);
224 let mut cursor = bucket.cursor(&scope);
225 while let Some(entry) = cursor.current() {
226 if entry.key == key {
227 let old_value = entry.value.clone();
228 unsafe {
231 let removed_entry = cursor.remove();
232 self.insertion_chain.remove(&scope, removed_entry);
233 let entry = bucket.push_front(&scope, Entry::new(key, value));
234 self.insertion_chain.push_back(&scope, entry);
235 };
236 return InsertionResult::Updated(old_value);
237 }
238 cursor.advance();
239 }
240
241 unsafe {
244 let entry = bucket.push_front(&scope, Entry::new(key, value));
245 self.insertion_chain.push_back(&scope, entry);
246 }
247 let count = self.num_entries.fetch_add(1, Ordering::Relaxed);
248 InsertionResult::Inserted(count + 1)
249 }
250
251 pub unsafe fn remove<Q>(&self, key: &Q) -> Option<V>
261 where
262 K: Borrow<Q>,
263 Q: ?Sized + Hash + Eq,
264 {
265 let scope = RcuReadScope::new();
266 let bucket = self.read_bucket(&scope, key)?;
267 let mut cursor = bucket.cursor(&scope);
268 while let Some(entry) = cursor.current() {
269 if entry.key.borrow() == key {
270 let old_value = entry.value.clone();
271 unsafe {
274 let removed_entry = cursor.remove();
275 self.insertion_chain.remove(&scope, removed_entry);
276 };
277 self.num_entries.fetch_sub(1, Ordering::Relaxed);
278 return Some(old_value);
279 }
280 cursor.advance();
281 }
282 None
283 }
284
285 fn needs_to_grow(&self, table: &[Bucket<K, V>]) -> bool {
287 table.is_empty() || self.num_entries.load(Ordering::Relaxed) > table.len() * 2
288 }
289
290 #[must_use]
299 unsafe fn grow<'a>(
300 &self,
301 scope: &'a RcuReadScope,
302 old_table: &[Bucket<K, V>],
303 ) -> &'a [Bucket<K, V>] {
304 let new_size = if old_table.is_empty() { INITIAL_CAPACITY } else { old_table.len() * 2 };
305 let mut new_table = Vec::new();
306 let new_insertion_chain = RcuIntrusiveList::default();
307 new_table.resize_with(new_size, Default::default);
308
309 for entry in self.insertion_chain.iter(scope) {
310 let bucket = self.get_bucket(&new_table, &entry.key);
311 let key = entry.key.clone();
312 let value = entry.value.clone();
313 unsafe {
315 let entry = bucket.push_front(&scope, Entry::new(key, value));
316 new_insertion_chain.push_back(&scope, entry);
317 };
318 }
319
320 self.table.update(new_table);
321 unsafe {
323 self.insertion_chain.update(&scope, new_insertion_chain);
324 }
325 self.table.as_slice(scope)
326 }
327
328 pub fn cursor<'a>(&'a self, scope: &'a RcuReadScope) -> RcuRawHashMapCursor<'a, K, V, S> {
332 RcuRawHashMapCursor { inner: self.insertion_chain.cursor(scope), map: self }
333 }
334
335 pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
337 self.insertion_chain.iter(scope).map(|entry| &entry.key)
338 }
339}
340
341impl<K, V, S> std::fmt::Debug for RcuRawHashMap<K, V, S>
343where
344 K: Eq + Hash + Clone + Send + Sync + 'static + std::fmt::Debug,
345 V: Clone + Send + Sync + 'static + std::fmt::Debug,
346 S: std::hash::BuildHasher + Send + Sync + 'static,
347{
348 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349 f.debug_struct("RcuRawHashMap")
350 .field("table", &self.table)
351 .field("num_entries", &self.num_entries)
352 .field("insertion_chain", &self.insertion_chain)
353 .field("hash_builder", &std::any::type_name::<S>())
354 .finish_non_exhaustive()
355 }
356}
357
358pub struct RcuRawHashMapCursor<'a, K, V, S = rapidhash::RapidBuildHasher>
362where
363 K: Eq + Hash + Clone + Send + Sync + 'static,
364 V: Clone + Send + Sync + 'static,
365 S: BuildHasher + Send + Sync + 'static,
366{
367 inner: RcuIntrusiveListCursor<'a, Entry<K, V>, InsertionAdapter>,
368 map: &'a RcuRawHashMap<K, V, S>,
369}
370
371impl<'a, K, V, S> RcuRawHashMapCursor<'a, K, V, S>
372where
373 K: Eq + Hash + Clone + Send + Sync + 'static,
374 V: Clone + Send + Sync + 'static,
375 S: BuildHasher + Send + Sync + 'static,
376{
377 pub fn current(&self) -> Option<(&'a K, &'a V)> {
379 self.inner.current().map(|entry| (&entry.key, &entry.value))
380 }
381
382 pub fn advance(&mut self) {
384 self.inner.advance()
385 }
386
387 pub unsafe fn remove(&mut self) -> Option<V> {
398 if let Some((key, _)) = self.current() {
399 self.advance();
400 unsafe { self.map.remove(key) }
402 } else {
403 None
404 }
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use fuchsia_rcu::rcu_synchronize;
412
413 #[test]
414 fn test_rcu_hash_map_custom_hasher() {
415 use std::collections::hash_map::DefaultHasher;
416 use std::hash::BuildHasherDefault;
417 let hasher = BuildHasherDefault::<DefaultHasher>::default();
418 let map = RcuRawHashMap::with_capacity_and_hasher(10, hasher);
419 let scope = RcuReadScope::new();
420 unsafe {
421 map.insert(&scope, 1, 10);
422 }
423 assert_eq!(map.get(&scope, &1), Some(&10));
424 }
425
426 #[test]
427 fn test_rcu_hash_map_insert_and_get() {
428 let map = RcuRawHashMap::default();
429 let scope = RcuReadScope::new();
430 unsafe {
431 map.insert(&scope, 1, 10);
432 map.insert(&scope, 2, 20);
433 }
434
435 assert_eq!(map.get(&scope, &1), Some(&10));
436 assert_eq!(map.get(&scope, &2), Some(&20));
437 assert_eq!(map.get(&scope, &3), None);
438
439 std::mem::drop(scope);
440 rcu_synchronize();
441 }
442
443 #[test]
444 fn test_rcu_hash_map_remove() {
445 let map = RcuRawHashMap::default();
446 let scope = RcuReadScope::new();
447 unsafe {
448 map.insert(&scope, 1, 10);
449 map.insert(&scope, 2, 20);
450 }
451
452 assert_eq!(map.get(&scope, &1), Some(&10));
453
454 unsafe {
455 assert_eq!(map.remove(&1), Some(10));
456 }
457
458 assert_eq!(map.get(&scope, &1), None);
459 assert_eq!(map.get(&scope, &2), Some(&20));
460
461 std::mem::drop(scope);
462 rcu_synchronize();
463 }
464
465 #[test]
466 fn test_rcu_hash_map_insert_update() {
467 let map = RcuRawHashMap::default();
468 let scope = RcuReadScope::new();
469 unsafe {
470 map.insert(&scope, 1, 10);
471 }
472
473 assert_eq!(map.get(&scope, &1), Some(&10));
474
475 let result = unsafe { map.insert(&scope, 1, 100) };
476 assert!(matches!(result, InsertionResult::Updated(10)));
477
478 assert_eq!(map.get(&scope, &1), Some(&100));
479
480 std::mem::drop(scope);
481 rcu_synchronize();
482 }
483
484 #[test]
485 fn test_rcu_hash_map_cursor() {
486 let map = RcuRawHashMap::default();
487 let scope = RcuReadScope::new();
488 unsafe {
489 map.insert(&scope, 1, 10);
490 map.insert(&scope, 2, 20);
491 map.insert(&scope, 3, 30);
492 }
493
494 let mut cursor = map.cursor(&scope);
495
496 assert_eq!(cursor.current(), Some((&1, &10)));
497 cursor.advance();
498 assert_eq!(cursor.current(), Some((&2, &20)));
499
500 unsafe {
501 cursor.remove();
502 }
503
504 assert_eq!(cursor.current(), Some((&3, &30)));
505 assert_eq!(map.get(&scope, &2), None);
506
507 cursor.advance();
508 assert_eq!(cursor.current(), None);
509
510 std::mem::drop(scope);
511 rcu_synchronize();
512 }
513
514 #[test]
515 fn test_rcu_hash_map_grow_maintains_order() {
516 let map = RcuRawHashMap::default();
517 let scope = RcuReadScope::new();
518 let num_elements = INITIAL_CAPACITY * 3;
519 let mut expected_order = Vec::new();
520
521 for i in 0..num_elements {
522 unsafe {
523 map.insert(&scope, i, i * 10);
524 }
525 expected_order.push((i, i * 10));
526 }
527
528 let mut cursor = map.cursor(&scope);
529 let mut actual_order = Vec::new();
530
531 while let Some((key, value)) = cursor.current() {
532 actual_order.push((*key, *value));
533 cursor.advance();
534 }
535
536 assert_eq!(actual_order, expected_order);
537
538 std::mem::drop(scope);
539 rcu_synchronize();
540 }
541 #[test]
542 fn test_rcu_hash_map_grow_overwrites_maintain_order() {
543 let map = RcuRawHashMap::default();
544 let scope = RcuReadScope::new();
545 let num_elements = INITIAL_CAPACITY * 3;
546 let mut expected_order = Vec::new();
547
548 for i in 0..num_elements {
549 unsafe {
550 map.insert(&scope, i, i * 10);
551 }
552 expected_order.push((i, i * 10));
553 }
554
555 unsafe {
557 map.insert(&scope, 5, 500);
558 map.insert(&scope, INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10); }
560 expected_order.retain(|(k, _)| *k != 5);
561 expected_order.push((5, 500));
562 expected_order.push((INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10));
563
564 let mut cursor = map.cursor(&scope);
565 let mut actual_order = Vec::new();
566
567 while let Some((key, value)) = cursor.current() {
568 actual_order.push((*key, *value));
569 cursor.advance();
570 }
571
572 assert_eq!(actual_order, expected_order);
573
574 std::mem::drop(scope);
575 rcu_synchronize();
576 }
577
578 #[test]
579 fn test_rcu_hash_map_grow() {
580 let map = RcuRawHashMap::default();
581 let scope = RcuReadScope::new();
582 for i in 0..(INITIAL_CAPACITY * 3) {
583 unsafe {
584 map.insert(&scope, i, i * 10);
585 }
586 }
587
588 for i in 0..(INITIAL_CAPACITY * 3) {
589 assert_eq!(map.get(&scope, &i), Some(&(i * 10)));
590 }
591
592 std::mem::drop(scope);
593 rcu_synchronize();
594 }
595
596 #[test]
597 fn test_rcu_hash_map_capacity_zero() {
598 let map = RcuRawHashMap::with_capacity(0);
599 let scope = RcuReadScope::new();
600
601 assert_eq!(map.get(&scope, &1), None);
602
603 unsafe {
604 map.insert(&scope, 1, 10);
605 }
606 assert_eq!(map.get(&scope, &1), Some(&10));
607
608 unsafe {
609 assert_eq!(map.remove(&1), Some(10));
610 }
611 assert_eq!(map.get(&scope, &1), None);
612
613 std::mem::drop(scope);
614 rcu_synchronize();
615 }
616}