use alloc::collections::{hash_map, HashMap};
use core::fmt::Debug;
use core::hash::Hash;
use core::num::NonZeroUsize;
use either::Either;
use derivative::Derivative;
pub trait IterShadows {
type IterShadows: Iterator<Item = Self>;
fn iter_shadows(&self) -> Self::IterShadows;
}
pub trait Tagged<A> {
type Tag: Copy + Eq + core::fmt::Debug;
fn tag(&self, address: &A) -> Self::Tag;
}
#[derive(Derivative, Debug)]
#[derivative(Default(bound = ""))]
pub struct SocketMap<A: Hash + Eq, V: Tagged<A>> {
map: HashMap<A, MapValue<V, V::Tag>>,
len: usize,
}
#[derive(Derivative, Debug)]
#[derivative(Default(bound = ""))]
struct MapValue<V, T> {
value: Option<V>,
descendant_counts: DescendantCounts<T>,
}
#[derive(Derivative, Debug)]
#[derivative(Default(bound = ""))]
struct DescendantCounts<T, const INLINE_SIZE: usize = 1> {
counts: smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]>,
}
pub struct OccupiedEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
#[cfg_attr(test, derive(Debug))]
pub struct VacantEntry<'a, A: Hash + Eq, V: Tagged<A>>(&'a mut SocketMap<A, V>, A);
#[cfg_attr(test, derive(Debug))]
pub enum Entry<'a, A: Hash + Eq, V: Tagged<A>> {
Occupied(OccupiedEntry<'a, A, V>),
Vacant(VacantEntry<'a, A, V>),
}
impl<A, V> SocketMap<A, V>
where
A: IterShadows + Hash + Eq,
V: Tagged<A>,
{
pub fn len(&self) -> usize {
self.len
}
pub fn get(&self, key: &A) -> Option<&V> {
let Self { map, len: _ } = self;
map.get(key).and_then(|MapValue { value, descendant_counts: _ }| value.as_ref())
}
pub fn entry(&mut self, key: A) -> Entry<'_, A, V> {
let Self { map, len: _ } = self;
match map.get(&key) {
Some(MapValue { descendant_counts: _, value: Some(_) }) => {
Entry::Occupied(OccupiedEntry(self, key))
}
Some(MapValue { descendant_counts: _, value: None }) | None => {
Entry::Vacant(VacantEntry(self, key))
}
}
}
#[cfg(test)]
pub fn remove(&mut self, key: &A) -> Option<V>
where
A: Clone,
{
match self.entry(key.clone()) {
Entry::Vacant(_) => return None,
Entry::Occupied(o) => Some(o.remove()),
}
}
pub fn descendant_counts(
&self,
key: &A,
) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
let Self { map, len: _ } = self;
map.get(key)
.map(|MapValue { value: _, descendant_counts }| {
Either::Left(descendant_counts.into_iter())
})
.unwrap_or(Either::Right(core::iter::empty()))
}
pub fn iter(&self) -> impl Iterator<Item = (&'_ A, &'_ V)> {
let Self { map, len: _ } = self;
map.iter().filter_map(|(a, MapValue { value, descendant_counts: _ })| {
value.as_ref().map(|v| (a, v))
})
}
fn increment_descendant_counts(
map: &mut HashMap<A, MapValue<V, V::Tag>>,
shadows: A::IterShadows,
tag: V::Tag,
) {
for shadow in shadows {
let MapValue { descendant_counts, value: _ } = map.entry(shadow).or_default();
descendant_counts.increment(tag);
}
}
fn update_descendant_counts(
map: &mut HashMap<A, MapValue<V, V::Tag>>,
shadows: A::IterShadows,
old_tag: V::Tag,
new_tag: V::Tag,
) {
if old_tag != new_tag {
for shadow in shadows {
let counts = &mut map.get_mut(&shadow).unwrap().descendant_counts;
counts.increment(new_tag);
counts.decrement(old_tag);
}
}
}
fn decrement_descendant_counts(
map: &mut HashMap<A, MapValue<V, V::Tag>>,
shadows: A::IterShadows,
old_tag: V::Tag,
) {
for shadow in shadows {
let mut entry = match map.entry(shadow) {
hash_map::Entry::Occupied(o) => o,
hash_map::Entry::Vacant(_) => unreachable!(),
};
let MapValue { descendant_counts, value } = entry.get_mut();
descendant_counts.decrement(old_tag);
if descendant_counts.is_empty() && value.is_none() {
let _: MapValue<_, _> = entry.remove();
}
}
}
}
impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> OccupiedEntry<'a, K, V> {
pub fn key(&self) -> &K {
let Self(SocketMap { map: _, len: _ }, key) = self;
key
}
pub fn get(&self) -> &V {
let Self(SocketMap { map, len: _ }, key) = self;
let MapValue { descendant_counts: _, value } = map.get(key).unwrap();
value.as_ref().unwrap()
}
pub fn map_mut<R>(&mut self, apply: impl FnOnce(&mut V) -> R) -> R {
let Self(SocketMap { map, len: _ }, key) = self;
let MapValue { descendant_counts: _, value } = map.get_mut(key).unwrap();
let value = value.as_mut().unwrap();
let old_tag = value.tag(key);
let r = apply(value);
let new_tag = value.tag(key);
SocketMap::update_descendant_counts(map, key.iter_shadows(), old_tag, new_tag);
r
}
pub fn into_map(self) -> &'a mut SocketMap<K, V> {
let Self(socketmap, _) = self;
socketmap
}
pub fn remove(self) -> V {
let (value, _map) = self.remove_from_map();
value
}
pub fn get_map(&self) -> &SocketMap<K, V> {
let Self(socketmap, _) = self;
socketmap
}
pub fn remove_from_map(self) -> (V, &'a mut SocketMap<K, V>) {
let Self(socketmap, key) = self;
let SocketMap { map, len } = socketmap;
let shadows = key.iter_shadows();
let mut entry = match map.entry(key) {
hash_map::Entry::Occupied(o) => o,
hash_map::Entry::Vacant(_) => unreachable!("OccupiedEntry not occupied"),
};
let tag = {
let MapValue { descendant_counts: _, value } = entry.get();
value.as_ref().unwrap().tag(entry.key())
};
let MapValue { descendant_counts, value } = entry.get_mut();
let value =
value.take().expect("OccupiedEntry invariant violated: expected Some, found None");
if descendant_counts.is_empty() {
let _: MapValue<V, V::Tag> = entry.remove();
}
SocketMap::decrement_descendant_counts(map, shadows, tag);
*len -= 1;
(value, socketmap)
}
}
impl<'a, K: Eq + Hash + IterShadows, V: Tagged<K>> VacantEntry<'a, K, V> {
pub fn insert(self, value: V) -> OccupiedEntry<'a, K, V>
where
K: Clone,
{
let Self(socket_map, key) = self;
let SocketMap { map, len } = socket_map;
let iter_shadows = key.iter_shadows();
let tag = value.tag(&key);
*len += 1;
SocketMap::increment_descendant_counts(map, iter_shadows, tag);
let MapValue { value: map_value, descendant_counts: _ } =
map.entry(key.clone()).or_default();
assert!(map_value.replace(value).is_none());
OccupiedEntry(socket_map, key)
}
pub fn into_map(self) -> &'a mut SocketMap<K, V> {
let Self(socketmap, _) = self;
socketmap
}
pub fn descendant_counts(&self) -> impl ExactSizeIterator<Item = &'_ (V::Tag, NonZeroUsize)> {
let Self(socket_map, key) = self;
socket_map.descendant_counts(&key)
}
}
impl<'a, A: Debug + Eq + Hash, V: Tagged<A>> Debug for OccupiedEntry<'a, A, V> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let Self(_socket_map, key) = self;
f.debug_tuple("OccupiedEntry").field(&"_").field(key).finish()
}
}
impl<T: Eq, const INLINE_SIZE: usize> DescendantCounts<T, INLINE_SIZE> {
const ONE: NonZeroUsize = const_unwrap::const_unwrap_option(NonZeroUsize::new(1));
fn increment(&mut self, tag: T) {
let Self { counts } = self;
match counts.iter_mut().find_map(|(t, count)| (t == &tag).then(|| count)) {
Some(count) => *count = NonZeroUsize::new(count.get() + 1).unwrap(),
None => counts.push((tag, Self::ONE)),
}
}
fn decrement(&mut self, tag: T) {
let Self { counts } = self;
let (index, count) = counts
.iter_mut()
.enumerate()
.find_map(|(i, (t, count))| (t == &tag).then(|| (i, count)))
.unwrap();
if let Some(new_count) = NonZeroUsize::new(count.get() - 1) {
*count = new_count
} else {
let _: (T, NonZeroUsize) = counts.swap_remove(index);
}
}
fn is_empty(&self) -> bool {
let Self { counts } = self;
counts.is_empty()
}
}
impl<'d, T, const INLINE_SIZE: usize> IntoIterator for &'d DescendantCounts<T, INLINE_SIZE> {
type Item = &'d (T, NonZeroUsize);
type IntoIter =
<&'d smallvec::SmallVec<[(T, NonZeroUsize); INLINE_SIZE]> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
let DescendantCounts { counts } = self;
counts.into_iter()
}
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use alloc::{format, vec};
use assert_matches::assert_matches;
use proptest::prop_assert_eq;
use proptest::strategy::Strategy;
use super::*;
trait AsMap {
type K: Hash + Eq;
type V;
fn as_map(self) -> HashMap<Self::K, Self::V>;
}
impl<'d, K, V, I> AsMap for I
where
K: Hash + Eq + Clone + 'd,
V: 'd,
V: Clone + Into<usize>,
I: Iterator<Item = &'d (K, V)>,
{
type K = K;
type V = usize;
fn as_map(self) -> HashMap<Self::K, Self::V> {
self.map(|(k, v)| (k.clone(), v.clone().into())).collect()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
enum Address {
A(u8),
AB(u8, char),
ABC(u8, char, u8),
}
use Address::*;
impl IterShadows for Address {
type IterShadows = <Vec<Address> as IntoIterator>::IntoIter;
fn iter_shadows(&self) -> Self::IterShadows {
match self {
A(_) => vec![],
AB(a, _) => vec![A(*a)],
ABC(a, b, _) => vec![AB(*a, *b), A(*a)],
}
.into_iter()
}
}
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
struct TV<T, V>(T, V);
impl<T: Copy + Eq + core::fmt::Debug, V> Tagged<Address> for TV<T, V> {
type Tag = T;
fn tag(&self, _: &Address) -> Self::Tag {
self.0
}
}
type TestSocketMap<T> = SocketMap<Address, TV<T, u8>>;
#[test]
fn insert_get_remove() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
assert_eq!(map.get(&ABC(1, 'c', 2)), Some(&TV(0, 32)));
assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
assert_eq!(map.get(&ABC(1, 'c', 2)), None);
}
#[test]
fn insert_remove_len() {
let mut map = TestSocketMap::default();
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, 0);
assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, 1);
assert_eq!(map.remove(&ABC(1, 'c', 2)), Some(TV(0, 32)));
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, 0);
}
#[test]
fn entry_same_key() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(0, 32)));
let occupied = assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Occupied(o) => o);
assert_eq!(occupied.get(), &TV(0, 32));
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, 1);
}
#[test]
fn multiple_insert_descendant_counts() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(1, 'c', 2)), Entry::Vacant(v) => v.insert(TV(1, 111)));
assert_matches!(map.entry(ABC(1, 'd', 2)), Entry::Vacant(v) => v.insert(TV(2, 111)));
assert_matches!(map.entry(AB(5, 'd')), Entry::Vacant(v) => v.insert(TV(1, 54)));
assert_matches!(map.entry(AB(1, 'd')), Entry::Vacant(v) => v.insert(TV(3, 56)));
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, 4);
assert_eq!(map.descendant_counts(&A(1)).as_map(), HashMap::from([(1, 1), (2, 1), (3, 1)]));
assert_eq!(map.descendant_counts(&AB(1, 'c')).as_map(), HashMap::from([(1, 1)]));
assert_eq!(map.descendant_counts(&AB(1, 'd')).as_map(), HashMap::from([(2, 1)]));
assert_eq!(map.descendant_counts(&A(5)).as_map(), HashMap::from([(1, 1)]));
assert_eq!(map.descendant_counts(&ABC(1, 'd', 2)).as_map(), HashMap::from([]));
assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([]));
}
#[test]
fn entry_remove_no_shadows() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(3, 111)));
let entry = assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Occupied(o) => o);
assert_eq!(entry.remove(), TV(3, 111));
let TestSocketMap { map, len } = map;
assert_eq!(len, 0);
assert_eq!(map.len(), 0);
}
#[test]
fn entry_remove_with_shadows() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(16, 'c', 8)), Entry::Vacant(v) => v.insert(TV(2, 112)));
assert_matches!(map.entry(AB(16, 'c')), Entry::Vacant(v) => v.insert(TV(1, 111)));
assert_matches!(map.entry(A(16)), Entry::Vacant(v) => v.insert(TV(0, 110)));
let entry = assert_matches!(map.entry(AB(16, 'c')), Entry::Occupied(o) => o);
assert_eq!(entry.remove(), TV(1, 111));
let TestSocketMap { map, len } = map;
assert_eq!(len, 2);
assert_eq!(map.len(), 3);
}
#[test]
fn remove_ancestor_value() {
let mut map = TestSocketMap::default();
assert_matches!(map.entry(ABC(2, 'e', 1)), Entry::Vacant(v) => v.insert(TV(20, 100)));
assert_matches!(map.entry(AB(2, 'e')), Entry::Vacant(v) => v.insert(TV(20, 100)));
assert_eq!(map.remove(&AB(2, 'e')), Some(TV(20, 100)));
assert_eq!(map.descendant_counts(&A(2)).as_map(), HashMap::from([(20, 1)]));
}
fn key_strategy() -> impl Strategy<Value = Address> {
let a_strategy = 1..5u8;
let b_strategy = proptest::char::range('a', 'e');
let c_strategy = 1..5u8;
(a_strategy, proptest::option::of((b_strategy, proptest::option::of(c_strategy)))).prop_map(
|(a, b)| match b {
None => A(a),
Some((b, None)) => AB(a, b),
Some((b, Some(c))) => ABC(a, b, c),
},
)
}
fn value_strategy() -> impl Strategy<Value = TV<u8, u8>> {
(20..25u8, 100..105u8).prop_map(|(t, v)| TV(t, v))
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum Operation {
Entry(Address, TV<u8, u8>),
Remove(Address),
}
impl Operation {
fn apply(
self,
socket_map: &mut TestSocketMap<u8>,
reference: &mut HashMap<Address, TV<u8, u8>>,
) {
match self {
Operation::Entry(a, v) => match (socket_map.entry(a), reference.entry(a)) {
(Entry::Occupied(mut s), hash_map::Entry::Occupied(mut h)) => {
assert_eq!(s.map_mut(|value| core::mem::replace(value, v)), h.insert(v))
}
(Entry::Vacant(s), hash_map::Entry::Vacant(h)) => {
let _: OccupiedEntry<'_, _, _> = s.insert(v);
let _: &mut TV<_, _> = h.insert(v);
}
(Entry::Occupied(_), hash_map::Entry::Vacant(_)) => {
panic!("socketmap has a value for {:?} but reference does not", a)
}
(Entry::Vacant(_), hash_map::Entry::Occupied(_)) => {
panic!("socketmap has no value for {:?} but reference does", a)
}
},
Operation::Remove(a) => assert_eq!(socket_map.remove(&a), reference.remove(&a)),
}
}
}
fn operation_strategy() -> impl Strategy<Value = Operation> {
proptest::prop_oneof!(
(key_strategy(), value_strategy()).prop_map(|(a, v)| Operation::Entry(a, v)),
key_strategy().prop_map(|a| Operation::Remove(a)),
)
}
fn validate_map(
map: TestSocketMap<u8>,
reference: HashMap<Address, TV<u8, u8>>,
) -> Result<(), proptest::test_runner::TestCaseError> {
let map_values: HashMap<_, _> = map.iter().map(|(a, v)| (*a, *v)).collect();
assert_eq!(map_values, reference);
let TestSocketMap { len, map: _ } = map;
assert_eq!(len, reference.len());
let TestSocketMap { map: inner_map, len: _ } = ↦
for (key, entry) in inner_map {
let descendant_values = map
.iter()
.filter(|(k, _)| k.iter_shadows().any(|s| s == *key))
.map(|(_, value)| value);
let expected_tag_counts = descendant_values.fold(HashMap::new(), |mut m, v| {
*m.entry(v.tag(key)).or_default() += 1;
m
});
let MapValue { descendant_counts, value: _ } = entry;
prop_assert_eq!(
expected_tag_counts,
descendant_counts.into_iter().as_map(),
"key = {:?}",
key
);
}
Ok(())
}
proptest::proptest! {
#![proptest_config(proptest::test_runner::Config {
failure_persistence: proptest_support::failed_seeds_no_std!(),
..proptest::test_runner::Config::default()
})]
#[test]
fn test_arbitrary_operations(operations in proptest::collection::vec(operation_strategy(), 10)) {
let mut map = TestSocketMap::default();
let mut reference = HashMap::new();
for op in operations {
op.apply(&mut map, &mut reference);
}
validate_map(map, reference)?;
}
}
}