netstack3_base/socket/
sndbuf.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Types and traits providing send buffer management for netstack sockets.
6
7use netstack3_sync::atomic::{AtomicIsize, Ordering};
8use netstack3_sync::Mutex;
9
10use crate::num::PositiveIsize;
11
12/// Tracks the available send buffer space for a socket.
13#[derive(Debug)]
14pub struct SendBufferTracking<L> {
15    /// Keeps track of how many bytes are available in the send buffer.
16    ///
17    /// Due to overcommit, this is allowed to become negative. Hence, whenever
18    /// `available_bytes changes`, the only thing to be observed in terms of
19    /// signaling the listener is whether the zero is crossed (from negative to
20    /// positive and vice-versa) as part of adding or subtracting in flight
21    /// bytes.
22    available_bytes: AtomicIsize,
23    inner: Mutex<Inner<L>>,
24}
25
26impl<L> SendBufferTracking<L> {
27    /// Creates a new `SendBufferTracking` with the initial available send
28    /// buffer space `capacity` and a writable listener `listener`.
29    pub fn new(capacity: PositiveIsize, listener: L) -> Self {
30        Self {
31            available_bytes: AtomicIsize::new(capacity.into()),
32            inner: Mutex::new(Inner {
33                listener,
34                // Listeners must assume the socket is writable from creation.
35                // See SocketWritableListener.
36                notified_state: true,
37                capacity,
38            }),
39        }
40    }
41
42    /// Calls the callback `f` with a mutable reference to the listener.
43    #[cfg(any(test, feature = "testutils"))]
44    pub fn with_listener<R, F: FnOnce(&mut L) -> R>(&self, f: F) -> R {
45        f(&mut self.inner.lock().listener)
46    }
47}
48
49impl<L: SocketWritableListener> SendBufferTracking<L> {
50    /// Acquires `bytes` for send with this `SendBufferTracking` instance.
51    pub fn acquire(&self, bytes: PositiveIsize) -> Result<SendBufferSpace, SendBufferFullError> {
52        let bytes_isize = bytes.into();
53        // Ordering: Paired with Release in return_and_notify.
54        let prev = self.available_bytes.fetch_sub(bytes_isize, Ordering::Acquire);
55        // If we had enough bytes available allow the write.
56        if prev > bytes_isize {
57            return Ok(SendBufferSpace(bytes));
58        }
59        // Turns out we didn't have enough space. Place the bytes back in and
60        // potentially notify, but return an error.
61        //
62        // The regular return flow here is necessary because we could be racing
63        // with other agents acquiring and returning space, and this return
64        // could be the one that flips the socket back into writable.
65        if prev <= 0 {
66            self.return_and_notify(bytes_isize);
67            return Err(SendBufferFullError);
68        }
69
70        // prev is in the interval (0, bytes] meaning this allocation crossed a
71        // threshold. Notify the listener before returning.
72        self.notify();
73        Ok(SendBufferSpace(bytes))
74    }
75
76    /// Releases `space` back to this `SendBufferTracking` instance.
77    pub fn release(&self, space: SendBufferSpace) {
78        let SendBufferSpace(delta) = &space;
79        self.return_and_notify((*delta).into());
80        // Prevent drop panic for send buffer space.
81        core::mem::forget(space);
82    }
83
84    fn return_and_notify(&self, delta: isize) {
85        // Ordering: Paired with Acquire in acquire.
86        let prev = self.available_bytes.fetch_add(delta, Ordering::Release);
87        if prev <= 0 && prev + delta > 0 {
88            self.notify();
89        }
90    }
91
92    fn notify(&self) {
93        let Self { available_bytes, inner } = self;
94        let mut inner = inner.lock();
95        Self::notify_locked(available_bytes, &mut inner);
96    }
97
98    fn notify_locked(available_bytes: &AtomicIsize, inner: &mut Inner<L>) {
99        let Inner { listener, notified_state, capacity: _ } = inner;
100        // Read the available buffer space under lock and change the
101        // notification state accordingly. Relaxed ordering is okay here because
102        // the lock is guaranteeing the ordering.
103        let new_writable = available_bytes.load(Ordering::Relaxed) > 0;
104        if core::mem::replace(notified_state, new_writable) != new_writable {
105            listener.on_writable_changed(new_writable);
106        }
107    }
108
109    /// Returns the tracker's capacity.
110    pub fn capacity(&self) -> PositiveIsize {
111        self.inner.lock().capacity
112    }
113
114    /// Returns the currently available buffer space.
115    pub fn available(&self) -> Option<PositiveIsize> {
116        PositiveIsize::new(self.available_bytes.load(Ordering::Relaxed))
117    }
118
119    /// Updates the tracker's capacity to `new_capacity`.
120    ///
121    /// Note that upon changing the capacity the socket's writable state may
122    /// change.
123    pub fn set_capacity(&self, new_capacity: PositiveIsize) {
124        let Self { available_bytes, inner } = self;
125        let mut inner = inner.lock();
126        let Inner { listener: _, notified_state: _, capacity } = &mut *inner;
127        let old = core::mem::replace(capacity, new_capacity);
128        let delta = new_capacity.get() - old.get();
129        // Ordering: We're already under lock here and we want to ensure we're
130        // not reordering with the relaxed load in notify_locked.
131        let _: isize = available_bytes.fetch_add(delta, Ordering::AcqRel);
132        // We already have the lock, check if we need to notify regardless of
133        // whether we crossed zero or not.
134        Self::notify_locked(available_bytes, &mut inner);
135    }
136}
137
138#[derive(Debug)]
139struct Inner<L> {
140    listener: L,
141    /// Keeps track of the expected "writable" state seen by `listener`.
142    ///
143    /// This is necessary because of the zero-crossing optimization implemented
144    /// by [`SendBufferTracking`]. Multiple threads can be acquiring and
145    /// releasing bytes from the buffer at the same time and the zero crossing
146    /// itself doesn't fully guarantee that we'd never notify the listener twice
147    /// with the same value. This boolean provides that guarantee.
148    ///
149    /// Example race:
150    ///
151    /// - T1 acquires N bytes observes zero-crossing and tries to acquire the
152    ///   lock; gets descheduled.
153    /// - T2 releases N bytes, observes zero crossing and tries to acquire the
154    ///   lock; gets descheduled.
155    /// - T3 acquires N bytes, observes zero crossing and tries to acquire the
156    ///   lock; gets descheduled.
157    ///
158    /// Whenever the 3 threads become runnable again they'll acquire the lock
159    /// one at a time, read the number of available bytes in the buffer and
160    /// decide on a writable state, which would be attempted to be notified 3
161    /// times in a row.
162    notified_state: bool,
163    /// The maximum capacity allowed in the [`SendBufferTracking`] instance.
164    ///
165    /// Note that, due to overcommit, more than `capacity` bytes can be
166    /// in-flight at once. A single [`SendBufferSpace`] may be emitted that
167    /// crosses the capacity threshold.
168    capacity: PositiveIsize,
169}
170
171/// A type stating that some amount of space was reserved within a
172/// [`SendBufferTracking`] instance.
173///
174/// This type is returned from [`SendBufferTracking::acquire`] and *must* be
175/// returned to [`SendBufferTracking::release`] when the annotated send buffer
176/// space is freed. Otherwise, [`SendBufferSpace::acknowledge_drop`] must be
177/// called. This type panics on drop otherwise.
178#[derive(Debug, Eq, PartialEq)]
179pub struct SendBufferSpace(PositiveIsize);
180
181impl Drop for SendBufferSpace {
182    fn drop(&mut self) {
183        panic!("dropped send buffer space with {:?} bytes", self)
184    }
185}
186
187impl SendBufferSpace {
188    /// Acknowledges an unused and from this point on untracked buffer space.
189    ///
190    /// This may be used to drop previously allocated buffer space tracking that
191    /// has outlived the owning socket.
192    pub fn acknowledge_drop(self) {
193        core::mem::forget(self)
194    }
195}
196
197/// An error indicating that the send buffer is full.
198#[derive(Debug, Eq, PartialEq)]
199pub struct SendBufferFullError;
200
201/// A type capable of handling socket writable changes.
202///
203/// Upon creation, listeners must always assume the socket is writable.
204pub trait SocketWritableListener {
205    /// Notifies the listener the writable state has changed to `writable`.
206    ///
207    /// Callers must only call this when the writable state has actually
208    /// changed. Implementers may panic if they see the current state being
209    /// notified as changed.
210    fn on_writable_changed(&mut self, writable: bool);
211}
212
213#[cfg(any(test, feature = "testutils"))]
214pub(crate) mod testutil {
215    use super::*;
216
217    /// A fake [`SocketWritableListener`] implementation.
218    #[derive(Debug)]
219    pub struct FakeSocketWritableListener {
220        writable: bool,
221    }
222
223    impl FakeSocketWritableListener {
224        /// Returns whether the listener has observed a writable state.
225        pub fn is_writable(&self) -> bool {
226            self.writable
227        }
228    }
229
230    impl Default for FakeSocketWritableListener {
231        fn default() -> Self {
232            Self { writable: true }
233        }
234    }
235
236    impl SocketWritableListener for FakeSocketWritableListener {
237        fn on_writable_changed(&mut self, writable: bool) {
238            assert_ne!(core::mem::replace(&mut self.writable, writable), writable);
239        }
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use testutil::FakeSocketWritableListener;
247
248    use alloc::vec::Vec;
249
250    const SNDBUF: PositiveIsize = PositiveIsize::new(4).unwrap();
251    const HALF_SNDBUF: PositiveIsize = PositiveIsize::new(SNDBUF.get() / 2).unwrap();
252    const TWO_SNDBUF: PositiveIsize = PositiveIsize::new(SNDBUF.get() * 2).unwrap();
253    const ONE: PositiveIsize = PositiveIsize::new(1).unwrap();
254
255    impl SendBufferTracking<FakeSocketWritableListener> {
256        fn listener_writable(&self) -> bool {
257            self.inner.lock().listener.is_writable()
258        }
259    }
260
261    #[test]
262    fn acquire_all_buffer() {
263        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
264        let acquired = tracking.acquire(SNDBUF).expect("acquire");
265        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
266        assert_eq!(tracking.listener_writable(), false);
267        assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
268        tracking.release(acquired);
269        assert_eq!(tracking.listener_writable(), true);
270        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
271    }
272
273    #[test]
274    fn acquire_half_buffer() {
275        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
276        let acquired = tracking.acquire(HALF_SNDBUF).expect("acquire");
277        assert_eq!(
278            tracking.available_bytes.load(Ordering::SeqCst),
279            SNDBUF.get() - HALF_SNDBUF.get()
280        );
281        assert_eq!(tracking.listener_writable(), true);
282        tracking.release(acquired);
283        assert_eq!(tracking.listener_writable(), true);
284        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
285    }
286
287    #[test]
288    fn acquire_multiple() {
289        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
290        let tokens = (0..SNDBUF.get())
291            .map(|_| {
292                assert_eq!(tracking.listener_writable(), true);
293                tracking.acquire(ONE).expect("acquire")
294            })
295            .collect::<Vec<_>>();
296        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
297        assert_eq!(tracking.listener_writable(), false);
298
299        assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
300        for t in tokens {
301            tracking.release(t);
302            assert_eq!(tracking.listener_writable(), true);
303        }
304
305        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
306        assert_eq!(tracking.listener_writable(), true);
307    }
308
309    #[test]
310    fn overcommit_single_buffer() {
311        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
312        let acquired = tracking.acquire(TWO_SNDBUF).expect("acquire");
313        assert_eq!(tracking.listener_writable(), false);
314        assert_eq!(
315            tracking.available_bytes.load(Ordering::SeqCst),
316            SNDBUF.get() - TWO_SNDBUF.get()
317        );
318        assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
319
320        tracking.release(acquired);
321        assert_eq!(tracking.listener_writable(), true);
322        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
323    }
324
325    #[test]
326    fn overcommit_two_buffers() {
327        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
328        let acquired1 = tracking.acquire(ONE).expect("acquire");
329        assert_eq!(tracking.listener_writable(), true);
330        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get() - 1);
331        let acquired2 = tracking.acquire(SNDBUF).expect("acquire");
332        assert_eq!(tracking.listener_writable(), false);
333        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), -1);
334        assert_eq!(tracking.acquire(ONE), Err(SendBufferFullError));
335
336        tracking.release(acquired1);
337        // Still not writable.
338        assert_eq!(tracking.listener_writable(), false);
339        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
340
341        tracking.release(acquired2);
342        assert_eq!(tracking.listener_writable(), true);
343        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), SNDBUF.get());
344    }
345
346    #[test]
347    fn capacity_increase_makes_writable() {
348        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
349        assert_eq!(tracking.capacity(), SNDBUF);
350        let acquired = tracking.acquire(SNDBUF).expect("acquire");
351        assert_eq!(tracking.listener_writable(), false);
352
353        tracking.set_capacity(TWO_SNDBUF);
354        assert_eq!(tracking.listener_writable(), true);
355        assert_eq!(
356            tracking.available_bytes.load(Ordering::SeqCst),
357            TWO_SNDBUF.get() - SNDBUF.get()
358        );
359        assert_eq!(tracking.capacity(), TWO_SNDBUF);
360        tracking.release(acquired);
361    }
362
363    #[test]
364    fn capacity_decrease_makes_non_writable() {
365        let tracking = SendBufferTracking::new(SNDBUF, FakeSocketWritableListener::default());
366        assert_eq!(tracking.capacity(), SNDBUF);
367        let acquired = tracking.acquire(HALF_SNDBUF).expect("acquire");
368        assert_eq!(tracking.listener_writable(), true);
369
370        tracking.set_capacity(HALF_SNDBUF);
371        assert_eq!(tracking.listener_writable(), false);
372        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), 0);
373        assert_eq!(tracking.capacity(), HALF_SNDBUF);
374        tracking.release(acquired);
375        assert_eq!(tracking.listener_writable(), true);
376        assert_eq!(tracking.available_bytes.load(Ordering::SeqCst), HALF_SNDBUF.get());
377    }
378}