netdevice_client/
port_slab.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Data structure helper to keep data associated with netdevice ports.
6
7use crate::Port;
8
9#[derive(Debug)]
10struct Slot<T> {
11    salt: u8,
12    value: T,
13}
14
15/// A data structure that is keyed on [`Port`], guarantees O(1) lookup
16/// and takes into account salted port identifiers.
17#[derive(Debug, derivative::Derivative)]
18#[derivative(Default(bound = ""))]
19pub struct PortSlab<T> {
20    slots: Vec<Option<Slot<T>>>,
21}
22
23/// Observable outcomes from [`PortSlab::remove`].
24#[derive(Eq, PartialEq, Debug)]
25pub enum RemoveOutcome<T> {
26    /// Requested port was not present.
27    PortNotPresent,
28    /// There exists a port with the same ID present, but the stored salt
29    /// doesn't match. Contains the stored salt.
30    SaltMismatch(u8),
31    /// Port was removed, contains the stored value.
32    Removed(T),
33}
34
35impl<T> PortSlab<T> {
36    /// Creates a new empty `PortSlab`.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Inserts `value` indexed by `port`.
42    ///
43    /// Returns `Some(_)` if the value is inserted and there was already
44    /// some value stored in the slab.
45    pub fn insert(&mut self, port: Port, value: T) -> Option<T> {
46        let Port { base, salt } = port;
47        let slot = self.get_slot_mut(base.into());
48        slot.replace(Slot { salt, value }).map(|Slot { salt: _, value }| value)
49    }
50
51    /// Removes the entry indexed by `port`, if one exists.
52    ///
53    /// Note that `remove` will *not* remove an entry if the currently stored
54    /// port's salt doesn't match `port`.
55    pub fn remove(&mut self, port: &Port) -> RemoveOutcome<T> {
56        match self.entry(*port) {
57            Entry::SaltMismatch(SaltMismatchEntry(slot)) => {
58                RemoveOutcome::SaltMismatch(slot.as_ref().unwrap().salt)
59            }
60            Entry::Vacant(VacantEntry(_, _)) => RemoveOutcome::PortNotPresent,
61            Entry::Occupied(e) => RemoveOutcome::Removed(e.remove()),
62        }
63    }
64
65    /// Gets a reference to the value indexed by `port`.
66    ///
67    /// `get` only returns `Some` if the slab contains an entry for `port` with
68    /// a matching salt.
69    pub fn get(&self, port: &Port) -> Option<&T> {
70        let Self { slots } = self;
71        let Port { base, salt } = port;
72        slots.get(usize::from(*base)).and_then(|s| s.as_ref()).and_then(
73            |Slot { salt: existing_salt, value }| (existing_salt == salt).then_some(value),
74        )
75    }
76
77    /// Gets a mutable reference to the value indexed by `port`.
78    ///
79    /// `get_mut` only returns `Some` if the slab contains an entry for `port`
80    /// with a matching salt.
81    pub fn get_mut(&mut self, port: &Port) -> Option<&mut T> {
82        let Self { slots } = self;
83        let Port { base, salt } = port;
84        slots.get_mut(usize::from(*base)).and_then(|s| s.as_mut()).and_then(
85            |Slot { salt: existing_salt, value }| (existing_salt == salt).then_some(value),
86        )
87    }
88
89    /// Retrieves an [`entry`] indexed by `port`.
90    pub fn entry(&mut self, port: Port) -> Entry<'_, T> {
91        let Port { base, salt } = port;
92        let base = usize::from(base);
93
94        // NB: Lifetimes in this function disallow us from doing the "pretty"
95        // thing here of matching just once on the result of `get_mut`. We need
96        // to erase the lifetime information with the boolean check to appease
97        // the borrow checker. Otherwise we get errors of the form
98        // "error[E0499]: cannot borrow `*self` as mutable more than once at a
99        // time".
100        if self.slots.get_mut(base).is_none() {
101            return Entry::Vacant(VacantEntry(
102                VacantState::NeedSlot(self, usize::from(base)),
103                salt,
104            ));
105        }
106        let slot = self.slots.get_mut(base).unwrap();
107        match slot {
108            Some(Slot { salt: existing_salt, value: _ }) => {
109                if *existing_salt == salt {
110                    Entry::Occupied(OccupiedEntry(slot))
111                } else {
112                    Entry::SaltMismatch(SaltMismatchEntry(slot))
113                }
114            }
115            None => Entry::Vacant(VacantEntry(VacantState::EmptySlot(slot), salt)),
116        }
117    }
118
119    fn get_slot_mut(&mut self, index: usize) -> &mut Option<Slot<T>> {
120        let Self { slots } = self;
121        // The slab only ever grows.
122        if slots.len() <= index {
123            slots.resize_with(index + 1, || None);
124        }
125
126        &mut slots[index]
127    }
128}
129
130/// An entry obtained from [`PortSlab::entry`].
131#[derive(Debug)]
132pub enum Entry<'a, T> {
133    /// Slot is vacant.
134    Vacant(VacantEntry<'a, T>),
135    /// Slot is occupied with a matching salt.
136    Occupied(OccupiedEntry<'a, T>),
137    /// Slot is occupied with a mismatched salt.
138    SaltMismatch(SaltMismatchEntry<'a, T>),
139}
140
141#[derive(Debug)]
142enum VacantState<'a, T> {
143    NeedSlot(&'a mut PortSlab<T>, usize),
144    EmptySlot(&'a mut Option<Slot<T>>),
145}
146
147/// A vacant slot in a [`PortSlab`].
148#[derive(Debug)]
149pub struct VacantEntry<'a, T>(VacantState<'a, T>, u8);
150
151impl<'a, T> VacantEntry<'a, T> {
152    /// Inserts `value` in this entry slot.
153    pub fn insert(self, value: T) {
154        let VacantEntry(state, salt) = self;
155        let slot = match state {
156            VacantState::NeedSlot(slab, base) => slab.get_slot_mut(base),
157            VacantState::EmptySlot(slot) => slot,
158        };
159        assert!(slot.replace(Slot { salt, value }).is_none(), "violated VacantEntry invariant");
160    }
161}
162
163/// An occupied entry in a [`PortSlab`].
164#[derive(Debug)]
165pub struct OccupiedEntry<'a, T>(&'a mut Option<Slot<T>>);
166
167impl<'a, T> OccupiedEntry<'a, T> {
168    /// Gets a reference to the stored value.
169    pub fn get(&self) -> &T {
170        let OccupiedEntry(slot) = self;
171        // OccupiedEntry is a witness to the slot being filled.
172        &slot.as_ref().unwrap().value
173    }
174
175    /// Gets a mutable reference to the stored value.
176    pub fn get_mut(&mut self) -> &mut T {
177        let OccupiedEntry(slot) = self;
178        // OccupiedEntry is a witness to the slot being filled.
179        &mut slot.as_mut().unwrap().value
180    }
181
182    /// Removes the value from the slab.
183    pub fn remove(self) -> T {
184        let OccupiedEntry(slot) = self;
185        // OccupiedEntry is a witness to the slot being filled.
186        slot.take().unwrap().value
187    }
188}
189
190/// A mismatched salt entry in a [`PortSlab`].
191#[derive(Debug)]
192pub struct SaltMismatchEntry<'a, T>(&'a mut Option<Slot<T>>);
193
194impl<'a, T> SaltMismatchEntry<'a, T> {
195    /// Removes the mismatched entry from the slab.
196    pub fn remove(self) -> T {
197        let SaltMismatchEntry(slot) = self;
198        // SaltMismatch is a witness to the slot being filled.
199        slot.take().unwrap().value
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use assert_matches::assert_matches;
207
208    const PORT_A: Port = Port { base: 0, salt: 1 };
209    const PORT_A_GEN_2: Port = Port { salt: 2, ..PORT_A };
210    const PORT_B: Port = Port { base: 1, salt: 1 };
211
212    #[test]
213    fn insert_new_entry() {
214        let mut slab = PortSlab::new();
215        assert_eq!(slab.insert(PORT_A, 0), None);
216        assert_eq!(slab.get(&PORT_A), Some(&0));
217    }
218
219    #[test]
220    fn insert_replaces() {
221        let mut slab = PortSlab::new();
222        assert_eq!(slab.insert(PORT_A, 0), None);
223        assert_eq!(slab.insert(PORT_A, 1), Some(0));
224        assert_eq!(slab.get(&PORT_A), Some(&1));
225    }
226
227    #[test]
228    fn insert_replaces_even_on_salt_mismatch() {
229        let mut slab = PortSlab::new();
230        assert_eq!(slab.insert(PORT_A, 0), None);
231        assert_eq!(slab.insert(PORT_A_GEN_2, 1), Some(0));
232        assert_eq!(slab.get(&PORT_A), None);
233        assert_eq!(slab.get(&PORT_A_GEN_2), Some(&1));
234    }
235
236    #[test]
237    fn remove_nonexisting() {
238        let mut slab = PortSlab::<u32>::new();
239        assert_eq!(slab.remove(&PORT_A), RemoveOutcome::PortNotPresent);
240    }
241
242    #[test]
243    fn remove_matching_salt() {
244        let mut slab = PortSlab::new();
245        assert_eq!(slab.insert(PORT_A, 0), None);
246        assert_eq!(slab.remove(&PORT_A), RemoveOutcome::Removed(0));
247        assert_eq!(slab.get(&PORT_A), None);
248    }
249
250    #[test]
251    fn remove_salt_mismatch() {
252        let mut slab = PortSlab::new();
253        assert_eq!(slab.insert(PORT_A, 0), None);
254        assert_eq!(slab.remove(&PORT_A_GEN_2), RemoveOutcome::SaltMismatch(PORT_A.salt));
255        assert_eq!(slab.get(&PORT_A), Some(&0));
256    }
257
258    #[test]
259    fn get_mut() {
260        let mut slab = PortSlab::new();
261        assert_eq!(slab.get_mut(&PORT_A), None);
262        assert_eq!(slab.insert(PORT_A, 0), None);
263        assert_eq!(slab.insert(PORT_B, 1), None);
264        let a = slab.get_mut(&PORT_A).unwrap();
265        assert_eq!(*a, 0);
266        *a = 3;
267        assert_eq!(slab.get_mut(&PORT_A_GEN_2), None);
268        assert_eq!(slab.get_mut(&PORT_A), Some(&mut 3));
269    }
270
271    #[test]
272    fn entry_vacant_no_slot() {
273        let mut slab = PortSlab::new();
274        let vacant = assert_matches!(slab.entry(PORT_A), Entry::Vacant(v) => v);
275        vacant.insert(1);
276        assert_eq!(slab.get(&PORT_A), Some(&1));
277    }
278
279    #[test]
280    fn entry_vacant_existing_slot() {
281        let mut slab = PortSlab::new();
282        assert_eq!(slab.insert(PORT_A, 0), None);
283        assert_eq!(slab.remove(&PORT_A), RemoveOutcome::Removed(0));
284        let vacant = assert_matches!(slab.entry(PORT_A), Entry::Vacant(v) => v);
285        vacant.insert(1);
286        assert_eq!(slab.get(&PORT_A), Some(&1));
287    }
288
289    #[test]
290    fn entry_occupied_get() {
291        let mut slab = PortSlab::new();
292        assert_eq!(slab.insert(PORT_A, 2), None);
293        let mut occupied = assert_matches!(slab.entry(PORT_A), Entry::Occupied(o) => o);
294        assert_eq!(occupied.get(), &2);
295        assert_eq!(occupied.get_mut(), &mut 2);
296    }
297
298    #[test]
299    fn entry_occupied_remove() {
300        let mut slab = PortSlab::new();
301        assert_eq!(slab.insert(PORT_A, 2), None);
302        let occupied = assert_matches!(slab.entry(PORT_A), Entry::Occupied(o) => o);
303        assert_eq!(occupied.remove(), 2);
304        assert_eq!(slab.get(&PORT_A), None);
305    }
306
307    #[test]
308    fn entry_mismatch() {
309        let mut slab = PortSlab::new();
310        assert_eq!(slab.insert(PORT_A, 2), None);
311        let mismatch = assert_matches!(slab.entry(PORT_A_GEN_2), Entry::SaltMismatch(m) => m);
312        assert_eq!(mismatch.remove(), 2);
313        assert_eq!(slab.get(&PORT_A), None);
314        assert_eq!(slab.get(&PORT_A_GEN_2), None);
315    }
316
317    #[test]
318    fn underlying_vec_only_grows() {
319        let mut slab = PortSlab::new();
320        let high_port = Port { base: 4, salt: 0 };
321        let low_port = Port { base: 0, salt: 0 };
322        assert_eq!(slab.slots.len(), 0, "{:?}", slab.slots);
323        assert_eq!(slab.insert(high_port, 0), None);
324        assert_eq!(slab.slots.len(), usize::from(high_port.base + 1), "{:?}", slab.slots);
325        assert_eq!(slab.remove(&high_port), RemoveOutcome::Removed(0));
326        assert_eq!(slab.slots.len(), usize::from(high_port.base + 1), "{:?}", slab.slots);
327        assert_eq!(slab.insert(low_port, 1), None);
328    }
329}