1use alloc::collections::{hash_map, HashMap};
13use core::fmt::Debug;
14use core::hash::Hash;
15use core::num::NonZeroUsize;
16use either::Either;
17
18use derivative::Derivative;
19
20pub trait IterShadows {
36 type IterShadows: Iterator<Item = Self>;
38 fn iter_shadows(&self) -> Self::IterShadows;
40}
41
42pub trait Tagged<A> {
47 type Tag: Copy + Eq + core::fmt::Debug;
49
50 fn tag(&self, address: &A) -> Self::Tag;
55}
56
57#[derive(Derivative, Debug)]
72#[derivative(Default(bound = ""))]
73pub struct SocketMap<A: Hash + Eq, V: Tagged<A>> {
74 map: HashMap<A, MapValue<V, V::Tag>>,
75 len: usize,
76}
77
78#[derive(Derivative, Debug)]
79#[derivative(Default(bound = ""))]
80struct MapValue<V, T> {
81 value: Option<V>,
82 descendant_counts: DescendantCounts<T>,
83}
84
85#[derive(Derivative, Debug)]
86#[derivative(Default(bound = ""))]
87struct DescendantCounts<T, const INLINE_SIZE: usize = 1> {
88 counts: smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]>,
93}
94
95pub struct OccupiedEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
102
103#[cfg_attr(test, derive(Debug))]
110pub struct VacantEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
111
112#[cfg_attr(test, derive(Debug))]
114pub enum Entry<'a, A: Hash + Eq, V: Tagged<A>> {
115 Occupied(OccupiedEntry<'a, A, V>),
123 Vacant(VacantEntry<'a, A, V>),
125}
126
127impl<A, V> SocketMap<A, V>
128where
129 A: IterShadows + Hash + Eq,
130 V: Tagged<A>,
131{
132 pub fn len(&self) -> usize {
134 self.len
135 }
136
137 pub fn get(&self, key: &A) -> Option<&V> {
139 let Self { map, len: _ } = self;
140 map.get(key).and_then(|MapValue { value, descendant_counts: _ }| value.as_ref())
141 }
142
143 pub fn entry(&mut self, key: A) -> Entry<'_, A, V> {
149 let Self { map, len: _ } = self;
150 match map.get(&key) {
151 Some(MapValue { descendant_counts: _, value: Some(_) }) => {
152 Entry::Occupied(OccupiedEntry(self, key))
153 }
154 Some(MapValue { descendant_counts: _, value: None }) | None => {
155 Entry::Vacant(VacantEntry(self, key))
156 }
157 }
158 }
159
160 #[cfg(test)]
165 pub fn remove(&mut self, key: &A) -> Option<V>
166 where
167 A: Clone,
168 {
169 match self.entry(key.clone()) {
170 Entry::Vacant(_) => return None,
171 Entry::Occupied(o) => Some(o.remove()),
172 }
173 }
174
175 pub fn descendant_counts(
182 &self,
183 key: &A,
184 ) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
185 let Self { map, len: _ } = self;
186 map.get(key)
187 .map(|MapValue { value: _, descendant_counts }| {
188 Either::Left(descendant_counts.into_iter())
189 })
190 .unwrap_or(Either::Right(core::iter::empty()))
191 }
192
193 pub fn iter(&self) -> impl Iterator<Item = (&'_ A, &'_ V)> {
195 let Self { map, len: _ } = self;
196 map.iter().filter_map(|(a, MapValue { value, descendant_counts: _ })| {
197 value.as_ref().map(|v| (a, v))
198 })
199 }
200
201 fn increment_descendant_counts(
202 map: &mut HashMap<A, MapValue<V, V::Tag>>,
203 shadows: A::IterShadows,
204 tag: V::Tag,
205 ) {
206 for shadow in shadows {
207 let MapValue { descendant_counts, value: _ } = map.entry(shadow).or_default();
208 descendant_counts.increment(tag);
209 }
210 }
211
212 fn update_descendant_counts(
213 map: &mut HashMap<A, MapValue<V, V::Tag>>,
214 shadows: A::IterShadows,
215 old_tag: V::Tag,
216 new_tag: V::Tag,
217 ) {
218 if old_tag != new_tag {
219 for shadow in shadows {
220 let counts = &mut map.get_mut(&shadow).unwrap().descendant_counts;
221 counts.increment(new_tag);
222 counts.decrement(old_tag);
223 }
224 }
225 }
226
227 fn decrement_descendant_counts(
228 map: &mut HashMap<A, MapValue<V, V::Tag>>,
229 shadows: A::IterShadows,
230 old_tag: V::Tag,
231 ) {
232 for shadow in shadows {
233 let mut entry = match map.entry(shadow) {
234 hash_map::Entry::Occupied(o) => o,
235 hash_map::Entry::Vacant(_) => unreachable!(),
236 };
237 let MapValue { descendant_counts, value } = entry.get_mut();
238 descendant_counts.decrement(old_tag);
239 if descendant_counts.is_empty() && value.is_none() {
240 let _: MapValue<_, _> = entry.remove();
241 }
242 }
243 }
244}
245
246impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> OccupiedEntry<'a, K, V> {
247 pub fn key(&self) -> &K {
249 let Self(SocketMap { map: _, len: _ }, key) = self;
250 key
251 }
252
253 pub fn get(&self) -> &V {
255 let Self(SocketMap { map, len: _ }, key) = self;
256 let MapValue { descendant_counts: _, value } = map.get(key).unwrap();
257 value.as_ref().unwrap()
259 }
260
261 pub fn map_mut<R>(&mut self, apply: impl FnOnce(&mut V) -> R) -> R {
268 let Self(SocketMap { map, len: _ }, key) = self;
269 let MapValue { descendant_counts: _, value } = map.get_mut(key).unwrap();
271 let value = value.as_mut().unwrap();
272
273 let old_tag = value.tag(key);
274 let r = apply(value);
275 let new_tag = value.tag(key);
276 SocketMap::update_descendant_counts(map, key.iter_shadows(), old_tag, new_tag);
277 r
278 }
279
280 pub fn into_map(self) -> &'a mut SocketMap<K, V> {
282 let Self(socketmap, _) = self;
283 socketmap
284 }
285
286 pub fn remove(self) -> V {
288 let (value, _map) = self.remove_from_map();
289 value
290 }
291
292 pub fn get_map(&self) -> &SocketMap<K, V> {
294 let Self(socketmap, _) = self;
295 socketmap
296 }
297
298 pub fn remove_from_map(self) -> (V, &'a mut SocketMap<K, V>) {
300 let Self(socketmap, key) = self;
301 let SocketMap { map, len } = socketmap;
302 let shadows = key.iter_shadows();
303 let mut entry = match map.entry(key) {
304 hash_map::Entry::Occupied(o) => o,
305 hash_map::Entry::Vacant(_) => unreachable!("OccupiedEntry not occupied"),
306 };
307 let tag = {
308 let MapValue { descendant_counts: _, value } = entry.get();
309 value.as_ref().unwrap().tag(entry.key())
311 };
312
313 let MapValue { descendant_counts, value } = entry.get_mut();
314 let value =
316 value.take().expect("OccupiedEntry invariant violated: expected Some, found None");
317 if descendant_counts.is_empty() {
318 let _: MapValue<V, V::Tag> = entry.remove();
319 }
320 SocketMap::decrement_descendant_counts(map, shadows, tag);
321 *len -= 1;
322 (value, socketmap)
323 }
324}
325
326impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> VacantEntry<'a, K, V> {
327 pub fn insert(self, value: V) -> OccupiedEntry<'a, K, V>
331 where
332 K: Clone,
333 {
334 let Self(socket_map, key) = self;
335 let SocketMap { map, len } = socket_map;
336 let iter_shadows = key.iter_shadows();
337 let tag = value.tag(&key);
338 *len += 1;
339 SocketMap::increment_descendant_counts(map, iter_shadows, tag);
340 let MapValue { value: map_value, descendant_counts: _ } =
341 map.entry(key.clone()).or_default();
342 assert!(map_value.replace(value).is_none());
343 OccupiedEntry(socket_map, key)
344 }
345
346 pub fn into_map(self) -> &'a mut SocketMap<K, V> {
348 let Self(socketmap, _) = self;
349 socketmap
350 }
351
352 pub fn descendant_counts(&self) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
354 let Self(socket_map, key) = self;
355 socket_map.descendant_counts(&key)
356 }
357}
358
359impl<'a, A: Debug + Eq + Hash, V: Tagged<A>> Debug for OccupiedEntry<'a, A, V> {
360 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
361 let Self(_socket_map, key) = self;
362 f.debug_tuple("OccupiedEntry").field(&"_").field(key).finish()
363 }
364}
365
366impl<T: Eq, const INLINE_SIZE: usize> DescendantCounts<T, INLINE_SIZE> {
367 const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap();
368
369 fn increment(&mut self, tag: T) {
371 let Self { counts } = self;
372 match counts.iter_mut().find_map(|(t, count)| (t == &tag).then_some(count)) {
373 Some(count) => *count = NonZeroUsize::new(count.get() + 1).unwrap(),
374 None => counts.push((tag, Self::ONE)),
375 }
376 }
377
378 fn decrement(&mut self, tag: T) {
384 let Self { counts } = self;
385 let (index, count) = counts
386 .iter_mut()
387 .enumerate()
388 .find_map(|(i, (t, count))| (t == &tag).then_some((i, count)))
389 .unwrap();
390 if let Some(new_count) = NonZeroUsize::new(count.get() - 1) {
391 *count = new_count
392 } else {
393 let _: (T, NonZeroUsize) = counts.swap_remove(index);
394 }
395 }
396
397 fn is_empty(&self) -> bool {
398 let Self { counts } = self;
399 counts.is_empty()
400 }
401}
402
403impl<'d, T, const INLINE_SIZE: usize> IntoIterator for &'d DescendantCounts<T, INLINE_SIZE> {
404 type Item = &'d (T, NonZeroUsize);
405 type IntoIter =
406 <&'d smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]> as IntoIterator>::IntoIter;
407
408 fn into_iter(self) -> Self::IntoIter {
409 let DescendantCounts { counts } = self;
410 counts.into_iter()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use alloc::vec::Vec;
417 use alloc::{format, vec};
418
419 use assert_matches::assert_matches;
420 use proptest::prop_assert_eq;
421 use proptest::strategy::Strategy;
422
423 use super::*;
424
425 trait AsMap {
426 type K: Hash + Eq;
427 type V;
428 fn as_map(self) -> HashMap<Self::K, Self::V>;
429 }
430
431 impl<'d, K, V, I> AsMap for I
432 where
433 K: Hash + Eq + Clone + 'd,
434 V: 'd,
435 V: Clone + Into<usize>,
436 I: Iterator<Item = &'d (K, V)>,
437 {
438 type K = K;
439 type V = usize;
440 fn as_map(self) -> HashMap<Self::K, Self::V> {
441 self.map(|(k, v)| (k.clone(), v.clone().into())).collect()
442 }
443 }
444
445 #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
446 enum Address {
447 A(u8),
448 AB(u8, char),
449 ABC(u8, char, u8),
450 }
451 use Address::*;
452
453 impl IterShadows for Address {
454 type IterShadows = <Vec<Address> as IntoIterator>::IntoIter;
455 fn iter_shadows(&self) -> Self::IterShadows {
456 match self {
457 A(_) => vec![],
458 AB(a, _) => vec![A(*a)],
459 ABC(a, b, _) => vec![AB(*a, *b), A(*a)],
460 }
461 .into_iter()
462 }
463 }
464
465 #[derive(Eq, PartialEq, Clone, Copy, Debug)]
466 struct TV<T, V>(T, V);
467
468 impl<T: Copy + Eq + core::fmt::Debug, V> Tagged<Address> for TV<T, V> {
469 type Tag = T;
470
471 fn tag(&self, _: &Address) -> Self::Tag {
472 self.0
473 }
474 }
475
476 type TestSocketMap<T> = SocketMap<Address, TV<T, u8>>;
477
478 #[test]
479 fn insert_get_remove() {
480 let mut map = TestSocketMap::default();
481
482 assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
483 assert_eq!(map.get(&ABC(1, 'c', 2)), Some(&TV(0, 32)));
484
485 assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
486 assert_eq!(map.get(&ABC(1, 'c', 2)), None);
487 }
488
489 #[test]
490 fn insert_remove_len() {
491 let mut map = TestSocketMap::default();
492 let TestSocketMap { len, map: _ } = map;
493 assert_eq!(len, 0);
494
495 assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
496 let TestSocketMap { len, map: _ } = map;
497 assert_eq!(len, 1);
498
499 assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
500 let TestSocketMap { len, map: _ } = map;
501 assert_eq!(len, 0);
502 }
503
504 #[test]
505 fn entry_same_key() {
506 let mut map = TestSocketMap::default();
507
508 assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
509 let occupied = assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Occupied(o) => o);
510 assert_eq!(occupied.get(), &TV(0, 32));
511 let TestSocketMap { len, map: _ } = map;
512 assert_eq!(len, 1);
513 }
514
515 #[test]
516 fn multiple_insert_descendant_counts() {
517 let mut map = TestSocketMap::default();
518
519 assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(1, 111)));
520 assert_matches!(map.entry(ABC(1, 'd', 2)), Entry::Vacant(v) => v.insert(TV(2, 111)));
521 assert_matches!(map.entry(AB(5, 'd')), Entry::Vacant(v) => v.insert(TV(1, 54)));
522 assert_matches!(map.entry(AB(1, 'd')), Entry::Vacant(v) => v.insert(TV(3, 56)));
523 let TestSocketMap { len, map: _ } = map;
524 assert_eq!(len, 4);
525
526 assert_eq!(map.descendant_counts(&A(1)).as_map(), HashMap::from([(1, 1), (2, 1), (3, 1)]));
527 assert_eq!(map.descendant_counts(&AB(1, 'c')).as_map(), HashMap::from([(1, 1)]));
528 assert_eq!(map.descendant_counts(&AB(1, 'd')).as_map(), HashMap::from([(2, 1)]));
529
530 assert_eq!(map.descendant_counts(&A(5)).as_map(), HashMap::from([(1, 1)]));
531
532 assert_eq!(map.descendant_counts(&ABC(1, 'd', 2)).as_map(), HashMap::from([]));
533 assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([]));
534 }
535
536 #[test]
537 fn entry_remove_no_shadows() {
538 let mut map = TestSocketMap::default();
539
540 assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(3, 111)));
541
542 let entry = assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Occupied(o) => o);
543 assert_eq!(entry.remove(), TV(3, 111));
544 let TestSocketMap { map, len } = map;
545 assert_eq!(len, 0);
546 assert_eq!(map.len(), 0);
547 }
548
549 #[test]
550 fn entry_remove_with_shadows() {
551 let mut map = TestSocketMap::default();
552
553 assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(2, 112)));
554 assert_matches!(map.entry(AB(16, 'c')), Entry::Vacant(v) => v.insert(TV(1, 111)));
555 assert_matches!(map.entry(A(16)), Entry::Vacant(v) => v.insert(TV(0, 110)));
556
557 let entry = assert_matches!(map.entry(AB(16, 'c')), Entry::Occupied(o) => o);
558 assert_eq!(entry.remove(), TV(1, 111));
559 let TestSocketMap { map, len } = map;
560 assert_eq!(len, 2);
561 assert_eq!(map.len(), 3);
562 }
563
564 #[test]
565 fn remove_ancestor_value() {
566 let mut map = TestSocketMap::default();
567 assert_matches!(map.entry(ABC(2, 'e', 1)), Entry::Vacant(v) => v.insert(TV(20, 100)));
568 assert_matches!(map.entry(AB(2, 'e')), Entry::Vacant(v) => v.insert(TV(20, 100)));
569 assert_eq!(map.remove(&AB(2, 'e')), Some(TV(20, 100)));
570
571 assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([(20, 1)]));
572 }
573
574 fn key_strategy() -> impl Strategy<Value = Address> {
575 let a_strategy = 1..5u8;
576 let b_strategy = proptest::char::range('a', 'e');
577 let c_strategy = 1..5u8;
578 (a_strategy, proptest::option::of((b_strategy, proptest::option::of(c_strategy)))).prop_map(
579 |(a, b)| match b {
580 None => A(a),
581 Some((b, None)) => AB(a, b),
582 Some((b, Some(c))) => ABC(a, b, c),
583 },
584 )
585 }
586
587 fn value_strategy() -> impl Strategy<Value = TV<u8, u8>> {
588 (20..25u8, 100..105u8).prop_map(|(t, v)| TV(t, v))
589 }
590
591 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
592 enum Operation {
593 Entry(Address, TV<u8, u8>),
594 Remove(Address),
595 }
596
597 impl Operation {
598 fn apply(
599 self,
600 socket_map: &mut TestSocketMap<u8>,
601 reference: &mut HashMap<Address, TV<u8, u8>>,
602 ) {
603 match self {
604 Operation::Entry(a, v) => match (socket_map.entry(a), reference.entry(a)) {
605 (Entry::Occupied(mut s), hash_map::Entry::Occupied(mut h)) => {
606 assert_eq!(s.map_mut(|value| core::mem::replace(value, v)), h.insert(v))
607 }
608 (Entry::Vacant(s), hash_map::Entry::Vacant(h)) => {
609 let _: OccupiedEntry<'_, _, _> = s.insert(v);
610 let _: &mut TV<_, _> = h.insert(v);
611 }
612 (Entry::Occupied(_), hash_map::Entry::Vacant(_)) => {
613 panic!("socketmap has a value for {:?} but reference does not", a)
614 }
615 (Entry::Vacant(_), hash_map::Entry::Occupied(_)) => {
616 panic!("socketmap has no value for {:?} but reference does", a)
617 }
618 },
619 Operation::Remove(a) => assert_eq!(socket_map.remove(&a), reference.remove(&a)),
620 }
621 }
622 }
623
624 fn operation_strategy() -> impl Strategy<Value = Operation> {
625 proptest::prop_oneof!(
626 (key_strategy(), value_strategy()).prop_map(|(a, v)| Operation::Entry(a, v)),
627 key_strategy().prop_map(|a| Operation::Remove(a)),
628 )
629 }
630
631 fn validate_map(
632 map: TestSocketMap<u8>,
633 reference: HashMap<Address, TV<u8, u8>>,
634 ) -> Result<(), proptest::test_runner::TestCaseError> {
635 let map_values: HashMap<_, _> = map.iter().map(|(a, v)| (*a, *v)).collect();
636 assert_eq!(map_values, reference);
637 let TestSocketMap { len, map: _ } = map;
638 assert_eq!(len, reference.len());
639
640 let TestSocketMap { map: inner_map, len: _ } = ↦
641 for (key, entry) in inner_map {
642 let descendant_values = map
643 .iter()
644 .filter(|(k, _)| k.iter_shadows().any(|s| s == *key))
645 .map(|(_, value)| value);
646
647 let expected_tag_counts = descendant_values.fold(HashMap::new(), |mut m, v| {
649 *m.entry(v.tag(key)).or_default() += 1;
650 m
651 });
652
653 let MapValue { descendant_counts, value: _ } = entry;
654 prop_assert_eq!(
655 expected_tag_counts,
656 descendant_counts.into_iter().as_map(),
657 "key = {:?}",
658 key
659 );
660 }
661 Ok(())
662 }
663
664 proptest::proptest! {
665 #![proptest_config(proptest::test_runner::Config {
666 failure_persistence: proptest_support::failed_seeds_no_std!(),
668 ..proptest::test_runner::Config::default()
669 })]
670
671 #[test]
672 fn test_arbitrary_operations(operations in proptest::collection::vec(operation_strategy(), 10)) {
673 let mut map = TestSocketMap::default();
674 let mut reference = HashMap::new();
675 for op in operations {
676 op.apply(&mut map, &mut reference);
677 }
678
679 validate_map(map, reference)?;
682 }
683
684 }
685}