1pub(super) mod pool;
8pub mod sys;
9
10use std::iter;
11use std::num::{NonZeroU16, NonZeroU64};
12use std::ops::{Deref, DerefMut};
13use std::ptr::NonNull;
14use std::sync::atomic::{AtomicU8, Ordering};
15
16use fidl_fuchsia_hardware_network as netdev;
17use fuchsia_runtime::vmar_root_self;
18use zx::sys::ZX_MIN_PAGE_SHIFT;
19
20use static_assertions::{const_assert, const_assert_eq};
21
22use crate::error::{Error, Result};
23use crate::session::Port;
24use types::{ChainLength, DESCID_NO_NEXT};
25
26pub use pool::{AllocKind, Buffer, Rx, Tx};
27pub const NETWORK_DEVICE_DESCRIPTOR_VERSION: u8 = sys::__NETWORK_DEVICE_DESCRIPTOR_VERSION as u8;
29pub(super) use types::DescId;
30pub(super) const NETWORK_DEVICE_DESCRIPTOR_LENGTH: usize =
32 std::mem::size_of::<sys::buffer_descriptor>();
33
34const_assert_eq!(NETWORK_DEVICE_DESCRIPTOR_LENGTH % std::mem::size_of::<u64>(), 0);
36const_assert!(
40 std::mem::align_of::<Descriptor<Tx>>().count_ones() == 1
41 && std::mem::align_of::<Descriptor<Tx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
42);
43const_assert!(
44 std::mem::align_of::<Descriptor<Rx>>().count_ones() == 1
45 && std::mem::align_of::<Descriptor<Rx>>() <= (1 << ZX_MIN_PAGE_SHIFT)
46);
47
48#[repr(transparent)]
50struct Descriptor<K: AllocKind>(sys::buffer_descriptor, std::marker::PhantomData<K>);
51
52impl<K: AllocKind> Descriptor<K> {
53 fn frame_type(&self) -> Result<netdev::FrameType> {
54 let Self(this, _marker) = self;
55 let prim = this.frame_type;
56 netdev::FrameType::from_primitive(prim).ok_or(Error::FrameType(prim))
57 }
58
59 fn chain_length(&self) -> Result<ChainLength> {
60 let Self(this, _marker) = self;
61 ChainLength::try_from(this.chain_length)
62 }
63
64 fn nxt(&self) -> Option<u16> {
65 let Self(this, _marker) = self;
66 if this.nxt == DESCID_NO_NEXT {
67 None
68 } else {
69 Some(this.nxt)
70 }
71 }
72
73 fn set_nxt(&mut self, desc: Option<DescId<K>>) {
74 let Self(this, _marker) = self;
75 this.nxt = desc.as_ref().map(DescId::get).unwrap_or(DESCID_NO_NEXT);
76 }
77
78 fn offset(&self) -> u64 {
79 let Self(this, _marker) = self;
80 this.offset
81 }
82
83 fn set_offset(&mut self, offset: u64) {
84 let Self(this, _marker) = self;
85 this.offset = offset;
86 }
87
88 fn head_length(&self) -> u16 {
89 let Self(this, _marker) = self;
90 this.head_length
91 }
92
93 fn data_length(&self) -> u32 {
94 let Self(this, _marker) = self;
95 this.data_length
96 }
97
98 fn tail_length(&self) -> u16 {
99 let Self(this, _marker) = self;
100 this.tail_length
101 }
102
103 fn port(&self) -> Port {
104 let Self(
105 sys::buffer_descriptor {
106 port_id: sys::buffer_descriptor_port_id { base, salt }, ..
107 },
108 _marker,
109 ) = self;
110 Port { base: *base, salt: *salt }
111 }
112
113 fn set_port(&mut self, Port { base, salt }: Port) {
114 let Self(sys::buffer_descriptor { port_id, .. }, _marker) = self;
115 *port_id = sys::buffer_descriptor_port_id { base, salt };
116 }
117
118 fn initialize(&mut self, chain_len: ChainLength, head_len: u16, data_len: u32, tail_len: u16) {
120 let Self(
121 sys::buffer_descriptor {
122 frame_type,
123 chain_length,
124 nxt: _,
127 info_type,
128 port_id: sys::buffer_descriptor_port_id { base, salt },
129 _reserved: _,
131 client_opaque_data: _,
133 offset: _,
136 head_length,
137 tail_length,
138 data_length,
139 inbound_flags,
140 return_flags,
141 },
142 _marker,
143 ) = self;
144 *frame_type = 0;
145 *chain_length = chain_len.get();
146 *info_type = 0;
147 *base = 0;
148 *salt = 0;
149 *head_length = head_len;
150 *tail_length = tail_len;
151 *data_length = data_len;
152 *inbound_flags = 0;
153 *return_flags = 0;
154 }
155}
156
157impl Descriptor<Rx> {
158 fn rx_flags(&self) -> Result<netdev::RxFlags> {
159 let Self(this, _marker) = self;
160 let bits = this.inbound_flags;
161 netdev::RxFlags::from_bits(bits).ok_or(Error::RxFlags(bits))
162 }
163}
164
165impl Descriptor<Tx> {
166 fn set_tx_flags(&mut self, flags: netdev::TxFlags) {
167 let Self(this, _marker) = self;
168 let bits = flags.bits();
169 this.return_flags = bits;
170 }
171
172 fn set_frame_type(&mut self, frame_type: netdev::FrameType) {
173 let Self(this, _marker) = self;
174 this.frame_type = frame_type.into_primitive();
175 }
176
177 fn commit(&mut self, used: u32) {
183 let Self(this, _marker) = self;
184 let total = this.data_length + u32::from(this.tail_length);
187 let tail = total.checked_sub(used).unwrap();
188 this.data_length = used;
189 this.tail_length = u16::try_from(tail).unwrap();
190 }
191}
192
193struct Descriptors {
198 ptr: NonNull<sys::buffer_descriptor>,
199 count: u16,
200}
201
202impl Descriptors {
203 fn new(
211 num_tx: NonZeroU16,
212 num_rx: NonZeroU16,
213 buffer_stride: NonZeroU64,
214 ) -> Result<(Self, zx::Vmo, Vec<DescId<Tx>>, Vec<DescId<Rx>>)> {
215 let total = num_tx.get() + num_rx.get();
216 let size = u64::try_from(NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(total))
217 .expect("vmo_size overflows u64");
218 let vmo = zx::Vmo::create(size).map_err(|status| Error::Vmo("descriptors", status))?;
219 let ptr = NonNull::new(
223 vmar_root_self()
224 .map(
225 0,
226 &vmo,
227 0,
228 usize::try_from(size).unwrap(),
229 zx::VmarFlags::PERM_WRITE | zx::VmarFlags::PERM_READ,
230 )
231 .map_err(|status| Error::Map("descriptors", status))?
232 as *mut sys::buffer_descriptor,
233 )
234 .unwrap();
235
236 let mut tx =
240 (0..num_tx.get()).map(|x| unsafe { DescId::<Tx>::from_raw(x) }).collect::<Vec<_>>();
241 let mut rx =
242 (num_tx.get()..total).map(|x| unsafe { DescId::<Rx>::from_raw(x) }).collect::<Vec<_>>();
243 let descriptors = Self { ptr, count: total };
244 fn init_offset<K: AllocKind>(
245 descriptors: &Descriptors,
246 desc: &mut DescId<K>,
247 buffer_stride: NonZeroU64,
248 ) {
249 let offset = buffer_stride.get().checked_mul(u64::from(desc.get())).unwrap();
250 descriptors.borrow_mut(desc).set_offset(offset);
251 }
252 tx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
253 rx.iter_mut().for_each(|desc| init_offset(&descriptors, desc, buffer_stride));
254 Ok((descriptors, vmo, tx, rx))
255 }
256
257 fn borrow<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a DescId<K>) -> DescRef<'a, K> {
265 assert!(
266 id.get() < self.count,
267 "descriptor index out of range: {} >= {}",
268 id.get(),
269 self.count
270 );
271 unsafe { DescRef::new(self.ptr.as_ptr().add(id.get().into())) }
272 }
273
274 fn borrow_mut<'a, 'b: 'a, K: AllocKind>(&'b self, id: &'a mut DescId<K>) -> DescRefMut<'a, K> {
282 assert!(
283 id.get() < self.count,
284 "descriptor index out of range: {} >= {}",
285 id.get(),
286 self.count
287 );
288 unsafe { DescRefMut::new(self.ptr.as_ptr().add(id.get().into())) }
289 }
290
291 fn chain<K: AllocKind>(&self, head: DescId<K>) -> impl Iterator<Item = Result<DescId<K>>> + '_ {
296 iter::successors(
297 Some(Ok(head)),
298 move |curr: &Result<DescId<K>>| -> Option<Result<DescId<K>>> {
299 match curr {
300 Err(_err) => None,
301 Ok(curr) => {
302 let descriptor = self.borrow(curr);
303 match descriptor.chain_length() {
304 Err(e) => Some(Err(e)),
305 Ok(len) => {
306 if len == ChainLength::ZERO {
307 None
308 } else {
309 descriptor.nxt().map(|id| Ok(unsafe { DescId::from_raw(id) }))
314 }
315 }
316 }
317 }
318 }
319 },
320 )
321 }
322}
323
324unsafe impl Send for Descriptors {}
326unsafe impl Sync for Descriptors {}
328
329impl Drop for Descriptors {
330 fn drop(&mut self) {
331 let len = NETWORK_DEVICE_DESCRIPTOR_LENGTH * usize::from(self.count);
334 let page_size = usize::try_from(zx::system_get_page_size()).unwrap();
335 let aligned = (len + page_size - 1) / page_size * page_size;
336 unsafe {
337 vmar_root_self()
338 .unmap(self.ptr.as_ptr() as usize, aligned)
339 .expect("failed to unmap VMO")
340 }
341 }
342}
343
344unsafe fn ref_count<'a>(ptr: *const sys::buffer_descriptor) -> &'a AtomicU8 {
369 const_assert_eq!(std::mem::align_of::<AtomicU8>(), std::mem::align_of::<u8>());
373 &*(&((*ptr).client_opaque_data[0]) as *const u8 as *const AtomicU8)
374}
375
376const DESC_REF_UNUSED: u8 = 0;
378const DESC_REF_EXCLUSIVE: u8 = u8::MAX;
382
383struct DescRef<'a, K: AllocKind> {
385 ptr: &'a Descriptor<K>,
386}
387
388impl<K: AllocKind> DescRef<'_, K> {
389 unsafe fn new(ptr: *const sys::buffer_descriptor) -> Self {
403 let ref_cnt = ref_count(ptr);
404 let prev = ref_cnt.fetch_add(1, Ordering::AcqRel);
405 if prev == DESC_REF_EXCLUSIVE {
406 panic!("trying to create a shared reference when there is already a mutable reference");
407 }
408 if prev + 1 == DESC_REF_EXCLUSIVE {
409 panic!("there are too many shared references")
410 }
411 Self { ptr: &*(ptr as *const Descriptor<K>) }
412 }
413}
414
415impl<K: AllocKind> Drop for DescRef<'_, K> {
416 fn drop(&mut self) {
417 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
418 let prev = ref_cnt.fetch_sub(1, Ordering::AcqRel);
419 assert!(prev != DESC_REF_EXCLUSIVE && prev != DESC_REF_UNUSED);
420 }
421}
422
423impl<K: AllocKind> Deref for DescRef<'_, K> {
424 type Target = Descriptor<K>;
425
426 fn deref(&self) -> &Self::Target {
427 self.ptr
428 }
429}
430
431struct DescRefMut<'a, K: AllocKind> {
433 ptr: &'a mut Descriptor<K>,
434}
435
436impl<K: AllocKind> DescRefMut<'_, K> {
437 unsafe fn new(ptr: *mut sys::buffer_descriptor) -> Self {
450 let ref_cnt = ref_count(ptr);
451 if let Err(prev) = ref_cnt.compare_exchange(
452 DESC_REF_UNUSED,
453 DESC_REF_EXCLUSIVE,
454 Ordering::AcqRel,
455 Ordering::Acquire,
456 ) {
457 panic!(
458 "trying to create an exclusive reference when there are other references: {}",
459 prev
460 );
461 }
462 Self { ptr: &mut *(ptr as *mut Descriptor<K>) }
463 }
464}
465
466impl<K: AllocKind> Drop for DescRefMut<'_, K> {
467 fn drop(&mut self) {
468 let ref_cnt = unsafe { ref_count(&self.ptr.0 as *const _) };
469 if let Err(prev) = ref_cnt.compare_exchange(
470 DESC_REF_EXCLUSIVE,
471 DESC_REF_UNUSED,
472 Ordering::AcqRel,
473 Ordering::Acquire,
474 ) {
475 panic!(
476 "we have a mutable reference while the descriptor is not exclusively borrowed: {}",
477 prev
478 );
479 }
480 }
481}
482
483impl<K: AllocKind> Deref for DescRefMut<'_, K> {
484 type Target = Descriptor<K>;
485
486 fn deref(&self) -> &Self::Target {
487 self.ptr
488 }
489}
490
491impl<K: AllocKind> DerefMut for DescRefMut<'_, K> {
492 fn deref_mut(&mut self) -> &mut Self::Target {
493 &mut *self.ptr
494 }
495}
496
497mod types {
500 use super::{netdev, AllocKind, Error, Result};
501 use std::fmt::Debug;
502 use std::num::TryFromIntError;
503 use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
504
505 #[derive(PartialEq, Eq, KnownLayout, FromBytes, IntoBytes, Immutable)]
517 #[repr(transparent)]
518 pub(in crate::session) struct DescId<K: AllocKind>(u16, std::marker::PhantomData<K>);
519
520 impl<K: AllocKind> Debug for DescId<K> {
521 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
522 let Self(id, _marker) = self;
523 f.debug_tuple(K::REFL.as_str()).field(id).finish()
524 }
525 }
526
527 pub(super) const DESCID_NO_NEXT: u16 = u16::MAX;
532
533 impl<K: AllocKind> DescId<K> {
534 pub(super) unsafe fn from_raw(id: u16) -> Self {
537 assert_ne!(id, DESCID_NO_NEXT);
538 Self(id, std::marker::PhantomData)
539 }
540
541 pub(super) fn get(&self) -> u16 {
542 let Self(id, _marker) = self;
543 *id
544 }
545 }
546
547 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
549 pub(in crate::session) struct ChainLength(u8);
550
551 impl TryFrom<u8> for ChainLength {
552 type Error = Error;
553
554 fn try_from(value: u8) -> Result<Self> {
555 if value > netdev::MAX_DESCRIPTOR_CHAIN {
556 return Err(Error::LargeChain(value.into()));
557 }
558 Ok(ChainLength(value))
559 }
560 }
561
562 impl TryFrom<usize> for ChainLength {
563 type Error = Error;
564
565 fn try_from(value: usize) -> Result<Self> {
566 let value =
567 u8::try_from(value).map_err(|TryFromIntError { .. }| Error::LargeChain(value))?;
568 value.try_into()
569 }
570 }
571
572 impl From<ChainLength> for usize {
573 fn from(ChainLength(len): ChainLength) -> Self {
574 len.into()
575 }
576 }
577
578 impl ChainLength {
579 pub(super) const ZERO: Self = Self(0);
580
581 pub(super) fn get(&self) -> u8 {
582 let ChainLength(len) = self;
583 *len
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use assert_matches::assert_matches;
592
593 const TX_BUFFERS: NonZeroU16 = NonZeroU16::new(1).unwrap();
595 const RX_BUFFERS: NonZeroU16 = NonZeroU16::new(2).unwrap();
596 const BUFFER_STRIDE: NonZeroU64 = NonZeroU64::new(4).unwrap();
597
598 #[test]
599 fn get_descriptor_after_vmo_write() {
600 let (descriptors, vmo, tx, rx) =
601 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
602 vmo.write(&[netdev::FrameType::Ethernet.into_primitive()][..], 0).expect("vmo write");
603 assert_eq!(tx.len(), TX_BUFFERS.get().into());
604 assert_eq!(rx.len(), RX_BUFFERS.get().into());
605 assert_eq!(
606 descriptors.borrow(&tx[0]).frame_type().expect("failed to get frame type"),
607 netdev::FrameType::Ethernet
608 );
609 }
610
611 #[test]
612 fn init_descriptor() {
613 const HEAD_LEN: u16 = 1;
614 const DATA_LEN: u32 = 2;
615 const TAIL_LEN: u16 = 3;
616 let (descriptors, _vmo, mut tx, _rx) =
617 Descriptors::new(TX_BUFFERS, RX_BUFFERS, BUFFER_STRIDE).expect("create descriptors");
618 {
619 let mut descriptor = descriptors.borrow_mut(&mut tx[0]);
620 descriptor.initialize(ChainLength::ZERO, HEAD_LEN, DATA_LEN, TAIL_LEN);
621 }
622
623 let got = descriptors.borrow(&tx[0]);
624 assert_eq!(got.chain_length().unwrap(), ChainLength::ZERO);
625 assert_eq!(got.offset(), 0);
626 assert_eq!(got.head_length(), HEAD_LEN);
627 assert_eq!(got.data_length(), DATA_LEN);
628 assert_eq!(got.tail_length(), TAIL_LEN);
629 }
630
631 #[test]
632 fn chain_length() {
633 for raw in 0..=netdev::MAX_DESCRIPTOR_CHAIN {
634 let got = ChainLength::try_from(raw)
635 .expect("the conversion should succeed with length <= MAX_DESCRIPTOR_CHAIN");
636 assert_eq!(got.get(), raw);
637 }
638
639 for raw in netdev::MAX_DESCRIPTOR_CHAIN + 1..u8::MAX {
640 assert_matches!(ChainLength::try_from(raw).expect_err("the conversion should fail with length > MAX_DESCRIPTOR_CHAIN"), Error::LargeChain(len) if len == raw.into());
641 }
642 }
643}