1use core::fmt::Debug;
8use core::hash::Hash;
9use core::num::NonZeroU16;
10
11use derivative::Derivative;
12use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
13use net_types::ethernet::Mac;
14use net_types::ip::IpVersion;
15use netstack3_base::socket::SocketCookie;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18 AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19 Inspector, InspectorDeviceExt, InspectorExt, NetworkSerializer, ReferenceNotifiers,
20 ReferenceNotifiersExt as _, RemoveResourceResultWithContext, ResourceCounterContext,
21 SendFrameContext, SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use netstack3_hashmap::{HashMap, HashSet};
24use packet::{BufferMut, ParsablePacket as _};
25use packet_formats::error::ParseError;
26use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
27
28use crate::internal::base::DeviceLayerTypes;
29use crate::internal::id::WeakDeviceId;
30
31#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
33pub enum Protocol {
34 All,
36 Specific(NonZeroU16),
38}
39
40#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
42#[derivative(Default(bound = ""))]
43pub enum TargetDevice<D> {
44 #[derivative(Default)]
46 AnyDevice,
47 SpecificDevice(D),
49}
50
51#[derive(Debug)]
53#[cfg_attr(test, derive(PartialEq))]
54pub struct SocketInfo<D> {
55 pub protocol: Option<Protocol>,
57 pub device: TargetDevice<D>,
59}
60
61pub trait DeviceSocketTypes {
64 type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
66}
67
68pub enum ReceiveFrameError {
70 QueueFull,
72}
73
74pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>:
76 DeviceSocketTypes + Sized
77{
78 fn receive_frame(
82 &self,
83 socket_id: &DeviceSocketId<DeviceId::Weak, Self>,
84 device: &DeviceId,
85 frame: Frame<&[u8]>,
86 raw_frame: &[u8],
87 ) -> Result<(), ReceiveFrameError>;
88}
89
90#[derive(Debug)]
94pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
95 PrimaryRc<SocketState<D, BT>>,
96);
97
98impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
99 fn new(external_state: BT::SocketState<D>) -> Self {
101 Self(PrimaryRc::new(SocketState {
102 external_state,
103 counters: Default::default(),
104 target: Default::default(),
105 }))
106 }
107
108 fn clone_strong(&self) -> DeviceSocketId<D, BT> {
110 let PrimaryDeviceSocketId(rc) = self;
111 DeviceSocketId(PrimaryRc::clone_strong(rc))
112 }
113}
114
115#[derive(Derivative)]
120#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
121pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
122 StrongRc<SocketState<D, BT>>,
123);
124
125impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
126 pub fn socket_cookie(&self) -> SocketCookie {
128 let Self(rc) = self;
129 SocketCookie::new(rc.resource_token())
130 }
131}
132
133impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
134 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135 let Self(rc) = self;
136 f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
137 }
138}
139
140impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
141 for DeviceSocketId<D, BT>
142{
143 type Lock = Mutex<Target<D>>;
144 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
145 let Self(rc) = self;
146 OrderedLockRef::new(&rc.target)
147 }
148}
149
150#[derive(Derivative)]
155#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
156pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
157 WeakRc<SocketState<D, BT>>,
158);
159
160impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
161 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162 let Self(rc) = self;
163 f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
164 }
165}
166
167#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
171 any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
174
175 all_sockets: RwLock<AllSockets<D, BT>>,
182}
183
184#[derive(Derivative)]
186#[derivative(Default(bound = ""))]
187pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
188 HashSet<DeviceSocketId<D, BT>>,
189);
190
191#[derive(Derivative)]
193#[derivative(Default(bound = ""))]
194pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
195 HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
196);
197
198#[derive(Debug)]
200pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
201 pub external_state: BT::SocketState<D>,
203 target: Mutex<Target<D>>,
207 counters: DeviceSocketCounters,
209}
210
211#[derive(Debug, Derivative)]
213#[derivative(Default(bound = ""))]
214pub struct Target<D> {
215 protocol: Option<Protocol>,
216 device: TargetDevice<D>,
217}
218
219#[derive(Derivative)]
224#[derivative(Default(bound = ""))]
225#[cfg_attr(
226 test,
227 derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
228)]
229pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
230 HashSet<DeviceSocketId<D, BT>>,
231);
232
233pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
235
236pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
240
241pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
243 type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>;
245
246 fn with_all_device_sockets<
249 F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
250 R,
251 >(
252 &mut self,
253 cb: F,
254 ) -> R;
255
256 fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
259 &mut self,
260 cb: F,
261 ) -> R;
262
263 fn with_any_device_sockets<
265 F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
266 R,
267 >(
268 &mut self,
269 cb: F,
270 ) -> R;
271
272 fn with_any_device_sockets_mut<
274 F: FnOnce(
275 &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
276 &mut Self::SocketTablesCoreCtx<'_>,
277 ) -> R,
278 R,
279 >(
280 &mut self,
281 cb: F,
282 ) -> R;
283}
284
285pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
287 fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
289 &mut self,
290 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291 cb: F,
292 ) -> R;
293
294 fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
296 &mut self,
297 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
298 cb: F,
299 ) -> R;
300}
301
302pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
304 type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
306 + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
307
308 fn with_device_sockets<
311 F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
312 R,
313 >(
314 &mut self,
315 device: &Self::DeviceId,
316 cb: F,
317 ) -> R;
318
319 fn with_device_sockets_mut<
322 F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
323 R,
324 >(
325 &mut self,
326 device: &Self::DeviceId,
327 cb: F,
328 ) -> R;
329}
330
331enum MaybeUpdate<T> {
332 NoChange,
333 NewValue(T),
334}
335
336fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
337 core_ctx: &mut CC,
338 socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
339 new_device: TargetDevice<&CC::DeviceId>,
340 protocol_update: MaybeUpdate<Protocol>,
341) {
342 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
343 let old_device = core_ctx.with_socket_state_mut(socket, |Target { protocol, device }| {
349 match protocol_update {
350 MaybeUpdate::NewValue(p) => *protocol = Some(p),
351 MaybeUpdate::NoChange => (),
352 };
353 let old_device = match &device {
354 TargetDevice::SpecificDevice(device) => device.upgrade(),
355 TargetDevice::AnyDevice => {
356 assert!(any_device_sockets.remove(socket));
357 None
358 }
359 };
360 *device = match &new_device {
361 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
362 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
363 };
364 old_device
365 });
366
367 if let Some(device) = old_device {
373 core_ctx.with_device_sockets_mut(
376 &device,
377 |DeviceSockets(device_sockets), _core_ctx| {
378 assert!(device_sockets.remove(socket), "socket not found in device state");
379 },
380 );
381 }
382
383 match &new_device {
385 TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
386 new_device,
387 |DeviceSockets(device_sockets), _core_ctx| {
388 assert!(device_sockets.insert(socket.clone()));
389 },
390 ),
391 TargetDevice::AnyDevice => {
392 assert!(any_device_sockets.insert(socket.clone()))
393 }
394 }
395 })
396}
397
398pub struct DeviceSocketApi<C>(C);
400
401impl<C> DeviceSocketApi<C> {
402 pub fn new(ctx: C) -> Self {
404 Self(ctx)
405 }
406}
407
408type ApiSocketId<C> = DeviceSocketId<
413 <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
414 <C as ContextPair>::BindingsContext,
415>;
416
417impl<C> DeviceSocketApi<C>
418where
419 C: ContextPair,
420 C::CoreContext: DeviceSocketContext<C::BindingsContext>
421 + SocketStateAccessor<C::BindingsContext>
422 + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
423 C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
424 + ReferenceNotifiers
425 + 'static,
426{
427 fn core_ctx(&mut self) -> &mut C::CoreContext {
428 let Self(pair) = self;
429 pair.core_ctx()
430 }
431
432 fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
433 let Self(pair) = self;
434 pair.contexts()
435 }
436
437 pub fn create(
439 &mut self,
440 external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
441 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
442 >,
443 ) -> ApiSocketId<C> {
444 let core_ctx = self.core_ctx();
445
446 let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
447 let primary = PrimaryDeviceSocketId::new(external_state);
448 let strong = primary.clone_strong();
449 assert!(sockets.insert(strong.clone(), primary).is_none());
450 strong
451 });
452 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
453 assert!(any_device_sockets.insert(strong.clone()));
460 });
461 strong
462 }
463
464 pub fn set_device(
466 &mut self,
467 socket: &ApiSocketId<C>,
468 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
469 ) {
470 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
471 }
472
473 pub fn set_device_and_protocol(
475 &mut self,
476 socket: &ApiSocketId<C>,
477 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
478 protocol: Protocol,
479 ) {
480 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
481 }
482
483 pub fn get_info(
485 &mut self,
486 id: &ApiSocketId<C>,
487 ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
488 self.core_ctx().with_socket_state(id, |Target { device, protocol }| SocketInfo {
489 device: device.clone(),
490 protocol: *protocol,
491 })
492 }
493
494 pub fn remove(
496 &mut self,
497 id: ApiSocketId<C>,
498 ) -> RemoveResourceResultWithContext<
499 <C::BindingsContext as DeviceSocketTypes>::SocketState<
500 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
501 >,
502 C::BindingsContext,
503 > {
504 let core_ctx = self.core_ctx();
505 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
506 let old_device = core_ctx.with_socket_state_mut(&id, |target| {
507 let Target { device, protocol: _ } = target;
508 match &device {
509 TargetDevice::SpecificDevice(device) => device.upgrade(),
510 TargetDevice::AnyDevice => {
511 assert!(any_device_sockets.remove(&id));
512 None
513 }
514 }
515 });
516 if let Some(device) = old_device {
517 core_ctx.with_device_sockets_mut(
518 &device,
519 |DeviceSockets(device_sockets), _core_ctx| {
520 assert!(device_sockets.remove(&id), "device doesn't have socket");
521 },
522 )
523 }
524 });
525
526 core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
527 let primary = sockets
528 .remove(&id)
529 .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
530 drop(id);
533
534 let PrimaryDeviceSocketId(primary) = primary;
535 C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
536 primary,
537 |SocketState { external_state, counters: _, target: _ }| external_state,
538 )
539 })
540 }
541
542 pub fn send_frame<S, D>(
544 &mut self,
545 id: &ApiSocketId<C>,
546 metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
547 body: S,
548 ) -> Result<(), SendFrameErrorReason>
549 where
550 S: NetworkSerializer,
551 S::Buffer: BufferMut,
552 D: DeviceSocketSendTypes,
553 C::CoreContext: DeviceIdContext<D>
554 + SendFrameContext<
555 C::BindingsContext,
556 DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
557 >,
558 C::BindingsContext: DeviceLayerTypes,
559 {
560 let (core_ctx, bindings_ctx) = self.contexts();
561 let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
562 match &result {
563 Ok(()) => {
564 core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
565 }
566 Err(SendFrameErrorReason::QueueFull) => core_ctx
567 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
568 Err(SendFrameErrorReason::Alloc) => core_ctx
569 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
570 Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
571 .increment_both(id, |counters: &DeviceSocketCounters| {
572 &counters.tx_err_size_constraint
573 }),
574 }
575 result
576 }
577
578 pub fn inspect<N>(&mut self, inspector: &mut N)
580 where
581 N: Inspector
582 + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
583 {
584 self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
585 sockets.keys().for_each(|socket| {
586 inspector.record_debug_child(socket, |node| {
587 core_ctx.with_socket_state(socket, |Target { protocol, device }| {
588 node.record_debug("Protocol", protocol);
589 match device {
590 TargetDevice::AnyDevice => node.record_str("Device", "Any"),
591 TargetDevice::SpecificDevice(d) => N::record_device(node, "Device", d),
592 }
593 });
594 node.record_child("Counters", |node| {
595 node.delegate_inspectable(socket.counters())
596 })
597 })
598 })
599 })
600 }
601}
602
603pub trait DeviceSocketSendTypes: Device {
605 type Metadata;
607}
608
609#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612 pub device_id: DeviceId,
614 pub metadata: D::Metadata,
616 }
619
620#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623 pub dest_addr: Mac,
625 pub protocol: EtherType,
627}
628
629pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636 pub fn socket_state(&self) -> &BT::SocketState<D> {
639 let Self(strong) = self;
640 let SocketState { external_state, counters: _, target: _ } = &**strong;
641 external_state
642 }
643
644 pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646 let Self(inner) = self;
647 WeakDeviceSocketId(StrongRc::downgrade(inner))
648 }
649
650 pub fn counters(&self) -> &DeviceSocketCounters {
652 let Self(strong) = self;
653 let SocketState { external_state: _, counters, target: _ } = &**strong;
654 counters
655 }
656}
657
658pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663 fn handle_frame(
665 &mut self,
666 bindings_ctx: &mut BC,
667 device: &Self::DeviceId,
668 frame: Frame<&[u8]>,
669 whole_frame: &[u8],
670 );
671}
672
673#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676 Ethernet {
678 destination: FrameDestination,
680 frame: EthernetFrame<B>,
682 },
683 Ip(IpFrame<B>),
688}
689
690#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693 Ethernet(EthernetFrame<B>),
695 Ip(IpFrame<B>),
700}
701
702#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707 pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709 packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710 .map_err(|_: ParseError| ParseSentFrameError)
711 .map(|frame| SentFrame::Ethernet(frame.into()))
712 }
713}
714
715#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718 pub src_mac: Mac,
720 pub dst_mac: Mac,
722 pub ethertype: Option<EtherType>,
724 pub body_offset: usize,
726 pub body: B,
728}
729
730#[derive(Clone, Copy, Debug, Eq, PartialEq)]
732pub struct IpFrame<B> {
733 pub ip_version: IpVersion,
735 pub body: B,
737}
738
739impl<B> IpFrame<B> {
740 fn ethertype(&self) -> EtherType {
741 let IpFrame { ip_version, body: _ } = self;
742 EtherType::from_ip_version(*ip_version)
743 }
744}
745
746#[derive(Clone, Copy, Debug, Eq, PartialEq)]
748pub enum Frame<B> {
749 Sent(SentFrame<B>),
751 Received(ReceivedFrame<B>),
753}
754
755impl<B> From<SentFrame<B>> for Frame<B> {
756 fn from(value: SentFrame<B>) -> Self {
757 Self::Sent(value)
758 }
759}
760
761impl<B> From<ReceivedFrame<B>> for Frame<B> {
762 fn from(value: ReceivedFrame<B>) -> Self {
763 Self::Received(value)
764 }
765}
766
767impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
768 fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
769 Self {
770 src_mac: frame.src_mac(),
771 dst_mac: frame.dst_mac(),
772 ethertype: frame.ethertype(),
773 body_offset: frame.parse_metadata().header_len(),
774 body: frame.into_body(),
775 }
776 }
777}
778
779impl<'a> ReceivedFrame<&'a [u8]> {
780 pub(crate) fn from_ethernet(
781 frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
782 destination: FrameDestination,
783 ) -> Self {
784 Self::Ethernet { destination, frame: frame.into() }
785 }
786}
787
788impl<B> Frame<B> {
789 pub fn protocol(&self) -> Option<u16> {
791 let ethertype = match self {
792 Self::Sent(SentFrame::Ethernet(frame))
793 | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
794 Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
795 Some(frame.ethertype())
796 }
797 };
798 ethertype.map(Into::into)
799 }
800
801 pub fn into_body(self) -> B {
803 match self {
804 Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
805 | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
806 Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
807 frame.body
808 }
809 }
810 }
811
812 pub fn body_offset(&self) -> usize {
814 match self {
815 Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
816 | Self::Sent(SentFrame::Ethernet(frame)) => frame.body_offset,
817 Self::Received(ReceivedFrame::Ip(_)) | Self::Sent(SentFrame::Ip(_)) => 0,
818 }
819 }
820}
821
822impl<
823 D: Device,
824 BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
825 CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
826> DeviceSocketHandler<D, BC> for CC
827where
828 <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
829{
830 fn handle_frame(
831 &mut self,
832 bindings_ctx: &mut BC,
833 device: &Self::DeviceId,
834 frame: Frame<&[u8]>,
835 whole_frame: &[u8],
836 ) {
837 let device = device.clone().into();
838
839 self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
843 core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
858 for socket in any_device_sockets.iter().chain(device_sockets) {
859 let delivered =
860 core_ctx.with_socket_state(socket, |Target { protocol, device: _ }| {
861 let should_deliver = match protocol {
862 None => false,
863 Some(p) => match p {
864 Protocol::Specific(p) => match frame {
868 Frame::Received(_) => Some(p.get()) == frame.protocol(),
869 Frame::Sent(_) => false,
870 },
871 Protocol::All => true,
872 },
873 };
874 should_deliver.then(|| {
875 bindings_ctx.receive_frame(socket, &device, frame, whole_frame)
876 })
877 });
878 match delivered {
879 None => {}
880 Some(result) => {
881 core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
882 &counters.rx_frames
883 });
884 match result {
885 Ok(()) => {}
886 Err(ReceiveFrameError::QueueFull) => {
887 core_ctx.increment_both(
888 socket,
889 |counters: &DeviceSocketCounters| &counters.rx_queue_full,
890 );
891 }
892 }
893 }
894 }
895 }
896 })
897 })
898 }
899}
900
901#[derive(Debug, Default)]
905pub struct DeviceSocketCounters {
906 rx_frames: Counter,
912 rx_queue_full: Counter,
915 tx_frames: Counter,
917 tx_err_queue_full: Counter,
919 tx_err_alloc: Counter,
921 tx_err_size_constraint: Counter,
923}
924
925impl Inspectable for DeviceSocketCounters {
926 fn record<I: Inspector>(&self, inspector: &mut I) {
927 let Self {
928 rx_frames,
929 rx_queue_full,
930 tx_frames,
931 tx_err_queue_full,
932 tx_err_alloc,
933 tx_err_size_constraint,
934 } = self;
935 inspector.record_child("Rx", |inspector| {
936 inspector.record_counter("DeliveredFrames", rx_frames);
937 inspector.record_counter("DroppedQueueFull", rx_queue_full);
938 });
939 inspector.record_child("Tx", |inspector| {
940 inspector.record_counter("SentFrames", tx_frames);
941 inspector.record_counter("QueueFullError", tx_err_queue_full);
942 inspector.record_counter("AllocError", tx_err_alloc);
943 inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
944 });
945 }
946}
947
948impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
949 for Sockets<D, BT>
950{
951 type Lock = RwLock<AnyDeviceSockets<D, BT>>;
952 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
953 OrderedLockRef::new(&self.any_device_sockets)
954 }
955}
956
957impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
958 for Sockets<D, BT>
959{
960 type Lock = RwLock<AllSockets<D, BT>>;
961 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
962 OrderedLockRef::new(&self.all_sockets)
963 }
964}
965
966#[cfg(any(test, feature = "testutils"))]
967mod testutil {
968 use alloc::vec::Vec;
969 use core::num::NonZeroU64;
970 use core::ops::DerefMut;
971 use netstack3_base::StrongDeviceIdentifier;
972 use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
973
974 use super::*;
975 use crate::internal::base::{
976 DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
977 };
978
979 #[derive(Derivative, Debug)]
980 #[derivative(Default(bound = ""))]
981 pub struct RxQueue<D> {
982 pub frames: Vec<ReceivedFrame<D>>,
983 #[derivative(Default(value = "usize::MAX"))]
984 pub max_size: usize,
985 }
986
987 #[derive(Clone, Debug, PartialEq)]
988 pub struct ReceivedFrame<D> {
989 pub device: D,
990 pub frame: Frame<Vec<u8>>,
991 pub raw: Vec<u8>,
992 }
993
994 #[derive(Debug, Derivative)]
995 #[derivative(Default(bound = ""))]
996 pub struct ExternalSocketState<D>(pub Mutex<RxQueue<D>>);
997
998 impl<TimerId, Event: Debug, State> DeviceSocketTypes
999 for FakeBindingsCtx<TimerId, Event, State, ()>
1000 {
1001 type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
1002 }
1003
1004 impl Frame<&[u8]> {
1005 pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
1006 match self {
1007 Self::Sent(SentFrame::Ethernet(frame)) => {
1008 Frame::Sent(SentFrame::Ethernet(frame.cloned()))
1009 }
1010 Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
1011 Frame::Received(super::ReceivedFrame::Ethernet {
1012 destination,
1013 frame: frame.cloned(),
1014 })
1015 }
1016 Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
1017 Self::Received(super::ReceivedFrame::Ip(frame)) => {
1018 Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
1019 }
1020 }
1021 }
1022 }
1023
1024 impl EthernetFrame<&[u8]> {
1025 fn cloned(self) -> EthernetFrame<Vec<u8>> {
1026 let Self { src_mac, dst_mac, ethertype, body_offset, body } = self;
1027 EthernetFrame { src_mac, dst_mac, ethertype, body_offset, body: Vec::from(body) }
1028 }
1029 }
1030
1031 impl IpFrame<&[u8]> {
1032 fn cloned(self) -> IpFrame<Vec<u8>> {
1033 let Self { ip_version, body } = self;
1034 IpFrame { ip_version, body: Vec::from(body) }
1035 }
1036 }
1037
1038 impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1039 for FakeBindingsCtx<TimerId, Event, State, ()>
1040 {
1041 fn receive_frame(
1042 &self,
1043 state: &DeviceSocketId<D::Weak, Self>,
1044 device: &D,
1045 frame: Frame<&[u8]>,
1046 raw_frame: &[u8],
1047 ) -> Result<(), ReceiveFrameError> {
1048 let ExternalSocketState(queue) = state.socket_state();
1049 let mut lock_guard = queue.lock();
1050 let RxQueue { frames, max_size } = lock_guard.deref_mut();
1051 if frames.len() < *max_size {
1052 frames.push(ReceivedFrame {
1053 device: device.downgrade(),
1054 frame: frame.cloned(),
1055 raw: raw_frame.into(),
1056 });
1057 Ok(())
1058 } else {
1059 Err(ReceiveFrameError::QueueFull)
1060 }
1061 }
1062 }
1063
1064 impl<
1065 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1066 Event: Debug + 'static,
1067 State: 'static,
1068 > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1069 {
1070 type EthernetDeviceState = ();
1071 type LoopbackDeviceState = ();
1072 type PureIpDeviceState = ();
1073 type BlackholeDeviceState = ();
1074 type DeviceIdentifier = MonotonicIdentifier;
1075 }
1076
1077 impl DeviceClassMatcher<()> for () {
1078 fn device_class_matches(&self, (): &()) -> bool {
1079 unimplemented!()
1080 }
1081 }
1082
1083 impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1084 fn id_matches(&self, _id: &NonZeroU64) -> bool {
1085 unimplemented!()
1086 }
1087
1088 fn name_matches(&self, _name: &str) -> bool {
1089 unimplemented!()
1090 }
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use alloc::vec;
1097 use alloc::vec::Vec;
1098 use core::marker::PhantomData;
1099 use core::ops::Deref;
1100
1101 use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1102 use netstack3_base::testutil::{
1103 FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1104 };
1105 use netstack3_base::{
1106 CounterContext, CtxPair, NetworkSerializationContext, SendFrameError, SendableFrameMeta,
1107 };
1108 use netstack3_hashmap::HashMap;
1109 use packet::ParsablePacket;
1110 use test_case::test_case;
1111
1112 use super::*;
1113
1114 type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1115 type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1116 type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1117
1118 trait DeviceSocketApiExt: ContextPair + Sized {
1121 fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1122 DeviceSocketApi::new(self)
1123 }
1124 }
1125
1126 impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1127
1128 #[derive(Derivative)]
1129 #[derivative(Default(bound = ""))]
1130 struct FakeSockets<D: FakeStrongDeviceId> {
1131 any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1132 device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1133 all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1134 counters: DeviceSocketCounters,
1136 sent_frames: Vec<Vec<u8>>,
1137 }
1138
1139 pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1141 &'m mut AnyDevice,
1142 &'m mut AllSockets,
1143 &'m mut Devices,
1144 PhantomData<Device>,
1145 &'m DeviceSocketCounters,
1146 );
1147
1148 pub trait AsFakeSocketsMutRefs {
1151 type AnyDevice: 'static;
1152 type AllSockets: 'static;
1153 type Devices: 'static;
1154 type Device: 'static;
1155 fn as_sockets_ref(
1156 &mut self,
1157 ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1158 }
1159
1160 impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1161 type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1162 type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1163 type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1164 type Device = D;
1165
1166 fn as_sockets_ref(
1167 &mut self,
1168 ) -> FakeSocketsMutRefs<
1169 '_,
1170 AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1171 AllSockets<D::Weak, FakeBindingsCtx>,
1172 HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1173 D,
1174 > {
1175 let FakeSockets {
1176 any_device_sockets,
1177 device_sockets,
1178 all_sockets,
1179 counters,
1180 sent_frames: _,
1181 } = &mut self.state;
1182 FakeSocketsMutRefs(
1183 any_device_sockets,
1184 all_sockets,
1185 device_sockets,
1186 PhantomData,
1187 counters,
1188 )
1189 }
1190 }
1191
1192 impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1193 AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1194 {
1195 type AnyDevice = AnyDevice;
1196 type AllSockets = AllSockets;
1197 type Devices = Devices;
1198 type Device = Device;
1199
1200 fn as_sockets_ref(
1201 &mut self,
1202 ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1203 let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1204 FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1205 }
1206 }
1207
1208 impl<D: Clone> TargetDevice<&D> {
1209 fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1210 match self {
1211 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1212 TargetDevice::SpecificDevice(d) => {
1213 TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1214 }
1215 }
1216 }
1217 }
1218
1219 impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1220 fn new(devices: impl IntoIterator<Item = D>) -> Self {
1221 let device_sockets =
1222 devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1223 Self {
1224 any_device_sockets: AnyDeviceSockets::default(),
1225 device_sockets,
1226 all_sockets: Default::default(),
1227 counters: Default::default(),
1228 sent_frames: Default::default(),
1229 }
1230 }
1231 }
1232
1233 impl<
1234 'm,
1235 DeviceId: FakeStrongDeviceId,
1236 As: AsFakeSocketsMutRefs
1237 + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1238 > SocketStateAccessor<FakeBindingsCtx> for As
1239 {
1240 fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
1241 &mut self,
1242 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1243 cb: F,
1244 ) -> R {
1245 let DeviceSocketId(rc) = socket;
1246 let target = rc.target.lock();
1248 cb(&target)
1249 }
1250
1251 fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
1252 &mut self,
1253 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1254 cb: F,
1255 ) -> R {
1256 let DeviceSocketId(rc) = socket;
1257 let mut target = rc.target.lock();
1259 cb(&mut target)
1260 }
1261 }
1262
1263 impl<
1264 'm,
1265 DeviceId: FakeStrongDeviceId,
1266 As: AsFakeSocketsMutRefs<
1267 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1268 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1269 > DeviceSocketAccessor<FakeBindingsCtx> for As
1270 {
1271 type DeviceSocketCoreCtx<'a> =
1272 FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1273 fn with_device_sockets<
1274 F: FnOnce(
1275 &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1276 &mut Self::DeviceSocketCoreCtx<'_>,
1277 ) -> R,
1278 R,
1279 >(
1280 &mut self,
1281 device: &Self::DeviceId,
1282 cb: F,
1283 ) -> R {
1284 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1285 self.as_sockets_ref();
1286 let mut devices = device_sockets.keys().cloned().collect();
1287 let device = device_sockets.get(device).unwrap();
1288 cb(
1289 device,
1290 &mut FakeSocketsMutRefs(
1291 any_device,
1292 all_sockets,
1293 &mut devices,
1294 PhantomData,
1295 counters,
1296 ),
1297 )
1298 }
1299 fn with_device_sockets_mut<
1300 F: FnOnce(
1301 &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1302 &mut Self::DeviceSocketCoreCtx<'_>,
1303 ) -> R,
1304 R,
1305 >(
1306 &mut self,
1307 device: &Self::DeviceId,
1308 cb: F,
1309 ) -> R {
1310 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1311 self.as_sockets_ref();
1312 let mut devices = device_sockets.keys().cloned().collect();
1313 let device = device_sockets.get_mut(device).unwrap();
1314 cb(
1315 device,
1316 &mut FakeSocketsMutRefs(
1317 any_device,
1318 all_sockets,
1319 &mut devices,
1320 PhantomData,
1321 counters,
1322 ),
1323 )
1324 }
1325 }
1326
1327 impl<
1328 'm,
1329 DeviceId: FakeStrongDeviceId,
1330 As: AsFakeSocketsMutRefs<
1331 AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1332 AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1333 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1334 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1335 > DeviceSocketContext<FakeBindingsCtx> for As
1336 {
1337 type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1338 'a,
1339 (),
1340 (),
1341 HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1342 DeviceId,
1343 >;
1344
1345 fn with_any_device_sockets<
1346 F: FnOnce(
1347 &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1348 &mut Self::SocketTablesCoreCtx<'_>,
1349 ) -> R,
1350 R,
1351 >(
1352 &mut self,
1353 cb: F,
1354 ) -> R {
1355 let FakeSocketsMutRefs(
1356 any_device_sockets,
1357 _all_sockets,
1358 device_sockets,
1359 PhantomData,
1360 counters,
1361 ) = self.as_sockets_ref();
1362 cb(
1363 any_device_sockets,
1364 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1365 )
1366 }
1367 fn with_any_device_sockets_mut<
1368 F: FnOnce(
1369 &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1370 &mut Self::SocketTablesCoreCtx<'_>,
1371 ) -> R,
1372 R,
1373 >(
1374 &mut self,
1375 cb: F,
1376 ) -> R {
1377 let FakeSocketsMutRefs(
1378 any_device_sockets,
1379 _all_sockets,
1380 device_sockets,
1381 PhantomData,
1382 counters,
1383 ) = self.as_sockets_ref();
1384 cb(
1385 any_device_sockets,
1386 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1387 )
1388 }
1389
1390 fn with_all_device_sockets<
1391 F: FnOnce(
1392 &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1393 &mut Self::SocketTablesCoreCtx<'_>,
1394 ) -> R,
1395 R,
1396 >(
1397 &mut self,
1398 cb: F,
1399 ) -> R {
1400 let FakeSocketsMutRefs(
1401 _any_device_sockets,
1402 all_sockets,
1403 device_sockets,
1404 PhantomData,
1405 counters,
1406 ) = self.as_sockets_ref();
1407 cb(
1408 all_sockets,
1409 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1410 )
1411 }
1412
1413 fn with_all_device_sockets_mut<
1414 F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1415 R,
1416 >(
1417 &mut self,
1418 cb: F,
1419 ) -> R {
1420 let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1421 cb(all_sockets)
1422 }
1423 }
1424
1425 impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1426 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1427 {
1428 type DeviceId = D;
1429 type WeakDeviceId = FakeWeakDeviceId<D>;
1430 }
1431
1432 impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1433 fn counters(&self) -> &DeviceSocketCounters {
1434 &self.state.counters
1435 }
1436 }
1437
1438 impl<D: FakeStrongDeviceId>
1439 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1440 for FakeCoreCtx<D>
1441 {
1442 fn per_resource_counters<'a>(
1443 &'a self,
1444 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1445 ) -> &'a DeviceSocketCounters {
1446 socket.counters()
1447 }
1448 }
1449
1450 impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1451 fn counters(&self) -> &DeviceSocketCounters {
1452 let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1453 counters
1454 }
1455 }
1456
1457 impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1458 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1459 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1460 {
1461 fn per_resource_counters<'a>(
1462 &'a self,
1463 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1464 ) -> &'a DeviceSocketCounters {
1465 socket.counters()
1466 }
1467 }
1468
1469 const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1470
1471 #[test]
1472 fn create_remove() {
1473 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1474 MultipleDevicesId::all(),
1475 )));
1476 let mut api = ctx.device_socket_api();
1477
1478 let bound = api.create(Default::default());
1479 assert_eq!(
1480 api.get_info(&bound),
1481 SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1482 );
1483
1484 let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1485 }
1486
1487 #[test_case(TargetDevice::AnyDevice)]
1488 #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1489 fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1490 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1491 MultipleDevicesId::all(),
1492 )));
1493 let mut api = ctx.device_socket_api();
1494
1495 let bound = api.create(Default::default());
1496 api.set_device(&bound, device.clone());
1497 assert_eq!(
1498 api.get_info(&bound),
1499 SocketInfo { device: device.with_weak_id(), protocol: None }
1500 );
1501
1502 let device_sockets = &api.core_ctx().state.device_sockets;
1503 if let TargetDevice::SpecificDevice(d) = device {
1504 let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1505 assert_eq!(socket_ids, &HashSet::from([bound]));
1506 }
1507 }
1508
1509 #[test]
1510 fn update_device() {
1511 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1512 MultipleDevicesId::all(),
1513 )));
1514 let mut api = ctx.device_socket_api();
1515 let bound = api.create(Default::default());
1516
1517 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1518
1519 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1522 assert_eq!(
1523 api.get_info(&bound),
1524 SocketInfo {
1525 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1526 protocol: None
1527 }
1528 );
1529
1530 let device_sockets = &api.core_ctx().state.device_sockets;
1531 let device_socket_lists = device_sockets
1532 .iter()
1533 .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1534 .collect::<HashMap<_, _>>();
1535
1536 assert_eq!(
1537 device_socket_lists,
1538 HashMap::from([
1539 (&MultipleDevicesId::A, vec![]),
1540 (&MultipleDevicesId::B, vec![&bound]),
1541 (&MultipleDevicesId::C, vec![])
1542 ])
1543 );
1544 }
1545
1546 #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1547 #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1548 #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1549 #[test_case(
1550 Protocol::Specific(SOME_PROTOCOL),
1551 TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1552 )]
1553 fn create_set_device_and_protocol_remove_multiple(
1554 protocol: Protocol,
1555 device: TargetDevice<&MultipleDevicesId>,
1556 ) {
1557 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1558 MultipleDevicesId::all(),
1559 )));
1560 let mut api = ctx.device_socket_api();
1561
1562 let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1563 for socket in &mut sockets {
1564 api.set_device_and_protocol(socket, device.clone(), protocol);
1565 assert_eq!(
1566 api.get_info(socket),
1567 SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1568 );
1569 }
1570
1571 for socket in sockets {
1572 let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1573 }
1574 }
1575
1576 #[test]
1577 fn change_device_after_removal() {
1578 let device_to_remove = FakeReferencyDeviceId::default();
1579 let device_to_maintain = FakeReferencyDeviceId::default();
1580 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1581 device_to_remove.clone(),
1582 device_to_maintain.clone(),
1583 ])));
1584 let mut api = ctx.device_socket_api();
1585
1586 let bound = api.create(Default::default());
1587 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1590
1591 device_to_remove.mark_removed();
1594
1595 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1598 assert_eq!(
1599 api.get_info(&bound),
1600 SocketInfo {
1601 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1602 protocol: None,
1603 }
1604 );
1605
1606 let device_sockets = &api.core_ctx().state.device_sockets;
1607 let DeviceSockets(weak_sockets) =
1608 device_sockets.get(&device_to_maintain).expect("device state exists");
1609 assert_eq!(weak_sockets, &HashSet::from([bound]));
1610 }
1611
1612 struct TestData;
1613 impl TestData {
1614 const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1615 const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1616 const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1618 const BODY: &'static [u8] = b"some pig";
1619 const BUFFER: &'static [u8] = &[
1620 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1621 b'i', b'g',
1622 ];
1623 const BUFFER_OFFSET: usize = Self::BUFFER.len() - Self::BODY.len();
1624
1625 fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1627 let mut buffer_view = Self::BUFFER;
1628 packet_formats::ethernet::EthernetFrame::parse(
1629 &mut buffer_view,
1630 EthernetFrameLengthCheck::NoCheck,
1631 )
1632 .unwrap()
1633 }
1634 }
1635
1636 const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1637
1638 fn make_bound<D: FakeStrongDeviceId>(
1639 ctx: &mut FakeCtx<D>,
1640 device: TargetDevice<D>,
1641 protocol: Option<Protocol>,
1642 state: ExternalSocketState<D::Weak>,
1643 ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1644 let mut api = ctx.device_socket_api();
1645 let id = api.create(state);
1646 let device = match &device {
1647 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1648 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1649 };
1650 match protocol {
1651 Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1652 None => api.set_device(&id, device),
1653 };
1654 id
1655 }
1656
1657 fn deliver_one_frame(
1660 delivered_frame: Frame<&[u8]>,
1661 FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1662 ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1663 DeviceSocketHandler::handle_frame(
1664 core_ctx,
1665 bindings_ctx,
1666 &MultipleDevicesId::A,
1667 delivered_frame.clone(),
1668 TestData::BUFFER,
1669 );
1670
1671 let FakeSockets {
1672 all_sockets: AllSockets(all_sockets),
1673 any_device_sockets: _,
1674 device_sockets: _,
1675 counters: _,
1676 sent_frames: _,
1677 } = &core_ctx.state;
1678
1679 all_sockets
1680 .iter()
1681 .filter_map(|(id, _primary)| {
1682 let DeviceSocketId(rc) = &id;
1683 let ExternalSocketState(frames) = &rc.external_state;
1684 let lock_guard = frames.lock();
1685 let testutil::RxQueue { frames, .. } = lock_guard.deref();
1686 (!frames.is_empty()).then(|| {
1687 assert_eq!(
1688 &*frames,
1689 &[ReceivedFrame {
1690 device: FakeWeakDeviceId(MultipleDevicesId::A),
1691 frame: delivered_frame.cloned(),
1692 raw: TestData::BUFFER.into(),
1693 }]
1694 );
1695 id.clone()
1696 })
1697 })
1698 .collect()
1699 }
1700
1701 #[test]
1702 fn receive_frame_deliver_to_multiple() {
1703 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1704 MultipleDevicesId::all(),
1705 )));
1706
1707 use Protocol::*;
1708 use TargetDevice::*;
1709 let never_bound = {
1710 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1711 ctx.device_socket_api().create(state)
1712 };
1713
1714 let mut make_bound = |device, protocol| {
1715 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1716 make_bound(&mut ctx, device, protocol, state)
1717 };
1718 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1719 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1720 let bound_a_right_protocol =
1721 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1722 let bound_a_wrong_protocol =
1723 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1724 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1725 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1726 let bound_b_right_protocol =
1727 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1728 let bound_b_wrong_protocol =
1729 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1730 let bound_any_no_protocol = make_bound(AnyDevice, None);
1731 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1732 let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1733 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1734
1735 let mut sockets_with_received_frames = deliver_one_frame(
1736 super::ReceivedFrame::from_ethernet(
1737 TestData::frame(),
1738 FrameDestination::Individual { local: true },
1739 )
1740 .into(),
1741 &mut ctx,
1742 );
1743
1744 let sockets_not_expecting_frames = [
1745 never_bound,
1746 bound_a_no_protocol,
1747 bound_a_wrong_protocol,
1748 bound_b_no_protocol,
1749 bound_b_all_protocols,
1750 bound_b_right_protocol,
1751 bound_b_wrong_protocol,
1752 bound_any_no_protocol,
1753 bound_any_wrong_protocol,
1754 ];
1755 let sockets_expecting_frames = [
1756 bound_a_all_protocols,
1757 bound_a_right_protocol,
1758 bound_any_all_protocols,
1759 bound_any_right_protocol,
1760 ];
1761
1762 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1763 assert!(
1764 sockets_with_received_frames.remove(&socket),
1765 "socket {n} didn't receive the frame"
1766 );
1767 }
1768 assert!(sockets_with_received_frames.is_empty());
1769
1770 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1772 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1773 }
1774 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1775 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1776 }
1777 }
1778
1779 #[test]
1780 fn sent_frame_deliver_to_multiple() {
1781 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1782 MultipleDevicesId::all(),
1783 )));
1784
1785 use Protocol::*;
1786 use TargetDevice::*;
1787 let never_bound = {
1788 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1789 ctx.device_socket_api().create(state)
1790 };
1791
1792 let mut make_bound = |device, protocol| {
1793 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1794 make_bound(&mut ctx, device, protocol, state)
1795 };
1796 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1797 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1798 let bound_a_same_protocol =
1799 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1800 let bound_a_wrong_protocol =
1801 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1802 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1803 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1804 let bound_b_same_protocol =
1805 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1806 let bound_b_wrong_protocol =
1807 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1808 let bound_any_no_protocol = make_bound(AnyDevice, None);
1809 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1810 let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1811 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1812
1813 let mut sockets_with_received_frames =
1814 deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1815
1816 let sockets_not_expecting_frames = [
1817 never_bound,
1818 bound_a_no_protocol,
1819 bound_a_same_protocol,
1820 bound_a_wrong_protocol,
1821 bound_b_no_protocol,
1822 bound_b_all_protocols,
1823 bound_b_same_protocol,
1824 bound_b_wrong_protocol,
1825 bound_any_no_protocol,
1826 bound_any_same_protocol,
1827 bound_any_wrong_protocol,
1828 ];
1829 let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1831
1832 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1833 assert!(
1834 sockets_with_received_frames.remove(&socket),
1835 "socket {n} didn't receive the frame"
1836 );
1837 }
1838 assert!(sockets_with_received_frames.is_empty());
1839
1840 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1842 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1843 }
1844 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1845 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1846 }
1847 }
1848
1849 #[test]
1850 fn deliver_multiple_frames() {
1851 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1852 MultipleDevicesId::all(),
1853 )));
1854 let socket = make_bound(
1855 &mut ctx,
1856 TargetDevice::AnyDevice,
1857 Some(Protocol::All),
1858 ExternalSocketState::default(),
1859 );
1860 let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1861
1862 const RECEIVE_COUNT: usize = 10;
1863 for _ in 0..RECEIVE_COUNT {
1864 DeviceSocketHandler::handle_frame(
1865 &mut core_ctx,
1866 &mut bindings_ctx,
1867 &MultipleDevicesId::A,
1868 super::ReceivedFrame::from_ethernet(
1869 TestData::frame(),
1870 FrameDestination::Individual { local: true },
1871 )
1872 .into(),
1873 TestData::BUFFER,
1874 );
1875 }
1876
1877 let FakeSockets {
1878 all_sockets: AllSockets(mut all_sockets),
1879 any_device_sockets: _,
1880 device_sockets: _,
1881 counters: _,
1882 sent_frames: _,
1883 } = core_ctx.into_state();
1884 let primary = all_sockets.remove(&socket).unwrap();
1885 let PrimaryDeviceSocketId(primary) = primary;
1886 assert!(all_sockets.is_empty());
1887 drop(socket);
1888 let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1889 PrimaryRc::unwrap(primary);
1890 assert_eq!(
1891 received.into_inner().frames,
1892 vec![
1893 ReceivedFrame {
1894 device: FakeWeakDeviceId(MultipleDevicesId::A),
1895 frame: Frame::Received(super::ReceivedFrame::Ethernet {
1896 destination: FrameDestination::Individual { local: true },
1897 frame: EthernetFrame {
1898 src_mac: TestData::SRC_MAC,
1899 dst_mac: TestData::DST_MAC,
1900 ethertype: Some(TestData::PROTO.get().into()),
1901 body_offset: TestData::BUFFER_OFFSET,
1902 body: Vec::from(TestData::BODY),
1903 }
1904 }),
1905 raw: TestData::BUFFER.into()
1906 };
1907 RECEIVE_COUNT
1908 ]
1909 );
1910 assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1911 }
1912
1913 #[test]
1914 fn deliver_frame_queue_full() {
1915 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1916 MultipleDevicesId::all(),
1917 )));
1918
1919 let sock1 = make_bound(
1921 &mut ctx,
1922 TargetDevice::AnyDevice,
1923 Some(Protocol::All),
1924 ExternalSocketState(Mutex::new(testutil::RxQueue { frames: vec![], max_size: 0 })),
1925 );
1926 let sock2 = make_bound(
1927 &mut ctx,
1928 TargetDevice::AnyDevice,
1929 Some(Protocol::All),
1930 ExternalSocketState::default(),
1931 );
1932
1933 let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1934
1935 DeviceSocketHandler::handle_frame(
1936 &mut core_ctx,
1937 &mut bindings_ctx,
1938 &MultipleDevicesId::A,
1939 super::ReceivedFrame::from_ethernet(
1940 TestData::frame(),
1941 FrameDestination::Individual { local: true },
1942 )
1943 .into(),
1944 TestData::BUFFER,
1945 );
1946
1947 assert_eq!(core_ctx.state.counters.rx_frames.get(), 2);
1948 assert_eq!(core_ctx.state.counters.rx_queue_full.get(), 1);
1949 assert_eq!(sock1.counters().rx_frames.get(), 1);
1950 assert_eq!(sock1.counters().rx_queue_full.get(), 1);
1951 assert_eq!(sock2.counters().rx_frames.get(), 1);
1952 assert_eq!(sock2.counters().rx_queue_full.get(), 0);
1953
1954 drop(sock1);
1957 drop(sock2);
1958 }
1959
1960 pub struct FakeSendMetadata;
1961 impl DeviceSocketSendTypes for AnyDevice {
1962 type Metadata = FakeSendMetadata;
1963 }
1964 impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1965 for DeviceSocketMetadata<AnyDevice, D>
1966 {
1967 fn send_meta<S>(
1968 self,
1969 core_ctx: &mut FakeCoreCtx<D>,
1970 _bindings_ctx: &mut BC,
1971 frame: S,
1972 ) -> Result<(), SendFrameError<S>>
1973 where
1974 S: NetworkSerializer,
1975 S::Buffer: BufferMut,
1976 {
1977 let frame = match frame.serialize_vec_outer(&mut NetworkSerializationContext::default())
1978 {
1979 Err(e) => {
1980 let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1981 unreachable!()
1982 }
1983 Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1984 };
1985 core_ctx.state.sent_frames.push(frame);
1986 Ok(())
1987 }
1988 }
1989
1990 #[test]
1991 fn send_multiple_frames() {
1992 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1993 MultipleDevicesId::all(),
1994 )));
1995
1996 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1997 let socket = make_bound(
1998 &mut ctx,
1999 TargetDevice::SpecificDevice(DEVICE),
2000 Some(Protocol::All),
2001 ExternalSocketState::default(),
2002 );
2003 let mut api = ctx.device_socket_api();
2004
2005 const SEND_COUNT: usize = 10;
2006 const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
2007 for _ in 0..SEND_COUNT {
2008 let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
2009 api.send_frame(
2010 &socket,
2011 DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
2012 buf,
2013 )
2014 .expect("send failed");
2015 }
2016
2017 assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
2018
2019 assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
2020 }
2021}