1use core::fmt::Debug;
8use core::hash::Hash;
9use core::num::NonZeroU16;
10
11use derivative::Derivative;
12use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
13use net_types::ethernet::Mac;
14use net_types::ip::IpVersion;
15use netstack3_base::socket::SocketCookie;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18 AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19 Inspector, InspectorDeviceExt, InspectorExt, ReferenceNotifiers, ReferenceNotifiersExt as _,
20 RemoveResourceResultWithContext, ResourceCounterContext, SendFrameContext,
21 SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use netstack3_hashmap::{HashMap, HashSet};
24use packet::{BufferMut, ParsablePacket as _, Serializer};
25use packet_formats::error::ParseError;
26use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
27
28use crate::internal::base::DeviceLayerTypes;
29use crate::internal::id::WeakDeviceId;
30
31#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
33pub enum Protocol {
34 All,
36 Specific(NonZeroU16),
38}
39
40#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
42#[derivative(Default(bound = ""))]
43pub enum TargetDevice<D> {
44 #[derivative(Default)]
46 AnyDevice,
47 SpecificDevice(D),
49}
50
51#[derive(Debug)]
53#[cfg_attr(test, derive(PartialEq))]
54pub struct SocketInfo<D> {
55 pub protocol: Option<Protocol>,
57 pub device: TargetDevice<D>,
59}
60
61pub trait DeviceSocketTypes {
64 type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
66}
67
68pub enum ReceiveFrameError {
70 QueueFull,
72}
73
74pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>:
76 DeviceSocketTypes + Sized
77{
78 fn receive_frame(
82 &self,
83 socket_id: &DeviceSocketId<DeviceId::Weak, Self>,
84 device: &DeviceId,
85 frame: Frame<&[u8]>,
86 raw_frame: &[u8],
87 ) -> Result<(), ReceiveFrameError>;
88}
89
90#[derive(Debug)]
94pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
95 PrimaryRc<SocketState<D, BT>>,
96);
97
98impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
99 fn new(external_state: BT::SocketState<D>) -> Self {
101 Self(PrimaryRc::new(SocketState {
102 external_state,
103 counters: Default::default(),
104 target: Default::default(),
105 }))
106 }
107
108 fn clone_strong(&self) -> DeviceSocketId<D, BT> {
110 let PrimaryDeviceSocketId(rc) = self;
111 DeviceSocketId(PrimaryRc::clone_strong(rc))
112 }
113}
114
115#[derive(Derivative)]
120#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
121pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
122 StrongRc<SocketState<D, BT>>,
123);
124
125impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
126 pub fn socket_cookie(&self) -> SocketCookie {
128 let Self(rc) = self;
129 SocketCookie::new(rc.resource_token())
130 }
131}
132
133impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
134 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135 let Self(rc) = self;
136 f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
137 }
138}
139
140impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
141 for DeviceSocketId<D, BT>
142{
143 type Lock = Mutex<Target<D>>;
144 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
145 let Self(rc) = self;
146 OrderedLockRef::new(&rc.target)
147 }
148}
149
150#[derive(Derivative)]
155#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
156pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
157 WeakRc<SocketState<D, BT>>,
158);
159
160impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
161 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162 let Self(rc) = self;
163 f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
164 }
165}
166
167#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
171 any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
174
175 all_sockets: RwLock<AllSockets<D, BT>>,
182}
183
184#[derive(Derivative)]
186#[derivative(Default(bound = ""))]
187pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
188 HashSet<DeviceSocketId<D, BT>>,
189);
190
191#[derive(Derivative)]
193#[derivative(Default(bound = ""))]
194pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
195 HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
196);
197
198#[derive(Debug)]
200pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
201 pub external_state: BT::SocketState<D>,
203 target: Mutex<Target<D>>,
207 counters: DeviceSocketCounters,
209}
210
211#[derive(Debug, Derivative)]
213#[derivative(Default(bound = ""))]
214pub struct Target<D> {
215 protocol: Option<Protocol>,
216 device: TargetDevice<D>,
217}
218
219#[derive(Derivative)]
224#[derivative(Default(bound = ""))]
225#[cfg_attr(
226 test,
227 derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
228)]
229pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
230 HashSet<DeviceSocketId<D, BT>>,
231);
232
233pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
235
236pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
240
241pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
243 type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>;
245
246 fn with_all_device_sockets<
249 F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
250 R,
251 >(
252 &mut self,
253 cb: F,
254 ) -> R;
255
256 fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
259 &mut self,
260 cb: F,
261 ) -> R;
262
263 fn with_any_device_sockets<
265 F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
266 R,
267 >(
268 &mut self,
269 cb: F,
270 ) -> R;
271
272 fn with_any_device_sockets_mut<
274 F: FnOnce(
275 &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
276 &mut Self::SocketTablesCoreCtx<'_>,
277 ) -> R,
278 R,
279 >(
280 &mut self,
281 cb: F,
282 ) -> R;
283}
284
285pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
287 fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
289 &mut self,
290 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291 cb: F,
292 ) -> R;
293
294 fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
296 &mut self,
297 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
298 cb: F,
299 ) -> R;
300}
301
302pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
304 type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
306 + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
307
308 fn with_device_sockets<
311 F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
312 R,
313 >(
314 &mut self,
315 device: &Self::DeviceId,
316 cb: F,
317 ) -> R;
318
319 fn with_device_sockets_mut<
322 F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
323 R,
324 >(
325 &mut self,
326 device: &Self::DeviceId,
327 cb: F,
328 ) -> R;
329}
330
331enum MaybeUpdate<T> {
332 NoChange,
333 NewValue(T),
334}
335
336fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
337 core_ctx: &mut CC,
338 socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
339 new_device: TargetDevice<&CC::DeviceId>,
340 protocol_update: MaybeUpdate<Protocol>,
341) {
342 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
343 let old_device = core_ctx.with_socket_state_mut(socket, |Target { protocol, device }| {
349 match protocol_update {
350 MaybeUpdate::NewValue(p) => *protocol = Some(p),
351 MaybeUpdate::NoChange => (),
352 };
353 let old_device = match &device {
354 TargetDevice::SpecificDevice(device) => device.upgrade(),
355 TargetDevice::AnyDevice => {
356 assert!(any_device_sockets.remove(socket));
357 None
358 }
359 };
360 *device = match &new_device {
361 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
362 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
363 };
364 old_device
365 });
366
367 if let Some(device) = old_device {
373 core_ctx.with_device_sockets_mut(
376 &device,
377 |DeviceSockets(device_sockets), _core_ctx| {
378 assert!(device_sockets.remove(socket), "socket not found in device state");
379 },
380 );
381 }
382
383 match &new_device {
385 TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
386 new_device,
387 |DeviceSockets(device_sockets), _core_ctx| {
388 assert!(device_sockets.insert(socket.clone()));
389 },
390 ),
391 TargetDevice::AnyDevice => {
392 assert!(any_device_sockets.insert(socket.clone()))
393 }
394 }
395 })
396}
397
398pub struct DeviceSocketApi<C>(C);
400
401impl<C> DeviceSocketApi<C> {
402 pub fn new(ctx: C) -> Self {
404 Self(ctx)
405 }
406}
407
408type ApiSocketId<C> = DeviceSocketId<
413 <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
414 <C as ContextPair>::BindingsContext,
415>;
416
417impl<C> DeviceSocketApi<C>
418where
419 C: ContextPair,
420 C::CoreContext: DeviceSocketContext<C::BindingsContext>
421 + SocketStateAccessor<C::BindingsContext>
422 + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
423 C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
424 + ReferenceNotifiers
425 + 'static,
426{
427 fn core_ctx(&mut self) -> &mut C::CoreContext {
428 let Self(pair) = self;
429 pair.core_ctx()
430 }
431
432 fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
433 let Self(pair) = self;
434 pair.contexts()
435 }
436
437 pub fn create(
439 &mut self,
440 external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
441 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
442 >,
443 ) -> ApiSocketId<C> {
444 let core_ctx = self.core_ctx();
445
446 let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
447 let primary = PrimaryDeviceSocketId::new(external_state);
448 let strong = primary.clone_strong();
449 assert!(sockets.insert(strong.clone(), primary).is_none());
450 strong
451 });
452 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
453 assert!(any_device_sockets.insert(strong.clone()));
460 });
461 strong
462 }
463
464 pub fn set_device(
466 &mut self,
467 socket: &ApiSocketId<C>,
468 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
469 ) {
470 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
471 }
472
473 pub fn set_device_and_protocol(
475 &mut self,
476 socket: &ApiSocketId<C>,
477 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
478 protocol: Protocol,
479 ) {
480 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
481 }
482
483 pub fn get_info(
485 &mut self,
486 id: &ApiSocketId<C>,
487 ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
488 self.core_ctx().with_socket_state(id, |Target { device, protocol }| SocketInfo {
489 device: device.clone(),
490 protocol: *protocol,
491 })
492 }
493
494 pub fn remove(
496 &mut self,
497 id: ApiSocketId<C>,
498 ) -> RemoveResourceResultWithContext<
499 <C::BindingsContext as DeviceSocketTypes>::SocketState<
500 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
501 >,
502 C::BindingsContext,
503 > {
504 let core_ctx = self.core_ctx();
505 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
506 let old_device = core_ctx.with_socket_state_mut(&id, |target| {
507 let Target { device, protocol: _ } = target;
508 match &device {
509 TargetDevice::SpecificDevice(device) => device.upgrade(),
510 TargetDevice::AnyDevice => {
511 assert!(any_device_sockets.remove(&id));
512 None
513 }
514 }
515 });
516 if let Some(device) = old_device {
517 core_ctx.with_device_sockets_mut(
518 &device,
519 |DeviceSockets(device_sockets), _core_ctx| {
520 assert!(device_sockets.remove(&id), "device doesn't have socket");
521 },
522 )
523 }
524 });
525
526 core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
527 let primary = sockets
528 .remove(&id)
529 .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
530 drop(id);
533
534 let PrimaryDeviceSocketId(primary) = primary;
535 C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
536 primary,
537 |SocketState { external_state, counters: _, target: _ }| external_state,
538 )
539 })
540 }
541
542 pub fn send_frame<S, D>(
544 &mut self,
545 id: &ApiSocketId<C>,
546 metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
547 body: S,
548 ) -> Result<(), SendFrameErrorReason>
549 where
550 S: Serializer,
551 S::Buffer: BufferMut,
552 D: DeviceSocketSendTypes,
553 C::CoreContext: DeviceIdContext<D>
554 + SendFrameContext<
555 C::BindingsContext,
556 DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
557 >,
558 C::BindingsContext: DeviceLayerTypes,
559 {
560 let (core_ctx, bindings_ctx) = self.contexts();
561 let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
562 match &result {
563 Ok(()) => {
564 core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
565 }
566 Err(SendFrameErrorReason::QueueFull) => core_ctx
567 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
568 Err(SendFrameErrorReason::Alloc) => core_ctx
569 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
570 Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
571 .increment_both(id, |counters: &DeviceSocketCounters| {
572 &counters.tx_err_size_constraint
573 }),
574 }
575 result
576 }
577
578 pub fn inspect<N>(&mut self, inspector: &mut N)
580 where
581 N: Inspector
582 + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
583 {
584 self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
585 sockets.keys().for_each(|socket| {
586 inspector.record_debug_child(socket, |node| {
587 core_ctx.with_socket_state(socket, |Target { protocol, device }| {
588 node.record_debug("Protocol", protocol);
589 match device {
590 TargetDevice::AnyDevice => node.record_str("Device", "Any"),
591 TargetDevice::SpecificDevice(d) => N::record_device(node, "Device", d),
592 }
593 });
594 node.record_child("Counters", |node| {
595 node.delegate_inspectable(socket.counters())
596 })
597 })
598 })
599 })
600 }
601}
602
603pub trait DeviceSocketSendTypes: Device {
605 type Metadata;
607}
608
609#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612 pub device_id: DeviceId,
614 pub metadata: D::Metadata,
616 }
619
620#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623 pub dest_addr: Mac,
625 pub protocol: EtherType,
627}
628
629pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636 pub fn socket_state(&self) -> &BT::SocketState<D> {
639 let Self(strong) = self;
640 let SocketState { external_state, counters: _, target: _ } = &**strong;
641 external_state
642 }
643
644 pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646 let Self(inner) = self;
647 WeakDeviceSocketId(StrongRc::downgrade(inner))
648 }
649
650 pub fn counters(&self) -> &DeviceSocketCounters {
652 let Self(strong) = self;
653 let SocketState { external_state: _, counters, target: _ } = &**strong;
654 counters
655 }
656}
657
658pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663 fn handle_frame(
665 &mut self,
666 bindings_ctx: &mut BC,
667 device: &Self::DeviceId,
668 frame: Frame<&[u8]>,
669 whole_frame: &[u8],
670 );
671}
672
673#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676 Ethernet {
678 destination: FrameDestination,
680 frame: EthernetFrame<B>,
682 },
683 Ip(IpFrame<B>),
688}
689
690#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693 Ethernet(EthernetFrame<B>),
695 Ip(IpFrame<B>),
700}
701
702#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707 pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709 packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710 .map_err(|_: ParseError| ParseSentFrameError)
711 .map(|frame| SentFrame::Ethernet(frame.into()))
712 }
713}
714
715#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718 pub src_mac: Mac,
720 pub dst_mac: Mac,
722 pub ethertype: Option<EtherType>,
724 pub body_offset: usize,
726 pub body: B,
728}
729
730#[derive(Clone, Copy, Debug, Eq, PartialEq)]
732pub struct IpFrame<B> {
733 pub ip_version: IpVersion,
735 pub body: B,
737}
738
739impl<B> IpFrame<B> {
740 fn ethertype(&self) -> EtherType {
741 let IpFrame { ip_version, body: _ } = self;
742 EtherType::from_ip_version(*ip_version)
743 }
744}
745
746#[derive(Clone, Copy, Debug, Eq, PartialEq)]
748pub enum Frame<B> {
749 Sent(SentFrame<B>),
751 Received(ReceivedFrame<B>),
753}
754
755impl<B> From<SentFrame<B>> for Frame<B> {
756 fn from(value: SentFrame<B>) -> Self {
757 Self::Sent(value)
758 }
759}
760
761impl<B> From<ReceivedFrame<B>> for Frame<B> {
762 fn from(value: ReceivedFrame<B>) -> Self {
763 Self::Received(value)
764 }
765}
766
767impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
768 fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
769 Self {
770 src_mac: frame.src_mac(),
771 dst_mac: frame.dst_mac(),
772 ethertype: frame.ethertype(),
773 body_offset: frame.parse_metadata().header_len(),
774 body: frame.into_body(),
775 }
776 }
777}
778
779impl<'a> ReceivedFrame<&'a [u8]> {
780 pub(crate) fn from_ethernet(
781 frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
782 destination: FrameDestination,
783 ) -> Self {
784 Self::Ethernet { destination, frame: frame.into() }
785 }
786}
787
788impl<B> Frame<B> {
789 pub fn protocol(&self) -> Option<u16> {
791 let ethertype = match self {
792 Self::Sent(SentFrame::Ethernet(frame))
793 | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
794 Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
795 Some(frame.ethertype())
796 }
797 };
798 ethertype.map(Into::into)
799 }
800
801 pub fn into_body(self) -> B {
803 match self {
804 Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
805 | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
806 Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
807 frame.body
808 }
809 }
810 }
811
812 pub fn body_offset(&self) -> usize {
814 match self {
815 Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
816 | Self::Sent(SentFrame::Ethernet(frame)) => frame.body_offset,
817 Self::Received(ReceivedFrame::Ip(_)) | Self::Sent(SentFrame::Ip(_)) => 0,
818 }
819 }
820}
821
822impl<
823 D: Device,
824 BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
825 CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
826> DeviceSocketHandler<D, BC> for CC
827where
828 <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
829{
830 fn handle_frame(
831 &mut self,
832 bindings_ctx: &mut BC,
833 device: &Self::DeviceId,
834 frame: Frame<&[u8]>,
835 whole_frame: &[u8],
836 ) {
837 let device = device.clone().into();
838
839 self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
843 core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
858 for socket in any_device_sockets.iter().chain(device_sockets) {
859 let delivered =
860 core_ctx.with_socket_state(socket, |Target { protocol, device: _ }| {
861 let should_deliver = match protocol {
862 None => false,
863 Some(p) => match p {
864 Protocol::Specific(p) => match frame {
868 Frame::Received(_) => Some(p.get()) == frame.protocol(),
869 Frame::Sent(_) => false,
870 },
871 Protocol::All => true,
872 },
873 };
874 should_deliver.then(|| {
875 bindings_ctx.receive_frame(socket, &device, frame, whole_frame)
876 })
877 });
878 match delivered {
879 None => {}
880 Some(result) => {
881 core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
882 &counters.rx_frames
883 });
884 match result {
885 Ok(()) => {}
886 Err(ReceiveFrameError::QueueFull) => {
887 core_ctx.increment_both(
888 socket,
889 |counters: &DeviceSocketCounters| &counters.rx_queue_full,
890 );
891 }
892 }
893 }
894 }
895 }
896 })
897 })
898 }
899}
900
901#[derive(Debug, Default)]
905pub struct DeviceSocketCounters {
906 rx_frames: Counter,
912 rx_queue_full: Counter,
915 tx_frames: Counter,
917 tx_err_queue_full: Counter,
919 tx_err_alloc: Counter,
921 tx_err_size_constraint: Counter,
923}
924
925impl Inspectable for DeviceSocketCounters {
926 fn record<I: Inspector>(&self, inspector: &mut I) {
927 let Self {
928 rx_frames,
929 rx_queue_full,
930 tx_frames,
931 tx_err_queue_full,
932 tx_err_alloc,
933 tx_err_size_constraint,
934 } = self;
935 inspector.record_child("Rx", |inspector| {
936 inspector.record_counter("DeliveredFrames", rx_frames);
937 inspector.record_counter("DroppedQueueFull", rx_queue_full);
938 });
939 inspector.record_child("Tx", |inspector| {
940 inspector.record_counter("SentFrames", tx_frames);
941 inspector.record_counter("QueueFullError", tx_err_queue_full);
942 inspector.record_counter("AllocError", tx_err_alloc);
943 inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
944 });
945 }
946}
947
948impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
949 for Sockets<D, BT>
950{
951 type Lock = RwLock<AnyDeviceSockets<D, BT>>;
952 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
953 OrderedLockRef::new(&self.any_device_sockets)
954 }
955}
956
957impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
958 for Sockets<D, BT>
959{
960 type Lock = RwLock<AllSockets<D, BT>>;
961 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
962 OrderedLockRef::new(&self.all_sockets)
963 }
964}
965
966#[cfg(any(test, feature = "testutils"))]
967mod testutil {
968 use alloc::vec::Vec;
969 use core::num::NonZeroU64;
970 use core::ops::DerefMut;
971 use netstack3_base::StrongDeviceIdentifier;
972 use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
973
974 use super::*;
975 use crate::internal::base::{
976 DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
977 };
978
979 #[derive(Derivative, Debug)]
980 #[derivative(Default(bound = ""))]
981 pub struct RxQueue<D> {
982 pub frames: Vec<ReceivedFrame<D>>,
983 #[derivative(Default(value = "usize::MAX"))]
984 pub max_size: usize,
985 }
986
987 #[derive(Clone, Debug, PartialEq)]
988 pub struct ReceivedFrame<D> {
989 pub device: D,
990 pub frame: Frame<Vec<u8>>,
991 pub raw: Vec<u8>,
992 }
993
994 #[derive(Debug, Derivative)]
995 #[derivative(Default(bound = ""))]
996 pub struct ExternalSocketState<D>(pub Mutex<RxQueue<D>>);
997
998 impl<TimerId, Event: Debug, State> DeviceSocketTypes
999 for FakeBindingsCtx<TimerId, Event, State, ()>
1000 {
1001 type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
1002 }
1003
1004 impl Frame<&[u8]> {
1005 pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
1006 match self {
1007 Self::Sent(SentFrame::Ethernet(frame)) => {
1008 Frame::Sent(SentFrame::Ethernet(frame.cloned()))
1009 }
1010 Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
1011 Frame::Received(super::ReceivedFrame::Ethernet {
1012 destination,
1013 frame: frame.cloned(),
1014 })
1015 }
1016 Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
1017 Self::Received(super::ReceivedFrame::Ip(frame)) => {
1018 Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
1019 }
1020 }
1021 }
1022 }
1023
1024 impl EthernetFrame<&[u8]> {
1025 fn cloned(self) -> EthernetFrame<Vec<u8>> {
1026 let Self { src_mac, dst_mac, ethertype, body_offset, body } = self;
1027 EthernetFrame { src_mac, dst_mac, ethertype, body_offset, body: Vec::from(body) }
1028 }
1029 }
1030
1031 impl IpFrame<&[u8]> {
1032 fn cloned(self) -> IpFrame<Vec<u8>> {
1033 let Self { ip_version, body } = self;
1034 IpFrame { ip_version, body: Vec::from(body) }
1035 }
1036 }
1037
1038 impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1039 for FakeBindingsCtx<TimerId, Event, State, ()>
1040 {
1041 fn receive_frame(
1042 &self,
1043 state: &DeviceSocketId<D::Weak, Self>,
1044 device: &D,
1045 frame: Frame<&[u8]>,
1046 raw_frame: &[u8],
1047 ) -> Result<(), ReceiveFrameError> {
1048 let ExternalSocketState(queue) = state.socket_state();
1049 let mut lock_guard = queue.lock();
1050 let RxQueue { frames, max_size } = lock_guard.deref_mut();
1051 if frames.len() < *max_size {
1052 frames.push(ReceivedFrame {
1053 device: device.downgrade(),
1054 frame: frame.cloned(),
1055 raw: raw_frame.into(),
1056 });
1057 Ok(())
1058 } else {
1059 Err(ReceiveFrameError::QueueFull)
1060 }
1061 }
1062 }
1063
1064 impl<
1065 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1066 Event: Debug + 'static,
1067 State: 'static,
1068 > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1069 {
1070 type EthernetDeviceState = ();
1071 type LoopbackDeviceState = ();
1072 type PureIpDeviceState = ();
1073 type BlackholeDeviceState = ();
1074 type DeviceIdentifier = MonotonicIdentifier;
1075 }
1076
1077 impl DeviceClassMatcher<()> for () {
1078 fn device_class_matches(&self, (): &()) -> bool {
1079 unimplemented!()
1080 }
1081 }
1082
1083 impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1084 fn id_matches(&self, _id: &NonZeroU64) -> bool {
1085 unimplemented!()
1086 }
1087
1088 fn name_matches(&self, _name: &str) -> bool {
1089 unimplemented!()
1090 }
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use alloc::vec;
1097 use alloc::vec::Vec;
1098 use core::marker::PhantomData;
1099 use core::ops::Deref;
1100
1101 use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1102 use netstack3_base::testutil::{
1103 FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1104 };
1105 use netstack3_base::{CounterContext, CtxPair, SendFrameError, SendableFrameMeta};
1106 use netstack3_hashmap::HashMap;
1107 use packet::ParsablePacket;
1108 use test_case::test_case;
1109
1110 use super::*;
1111
1112 type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1113 type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1114 type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1115
1116 trait DeviceSocketApiExt: ContextPair + Sized {
1119 fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1120 DeviceSocketApi::new(self)
1121 }
1122 }
1123
1124 impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1125
1126 #[derive(Derivative)]
1127 #[derivative(Default(bound = ""))]
1128 struct FakeSockets<D: FakeStrongDeviceId> {
1129 any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1130 device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1131 all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1132 counters: DeviceSocketCounters,
1134 sent_frames: Vec<Vec<u8>>,
1135 }
1136
1137 pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1139 &'m mut AnyDevice,
1140 &'m mut AllSockets,
1141 &'m mut Devices,
1142 PhantomData<Device>,
1143 &'m DeviceSocketCounters,
1144 );
1145
1146 pub trait AsFakeSocketsMutRefs {
1149 type AnyDevice: 'static;
1150 type AllSockets: 'static;
1151 type Devices: 'static;
1152 type Device: 'static;
1153 fn as_sockets_ref(
1154 &mut self,
1155 ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1156 }
1157
1158 impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1159 type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1160 type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1161 type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1162 type Device = D;
1163
1164 fn as_sockets_ref(
1165 &mut self,
1166 ) -> FakeSocketsMutRefs<
1167 '_,
1168 AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1169 AllSockets<D::Weak, FakeBindingsCtx>,
1170 HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1171 D,
1172 > {
1173 let FakeSockets {
1174 any_device_sockets,
1175 device_sockets,
1176 all_sockets,
1177 counters,
1178 sent_frames: _,
1179 } = &mut self.state;
1180 FakeSocketsMutRefs(
1181 any_device_sockets,
1182 all_sockets,
1183 device_sockets,
1184 PhantomData,
1185 counters,
1186 )
1187 }
1188 }
1189
1190 impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1191 AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1192 {
1193 type AnyDevice = AnyDevice;
1194 type AllSockets = AllSockets;
1195 type Devices = Devices;
1196 type Device = Device;
1197
1198 fn as_sockets_ref(
1199 &mut self,
1200 ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1201 let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1202 FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1203 }
1204 }
1205
1206 impl<D: Clone> TargetDevice<&D> {
1207 fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1208 match self {
1209 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1210 TargetDevice::SpecificDevice(d) => {
1211 TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1212 }
1213 }
1214 }
1215 }
1216
1217 impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1218 fn new(devices: impl IntoIterator<Item = D>) -> Self {
1219 let device_sockets =
1220 devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1221 Self {
1222 any_device_sockets: AnyDeviceSockets::default(),
1223 device_sockets,
1224 all_sockets: Default::default(),
1225 counters: Default::default(),
1226 sent_frames: Default::default(),
1227 }
1228 }
1229 }
1230
1231 impl<
1232 'm,
1233 DeviceId: FakeStrongDeviceId,
1234 As: AsFakeSocketsMutRefs
1235 + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1236 > SocketStateAccessor<FakeBindingsCtx> for As
1237 {
1238 fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
1239 &mut self,
1240 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1241 cb: F,
1242 ) -> R {
1243 let DeviceSocketId(rc) = socket;
1244 let target = rc.target.lock();
1246 cb(&target)
1247 }
1248
1249 fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
1250 &mut self,
1251 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1252 cb: F,
1253 ) -> R {
1254 let DeviceSocketId(rc) = socket;
1255 let mut target = rc.target.lock();
1257 cb(&mut target)
1258 }
1259 }
1260
1261 impl<
1262 'm,
1263 DeviceId: FakeStrongDeviceId,
1264 As: AsFakeSocketsMutRefs<
1265 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1266 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1267 > DeviceSocketAccessor<FakeBindingsCtx> for As
1268 {
1269 type DeviceSocketCoreCtx<'a> =
1270 FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1271 fn with_device_sockets<
1272 F: FnOnce(
1273 &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1274 &mut Self::DeviceSocketCoreCtx<'_>,
1275 ) -> R,
1276 R,
1277 >(
1278 &mut self,
1279 device: &Self::DeviceId,
1280 cb: F,
1281 ) -> R {
1282 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1283 self.as_sockets_ref();
1284 let mut devices = device_sockets.keys().cloned().collect();
1285 let device = device_sockets.get(device).unwrap();
1286 cb(
1287 device,
1288 &mut FakeSocketsMutRefs(
1289 any_device,
1290 all_sockets,
1291 &mut devices,
1292 PhantomData,
1293 counters,
1294 ),
1295 )
1296 }
1297 fn with_device_sockets_mut<
1298 F: FnOnce(
1299 &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1300 &mut Self::DeviceSocketCoreCtx<'_>,
1301 ) -> R,
1302 R,
1303 >(
1304 &mut self,
1305 device: &Self::DeviceId,
1306 cb: F,
1307 ) -> R {
1308 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1309 self.as_sockets_ref();
1310 let mut devices = device_sockets.keys().cloned().collect();
1311 let device = device_sockets.get_mut(device).unwrap();
1312 cb(
1313 device,
1314 &mut FakeSocketsMutRefs(
1315 any_device,
1316 all_sockets,
1317 &mut devices,
1318 PhantomData,
1319 counters,
1320 ),
1321 )
1322 }
1323 }
1324
1325 impl<
1326 'm,
1327 DeviceId: FakeStrongDeviceId,
1328 As: AsFakeSocketsMutRefs<
1329 AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1330 AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1331 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1332 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1333 > DeviceSocketContext<FakeBindingsCtx> for As
1334 {
1335 type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1336 'a,
1337 (),
1338 (),
1339 HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1340 DeviceId,
1341 >;
1342
1343 fn with_any_device_sockets<
1344 F: FnOnce(
1345 &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1346 &mut Self::SocketTablesCoreCtx<'_>,
1347 ) -> R,
1348 R,
1349 >(
1350 &mut self,
1351 cb: F,
1352 ) -> R {
1353 let FakeSocketsMutRefs(
1354 any_device_sockets,
1355 _all_sockets,
1356 device_sockets,
1357 PhantomData,
1358 counters,
1359 ) = self.as_sockets_ref();
1360 cb(
1361 any_device_sockets,
1362 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1363 )
1364 }
1365 fn with_any_device_sockets_mut<
1366 F: FnOnce(
1367 &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1368 &mut Self::SocketTablesCoreCtx<'_>,
1369 ) -> R,
1370 R,
1371 >(
1372 &mut self,
1373 cb: F,
1374 ) -> R {
1375 let FakeSocketsMutRefs(
1376 any_device_sockets,
1377 _all_sockets,
1378 device_sockets,
1379 PhantomData,
1380 counters,
1381 ) = self.as_sockets_ref();
1382 cb(
1383 any_device_sockets,
1384 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1385 )
1386 }
1387
1388 fn with_all_device_sockets<
1389 F: FnOnce(
1390 &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1391 &mut Self::SocketTablesCoreCtx<'_>,
1392 ) -> R,
1393 R,
1394 >(
1395 &mut self,
1396 cb: F,
1397 ) -> R {
1398 let FakeSocketsMutRefs(
1399 _any_device_sockets,
1400 all_sockets,
1401 device_sockets,
1402 PhantomData,
1403 counters,
1404 ) = self.as_sockets_ref();
1405 cb(
1406 all_sockets,
1407 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1408 )
1409 }
1410
1411 fn with_all_device_sockets_mut<
1412 F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1413 R,
1414 >(
1415 &mut self,
1416 cb: F,
1417 ) -> R {
1418 let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1419 cb(all_sockets)
1420 }
1421 }
1422
1423 impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1424 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1425 {
1426 type DeviceId = D;
1427 type WeakDeviceId = FakeWeakDeviceId<D>;
1428 }
1429
1430 impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1431 fn counters(&self) -> &DeviceSocketCounters {
1432 &self.state.counters
1433 }
1434 }
1435
1436 impl<D: FakeStrongDeviceId>
1437 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1438 for FakeCoreCtx<D>
1439 {
1440 fn per_resource_counters<'a>(
1441 &'a self,
1442 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1443 ) -> &'a DeviceSocketCounters {
1444 socket.counters()
1445 }
1446 }
1447
1448 impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1449 fn counters(&self) -> &DeviceSocketCounters {
1450 let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1451 counters
1452 }
1453 }
1454
1455 impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1456 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1457 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1458 {
1459 fn per_resource_counters<'a>(
1460 &'a self,
1461 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1462 ) -> &'a DeviceSocketCounters {
1463 socket.counters()
1464 }
1465 }
1466
1467 const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1468
1469 #[test]
1470 fn create_remove() {
1471 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1472 MultipleDevicesId::all(),
1473 )));
1474 let mut api = ctx.device_socket_api();
1475
1476 let bound = api.create(Default::default());
1477 assert_eq!(
1478 api.get_info(&bound),
1479 SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1480 );
1481
1482 let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1483 }
1484
1485 #[test_case(TargetDevice::AnyDevice)]
1486 #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1487 fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1488 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1489 MultipleDevicesId::all(),
1490 )));
1491 let mut api = ctx.device_socket_api();
1492
1493 let bound = api.create(Default::default());
1494 api.set_device(&bound, device.clone());
1495 assert_eq!(
1496 api.get_info(&bound),
1497 SocketInfo { device: device.with_weak_id(), protocol: None }
1498 );
1499
1500 let device_sockets = &api.core_ctx().state.device_sockets;
1501 if let TargetDevice::SpecificDevice(d) = device {
1502 let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1503 assert_eq!(socket_ids, &HashSet::from([bound]));
1504 }
1505 }
1506
1507 #[test]
1508 fn update_device() {
1509 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1510 MultipleDevicesId::all(),
1511 )));
1512 let mut api = ctx.device_socket_api();
1513 let bound = api.create(Default::default());
1514
1515 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1516
1517 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1520 assert_eq!(
1521 api.get_info(&bound),
1522 SocketInfo {
1523 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1524 protocol: None
1525 }
1526 );
1527
1528 let device_sockets = &api.core_ctx().state.device_sockets;
1529 let device_socket_lists = device_sockets
1530 .iter()
1531 .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1532 .collect::<HashMap<_, _>>();
1533
1534 assert_eq!(
1535 device_socket_lists,
1536 HashMap::from([
1537 (&MultipleDevicesId::A, vec![]),
1538 (&MultipleDevicesId::B, vec![&bound]),
1539 (&MultipleDevicesId::C, vec![])
1540 ])
1541 );
1542 }
1543
1544 #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1545 #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1546 #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1547 #[test_case(
1548 Protocol::Specific(SOME_PROTOCOL),
1549 TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1550 )]
1551 fn create_set_device_and_protocol_remove_multiple(
1552 protocol: Protocol,
1553 device: TargetDevice<&MultipleDevicesId>,
1554 ) {
1555 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1556 MultipleDevicesId::all(),
1557 )));
1558 let mut api = ctx.device_socket_api();
1559
1560 let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1561 for socket in &mut sockets {
1562 api.set_device_and_protocol(socket, device.clone(), protocol);
1563 assert_eq!(
1564 api.get_info(socket),
1565 SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1566 );
1567 }
1568
1569 for socket in sockets {
1570 let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1571 }
1572 }
1573
1574 #[test]
1575 fn change_device_after_removal() {
1576 let device_to_remove = FakeReferencyDeviceId::default();
1577 let device_to_maintain = FakeReferencyDeviceId::default();
1578 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1579 device_to_remove.clone(),
1580 device_to_maintain.clone(),
1581 ])));
1582 let mut api = ctx.device_socket_api();
1583
1584 let bound = api.create(Default::default());
1585 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1588
1589 device_to_remove.mark_removed();
1592
1593 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1596 assert_eq!(
1597 api.get_info(&bound),
1598 SocketInfo {
1599 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1600 protocol: None,
1601 }
1602 );
1603
1604 let device_sockets = &api.core_ctx().state.device_sockets;
1605 let DeviceSockets(weak_sockets) =
1606 device_sockets.get(&device_to_maintain).expect("device state exists");
1607 assert_eq!(weak_sockets, &HashSet::from([bound]));
1608 }
1609
1610 struct TestData;
1611 impl TestData {
1612 const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1613 const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1614 const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1616 const BODY: &'static [u8] = b"some pig";
1617 const BUFFER: &'static [u8] = &[
1618 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1619 b'i', b'g',
1620 ];
1621 const BUFFER_OFFSET: usize = Self::BUFFER.len() - Self::BODY.len();
1622
1623 fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1625 let mut buffer_view = Self::BUFFER;
1626 packet_formats::ethernet::EthernetFrame::parse(
1627 &mut buffer_view,
1628 EthernetFrameLengthCheck::NoCheck,
1629 )
1630 .unwrap()
1631 }
1632 }
1633
1634 const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1635
1636 fn make_bound<D: FakeStrongDeviceId>(
1637 ctx: &mut FakeCtx<D>,
1638 device: TargetDevice<D>,
1639 protocol: Option<Protocol>,
1640 state: ExternalSocketState<D::Weak>,
1641 ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1642 let mut api = ctx.device_socket_api();
1643 let id = api.create(state);
1644 let device = match &device {
1645 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1646 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1647 };
1648 match protocol {
1649 Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1650 None => api.set_device(&id, device),
1651 };
1652 id
1653 }
1654
1655 fn deliver_one_frame(
1658 delivered_frame: Frame<&[u8]>,
1659 FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1660 ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1661 DeviceSocketHandler::handle_frame(
1662 core_ctx,
1663 bindings_ctx,
1664 &MultipleDevicesId::A,
1665 delivered_frame.clone(),
1666 TestData::BUFFER,
1667 );
1668
1669 let FakeSockets {
1670 all_sockets: AllSockets(all_sockets),
1671 any_device_sockets: _,
1672 device_sockets: _,
1673 counters: _,
1674 sent_frames: _,
1675 } = &core_ctx.state;
1676
1677 all_sockets
1678 .iter()
1679 .filter_map(|(id, _primary)| {
1680 let DeviceSocketId(rc) = &id;
1681 let ExternalSocketState(frames) = &rc.external_state;
1682 let lock_guard = frames.lock();
1683 let testutil::RxQueue { frames, .. } = lock_guard.deref();
1684 (!frames.is_empty()).then(|| {
1685 assert_eq!(
1686 &*frames,
1687 &[ReceivedFrame {
1688 device: FakeWeakDeviceId(MultipleDevicesId::A),
1689 frame: delivered_frame.cloned(),
1690 raw: TestData::BUFFER.into(),
1691 }]
1692 );
1693 id.clone()
1694 })
1695 })
1696 .collect()
1697 }
1698
1699 #[test]
1700 fn receive_frame_deliver_to_multiple() {
1701 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1702 MultipleDevicesId::all(),
1703 )));
1704
1705 use Protocol::*;
1706 use TargetDevice::*;
1707 let never_bound = {
1708 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1709 ctx.device_socket_api().create(state)
1710 };
1711
1712 let mut make_bound = |device, protocol| {
1713 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1714 make_bound(&mut ctx, device, protocol, state)
1715 };
1716 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1717 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1718 let bound_a_right_protocol =
1719 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1720 let bound_a_wrong_protocol =
1721 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1722 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1723 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1724 let bound_b_right_protocol =
1725 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1726 let bound_b_wrong_protocol =
1727 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1728 let bound_any_no_protocol = make_bound(AnyDevice, None);
1729 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1730 let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1731 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1732
1733 let mut sockets_with_received_frames = deliver_one_frame(
1734 super::ReceivedFrame::from_ethernet(
1735 TestData::frame(),
1736 FrameDestination::Individual { local: true },
1737 )
1738 .into(),
1739 &mut ctx,
1740 );
1741
1742 let sockets_not_expecting_frames = [
1743 never_bound,
1744 bound_a_no_protocol,
1745 bound_a_wrong_protocol,
1746 bound_b_no_protocol,
1747 bound_b_all_protocols,
1748 bound_b_right_protocol,
1749 bound_b_wrong_protocol,
1750 bound_any_no_protocol,
1751 bound_any_wrong_protocol,
1752 ];
1753 let sockets_expecting_frames = [
1754 bound_a_all_protocols,
1755 bound_a_right_protocol,
1756 bound_any_all_protocols,
1757 bound_any_right_protocol,
1758 ];
1759
1760 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1761 assert!(
1762 sockets_with_received_frames.remove(&socket),
1763 "socket {n} didn't receive the frame"
1764 );
1765 }
1766 assert!(sockets_with_received_frames.is_empty());
1767
1768 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1770 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1771 }
1772 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1773 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1774 }
1775 }
1776
1777 #[test]
1778 fn sent_frame_deliver_to_multiple() {
1779 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1780 MultipleDevicesId::all(),
1781 )));
1782
1783 use Protocol::*;
1784 use TargetDevice::*;
1785 let never_bound = {
1786 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1787 ctx.device_socket_api().create(state)
1788 };
1789
1790 let mut make_bound = |device, protocol| {
1791 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1792 make_bound(&mut ctx, device, protocol, state)
1793 };
1794 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1795 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1796 let bound_a_same_protocol =
1797 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1798 let bound_a_wrong_protocol =
1799 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1800 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1801 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1802 let bound_b_same_protocol =
1803 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1804 let bound_b_wrong_protocol =
1805 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1806 let bound_any_no_protocol = make_bound(AnyDevice, None);
1807 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1808 let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1809 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1810
1811 let mut sockets_with_received_frames =
1812 deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1813
1814 let sockets_not_expecting_frames = [
1815 never_bound,
1816 bound_a_no_protocol,
1817 bound_a_same_protocol,
1818 bound_a_wrong_protocol,
1819 bound_b_no_protocol,
1820 bound_b_all_protocols,
1821 bound_b_same_protocol,
1822 bound_b_wrong_protocol,
1823 bound_any_no_protocol,
1824 bound_any_same_protocol,
1825 bound_any_wrong_protocol,
1826 ];
1827 let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1829
1830 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1831 assert!(
1832 sockets_with_received_frames.remove(&socket),
1833 "socket {n} didn't receive the frame"
1834 );
1835 }
1836 assert!(sockets_with_received_frames.is_empty());
1837
1838 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1840 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1841 }
1842 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1843 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1844 }
1845 }
1846
1847 #[test]
1848 fn deliver_multiple_frames() {
1849 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1850 MultipleDevicesId::all(),
1851 )));
1852 let socket = make_bound(
1853 &mut ctx,
1854 TargetDevice::AnyDevice,
1855 Some(Protocol::All),
1856 ExternalSocketState::default(),
1857 );
1858 let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1859
1860 const RECEIVE_COUNT: usize = 10;
1861 for _ in 0..RECEIVE_COUNT {
1862 DeviceSocketHandler::handle_frame(
1863 &mut core_ctx,
1864 &mut bindings_ctx,
1865 &MultipleDevicesId::A,
1866 super::ReceivedFrame::from_ethernet(
1867 TestData::frame(),
1868 FrameDestination::Individual { local: true },
1869 )
1870 .into(),
1871 TestData::BUFFER,
1872 );
1873 }
1874
1875 let FakeSockets {
1876 all_sockets: AllSockets(mut all_sockets),
1877 any_device_sockets: _,
1878 device_sockets: _,
1879 counters: _,
1880 sent_frames: _,
1881 } = core_ctx.into_state();
1882 let primary = all_sockets.remove(&socket).unwrap();
1883 let PrimaryDeviceSocketId(primary) = primary;
1884 assert!(all_sockets.is_empty());
1885 drop(socket);
1886 let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1887 PrimaryRc::unwrap(primary);
1888 assert_eq!(
1889 received.into_inner().frames,
1890 vec![
1891 ReceivedFrame {
1892 device: FakeWeakDeviceId(MultipleDevicesId::A),
1893 frame: Frame::Received(super::ReceivedFrame::Ethernet {
1894 destination: FrameDestination::Individual { local: true },
1895 frame: EthernetFrame {
1896 src_mac: TestData::SRC_MAC,
1897 dst_mac: TestData::DST_MAC,
1898 ethertype: Some(TestData::PROTO.get().into()),
1899 body_offset: TestData::BUFFER_OFFSET,
1900 body: Vec::from(TestData::BODY),
1901 }
1902 }),
1903 raw: TestData::BUFFER.into()
1904 };
1905 RECEIVE_COUNT
1906 ]
1907 );
1908 assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1909 }
1910
1911 #[test]
1912 fn deliver_frame_queue_full() {
1913 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1914 MultipleDevicesId::all(),
1915 )));
1916
1917 let sock1 = make_bound(
1919 &mut ctx,
1920 TargetDevice::AnyDevice,
1921 Some(Protocol::All),
1922 ExternalSocketState(Mutex::new(testutil::RxQueue { frames: vec![], max_size: 0 })),
1923 );
1924 let sock2 = make_bound(
1925 &mut ctx,
1926 TargetDevice::AnyDevice,
1927 Some(Protocol::All),
1928 ExternalSocketState::default(),
1929 );
1930
1931 let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1932
1933 DeviceSocketHandler::handle_frame(
1934 &mut core_ctx,
1935 &mut bindings_ctx,
1936 &MultipleDevicesId::A,
1937 super::ReceivedFrame::from_ethernet(
1938 TestData::frame(),
1939 FrameDestination::Individual { local: true },
1940 )
1941 .into(),
1942 TestData::BUFFER,
1943 );
1944
1945 assert_eq!(core_ctx.state.counters.rx_frames.get(), 2);
1946 assert_eq!(core_ctx.state.counters.rx_queue_full.get(), 1);
1947 assert_eq!(sock1.counters().rx_frames.get(), 1);
1948 assert_eq!(sock1.counters().rx_queue_full.get(), 1);
1949 assert_eq!(sock2.counters().rx_frames.get(), 1);
1950 assert_eq!(sock2.counters().rx_queue_full.get(), 0);
1951
1952 drop(sock1);
1955 drop(sock2);
1956 }
1957
1958 pub struct FakeSendMetadata;
1959 impl DeviceSocketSendTypes for AnyDevice {
1960 type Metadata = FakeSendMetadata;
1961 }
1962 impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1963 for DeviceSocketMetadata<AnyDevice, D>
1964 {
1965 fn send_meta<S>(
1966 self,
1967 core_ctx: &mut FakeCoreCtx<D>,
1968 _bindings_ctx: &mut BC,
1969 frame: S,
1970 ) -> Result<(), SendFrameError<S>>
1971 where
1972 S: packet::Serializer,
1973 S::Buffer: packet::BufferMut,
1974 {
1975 let frame = match frame.serialize_vec_outer() {
1976 Err(e) => {
1977 let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1978 unreachable!()
1979 }
1980 Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1981 };
1982 core_ctx.state.sent_frames.push(frame);
1983 Ok(())
1984 }
1985 }
1986
1987 #[test]
1988 fn send_multiple_frames() {
1989 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1990 MultipleDevicesId::all(),
1991 )));
1992
1993 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1994 let socket = make_bound(
1995 &mut ctx,
1996 TargetDevice::SpecificDevice(DEVICE),
1997 Some(Protocol::All),
1998 ExternalSocketState::default(),
1999 );
2000 let mut api = ctx.device_socket_api();
2001
2002 const SEND_COUNT: usize = 10;
2003 const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
2004 for _ in 0..SEND_COUNT {
2005 let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
2006 api.send_frame(
2007 &socket,
2008 DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
2009 buf,
2010 )
2011 .expect("send failed");
2012 }
2013
2014 assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
2015
2016 assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
2017 }
2018}