netstack3_base/data_structures/
ref_counted_hash_map.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use alloc::collections::hash_map::{Entry, HashMap};
6use core::hash::Hash;
7use core::num::NonZeroUsize;
8
9/// The result of inserting an element into a [`RefCountedHashMap`].
10#[derive(Debug, Eq, PartialEq)]
11pub enum InsertResult<O> {
12    /// The key was not previously in the map, so it was inserted.
13    Inserted(O),
14    /// The key was already in the map, so we incremented the entry's reference
15    /// count.
16    AlreadyPresent,
17}
18
19/// The result of removing an entry from a [`RefCountedHashMap`].
20#[derive(Debug, Eq, PartialEq)]
21pub enum RemoveResult<V> {
22    /// The reference count reached 0, so the entry was removed.
23    Removed(V),
24    /// The reference count did not reach 0, so the entry still exists in the map.
25    StillPresent,
26    /// The key was not in the map.
27    NotPresent,
28}
29
30/// A [`HashMap`] which keeps a reference count for each entry.
31#[derive(Debug)]
32pub struct RefCountedHashMap<K, V> {
33    inner: HashMap<K, (NonZeroUsize, V)>,
34}
35
36impl<K, V> Default for RefCountedHashMap<K, V> {
37    fn default() -> RefCountedHashMap<K, V> {
38        RefCountedHashMap { inner: HashMap::default() }
39    }
40}
41
42impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
43    /// Increments the reference count of the entry with the given key.
44    ///
45    /// If the key isn't in the map, the given function is called to create its
46    /// associated value.
47    pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
48        match self.inner.entry(key) {
49            Entry::Occupied(mut entry) => {
50                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
51                *refcnt = refcnt.checked_add(1).unwrap();
52                InsertResult::AlreadyPresent
53            }
54            Entry::Vacant(entry) => {
55                let (value, output) = f();
56                let _: &mut (NonZeroUsize, V) =
57                    entry.insert((NonZeroUsize::new(1).unwrap(), value));
58                InsertResult::Inserted(output)
59            }
60        }
61    }
62
63    /// Decrements the reference count of the entry with the given key.
64    ///
65    /// If the reference count reaches 0, the entry will be removed and its
66    /// value returned.
67    pub fn remove(&mut self, key: K) -> RemoveResult<V> {
68        match self.inner.entry(key) {
69            Entry::Vacant(_) => RemoveResult::NotPresent,
70            Entry::Occupied(mut entry) => {
71                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
72                match NonZeroUsize::new(refcnt.get() - 1) {
73                    None => {
74                        let (_, value): (NonZeroUsize, V) = entry.remove();
75                        RemoveResult::Removed(value)
76                    }
77                    Some(new_refcnt) => {
78                        *refcnt = new_refcnt;
79                        RemoveResult::StillPresent
80                    }
81                }
82            }
83        }
84    }
85
86    /// Returns `true` if the map contains a value for the specified key.
87    pub fn contains_key(&self, key: &K) -> bool {
88        self.inner.contains_key(key)
89    }
90
91    /// Returns a reference to the value corresponding to the key.
92    pub fn get(&self, key: &K) -> Option<&V> {
93        self.inner.get(key).map(|(_, value)| value)
94    }
95
96    /// Returns a mutable reference to the value corresponding to the key.
97    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
98        self.inner.get_mut(key).map(|(_, value)| value)
99    }
100
101    /// An iterator visiting all key-value pairs in arbitrary order, with
102    /// mutable references to the values.
103    pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
104        self.inner.iter_mut().map(|(key, (_, value))| (key, value))
105    }
106
107    /// An iterator visiting all key-value pairs in arbitrary order, with
108    /// non-mutable references to the values.
109    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> + Clone {
110        self.inner.iter().map(|(key, (_, value))| (key, value))
111    }
112
113    /// An iterator visiting all keys in arbitrary order with the reference
114    /// count for each key.
115    pub fn iter_ref_counts<'a>(
116        &'a self,
117    ) -> impl 'a + Iterator<Item = (&'a K, &'a NonZeroUsize)> + Clone {
118        self.inner.iter().map(|(key, (count, _))| (key, count))
119    }
120
121    /// Returns whether the map is empty.
122    pub fn is_empty(&self) -> bool {
123        self.inner.is_empty()
124    }
125}
126
127/// A [`RefCountedHashMap`] where the value is `()`.
128#[derive(Debug)]
129pub struct RefCountedHashSet<T> {
130    inner: RefCountedHashMap<T, ()>,
131}
132
133impl<T> Default for RefCountedHashSet<T> {
134    fn default() -> RefCountedHashSet<T> {
135        RefCountedHashSet { inner: RefCountedHashMap::default() }
136    }
137}
138
139impl<T: Eq + Hash> RefCountedHashSet<T> {
140    /// Increments the reference count of the given value.
141    pub fn insert(&mut self, value: T) -> InsertResult<()> {
142        self.inner.insert_with(value, || ((), ()))
143    }
144
145    /// Decrements the reference count of the given value.
146    ///
147    /// If the reference count reaches 0, the value will be removed from the
148    /// set.
149    pub fn remove(&mut self, value: T) -> RemoveResult<()> {
150        self.inner.remove(value)
151    }
152
153    /// Returns `true` if the set contains the given value.
154    pub fn contains(&self, value: &T) -> bool {
155        self.inner.contains_key(value)
156    }
157
158    /// Returns the number of values in the set.
159    pub fn len(&self) -> usize {
160        self.inner.inner.len()
161    }
162
163    /// Iterates over values and reference counts.
164    pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
165        self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
166    }
167}
168
169impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
170    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
171        iter.into_iter().fold(Self::default(), |mut set, t| {
172            let _: InsertResult<()> = set.insert(t);
173            set
174        })
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use super::*;
181
182    #[test]
183    fn test_ref_counted_hash_map() {
184        let mut map = RefCountedHashMap::<&str, ()>::default();
185        let key = "key";
186
187        // Test refcounts 1 and 2. The behavioral difference is that testing
188        // only with a refcount of 1 doesn't exercise the refcount incrementing
189        // functionality - it only exercises the functionality of initializing a
190        // new entry with a refcount of 1.
191        for refcount in 1..=2 {
192            assert!(!map.contains_key(&key));
193
194            // Insert an entry for the first time, initializing the refcount to
195            // 1.
196            assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
197            assert!(map.contains_key(&key));
198            assert_refcount(&map, key, 1, "after initial insert");
199
200            // Increase the refcount to `refcount`.
201            for i in 1..refcount {
202                // Since the refcount starts at 1, the entry is always already
203                // in the map.
204                assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
205                assert!(map.contains_key(&key));
206                assert_refcount(&map, key, i + 1, "after subsequent insert");
207            }
208
209            // Decrement the refcount to 1.
210            for i in 1..refcount {
211                // Since we don't decrement the refcount past 1, the entry is
212                // always still present.
213                assert_eq!(map.remove(key), RemoveResult::StillPresent);
214                assert!(map.contains_key(&key));
215                assert_refcount(&map, key, refcount - i, "after decrement refcount");
216            }
217
218            assert_refcount(&map, key, 1, "before entry removed");
219            // Remove the entry when the refcount is 1.
220            assert_eq!(map.remove(key), RemoveResult::Removed(()));
221            assert!(!map.contains_key(&key));
222
223            // Try to remove an entry that no longer exists.
224            assert_eq!(map.remove(key), RemoveResult::NotPresent);
225        }
226    }
227
228    fn assert_refcount(
229        map: &RefCountedHashMap<&str, ()>,
230        key: &str,
231        expected_refcount: usize,
232        context: &str,
233    ) {
234        let (actual_refcount, _value) =
235            map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
236        assert_eq!(actual_refcount.get(), expected_refcount);
237    }
238}