netstack3_base/data_structures/
socketmap.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Defines generic data structures used to implement common application socket
6//! functionality for multiple protocols.
7//!
8//! The core of this module is the [`SocketMap`] struct. It provides a map-like
9//! API for setting and getting values while maintaining extra information about
10//! the number of values of certain types present in the map.
11
12use alloc::collections::{hash_map, HashMap};
13use core::fmt::Debug;
14use core::hash::Hash;
15use core::num::NonZeroUsize;
16use either::Either;
17
18use derivative::Derivative;
19
20/// A type whose values can "shadow" other values of the type.
21///
22/// An implementation of this trait defines a relationship between values of the
23/// type. For any value `s: S`, if `t` appears in
24/// `IterShadows::iter_shadows(s)`, then `s` shadows `t`.
25///
26/// This "shadows" relationship is similar to [`PartialOrd`] in that certain
27/// propreties must hold:
28///
29/// 1. transitivity: if `s.iter_shadows()` yields `t`, and `t.iter_shadows()`
30///    yields `u`, then `s.iter_shadows()` must also yield `u`.
31/// 2. anticyclic: `s` cannot shadow itself.
32///
33/// Produces an iterator that yields all the shadows of a given value. The order
34/// of iteration is unspecified.
35pub trait IterShadows {
36    /// The iterator returned by `iter_shadows`.
37    type IterShadows: Iterator<Item = Self>;
38    /// Produces the iterator for shadow values.
39    fn iter_shadows(&self) -> Self::IterShadows;
40}
41
42/// A type whose values can be used to produce "tag" values of a different type.
43///
44/// This can be used to provide a summary value, e.g. even or odd for an
45/// integer-like type.
46pub trait Tagged<A> {
47    /// The tag type.
48    type Tag: Copy + Eq + core::fmt::Debug;
49
50    /// Returns the tag value for `self` at the given address.
51    ///
52    /// This function must be deterministic, such that calling `Tagged::tag` on
53    /// the same values always returns the same tag value.
54    fn tag(&self, address: &A) -> Self::Tag;
55}
56
57/// A map that stores values and summarizes tag counts.
58///
59/// This provides a similar insertion/removal API to [`HashMap`] for individual
60/// key/value pairs. Unlike a regular `HashMap`, the key type `A` is required to
61/// implement [`IterShadows`], and `V` to implement [`Tagged`].
62///
63/// Since `A` implements `IterShadows`, a given value `a : A` has zero or more
64/// shadow values. Since the shadow relationship is transitive, we call any
65/// value `v` that is reachable by following shadows of `a` one of `a`'s
66/// "ancestors", and we say `a` is a "descendant" of `v`.
67///
68/// In addition to keys and values, this map stores the number of values
69/// present in the map for all descendants of each key. These counts are
70/// separated into buckets for different tags of type `V::Tag`.
71#[derive(Derivative, Debug)]
72#[derivative(Default(bound = ""))]
73pub struct SocketMap<A: Hash + Eq, V: Tagged<A>> {
74    map: HashMap<A, MapValue<V, V::Tag>>,
75    len: usize,
76}
77
78#[derive(Derivative, Debug)]
79#[derivative(Default(bound = ""))]
80struct MapValue<V, T> {
81    value: Option<V>,
82    descendant_counts: DescendantCounts<T>,
83}
84
85#[derive(Derivative, Debug)]
86#[derivative(Default(bound = ""))]
87struct DescendantCounts<T, const INLINE_SIZE: usize = 1> {
88    /// Holds unordered (tag, count) pairs.
89    ///
90    /// [`DescendantCounts`] maintains the invariant that tags are unique. The
91    /// ordering of tags is unspecified.
92    counts: smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]>,
93}
94
95/// An entry for a key in a map that has a value.
96///
97/// This type maintains the invariant that, if an `OccupiedEntry(map, a)`
98/// exists, `SocketMap::get(map, a)` is `Some(v)`, i.e. the `HashMap` that
99/// [`SocketMap`] wraps contains a [`MapValue`] whose `value` field is
100/// `Some(v)`.
101pub struct OccupiedEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
102
103/// An entry for a key in a map that does not have a value.
104///
105/// This type maintains the invariant that, if a `VacantEntry(map, a)` exists,
106/// `SocketMap::get(map, a)` is `None`. This means that in the `HashMap` that
107/// `SocketMap` wraps, either there is no value for key `a` or there is a
108/// `MapValue` whose `value` field is `None`.
109#[cfg_attr(test, derive(Debug))]
110pub struct VacantEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
111
112/// An entry in a map that can be used to manipulate the value in-place.
113#[cfg_attr(test, derive(Debug))]
114pub enum Entry<'a, A: Hash + Eq, V: Tagged<A>> {
115    // NB: Both `OccupiedEntry` and `VacantEntry` store a reference to the map
116    // and a key directly since they need access to the entire map to update
117    // descendant counts. This means that any operation on them requires an
118    // additional map lookup with the same key. Experimentation suggests the
119    // compiler will optimize this duplicate lookup out, since it is the same
120    // one done by `SocketMap::entry` to produce the `Entry` in the first place.
121    /// An occupied entry.
122    Occupied(OccupiedEntry<'a, A, V>),
123    /// A vacant entry.
124    Vacant(VacantEntry<'a, A, V>),
125}
126
127impl<A, V> SocketMap<A, V>
128where
129    A: IterShadows + Hash + Eq,
130    V: Tagged<A>,
131{
132    /// Returns the number of entries in this `SocketMap`.
133    pub fn len(&self) -> usize {
134        self.len
135    }
136
137    /// Gets a reference to the value associated with the given key, if any.
138    pub fn get(&self, key: &A) -> Option<&V> {
139        let Self { map, len: _ } = self;
140        map.get(key).and_then(|MapValue { value, descendant_counts: _ }| value.as_ref())
141    }
142
143    /// Provides an [`Entry`] for the given key for in-place manipulation.
144    ///
145    /// This is similar to the API provided by [`HashMap::entry`]. Callers can
146    /// match on the result to perform different actions depending on whether
147    /// the map has a value for the key or not.
148    pub fn entry(&mut self, key: A) -> Entry<'_, A, V> {
149        let Self { map, len: _ } = self;
150        match map.get(&key) {
151            Some(MapValue { descendant_counts: _, value: Some(_) }) => {
152                Entry::Occupied(OccupiedEntry(self, key))
153            }
154            Some(MapValue { descendant_counts: _, value: None }) | None => {
155                Entry::Vacant(VacantEntry(self, key))
156            }
157        }
158    }
159
160    /// Removes the value for the given key if there is one.
161    ///
162    /// If there is a value for key `key`, removes it and returns it. Otherwise
163    /// returns None.
164    #[cfg(test)]
165    pub fn remove(&mut self, key: &A) -> Option<V>
166    where
167        A: Clone,
168    {
169        match self.entry(key.clone()) {
170            Entry::Vacant(_) => return None,
171            Entry::Occupied(o) => Some(o.remove()),
172        }
173    }
174
175    /// Returns counts of tags for values at keys that shadow `key`.
176    ///
177    /// This is equivalent to iterating over all keys in the map, filtering for
178    /// those keys for which `key` is one of their shadows, then calling
179    /// [`Tagged::tag`] on the value for each of those keys, and then computing
180    /// the number of occurrences for each tag.
181    pub fn descendant_counts(
182        &self,
183        key: &A,
184    ) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
185        let Self { map, len: _ } = self;
186        map.get(key)
187            .map(|MapValue { value: _, descendant_counts }| {
188                Either::Left(descendant_counts.into_iter())
189            })
190            .unwrap_or(Either::Right(core::iter::empty()))
191    }
192
193    /// Returns an iterator over the keys and values in the map.
194    pub fn iter(&self) -> impl Iterator<Item = (&'_ A, &'_ V)> {
195        let Self { map, len: _ } = self;
196        map.iter().filter_map(|(a, MapValue { value, descendant_counts: _ })| {
197            value.as_ref().map(|v| (a, v))
198        })
199    }
200
201    fn increment_descendant_counts(
202        map: &mut HashMap<A, MapValue<V, V::Tag>>,
203        shadows: A::IterShadows,
204        tag: V::Tag,
205    ) {
206        for shadow in shadows {
207            let MapValue { descendant_counts, value: _ } = map.entry(shadow).or_default();
208            descendant_counts.increment(tag);
209        }
210    }
211
212    fn update_descendant_counts(
213        map: &mut HashMap<A, MapValue<V, V::Tag>>,
214        shadows: A::IterShadows,
215        old_tag: V::Tag,
216        new_tag: V::Tag,
217    ) {
218        if old_tag != new_tag {
219            for shadow in shadows {
220                let counts = &mut map.get_mut(&shadow).unwrap().descendant_counts;
221                counts.increment(new_tag);
222                counts.decrement(old_tag);
223            }
224        }
225    }
226
227    fn decrement_descendant_counts(
228        map: &mut HashMap<A, MapValue<V, V::Tag>>,
229        shadows: A::IterShadows,
230        old_tag: V::Tag,
231    ) {
232        for shadow in shadows {
233            let mut entry = match map.entry(shadow) {
234                hash_map::Entry::Occupied(o) => o,
235                hash_map::Entry::Vacant(_) => unreachable!(),
236            };
237            let MapValue { descendant_counts, value } = entry.get_mut();
238            descendant_counts.decrement(old_tag);
239            if descendant_counts.is_empty() && value.is_none() {
240                let _: MapValue<_, _> = entry.remove();
241            }
242        }
243    }
244}
245
246impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> OccupiedEntry<'a, K, V> {
247    /// Gets a reference to the key for the entry.
248    pub fn key(&self) -> &K {
249        let Self(SocketMap { map: _, len: _ }, key) = self;
250        key
251    }
252
253    /// Retrieves the value referenced by this entry.
254    pub fn get(&self) -> &V {
255        let Self(SocketMap { map, len: _ }, key) = self;
256        let MapValue { descendant_counts: _, value } = map.get(key).unwrap();
257        // unwrap() call is guaranteed safe by OccupiedEntry invariant.
258        value.as_ref().unwrap()
259    }
260
261    // NB: there is no get_mut because that would allow the caller to manipulate
262    // a value without updating the descendant tag counts.
263
264    /// Runs the provided callback on the value referenced by this entry.
265    ///
266    /// Returns the result of the callback.
267    pub fn map_mut<R>(&mut self, apply: impl FnOnce(&mut V) -> R) -> R {
268        let Self(SocketMap { map, len: _ }, key) = self;
269        // unwrap() calls are guaranteed safe by OccupiedEntry invariant.
270        let MapValue { descendant_counts: _, value } = map.get_mut(key).unwrap();
271        let value = value.as_mut().unwrap();
272
273        let old_tag = value.tag(key);
274        let r = apply(value);
275        let new_tag = value.tag(key);
276        SocketMap::update_descendant_counts(map, key.iter_shadows(), old_tag, new_tag);
277        r
278    }
279
280    /// Extracts the underlying [`SocketMap`] reference backing this entry.
281    pub fn into_map(self) -> &'a mut SocketMap<K, V> {
282        let Self(socketmap, _) = self;
283        socketmap
284    }
285
286    /// Removes the value from the map and returns it.
287    pub fn remove(self) -> V {
288        let (value, _map) = self.remove_from_map();
289        value
290    }
291
292    /// Gets a reference to the backing map.
293    pub fn get_map(&self) -> &SocketMap<K, V> {
294        let Self(socketmap, _) = self;
295        socketmap
296    }
297
298    /// Removes the value from the map and returns the value and map.
299    pub fn remove_from_map(self) -> (V, &'a mut SocketMap<K, V>) {
300        let Self(socketmap, key) = self;
301        let SocketMap { map, len } = socketmap;
302        let shadows = key.iter_shadows();
303        let mut entry = match map.entry(key) {
304            hash_map::Entry::Occupied(o) => o,
305            hash_map::Entry::Vacant(_) => unreachable!("OccupiedEntry not occupied"),
306        };
307        let tag = {
308            let MapValue { descendant_counts: _, value } = entry.get();
309            // unwrap() is guaranteed safe by OccupiedEntry invariant.
310            value.as_ref().unwrap().tag(entry.key())
311        };
312
313        let MapValue { descendant_counts, value } = entry.get_mut();
314        // unwrap() is guaranteed safe by OccupiedEntry invariant.
315        let value =
316            value.take().expect("OccupiedEntry invariant violated: expected Some, found None");
317        if descendant_counts.is_empty() {
318            let _: MapValue<V, V::Tag> = entry.remove();
319        }
320        SocketMap::decrement_descendant_counts(map, shadows, tag);
321        *len -= 1;
322        (value, socketmap)
323    }
324}
325
326impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> VacantEntry<'a, K, V> {
327    /// Inserts a value for the key referenced by this entry.
328    ///
329    /// Returns a reference to the newly-inserted value.
330    pub fn insert(self, value: V) -> OccupiedEntry<'a, K, V>
331    where
332        K: Clone,
333    {
334        let Self(socket_map, key) = self;
335        let SocketMap { map, len } = socket_map;
336        let iter_shadows = key.iter_shadows();
337        let tag = value.tag(&key);
338        *len += 1;
339        SocketMap::increment_descendant_counts(map, iter_shadows, tag);
340        let MapValue { value: map_value, descendant_counts: _ } =
341            map.entry(key.clone()).or_default();
342        assert!(map_value.replace(value).is_none());
343        OccupiedEntry(socket_map, key)
344    }
345
346    /// Extracts the underlying [`SocketMap`] reference backing this entry.
347    pub fn into_map(self) -> &'a mut SocketMap<K, V> {
348        let Self(socketmap, _) = self;
349        socketmap
350    }
351
352    /// Gets the descendant counts for this entry.
353    pub fn descendant_counts(&self) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
354        let Self(socket_map, key) = self;
355        socket_map.descendant_counts(&key)
356    }
357}
358
359impl<'a, A: Debug + Eq + Hash, V: Tagged<A>> Debug for OccupiedEntry<'a, A, V> {
360    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
361        let Self(_socket_map, key) = self;
362        f.debug_tuple("OccupiedEntry").field(&"_").field(key).finish()
363    }
364}
365
366impl<T: Eq, const INLINE_SIZE: usize> DescendantCounts<T, INLINE_SIZE> {
367    const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap();
368
369    /// Increments the count for the given tag.
370    fn increment(&mut self, tag: T) {
371        let Self { counts } = self;
372        match counts.iter_mut().find_map(|(t, count)| (t == &tag).then_some(count)) {
373            Some(count) => *count = NonZeroUsize::new(count.get() + 1).unwrap(),
374            None => counts.push((tag, Self::ONE)),
375        }
376    }
377
378    /// Decrements the count for the given tag.
379    ///
380    /// # Panics
381    ///
382    /// Panics if there is no count for the given tag.
383    fn decrement(&mut self, tag: T) {
384        let Self { counts } = self;
385        let (index, count) = counts
386            .iter_mut()
387            .enumerate()
388            .find_map(|(i, (t, count))| (t == &tag).then_some((i, count)))
389            .unwrap();
390        if let Some(new_count) = NonZeroUsize::new(count.get() - 1) {
391            *count = new_count
392        } else {
393            let _: (T, NonZeroUsize) = counts.swap_remove(index);
394        }
395    }
396
397    fn is_empty(&self) -> bool {
398        let Self { counts } = self;
399        counts.is_empty()
400    }
401}
402
403impl<'d, T, const INLINE_SIZE: usize> IntoIterator for &'d DescendantCounts<T, INLINE_SIZE> {
404    type Item = &'d (T, NonZeroUsize);
405    type IntoIter =
406        <&'d smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]> as IntoIterator>::IntoIter;
407
408    fn into_iter(self) -> Self::IntoIter {
409        let DescendantCounts { counts } = self;
410        counts.into_iter()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use alloc::vec::Vec;
417    use alloc::{format, vec};
418
419    use assert_matches::assert_matches;
420    use proptest::prop_assert_eq;
421    use proptest::strategy::Strategy;
422
423    use super::*;
424
425    trait AsMap {
426        type K: Hash + Eq;
427        type V;
428        fn as_map(self) -> HashMap<Self::K, Self::V>;
429    }
430
431    impl<'d, K, V, I> AsMap for I
432    where
433        K: Hash + Eq + Clone + 'd,
434        V: 'd,
435        V: Clone + Into<usize>,
436        I: Iterator<Item = &'d (K, V)>,
437    {
438        type K = K;
439        type V = usize;
440        fn as_map(self) -> HashMap<Self::K, Self::V> {
441            self.map(|(k, v)| (k.clone(), v.clone().into())).collect()
442        }
443    }
444
445    #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
446    enum Address {
447        A(u8),
448        AB(u8, char),
449        ABC(u8, char, u8),
450    }
451    use Address::*;
452
453    impl IterShadows for Address {
454        type IterShadows = <Vec<Address> as IntoIterator>::IntoIter;
455        fn iter_shadows(&self) -> Self::IterShadows {
456            match self {
457                A(_) => vec![],
458                AB(a, _) => vec![A(*a)],
459                ABC(a, b, _) => vec![AB(*a, *b), A(*a)],
460            }
461            .into_iter()
462        }
463    }
464
465    #[derive(Eq, PartialEq, Clone, Copy, Debug)]
466    struct TV<T, V>(T, V);
467
468    impl<T: Copy + Eq + core::fmt::Debug, V> Tagged<Address> for TV<T, V> {
469        type Tag = T;
470
471        fn tag(&self, _: &Address) -> Self::Tag {
472            self.0
473        }
474    }
475
476    type TestSocketMap<T> = SocketMap<Address, TV<T, u8>>;
477
478    #[test]
479    fn insert_get_remove() {
480        let mut map = TestSocketMap::default();
481
482        assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
483        assert_eq!(map.get(&ABC(1, 'c', 2)), Some(&TV(0, 32)));
484
485        assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
486        assert_eq!(map.get(&ABC(1, 'c', 2)), None);
487    }
488
489    #[test]
490    fn insert_remove_len() {
491        let mut map = TestSocketMap::default();
492        let TestSocketMap { len, map: _ } = map;
493        assert_eq!(len, 0);
494
495        assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
496        let TestSocketMap { len, map: _ } = map;
497        assert_eq!(len, 1);
498
499        assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
500        let TestSocketMap { len, map: _ } = map;
501        assert_eq!(len, 0);
502    }
503
504    #[test]
505    fn entry_same_key() {
506        let mut map = TestSocketMap::default();
507
508        assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
509        let occupied = assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Occupied(o) => o);
510        assert_eq!(occupied.get(), &TV(0, 32));
511        let TestSocketMap { len, map: _ } = map;
512        assert_eq!(len, 1);
513    }
514
515    #[test]
516    fn multiple_insert_descendant_counts() {
517        let mut map = TestSocketMap::default();
518
519        assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(1, 111)));
520        assert_matches!(map.entry(ABC(1, 'd', 2)), Entry::Vacant(v) => v.insert(TV(2, 111)));
521        assert_matches!(map.entry(AB(5, 'd')), Entry::Vacant(v) => v.insert(TV(1, 54)));
522        assert_matches!(map.entry(AB(1, 'd')),  Entry::Vacant(v) => v.insert(TV(3, 56)));
523        let TestSocketMap { len, map: _ } = map;
524        assert_eq!(len, 4);
525
526        assert_eq!(map.descendant_counts(&A(1)).as_map(), HashMap::from([(1, 1), (2, 1), (3, 1)]));
527        assert_eq!(map.descendant_counts(&AB(1, 'c')).as_map(), HashMap::from([(1, 1)]));
528        assert_eq!(map.descendant_counts(&AB(1, 'd')).as_map(), HashMap::from([(2, 1)]));
529
530        assert_eq!(map.descendant_counts(&A(5)).as_map(), HashMap::from([(1, 1)]));
531
532        assert_eq!(map.descendant_counts(&ABC(1, 'd', 2)).as_map(), HashMap::from([]));
533        assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([]));
534    }
535
536    #[test]
537    fn entry_remove_no_shadows() {
538        let mut map = TestSocketMap::default();
539
540        assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(3, 111)));
541
542        let entry = assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Occupied(o) => o);
543        assert_eq!(entry.remove(), TV(3, 111));
544        let TestSocketMap { map, len } = map;
545        assert_eq!(len, 0);
546        assert_eq!(map.len(), 0);
547    }
548
549    #[test]
550    fn entry_remove_with_shadows() {
551        let mut map = TestSocketMap::default();
552
553        assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(2, 112)));
554        assert_matches!(map.entry(AB(16, 'c')), Entry::Vacant(v) => v.insert(TV(1, 111)));
555        assert_matches!(map.entry(A(16)), Entry::Vacant(v) => v.insert(TV(0, 110)));
556
557        let entry = assert_matches!(map.entry(AB(16, 'c')), Entry::Occupied(o) => o);
558        assert_eq!(entry.remove(), TV(1, 111));
559        let TestSocketMap { map, len } = map;
560        assert_eq!(len, 2);
561        assert_eq!(map.len(), 3);
562    }
563
564    #[test]
565    fn remove_ancestor_value() {
566        let mut map = TestSocketMap::default();
567        assert_matches!(map.entry(ABC(2, 'e', 1)), Entry::Vacant(v) => v.insert(TV(20, 100)));
568        assert_matches!(map.entry(AB(2, 'e')), Entry::Vacant(v) => v.insert(TV(20, 100)));
569        assert_eq!(map.remove(&AB(2, 'e')), Some(TV(20, 100)));
570
571        assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([(20, 1)]));
572    }
573
574    fn key_strategy() -> impl Strategy<Value = Address> {
575        let a_strategy = 1..5u8;
576        let b_strategy = proptest::char::range('a', 'e');
577        let c_strategy = 1..5u8;
578        (a_strategy, proptest::option::of((b_strategy, proptest::option::of(c_strategy)))).prop_map(
579            |(a, b)| match b {
580                None => A(a),
581                Some((b, None)) => AB(a, b),
582                Some((b, Some(c))) => ABC(a, b, c),
583            },
584        )
585    }
586
587    fn value_strategy() -> impl Strategy<Value = TV<u8, u8>> {
588        (20..25u8, 100..105u8).prop_map(|(t, v)| TV(t, v))
589    }
590
591    #[derive(Debug, Copy, Clone, Eq, PartialEq)]
592    enum Operation {
593        Entry(Address, TV<u8, u8>),
594        Remove(Address),
595    }
596
597    impl Operation {
598        fn apply(
599            self,
600            socket_map: &mut TestSocketMap<u8>,
601            reference: &mut HashMap<Address, TV<u8, u8>>,
602        ) {
603            match self {
604                Operation::Entry(a, v) => match (socket_map.entry(a), reference.entry(a)) {
605                    (Entry::Occupied(mut s), hash_map::Entry::Occupied(mut h)) => {
606                        assert_eq!(s.map_mut(|value| core::mem::replace(value, v)), h.insert(v))
607                    }
608                    (Entry::Vacant(s), hash_map::Entry::Vacant(h)) => {
609                        let _: OccupiedEntry<'_, _, _> = s.insert(v);
610                        let _: &mut TV<_, _> = h.insert(v);
611                    }
612                    (Entry::Occupied(_), hash_map::Entry::Vacant(_)) => {
613                        panic!("socketmap has a value for {:?} but reference does not", a)
614                    }
615                    (Entry::Vacant(_), hash_map::Entry::Occupied(_)) => {
616                        panic!("socketmap has no value for {:?} but reference does", a)
617                    }
618                },
619                Operation::Remove(a) => assert_eq!(socket_map.remove(&a), reference.remove(&a)),
620            }
621        }
622    }
623
624    fn operation_strategy() -> impl Strategy<Value = Operation> {
625        proptest::prop_oneof!(
626            (key_strategy(), value_strategy()).prop_map(|(a, v)| Operation::Entry(a, v)),
627            key_strategy().prop_map(|a| Operation::Remove(a)),
628        )
629    }
630
631    fn validate_map(
632        map: TestSocketMap<u8>,
633        reference: HashMap<Address, TV<u8, u8>>,
634    ) -> Result<(), proptest::test_runner::TestCaseError> {
635        let map_values: HashMap<_, _> = map.iter().map(|(a, v)| (*a, *v)).collect();
636        assert_eq!(map_values, reference);
637        let TestSocketMap { len, map: _ } = map;
638        assert_eq!(len, reference.len());
639
640        let TestSocketMap { map: inner_map, len: _ } = &map;
641        for (key, entry) in inner_map {
642            let descendant_values = map
643                .iter()
644                .filter(|(k, _)| k.iter_shadows().any(|s| s == *key))
645                .map(|(_, value)| value);
646
647            // Fold values into a map from tag to count.
648            let expected_tag_counts = descendant_values.fold(HashMap::new(), |mut m, v| {
649                *m.entry(v.tag(key)).or_default() += 1;
650                m
651            });
652
653            let MapValue { descendant_counts, value: _ } = entry;
654            prop_assert_eq!(
655                expected_tag_counts,
656                descendant_counts.into_iter().as_map(),
657                "key = {:?}",
658                key
659            );
660        }
661        Ok(())
662    }
663
664    proptest::proptest! {
665        #![proptest_config(proptest::test_runner::Config {
666            // Add all failed seeds here.
667            failure_persistence: proptest_support::failed_seeds_no_std!(),
668            ..proptest::test_runner::Config::default()
669        })]
670
671        #[test]
672        fn test_arbitrary_operations(operations in proptest::collection::vec(operation_strategy(), 10)) {
673            let mut map = TestSocketMap::default();
674            let mut reference = HashMap::new();
675            for op in operations {
676                op.apply(&mut map, &mut reference);
677            }
678
679            // After all operations have completed, check invariants for
680            // SocketMap.
681            validate_map(map, reference)?;
682        }
683
684    }
685}