Skip to main content

netdevice_client/session/buffer/
pool.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Fuchsia netdevice buffer pool.
6
7use fuchsia_sync::Mutex;
8use futures::task::AtomicWaker;
9use std::borrow::Borrow;
10use std::collections::VecDeque;
11use std::convert::TryInto as _;
12use std::fmt::Debug;
13use std::io::{Read, Seek, SeekFrom, Write};
14use std::mem::MaybeUninit;
15use std::num::TryFromIntError;
16use std::ops::{Deref, DerefMut};
17use std::ptr::NonNull;
18use std::sync::Arc;
19use std::sync::atomic::{self, AtomicBool, AtomicU64};
20use std::task::Poll;
21
22use arrayvec::ArrayVec;
23use explicit::ResultExt as _;
24use fidl_fuchsia_hardware_network as netdev;
25use fuchsia_runtime::vmar_root_self;
26use futures::channel::oneshot::{Receiver, Sender, channel};
27
28use super::{ChainLength, DescId, DescRef, DescRefMut, Descriptors};
29use crate::error::{Error, Result};
30use crate::session::{BufferLayout, Config, Pending, Port};
31
32/// Responsible for managing [`Buffer`]s for a [`Session`](crate::session::Session).
33pub(in crate::session) struct Pool {
34    /// Base address of the pool.
35    // Note: This field requires us to manually implement `Sync` and `Send`.
36    base: NonNull<u8>,
37    /// The length of the pool in bytes.
38    bytes: usize,
39    /// The descriptors allocated for the pool.
40    descriptors: Descriptors,
41    /// Shared state for allocation.
42    tx_alloc_state: Mutex<TxAllocState>,
43    /// The free rx descriptors pending to be sent to driver.
44    pub(in crate::session) rx_pending: Pending<Rx>,
45    /// The buffer layout.
46    buffer_layout: BufferLayout,
47    /// State-keeping allowing sessions to handle rx leases.
48    rx_leases: RxLeaseHandlingState,
49}
50
51// `Pool` is `Send` and `Sync`, and this allows the compiler to deduce `Buffer`
52// to be `Send`. These impls are safe because we can safely share `Pool` and
53// `&Pool`: the implementation would never allocate the same buffer to two
54// callers at the same time.
55unsafe impl Send for Pool {}
56unsafe impl Sync for Pool {}
57
58/// The shared state which keeps track of available buffers and tx buffers.
59struct TxAllocState {
60    /// All pending tx allocation requests.
61    requests: VecDeque<TxAllocReq>,
62    free_list: TxFreeList,
63}
64
65/// We use a linked list to maintain the tx free descriptors - they are linked
66/// through their `nxt` fields, note this differs from the chaining expected
67/// by the network device protocol:
68/// - You can chain more than [`netdev::MAX_DESCRIPTOR_CHAIN`] descriptors
69///   together.
70/// - the free-list ends when the `nxt` field is 0xff, while the normal chain
71///   ends when `chain_length` becomes 0.
72struct TxFreeList {
73    /// The head of a linked list of available descriptors that can be allocated
74    /// for tx.
75    head: Option<DescId<Tx>>,
76    /// How many free descriptors are there in the pool.
77    len: u16,
78}
79
80impl Pool {
81    /// Creates a new [`Pool`] and its backing [`zx::Vmo`]s.
82    ///
83    /// Returns [`Pool`] and the [`zx::Vmo`]s for descriptors and data, in that
84    /// order.
85    pub(in crate::session) fn new(config: Config) -> Result<(Arc<Self>, zx::Vmo, zx::Vmo)> {
86        let Config { buffer_stride, num_rx_buffers, num_tx_buffers, options, buffer_layout } =
87            config;
88        let num_buffers = num_rx_buffers.get() + num_tx_buffers.get();
89        let (descriptors, descriptors_vmo, tx_free, mut rx_free) =
90            Descriptors::new(num_tx_buffers, num_rx_buffers, buffer_stride)?;
91
92        // Construct the free list.
93        let free_head = tx_free.into_iter().rev().fold(None, |head, mut curr| {
94            descriptors.borrow_mut(&mut curr).set_nxt(head);
95            Some(curr)
96        });
97
98        for rx_desc in rx_free.iter_mut() {
99            descriptors.borrow_mut(rx_desc).initialize(
100                ChainLength::ZERO,
101                0,
102                buffer_layout.length.try_into().unwrap(),
103                0,
104            );
105        }
106
107        let tx_alloc_state = TxAllocState {
108            free_list: TxFreeList { head: free_head, len: num_tx_buffers.get() },
109            requests: VecDeque::new(),
110        };
111
112        let size = buffer_stride.get() * u64::from(num_buffers);
113        let data_vmo = zx::Vmo::create(size).map_err(|status| Error::Vmo("data", status))?;
114
115        const VMO_NAME: zx::Name =
116            const_unwrap::const_unwrap_result(zx::Name::new("netdevice:data"));
117        data_vmo.set_name(&VMO_NAME).map_err(|status| Error::Vmo("set name", status))?;
118        // `as` is OK because `size` is positive and smaller than isize::MAX.
119        // This is following the practice of rust stdlib to ensure allocation
120        // size never reaches isize::MAX.
121        // https://doc.rust-lang.org/std/primitive.pointer.html#method.add-1.
122        let len = isize::try_from(size).expect("VMO size larger than isize::MAX") as usize;
123        // The returned address of zx_vmar_map on success must be non-zero:
124        // https://fuchsia.dev/fuchsia-src/reference/syscalls/vmar_map
125        let base = NonNull::new(
126            vmar_root_self()
127                .map(0, &data_vmo, 0, len, zx::VmarFlags::PERM_READ | zx::VmarFlags::PERM_WRITE)
128                .map_err(|status| Error::Map("data", status))? as *mut u8,
129        )
130        .unwrap();
131
132        Ok((
133            Arc::new(Pool {
134                base,
135                bytes: len,
136                descriptors,
137                tx_alloc_state: Mutex::new(tx_alloc_state),
138                rx_pending: Pending::new(rx_free),
139                buffer_layout,
140                rx_leases: RxLeaseHandlingState::new_with_flags(options),
141            }),
142            descriptors_vmo,
143            data_vmo,
144        ))
145    }
146
147    /// Allocates `num_parts` tx descriptors.
148    ///
149    /// It will block if there are not enough descriptors. Note that the
150    /// descriptors are not initialized, you need to call [`AllocGuard::init()`]
151    /// on the returned [`AllocGuard`] if you want to send it to the driver
152    /// later. See [`AllocGuard<Rx>::into_tx()`] for an example where
153    /// [`AllocGuard::init()`] is not needed because the tx allocation will be
154    /// returned to the pool immediately and won't be sent to the driver.
155    pub(in crate::session) async fn alloc_tx(
156        self: &Arc<Self>,
157        num_parts: ChainLength,
158    ) -> AllocGuard<Tx> {
159        let receiver = {
160            let mut state = self.tx_alloc_state.lock();
161            match state.free_list.try_alloc(num_parts, &self.descriptors) {
162                Some(allocated) => {
163                    return AllocGuard::new(allocated, self.clone());
164                }
165                None => {
166                    let (request, receiver) = TxAllocReq::new(num_parts);
167                    state.requests.push_back(request);
168                    receiver
169                }
170            }
171        };
172        // The sender must not be dropped.
173        receiver.await.unwrap()
174    }
175
176    /// Tries to allocate a [`SinglePartTxBuffer`].
177    ///
178    /// Returns `Ok(None)` if there is no available buffer, or `Err(Error::TxLength)`
179    /// if the requested size cannot meet the device requirement.
180    pub(in crate::session) fn try_alloc_single_part_tx_buffer(
181        self: &Arc<Self>,
182        num_bytes: usize,
183    ) -> Result<Option<SinglePartTxBuffer>> {
184        let BufferLayout { min_tx_data: _, min_tx_head, min_tx_tail, length: buffer_length } =
185            self.buffer_layout;
186        if num_bytes > buffer_length - usize::from(min_tx_head) - usize::from(min_tx_tail) {
187            return Err(Error::TxLength);
188        }
189        self.tx_alloc_state
190            .lock()
191            .free_list
192            .try_alloc(ChainLength::try_from(1u8).unwrap(), &self.descriptors)
193            .map(|allocated| -> Result<SinglePartTxBuffer> {
194                let mut alloc = AllocGuard::new(allocated, self.clone());
195                alloc.init(num_bytes)?;
196                let buffer = Buffer::from(alloc);
197                Ok(SinglePartTxBuffer::new(buffer, num_bytes).expect("must be single part"))
198            })
199            .transpose()
200    }
201
202    /// Allocates a tx [`Buffer`].
203    ///
204    /// The returned buffer will have `num_bytes` as its capacity, the method
205    /// will block if there are not enough buffers. An error will be returned if
206    /// the requested size cannot meet the device requirement, for example, if
207    /// the size of the head or tail region will become unrepresentable in u16.
208    pub(in crate::session) async fn alloc_tx_buffer(
209        self: &Arc<Self>,
210        num_bytes: usize,
211    ) -> Result<Buffer<Tx>> {
212        self.alloc_tx_buffers(num_bytes).await?.next().unwrap()
213    }
214
215    /// Waits for at least one TX buffer to be available and returns an iterator
216    /// of buffers with `num_bytes` as capacity.
217    ///
218    /// The returned iterator is guaranteed to yield at least one item (though
219    /// it might be an error if the requested size cannot meet the device
220    /// requirement).
221    ///
222    /// # Note
223    ///
224    /// Given a `Buffer<Tx>` is returned to the pool when it's dropped, the
225    /// returned iterator will seemingly yield infinite items if the yielded
226    /// `Buffer`s are dropped while iterating.
227    pub(in crate::session) async fn alloc_tx_buffers<'a>(
228        self: &'a Arc<Self>,
229        num_bytes: usize,
230    ) -> Result<impl Iterator<Item = Result<Buffer<Tx>>> + 'a> {
231        let BufferLayout { min_tx_data, min_tx_head, min_tx_tail, length: buffer_length } =
232            self.buffer_layout;
233        let tx_head = usize::from(min_tx_head);
234        let tx_tail = usize::from(min_tx_tail);
235        let total_bytes = num_bytes.max(min_tx_data) + tx_head + tx_tail;
236        let num_parts = (total_bytes + buffer_length - 1) / buffer_length;
237        let chain_length = ChainLength::try_from(num_parts)?;
238        let first = self.alloc_tx(chain_length).await;
239        let iter = std::iter::once(first)
240            .chain(std::iter::from_fn(move || {
241                let mut state = self.tx_alloc_state.lock();
242                state
243                    .free_list
244                    .try_alloc(chain_length, &self.descriptors)
245                    .map(|allocated| AllocGuard::new(allocated, self.clone()))
246            }))
247            // Fuse afterwards so we're guaranteeing we can't see a new entry
248            // after having yielded `None` once.
249            .fuse()
250            .map(move |mut alloc| {
251                alloc.init(num_bytes)?;
252                Ok(alloc.into())
253            });
254        Ok(iter)
255    }
256
257    /// Frees rx descriptors.
258    pub(in crate::session) fn free_rx(&self, descs: impl IntoIterator<Item = DescId<Rx>>) {
259        self.rx_pending.extend(descs.into_iter().map(|mut desc| {
260            self.descriptors.borrow_mut(&mut desc).initialize(
261                ChainLength::ZERO,
262                0,
263                self.buffer_layout.length.try_into().unwrap(),
264                0,
265            );
266            desc
267        }));
268    }
269
270    /// Frees tx descriptors.
271    ///
272    /// # Panics
273    ///
274    /// Panics if given an empty chain.
275    fn free_tx(self: &Arc<Self>, chain: Chained<DescId<Tx>>) {
276        // We store any pending request that need to be fulfilled in the stack
277        // here, to fulfill them only once we drop the lock, guaranteeing an
278        // AllocGuard can't be dropped while the lock is held.
279        let mut to_fulfill = ArrayVec::<
280            (TxAllocReq, AllocGuard<Tx>),
281            { netdev::MAX_DESCRIPTOR_CHAIN as usize },
282        >::new();
283
284        let mut state = self.tx_alloc_state.lock();
285
286        {
287            let mut descs = chain.into_iter();
288            // The following can't overflow because we can have at most u16::MAX
289            // descriptors: free_len + #(to_free) + #(descs in use) <= u16::MAX,
290            // Thus free_len + #(to_free) <= u16::MAX.
291            state.free_list.len += u16::try_from(descs.len()).unwrap();
292            let head = descs.next();
293            let old_head = std::mem::replace(&mut state.free_list.head, head);
294            let mut tail = descs.last();
295            let mut tail_ref = self.descriptors.borrow_mut(
296                tail.as_mut().unwrap_or_else(|| state.free_list.head.as_mut().unwrap()),
297            );
298            tail_ref.set_nxt(old_head);
299        }
300
301        // After putting the chain back into the free list, we try to fulfill
302        // any pending tx allocation requests.
303        while let Some(req) = state.requests.front() {
304            // Skip a request that we know is canceled.
305            //
306            // This is an optimization for long-ago dropped requests, since the
307            // receiver side can be dropped between here and fulfillment later.
308            if req.sender.is_canceled() {
309                let _cancelled: Option<TxAllocReq> = state.requests.pop_front();
310                continue;
311            }
312            let size = req.size;
313            match state.free_list.try_alloc(size, &self.descriptors) {
314                Some(descs) => {
315                    // The unwrap is safe because we know requests is not empty.
316                    let req = state.requests.pop_front().unwrap();
317                    to_fulfill.push((req, AllocGuard::new(descs, self.clone())));
318
319                    // If we're full temporarily release the lock to go again
320                    // later. Fulfilling a request must _always_ be done without
321                    // holding the lock.
322                    if to_fulfill.is_full() {
323                        drop(state);
324                        for (req, alloc) in to_fulfill.drain(..) {
325                            req.fulfill(alloc)
326                        }
327                        state = self.tx_alloc_state.lock();
328                    }
329                }
330                None => break,
331            }
332        }
333
334        // Make sure we're not holding the state lock when fulfilling requests.
335        drop(state);
336        // Fulfill any ready requests.
337        for (req, alloc) in to_fulfill {
338            req.fulfill(alloc)
339        }
340    }
341
342    /// Frees the completed tx descriptors chained by head to the pool.
343    ///
344    /// Call this function when the driver hands back a completed tx descriptor.
345    pub(in crate::session) fn tx_completed(self: &Arc<Self>, head: DescId<Tx>) -> Result<()> {
346        let chain = self.descriptors.chain(head).collect::<Result<Chained<_>>>()?;
347        Ok(self.free_tx(chain))
348    }
349
350    /// Creates a [`Buffer<Rx>`] corresponding to the completed rx descriptors.
351    ///
352    /// Whenever the driver hands back a completed rx descriptor, this function
353    /// can be used to create the buffer that is represented by those chained
354    /// descriptors.
355    pub(in crate::session) fn rx_completed(
356        self: &Arc<Self>,
357        head: DescId<Rx>,
358    ) -> Result<Buffer<Rx>> {
359        let descs = self.descriptors.chain(head).collect::<Result<Chained<_>>>()?;
360        let alloc = AllocGuard::new(descs, self.clone());
361        Ok(alloc.into())
362    }
363
364    fn get_slice<'a, K: AllocKind>(&self, desc: &'a DescId<K>) -> &'a [u8] {
365        let desc = self.descriptors.borrow(desc);
366        let offset = usize::try_from(desc.offset() + u64::from(desc.head_length()))
367            .expect("usize must hold u64");
368        let len = usize::try_from(desc.data_length()).expect("usize must hold u32");
369        // Safety: The descriptor is describing a buffer from this pool. It must
370        // be valid to create a slice into that region. We hold a immutable
371        // reference to the underlying descriptor, this means no one else should
372        // have mutable reference to this memory region.
373        unsafe {
374            let ptr = self.base.as_ptr().add(offset);
375            std::slice::from_raw_parts(ptr, len)
376        }
377    }
378
379    fn get_slice_mut<'a, K: AllocKind>(&self, desc: &'a mut DescId<K>) -> &'a mut [u8] {
380        let desc = self.descriptors.borrow_mut(desc);
381        let offset = usize::try_from(desc.offset() + u64::from(desc.head_length()))
382            .expect("usize must hold u64");
383        let len = usize::try_from(desc.data_length()).expect("usize must hold u32");
384        // Safety: The descriptor is describing a buffer from this pool. It must
385        // be valid to create a slice into that region. We hold a mutable
386        // reference to the underlying descriptor, this means we are currently
387        // the only one has access to this memory region.
388        unsafe {
389            let ptr = self.base.as_ptr().add(offset);
390            std::slice::from_raw_parts_mut(ptr, len)
391        }
392    }
393}
394
395impl Drop for Pool {
396    fn drop(&mut self) {
397        unsafe {
398            vmar_root_self()
399                .unmap(self.base.as_ptr() as usize, self.bytes)
400                .expect("failed to unmap VMO for Pool")
401        }
402    }
403}
404
405impl TxFreeList {
406    /// Tries to allocate tx descriptors.
407    ///
408    /// Returns [`None`] if there are not enough descriptors.
409    fn try_alloc(
410        &mut self,
411        num_parts: ChainLength,
412        descriptors: &Descriptors,
413    ) -> Option<Chained<DescId<Tx>>> {
414        if u16::from(num_parts.get()) > self.len {
415            return None;
416        }
417
418        let free_list = std::iter::from_fn(|| -> Option<DescId<Tx>> {
419            let new_head = self.head.as_ref().and_then(|head| {
420                let nxt = descriptors.borrow(head).nxt();
421                nxt.map(|id| unsafe {
422                    // Safety: This is the nxt field of head of the free list,
423                    // it must be a tx descriptor id.
424                    DescId::from_raw(id)
425                })
426            });
427            std::mem::replace(&mut self.head, new_head)
428        });
429        let allocated = free_list.take(num_parts.get().into()).collect::<Chained<_>>();
430        assert_eq!(allocated.len(), num_parts.into());
431        self.len -= u16::from(num_parts.get());
432        Some(allocated)
433    }
434}
435
436/// The buffer that can be used by the [`Session`](crate::session::Session).
437pub struct Buffer<K: AllocKind> {
438    /// The descriptors allocation.
439    alloc: AllocGuard<K>,
440}
441
442impl<K: AllocKind> Buffer<K> {
443    /// Returns the length of data region of the buffer.
444    pub fn len(&self) -> usize {
445        self.parts().map(|s| s.len()).sum()
446    }
447
448    /// Returns an iterator over the data slices of the buffer parts.
449    fn parts(&self) -> impl Iterator<Item = &[u8]> + '_ {
450        self.alloc.descs.iter().map(|desc| self.alloc.pool.get_slice(desc))
451    }
452
453    /// Returns an iterator over the mutable valid data slices of the buffer parts.
454    fn parts_mut(&mut self) -> impl Iterator<Item = &mut [u8]> + '_ {
455        self.alloc.descs.iter_mut().map(|desc| self.alloc.pool.get_slice_mut(desc))
456    }
457
458    /// Leaks the underlying buffer descriptors to the driver.
459    pub(in crate::session) fn leak(mut self) -> DescId<K> {
460        let descs = std::mem::replace(&mut self.alloc.descs, Chained::empty());
461        descs.into_iter().next().unwrap()
462    }
463
464    /// Retrieves the frame type of the buffer.
465    pub fn frame_type(&self) -> Result<netdev::FrameType> {
466        self.alloc.descriptor().frame_type()
467    }
468
469    /// Retrieves the buffer's source port.
470    pub fn port(&self) -> Port {
471        self.alloc.descriptor().port()
472    }
473
474    /// Returns the buffer data as a slice.
475    pub fn as_slice(&self) -> Option<&[u8]> {
476        if self.alloc.len() != 1 {
477            return None;
478        }
479        self.parts().next()
480    }
481
482    /// Returns the buffer data as a mutable slice.
483    pub fn as_slice_mut(&mut self) -> Option<&mut [u8]> {
484        if self.alloc.len() != 1 {
485            return None;
486        }
487        self.parts_mut().next()
488    }
489
490    /// Returns a wrapper for read-only operations.
491    pub fn io(&self) -> BufferIORef<'_, K> {
492        let mut len = 0;
493        let parts: Chained<&[u8]> = self.parts().inspect(|s| len += s.len()).collect();
494        BufferIO { parts, pos: 0, len, _marker: std::marker::PhantomData }
495    }
496
497    /// Returns a wrapper for read-write operations.
498    pub fn io_mut(&mut self) -> BufferIOMut<'_, K> {
499        let mut len = 0;
500        let parts: Chained<&mut [u8]> = self.parts_mut().inspect(|s| len += s.len()).collect();
501        BufferIO { parts, pos: 0, len, _marker: std::marker::PhantomData }
502    }
503}
504
505impl Buffer<Tx> {
506    /// Sets the buffer's destination port.
507    pub fn set_port(&mut self, port: Port) {
508        self.alloc.descriptor_mut().set_port(port)
509    }
510
511    /// Sets the frame type of the buffer.
512    pub fn set_frame_type(&mut self, frame_type: netdev::FrameType) {
513        self.alloc.descriptor_mut().set_frame_type(frame_type)
514    }
515
516    /// Sets TxFlags of a Tx buffer.
517    pub fn set_tx_flags(&mut self, flags: netdev::TxFlags) {
518        self.alloc.descriptor_mut().set_tx_flags(flags)
519    }
520
521    /// Shrinks the buffer.
522    ///
523    /// This method shrinks the buffer length to the larger of
524    ///   - requested new length
525    ///   - device required minimum Tx data length
526    ///
527    /// It is an error to try to increase the buffer length.
528    pub fn shrink_to(&mut self, mut new_len: usize) -> Result<()> {
529        let current_len = self.len();
530
531        if new_len > current_len {
532            return Err(Error::TxLength);
533        }
534
535        let min_tx_data = usize::from(self.alloc.pool.buffer_layout.min_tx_data);
536        new_len = new_len.max(min_tx_data);
537
538        let layouts = self.alloc.calculate_descriptor_layouts(new_len)?;
539
540        for (desc_id, DescriptorLayout { data_length, tail_length, .. }) in
541            self.alloc.descs.iter_mut().zip(layouts)
542        {
543            let mut descriptor = self.alloc.pool.descriptors.borrow_mut(desc_id);
544            descriptor.set_data_length(data_length);
545            descriptor.set_tail_length(tail_length);
546        }
547        Ok(())
548    }
549}
550
551impl Buffer<Rx> {
552    /// Turns an rx buffer into a tx one.
553    pub async fn into_tx(self) -> Buffer<Tx> {
554        let Buffer { alloc } = self;
555        Buffer { alloc: alloc.into_tx().await }
556    }
557
558    /// Retrieves RxFlags of an Rx Buffer.
559    pub fn rx_flags(&self) -> Result<netdev::RxFlags> {
560        self.alloc.descriptor().rx_flags()
561    }
562}
563
564impl<K: AllocKind> Debug for Buffer<K> {
565    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        let Self { alloc } = self;
567        f.debug_struct("Buffer").field("alloc", alloc).finish()
568    }
569}
570
571/// A witness type that proves the buffer is backed by one part only
572/// and thus can be converted into `&[u8]`.
573pub struct SinglePartTxBuffer(Buffer<Tx>);
574
575impl SinglePartTxBuffer {
576    /// Creates a new [`SinglePartTxBuffer`] from a [`Buffer<Tx>`] if it is
577    /// backed by one part only.
578    pub fn new(buffer: Buffer<Tx>, len: usize) -> Option<Self> {
579        if buffer.alloc.len() != 1 {
580            None
581        } else {
582            let cap = usize::try_from(buffer.alloc.descriptor().data_length())
583                .expect("u32 must fit in a usize");
584            if cap < len {
585                return None;
586            }
587            Some(Self(buffer))
588        }
589    }
590
591    /// Converts back to a Tx buffer.
592    pub fn into_inner(self) -> Buffer<Tx> {
593        let Self(buffer) = self;
594        buffer
595    }
596}
597
598impl AsRef<[u8]> for SinglePartTxBuffer {
599    fn as_ref(&self) -> &[u8] {
600        // Safety: `SinglePartTxBuffer` is guaranteed to have exactly one part
601        // (verified on creation), so the first descriptor is always initialized.
602        let desc = unsafe { self.0.alloc.descs.storage[0].assume_init_ref() };
603        self.0.alloc.pool.get_slice(desc)
604    }
605}
606
607impl AsMut<[u8]> for SinglePartTxBuffer {
608    fn as_mut(&mut self) -> &mut [u8] {
609        // Safety: `SinglePartTxBuffer` is guaranteed to have exactly one part
610        // (verified on creation), so the first descriptor is always initialized.
611        let desc = unsafe { self.0.alloc.descs.storage[0].assume_init_mut() };
612        self.0.alloc.pool.get_slice_mut(desc)
613    }
614}
615
616impl packet::FragmentedBuffer for SinglePartTxBuffer {
617    fn len(&self) -> usize {
618        let desc = self.0.alloc.descriptor();
619        usize::try_from(desc.data_length()).expect("u32 must fit in a usize")
620    }
621
622    fn with_bytes<'a, R, F>(&'a self, f: F) -> R
623    where
624        F: for<'b> FnOnce(packet::FragmentedBytes<'b, 'a>) -> R,
625    {
626        f(packet::FragmentedBytes::new(&mut [self.as_ref()][..]))
627    }
628}
629
630/// A wrapper around [`Buffer`] for sequential I/O.
631///
632/// `T` must be a slice reference type, typically `&'a [u8]` for read-only
633/// operations, or `&'a mut [u8]` for read-write operations.
634pub struct BufferIO<T, K: AllocKind> {
635    parts: Chained<T>,
636    pos: usize,
637    len: usize,
638    _marker: std::marker::PhantomData<K>,
639}
640
641pub type BufferIORef<'a, K> = BufferIO<&'a [u8], K>;
642pub type BufferIOMut<'a, K> = BufferIO<&'a mut [u8], K>;
643
644impl<T> BufferIO<T, Tx>
645where
646    T: AsMut<[u8]>,
647{
648    /// Writes data from `src` into the TX buffer starting at the specified `offset`.
649    ///
650    /// This method is infallible. It returns the number of bytes successfully written.
651    ///
652    /// If the specified `offset` is greater than or equal to the total length of the
653    /// buffer, or if the buffer has no remaining capacity at the offset, `0` bytes
654    /// will be written.
655    ///
656    /// If `src` is larger than the remaining capacity of the buffer starting at
657    /// `offset`, a short write occurs: only the bytes that fit within the buffer
658    /// are written, and the returned value will be less than `src.len()`.
659    pub fn write_at(&mut self, mut offset: usize, src: &[u8]) -> usize {
660        let mut total = 0;
661
662        for slice in self.parts.iter_mut() {
663            let slice = slice.as_mut();
664            if offset < slice.len() {
665                let available = slice.len() - offset;
666                let to_copy = std::cmp::min(src.len() - total, available);
667                slice[offset..offset + to_copy].copy_from_slice(&src[total..total + to_copy]);
668                total += to_copy;
669                offset = 0;
670                if total == src.len() {
671                    break;
672                }
673            } else {
674                offset -= slice.len();
675            }
676        }
677        total
678    }
679}
680
681impl<T, K: AllocKind> BufferIO<T, K>
682where
683    T: AsRef<[u8]>,
684{
685    /// Reads data from the buffer starting at the specified `offset` into `dst`.
686    ///
687    /// This method is infallible. It returns the number of bytes successfully read.
688    ///
689    /// If the specified `offset` is greater than or equal to the total length of the
690    /// buffer, `0` bytes will be read.
691    ///
692    /// If the remaining data in the buffer starting at `offset` is less than the
693    /// size of `dst`, a short read occurs: only the available bytes are copied,
694    /// and the returned value will be less than `dst.len()`.
695    pub fn read_at(&self, mut offset: usize, dst: &mut [u8]) -> usize {
696        let mut total = 0;
697
698        for slice in self.parts.iter() {
699            let slice = slice.as_ref();
700            if offset < slice.len() {
701                let available = slice.len() - offset;
702                let to_copy = std::cmp::min(dst.len() - total, available);
703                dst[total..total + to_copy].copy_from_slice(&slice[offset..offset + to_copy]);
704                total += to_copy;
705                offset = 0;
706                if total == dst.len() {
707                    break;
708                }
709            } else {
710                offset -= slice.len();
711            }
712        }
713        total
714    }
715}
716
717impl AllocGuard<Rx> {
718    /// Turns a tx allocation into an rx one.
719    ///
720    /// To achieve this we have to convert the same amount of descriptors from
721    /// the tx pool to the rx pool to compensate for us being converted to tx
722    /// descriptors from rx ones.
723    async fn into_tx(mut self) -> AllocGuard<Tx> {
724        let mut tx = self.pool.alloc_tx(self.descs.len).await;
725        // [MaybeUninit<DescId<Tx>; 4] and [MaybeUninit<DescId<Rx>; 4] have the
726        // same memory layout because DescId is repr(transparent). So it is safe
727        // to transmute and swap the values between the storages. After the swap
728        // the drop implementation of self will return the descriptors back to
729        // rx pool.
730        std::mem::swap(&mut self.descs.storage, unsafe {
731            std::mem::transmute(&mut tx.descs.storage)
732        });
733        tx
734    }
735}
736
737/// A non-empty container that has at most [`netdev::MAX_DESCRIPTOR_CHAIN`] elements.
738struct Chained<T> {
739    storage: [MaybeUninit<T>; netdev::MAX_DESCRIPTOR_CHAIN as usize],
740    len: ChainLength,
741}
742
743impl<T> Deref for Chained<T> {
744    type Target = [T];
745
746    fn deref(&self) -> &Self::Target {
747        // Safety: `self.storage[..self.len]` is already initialized.
748        unsafe { std::mem::transmute(&self.storage[..self.len.into()]) }
749    }
750}
751
752impl<T> DerefMut for Chained<T> {
753    fn deref_mut(&mut self) -> &mut Self::Target {
754        // Safety: `self.storage[..self.len]` is already initialized.
755        unsafe { std::mem::transmute(&mut self.storage[..self.len.into()]) }
756    }
757}
758
759impl<T> Drop for Chained<T> {
760    fn drop(&mut self) {
761        // Safety: `self.deref_mut()` is a slice of all initialized elements.
762        unsafe {
763            std::ptr::drop_in_place(self.deref_mut());
764        }
765    }
766}
767
768impl<T: Debug> Debug for Chained<T> {
769    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
770        f.debug_list().entries(self.iter()).finish()
771    }
772}
773
774impl<T> Chained<T> {
775    #[allow(clippy::uninit_assumed_init)]
776    fn empty() -> Self {
777        // Create an uninitialized array of `MaybeUninit`. The `assume_init` is
778        // safe because the type we are claiming to have initialized here is a
779        // bunch of `MaybeUninit`s, which do not require initialization.
780        // TODO(https://fxbug.dev/42160423): use MaybeUninit::uninit_array once it
781        // is stablized.
782        // https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.uninit_array
783        Self { storage: unsafe { MaybeUninit::uninit().assume_init() }, len: ChainLength::ZERO }
784    }
785}
786
787impl<T> FromIterator<T> for Chained<T> {
788    /// # Panics
789    ///
790    /// if the iterator can yield more than MAX_DESCRIPTOR_CHAIN elements.
791    fn from_iter<I: IntoIterator<Item = T>>(elements: I) -> Self {
792        let mut result = Self::empty();
793        let mut len = 0u8;
794        for (idx, e) in elements.into_iter().enumerate() {
795            result.storage[idx] = MaybeUninit::new(e);
796            len += 1;
797        }
798        // `len` can not be larger than `MAX_DESCRIPTOR_CHAIN`, otherwise we can't
799        // get here due to the bound checks on `result.storage`.
800        result.len = ChainLength::try_from(len).unwrap();
801        result
802    }
803}
804
805impl<T> IntoIterator for Chained<T> {
806    type Item = T;
807    type IntoIter = ChainedIter<T>;
808
809    fn into_iter(mut self) -> Self::IntoIter {
810        let len = self.len;
811        self.len = ChainLength::ZERO;
812        // Safety: we have reset the length to zero, it is now safe to move out
813        // the values and set them to be uninitialized. The `assume_init` is
814        // safe because the type we are claiming to have initialized here is a
815        // bunch of `MaybeUninit`s, which do not require initialization.
816        // TODO(https://fxbug.dev/42160423): use MaybeUninit::uninit_array once it
817        // is stablized.
818        #[allow(clippy::uninit_assumed_init)]
819        let storage =
820            std::mem::replace(&mut self.storage, unsafe { MaybeUninit::uninit().assume_init() });
821        ChainedIter { storage, len, consumed: 0 }
822    }
823}
824
825struct ChainedIter<T> {
826    storage: [MaybeUninit<T>; netdev::MAX_DESCRIPTOR_CHAIN as usize],
827    len: ChainLength,
828    consumed: u8,
829}
830
831impl<T> Iterator for ChainedIter<T> {
832    type Item = T;
833
834    fn next(&mut self) -> Option<Self::Item> {
835        if self.consumed < self.len.get() {
836            // Safety: it is safe now to replace that slot with an uninitialized
837            // value because we will advance consumed by 1.
838            let value = unsafe {
839                std::mem::replace(
840                    &mut self.storage[usize::from(self.consumed)],
841                    MaybeUninit::uninit(),
842                )
843                .assume_init()
844            };
845            self.consumed += 1;
846            Some(value)
847        } else {
848            None
849        }
850    }
851
852    fn size_hint(&self) -> (usize, Option<usize>) {
853        let len = usize::from(self.len.get() - self.consumed);
854        (len, Some(len))
855    }
856}
857
858impl<T> ExactSizeIterator for ChainedIter<T> {}
859
860impl<T> Drop for ChainedIter<T> {
861    fn drop(&mut self) {
862        // Safety: `self.storage[self.consumed..self.len]` is initialized.
863        unsafe {
864            std::ptr::drop_in_place(std::mem::transmute::<_, &mut [T]>(
865                &mut self.storage[self.consumed.into()..self.len.into()],
866            ));
867        }
868    }
869}
870
871/// Guards the allocated descriptors; they will be freed when dropped.
872pub(in crate::session) struct AllocGuard<K: AllocKind> {
873    descs: Chained<DescId<K>>,
874    pool: Arc<Pool>,
875}
876
877impl<K: AllocKind> Debug for AllocGuard<K> {
878    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879        let Self { descs, pool: _ } = self;
880        f.debug_struct("AllocGuard").field("descs", descs).finish()
881    }
882}
883
884impl<K: AllocKind> AllocGuard<K> {
885    fn new(descs: Chained<DescId<K>>, pool: Arc<Pool>) -> Self {
886        Self { descs, pool }
887    }
888
889    /// Iterates over references to the descriptors.
890    fn descriptors(&self) -> impl Iterator<Item = DescRef<'_, K>> + '_ {
891        self.descs.iter().map(move |desc| self.pool.descriptors.borrow(desc))
892    }
893
894    /// Iterates over mutable references to the descriptors.
895    fn descriptors_mut(&mut self) -> impl Iterator<Item = DescRefMut<'_, K>> + '_ {
896        let descriptors = &self.pool.descriptors;
897        self.descs.iter_mut().map(move |desc| descriptors.borrow_mut(desc))
898    }
899
900    /// Gets a reference to the head descriptor.
901    fn descriptor(&self) -> DescRef<'_, K> {
902        self.descriptors().next().expect("descriptors must not be empty")
903    }
904
905    /// Gets a mutable reference to the head descriptor.
906    fn descriptor_mut(&mut self) -> DescRefMut<'_, K> {
907        self.descriptors_mut().next().expect("descriptors must not be empty")
908    }
909}
910
911#[derive(Debug, Clone, Copy, PartialEq, Eq)]
912struct DescriptorLayout {
913    chain_length: ChainLength,
914    head_length: u16,
915    data_length: u32,
916    tail_length: u16,
917}
918
919impl AllocGuard<Tx> {
920    /// Calculates the layout for each descriptor in this allocation chain.
921    ///
922    /// The layouts are calculated to satisfy the requested `target_len`, while
923    /// ensuring the session's `min_tx_head` and `min_tx_tail` requirements are
924    /// met.
925    ///
926    /// Returns `Err(Error::TxLength)` if the requirements cannot be met (e.g. if the
927    /// required tail padding overflows `u16`).
928    fn calculate_descriptor_layouts(&self, target_len: usize) -> Result<Chained<DescriptorLayout>> {
929        let len = self.len();
930        let BufferLayout { min_tx_head, min_tx_tail, length: buffer_length, .. } =
931            self.pool.buffer_layout;
932
933        let mut remaining_target = target_len;
934        (0..len)
935            .rev()
936            .map(|clen| {
937                let chain_length = ChainLength::try_from(clen).unwrap();
938                let head_length = if clen + 1 == len { min_tx_head } else { 0 };
939                let mut tail_length = if clen == 0 { min_tx_tail } else { 0 };
940
941                // head_length and tail_length. The check was done when the config
942                // for pool was created, so the subtraction won't overflow.
943                let available_bytes = u32::try_from(
944                    buffer_length - usize::from(head_length) - usize::from(tail_length),
945                )
946                .unwrap();
947
948                let data_length = match u32::try_from(remaining_target) {
949                    Ok(target) => {
950                        if target < available_bytes {
951                            // The target bytes are less than what is available,
952                            // we need to put the excess in the tail so that the
953                            // user cannot write more than they requested (or padded).
954                            let excess = available_bytes - target;
955                            tail_length = u16::try_from(excess)
956                                .ok_checked::<TryFromIntError>()
957                                .and_then(|tail_adjustment| {
958                                    tail_length.checked_add(tail_adjustment)
959                                })
960                                .ok_or(Error::TxLength)?;
961                        }
962                        target.min(available_bytes)
963                    }
964                    Err(TryFromIntError { .. }) => available_bytes,
965                };
966
967                let data_length_usize =
968                    usize::try_from(data_length).expect("u32 must fit in a usize");
969                remaining_target = remaining_target.saturating_sub(data_length_usize);
970
971                Ok::<_, Error>(DescriptorLayout {
972                    chain_length,
973                    head_length,
974                    data_length,
975                    tail_length,
976                })
977            })
978            .collect()
979    }
980
981    /// Initializes descriptors of a tx allocation.
982    ///
983    /// We choose to enforce and satisfy the `min_tx_data` layout requirement
984    /// (imposed by the device/driver) immediately during buffer allocation and
985    /// initialization here.
986    ///
987    /// Consequently, the allocated buffer's capacity (`target_len`) may be
988    /// larger than the `requested_bytes` if `requested_bytes` is smaller than
989    /// `min_tx_data`.
990    ///
991    /// While this means we might spend CPU cycles zero-padding buffers that are
992    /// subsequently dropped without being sent (a rare occurrence in typical
993    /// usage), this guarantees that buffer is always suitable for sending. This
994    /// also makes the transmit path (`Session::send`) infallible.
995    fn init(&mut self, requested_bytes: usize) -> Result<()> {
996        let min_tx_data = self.pool.buffer_layout.min_tx_data;
997        let target_len = requested_bytes.max(usize::from(min_tx_data));
998        let layouts = self.calculate_descriptor_layouts(target_len)?;
999
1000        let mut remaining_requested = requested_bytes;
1001
1002        for (desc_id, DescriptorLayout { chain_length, head_length, data_length, tail_length }) in
1003            self.descs.iter_mut().zip(layouts)
1004        {
1005            // Initialize the descriptor.
1006            {
1007                let mut descriptor = self.pool.descriptors.borrow_mut(desc_id);
1008                descriptor.initialize(chain_length, head_length, data_length, tail_length);
1009            }
1010
1011            let data_length_usize = usize::try_from(data_length).expect("u32 must fit in a usize");
1012            let requested_in_part = std::cmp::min(remaining_requested, data_length_usize);
1013            let pad_in_part = data_length_usize - requested_in_part;
1014
1015            // Zero-pad any excess capacity in this buffer part that was allocated
1016            // to satisfy the `min_tx_data` layout requirement but not requested by
1017            // the caller.
1018            //
1019            // We decided to pad the buffer on initialization because the lazy commit
1020            // model can only avoid padding for the following 2 cases:
1021            // 1) User only allocates but never sends.
1022            // 2) User writes past their requested size and meets the min_tx_data
1023            //    requirement.
1024            // Both should be uncommon, and in case 2) we can fix the client by
1025            // requesting a larger size to avoid padding.
1026            if pad_in_part > 0 {
1027                let slice = self.pool.get_slice_mut(desc_id);
1028                slice[requested_in_part..requested_in_part + pad_in_part].fill(0);
1029            }
1030
1031            remaining_requested -= requested_in_part;
1032        }
1033        Ok(())
1034    }
1035}
1036
1037impl<K: AllocKind> Drop for AllocGuard<K> {
1038    fn drop(&mut self) {
1039        if self.is_empty() {
1040            return;
1041        }
1042        K::free(private::Allocation(self));
1043    }
1044}
1045
1046impl<K: AllocKind> Deref for AllocGuard<K> {
1047    type Target = [DescId<K>];
1048
1049    fn deref(&self) -> &Self::Target {
1050        self.descs.deref()
1051    }
1052}
1053
1054impl<K: AllocKind> From<AllocGuard<K>> for Buffer<K> {
1055    fn from(alloc: AllocGuard<K>) -> Self {
1056        Self { alloc }
1057    }
1058}
1059
1060impl<T, K: AllocKind> Read for BufferIO<T, K>
1061where
1062    T: AsRef<[u8]>,
1063{
1064    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
1065        let read_len = self.read_at(self.pos, buf);
1066        self.pos += read_len;
1067        Ok(read_len)
1068    }
1069}
1070
1071impl<T> Write for BufferIO<T, Tx>
1072where
1073    T: AsMut<[u8]>,
1074{
1075    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
1076        let write_len = self.write_at(self.pos, buf);
1077        self.pos += write_len;
1078        Ok(write_len)
1079    }
1080
1081    fn flush(&mut self) -> std::io::Result<()> {
1082        Ok(())
1083    }
1084}
1085
1086impl<T, K: AllocKind> Seek for BufferIO<T, K> {
1087    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
1088        let pos = match pos {
1089            SeekFrom::Start(offset) => offset,
1090            SeekFrom::End(offset) => {
1091                let end = i64::try_from(self.len).unwrap();
1092                u64::try_from(end.wrapping_add(offset)).unwrap()
1093            }
1094            SeekFrom::Current(offset) => {
1095                let current = i64::try_from(self.pos).map_err(|TryFromIntError { .. }| {
1096                    std::io::Error::from(std::io::ErrorKind::InvalidInput)
1097                })?;
1098                u64::try_from(current.wrapping_add(offset)).unwrap()
1099            }
1100        };
1101        self.pos = usize::try_from(pos).map_err(|TryFromIntError { .. }| {
1102            std::io::Error::from(std::io::ErrorKind::InvalidInput)
1103        })?;
1104        Ok(pos)
1105    }
1106}
1107
1108/// A pending tx allocation request.
1109struct TxAllocReq {
1110    sender: Sender<AllocGuard<Tx>>,
1111    size: ChainLength,
1112}
1113
1114impl TxAllocReq {
1115    fn new(size: ChainLength) -> (Self, Receiver<AllocGuard<Tx>>) {
1116        let (sender, receiver) = channel();
1117        (TxAllocReq { sender, size }, receiver)
1118    }
1119
1120    /// Fulfills the pending request with an `AllocGuard`.
1121    ///
1122    /// If the request is already closed, the guard is simply dropped and
1123    /// returned to the queue.
1124    ///
1125    /// `fulfill` must *not* be called when the `guard`'s pool is holding the tx
1126    /// lock, since we may deadlock/panic upon the double tx lock acquisition.
1127    fn fulfill(self, guard: AllocGuard<Tx>) {
1128        let Self { sender, size: _ } = self;
1129        match sender.send(guard) {
1130            Ok(()) => (),
1131            Err(guard) => {
1132                // It's ok to just drop the guard here, it'll be returned to the
1133                // pool.
1134                drop(guard);
1135            }
1136        }
1137    }
1138}
1139
1140/// A module for sealed traits so that the user of this crate can not implement
1141/// [`AllocKind`] for anything than [`Rx`] and [`Tx`].
1142mod private {
1143    use super::{AllocKind, Rx, Tx};
1144    pub trait Sealed: 'static + Sized {}
1145    impl Sealed for Rx {}
1146    impl Sealed for Tx {}
1147
1148    // We can't leak a private type in a public trait, create an opaque private
1149    // new type for &mut super::AllocGuard so that we can mention it in the
1150    // AllocKind trait.
1151    pub struct Allocation<'a, K: AllocKind>(pub(super) &'a mut super::AllocGuard<K>);
1152}
1153
1154/// An allocation can have two kinds, this trait provides a way to project a
1155/// type ([`Rx`] or [`Tx`]) into a value.
1156pub trait AllocKind: private::Sealed {
1157    /// The reflected value of Self.
1158    const REFL: AllocKindRefl;
1159
1160    /// frees an allocation of the given kind.
1161    fn free(alloc: private::Allocation<'_, Self>);
1162}
1163
1164/// A tag to related types for Tx allocations.
1165pub enum Tx {}
1166/// A tag to related types for Rx allocations.
1167pub enum Rx {}
1168
1169/// The reflected value that allows inspection on an [`AllocKind`] type.
1170pub enum AllocKindRefl {
1171    Tx,
1172    Rx,
1173}
1174
1175impl AllocKindRefl {
1176    pub(in crate::session) fn as_str(&self) -> &'static str {
1177        match self {
1178            AllocKindRefl::Tx => "Tx",
1179            AllocKindRefl::Rx => "Rx",
1180        }
1181    }
1182}
1183
1184impl AllocKind for Tx {
1185    const REFL: AllocKindRefl = AllocKindRefl::Tx;
1186
1187    fn free(alloc: private::Allocation<'_, Self>) {
1188        let private::Allocation(AllocGuard { pool, descs }) = alloc;
1189        pool.free_tx(std::mem::replace(descs, Chained::empty()));
1190    }
1191}
1192
1193impl AllocKind for Rx {
1194    const REFL: AllocKindRefl = AllocKindRefl::Rx;
1195
1196    fn free(alloc: private::Allocation<'_, Self>) {
1197        let private::Allocation(AllocGuard { pool, descs }) = alloc;
1198        pool.free_rx(std::mem::replace(descs, Chained::empty()));
1199        pool.rx_leases.rx_complete();
1200    }
1201}
1202
1203/// An extracted struct containing state pertaining to watching rx leases.
1204pub(in crate::session) struct RxLeaseHandlingState {
1205    can_watch_rx_leases: AtomicBool,
1206    /// Keeps a rolling counter of received rx frames MINUS the target frame
1207    /// number of the current outstanding lease.
1208    ///
1209    /// When no leases are pending (via [`RxLeaseWatcher::wait_until`]),
1210    /// then this matches exactly the number of received frames.
1211    ///
1212    /// Otherwise, the lease is currently waiting for remaining `u64::MAX -
1213    /// rx_Frame_counter` frames. The logic depends on `AtomicU64` wrapping
1214    /// around as part of completing rx buffers.
1215    rx_frame_counter: AtomicU64,
1216    rx_lease_waker: AtomicWaker,
1217}
1218
1219impl RxLeaseHandlingState {
1220    fn new_with_flags(flags: netdev::SessionFlags) -> Self {
1221        Self::new_with_enabled(flags.contains(netdev::SessionFlags::RECEIVE_RX_POWER_LEASES))
1222    }
1223
1224    fn new_with_enabled(enabled: bool) -> Self {
1225        Self {
1226            can_watch_rx_leases: AtomicBool::new(enabled),
1227            rx_frame_counter: AtomicU64::new(0),
1228            rx_lease_waker: AtomicWaker::new(),
1229        }
1230    }
1231
1232    /// Increments the total receive frame counter and possibly wakes up a
1233    /// waiting lease yielder.
1234    fn rx_complete(&self) {
1235        let Self { can_watch_rx_leases: _, rx_frame_counter, rx_lease_waker } = self;
1236        let prev = rx_frame_counter.fetch_add(1, atomic::Ordering::SeqCst);
1237
1238        // See wait_until for details. We need to hit a waker whenever our add
1239        // wrapped the u64 back around to 0.
1240        if prev == u64::MAX {
1241            rx_lease_waker.wake();
1242        }
1243    }
1244}
1245
1246/// A trait allowing [`RxLeaseWatcher`] to be agnostic over how to get an
1247/// [`RxLeaseHandlingState`].
1248pub(in crate::session) trait RxLeaseHandlingStateContainer {
1249    fn lease_handling_state(&self) -> &RxLeaseHandlingState;
1250}
1251
1252impl<T: Borrow<RxLeaseHandlingState>> RxLeaseHandlingStateContainer for T {
1253    fn lease_handling_state(&self) -> &RxLeaseHandlingState {
1254        self.borrow()
1255    }
1256}
1257
1258impl RxLeaseHandlingStateContainer for Arc<Pool> {
1259    fn lease_handling_state(&self) -> &RxLeaseHandlingState {
1260        &self.rx_leases
1261    }
1262}
1263
1264/// A type safe-wrapper around a single lease watcher per `Pool`.
1265pub(in crate::session) struct RxLeaseWatcher<T> {
1266    state: T,
1267}
1268
1269impl<T: RxLeaseHandlingStateContainer> RxLeaseWatcher<T> {
1270    /// Creates a new lease watcher.
1271    ///
1272    /// # Panics
1273    ///
1274    /// Panics if an [`RxLeaseWatcher`] has already been created for the given
1275    /// pool or the pool was not configured for it.
1276    pub(in crate::session) fn new(state: T) -> Self {
1277        assert!(
1278            state.lease_handling_state().can_watch_rx_leases.swap(false, atomic::Ordering::SeqCst),
1279            "can't watch rx leases"
1280        );
1281        Self { state }
1282    }
1283
1284    /// Called by sessions to wait until `hold_until_frame` is fulfilled to
1285    /// yield leases out.
1286    ///
1287    /// Blocks until `hold_until_frame`-th rx buffer has been released.
1288    ///
1289    /// Note that this method takes `&mut self` because only one
1290    /// [`RxLeaseWatcher`] may be created by lease handling state, and exclusive
1291    /// access to it is required to watch lease completion.
1292    pub(in crate::session) async fn wait_until(&mut self, hold_until_frame: u64) {
1293        // A note about wrap-arounds.
1294        //
1295        // We're assuming the frame counter will never wrap around for
1296        // correctness here. This should be fine, even assuming a packet
1297        // rate of 1 million pps it'd take almost 600k years for this counter
1298        // to wrap around:
1299        // - 2^64 / 1e6 / 60 / 60 / 24 / 365 ~ 584e3.
1300
1301        let RxLeaseHandlingState { can_watch_rx_leases: _, rx_frame_counter, rx_lease_waker } =
1302            self.state.lease_handling_state();
1303
1304        let prev = rx_frame_counter.fetch_sub(hold_until_frame, atomic::Ordering::SeqCst);
1305        // After having subtracted the waiting value we *must always restore the
1306        // value* on return, even if the future is not polled to completion.
1307        let _guard = scopeguard::guard((), |()| {
1308            let _: u64 = rx_frame_counter.fetch_add(hold_until_frame, atomic::Ordering::SeqCst);
1309        });
1310
1311        // Lease is ready to be fulfilled.
1312        if prev >= hold_until_frame {
1313            return;
1314        }
1315        // Threshold is a wrapped around subtraction. So now we must wait
1316        // until the read value from the atomic is LESS THAN the threshold.
1317        let threshold = prev.wrapping_sub(hold_until_frame);
1318        futures::future::poll_fn(|cx| {
1319            let v = rx_frame_counter.load(atomic::Ordering::SeqCst);
1320            if v < threshold {
1321                return Poll::Ready(());
1322            }
1323            rx_lease_waker.register(cx.waker());
1324            let v = rx_frame_counter.load(atomic::Ordering::SeqCst);
1325            if v < threshold {
1326                return Poll::Ready(());
1327            }
1328            Poll::Pending
1329        })
1330        .await;
1331    }
1332}
1333
1334#[cfg(test)]
1335mod tests {
1336
1337    use super::*;
1338
1339    use assert_matches::assert_matches;
1340    use fuchsia_async as fasync;
1341    use futures::future::FutureExt;
1342    use test_case::test_case;
1343
1344    use std::collections::HashSet;
1345    use std::num::{NonZeroU16, NonZeroU64, NonZeroUsize};
1346    use std::pin::pin;
1347    use std::task::{Poll, Waker};
1348
1349    const DEFAULT_MIN_TX_BUFFER_HEAD: u16 = 4;
1350    const DEFAULT_MIN_TX_BUFFER_TAIL: u16 = 8;
1351    // Safety: These are safe because none of the values are zero.
1352    const DEFAULT_BUFFER_LENGTH: NonZeroUsize = NonZeroUsize::new(64).unwrap();
1353    const DEFAULT_TX_BUFFERS: NonZeroU16 = NonZeroU16::new(8).unwrap();
1354    const DEFAULT_RX_BUFFERS: NonZeroU16 = NonZeroU16::new(8).unwrap();
1355    const MAX_BUFFER_BYTES: usize = DEFAULT_BUFFER_LENGTH.get()
1356        * netdev::MAX_DESCRIPTOR_CHAIN as usize
1357        - DEFAULT_MIN_TX_BUFFER_HEAD as usize
1358        - DEFAULT_MIN_TX_BUFFER_TAIL as usize;
1359
1360    const SENTINEL_BYTE: u8 = 0xab;
1361    const WRITE_BYTE: u8 = 1;
1362    const PAD_BYTE: u8 = 0;
1363
1364    const DEFAULT_CONFIG: Config = Config {
1365        buffer_stride: NonZeroU64::new(DEFAULT_BUFFER_LENGTH.get() as u64).unwrap(),
1366        num_rx_buffers: DEFAULT_RX_BUFFERS,
1367        num_tx_buffers: DEFAULT_TX_BUFFERS,
1368        options: netdev::SessionFlags::empty(),
1369        buffer_layout: BufferLayout {
1370            length: DEFAULT_BUFFER_LENGTH.get(),
1371            min_tx_head: DEFAULT_MIN_TX_BUFFER_HEAD,
1372            min_tx_tail: DEFAULT_MIN_TX_BUFFER_TAIL,
1373            min_tx_data: 0,
1374        },
1375    };
1376
1377    impl Pool {
1378        fn new_test_default() -> Arc<Self> {
1379            let (pool, _descriptors, _data) =
1380                Pool::new(DEFAULT_CONFIG).expect("failed to create default pool");
1381            pool
1382        }
1383
1384        async fn alloc_tx_checked(self: &Arc<Self>, n: u8) -> AllocGuard<Tx> {
1385            self.alloc_tx(ChainLength::try_from(n).expect("failed to convert to chain length"))
1386                .await
1387        }
1388
1389        fn alloc_tx_now_or_never(self: &Arc<Self>, n: u8) -> Option<AllocGuard<Tx>> {
1390            self.alloc_tx_checked(n).now_or_never()
1391        }
1392
1393        fn alloc_tx_all(self: &Arc<Self>, n: u8) -> Vec<AllocGuard<Tx>> {
1394            std::iter::from_fn(|| self.alloc_tx_now_or_never(n)).collect()
1395        }
1396
1397        fn alloc_tx_buffer_now_or_never(self: &Arc<Self>, num_bytes: usize) -> Option<Buffer<Tx>> {
1398            self.alloc_tx_buffer(num_bytes)
1399                .now_or_never()
1400                .transpose()
1401                .expect("invalid arguments for alloc_tx_buffer")
1402        }
1403
1404        fn set_min_tx_buffer_length(self: &mut Arc<Self>, length: usize) {
1405            Arc::get_mut(self).unwrap().buffer_layout.min_tx_data = length;
1406        }
1407
1408        fn fill_sentinel_bytes(&mut self) {
1409            // Safety: We have mut reference to Pool, so we get to modify the
1410            // VMO pointed by self.base.
1411            unsafe { std::ptr::write_bytes(self.base.as_ptr(), SENTINEL_BYTE, self.bytes) };
1412        }
1413    }
1414
1415    impl Buffer<Tx> {
1416        // Write a byte at offset, the result buffer should be pad_size long, with
1417        // 0..offset being the SENTINEL_BYTE, offset being the WRITE_BYTE and the
1418        // rest being PAD_BYTE.
1419        fn check_write_and_pad(&mut self, offset: usize, pad_size: usize) {
1420            {
1421                let mut io = self.io_mut();
1422                assert_eq!(io.write_at(offset, &[WRITE_BYTE][..]), 1);
1423            }
1424            assert_eq!(self.len(), pad_size);
1425            // An arbitrary value that is not SENTINAL/WRITE/PAD_BYTE so that
1426            // we can make sure the write really happened.
1427            const INIT_BYTE: u8 = 42;
1428            let mut read_buf = vec![INIT_BYTE; pad_size];
1429            assert_eq!(self.io().read_at(0, &mut read_buf[..]), read_buf.len());
1430            for (idx, byte) in read_buf.iter().enumerate() {
1431                if idx < offset {
1432                    assert_eq!(*byte, SENTINEL_BYTE);
1433                } else if idx == offset {
1434                    assert_eq!(*byte, WRITE_BYTE);
1435                } else {
1436                    assert_eq!(*byte, PAD_BYTE);
1437                }
1438            }
1439        }
1440    }
1441
1442    impl<K, I, T> PartialEq<T> for Chained<DescId<K>>
1443    where
1444        K: AllocKind,
1445        I: ExactSizeIterator<Item = u16>,
1446        T: Copy + IntoIterator<IntoIter = I>,
1447    {
1448        fn eq(&self, other: &T) -> bool {
1449            let iter = other.into_iter();
1450            if usize::from(self.len) != iter.len() {
1451                return false;
1452            }
1453            self.iter().zip(iter).all(|(l, r)| l.get() == r)
1454        }
1455    }
1456
1457    impl Debug for TxAllocReq {
1458        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1459            let TxAllocReq { sender: _, size } = self;
1460            f.debug_struct("TxAllocReq").field("size", &size).finish_non_exhaustive()
1461        }
1462    }
1463
1464    #[test]
1465    fn alloc_tx_distinct() {
1466        let pool = Pool::new_test_default();
1467        let allocated = pool.alloc_tx_all(1);
1468        assert_eq!(allocated.len(), DEFAULT_TX_BUFFERS.get().into());
1469        let distinct = allocated
1470            .iter()
1471            .map(|alloc| {
1472                assert_eq!(alloc.descs.len(), 1);
1473                alloc.descs[0].get()
1474            })
1475            .collect::<HashSet<u16>>();
1476        assert_eq!(allocated.len(), distinct.len());
1477    }
1478
1479    #[test]
1480    fn alloc_tx_free_len() {
1481        let pool = Pool::new_test_default();
1482        {
1483            let allocated = pool.alloc_tx_all(2);
1484            assert_eq!(
1485                allocated.iter().fold(0, |acc, a| { acc + a.descs.len() }),
1486                DEFAULT_TX_BUFFERS.get().into()
1487            );
1488            assert_eq!(pool.tx_alloc_state.lock().free_list.len, 0);
1489        }
1490        assert_eq!(pool.tx_alloc_state.lock().free_list.len, DEFAULT_TX_BUFFERS.get());
1491    }
1492
1493    #[test]
1494    fn alloc_tx_chain() {
1495        let pool = Pool::new_test_default();
1496        let allocated = pool.alloc_tx_all(3);
1497        assert_eq!(allocated.len(), usize::from(DEFAULT_TX_BUFFERS.get()) / 3);
1498        assert_matches!(pool.alloc_tx_now_or_never(3), None);
1499        assert_matches!(pool.alloc_tx_now_or_never(2), Some(a) if a.descs.len() == 2);
1500    }
1501
1502    #[test]
1503    fn alloc_tx_many() {
1504        let pool = Pool::new_test_default();
1505        let data_len = u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1506            - u32::from(DEFAULT_MIN_TX_BUFFER_HEAD)
1507            - u32::from(DEFAULT_MIN_TX_BUFFER_TAIL);
1508        let data_len = usize::try_from(data_len).unwrap();
1509        let mut buffers = pool
1510            .alloc_tx_buffers(data_len)
1511            .now_or_never()
1512            .expect("failed to alloc")
1513            .unwrap()
1514            // Collect into a vec so we keep the buffers alive, otherwise they
1515            // are immediately returned to the pool.
1516            .collect::<Result<Vec<_>>>()
1517            .expect("buffer error");
1518        assert_eq!(buffers.len(), DEFAULT_TX_BUFFERS.get().into());
1519
1520        // We have all the buffers, which means allocating more should not
1521        // resolve.
1522        assert!(pool.alloc_tx_buffers(data_len).now_or_never().is_none());
1523
1524        // If we release a single buffer we should be able to retrieve it again.
1525        assert_matches!(buffers.pop(), Some(_));
1526        let mut more_buffers =
1527            pool.alloc_tx_buffers(data_len).now_or_never().expect("failed to alloc").unwrap();
1528        let buffer = assert_matches!(more_buffers.next(), Some(Ok(b)) => b);
1529        assert_matches!(more_buffers.next(), None);
1530        // The iterator is fused, so None is yielded even after dropping the
1531        // buffer.
1532        drop(buffer);
1533        assert_matches!(more_buffers.next(), None);
1534    }
1535
1536    #[test]
1537    fn alloc_tx_after_free() {
1538        let pool = Pool::new_test_default();
1539        let mut allocated = pool.alloc_tx_all(1);
1540        assert_matches!(pool.alloc_tx_now_or_never(2), None);
1541        {
1542            let _drained = allocated.drain(..2);
1543        }
1544        assert_matches!(pool.alloc_tx_now_or_never(2), Some(a) if a.descs.len() == 2);
1545    }
1546
1547    #[test]
1548    fn blocking_alloc_tx() {
1549        let mut executor = fasync::TestExecutor::new();
1550        let pool = Pool::new_test_default();
1551        let mut allocated = pool.alloc_tx_all(1);
1552        let alloc_fut = pool.alloc_tx_checked(1);
1553        let mut alloc_fut = pin!(alloc_fut);
1554        // The allocation should block.
1555        assert_matches!(executor.run_until_stalled(&mut alloc_fut), Poll::Pending);
1556        // And the allocation request should be queued.
1557        assert!(!pool.tx_alloc_state.lock().requests.is_empty());
1558        let freed = allocated
1559            .pop()
1560            .expect("no fulfulled allocations")
1561            .iter()
1562            .map(|x| x.get())
1563            .collect::<Chained<_>>();
1564        let same_as_freed =
1565            |descs: &Chained<DescId<Tx>>| descs.iter().map(|x| x.get()).eq(freed.iter().copied());
1566        // Now the task should be able to continue.
1567        assert_matches!(
1568            &executor.run_until_stalled(&mut alloc_fut),
1569            Poll::Ready(AllocGuard{ descs, pool: _ }) if same_as_freed(descs)
1570        );
1571        // And the queued request should now be removed.
1572        assert!(pool.tx_alloc_state.lock().requests.is_empty());
1573    }
1574
1575    #[test]
1576    fn blocking_alloc_tx_cancel_before_free() {
1577        let mut executor = fasync::TestExecutor::new();
1578        let pool = Pool::new_test_default();
1579        let mut allocated = pool.alloc_tx_all(1);
1580        {
1581            let alloc_fut = pool.alloc_tx_checked(1);
1582            let mut alloc_fut = pin!(alloc_fut);
1583            assert_matches!(executor.run_until_stalled(&mut alloc_fut), Poll::Pending);
1584            assert_matches!(
1585                pool.tx_alloc_state.lock().requests.as_slices(),
1586                (&[ref req1, ref req2], &[]) if req1.size.get() == 1 && req2.size.get() == 1
1587            );
1588        }
1589        assert_matches!(
1590            allocated.pop(),
1591            Some(AllocGuard { ref descs, pool: ref p })
1592                if descs == &[DEFAULT_TX_BUFFERS.get() - 1] && Arc::ptr_eq(p, &pool)
1593        );
1594        let state = pool.tx_alloc_state.lock();
1595        assert_eq!(state.free_list.len, 1);
1596        assert!(state.requests.is_empty());
1597    }
1598
1599    #[test]
1600    fn blocking_alloc_tx_cancel_after_free() {
1601        let mut executor = fasync::TestExecutor::new();
1602        let pool = Pool::new_test_default();
1603        let mut allocated = pool.alloc_tx_all(1);
1604        {
1605            let alloc_fut = pool.alloc_tx_checked(1);
1606            let mut alloc_fut = pin!(alloc_fut);
1607            assert_matches!(executor.run_until_stalled(&mut alloc_fut), Poll::Pending);
1608            assert_matches!(
1609                pool.tx_alloc_state.lock().requests.as_slices(),
1610                (&[ref req1, ref req2], &[]) if req1.size.get() == 1 && req2.size.get() == 1
1611            );
1612            assert_matches!(
1613                allocated.pop(),
1614                Some(AllocGuard { ref descs, pool: ref p })
1615                    if descs == &[DEFAULT_TX_BUFFERS.get() - 1] && Arc::ptr_eq(p, &pool)
1616            );
1617        }
1618        let state = pool.tx_alloc_state.lock();
1619        assert_eq!(state.free_list.len, 1);
1620        assert!(state.requests.is_empty());
1621    }
1622
1623    #[test]
1624    fn multiple_blocking_alloc_tx_fulfill_order() {
1625        const TASKS_TOTAL: usize = 3;
1626        let mut executor = fasync::TestExecutor::new();
1627        let pool = Pool::new_test_default();
1628        let mut allocated = pool.alloc_tx_all(1);
1629        let mut alloc_futs = (1..=TASKS_TOTAL)
1630            .rev()
1631            .map(|x| {
1632                let pool = pool.clone();
1633                (x, Box::pin(async move { pool.alloc_tx_checked(x.try_into().unwrap()).await }))
1634            })
1635            .collect::<Vec<_>>();
1636
1637        for (idx, (req_size, task)) in alloc_futs.iter_mut().enumerate() {
1638            assert_matches!(executor.run_until_stalled(task), Poll::Pending);
1639            // assert that the tasks are sorted decreasing on the requested size.
1640            assert_eq!(idx + *req_size, TASKS_TOTAL);
1641        }
1642        {
1643            let state = pool.tx_alloc_state.lock();
1644            // The first pending request was introduced by `alloc_tx_all`.
1645            assert_eq!(state.requests.len(), TASKS_TOTAL + 1);
1646            let mut requests = state.requests.iter();
1647            // It should already be cancelled because the requesting future is
1648            // already dropped.
1649            assert!(requests.next().unwrap().sender.is_canceled());
1650            // The rest of the requests must not be cancelled.
1651            assert!(requests.all(|req| !req.sender.is_canceled()))
1652        }
1653
1654        let mut to_free = Vec::new();
1655        let mut freed = 0;
1656        for free_size in (1..=TASKS_TOTAL).rev() {
1657            let (_req_size, mut task) = alloc_futs.remove(0);
1658            for _ in 1..free_size {
1659                freed += 1;
1660                assert_matches!(
1661                    allocated.pop(),
1662                    Some(AllocGuard { ref descs, pool: ref p })
1663                        if descs == &[DEFAULT_TX_BUFFERS.get() - freed] && Arc::ptr_eq(p, &pool)
1664                );
1665                assert_matches!(executor.run_until_stalled(&mut task), Poll::Pending);
1666            }
1667            freed += 1;
1668            assert_matches!(
1669                allocated.pop(),
1670                Some(AllocGuard { ref descs, pool: ref p })
1671                    if descs == &[DEFAULT_TX_BUFFERS.get() - freed] && Arc::ptr_eq(p, &pool)
1672            );
1673            match executor.run_until_stalled(&mut task) {
1674                Poll::Ready(alloc) => {
1675                    assert_eq!(alloc.len(), free_size);
1676                    // Don't return the allocation to the pool now.
1677                    to_free.push(alloc);
1678                }
1679                Poll::Pending => panic!("The request should be fulfilled"),
1680            }
1681            // The rest of requests can not be fulfilled.
1682            for (_req_size, task) in alloc_futs.iter_mut() {
1683                assert_matches!(executor.run_until_stalled(task), Poll::Pending);
1684            }
1685        }
1686        assert!(pool.tx_alloc_state.lock().requests.is_empty());
1687    }
1688
1689    #[test]
1690    fn singleton_tx_layout() {
1691        let pool = Pool::new_test_default();
1692        let buffers = std::iter::from_fn(|| {
1693            let data_len = u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1694                - u32::from(DEFAULT_MIN_TX_BUFFER_HEAD)
1695                - u32::from(DEFAULT_MIN_TX_BUFFER_TAIL);
1696            pool.alloc_tx_buffer_now_or_never(usize::try_from(data_len).unwrap()).map(|buffer| {
1697                assert_eq!(buffer.alloc.descriptors().count(), 1);
1698                let offset = u64::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1699                    * u64::from(buffer.alloc[0].get());
1700                {
1701                    let descriptor = buffer.alloc.descriptor();
1702                    assert_matches!(descriptor.chain_length(), Ok(ChainLength::ZERO));
1703                    assert_eq!(descriptor.head_length(), DEFAULT_MIN_TX_BUFFER_HEAD);
1704                    assert_eq!(descriptor.tail_length(), DEFAULT_MIN_TX_BUFFER_TAIL);
1705                    assert_eq!(descriptor.data_length(), data_len);
1706                    assert_eq!(descriptor.offset(), offset);
1707                }
1708
1709                {
1710                    let mut slices = buffer.parts();
1711                    let slice = slices.next().expect("should have one slice");
1712                    assert_matches!(slices.next(), None);
1713                    assert_eq!(slice.len(), usize::try_from(data_len).unwrap());
1714                    assert_eq!(
1715                        slice.as_ptr(),
1716                        pool.base.as_ptr().wrapping_add(
1717                            usize::try_from(offset).unwrap()
1718                                + usize::from(DEFAULT_MIN_TX_BUFFER_HEAD),
1719                        )
1720                    );
1721                }
1722                buffer
1723            })
1724        })
1725        .collect::<Vec<_>>();
1726        assert_eq!(buffers.len(), usize::from(DEFAULT_TX_BUFFERS.get()));
1727    }
1728
1729    #[test]
1730    fn chained_tx_layout() {
1731        let pool = Pool::new_test_default();
1732        let alloc_len = 4 * DEFAULT_BUFFER_LENGTH.get()
1733            - usize::from(DEFAULT_MIN_TX_BUFFER_HEAD)
1734            - usize::from(DEFAULT_MIN_TX_BUFFER_TAIL);
1735        let buffers = std::iter::from_fn(|| {
1736            pool.alloc_tx_buffer_now_or_never(alloc_len).map(|buffer| {
1737                assert_eq!(buffer.parts().count(), 4);
1738                for (idx, (descriptor, slice)) in
1739                    buffer.alloc.descriptors().zip(buffer.parts()).enumerate()
1740                {
1741                    let chain_length = ChainLength::try_from(buffer.alloc.len() - idx - 1).unwrap();
1742                    let head_length = if idx == 0 { DEFAULT_MIN_TX_BUFFER_HEAD } else { 0 };
1743                    let tail_length = if chain_length == ChainLength::ZERO {
1744                        DEFAULT_MIN_TX_BUFFER_TAIL
1745                    } else {
1746                        0
1747                    };
1748                    let data_len = u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1749                        - u32::from(head_length)
1750                        - u32::from(tail_length);
1751                    let offset = u64::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1752                        * u64::from(buffer.alloc[idx].get());
1753                    assert_eq!(descriptor.chain_length().unwrap(), chain_length);
1754                    assert_eq!(descriptor.head_length(), head_length);
1755                    assert_eq!(descriptor.tail_length(), tail_length);
1756                    assert_eq!(descriptor.offset(), offset);
1757                    assert_eq!(descriptor.data_length(), data_len);
1758                    if chain_length != ChainLength::ZERO {
1759                        assert_eq!(descriptor.nxt(), Some(buffer.alloc[idx + 1].get()));
1760                    }
1761
1762                    assert_eq!(slice.len(), usize::try_from(data_len).unwrap());
1763                    assert_eq!(
1764                        slice.as_ptr(),
1765                        pool.base.as_ptr().wrapping_add(
1766                            usize::try_from(offset).unwrap() + usize::from(head_length),
1767                        )
1768                    );
1769                }
1770                buffer
1771            })
1772        })
1773        .collect::<Vec<_>>();
1774        assert_eq!(buffers.len(), usize::from(DEFAULT_TX_BUFFERS.get()) / 4);
1775    }
1776
1777    #[test]
1778    fn rx_distinct() {
1779        let pool = Pool::new_test_default();
1780        let mut guard = pool.rx_pending.inner.lock();
1781        let (descs, _): &mut (Vec<_>, Option<Waker>) = &mut *guard;
1782        assert_eq!(descs.len(), usize::from(DEFAULT_RX_BUFFERS.get()));
1783        let distinct = descs.iter().map(|desc| desc.get()).collect::<HashSet<u16>>();
1784        assert_eq!(descs.len(), distinct.len());
1785    }
1786
1787    #[test]
1788    fn alloc_rx_layout() {
1789        let pool = Pool::new_test_default();
1790        let mut guard = pool.rx_pending.inner.lock();
1791        let (descs, _): &mut (Vec<_>, Option<Waker>) = &mut *guard;
1792        assert_eq!(descs.len(), usize::from(DEFAULT_RX_BUFFERS.get()));
1793        for desc in descs.iter() {
1794            let descriptor = pool.descriptors.borrow(desc);
1795            let offset =
1796                u64::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap() * u64::from(desc.get());
1797            assert_matches!(descriptor.chain_length(), Ok(ChainLength::ZERO));
1798            assert_eq!(descriptor.head_length(), 0);
1799            assert_eq!(descriptor.tail_length(), 0);
1800            assert_eq!(descriptor.offset(), offset);
1801            assert_eq!(
1802                descriptor.data_length(),
1803                u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1804            );
1805        }
1806    }
1807
1808    #[test]
1809    fn buffer_read_at_write_at() {
1810        let pool = Pool::new_test_default();
1811        let alloc_bytes = DEFAULT_BUFFER_LENGTH.get();
1812        let mut buffer =
1813            pool.alloc_tx_buffer_now_or_never(alloc_bytes).expect("failed to allocate");
1814        // Because we have to accommodate the space for head and tail, there
1815        // would be 2 parts instead of 1.
1816        assert_eq!(buffer.parts().count(), 2);
1817        assert_eq!(buffer.len(), alloc_bytes);
1818        let write_buf = (0..u8::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()).collect::<Vec<_>>();
1819        assert_eq!(buffer.io_mut().write_at(0, &write_buf[..]), write_buf.len());
1820        let mut read_buf = [0xff; DEFAULT_BUFFER_LENGTH.get()];
1821        assert_eq!(buffer.io().read_at(0, &mut read_buf[..]), read_buf.len());
1822        for (idx, byte) in read_buf.iter().enumerate() {
1823            assert_eq!(*byte, write_buf[idx]);
1824        }
1825    }
1826
1827    #[test]
1828    fn buffer_write_at_short() {
1829        let pool = Pool::new_test_default();
1830        let alloc_bytes = DEFAULT_BUFFER_LENGTH.get();
1831        let mut buffer =
1832            pool.alloc_tx_buffer_now_or_never(alloc_bytes).expect("failed to allocate");
1833        assert_eq!(buffer.parts().count(), 2);
1834        assert_eq!(buffer.len(), alloc_bytes);
1835
1836        let write_buf = vec![WRITE_BYTE; alloc_bytes + 10];
1837
1838        // Test short write (writing more than buffer capacity)
1839        assert_eq!(buffer.io_mut().write_at(0, &write_buf[..]), alloc_bytes);
1840
1841        // Verify short write
1842        let mut read_buf = vec![0; alloc_bytes];
1843        assert_eq!(buffer.io().read_at(0, &mut read_buf[..]), alloc_bytes);
1844        for byte in read_buf.iter() {
1845            assert_eq!(*byte, WRITE_BYTE);
1846        }
1847
1848        // Test write with offset past end
1849        assert_eq!(buffer.io_mut().write_at(alloc_bytes + 1, &write_buf[..]), 0);
1850
1851        // Test write with offset inside buffer but src extending past end
1852        let offset = alloc_bytes / 2;
1853        let expected_write = alloc_bytes - offset;
1854        let write_buf = vec![2; alloc_bytes]; // Different byte to distinguish
1855        assert_eq!(buffer.io_mut().write_at(offset, &write_buf[..]), expected_write);
1856
1857        // Verify the write
1858        let mut read_buf = vec![0; alloc_bytes];
1859        assert_eq!(buffer.io().read_at(0, &mut read_buf[..]), alloc_bytes);
1860        for (idx, byte) in read_buf.iter().enumerate() {
1861            if idx < offset {
1862                assert_eq!(*byte, WRITE_BYTE);
1863            } else {
1864                assert_eq!(*byte, 2);
1865            }
1866        }
1867    }
1868
1869    #[test]
1870    fn buffer_read_at_short() {
1871        let pool = Pool::new_test_default();
1872        let alloc_bytes = DEFAULT_BUFFER_LENGTH.get();
1873        let mut buffer =
1874            pool.alloc_tx_buffer_now_or_never(alloc_bytes).expect("failed to allocate");
1875        assert_eq!(buffer.parts().count(), 2);
1876        assert_eq!(buffer.len(), alloc_bytes);
1877
1878        let write_buf = vec![WRITE_BYTE; alloc_bytes];
1879        assert_eq!(buffer.io_mut().write_at(0, &write_buf[..]), alloc_bytes);
1880
1881        // Test short read (reading more than buffer capacity)
1882        let mut read_buf = vec![0xff; alloc_bytes + 10];
1883        assert_eq!(buffer.io().read_at(0, &mut read_buf[..]), alloc_bytes);
1884        for (idx, byte) in read_buf.iter().enumerate() {
1885            if idx < alloc_bytes {
1886                assert_eq!(*byte, WRITE_BYTE);
1887            } else {
1888                assert_eq!(*byte, 0xff);
1889            }
1890        }
1891
1892        // Test read with offset past end
1893        assert_eq!(buffer.io().read_at(alloc_bytes + 1, &mut read_buf[..]), 0);
1894
1895        // Test read with offset inside buffer but dst extending past end
1896        let offset = alloc_bytes / 2;
1897        let expected_read = alloc_bytes - offset;
1898        let mut read_buf = vec![0xff; alloc_bytes];
1899        assert_eq!(buffer.io().read_at(offset, &mut read_buf[..]), expected_read);
1900        for (idx, byte) in read_buf.iter().enumerate() {
1901            if idx < expected_read {
1902                assert_eq!(*byte, WRITE_BYTE);
1903            } else {
1904                assert_eq!(*byte, 0xff);
1905            }
1906        }
1907    }
1908
1909    #[test]
1910    fn buffer_read_write_seek() {
1911        let pool = Pool::new_test_default();
1912        let alloc_bytes = DEFAULT_BUFFER_LENGTH.get();
1913        let mut buffer =
1914            pool.alloc_tx_buffer_now_or_never(alloc_bytes).expect("failed to allocate");
1915        // Because we have to accommodate the space for head and tail, there
1916        // would be 2 parts instead of 1.
1917        assert_eq!(buffer.parts().count(), 2);
1918        assert_eq!(buffer.len(), alloc_bytes);
1919        let write_buf = (0..u8::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()).collect::<Vec<_>>();
1920
1921        let mut io = buffer.io_mut();
1922
1923        assert_eq!(io.write(&write_buf[..]).expect("failed to write into buffer"), write_buf.len());
1924        const SEEK_FROM_END: usize = 64;
1925        const READ_LEN: usize = 12;
1926        assert_eq!(
1927            io.seek(SeekFrom::End(-i64::try_from(SEEK_FROM_END).unwrap())).unwrap(),
1928            u64::try_from(io.len - SEEK_FROM_END).unwrap()
1929        );
1930        let mut read_buf = [0xff; READ_LEN];
1931        assert_eq!(io.read(&mut read_buf[..]).expect("failed to read from buffer"), read_buf.len());
1932        assert_eq!(&write_buf[..READ_LEN], &read_buf[..]);
1933    }
1934
1935    #[test_case(32; "single buffer part")]
1936    #[test_case(MAX_BUFFER_BYTES; "multiple buffer parts")]
1937    fn buffer_pad(pad_size: usize) {
1938        let mut pool = Pool::new_test_default();
1939        pool.set_min_tx_buffer_length(pad_size);
1940        for offset in 0..pad_size {
1941            Arc::get_mut(&mut pool)
1942                .expect("there are multiple owners of the underlying VMO")
1943                .fill_sentinel_bytes();
1944            let mut buffer =
1945                pool.alloc_tx_buffer_now_or_never(offset + 1).expect("failed to allocate buffer");
1946            buffer.check_write_and_pad(offset, pad_size);
1947        }
1948    }
1949
1950    #[test]
1951    fn buffer_pad_grow() {
1952        const BUFFER_PARTS: u8 = 3;
1953        let mut pool = Pool::new_test_default();
1954        let pad_size = u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap()
1955            * u32::from(BUFFER_PARTS)
1956            - u32::from(DEFAULT_MIN_TX_BUFFER_HEAD)
1957            - u32::from(DEFAULT_MIN_TX_BUFFER_TAIL);
1958        pool.set_min_tx_buffer_length(pad_size.try_into().unwrap());
1959        for offset in 0..pad_size - u32::try_from(DEFAULT_BUFFER_LENGTH.get()).unwrap() {
1960            Arc::get_mut(&mut pool)
1961                .expect("there are multiple owners of the underlying VMO")
1962                .fill_sentinel_bytes();
1963            let mut alloc =
1964                pool.alloc_tx_now_or_never(BUFFER_PARTS).expect("failed to alloc descriptors");
1965            alloc
1966                .init(usize::try_from(offset).unwrap() + 1)
1967                .expect("head/body/tail sizes are representable with u16/u32/u16");
1968            let mut buffer = Buffer::try_from(alloc).unwrap();
1969            buffer.check_write_and_pad(offset.try_into().unwrap(), pad_size.try_into().unwrap());
1970        }
1971    }
1972
1973    #[test_case(  0; "writes at the beginning")]
1974    #[test_case( 15; "writes in the first part")]
1975    #[test_case( 75; "writes in the second part")]
1976    #[test_case(135; "writes in the third part")]
1977    #[test_case(195; "writes in the last part")]
1978    fn buffer_used(write_offset: usize) {
1979        let pool = Pool::new_test_default();
1980        let mut buffer =
1981            pool.alloc_tx_buffer_now_or_never(MAX_BUFFER_BYTES).expect("failed to allocate buffer");
1982        let expected_caps = (0..netdev::MAX_DESCRIPTOR_CHAIN).map(|i| {
1983            if i == 0 {
1984                DEFAULT_BUFFER_LENGTH.get() - usize::from(DEFAULT_MIN_TX_BUFFER_HEAD)
1985            } else if i < netdev::MAX_DESCRIPTOR_CHAIN - 1 {
1986                DEFAULT_BUFFER_LENGTH.get()
1987            } else {
1988                DEFAULT_BUFFER_LENGTH.get() - usize::from(DEFAULT_MIN_TX_BUFFER_TAIL)
1989            }
1990        });
1991        assert_eq!(buffer.alloc.len(), netdev::MAX_DESCRIPTOR_CHAIN.into());
1992        assert_eq!(buffer.io_mut().write_at(write_offset, &[WRITE_BYTE][..]), 1);
1993        // The accumulator is Some if we haven't found the part where the byte
1994        // was written, None if we've already found it.
1995        assert_eq!(
1996            buffer.parts().zip(expected_caps).fold(
1997                Some(write_offset),
1998                |offset, (slice, expected_cap)| {
1999                    assert_eq!(slice.len(), expected_cap);
2000                    match offset {
2001                        Some(offset) => {
2002                            if offset >= expected_cap {
2003                                Some(offset - slice.len())
2004                            } else {
2005                                assert_eq!(slice[offset], WRITE_BYTE);
2006                                None
2007                            }
2008                        }
2009                        None => None,
2010                    }
2011                }
2012            ),
2013            None
2014        );
2015    }
2016
2017    #[test]
2018    fn allocate_under_device_minimum() {
2019        const MIN_TX_DATA: usize = 32;
2020        const ALLOC_SIZE: usize = 16;
2021        const WRITE_BYTE: u8 = 0xff;
2022        const WRITE_SENTINAL_BYTE: u8 = 0xee;
2023        const READ_SENTINAL_BYTE: u8 = 0xdd;
2024        let mut config = DEFAULT_CONFIG;
2025        config.buffer_layout.min_tx_data = MIN_TX_DATA;
2026        let (pool, _descriptors, _vmo) = Pool::new(config).expect("failed to create a new pool");
2027        for mut buffer in Vec::from_iter(std::iter::from_fn({
2028            let pool = pool.clone();
2029            move || pool.alloc_tx_buffer_now_or_never(MIN_TX_DATA)
2030        })) {
2031            assert_eq!(
2032                buffer.io_mut().write_at(0, &[WRITE_SENTINAL_BYTE; MIN_TX_DATA]),
2033                MIN_TX_DATA
2034            );
2035        }
2036        let mut allocated =
2037            pool.alloc_tx_buffer_now_or_never(16).expect("failed to allocate buffer");
2038        assert_eq!(allocated.len(), MIN_TX_DATA);
2039        const WRITE_BUF_SIZE: usize = MIN_TX_DATA + 1;
2040        assert_eq!(allocated.io_mut().write_at(0, &[WRITE_BYTE; WRITE_BUF_SIZE]), MIN_TX_DATA);
2041        assert_eq!(allocated.io_mut().write_at(0, &[WRITE_BYTE; ALLOC_SIZE]), ALLOC_SIZE);
2042        assert_eq!(allocated.len(), MIN_TX_DATA);
2043        const READ_BUF_SIZE: usize = MIN_TX_DATA + 1;
2044        let mut read_buf = [READ_SENTINAL_BYTE; READ_BUF_SIZE];
2045        assert_eq!(allocated.io().read_at(0, &mut read_buf[..]), MIN_TX_DATA);
2046        assert_eq!(allocated.io().read_at(0, &mut read_buf[..MIN_TX_DATA]), MIN_TX_DATA);
2047        assert_eq!(&read_buf[..ALLOC_SIZE], &[WRITE_BYTE; ALLOC_SIZE][..]);
2048        assert_eq!(&read_buf[ALLOC_SIZE..MIN_TX_DATA], &[WRITE_BYTE; ALLOC_SIZE][..]);
2049        assert_eq!(&read_buf[MIN_TX_DATA..], &[READ_SENTINAL_BYTE; 1][..]);
2050    }
2051
2052    #[test]
2053    fn invalid_tx_length() {
2054        let mut config = DEFAULT_CONFIG;
2055        config.buffer_layout.length = usize::from(u16::MAX) + 2;
2056        config.buffer_layout.min_tx_head = 0;
2057        let (pool, _descriptors, _vmo) = Pool::new(config).expect("failed to create pool");
2058        assert_matches!(pool.alloc_tx_buffer(1).now_or_never(), Some(Err(Error::TxLength)));
2059    }
2060
2061    #[test]
2062    fn rx_leases() {
2063        let mut executor = fuchsia_async::TestExecutor::new();
2064        let state = RxLeaseHandlingState::new_with_enabled(true);
2065        let mut watcher = RxLeaseWatcher { state: &state };
2066
2067        {
2068            let mut fut = pin!(watcher.wait_until(0));
2069            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Ready(()));
2070        }
2071        {
2072            state.rx_complete();
2073            let mut fut = pin!(watcher.wait_until(1));
2074            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Ready(()));
2075        }
2076        {
2077            let mut fut = pin!(watcher.wait_until(0));
2078            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Ready(()));
2079        }
2080        {
2081            let mut fut = pin!(watcher.wait_until(3));
2082            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Pending);
2083            state.rx_complete();
2084            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Pending);
2085            state.rx_complete();
2086            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Ready(()));
2087        }
2088        // Dropping the wait future without seeing it complete restores the
2089        // value.
2090        let counter_before = state.rx_frame_counter.load(atomic::Ordering::SeqCst);
2091        {
2092            let mut fut = pin!(watcher.wait_until(10000));
2093            assert_eq!(executor.run_until_stalled(&mut fut), Poll::Pending);
2094        }
2095        let counter_after = state.rx_frame_counter.load(atomic::Ordering::SeqCst);
2096        assert_eq!(counter_before, counter_after);
2097    }
2098}