1use alloc::collections::VecDeque;
9use alloc::sync::Arc;
10#[cfg(test)]
11use alloc::vec::Vec;
12use core::fmt::Debug;
13use core::hash::Hash;
14use core::ops::DerefMut;
15
16use assert_matches::assert_matches;
17use derivative::Derivative;
18use netstack3_base::Inspector;
19use netstack3_base::sync::Mutex;
20use netstack3_hashmap::HashMap;
21
22pub trait ListenerNotifier {
25 fn new_incoming_connections(&mut self, num_ready: usize);
27}
28
29#[derive(Debug, Derivative)]
35#[derivative(Clone(bound = ""))]
36pub struct AcceptQueue<S, R, N>(Arc<Mutex<AcceptQueueInner<S, R, N>>>);
37
38#[cfg(test)]
39impl<S, R, N> PartialEq for AcceptQueue<S, R, N>
40where
41 AcceptQueueInner<S, R, N>: PartialEq,
42{
43 fn eq(&self, Self(other): &Self) -> bool {
44 let Self(inner) = self;
45 if Arc::ptr_eq(other, inner) {
46 return true;
47 }
48 let guard = inner.lock();
49 let other_guard = other.lock();
50 (&*guard).eq(&*other_guard)
51 }
52}
53
54#[cfg(test)]
55impl<S, R, N> Eq for AcceptQueue<S, R, N> where Self: PartialEq {}
56
57#[derive(Debug)]
58#[cfg_attr(test, derive(Eq, PartialEq))]
59enum EntryState {
60 Pending,
61 Ready,
62}
63
64#[derive(Debug)]
65#[cfg_attr(test, derive(Derivative))]
66#[cfg_attr(
67 test,
68 derivative(PartialEq(bound = "S: Hash + Clone + Eq + PartialEq, R: PartialEq, N: PartialEq"))
69)]
70struct AcceptQueueInner<S, R, N> {
71 ready_queue: VecDeque<(S, R)>,
72 all_sockets: HashMap<S, EntryState>,
73 notifier: Option<N>,
74}
75
76impl<S, R, N> AcceptQueue<S, R, N>
77where
78 S: Hash + Clone + Eq + PartialEq + Debug,
79 N: ListenerNotifier,
80{
81 pub(crate) fn new(notifier: N) -> Self {
83 Self(Arc::new(Mutex::new(AcceptQueueInner::new(notifier))))
84 }
85
86 fn lock(&self) -> impl DerefMut<Target = AcceptQueueInner<S, R, N>> + '_ {
87 let Self(inner) = self;
88 inner.lock()
89 }
90
91 pub(crate) fn pop_ready(&self) -> Option<(S, R)> {
98 self.lock().pop_ready()
99 }
100
101 #[cfg(test)]
103 pub(crate) fn collect_pending(&self) -> Vec<S> {
104 self.lock().collect_pending()
105 }
106
107 pub(crate) fn push_pending(&self, pending: S) {
112 self.lock().push_pending(pending)
113 }
114
115 pub(crate) fn len(&self) -> usize {
117 self.lock().len()
118 }
119
120 #[cfg(test)]
122 pub(crate) fn ready_len(&self) -> usize {
123 self.lock().ready_len()
124 }
125
126 #[cfg(test)]
128 pub(crate) fn pending_len(&self) -> usize {
129 self.lock().pending_len()
130 }
131
132 pub(crate) fn notify_ready(&self, newly_ready: &S, ready_state: R) {
139 self.lock().notify_ready(newly_ready, ready_state)
140 }
141
142 pub(crate) fn remove(&self, entry: &S) {
146 self.lock().remove(entry)
147 }
148
149 pub(crate) fn close(&self) -> (impl Iterator<Item = S> + use<S, R, N>, N) {
154 self.lock().close()
155 }
156
157 pub(crate) fn is_closed(&self) -> bool {
159 self.lock().is_closed()
160 }
161
162 pub(crate) fn inspect<I: Inspector>(&self, inspector: &mut I) {
164 let inner = self.lock();
165 inspector.record_usize("NumReady", inner.ready_len());
166 inspector.record_usize("NumPending", inner.pending_len());
167 inspector.record_debug("Contents", &inner.all_sockets);
168 }
169}
170
171impl<S, R, N> AcceptQueueInner<S, R, N>
172where
173 S: Hash + Clone + Eq + PartialEq + Debug,
174 N: ListenerNotifier,
175{
176 fn new(notifier: N) -> Self {
177 Self {
178 ready_queue: Default::default(),
179 all_sockets: Default::default(),
180 notifier: Some(notifier),
181 }
182 }
183
184 fn pop_ready(&mut self) -> Option<(S, R)> {
185 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
186 let (socket, ready_state) = ready_queue.pop_front()?;
187 assert_matches!(all_sockets.remove(&socket), Some(EntryState::Ready));
189 let notifier = notifier.as_mut().unwrap();
191 notifier.new_incoming_connections(ready_queue.len());
192 Some((socket, ready_state))
193 }
194
195 #[cfg(test)]
196 pub(crate) fn collect_pending(&self) -> Vec<S> {
197 let AcceptQueueInner { all_sockets, .. } = self;
198 all_sockets
199 .iter()
200 .filter_map(|(socket, state)| match state {
201 EntryState::Ready => None,
202 EntryState::Pending => Some(socket.clone()),
203 })
204 .collect()
205 }
206
207 fn push_pending(&mut self, pending: S) {
208 let AcceptQueueInner { all_sockets, notifier, .. } = self;
209 assert!(notifier.is_some());
212 assert_matches!(all_sockets.insert(pending, EntryState::Pending), None);
213 }
214
215 fn len(&self) -> usize {
216 let AcceptQueueInner { all_sockets, .. } = self;
217 all_sockets.len()
218 }
219
220 fn ready_len(&self) -> usize {
221 let AcceptQueueInner { ready_queue, .. } = self;
222 ready_queue.len()
223 }
224
225 fn pending_len(&self) -> usize {
226 let AcceptQueueInner { ready_queue, all_sockets, .. } = self;
227 all_sockets.len() - ready_queue.len()
228 }
229
230 fn notify_ready(&mut self, newly_ready: &S, ready_state: R) {
231 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
232 let notifier = match notifier {
233 Some(notifier) => notifier,
234
235 None => {
236 debug_assert!(ready_queue.is_empty());
240 debug_assert!(all_sockets.is_empty());
241 return;
242 }
243 };
244 let entry = all_sockets
245 .get_mut(newly_ready)
246 .expect("attempted to notify ready entry that was not in queue");
247 let prev_state = core::mem::replace(entry, EntryState::Ready);
248 assert_matches!(prev_state, EntryState::Pending);
249 ready_queue.push_back((newly_ready.clone(), ready_state));
250 notifier.new_incoming_connections(ready_queue.len());
251 }
252
253 fn remove(&mut self, entry: &S) {
254 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
255 let notifier = match notifier.as_mut() {
257 Some(notifier) => notifier,
258 None => {
259 debug_assert!(ready_queue.is_empty());
262 debug_assert!(all_sockets.is_empty());
263 return;
264 }
265 };
266
267 match all_sockets.remove(entry) {
268 Some(EntryState::Pending) | None => (),
269 Some(EntryState::Ready) => {
270 let before_len = ready_queue.len();
271 ready_queue.retain(|(s, _ready_data)| s != entry);
272 let after_len = ready_queue.len();
273 assert_eq!(after_len, before_len - 1);
275 notifier.new_incoming_connections(after_len);
276 }
277 }
278 }
279
280 fn close(&mut self) -> (impl Iterator<Item = S> + use<S, R, N>, N) {
281 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
282 let notifier = notifier.take().expect("queue is already closed");
284 *ready_queue = Default::default();
288 let entries = core::mem::take(all_sockets);
291 (entries.into_keys(), notifier)
292 }
293
294 fn is_closed(&self) -> bool {
295 let AcceptQueueInner { notifier, .. } = self;
296 notifier.is_none()
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use assert_matches::assert_matches;
303 use netstack3_hashmap::HashSet;
304
305 #[test]
306 fn push_ready_pop() {
307 let mut queue = AcceptQueueInner::new(Notifier::default());
308 assert_eq!(queue.pop_ready(), None);
309 assert_eq!(queue.len(), 0);
310 assert_eq!(queue.ready_len(), 0);
311 assert_eq!(queue.pending_len(), 0);
312 assert_eq!(queue.clear_notifier(), None);
313
314 queue.push_pending(Socket(0));
315 assert_eq!(queue.pop_ready(), None);
316 assert_eq!(queue.len(), 1);
317 assert_eq!(queue.ready_len(), 0);
318 assert_eq!(queue.pending_len(), 1);
319 assert_eq!(queue.clear_notifier(), None);
320
321 queue.notify_ready(&Socket(0), Ready(2));
322 assert_eq!(queue.len(), 1);
323 assert_eq!(queue.ready_len(), 1);
324 assert_eq!(queue.pending_len(), 0);
325 assert_eq!(queue.clear_notifier(), Some(1));
326
327 assert_eq!(queue.pop_ready(), Some((Socket(0), Ready(2))));
328 assert_eq!(queue.clear_notifier(), Some(0));
329 assert_eq!(queue.len(), 0);
330 assert_eq!(queue.ready_len(), 0);
331 assert_eq!(queue.pending_len(), 0);
332 assert_eq!(queue.pop_ready(), None);
333 }
334
335 #[test]
336 fn close() {
337 let mut queue = AcceptQueueInner::new(Notifier::default());
338 let mut expect = HashSet::new();
339 for i in 0..3 {
340 let s = Socket(i);
341 queue.push_pending(s.clone());
342 assert!(expect.insert(s));
343 }
344 let (socks, _notifier) = queue.close();
345 let got = socks.collect::<HashSet<_>>();
346 assert_eq!(got, expect);
347
348 assert!(queue.is_closed());
349 assert_eq!(queue.len(), 0);
350 }
351
352 #[test]
353 fn remove() {
354 let mut queue = AcceptQueueInner::new(Notifier::default());
355 let s1 = Socket(1);
356 let s2 = Socket(2);
357 queue.push_pending(s1.clone());
358 queue.push_pending(s2.clone());
359 queue.notify_ready(&s2, Ready(2));
360 assert_eq!(queue.len(), 2);
361 assert_eq!(queue.ready_len(), 1);
362 assert_eq!(queue.pending_len(), 1);
363 assert_eq!(queue.clear_notifier(), Some(1));
364
365 queue.remove(&s1);
366 assert_eq!(queue.len(), 1);
367 assert_eq!(queue.ready_len(), 1);
368 assert_eq!(queue.pending_len(), 0);
369 assert_eq!(queue.clear_notifier(), None);
370
371 queue.remove(&s2);
372 assert_eq!(queue.len(), 0);
373 assert_eq!(queue.ready_len(), 0);
374 assert_eq!(queue.pending_len(), 0);
375 assert_eq!(queue.clear_notifier(), Some(0));
376
377 queue.remove(&s1);
379 queue.remove(&s2);
380 assert_eq!(queue.len(), 0);
381 assert_eq!(queue.ready_len(), 0);
382 assert_eq!(queue.pending_len(), 0);
383 assert_eq!(queue.clear_notifier(), None);
384 }
385
386 #[derive(Default, Eq, PartialEq, Debug, Hash, Clone)]
387 struct Socket(usize);
388 #[derive(Default, Eq, PartialEq, Debug)]
389 struct Ready(usize);
390
391 #[derive(Default, Eq, PartialEq, Debug)]
392 struct Notifier(Option<usize>);
393
394 type AcceptQueueInner = super::AcceptQueueInner<Socket, Ready, Notifier>;
395
396 impl AcceptQueueInner {
397 fn clear_notifier(&mut self) -> Option<usize> {
398 let Self { notifier, .. } = self;
399 let Notifier(v) = notifier.as_mut().unwrap();
400 v.take()
401 }
402 }
403
404 impl super::ListenerNotifier for Notifier {
405 fn new_incoming_connections(&mut self, num_ready: usize) {
406 let Self(n) = self;
407 assert_matches!(n.replace(num_ready), None);
408 }
409 }
410}