1use alloc::collections::{HashMap, VecDeque};
9use alloc::sync::Arc;
10#[cfg(test)]
11use alloc::vec::Vec;
12use core::fmt::Debug;
13use core::hash::Hash;
14use core::ops::DerefMut;
15
16use assert_matches::assert_matches;
17use derivative::Derivative;
18use netstack3_base::sync::Mutex;
19use netstack3_base::Inspector;
20
21pub trait ListenerNotifier {
24 fn new_incoming_connections(&mut self, num_ready: usize);
26}
27
28#[derive(Debug, Derivative)]
34#[derivative(Clone(bound = ""))]
35pub struct AcceptQueue<S, R, N>(Arc<Mutex<AcceptQueueInner<S, R, N>>>);
36
37#[cfg(test)]
38impl<S, R, N> PartialEq for AcceptQueue<S, R, N>
39where
40 AcceptQueueInner<S, R, N>: PartialEq,
41{
42 fn eq(&self, Self(other): &Self) -> bool {
43 let Self(inner) = self;
44 if Arc::ptr_eq(other, inner) {
45 return true;
46 }
47 let guard = inner.lock();
48 let other_guard = other.lock();
49 (&*guard).eq(&*other_guard)
50 }
51}
52
53#[cfg(test)]
54impl<S, R, N> Eq for AcceptQueue<S, R, N> where Self: PartialEq {}
55
56#[derive(Debug)]
57#[cfg_attr(test, derive(Eq, PartialEq))]
58enum EntryState {
59 Pending,
60 Ready,
61}
62
63#[derive(Debug)]
64#[cfg_attr(test, derive(Derivative))]
65#[cfg_attr(
66 test,
67 derivative(PartialEq(bound = "S: Hash + Clone + Eq + PartialEq, R: PartialEq, N: PartialEq"))
68)]
69struct AcceptQueueInner<S, R, N> {
70 ready_queue: VecDeque<(S, R)>,
71 all_sockets: HashMap<S, EntryState>,
72 notifier: Option<N>,
73}
74
75impl<S, R, N> AcceptQueue<S, R, N>
76where
77 S: Hash + Clone + Eq + PartialEq + Debug,
78 N: ListenerNotifier,
79{
80 pub(crate) fn new(notifier: N) -> Self {
82 Self(Arc::new(Mutex::new(AcceptQueueInner::new(notifier))))
83 }
84
85 fn lock(&self) -> impl DerefMut<Target = AcceptQueueInner<S, R, N>> + '_ {
86 let Self(inner) = self;
87 inner.lock()
88 }
89
90 pub(crate) fn pop_ready(&self) -> Option<(S, R)> {
97 self.lock().pop_ready()
98 }
99
100 #[cfg(test)]
102 pub(crate) fn collect_pending(&self) -> Vec<S> {
103 self.lock().collect_pending()
104 }
105
106 pub(crate) fn push_pending(&self, pending: S) {
111 self.lock().push_pending(pending)
112 }
113
114 pub(crate) fn len(&self) -> usize {
116 self.lock().len()
117 }
118
119 #[cfg(test)]
121 pub(crate) fn ready_len(&self) -> usize {
122 self.lock().ready_len()
123 }
124
125 #[cfg(test)]
127 pub(crate) fn pending_len(&self) -> usize {
128 self.lock().pending_len()
129 }
130
131 pub(crate) fn notify_ready(&self, newly_ready: &S, ready_state: R) {
138 self.lock().notify_ready(newly_ready, ready_state)
139 }
140
141 pub(crate) fn remove(&self, entry: &S) {
145 self.lock().remove(entry)
146 }
147
148 pub(crate) fn close(&self) -> (impl Iterator<Item = S>, N) {
153 self.lock().close()
154 }
155
156 pub(crate) fn is_closed(&self) -> bool {
158 self.lock().is_closed()
159 }
160
161 pub(crate) fn inspect<I: Inspector>(&self, inspector: &mut I) {
163 let inner = self.lock();
164 inspector.record_usize("NumReady", inner.ready_len());
165 inspector.record_usize("NumPending", inner.pending_len());
166 inspector.record_debug("Contents", &inner.all_sockets);
167 }
168}
169
170impl<S, R, N> AcceptQueueInner<S, R, N>
171where
172 S: Hash + Clone + Eq + PartialEq + Debug,
173 N: ListenerNotifier,
174{
175 fn new(notifier: N) -> Self {
176 Self {
177 ready_queue: Default::default(),
178 all_sockets: Default::default(),
179 notifier: Some(notifier),
180 }
181 }
182
183 fn pop_ready(&mut self) -> Option<(S, R)> {
184 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
185 let (socket, ready_state) = ready_queue.pop_front()?;
186 assert_matches!(all_sockets.remove(&socket), Some(EntryState::Ready));
188 let notifier = notifier.as_mut().unwrap();
190 notifier.new_incoming_connections(ready_queue.len());
191 Some((socket, ready_state))
192 }
193
194 #[cfg(test)]
195 pub(crate) fn collect_pending(&self) -> Vec<S> {
196 let AcceptQueueInner { all_sockets, .. } = self;
197 all_sockets
198 .iter()
199 .filter_map(|(socket, state)| match state {
200 EntryState::Ready => None,
201 EntryState::Pending => Some(socket.clone()),
202 })
203 .collect()
204 }
205
206 fn push_pending(&mut self, pending: S) {
207 let AcceptQueueInner { all_sockets, notifier, .. } = self;
208 assert!(notifier.is_some());
211 assert_matches!(all_sockets.insert(pending, EntryState::Pending), None);
212 }
213
214 fn len(&self) -> usize {
215 let AcceptQueueInner { all_sockets, .. } = self;
216 all_sockets.len()
217 }
218
219 fn ready_len(&self) -> usize {
220 let AcceptQueueInner { ready_queue, .. } = self;
221 ready_queue.len()
222 }
223
224 fn pending_len(&self) -> usize {
225 let AcceptQueueInner { ready_queue, all_sockets, .. } = self;
226 all_sockets.len() - ready_queue.len()
227 }
228
229 fn notify_ready(&mut self, newly_ready: &S, ready_state: R) {
230 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
231 let notifier = match notifier {
232 Some(notifier) => notifier,
233
234 None => {
235 debug_assert!(ready_queue.is_empty());
239 debug_assert!(all_sockets.is_empty());
240 return;
241 }
242 };
243 let entry = all_sockets
244 .get_mut(newly_ready)
245 .expect("attempted to notify ready entry that was not in queue");
246 let prev_state = core::mem::replace(entry, EntryState::Ready);
247 assert_matches!(prev_state, EntryState::Pending);
248 ready_queue.push_back((newly_ready.clone(), ready_state));
249 notifier.new_incoming_connections(ready_queue.len());
250 }
251
252 fn remove(&mut self, entry: &S) {
253 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
254 let notifier = match notifier.as_mut() {
256 Some(notifier) => notifier,
257 None => {
258 debug_assert!(ready_queue.is_empty());
261 debug_assert!(all_sockets.is_empty());
262 return;
263 }
264 };
265
266 match all_sockets.remove(entry) {
267 Some(EntryState::Pending) | None => (),
268 Some(EntryState::Ready) => {
269 let before_len = ready_queue.len();
270 ready_queue.retain(|(s, _ready_data)| s != entry);
271 let after_len = ready_queue.len();
272 assert_eq!(after_len, before_len - 1);
274 notifier.new_incoming_connections(after_len);
275 }
276 }
277 }
278
279 fn close(&mut self) -> (impl Iterator<Item = S>, N) {
280 let AcceptQueueInner { ready_queue, all_sockets, notifier } = self;
281 let notifier = notifier.take().expect("queue is already closed");
283 *ready_queue = Default::default();
287 let entries = core::mem::take(all_sockets);
290 (entries.into_keys(), notifier)
291 }
292
293 fn is_closed(&self) -> bool {
294 let AcceptQueueInner { notifier, .. } = self;
295 notifier.is_none()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use alloc::collections::HashSet;
302 use assert_matches::assert_matches;
303
304 #[test]
305 fn push_ready_pop() {
306 let mut queue = AcceptQueueInner::new(Notifier::default());
307 assert_eq!(queue.pop_ready(), None);
308 assert_eq!(queue.len(), 0);
309 assert_eq!(queue.ready_len(), 0);
310 assert_eq!(queue.pending_len(), 0);
311 assert_eq!(queue.clear_notifier(), None);
312
313 queue.push_pending(Socket(0));
314 assert_eq!(queue.pop_ready(), None);
315 assert_eq!(queue.len(), 1);
316 assert_eq!(queue.ready_len(), 0);
317 assert_eq!(queue.pending_len(), 1);
318 assert_eq!(queue.clear_notifier(), None);
319
320 queue.notify_ready(&Socket(0), Ready(2));
321 assert_eq!(queue.len(), 1);
322 assert_eq!(queue.ready_len(), 1);
323 assert_eq!(queue.pending_len(), 0);
324 assert_eq!(queue.clear_notifier(), Some(1));
325
326 assert_eq!(queue.pop_ready(), Some((Socket(0), Ready(2))));
327 assert_eq!(queue.clear_notifier(), Some(0));
328 assert_eq!(queue.len(), 0);
329 assert_eq!(queue.ready_len(), 0);
330 assert_eq!(queue.pending_len(), 0);
331 assert_eq!(queue.pop_ready(), None);
332 }
333
334 #[test]
335 fn close() {
336 let mut queue = AcceptQueueInner::new(Notifier::default());
337 let mut expect = HashSet::new();
338 for i in 0..3 {
339 let s = Socket(i);
340 queue.push_pending(s.clone());
341 assert!(expect.insert(s));
342 }
343 let (socks, _notifier) = queue.close();
344 let got = socks.collect::<HashSet<_>>();
345 assert_eq!(got, expect);
346
347 assert!(queue.is_closed());
348 assert_eq!(queue.len(), 0);
349 }
350
351 #[test]
352 fn remove() {
353 let mut queue = AcceptQueueInner::new(Notifier::default());
354 let s1 = Socket(1);
355 let s2 = Socket(2);
356 queue.push_pending(s1.clone());
357 queue.push_pending(s2.clone());
358 queue.notify_ready(&s2, Ready(2));
359 assert_eq!(queue.len(), 2);
360 assert_eq!(queue.ready_len(), 1);
361 assert_eq!(queue.pending_len(), 1);
362 assert_eq!(queue.clear_notifier(), Some(1));
363
364 queue.remove(&s1);
365 assert_eq!(queue.len(), 1);
366 assert_eq!(queue.ready_len(), 1);
367 assert_eq!(queue.pending_len(), 0);
368 assert_eq!(queue.clear_notifier(), None);
369
370 queue.remove(&s2);
371 assert_eq!(queue.len(), 0);
372 assert_eq!(queue.ready_len(), 0);
373 assert_eq!(queue.pending_len(), 0);
374 assert_eq!(queue.clear_notifier(), Some(0));
375
376 queue.remove(&s1);
378 queue.remove(&s2);
379 assert_eq!(queue.len(), 0);
380 assert_eq!(queue.ready_len(), 0);
381 assert_eq!(queue.pending_len(), 0);
382 assert_eq!(queue.clear_notifier(), None);
383 }
384
385 #[derive(Default, Eq, PartialEq, Debug, Hash, Clone)]
386 struct Socket(usize);
387 #[derive(Default, Eq, PartialEq, Debug)]
388 struct Ready(usize);
389
390 #[derive(Default, Eq, PartialEq, Debug)]
391 struct Notifier(Option<usize>);
392
393 type AcceptQueueInner = super::AcceptQueueInner<Socket, Ready, Notifier>;
394
395 impl AcceptQueueInner {
396 fn clear_notifier(&mut self) -> Option<usize> {
397 let Self { notifier, .. } = self;
398 let Notifier(v) = notifier.as_mut().unwrap();
399 v.take()
400 }
401 }
402
403 impl super::ListenerNotifier for Notifier {
404 fn new_incoming_connections(&mut self, num_ready: usize) {
405 let Self(n) = self;
406 assert_matches!(n.replace(num_ready), None);
407 }
408 }
409}