1#![warn(unsafe_op_in_unsafe_fn)]
6
7use crate::rcu_array::RcuArray;
8use crate::rcu_intrusive_list::{
9 Link, RcuIntrusiveList, RcuIntrusiveListCursor, RcuListAdapter, rcu_list_adapter,
10};
11use crate::rcu_list::RcuList;
12use fuchsia_rcu::RcuReadScope;
13use std::borrow::Borrow;
14use std::hash::{BuildHasher, Hash, Hasher};
15use std::sync::atomic::{AtomicUsize, Ordering};
16
17const INITIAL_CAPACITY: usize = 128;
19
20#[derive(Debug)]
22struct Entry<K, V> {
23 key: K,
25
26 value: V,
28
29 collision_chain: Link,
31
32 insertion_chain: Link,
34}
35
36impl<K, V> Entry<K, V> {
37 fn new(key: K, value: V) -> Self {
39 Self {
40 key,
41 value,
42 collision_chain: Default::default(),
43 insertion_chain: Default::default(),
44 }
45 }
46}
47
48#[derive(Debug)]
50struct CollisionAdapter;
51
52impl<K, V> RcuListAdapter<Entry<K, V>> for CollisionAdapter {
53 rcu_list_adapter!(Entry<K, V>, collision_chain);
54}
55
56#[derive(Debug)]
58struct InsertionAdapter;
59
60impl<K, V> RcuListAdapter<Entry<K, V>> for InsertionAdapter {
61 rcu_list_adapter!(Entry<K, V>, insertion_chain);
62}
63
64pub enum InsertionResult<V> {
66 Inserted(usize),
70
71 Updated(V),
75}
76
77type Bucket<K, V> = RcuList<Entry<K, V>, CollisionAdapter>;
81
82pub struct RcuRawHashMap<K, V, S = rapidhash::RapidBuildHasher>
88where
89 K: Eq + Hash + Clone + Send + Sync + 'static,
90 V: Clone + Send + Sync + 'static,
91 S: BuildHasher + Send + Sync + 'static,
92{
93 table: RcuArray<Bucket<K, V>>,
95
96 num_entries: AtomicUsize,
98
99 insertion_chain: RcuIntrusiveList<Entry<K, V>, InsertionAdapter>,
101
102 hash_builder: S,
104}
105
106impl<K, V> Default for RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
107where
108 K: Eq + Hash + Clone + Send + Sync + 'static,
109 V: Clone + Send + Sync + 'static,
110{
111 fn default() -> Self {
112 Self::with_capacity_and_hasher(INITIAL_CAPACITY, rapidhash::RapidBuildHasher::default())
113 }
114}
115
116impl<K, V> RcuRawHashMap<K, V, rapidhash::RapidBuildHasher>
117where
118 K: Eq + Hash + Clone + Send + Sync + 'static,
119 V: Clone + Send + Sync + 'static,
120{
121 pub fn with_capacity(capacity: usize) -> Self {
123 Self::with_capacity_and_hasher(capacity, rapidhash::RapidBuildHasher::default())
124 }
125}
126
127impl<K, V, S> RcuRawHashMap<K, V, S>
128where
129 K: Eq + Hash + Clone + Send + Sync + 'static,
130 V: Clone + Send + Sync + 'static,
131 S: BuildHasher + Send + Sync + 'static,
132{
133 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
135 let mut table = Vec::new();
136 table.resize_with((capacity + 1) / 2, Default::default);
137 Self {
138 table: RcuArray::from(table),
139 num_entries: AtomicUsize::new(0),
140 insertion_chain: Default::default(),
141 hash_builder,
142 }
143 }
144
145 pub fn with_hasher(hash_builder: S) -> Self {
147 Self::with_capacity_and_hasher(INITIAL_CAPACITY, hash_builder)
148 }
149
150 fn hash_key<Q>(&self, key: &Q) -> u64
152 where
153 Q: ?Sized + Hash,
154 {
155 let mut hasher = self.hash_builder.build_hasher();
156 key.hash(&mut hasher);
157 hasher.finish()
158 }
159
160 fn get_bucket<'a, Q>(&self, table: &'a [Bucket<K, V>], key: &Q) -> &'a Bucket<K, V>
162 where
163 K: Borrow<Q>,
164 Q: ?Sized + Hash,
165 {
166 let hash = self.hash_key(key);
167 let index = hash as usize % table.len();
168 &table[index]
169 }
170
171 fn read_bucket<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> &'a Bucket<K, V>
173 where
174 K: Borrow<Q>,
175 Q: ?Sized + Hash,
176 {
177 let table = self.table.as_slice(scope);
178 self.get_bucket(table, key)
179 }
180
181 pub fn get<'a, Q>(&self, scope: &'a RcuReadScope, key: &Q) -> Option<&'a V>
185 where
186 K: Borrow<Q>,
187 Q: ?Sized + Hash + Eq,
188 {
189 let bucket = self.read_bucket(scope, key);
190 bucket.iter(scope).find(|entry| entry.key.borrow() == key).map(|entry| &entry.value)
191 }
192
193 pub fn len(&self) -> usize {
197 self.num_entries.load(Ordering::Relaxed)
198 }
199
200 pub unsafe fn insert(&self, scope: &RcuReadScope, key: K, value: V) -> InsertionResult<V> {
214 let mut table = self.table.as_slice(scope);
215 if self.needs_to_grow(table) {
216 table = unsafe { self.grow(&scope, table) };
219 }
220 let bucket = self.get_bucket(table, &key);
221 let mut cursor = bucket.cursor(&scope);
222 while let Some(entry) = cursor.current() {
223 if entry.key == key {
224 let old_value = entry.value.clone();
225 unsafe {
228 let removed_entry = cursor.remove();
229 self.insertion_chain.remove(&scope, removed_entry);
230 let entry = bucket.push_front(&scope, Entry::new(key, value));
231 self.insertion_chain.push_back(&scope, entry);
232 };
233 return InsertionResult::Updated(old_value);
234 }
235 cursor.advance();
236 }
237
238 unsafe {
241 let entry = bucket.push_front(&scope, Entry::new(key, value));
242 self.insertion_chain.push_back(&scope, entry);
243 }
244 let count = self.num_entries.fetch_add(1, Ordering::Relaxed);
245 InsertionResult::Inserted(count + 1)
246 }
247
248 pub unsafe fn remove<Q>(&self, key: &Q) -> Option<V>
258 where
259 K: Borrow<Q>,
260 Q: ?Sized + Hash + Eq,
261 {
262 let scope = RcuReadScope::new();
263 let bucket = self.read_bucket(&scope, key);
264 let mut cursor = bucket.cursor(&scope);
265 while let Some(entry) = cursor.current() {
266 if entry.key.borrow() == key {
267 let old_value = entry.value.clone();
268 unsafe {
271 let removed_entry = cursor.remove();
272 self.insertion_chain.remove(&scope, removed_entry);
273 };
274 self.num_entries.fetch_sub(1, Ordering::Relaxed);
275 return Some(old_value);
276 }
277 cursor.advance();
278 }
279 None
280 }
281
282 fn needs_to_grow(&self, table: &[Bucket<K, V>]) -> bool {
284 self.num_entries.load(Ordering::Relaxed) > table.len() * 2
285 }
286
287 #[must_use]
296 unsafe fn grow<'a>(
297 &self,
298 scope: &'a RcuReadScope,
299 old_table: &[Bucket<K, V>],
300 ) -> &'a [Bucket<K, V>] {
301 let new_size = old_table.len() * 2;
302 let mut new_table = Vec::new();
303 let new_insertion_chain = RcuIntrusiveList::default();
304 new_table.resize_with(new_size, Default::default);
305
306 for entry in self.insertion_chain.iter(scope) {
307 let bucket = self.get_bucket(&new_table, &entry.key);
308 let key = entry.key.clone();
309 let value = entry.value.clone();
310 unsafe {
312 let entry = bucket.push_front(&scope, Entry::new(key, value));
313 new_insertion_chain.push_back(&scope, entry);
314 };
315 }
316
317 self.table.update(new_table);
318 unsafe {
320 self.insertion_chain.update(&scope, new_insertion_chain);
321 }
322 self.table.as_slice(scope)
323 }
324
325 pub fn cursor<'a>(&'a self, scope: &'a RcuReadScope) -> RcuRawHashMapCursor<'a, K, V, S> {
329 RcuRawHashMapCursor { inner: self.insertion_chain.cursor(scope), map: self }
330 }
331
332 pub fn keys<'a>(&'a self, scope: &'a RcuReadScope) -> impl Iterator<Item = &'a K> {
334 self.insertion_chain.iter(scope).map(|entry| &entry.key)
335 }
336}
337
338impl<K, V, S> std::fmt::Debug for RcuRawHashMap<K, V, S>
340where
341 K: Eq + Hash + Clone + Send + Sync + 'static + std::fmt::Debug,
342 V: Clone + Send + Sync + 'static + std::fmt::Debug,
343 S: std::hash::BuildHasher + Send + Sync + 'static,
344{
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 f.debug_struct("RcuRawHashMap")
347 .field("table", &self.table)
348 .field("num_entries", &self.num_entries)
349 .field("insertion_chain", &self.insertion_chain)
350 .field("hash_builder", &std::any::type_name::<S>())
351 .finish_non_exhaustive()
352 }
353}
354
355pub struct RcuRawHashMapCursor<'a, K, V, S = rapidhash::RapidBuildHasher>
359where
360 K: Eq + Hash + Clone + Send + Sync + 'static,
361 V: Clone + Send + Sync + 'static,
362 S: BuildHasher + Send + Sync + 'static,
363{
364 inner: RcuIntrusiveListCursor<'a, Entry<K, V>, InsertionAdapter>,
365 map: &'a RcuRawHashMap<K, V, S>,
366}
367
368impl<'a, K, V, S> RcuRawHashMapCursor<'a, K, V, S>
369where
370 K: Eq + Hash + Clone + Send + Sync + 'static,
371 V: Clone + Send + Sync + 'static,
372 S: BuildHasher + Send + Sync + 'static,
373{
374 pub fn current(&self) -> Option<(&'a K, &'a V)> {
376 self.inner.current().map(|entry| (&entry.key, &entry.value))
377 }
378
379 pub fn advance(&mut self) {
381 self.inner.advance()
382 }
383
384 pub unsafe fn remove(&mut self) -> Option<V> {
395 if let Some((key, _)) = self.current() {
396 self.advance();
397 unsafe { self.map.remove(key) }
399 } else {
400 None
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use fuchsia_rcu::rcu_synchronize;
409
410 #[test]
411 fn test_rcu_hash_map_custom_hasher() {
412 use std::collections::hash_map::DefaultHasher;
413 use std::hash::BuildHasherDefault;
414 let hasher = BuildHasherDefault::<DefaultHasher>::default();
415 let map = RcuRawHashMap::with_capacity_and_hasher(10, hasher);
416 let scope = RcuReadScope::new();
417 unsafe {
418 map.insert(&scope, 1, 10);
419 }
420 assert_eq!(map.get(&scope, &1), Some(&10));
421 }
422
423 #[test]
424 fn test_rcu_hash_map_insert_and_get() {
425 let map = RcuRawHashMap::default();
426 let scope = RcuReadScope::new();
427 unsafe {
428 map.insert(&scope, 1, 10);
429 map.insert(&scope, 2, 20);
430 }
431
432 assert_eq!(map.get(&scope, &1), Some(&10));
433 assert_eq!(map.get(&scope, &2), Some(&20));
434 assert_eq!(map.get(&scope, &3), None);
435
436 std::mem::drop(scope);
437 rcu_synchronize();
438 }
439
440 #[test]
441 fn test_rcu_hash_map_remove() {
442 let map = RcuRawHashMap::default();
443 let scope = RcuReadScope::new();
444 unsafe {
445 map.insert(&scope, 1, 10);
446 map.insert(&scope, 2, 20);
447 }
448
449 assert_eq!(map.get(&scope, &1), Some(&10));
450
451 unsafe {
452 assert_eq!(map.remove(&1), Some(10));
453 }
454
455 assert_eq!(map.get(&scope, &1), None);
456 assert_eq!(map.get(&scope, &2), Some(&20));
457
458 std::mem::drop(scope);
459 rcu_synchronize();
460 }
461
462 #[test]
463 fn test_rcu_hash_map_insert_update() {
464 let map = RcuRawHashMap::default();
465 let scope = RcuReadScope::new();
466 unsafe {
467 map.insert(&scope, 1, 10);
468 }
469
470 assert_eq!(map.get(&scope, &1), Some(&10));
471
472 let result = unsafe { map.insert(&scope, 1, 100) };
473 assert!(matches!(result, InsertionResult::Updated(10)));
474
475 assert_eq!(map.get(&scope, &1), Some(&100));
476
477 std::mem::drop(scope);
478 rcu_synchronize();
479 }
480
481 #[test]
482 fn test_rcu_hash_map_cursor() {
483 let map = RcuRawHashMap::default();
484 let scope = RcuReadScope::new();
485 unsafe {
486 map.insert(&scope, 1, 10);
487 map.insert(&scope, 2, 20);
488 map.insert(&scope, 3, 30);
489 }
490
491 let mut cursor = map.cursor(&scope);
492
493 assert_eq!(cursor.current(), Some((&1, &10)));
494 cursor.advance();
495 assert_eq!(cursor.current(), Some((&2, &20)));
496
497 unsafe {
498 cursor.remove();
499 }
500
501 assert_eq!(cursor.current(), Some((&3, &30)));
502 assert_eq!(map.get(&scope, &2), None);
503
504 cursor.advance();
505 assert_eq!(cursor.current(), None);
506
507 std::mem::drop(scope);
508 rcu_synchronize();
509 }
510
511 #[test]
512 fn test_rcu_hash_map_grow_maintains_order() {
513 let map = RcuRawHashMap::default();
514 let scope = RcuReadScope::new();
515 let num_elements = INITIAL_CAPACITY * 3;
516 let mut expected_order = Vec::new();
517
518 for i in 0..num_elements {
519 unsafe {
520 map.insert(&scope, i, i * 10);
521 }
522 expected_order.push((i, i * 10));
523 }
524
525 let mut cursor = map.cursor(&scope);
526 let mut actual_order = Vec::new();
527
528 while let Some((key, value)) = cursor.current() {
529 actual_order.push((*key, *value));
530 cursor.advance();
531 }
532
533 assert_eq!(actual_order, expected_order);
534
535 std::mem::drop(scope);
536 rcu_synchronize();
537 }
538 #[test]
539 fn test_rcu_hash_map_grow_overwrites_maintain_order() {
540 let map = RcuRawHashMap::default();
541 let scope = RcuReadScope::new();
542 let num_elements = INITIAL_CAPACITY * 3;
543 let mut expected_order = Vec::new();
544
545 for i in 0..num_elements {
546 unsafe {
547 map.insert(&scope, i, i * 10);
548 }
549 expected_order.push((i, i * 10));
550 }
551
552 unsafe {
554 map.insert(&scope, 5, 500);
555 map.insert(&scope, INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10); }
557 expected_order.retain(|(k, _)| *k != 5);
558 expected_order.push((5, 500));
559 expected_order.push((INITIAL_CAPACITY * 3, (INITIAL_CAPACITY * 3) * 10));
560
561 let mut cursor = map.cursor(&scope);
562 let mut actual_order = Vec::new();
563
564 while let Some((key, value)) = cursor.current() {
565 actual_order.push((*key, *value));
566 cursor.advance();
567 }
568
569 assert_eq!(actual_order, expected_order);
570
571 std::mem::drop(scope);
572 rcu_synchronize();
573 }
574
575 #[test]
576 fn test_rcu_hash_map_grow() {
577 let map = RcuRawHashMap::default();
578 let scope = RcuReadScope::new();
579 for i in 0..(INITIAL_CAPACITY * 3) {
580 unsafe {
581 map.insert(&scope, i, i * 10);
582 }
583 }
584
585 for i in 0..(INITIAL_CAPACITY * 3) {
586 assert_eq!(map.get(&scope, &i), Some(&(i * 10)));
587 }
588
589 std::mem::drop(scope);
590 rcu_synchronize();
591 }
592}