1pub(super) mod pool;
8pub mod sys;
9
10use std::iter;
11use std::num::{NonZeroU16, NonZeroU64};
12use std::ops::{Deref, DerefMut};
13use std::ptr::NonNull;
14use std::sync::atomic::{AtomicU8, Ordering};
15
16use fidl_fuchsia_hardware_network as netdev;
17use fuchsia_runtime::vmar_root_self;
18use static_assertions::{const_assert, const_assert_eq};
19use zx::AsHandleRef as _;
20use zx::sys::ZX_MIN_PAGE_SHIFT;
21
22use crate::error::{Error, Result};
23use crate::session::Port;
24use types::{ChainLength, DESCID_NO_NEXT};
25
26pub use pool::{AllocKind, Buffer, Rx, Tx};
27pub const NETWORK_DEVICE_DESCRIPTOR_VERSION: u8 = sys::__NETWORK_DEVICE_DESCRIPTOR_VERSION as u8;
29pub(super) use types::DescId;
30pub(super) const NETWORK_DEVICE_DESCRIPTOR_LENGTH: usize =
32 std::mem::size_of::<sys::buffer_descriptor>();
33
34const_assert_eq!(NETWORK_DEVICE_DESCRIPTOR_LENGTH % std::mem::size_of::<u64>(), 0);
36const_assert!(
40 std::mem::align_of::<Descriptor<Tx>>().count_ones() == 1
41 && std::mem::align_of::<Descriptor<Tx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
42);
43const_assert!(
44 std::mem::align_of::<Descriptor<Rx>>().count_ones() == 1
45 && std::mem::align_of::<Descriptor<Rx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
46);
47
48#[repr(transparent)]
50struct Descriptor<K: AllocKind>(sys::buffer_descriptor, std::marker::PhantomData<K>);
51
52impl<K: AllocKind> Descriptor<K> {
53 fn frame_type(&self) -> Result<netdev::FrameType> {
54 let Self(this, _marker) = self;
55 let prim = this.frame_type;
56 netdev::FrameType::from_primitive(prim).ok_or(Error::FrameType(prim))
57 }
58
59 fn chain_length(&self) -> Result<ChainLength> {
60 let Self(this, _marker) = self;
61 ChainLength::try_from(this.chain_length)
62 }
63
64 fn nxt(&self) -> Option<u16> {
65 let Self(this, _marker) = self;
66 if this.nxt == DESCID_NO_NEXT { None } else { Some(this.nxt) }
67 }
68
69 fn set_nxt(&mut self, desc: Option<DescId<K>>) {
70 let Self(this, _marker) = self;
71 this.nxt = desc.as_ref().map(DescId::get).unwrap_or(DESCID_NO_NEXT);
72 }
73
74 fn offset(&self) -> u64 {
75 let Self(this, _marker) = self;
76 this.offset
77 }
78
79 fn set_offset(&mut self, offset: u64) {
80 let Self(this, _marker) = self;
81 this.offset = offset;
82 }
83
84 fn head_length(&self) -> u16 {
85 let Self(this, _marker) = self;
86 this.head_length
87 }
88
89 fn data_length(&self) -> u32 {
90 let Self(this, _marker) = self;
91 this.data_length
92 }
93
94 fn tail_length(&self) -> u16 {
95 let Self(this, _marker) = self;
96 this.tail_length
97 }
98
99 fn port(&self) -> Port {
100 let Self(
101 sys::buffer_descriptor {
102 port_id: sys::buffer_descriptor_port_id { base, salt }, ..
103 },
104 _marker,
105 ) = self;
106 Port { base: *base, salt: *salt }
107 }
108
109 fn set_port(&mut self, Port { base, salt }: Port) {
110 let Self(sys::buffer_descriptor { port_id, .. }, _marker) = self;
111 *port_id = sys::buffer_descriptor_port_id { base, salt };
112 }
113
114 fn initialize(&mut self, chain_len: ChainLength, head_len: u16, data_len: u32, tail_len: u16) {
116 let Self(
117 sys::buffer_descriptor {
118 frame_type,
119 chain_length,
120 nxt: _,
123 info_type,
124 port_id: sys::buffer_descriptor_port_id { base, salt },
125 _reserved: _,
127 client_opaque_data: _,
129 offset: _,
132 head_length,
133 tail_length,
134 data_length,
135 inbound_flags,
136 return_flags,
137 },
138 _marker,
139 ) = self;
140 *frame_type = 0;
141 *chain_length = chain_len.get();
142 *info_type = 0;
143 *base = 0;
144 *salt = 0;
145 *head_length = head_len;
146 *tail_length = tail_len;
147 *data_length = data_len;
148 *inbound_flags = 0;
149 *return_flags = 0;
150 }
151}
152
153impl Descriptor<Rx> {
154 fn rx_flags(&self) -> Result<netdev::RxFlags> {
155 let Self(this, _marker) = self;
156 let bits = this.inbound_flags;
157 netdev::RxFlags::from_bits(bits).ok_or(Error::RxFlags(bits))
158 }
159}
160
161impl Descriptor<Tx> {
162 fn set_tx_flags(&mut self, flags: netdev::TxFlags) {
163 let Self(this, _marker) = self;
164 let bits = flags.bits();
165 this.return_flags = bits;
166 }
167
168 fn set_frame_type(&mut self, frame_type: netdev::FrameType) {
169 let Self(this, _marker) = self;
170 this.frame_type = frame_type.into_primitive();
171 }
172
173 fn commit(&mut self, used: u32) {
179 let Self(this, _marker) = self;
180 let total = this.data_length + u32::from(this.tail_length);
183 let tail = total.checked_sub(used).unwrap();
184 this.data_length = used;
185 this.tail_length = u16::try_from(tail).unwrap();
186 }
187}
188
189struct Descriptors {
194 ptr: NonNull<sys::buffer_descriptor>,
195 count: u16,
196}
197
198impl Descriptors {
199 fn new(
207 num_tx: NonZeroU16,
208 num_rx: NonZeroU16,
209 buffer_stride: NonZeroU64,
210 ) -> Result<(Self, zx::Vmo, Vec<DescId<Tx>>, Vec<DescId<Rx>>)> {
211 let total = num_tx.get() + num_rx.get();
212 let size = u64::try_from(NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(total))
213 .expect("vmo_size overflows u64");
214 let vmo = zx::Vmo::create(size).map_err(|status| Error::Vmo("descriptors", status))?;
215
216 const VMO_NAME: zx::Name =
217 const_unwrap::const_unwrap_result(zx::Name::new("netdevice:descriptors"));
218 vmo.set_name(&VMO_NAME).map_err(|status| Error::Vmo("set name", status))?;
219 let ptr = NonNull::new(
223 vmar_root_self()
224 .map(
225 0,
226 &vmo,
227 0,
228 usize::try_from(size).unwrap(),
229 zx::VmarFlags::PERM_WRITE | zx::VmarFlags::PERM_READ,
230 )
231 .map_err(|status| Error::Map("descriptors", status))?
232 as *mut sys::buffer_descriptor,
233 )
234 .unwrap();
235
236 let mut tx =
240 (0..num_tx.get()).map(|x| unsafe { DescId::<Tx>::from_raw(x) }).collect::<Vec<_>>();
241 let mut rx =
242 (num_tx.get()..total).map(|x| unsafe { DescId::<Rx>::from_raw(x) }).collect::<Vec<_>>();
243 let descriptors = Self { ptr, count: total };
244 fn init_offset<K: AllocKind>(
245 descriptors: &Descriptors,
246 desc: &mut DescId<K>,
247 buffer_stride: NonZeroU64,
248 ) {
249 let offset = buffer_stride.get().checked_mul(u64::from(desc.get())).unwrap();
250 descriptors.borrow_mut(desc).set_offset(offset);
251 }
252 tx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
253 rx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
254 Ok((descriptors, vmo, tx, rx))
255 }
256
257 fn borrow<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a DescId<K>) -> DescRef<'a, K> {
265 assert!(
266 id.get() < self.count,
267 "descriptor index out of range: {} >= {}",
268 id.get(),
269 self.count
270 );
271 unsafe { DescRef::new(self.ptr.as_ptr().add(id.get().into())) }
272 }
273
274 fn borrow_mut<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a mut DescId<K>) -> DescRefMut<'a, K> {
282 assert!(
283 id.get() < self.count,
284 "descriptor index out of range: {} >= {}",
285 id.get(),
286 self.count
287 );
288 unsafe { DescRefMut::new(self.ptr.as_ptr().add(id.get().into())) }
289 }
290
291 fn chain<K: AllocKind>(&self, head: DescId<K>) -> impl Iterator<Item = Result<DescId<K>>> + '_ {
296 iter::successors(
297 Some(Ok(head)),
298 move |curr: &Result<DescId<K>>| -> Option<Result<DescId<K>>> {
299 match curr {
300 Err(_err) => None,
301 Ok(curr) => {
302 let descriptor = self.borrow(curr);
303 match descriptor.chain_length() {
304 Err(e) => Some(Err(e)),
305 Ok(len) => {
306 if len == ChainLength::ZERO {
307 None
308 } else {
309 descriptor.nxt().map(|id| Ok(unsafe { DescId::from_raw(id) }))
314 }
315 }
316 }
317 }
318 }
319 },
320 )
321 }
322}
323
324unsafe impl Send for Descriptors {}
326unsafe impl Sync for Descriptors {}
328
329impl Drop for Descriptors {
330 fn drop(&mut self) {
331 let len = NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(self.count);
334 let page_size = usize::try_from(zx::system_get_page_size()).unwrap();
335 let aligned = (len + page_size - 1) / page_size * page_size;
336 unsafe {
337 vmar_root_self()
338 .unmap(self.ptr.as_ptr() as usize, aligned)
339 .expect("failed to unmap VMO")
340 }
341 }
342}
343
344unsafe fn ref_count<'a>(ptr: *const sys::buffer_descriptor) -> &'a AtomicU8 {
369 const_assert_eq!(std::mem::align_of::<AtomicU8>(), std::mem::align_of::<u8>());
373 unsafe { &*(&((*ptr).client_opaque_data[0]) as *const u8 as *const AtomicU8) }
374}
375
376const DESC_REF_UNUSED: u8 = 0;
378const DESC_REF_EXCLUSIVE: u8 = u8::MAX;
382
383struct DescRef<'a, K: AllocKind> {
385 ptr: &'a Descriptor<K>,
386}
387
388impl<K: AllocKind> DescRef<'_, K> {
389 unsafe fn new(ptr: *const sys::buffer_descriptor) -> Self {
403 let ref_cnt = unsafe { ref_count(ptr) };
404 let prev = ref_cnt.fetch_add(1, Ordering::AcqRel);
405 if prev == DESC_REF_EXCLUSIVE {
406 panic!("trying to create a shared reference when there is already a mutable reference");
407 }
408 if prev + 1 == DESC_REF_EXCLUSIVE {
409 panic!("there are too many shared references")
410 }
411 Self { ptr: unsafe { &*(ptr as *const Descriptor<K>) } }
412 }
413}
414
415impl<K: AllocKind> Drop for DescRef<'_, K> {
416 fn drop(&mut self) {
417 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
418 let prev = ref_cnt.fetch_sub(1, Ordering::AcqRel);
419 assert!(prev != DESC_REF_EXCLUSIVE && prev != DESC_REF_UNUSED);
420 }
421}
422
423impl<K: AllocKind> Deref for DescRef<'_, K> {
424 type Target = Descriptor<K>;
425
426 fn deref(&self) -> &Self::Target {
427 self.ptr
428 }
429}
430
431struct DescRefMut<'a, K: AllocKind> {
433 ptr: &'a mut Descriptor<K>,
434}
435
436impl<K: AllocKind> DescRefMut<'_, K> {
437 unsafe fn new(ptr: *mut sys::buffer_descriptor) -> Self {
450 let ref_cnt = unsafe { ref_count(ptr) };
451 if let Err(prev) = ref_cnt.compare_exchange(
452 DESC_REF_UNUSED,
453 DESC_REF_EXCLUSIVE,
454 Ordering::AcqRel,
455 Ordering::Acquire,
456 ) {
457 panic!(
458 "trying to create an exclusive reference when there are other references: {}",
459 prev
460 );
461 }
462 Self { ptr: unsafe { &mut *(ptr as *mut Descriptor<K>) } }
463 }
464}
465
466impl<K: AllocKind> Drop for DescRefMut<'_, K> {
467 fn drop(&mut self) {
468 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
469 if let Err(prev) = ref_cnt.compare_exchange(
470 DESC_REF_EXCLUSIVE,
471 DESC_REF_UNUSED,
472 Ordering::AcqRel,
473 Ordering::Acquire,
474 ) {
475 panic!(
476 "we have a mutable reference while the descriptor is not exclusively borrowed: {}",
477 prev
478 );
479 }
480 }
481}
482
483impl<K: AllocKind> Deref for DescRefMut<'_, K> {
484 type Target = Descriptor<K>;
485
486 fn deref(&self) -> &Self::Target {
487 self.ptr
488 }
489}
490
491impl<K: AllocKind> DerefMut for DescRefMut<'_, K> {
492 fn deref_mut(&mut self) -> &mut Self::Target {
493 &mut *self.ptr
494 }
495}
496
497mod types {
500 use super::{AllocKind, Error, Result, netdev};
501 use std::fmt::Debug;
502 use std::num::TryFromIntError;
503 use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
504
505 #[derive(PartialEq, Eq, KnownLayout, FromBytes, IntoBytes, Immutable)]
517 #[repr(transparent)]
518 pub(in crate::session) struct DescId<K: AllocKind>(u16, std::marker::PhantomData<K>);
519
520 impl<K: AllocKind> Debug for DescId<K> {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 let Self(id, _marker) = self;
523 f.debug_tuple(K::REFL.as_str()).field(id).finish()
524 }
525 }
526
527 pub(super) const DESCID_NO_NEXT: u16 = u16::MAX;
532
533 impl<K: AllocKind> DescId<K> {
534 pub(super) unsafe fn from_raw(id: u16) -> Self {
537 assert_ne!(id, DESCID_NO_NEXT);
538 Self(id, std::marker::PhantomData)
539 }
540
541 pub(super) fn get(&self) -> u16 {
542 let Self(id, _marker) = self;
543 *id
544 }
545 }
546
547 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
549 pub(in crate::session) struct ChainLength(u8);
550
551 impl TryFrom<u8> for ChainLength {
552 type Error = Error;
553
554 fn try_from(value: u8) -> Result<Self> {
555 if value > netdev::MAX_DESCRIPTOR_CHAIN {
556 return Err(Error::LargeChain(value.into()));
557 }
558 Ok(ChainLength(value))
559 }
560 }
561
562 impl TryFrom<usize> for ChainLength {
563 type Error = Error;
564
565 fn try_from(value: usize) -> Result<Self> {
566 let value =
567 u8::try_from(value).map_err(|TryFromIntError { .. }| Error::LargeChain(value))?;
568 value.try_into()
569 }
570 }
571
572 impl From<ChainLength> for usize {
573 fn from(ChainLength(len): ChainLength) -> Self {
574 len.into()
575 }
576 }
577
578 impl ChainLength {
579 pub(super) const ZERO: Self = Self(0);
580
581 pub(super) fn get(&self) -> u8 {
582 let ChainLength(len) = self;
583 *len
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use assert_matches::assert_matches;
592
593 const TX_BUFFERS: NonZeroU16 = NonZeroU16::new(1).unwrap();
595 const RX_BUFFERS: NonZeroU16 = NonZeroU16::new(2).unwrap();
596 const BUFFER_STRIDE: NonZeroU64 = NonZeroU64::new(4).unwrap();
597
598 #[test]
599 fn get_descriptor_after_vmo_write() {
600 let (descriptors, vmo, tx, rx) =
601 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
602 vmo.write(&[netdev::FrameType::Ethernet.into_primitive()][..], 0).expect("vmo write");
603 assert_eq!(tx.len(), TX_BUFFERS.get().into());
604 assert_eq!(rx.len(), RX_BUFFERS.get().into());
605 assert_eq!(
606 descriptors.borrow(&tx[0]).frame_type().expect("failed to get frame type"),
607 netdev::FrameType::Ethernet
608 );
609 }
610
611 #[test]
612 fn init_descriptor() {
613 const HEAD_LEN: u16 = 1;
614 const DATA_LEN: u32 = 2;
615 const TAIL_LEN: u16 = 3;
616 let (descriptors, _vmo, mut tx, _rx) =
617 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
618 {
619 let mut descriptor = descriptors.borrow_mut(&mut tx[0]);
620 descriptor.initialize(ChainLength::ZERO, HEAD_LEN, DATA_LEN, TAIL_LEN);
621 }
622
623 let got = descriptors.borrow(&tx[0]);
624 assert_eq!(got.chain_length().unwrap(), ChainLength::ZERO);
625 assert_eq!(got.offset(), 0);
626 assert_eq!(got.head_length(), HEAD_LEN);
627 assert_eq!(got.data_length(), DATA_LEN);
628 assert_eq!(got.tail_length(), TAIL_LEN);
629 }
630
631 #[test]
632 fn chain_length() {
633 for raw in 0..=netdev::MAX_DESCRIPTOR_CHAIN {
634 let got = ChainLength::try_from(raw)
635 .expect("the conversion should succeed with length <= MAX_DESCRIPTOR_CHAIN");
636 assert_eq!(got.get(), raw);
637 }
638
639 for raw in netdev::MAX_DESCRIPTOR_CHAIN + 1..u8::MAX {
640 assert_matches!(ChainLength::try_from(raw).expect_err("the conversion should fail with length > MAX_DESCRIPTOR_CHAIN"), Error::LargeChain(len) if len == raw.into());
641 }
642 }
643}