1pub(super) mod pool;
8pub mod sys;
9
10use std::iter;
11use std::num::{NonZeroU16, NonZeroU64};
12use std::ops::{Deref, DerefMut};
13use std::ptr::NonNull;
14use std::sync::atomic::{AtomicU8, Ordering};
15
16use fidl_fuchsia_hardware_network as netdev;
17use fuchsia_runtime::vmar_root_self;
18use static_assertions::{const_assert, const_assert_eq};
19use zx::sys::ZX_MIN_PAGE_SHIFT;
20
21use crate::error::{Error, Result};
22use crate::session::Port;
23use types::{ChainLength, DESCID_NO_NEXT};
24
25pub use pool::{AllocKind, Buffer, Rx, Tx};
26pub const NETWORK_DEVICE_DESCRIPTOR_VERSION: u8 = sys::__NETWORK_DEVICE_DESCRIPTOR_VERSION as u8;
28pub(super) use types::DescId;
29pub(super) const NETWORK_DEVICE_DESCRIPTOR_LENGTH: usize =
31 std::mem::size_of::<sys::buffer_descriptor>();
32
33const_assert_eq!(NETWORK_DEVICE_DESCRIPTOR_LENGTH % std::mem::size_of::<u64>(), 0);
35const_assert!(
39 std::mem::align_of::<Descriptor<Tx>>().count_ones() == 1
40 && std::mem::align_of::<Descriptor<Tx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
41);
42const_assert!(
43 std::mem::align_of::<Descriptor<Rx>>().count_ones() == 1
44 && std::mem::align_of::<Descriptor<Rx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
45);
46
47#[repr(transparent)]
49struct Descriptor<K: AllocKind>(sys::buffer_descriptor, std::marker::PhantomData<K>);
50
51impl<K: AllocKind> Descriptor<K> {
52 fn frame_type(&self) -> Result<netdev::FrameType> {
53 let Self(this, _marker) = self;
54 let prim = this.frame_type;
55 netdev::FrameType::from_primitive(prim).ok_or(Error::FrameType(prim))
56 }
57
58 fn chain_length(&self) -> Result<ChainLength> {
59 let Self(this, _marker) = self;
60 ChainLength::try_from(this.chain_length)
61 }
62
63 fn nxt(&self) -> Option<u16> {
64 let Self(this, _marker) = self;
65 if this.nxt == DESCID_NO_NEXT { None } else { Some(this.nxt) }
66 }
67
68 fn set_nxt(&mut self, desc: Option<DescId<K>>) {
69 let Self(this, _marker) = self;
70 this.nxt = desc.as_ref().map(DescId::get).unwrap_or(DESCID_NO_NEXT);
71 }
72
73 fn offset(&self) -> u64 {
74 let Self(this, _marker) = self;
75 this.offset
76 }
77
78 fn set_offset(&mut self, offset: u64) {
79 let Self(this, _marker) = self;
80 this.offset = offset;
81 }
82
83 fn head_length(&self) -> u16 {
84 let Self(this, _marker) = self;
85 this.head_length
86 }
87
88 fn data_length(&self) -> u32 {
89 let Self(this, _marker) = self;
90 this.data_length
91 }
92
93 fn tail_length(&self) -> u16 {
94 let Self(this, _marker) = self;
95 this.tail_length
96 }
97
98 fn port(&self) -> Port {
99 let Self(
100 sys::buffer_descriptor {
101 port_id: sys::buffer_descriptor_port_id { base, salt }, ..
102 },
103 _marker,
104 ) = self;
105 Port { base: *base, salt: *salt }
106 }
107
108 fn set_port(&mut self, Port { base, salt }: Port) {
109 let Self(sys::buffer_descriptor { port_id, .. }, _marker) = self;
110 *port_id = sys::buffer_descriptor_port_id { base, salt };
111 }
112
113 fn initialize(&mut self, chain_len: ChainLength, head_len: u16, data_len: u32, tail_len: u16) {
115 let Self(
116 sys::buffer_descriptor {
117 frame_type,
118 chain_length,
119 nxt: _,
122 info_type,
123 port_id: sys::buffer_descriptor_port_id { base, salt },
124 _reserved: _,
126 client_opaque_data: _,
128 offset: _,
131 head_length,
132 tail_length,
133 data_length,
134 inbound_flags,
135 return_flags,
136 },
137 _marker,
138 ) = self;
139 *frame_type = 0;
140 *chain_length = chain_len.get();
141 *info_type = 0;
142 *base = 0;
143 *salt = 0;
144 *head_length = head_len;
145 *tail_length = tail_len;
146 *data_length = data_len;
147 *inbound_flags = 0;
148 *return_flags = 0;
149 }
150}
151
152impl Descriptor<Rx> {
153 fn rx_flags(&self) -> Result<netdev::RxFlags> {
154 let Self(this, _marker) = self;
155 let bits = this.inbound_flags;
156 netdev::RxFlags::from_bits(bits).ok_or(Error::RxFlags(bits))
157 }
158}
159
160impl Descriptor<Tx> {
161 fn set_tx_flags(&mut self, flags: netdev::TxFlags) {
162 let Self(this, _marker) = self;
163 let bits = flags.bits();
164 this.return_flags = bits;
165 }
166
167 fn set_frame_type(&mut self, frame_type: netdev::FrameType) {
168 let Self(this, _marker) = self;
169 this.frame_type = frame_type.into_primitive();
170 }
171
172 fn commit(&mut self, used: u32) {
178 let Self(this, _marker) = self;
179 let total = this.data_length + u32::from(this.tail_length);
182 let tail = total.checked_sub(used).unwrap();
183 this.data_length = used;
184 this.tail_length = u16::try_from(tail).unwrap();
185 }
186}
187
188struct Descriptors {
193 ptr: NonNull<sys::buffer_descriptor>,
194 count: u16,
195}
196
197impl Descriptors {
198 fn new(
206 num_tx: NonZeroU16,
207 num_rx: NonZeroU16,
208 buffer_stride: NonZeroU64,
209 ) -> Result<(Self, zx::Vmo, Vec<DescId<Tx>>, Vec<DescId<Rx>>)> {
210 let total = num_tx.get() + num_rx.get();
211 let size = u64::try_from(NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(total))
212 .expect("vmo_size overflows u64");
213 let vmo = zx::Vmo::create(size).map_err(|status| Error::Vmo("descriptors", status))?;
214
215 const VMO_NAME: zx::Name =
216 const_unwrap::const_unwrap_result(zx::Name::new("netdevice:descriptors"));
217 vmo.set_name(&VMO_NAME).map_err(|status| Error::Vmo("set name", status))?;
218 let ptr = NonNull::new(
222 vmar_root_self()
223 .map(
224 0,
225 &vmo,
226 0,
227 usize::try_from(size).unwrap(),
228 zx::VmarFlags::PERM_WRITE | zx::VmarFlags::PERM_READ,
229 )
230 .map_err(|status| Error::Map("descriptors", status))?
231 as *mut sys::buffer_descriptor,
232 )
233 .unwrap();
234
235 let mut tx =
239 (0..num_tx.get()).map(|x| unsafe { DescId::<Tx>::from_raw(x) }).collect::<Vec<_>>();
240 let mut rx =
241 (num_tx.get()..total).map(|x| unsafe { DescId::<Rx>::from_raw(x) }).collect::<Vec<_>>();
242 let descriptors = Self { ptr, count: total };
243 fn init_offset<K: AllocKind>(
244 descriptors: &Descriptors,
245 desc: &mut DescId<K>,
246 buffer_stride: NonZeroU64,
247 ) {
248 let offset = buffer_stride.get().checked_mul(u64::from(desc.get())).unwrap();
249 descriptors.borrow_mut(desc).set_offset(offset);
250 }
251 tx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
252 rx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
253 Ok((descriptors, vmo, tx, rx))
254 }
255
256 fn borrow<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a DescId<K>) -> DescRef<'a, K> {
264 assert!(
265 id.get() < self.count,
266 "descriptor index out of range: {} >= {}",
267 id.get(),
268 self.count
269 );
270 unsafe { DescRef::new(self.ptr.as_ptr().add(id.get().into())) }
271 }
272
273 fn borrow_mut<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a mut DescId<K>) -> DescRefMut<'a, K> {
281 assert!(
282 id.get() < self.count,
283 "descriptor index out of range: {} >= {}",
284 id.get(),
285 self.count
286 );
287 unsafe { DescRefMut::new(self.ptr.as_ptr().add(id.get().into())) }
288 }
289
290 fn chain<K: AllocKind>(&self, head: DescId<K>) -> impl Iterator<Item = Result<DescId<K>>> + '_ {
295 iter::successors(
296 Some(Ok(head)),
297 move |curr: &Result<DescId<K>>| -> Option<Result<DescId<K>>> {
298 match curr {
299 Err(_err) => None,
300 Ok(curr) => {
301 let descriptor = self.borrow(curr);
302 match descriptor.chain_length() {
303 Err(e) => Some(Err(e)),
304 Ok(len) => {
305 if len == ChainLength::ZERO {
306 None
307 } else {
308 descriptor.nxt().map(|id| Ok(unsafe { DescId::from_raw(id) }))
313 }
314 }
315 }
316 }
317 }
318 },
319 )
320 }
321}
322
323unsafe impl Send for Descriptors {}
325unsafe impl Sync for Descriptors {}
327
328impl Drop for Descriptors {
329 fn drop(&mut self) {
330 let len = NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(self.count);
333 let page_size = usize::try_from(zx::system_get_page_size()).unwrap();
334 let aligned = (len + page_size - 1) / page_size * page_size;
335 unsafe {
336 vmar_root_self()
337 .unmap(self.ptr.as_ptr() as usize, aligned)
338 .expect("failed to unmap VMO")
339 }
340 }
341}
342
343unsafe fn ref_count<'a>(ptr: *const sys::buffer_descriptor) -> &'a AtomicU8 {
368 const_assert_eq!(std::mem::align_of::<AtomicU8>(), std::mem::align_of::<u8>());
372 unsafe { &*(&((*ptr).client_opaque_data[0]) as *const u8 as *const AtomicU8) }
373}
374
375const DESC_REF_UNUSED: u8 = 0;
377const DESC_REF_EXCLUSIVE: u8 = u8::MAX;
381
382struct DescRef<'a, K: AllocKind> {
384 ptr: &'a Descriptor<K>,
385}
386
387impl<K: AllocKind> DescRef<'_, K> {
388 unsafe fn new(ptr: *const sys::buffer_descriptor) -> Self {
402 let ref_cnt = unsafe { ref_count(ptr) };
403 let prev = ref_cnt.fetch_add(1, Ordering::AcqRel);
404 if prev == DESC_REF_EXCLUSIVE {
405 panic!("trying to create a shared reference when there is already a mutable reference");
406 }
407 if prev + 1 == DESC_REF_EXCLUSIVE {
408 panic!("there are too many shared references")
409 }
410 Self { ptr: unsafe { &*(ptr as *const Descriptor<K>) } }
411 }
412}
413
414impl<K: AllocKind> Drop for DescRef<'_, K> {
415 fn drop(&mut self) {
416 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
417 let prev = ref_cnt.fetch_sub(1, Ordering::AcqRel);
418 assert!(prev != DESC_REF_EXCLUSIVE && prev != DESC_REF_UNUSED);
419 }
420}
421
422impl<K: AllocKind> Deref for DescRef<'_, K> {
423 type Target = Descriptor<K>;
424
425 fn deref(&self) -> &Self::Target {
426 self.ptr
427 }
428}
429
430struct DescRefMut<'a, K: AllocKind> {
432 ptr: &'a mut Descriptor<K>,
433}
434
435impl<K: AllocKind> DescRefMut<'_, K> {
436 unsafe fn new(ptr: *mut sys::buffer_descriptor) -> Self {
449 let ref_cnt = unsafe { ref_count(ptr) };
450 if let Err(prev) = ref_cnt.compare_exchange(
451 DESC_REF_UNUSED,
452 DESC_REF_EXCLUSIVE,
453 Ordering::AcqRel,
454 Ordering::Acquire,
455 ) {
456 panic!(
457 "trying to create an exclusive reference when there are other references: {}",
458 prev
459 );
460 }
461 Self { ptr: unsafe { &mut *(ptr as *mut Descriptor<K>) } }
462 }
463}
464
465impl<K: AllocKind> Drop for DescRefMut<'_, K> {
466 fn drop(&mut self) {
467 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
468 if let Err(prev) = ref_cnt.compare_exchange(
469 DESC_REF_EXCLUSIVE,
470 DESC_REF_UNUSED,
471 Ordering::AcqRel,
472 Ordering::Acquire,
473 ) {
474 panic!(
475 "we have a mutable reference while the descriptor is not exclusively borrowed: {}",
476 prev
477 );
478 }
479 }
480}
481
482impl<K: AllocKind> Deref for DescRefMut<'_, K> {
483 type Target = Descriptor<K>;
484
485 fn deref(&self) -> &Self::Target {
486 self.ptr
487 }
488}
489
490impl<K: AllocKind> DerefMut for DescRefMut<'_, K> {
491 fn deref_mut(&mut self) -> &mut Self::Target {
492 &mut *self.ptr
493 }
494}
495
496mod types {
499 use super::{AllocKind, Error, Result, netdev};
500 use std::fmt::Debug;
501 use std::num::TryFromIntError;
502 use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
503
504 #[derive(PartialEq, Eq, KnownLayout, FromBytes, IntoBytes, Immutable)]
516 #[repr(transparent)]
517 pub(in crate::session) struct DescId<K: AllocKind>(u16, std::marker::PhantomData<K>);
518
519 impl<K: AllocKind> Debug for DescId<K> {
520 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 let Self(id, _marker) = self;
522 f.debug_tuple(K::REFL.as_str()).field(id).finish()
523 }
524 }
525
526 pub(super) const DESCID_NO_NEXT: u16 = u16::MAX;
531
532 impl<K: AllocKind> DescId<K> {
533 pub(super) unsafe fn from_raw(id: u16) -> Self {
536 assert_ne!(id, DESCID_NO_NEXT);
537 Self(id, std::marker::PhantomData)
538 }
539
540 pub(super) fn get(&self) -> u16 {
541 let Self(id, _marker) = self;
542 *id
543 }
544 }
545
546 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
548 pub(in crate::session) struct ChainLength(u8);
549
550 impl TryFrom<u8> for ChainLength {
551 type Error = Error;
552
553 fn try_from(value: u8) -> Result<Self> {
554 if value > netdev::MAX_DESCRIPTOR_CHAIN {
555 return Err(Error::LargeChain(value.into()));
556 }
557 Ok(ChainLength(value))
558 }
559 }
560
561 impl TryFrom<usize> for ChainLength {
562 type Error = Error;
563
564 fn try_from(value: usize) -> Result<Self> {
565 let value =
566 u8::try_from(value).map_err(|TryFromIntError { .. }| Error::LargeChain(value))?;
567 value.try_into()
568 }
569 }
570
571 impl From<ChainLength> for usize {
572 fn from(ChainLength(len): ChainLength) -> Self {
573 len.into()
574 }
575 }
576
577 impl ChainLength {
578 pub(super) const ZERO: Self = Self(0);
579
580 pub(super) fn get(&self) -> u8 {
581 let ChainLength(len) = self;
582 *len
583 }
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590 use assert_matches::assert_matches;
591
592 const TX_BUFFERS: NonZeroU16 = NonZeroU16::new(1).unwrap();
594 const RX_BUFFERS: NonZeroU16 = NonZeroU16::new(2).unwrap();
595 const BUFFER_STRIDE: NonZeroU64 = NonZeroU64::new(4).unwrap();
596
597 #[test]
598 fn get_descriptor_after_vmo_write() {
599 let (descriptors, vmo, tx, rx) =
600 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
601 vmo.write(&[netdev::FrameType::Ethernet.into_primitive()][..], 0).expect("vmo write");
602 assert_eq!(tx.len(), TX_BUFFERS.get().into());
603 assert_eq!(rx.len(), RX_BUFFERS.get().into());
604 assert_eq!(
605 descriptors.borrow(&tx[0]).frame_type().expect("failed to get frame type"),
606 netdev::FrameType::Ethernet
607 );
608 }
609
610 #[test]
611 fn init_descriptor() {
612 const HEAD_LEN: u16 = 1;
613 const DATA_LEN: u32 = 2;
614 const TAIL_LEN: u16 = 3;
615 let (descriptors, _vmo, mut tx, _rx) =
616 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
617 {
618 let mut descriptor = descriptors.borrow_mut(&mut tx[0]);
619 descriptor.initialize(ChainLength::ZERO, HEAD_LEN, DATA_LEN, TAIL_LEN);
620 }
621
622 let got = descriptors.borrow(&tx[0]);
623 assert_eq!(got.chain_length().unwrap(), ChainLength::ZERO);
624 assert_eq!(got.offset(), 0);
625 assert_eq!(got.head_length(), HEAD_LEN);
626 assert_eq!(got.data_length(), DATA_LEN);
627 assert_eq!(got.tail_length(), TAIL_LEN);
628 }
629
630 #[test]
631 fn chain_length() {
632 for raw in 0..=netdev::MAX_DESCRIPTOR_CHAIN {
633 let got = ChainLength::try_from(raw)
634 .expect("the conversion should succeed with length <= MAX_DESCRIPTOR_CHAIN");
635 assert_eq!(got.get(), raw);
636 }
637
638 for raw in netdev::MAX_DESCRIPTOR_CHAIN + 1..u8::MAX {
639 assert_matches!(ChainLength::try_from(raw).expect_err("the conversion should fail with length > MAX_DESCRIPTOR_CHAIN"), Error::LargeChain(len) if len == raw.into());
640 }
641 }
642}