1use netstack3_sync::atomic::{AtomicIsize, Ordering};
8use netstack3_sync::Mutex;
9
10use crate::num::PositiveIsize;
11
12#[derive(Debug)]
14pub struct SendBufferTracking<L> {
15 available_bytes: AtomicIsize,
23 inner: Mutex<Inner<L>>,
24}
25
26impl<L> SendBufferTracking<L> {
27 pub fn new(capacity: PositiveIsize, listener: L) -> Self {
30 Self {
31 available_bytes: AtomicIsize::new(capacity.into()),
32 inner: Mutex::new(Inner {
33 listener,
34 notified_state: true,
37 capacity,
38 }),
39 }
40 }
41
42 #[cfg(any(test, feature = "testutils"))]
44 pub fn with_listener<R, F: FnOnce(&mut L) -> R>(&self, f: F) -> R {
45 f(&mut self.inner.lock().listener)
46 }
47}
48
49impl<L: SocketWritableListener> SendBufferTracking<L> {
50 pub fn acquire(&self, bytes: PositiveIsize) -> Result<SendBufferSpace, SendBufferFullError> {
52 let bytes_isize = bytes.into();
53 let prev = self.available_bytes.fetch_sub(bytes_isize, Ordering::Acquire);
55 if prev > bytes_isize {
57 return Ok(SendBufferSpace(bytes));
58 }
59 if prev <= 0 {
66 self.return_and_notify(bytes_isize);
67 return Err(SendBufferFullError);
68 }
69
70 self.notify();
73 Ok(SendBufferSpace(bytes))
74 }
75
76 pub fn release(&self, space: SendBufferSpace) {
78 let SendBufferSpace(delta) = &space;
79 self.return_and_notify((*delta).into());
80 core::mem::forget(space);
82 }
83
84 fn return_and_notify(&self, delta: isize) {
85 let prev = self.available_bytes.fetch_add(delta, Ordering::Release);
87 if prev <= 0 && prev + delta > 0 {
88 self.notify();
89 }
90 }
91
92 fn notify(&self) {
93 let Self { available_bytes, inner } = self;
94 let mut inner = inner.lock();
95 Self::notify_locked(available_bytes, &mut inner);
96 }
97
98 fn notify_locked(available_bytes: &AtomicIsize, inner: &mut Inner<L>) {
99 let Inner { listener, notified_state, capacity: _ } = inner;
100 let new_writable = available_bytes.load(Ordering::Relaxed) > 0;
104 if core::mem::replace(notified_state, new_writable) != new_writable {
105 listener.on_writable_changed(new_writable);
106 }
107 }
108
109 pub fn capacity(&self) -> PositiveIsize {
111 self.inner.lock().capacity
112 }
113
114 pub fn available(&self) -> Option<PositiveIsize> {
116 PositiveIsize::new(self.available_bytes.load(Ordering::Relaxed))
117 }
118
119 pub fn set_capacity(&self, new_capacity: PositiveIsize) {
124 let Self { available_bytes, inner } = self;
125 let mut inner = inner.lock();
126 let Inner { listener: _, notified_state: _, capacity } = &mut *inner;
127 let old = core::mem::replace(capacity, new_capacity);
128 let delta = new_capacity.get() - old.get();
129 let _: isize = available_bytes.fetch_add(delta, Ordering::AcqRel);
132 Self::notify_locked(available_bytes, &mut inner);
135 }
136}
137
138#[derive(Debug)]
139struct Inner<L> {
140 listener: L,
141 notified_state: bool,
163 capacity: PositiveIsize,
169}
170
171#[derive(Debug, Eq, PartialEq)]
179pub struct SendBufferSpace(PositiveIsize);
180
181impl Drop for SendBufferSpace {
182 fn drop(&mut self) {
183 panic!("dropped send buffer space with {:?} bytes", self)
184 }
185}
186
187impl SendBufferSpace {
188 pub fn acknowledge_drop(self) {
193 core::mem::forget(self)
194 }
195}
196
197#[derive(Debug, Eq, PartialEq)]
199pub struct SendBufferFullError;
200
201pub trait SocketWritableListener {
205 fn on_writable_changed(&mut self, writable: bool);
211}
212
213#[cfg(any(test, feature = "testutils"))]
214pub(crate) mod testutil {
215 use super::*;
216
217 #[derive(Debug)]
219 pub struct FakeSocketWritableListener {
220 writable: bool,
221 }
222
223 impl FakeSocketWritableListener {
224 pub fn is_writable(&self) -> bool {
226 self.writable
227 }
228 }
229
230 impl Default for FakeSocketWritableListener {
231 fn default() -> Self {
232 Self { writable: true }
233 }
234 }
235
236 impl SocketWritableListener for FakeSocketWritableListener {
237 fn on_writable_changed(&mut self, writable: bool) {
238 assert_ne!(core::mem::replace(&mut self.writable, writable), writable);
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use testutil::FakeSocketWritableListener;
247
248 use alloc::vec::Vec;
249
250 const SNDBUF: PositiveIsize = PositiveIsize::new(4).unwrap();
251 const HALF_SNDBUF: PositiveIsize = PositiveIsize::new(SNDBUF.get() / 2).unwrap();
252 const TWO_SNDBUF: PositiveIsize = PositiveIsize::new(SNDBUF.get() * 2).unwrap();
253 const ONE: PositiveIsize = PositiveIsize::new(1).unwrap();
254
255 impl SendBufferTracking<FakeSocketWritableListener> {
256 fn listener_writable(&self) -> bool {
257 self.inner.lock().listener.is_writable()
258 }
259 }
260
261 #[test]
262 fn acquire_all_buffer() {
263 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
264 let acquired = tracking.acquire(SNDBUF).expect("acquire");
265 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
266 assert_eq!(tracking.listener_writable(), false);
267 assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
268 tracking.release(acquired);
269 assert_eq!(tracking.listener_writable(), true);
270 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
271 }
272
273 #[test]
274 fn acquire_half_buffer() {
275 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
276 let acquired = tracking.acquire(HALF_SNDBUF).expect("acquire");
277 assert_eq!(
278 tracking.available_bytes.load(Ordering::SeqCst),
279 SNDBUF.get() - HALF_SNDBUF.get()
280 );
281 assert_eq!(tracking.listener_writable(), true);
282 tracking.release(acquired);
283 assert_eq!(tracking.listener_writable(), true);
284 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
285 }
286
287 #[test]
288 fn acquire_multiple() {
289 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
290 let tokens = (0..SNDBUF.get())
291 .map(|_| {
292 assert_eq!(tracking.listener_writable(), true);
293 tracking.acquire(ONE).expect("acquire")
294 })
295 .collect::<Vec<_>>();
296 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
297 assert_eq!(tracking.listener_writable(), false);
298
299 assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
300 for t in tokens {
301 tracking.release(t);
302 assert_eq!(tracking.listener_writable(), true);
303 }
304
305 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
306 assert_eq!(tracking.listener_writable(), true);
307 }
308
309 #[test]
310 fn overcommit_single_buffer() {
311 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
312 let acquired = tracking.acquire(TWO_SNDBUF).expect("acquire");
313 assert_eq!(tracking.listener_writable(), false);
314 assert_eq!(
315 tracking.available_bytes.load(Ordering::SeqCst),
316 SNDBUF.get() - TWO_SNDBUF.get()
317 );
318 assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
319
320 tracking.release(acquired);
321 assert_eq!(tracking.listener_writable(), true);
322 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
323 }
324
325 #[test]
326 fn overcommit_two_buffers() {
327 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
328 let acquired1 = tracking.acquire(ONE).expect("acquire");
329 assert_eq!(tracking.listener_writable(), true);
330 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get() - 1);
331 let acquired2 = tracking.acquire(SNDBUF).expect("acquire");
332 assert_eq!(tracking.listener_writable(), false);
333 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), -1);
334 assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
335
336 tracking.release(acquired1);
337 assert_eq!(tracking.listener_writable(), false);
339 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
340
341 tracking.release(acquired2);
342 assert_eq!(tracking.listener_writable(), true);
343 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
344 }
345
346 #[test]
347 fn capacity_increase_makes_writable() {
348 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
349 assert_eq!(tracking.capacity(), SNDBUF);
350 let acquired = tracking.acquire(SNDBUF).expect("acquire");
351 assert_eq!(tracking.listener_writable(), false);
352
353 tracking.set_capacity(TWO_SNDBUF);
354 assert_eq!(tracking.listener_writable(), true);
355 assert_eq!(
356 tracking.available_bytes.load(Ordering::SeqCst),
357 TWO_SNDBUF.get() - SNDBUF.get()
358 );
359 assert_eq!(tracking.capacity(), TWO_SNDBUF);
360 tracking.release(acquired);
361 }
362
363 #[test]
364 fn capacity_decrease_makes_non_writable() {
365 let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
366 assert_eq!(tracking.capacity(), SNDBUF);
367 let acquired = tracking.acquire(HALF_SNDBUF).expect("acquire");
368 assert_eq!(tracking.listener_writable(), true);
369
370 tracking.set_capacity(HALF_SNDBUF);
371 assert_eq!(tracking.listener_writable(), false);
372 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
373 assert_eq!(tracking.capacity(), HALF_SNDBUF);
374 tracking.release(acquired);
375 assert_eq!(tracking.listener_writable(), true);
376 assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), HALF_SNDBUF.get());
377 }
378}