1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use alloc::collections::hash_map::{Entry, HashMap};
use core::hash::Hash;
use core::num::NonZeroUsize;

/// The result of inserting an element into a [`RefCountedHashMap`].
#[derive(Debug, Eq, PartialEq)]
pub enum InsertResult<O> {
    /// The key was not previously in the map, so it was inserted.
    Inserted(O),
    /// The key was already in the map, so we incremented the entry's reference
    /// count.
    AlreadyPresent,
}

/// The result of removing an entry from a [`RefCountedHashMap`].
#[derive(Debug, Eq, PartialEq)]
pub enum RemoveResult<V> {
    /// The reference count reached 0, so the entry was removed.
    Removed(V),
    /// The reference count did not reach 0, so the entry still exists in the map.
    StillPresent,
    /// The key was not in the map.
    NotPresent,
}

/// A [`HashMap`] which keeps a reference count for each entry.
#[derive(Debug)]
pub struct RefCountedHashMap<K, V> {
    inner: HashMap<K, (NonZeroUsize, V)>,
}

impl<K, V> Default for RefCountedHashMap<K, V> {
    fn default() -> RefCountedHashMap<K, V> {
        RefCountedHashMap { inner: HashMap::default() }
    }
}

impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
    /// Increments the reference count of the entry with the given key.
    ///
    /// If the key isn't in the map, the given function is called to create its
    /// associated value.
    pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
        match self.inner.entry(key) {
            Entry::Occupied(mut entry) => {
                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
                *refcnt = refcnt.checked_add(1).unwrap();
                InsertResult::AlreadyPresent
            }
            Entry::Vacant(entry) => {
                let (value, output) = f();
                let _: &mut (NonZeroUsize, V) =
                    entry.insert((const_unwrap::const_unwrap_option(NonZeroUsize::new(1)), value));
                InsertResult::Inserted(output)
            }
        }
    }

    /// Decrements the reference count of the entry with the given key.
    ///
    /// If the reference count reaches 0, the entry will be removed and its
    /// value returned.
    pub fn remove(&mut self, key: K) -> RemoveResult<V> {
        match self.inner.entry(key) {
            Entry::Vacant(_) => RemoveResult::NotPresent,
            Entry::Occupied(mut entry) => {
                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
                match NonZeroUsize::new(refcnt.get() - 1) {
                    None => {
                        let (_, value): (NonZeroUsize, V) = entry.remove();
                        RemoveResult::Removed(value)
                    }
                    Some(new_refcnt) => {
                        *refcnt = new_refcnt;
                        RemoveResult::StillPresent
                    }
                }
            }
        }
    }

    /// Returns `true` if the map contains a value for the specified key.
    pub fn contains_key(&self, key: &K) -> bool {
        self.inner.contains_key(key)
    }

    /// Returns a reference to the value corresponding to the key.
    pub fn get(&self, key: &K) -> Option<&V> {
        self.inner.get(key).map(|(_, value)| value)
    }

    /// Returns a mutable reference to the value corresponding to the key.
    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
        self.inner.get_mut(key).map(|(_, value)| value)
    }

    /// An iterator visiting all key-value pairs in arbitrary order, with
    /// mutable references to the values.
    pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
        self.inner.iter_mut().map(|(key, (_, value))| (key, value))
    }
}

/// A [`RefCountedHashMap`] where the value is `()`.
#[derive(Debug)]
pub struct RefCountedHashSet<T> {
    inner: RefCountedHashMap<T, ()>,
}

impl<T> Default for RefCountedHashSet<T> {
    fn default() -> RefCountedHashSet<T> {
        RefCountedHashSet { inner: RefCountedHashMap::default() }
    }
}

impl<T: Eq + Hash> RefCountedHashSet<T> {
    /// Increments the reference count of the given value.
    pub fn insert(&mut self, value: T) -> InsertResult<()> {
        self.inner.insert_with(value, || ((), ()))
    }

    /// Decrements the reference count of the given value.
    ///
    /// If the reference count reaches 0, the value will be removed from the
    /// set.
    pub fn remove(&mut self, value: T) -> RemoveResult<()> {
        self.inner.remove(value)
    }

    /// Returns `true` if the set contains the given value.
    pub fn contains(&self, value: &T) -> bool {
        self.inner.contains_key(value)
    }

    /// Returns the number of values in the set.
    pub fn len(&self) -> usize {
        self.inner.inner.len()
    }

    /// Iterates over values and reference counts.
    pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
        self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
    }
}

impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
        iter.into_iter().fold(Self::default(), |mut set, t| {
            let _: InsertResult<()> = set.insert(t);
            set
        })
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_ref_counted_hash_map() {
        let mut map = RefCountedHashMap::<&str, ()>::default();
        let key = "key";

        // Test refcounts 1 and 2. The behavioral difference is that testing
        // only with a refcount of 1 doesn't exercise the refcount incrementing
        // functionality - it only exercises the functionality of initializing a
        // new entry with a refcount of 1.
        for refcount in 1..=2 {
            assert!(!map.contains_key(&key));

            // Insert an entry for the first time, initializing the refcount to
            // 1.
            assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
            assert!(map.contains_key(&key));
            assert_refcount(&map, key, 1, "after initial insert");

            // Increase the refcount to `refcount`.
            for i in 1..refcount {
                // Since the refcount starts at 1, the entry is always already
                // in the map.
                assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
                assert!(map.contains_key(&key));
                assert_refcount(&map, key, i + 1, "after subsequent insert");
            }

            // Decrement the refcount to 1.
            for i in 1..refcount {
                // Since we don't decrement the refcount past 1, the entry is
                // always still present.
                assert_eq!(map.remove(key), RemoveResult::StillPresent);
                assert!(map.contains_key(&key));
                assert_refcount(&map, key, refcount - i, "after decrement refcount");
            }

            assert_refcount(&map, key, 1, "before entry removed");
            // Remove the entry when the refcount is 1.
            assert_eq!(map.remove(key), RemoveResult::Removed(()));
            assert!(!map.contains_key(&key));

            // Try to remove an entry that no longer exists.
            assert_eq!(map.remove(key), RemoveResult::NotPresent);
        }
    }

    fn assert_refcount(
        map: &RefCountedHashMap<&str, ()>,
        key: &str,
        expected_refcount: usize,
        context: &str,
    ) {
        let (actual_refcount, _value) =
            map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
        assert_eq!(actual_refcount.get(), expected_refcount);
    }
}