netstack3_device/
socket.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Link-layer sockets (analogous to Linux's AF_PACKET sockets).
6
7use core::fmt::Debug;
8use core::hash::Hash;
9use core::num::NonZeroU16;
10
11use derivative::Derivative;
12use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
13use net_types::ethernet::Mac;
14use net_types::ip::IpVersion;
15use netstack3_base::socket::SocketCookie;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18    AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19    Inspector, InspectorDeviceExt, InspectorExt, ReferenceNotifiers, ReferenceNotifiersExt as _,
20    RemoveResourceResultWithContext, ResourceCounterContext, SendFrameContext,
21    SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use netstack3_hashmap::{HashMap, HashSet};
24use packet::{BufferMut, ParsablePacket as _, Serializer};
25use packet_formats::error::ParseError;
26use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
27
28use crate::internal::base::DeviceLayerTypes;
29use crate::internal::id::WeakDeviceId;
30
31/// A selector for frames based on link-layer protocol number.
32#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
33pub enum Protocol {
34    /// Select all frames, regardless of protocol number.
35    All,
36    /// Select frames with the given protocol number.
37    Specific(NonZeroU16),
38}
39
40/// Selector for devices to send and receive packets on.
41#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
42#[derivative(Default(bound = ""))]
43pub enum TargetDevice<D> {
44    /// Act on any device in the system.
45    #[derivative(Default)]
46    AnyDevice,
47    /// Act on a specific device.
48    SpecificDevice(D),
49}
50
51/// Information about the bound state of a socket.
52#[derive(Debug)]
53#[cfg_attr(test, derive(PartialEq))]
54pub struct SocketInfo<D> {
55    /// The protocol the socket is bound to, or `None` if no protocol is set.
56    pub protocol: Option<Protocol>,
57    /// The device selector for which the socket is set.
58    pub device: TargetDevice<D>,
59}
60
61/// Provides associated types for device sockets provided by the bindings
62/// context.
63pub trait DeviceSocketTypes {
64    /// State for the socket held by core and exposed to bindings.
65    type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
66}
67
68/// Errors that Bindings may encounter when receiving frames on a Device Socket.
69pub enum ReceiveFrameError {
70    /// The socket's receive queue is full and can't hold the frame.
71    QueueFull,
72}
73
74/// The execution context for device sockets provided by bindings.
75pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>:
76    DeviceSocketTypes + Sized
77{
78    /// Called for each received frame that matches the provided socket.
79    ///
80    /// `frame` and `raw_frame` are parsed and raw views into the same data.
81    fn receive_frame(
82        &self,
83        socket_id: &DeviceSocketId<DeviceId::Weak, Self>,
84        device: &DeviceId,
85        frame: Frame<&[u8]>,
86        raw_frame: &[u8],
87    ) -> Result<(), ReceiveFrameError>;
88}
89
90/// Strong owner of socket state.
91///
92/// This type strongly owns the socket state.
93#[derive(Debug)]
94pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
95    PrimaryRc<SocketState<D, BT>>,
96);
97
98impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
99    /// Creates a new socket ID with `external_state`.
100    fn new(external_state: BT::SocketState<D>) -> Self {
101        Self(PrimaryRc::new(SocketState {
102            external_state,
103            counters: Default::default(),
104            target: Default::default(),
105        }))
106    }
107
108    /// Clones the primary's underlying reference and returns as a strong id.
109    fn clone_strong(&self) -> DeviceSocketId<D, BT> {
110        let PrimaryDeviceSocketId(rc) = self;
111        DeviceSocketId(PrimaryRc::clone_strong(rc))
112    }
113}
114
115/// Reference to live socket state.
116///
117/// The existence of a `StrongId` attests to the liveness of the state of the
118/// backing socket.
119#[derive(Derivative)]
120#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
121pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
122    StrongRc<SocketState<D, BT>>,
123);
124
125impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
126    /// Returns [`SocketCookie`] for this socket.
127    pub fn socket_cookie(&self) -> SocketCookie {
128        let Self(rc) = self;
129        SocketCookie::new(rc.resource_token())
130    }
131}
132
133impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
134    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135        let Self(rc) = self;
136        f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
137    }
138}
139
140impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
141    for DeviceSocketId<D, BT>
142{
143    type Lock = Mutex<Target<D>>;
144    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
145        let Self(rc) = self;
146        OrderedLockRef::new(&rc.target)
147    }
148}
149
150/// A weak reference to socket state.
151///
152/// The existence of a [`WeakSocketDeviceId`] does not attest to the liveness of
153/// the backing socket.
154#[derive(Derivative)]
155#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
156pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
157    WeakRc<SocketState<D, BT>>,
158);
159
160impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
161    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162        let Self(rc) = self;
163        f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
164    }
165}
166
167/// Holds shared state for sockets.
168#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
171    /// Holds strong (but not owning) references to sockets that aren't
172    /// targeting a particular device.
173    any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
174
175    /// Table of all sockets in the system, regardless of target.
176    ///
177    /// Holds the primary (owning) reference for all sockets.
178    // This needs to be after `any_device_sockets` so that when an instance of
179    // this type is dropped, any strong IDs get dropped before their
180    // corresponding primary IDs.
181    all_sockets: RwLock<AllSockets<D, BT>>,
182}
183
184/// The set of sockets associated with a device.
185#[derive(Derivative)]
186#[derivative(Default(bound = ""))]
187pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
188    HashSet<DeviceSocketId<D, BT>>,
189);
190
191/// A collection of all device sockets in the system.
192#[derive(Derivative)]
193#[derivative(Default(bound = ""))]
194pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
195    HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
196);
197
198/// State held by a device socket.
199#[derive(Debug)]
200pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
201    /// State provided by bindings that is held in core.
202    pub external_state: BT::SocketState<D>,
203    /// The socket's target device and protocol.
204    // TODO(https://fxbug.dev/42077026): Consider splitting up the state here to
205    // improve performance.
206    target: Mutex<Target<D>>,
207    /// Statistics about the socket's usage.
208    counters: DeviceSocketCounters,
209}
210
211/// A device socket's binding information.
212#[derive(Debug, Derivative)]
213#[derivative(Default(bound = ""))]
214pub struct Target<D> {
215    protocol: Option<Protocol>,
216    device: TargetDevice<D>,
217}
218
219/// Per-device state for packet sockets.
220///
221/// Holds sockets that are bound to a particular device. An instance of this
222/// should be held in the state for each device in the system.
223#[derive(Derivative)]
224#[derivative(Default(bound = ""))]
225#[cfg_attr(
226    test,
227    derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
228)]
229pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
230    HashSet<DeviceSocketId<D, BT>>,
231);
232
233/// Convenience alias for use in device state storage.
234pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
235
236/// Convenience alias for use in shared storage.
237///
238/// The type parameter is expected to implement [`DeviceSocketTypes`].
239pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
240
241/// Core context for accessing socket state.
242pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
243    /// The core context available in callbacks to methods on this context.
244    type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>;
245
246    /// Executes the provided callback with access to the collection of all
247    /// sockets.
248    fn with_all_device_sockets<
249        F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
250        R,
251    >(
252        &mut self,
253        cb: F,
254    ) -> R;
255
256    /// Executes the provided callback with mutable access to the collection of
257    /// all sockets.
258    fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
259        &mut self,
260        cb: F,
261    ) -> R;
262
263    /// Executes the provided callback with immutable access to socket state.
264    fn with_any_device_sockets<
265        F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
266        R,
267    >(
268        &mut self,
269        cb: F,
270    ) -> R;
271
272    /// Executes the provided callback with mutable access to socket state.
273    fn with_any_device_sockets_mut<
274        F: FnOnce(
275            &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
276            &mut Self::SocketTablesCoreCtx<'_>,
277        ) -> R,
278        R,
279    >(
280        &mut self,
281        cb: F,
282    ) -> R;
283}
284
285/// Core context for accessing the state of an individual socket.
286pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
287    /// Provides read-only access to the state of a socket.
288    fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
289        &mut self,
290        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291        cb: F,
292    ) -> R;
293
294    /// Provides mutable access to the state of a socket.
295    fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
296        &mut self,
297        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
298        cb: F,
299    ) -> R;
300}
301
302/// Core context for accessing the socket state for a device.
303pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
304    /// Core context available in callbacks to methods on this context.
305    type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
306        + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
307
308    /// Executes the provided callback with immutable access to device-specific
309    /// socket state.
310    fn with_device_sockets<
311        F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
312        R,
313    >(
314        &mut self,
315        device: &Self::DeviceId,
316        cb: F,
317    ) -> R;
318
319    /// Executes the provided callback with mutable access to device-specific
320    /// socket state.
321    fn with_device_sockets_mut<
322        F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
323        R,
324    >(
325        &mut self,
326        device: &Self::DeviceId,
327        cb: F,
328    ) -> R;
329}
330
331enum MaybeUpdate<T> {
332    NoChange,
333    NewValue(T),
334}
335
336fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
337    core_ctx: &mut CC,
338    socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
339    new_device: TargetDevice<&CC::DeviceId>,
340    protocol_update: MaybeUpdate<Protocol>,
341) {
342    core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
343        // Even if we're never moving the socket from/to the any-device
344        // state, we acquire the lock to make the move between devices
345        // atomic from the perspective of frame delivery. Otherwise there
346        // would be a brief period during which arriving frames wouldn't be
347        // delivered to the socket from either device.
348        let old_device = core_ctx.with_socket_state_mut(socket, |Target { protocol, device }| {
349            match protocol_update {
350                MaybeUpdate::NewValue(p) => *protocol = Some(p),
351                MaybeUpdate::NoChange => (),
352            };
353            let old_device = match &device {
354                TargetDevice::SpecificDevice(device) => device.upgrade(),
355                TargetDevice::AnyDevice => {
356                    assert!(any_device_sockets.remove(socket));
357                    None
358                }
359            };
360            *device = match &new_device {
361                TargetDevice::AnyDevice => TargetDevice::AnyDevice,
362                TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
363            };
364            old_device
365        });
366
367        // This modification occurs without holding the socket's individual
368        // lock. That's safe because all modifications to the socket's
369        // device are done within a `with_sockets_mut` call, which
370        // synchronizes them.
371
372        if let Some(device) = old_device {
373            // Remove the reference to the socket from the old device if
374            // there is one, and it hasn't been removed.
375            core_ctx.with_device_sockets_mut(
376                &device,
377                |DeviceSockets(device_sockets), _core_ctx| {
378                    assert!(device_sockets.remove(socket), "socket not found in device state");
379                },
380            );
381        }
382
383        // Add the reference to the new device, if there is one.
384        match &new_device {
385            TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
386                new_device,
387                |DeviceSockets(device_sockets), _core_ctx| {
388                    assert!(device_sockets.insert(socket.clone()));
389                },
390            ),
391            TargetDevice::AnyDevice => {
392                assert!(any_device_sockets.insert(socket.clone()))
393            }
394        }
395    })
396}
397
398/// The device socket API.
399pub struct DeviceSocketApi<C>(C);
400
401impl<C> DeviceSocketApi<C> {
402    /// Creates a new `DeviceSocketApi` for `ctx`.
403    pub fn new(ctx: C) -> Self {
404        Self(ctx)
405    }
406}
407
408/// A local alias for [`DeviceSocketId`] for use in [`DeviceSocketApi`].
409///
410/// TODO(https://github.com/rust-lang/rust/issues/8995): Make this an inherent
411/// associated type.
412type ApiSocketId<C> = DeviceSocketId<
413    <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
414    <C as ContextPair>::BindingsContext,
415>;
416
417impl<C> DeviceSocketApi<C>
418where
419    C: ContextPair,
420    C::CoreContext: DeviceSocketContext<C::BindingsContext>
421        + SocketStateAccessor<C::BindingsContext>
422        + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
423    C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
424        + ReferenceNotifiers
425        + 'static,
426{
427    fn core_ctx(&mut self) -> &mut C::CoreContext {
428        let Self(pair) = self;
429        pair.core_ctx()
430    }
431
432    fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
433        let Self(pair) = self;
434        pair.contexts()
435    }
436
437    /// Creates an packet socket with no protocol set configured for all devices.
438    pub fn create(
439        &mut self,
440        external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
441            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
442        >,
443    ) -> ApiSocketId<C> {
444        let core_ctx = self.core_ctx();
445
446        let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
447            let primary = PrimaryDeviceSocketId::new(external_state);
448            let strong = primary.clone_strong();
449            assert!(sockets.insert(strong.clone(), primary).is_none());
450            strong
451        });
452        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
453            // On creation, sockets do not target any device or protocol.
454            // Inserting them into the `any_device_sockets` table lets us treat
455            // newly-created sockets uniformly with sockets whose target device
456            // or protocol was set. The difference is unobservable at runtime
457            // since newly-created sockets won't match any frames being
458            // delivered.
459            assert!(any_device_sockets.insert(strong.clone()));
460        });
461        strong
462    }
463
464    /// Sets the device for which a packet socket will receive packets.
465    pub fn set_device(
466        &mut self,
467        socket: &ApiSocketId<C>,
468        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
469    ) {
470        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
471    }
472
473    /// Sets the device and protocol for which a socket will receive packets.
474    pub fn set_device_and_protocol(
475        &mut self,
476        socket: &ApiSocketId<C>,
477        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
478        protocol: Protocol,
479    ) {
480        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
481    }
482
483    /// Gets the bound info for a socket.
484    pub fn get_info(
485        &mut self,
486        id: &ApiSocketId<C>,
487    ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
488        self.core_ctx().with_socket_state(id, |Target { device, protocol }| SocketInfo {
489            device: device.clone(),
490            protocol: *protocol,
491        })
492    }
493
494    /// Removes a bound socket.
495    pub fn remove(
496        &mut self,
497        id: ApiSocketId<C>,
498    ) -> RemoveResourceResultWithContext<
499        <C::BindingsContext as DeviceSocketTypes>::SocketState<
500            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
501        >,
502        C::BindingsContext,
503    > {
504        let core_ctx = self.core_ctx();
505        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
506            let old_device = core_ctx.with_socket_state_mut(&id, |target| {
507                let Target { device, protocol: _ } = target;
508                match &device {
509                    TargetDevice::SpecificDevice(device) => device.upgrade(),
510                    TargetDevice::AnyDevice => {
511                        assert!(any_device_sockets.remove(&id));
512                        None
513                    }
514                }
515            });
516            if let Some(device) = old_device {
517                core_ctx.with_device_sockets_mut(
518                    &device,
519                    |DeviceSockets(device_sockets), _core_ctx| {
520                        assert!(device_sockets.remove(&id), "device doesn't have socket");
521                    },
522                )
523            }
524        });
525
526        core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
527            let primary = sockets
528                .remove(&id)
529                .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
530            // Make sure to drop the strong ID before trying to unwrap the primary
531            // ID.
532            drop(id);
533
534            let PrimaryDeviceSocketId(primary) = primary;
535            C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
536                primary,
537                |SocketState { external_state, counters: _, target: _ }| external_state,
538            )
539        })
540    }
541
542    /// Sends a frame for the specified socket.
543    pub fn send_frame<S, D>(
544        &mut self,
545        id: &ApiSocketId<C>,
546        metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
547        body: S,
548    ) -> Result<(), SendFrameErrorReason>
549    where
550        S: Serializer,
551        S::Buffer: BufferMut,
552        D: DeviceSocketSendTypes,
553        C::CoreContext: DeviceIdContext<D>
554            + SendFrameContext<
555                C::BindingsContext,
556                DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
557            >,
558        C::BindingsContext: DeviceLayerTypes,
559    {
560        let (core_ctx, bindings_ctx) = self.contexts();
561        let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
562        match &result {
563            Ok(()) => {
564                core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
565            }
566            Err(SendFrameErrorReason::QueueFull) => core_ctx
567                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
568            Err(SendFrameErrorReason::Alloc) => core_ctx
569                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
570            Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
571                .increment_both(id, |counters: &DeviceSocketCounters| {
572                    &counters.tx_err_size_constraint
573                }),
574        }
575        result
576    }
577
578    /// Provides inspect data for raw IP sockets.
579    pub fn inspect<N>(&mut self, inspector: &mut N)
580    where
581        N: Inspector
582            + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
583    {
584        self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
585            sockets.keys().for_each(|socket| {
586                inspector.record_debug_child(socket, |node| {
587                    core_ctx.with_socket_state(socket, |Target { protocol, device }| {
588                        node.record_debug("Protocol", protocol);
589                        match device {
590                            TargetDevice::AnyDevice => node.record_str("Device", "Any"),
591                            TargetDevice::SpecificDevice(d) => N::record_device(node, "Device", d),
592                        }
593                    });
594                    node.record_child("Counters", |node| {
595                        node.delegate_inspectable(socket.counters())
596                    })
597                })
598            })
599        })
600    }
601}
602
603/// A provider of the types required to send on a device socket.
604pub trait DeviceSocketSendTypes: Device {
605    /// The metadata required to send a frame on the device.
606    type Metadata;
607}
608
609/// Metadata required to send a frame on a device socket.
610#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612    /// The device ID to send via.
613    pub device_id: DeviceId,
614    /// The metadata required to send that's specific to the device type.
615    pub metadata: D::Metadata,
616    // TODO(https://fxbug.dev/391946195): Include send buffer ownership metadata
617    // here.
618}
619
620/// Parameters needed to apply system-framing of an Ethernet frame.
621#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623    /// The destination MAC address to send to.
624    pub dest_addr: Mac,
625    /// The upperlayer protocol of the data contained in this Ethernet frame.
626    pub protocol: EtherType,
627}
628
629/// Public identifier for a socket.
630///
631/// Strongly owns the state of the socket. So long as the `SocketId` for a
632/// socket is not dropped, the socket is guaranteed to exist.
633pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636    /// Provides immutable access to [`DeviceSocketTypes::SocketState`] for the
637    /// socket.
638    pub fn socket_state(&self) -> &BT::SocketState<D> {
639        let Self(strong) = self;
640        let SocketState { external_state, counters: _, target: _ } = &**strong;
641        external_state
642    }
643
644    /// Obtain a [`WeakDeviceSocketId`] from this [`DeviceSocketId`].
645    pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646        let Self(inner) = self;
647        WeakDeviceSocketId(StrongRc::downgrade(inner))
648    }
649
650    /// Provides access to the socket's counters.
651    pub fn counters(&self) -> &DeviceSocketCounters {
652        let Self(strong) = self;
653        let SocketState { external_state: _, counters, target: _ } = &**strong;
654        counters
655    }
656}
657
658/// Allows the rest of the stack to dispatch packets to listening sockets.
659///
660/// This is implemented on top of [`DeviceSocketContext`] and abstracts packet
661/// socket delivery from the rest of the system.
662pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663    /// Dispatch a received frame to sockets.
664    fn handle_frame(
665        &mut self,
666        bindings_ctx: &mut BC,
667        device: &Self::DeviceId,
668        frame: Frame<&[u8]>,
669        whole_frame: &[u8],
670    );
671}
672
673/// A frame received on a device.
674#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676    /// An ethernet frame received on a device.
677    Ethernet {
678        /// Where the frame was destined.
679        destination: FrameDestination,
680        /// The parsed ethernet frame.
681        frame: EthernetFrame<B>,
682    },
683    /// An IP frame received on a device.
684    ///
685    /// Note that this is not an IP packet within an Ethernet Frame. This is an
686    /// IP packet received directly from the device (e.g. a pure IP device).
687    Ip(IpFrame<B>),
688}
689
690/// A frame sent on a device.
691#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693    /// An ethernet frame sent on a device.
694    Ethernet(EthernetFrame<B>),
695    /// An IP frame sent on a device.
696    ///
697    /// Note that this is not an IP packet within an Ethernet Frame. This is an
698    /// IP Packet send directly on the device (e.g. a pure IP device).
699    Ip(IpFrame<B>),
700}
701
702/// A frame couldn't be parsed as a [`SentFrame`].
703#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707    /// Tries to parse the given frame as an Ethernet frame.
708    pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709        packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710            .map_err(|_: ParseError| ParseSentFrameError)
711            .map(|frame| SentFrame::Ethernet(frame.into()))
712    }
713}
714
715/// Data from an Ethernet frame.
716#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718    /// The source address of the frame.
719    pub src_mac: Mac,
720    /// The destination address of the frame.
721    pub dst_mac: Mac,
722    /// The EtherType of the frame, or `None` if there was none.
723    pub ethertype: Option<EtherType>,
724    /// The offset of the body within the frame.
725    pub body_offset: usize,
726    /// The body of the frame.
727    pub body: B,
728}
729
730/// Data from an IP frame.
731#[derive(Clone, Copy, Debug, Eq, PartialEq)]
732pub struct IpFrame<B> {
733    /// The IP version of the frame.
734    pub ip_version: IpVersion,
735    /// The body of the frame.
736    pub body: B,
737}
738
739impl<B> IpFrame<B> {
740    fn ethertype(&self) -> EtherType {
741        let IpFrame { ip_version, body: _ } = self;
742        EtherType::from_ip_version(*ip_version)
743    }
744}
745
746/// A frame sent or received on a device
747#[derive(Clone, Copy, Debug, Eq, PartialEq)]
748pub enum Frame<B> {
749    /// A sent frame.
750    Sent(SentFrame<B>),
751    /// A received frame.
752    Received(ReceivedFrame<B>),
753}
754
755impl<B> From<SentFrame<B>> for Frame<B> {
756    fn from(value: SentFrame<B>) -> Self {
757        Self::Sent(value)
758    }
759}
760
761impl<B> From<ReceivedFrame<B>> for Frame<B> {
762    fn from(value: ReceivedFrame<B>) -> Self {
763        Self::Received(value)
764    }
765}
766
767impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
768    fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
769        Self {
770            src_mac: frame.src_mac(),
771            dst_mac: frame.dst_mac(),
772            ethertype: frame.ethertype(),
773            body_offset: frame.parse_metadata().header_len(),
774            body: frame.into_body(),
775        }
776    }
777}
778
779impl<'a> ReceivedFrame<&'a [u8]> {
780    pub(crate) fn from_ethernet(
781        frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
782        destination: FrameDestination,
783    ) -> Self {
784        Self::Ethernet { destination, frame: frame.into() }
785    }
786}
787
788impl<B> Frame<B> {
789    /// Returns ether type for the packet if it's known.
790    pub fn protocol(&self) -> Option<u16> {
791        let ethertype = match self {
792            Self::Sent(SentFrame::Ethernet(frame))
793            | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
794            Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
795                Some(frame.ethertype())
796            }
797        };
798        ethertype.map(Into::into)
799    }
800
801    /// Convenience method for consuming the `Frame` and producing the body.
802    pub fn into_body(self) -> B {
803        match self {
804            Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
805            | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
806            Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
807                frame.body
808            }
809        }
810    }
811
812    /// Returns the offset of the body within the frame.
813    pub fn body_offset(&self) -> usize {
814        match self {
815            Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
816            | Self::Sent(SentFrame::Ethernet(frame)) => frame.body_offset,
817            Self::Received(ReceivedFrame::Ip(_)) | Self::Sent(SentFrame::Ip(_)) => 0,
818        }
819    }
820}
821
822impl<
823    D: Device,
824    BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
825    CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
826> DeviceSocketHandler<D, BC> for CC
827where
828    <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
829{
830    fn handle_frame(
831        &mut self,
832        bindings_ctx: &mut BC,
833        device: &Self::DeviceId,
834        frame: Frame<&[u8]>,
835        whole_frame: &[u8],
836    ) {
837        let device = device.clone().into();
838
839        // TODO(https://fxbug.dev/42076496): Invert the order of acquisition
840        // for the lock on the sockets held in the device and the any-device
841        // sockets lock.
842        self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
843            // Iterate through the device's sockets while also holding the
844            // any-device sockets lock. This prevents double delivery to the
845            // same socket. If the two tables were locked independently,
846            // we could end up with a race, with the following thread
847            // interleaving (thread A is executing this code for device D,
848            // thread B is updating the device to D for the same socket X):
849            //   A) lock the any device sockets table
850            //   A) deliver to socket X in the table
851            //   A) unlock the any device sockets table
852            //   B) lock the any device sockets table, then D's sockets
853            //   B) remove X from the any table and add to D's
854            //   B) unlock D's sockets and any device sockets
855            //   A) lock D's sockets
856            //   A) deliver to socket X in D's table (!)
857            core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
858                for socket in any_device_sockets.iter().chain(device_sockets) {
859                    let delivered =
860                        core_ctx.with_socket_state(socket, |Target { protocol, device: _ }| {
861                            let should_deliver = match protocol {
862                                None => false,
863                                Some(p) => match p {
864                                    // Sent frames are only delivered to sockets
865                                    // matching all protocols for Linux
866                                    // compatibility. See https://github.com/google/gvisor/blob/68eae979409452209e4faaeac12aee4191b3d6f0/test/syscalls/linux/packet_socket.cc#L331-L392.
867                                    Protocol::Specific(p) => match frame {
868                                        Frame::Received(_) => Some(p.get()) == frame.protocol(),
869                                        Frame::Sent(_) => false,
870                                    },
871                                    Protocol::All => true,
872                                },
873                            };
874                            should_deliver.then(|| {
875                                bindings_ctx.receive_frame(socket, &device, frame, whole_frame)
876                            })
877                        });
878                    match delivered {
879                        None => {}
880                        Some(result) => {
881                            core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
882                                &counters.rx_frames
883                            });
884                            match result {
885                                Ok(()) => {}
886                                Err(ReceiveFrameError::QueueFull) => {
887                                    core_ctx.increment_both(
888                                        socket,
889                                        |counters: &DeviceSocketCounters| &counters.rx_queue_full,
890                                    );
891                                }
892                            }
893                        }
894                    }
895                }
896            })
897        })
898    }
899}
900
901/// Usage statistics about Device Sockets.
902///
903/// Tracked stack-wide and per-socket.
904#[derive(Debug, Default)]
905pub struct DeviceSocketCounters {
906    /// Count of incoming frames that were delivered to the socket.
907    ///
908    /// Note that a single frame may be delivered to multiple device sockets.
909    /// Thus this counter, when tracking the stack-wide aggregate, may exceed
910    /// the total number of frames received by the stack.
911    rx_frames: Counter,
912    /// Count of incoming frames that could not be delivered to a socket because
913    /// its receive buffer was full.
914    rx_queue_full: Counter,
915    /// Count of outgoing frames that were sent by the socket.
916    tx_frames: Counter,
917    /// Count of failed tx frames due to [`SendFrameErrorReason::QueueFull`].
918    tx_err_queue_full: Counter,
919    /// Count of failed tx frames due to [`SendFrameErrorReason::Alloc`].
920    tx_err_alloc: Counter,
921    /// Count of failed tx frames due to [`SendFrameErrorReason::SizeConstraintsViolation`].
922    tx_err_size_constraint: Counter,
923}
924
925impl Inspectable for DeviceSocketCounters {
926    fn record<I: Inspector>(&self, inspector: &mut I) {
927        let Self {
928            rx_frames,
929            rx_queue_full,
930            tx_frames,
931            tx_err_queue_full,
932            tx_err_alloc,
933            tx_err_size_constraint,
934        } = self;
935        inspector.record_child("Rx", |inspector| {
936            inspector.record_counter("DeliveredFrames", rx_frames);
937            inspector.record_counter("DroppedQueueFull", rx_queue_full);
938        });
939        inspector.record_child("Tx", |inspector| {
940            inspector.record_counter("SentFrames", tx_frames);
941            inspector.record_counter("QueueFullError", tx_err_queue_full);
942            inspector.record_counter("AllocError", tx_err_alloc);
943            inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
944        });
945    }
946}
947
948impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
949    for Sockets<D, BT>
950{
951    type Lock = RwLock<AnyDeviceSockets<D, BT>>;
952    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
953        OrderedLockRef::new(&self.any_device_sockets)
954    }
955}
956
957impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
958    for Sockets<D, BT>
959{
960    type Lock = RwLock<AllSockets<D, BT>>;
961    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
962        OrderedLockRef::new(&self.all_sockets)
963    }
964}
965
966#[cfg(any(test, feature = "testutils"))]
967mod testutil {
968    use alloc::vec::Vec;
969    use core::num::NonZeroU64;
970    use core::ops::DerefMut;
971    use netstack3_base::StrongDeviceIdentifier;
972    use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
973
974    use super::*;
975    use crate::internal::base::{
976        DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
977    };
978
979    #[derive(Derivative, Debug)]
980    #[derivative(Default(bound = ""))]
981    pub struct RxQueue<D> {
982        pub frames: Vec<ReceivedFrame<D>>,
983        #[derivative(Default(value = "usize::MAX"))]
984        pub max_size: usize,
985    }
986
987    #[derive(Clone, Debug, PartialEq)]
988    pub struct ReceivedFrame<D> {
989        pub device: D,
990        pub frame: Frame<Vec<u8>>,
991        pub raw: Vec<u8>,
992    }
993
994    #[derive(Debug, Derivative)]
995    #[derivative(Default(bound = ""))]
996    pub struct ExternalSocketState<D>(pub Mutex<RxQueue<D>>);
997
998    impl<TimerId, Event: Debug, State> DeviceSocketTypes
999        for FakeBindingsCtx<TimerId, Event, State, ()>
1000    {
1001        type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
1002    }
1003
1004    impl Frame<&[u8]> {
1005        pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
1006            match self {
1007                Self::Sent(SentFrame::Ethernet(frame)) => {
1008                    Frame::Sent(SentFrame::Ethernet(frame.cloned()))
1009                }
1010                Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
1011                    Frame::Received(super::ReceivedFrame::Ethernet {
1012                        destination,
1013                        frame: frame.cloned(),
1014                    })
1015                }
1016                Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
1017                Self::Received(super::ReceivedFrame::Ip(frame)) => {
1018                    Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
1019                }
1020            }
1021        }
1022    }
1023
1024    impl EthernetFrame<&[u8]> {
1025        fn cloned(self) -> EthernetFrame<Vec<u8>> {
1026            let Self { src_mac, dst_mac, ethertype, body_offset, body } = self;
1027            EthernetFrame { src_mac, dst_mac, ethertype, body_offset, body: Vec::from(body) }
1028        }
1029    }
1030
1031    impl IpFrame<&[u8]> {
1032        fn cloned(self) -> IpFrame<Vec<u8>> {
1033            let Self { ip_version, body } = self;
1034            IpFrame { ip_version, body: Vec::from(body) }
1035        }
1036    }
1037
1038    impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1039        for FakeBindingsCtx<TimerId, Event, State, ()>
1040    {
1041        fn receive_frame(
1042            &self,
1043            state: &DeviceSocketId<D::Weak, Self>,
1044            device: &D,
1045            frame: Frame<&[u8]>,
1046            raw_frame: &[u8],
1047        ) -> Result<(), ReceiveFrameError> {
1048            let ExternalSocketState(queue) = state.socket_state();
1049            let mut lock_guard = queue.lock();
1050            let RxQueue { frames, max_size } = lock_guard.deref_mut();
1051            if frames.len() < *max_size {
1052                frames.push(ReceivedFrame {
1053                    device: device.downgrade(),
1054                    frame: frame.cloned(),
1055                    raw: raw_frame.into(),
1056                });
1057                Ok(())
1058            } else {
1059                Err(ReceiveFrameError::QueueFull)
1060            }
1061        }
1062    }
1063
1064    impl<
1065        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1066        Event: Debug + 'static,
1067        State: 'static,
1068    > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1069    {
1070        type EthernetDeviceState = ();
1071        type LoopbackDeviceState = ();
1072        type PureIpDeviceState = ();
1073        type BlackholeDeviceState = ();
1074        type DeviceIdentifier = MonotonicIdentifier;
1075    }
1076
1077    impl DeviceClassMatcher<()> for () {
1078        fn device_class_matches(&self, (): &()) -> bool {
1079            unimplemented!()
1080        }
1081    }
1082
1083    impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1084        fn id_matches(&self, _id: &NonZeroU64) -> bool {
1085            unimplemented!()
1086        }
1087
1088        fn name_matches(&self, _name: &str) -> bool {
1089            unimplemented!()
1090        }
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use alloc::vec;
1097    use alloc::vec::Vec;
1098    use core::marker::PhantomData;
1099    use core::ops::Deref;
1100
1101    use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1102    use netstack3_base::testutil::{
1103        FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1104    };
1105    use netstack3_base::{CounterContext, CtxPair, SendFrameError, SendableFrameMeta};
1106    use netstack3_hashmap::HashMap;
1107    use packet::ParsablePacket;
1108    use test_case::test_case;
1109
1110    use super::*;
1111
1112    type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1113    type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1114    type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1115
1116    /// A trait providing a shortcut to instantiate a [`DeviceSocketApi`] from a
1117    /// context.
1118    trait DeviceSocketApiExt: ContextPair + Sized {
1119        fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1120            DeviceSocketApi::new(self)
1121        }
1122    }
1123
1124    impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1125
1126    #[derive(Derivative)]
1127    #[derivative(Default(bound = ""))]
1128    struct FakeSockets<D: FakeStrongDeviceId> {
1129        any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1130        device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1131        all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1132        /// The stack-wide counters for device sockets.
1133        counters: DeviceSocketCounters,
1134        sent_frames: Vec<Vec<u8>>,
1135    }
1136
1137    /// Tuple of references
1138    pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1139        &'m mut AnyDevice,
1140        &'m mut AllSockets,
1141        &'m mut Devices,
1142        PhantomData<Device>,
1143        &'m DeviceSocketCounters,
1144    );
1145
1146    /// Helper trait to allow treating a `&mut self` as a
1147    /// [`FakeSocketsMutRefs`].
1148    pub trait AsFakeSocketsMutRefs {
1149        type AnyDevice: 'static;
1150        type AllSockets: 'static;
1151        type Devices: 'static;
1152        type Device: 'static;
1153        fn as_sockets_ref(
1154            &mut self,
1155        ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1156    }
1157
1158    impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1159        type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1160        type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1161        type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1162        type Device = D;
1163
1164        fn as_sockets_ref(
1165            &mut self,
1166        ) -> FakeSocketsMutRefs<
1167            '_,
1168            AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1169            AllSockets<D::Weak, FakeBindingsCtx>,
1170            HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1171            D,
1172        > {
1173            let FakeSockets {
1174                any_device_sockets,
1175                device_sockets,
1176                all_sockets,
1177                counters,
1178                sent_frames: _,
1179            } = &mut self.state;
1180            FakeSocketsMutRefs(
1181                any_device_sockets,
1182                all_sockets,
1183                device_sockets,
1184                PhantomData,
1185                counters,
1186            )
1187        }
1188    }
1189
1190    impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1191        AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1192    {
1193        type AnyDevice = AnyDevice;
1194        type AllSockets = AllSockets;
1195        type Devices = Devices;
1196        type Device = Device;
1197
1198        fn as_sockets_ref(
1199            &mut self,
1200        ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1201            let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1202            FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1203        }
1204    }
1205
1206    impl<D: Clone> TargetDevice<&D> {
1207        fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1208            match self {
1209                TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1210                TargetDevice::SpecificDevice(d) => {
1211                    TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1212                }
1213            }
1214        }
1215    }
1216
1217    impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1218        fn new(devices: impl IntoIterator<Item = D>) -> Self {
1219            let device_sockets =
1220                devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1221            Self {
1222                any_device_sockets: AnyDeviceSockets::default(),
1223                device_sockets,
1224                all_sockets: Default::default(),
1225                counters: Default::default(),
1226                sent_frames: Default::default(),
1227            }
1228        }
1229    }
1230
1231    impl<
1232        'm,
1233        DeviceId: FakeStrongDeviceId,
1234        As: AsFakeSocketsMutRefs
1235            + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1236    > SocketStateAccessor<FakeBindingsCtx> for As
1237    {
1238        fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
1239            &mut self,
1240            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1241            cb: F,
1242        ) -> R {
1243            let DeviceSocketId(rc) = socket;
1244            // NB: Circumvent lock ordering for tests.
1245            let target = rc.target.lock();
1246            cb(&target)
1247        }
1248
1249        fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
1250            &mut self,
1251            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1252            cb: F,
1253        ) -> R {
1254            let DeviceSocketId(rc) = socket;
1255            // NB: Circumvent lock ordering for tests.
1256            let mut target = rc.target.lock();
1257            cb(&mut target)
1258        }
1259    }
1260
1261    impl<
1262        'm,
1263        DeviceId: FakeStrongDeviceId,
1264        As: AsFakeSocketsMutRefs<
1265                Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1266            > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1267    > DeviceSocketAccessor<FakeBindingsCtx> for As
1268    {
1269        type DeviceSocketCoreCtx<'a> =
1270            FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1271        fn with_device_sockets<
1272            F: FnOnce(
1273                &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1274                &mut Self::DeviceSocketCoreCtx<'_>,
1275            ) -> R,
1276            R,
1277        >(
1278            &mut self,
1279            device: &Self::DeviceId,
1280            cb: F,
1281        ) -> R {
1282            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1283                self.as_sockets_ref();
1284            let mut devices = device_sockets.keys().cloned().collect();
1285            let device = device_sockets.get(device).unwrap();
1286            cb(
1287                device,
1288                &mut FakeSocketsMutRefs(
1289                    any_device,
1290                    all_sockets,
1291                    &mut devices,
1292                    PhantomData,
1293                    counters,
1294                ),
1295            )
1296        }
1297        fn with_device_sockets_mut<
1298            F: FnOnce(
1299                &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1300                &mut Self::DeviceSocketCoreCtx<'_>,
1301            ) -> R,
1302            R,
1303        >(
1304            &mut self,
1305            device: &Self::DeviceId,
1306            cb: F,
1307        ) -> R {
1308            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1309                self.as_sockets_ref();
1310            let mut devices = device_sockets.keys().cloned().collect();
1311            let device = device_sockets.get_mut(device).unwrap();
1312            cb(
1313                device,
1314                &mut FakeSocketsMutRefs(
1315                    any_device,
1316                    all_sockets,
1317                    &mut devices,
1318                    PhantomData,
1319                    counters,
1320                ),
1321            )
1322        }
1323    }
1324
1325    impl<
1326        'm,
1327        DeviceId: FakeStrongDeviceId,
1328        As: AsFakeSocketsMutRefs<
1329                AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1330                AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1331                Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1332            > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1333    > DeviceSocketContext<FakeBindingsCtx> for As
1334    {
1335        type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1336            'a,
1337            (),
1338            (),
1339            HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1340            DeviceId,
1341        >;
1342
1343        fn with_any_device_sockets<
1344            F: FnOnce(
1345                &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1346                &mut Self::SocketTablesCoreCtx<'_>,
1347            ) -> R,
1348            R,
1349        >(
1350            &mut self,
1351            cb: F,
1352        ) -> R {
1353            let FakeSocketsMutRefs(
1354                any_device_sockets,
1355                _all_sockets,
1356                device_sockets,
1357                PhantomData,
1358                counters,
1359            ) = self.as_sockets_ref();
1360            cb(
1361                any_device_sockets,
1362                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1363            )
1364        }
1365        fn with_any_device_sockets_mut<
1366            F: FnOnce(
1367                &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1368                &mut Self::SocketTablesCoreCtx<'_>,
1369            ) -> R,
1370            R,
1371        >(
1372            &mut self,
1373            cb: F,
1374        ) -> R {
1375            let FakeSocketsMutRefs(
1376                any_device_sockets,
1377                _all_sockets,
1378                device_sockets,
1379                PhantomData,
1380                counters,
1381            ) = self.as_sockets_ref();
1382            cb(
1383                any_device_sockets,
1384                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1385            )
1386        }
1387
1388        fn with_all_device_sockets<
1389            F: FnOnce(
1390                &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1391                &mut Self::SocketTablesCoreCtx<'_>,
1392            ) -> R,
1393            R,
1394        >(
1395            &mut self,
1396            cb: F,
1397        ) -> R {
1398            let FakeSocketsMutRefs(
1399                _any_device_sockets,
1400                all_sockets,
1401                device_sockets,
1402                PhantomData,
1403                counters,
1404            ) = self.as_sockets_ref();
1405            cb(
1406                all_sockets,
1407                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1408            )
1409        }
1410
1411        fn with_all_device_sockets_mut<
1412            F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1413            R,
1414        >(
1415            &mut self,
1416            cb: F,
1417        ) -> R {
1418            let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1419            cb(all_sockets)
1420        }
1421    }
1422
1423    impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1424        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1425    {
1426        type DeviceId = D;
1427        type WeakDeviceId = FakeWeakDeviceId<D>;
1428    }
1429
1430    impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1431        fn counters(&self) -> &DeviceSocketCounters {
1432            &self.state.counters
1433        }
1434    }
1435
1436    impl<D: FakeStrongDeviceId>
1437        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1438        for FakeCoreCtx<D>
1439    {
1440        fn per_resource_counters<'a>(
1441            &'a self,
1442            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1443        ) -> &'a DeviceSocketCounters {
1444            socket.counters()
1445        }
1446    }
1447
1448    impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1449        fn counters(&self) -> &DeviceSocketCounters {
1450            let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1451            counters
1452        }
1453    }
1454
1455    impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1456        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1457        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1458    {
1459        fn per_resource_counters<'a>(
1460            &'a self,
1461            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1462        ) -> &'a DeviceSocketCounters {
1463            socket.counters()
1464        }
1465    }
1466
1467    const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1468
1469    #[test]
1470    fn create_remove() {
1471        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1472            MultipleDevicesId::all(),
1473        )));
1474        let mut api = ctx.device_socket_api();
1475
1476        let bound = api.create(Default::default());
1477        assert_eq!(
1478            api.get_info(&bound),
1479            SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1480        );
1481
1482        let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1483    }
1484
1485    #[test_case(TargetDevice::AnyDevice)]
1486    #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1487    fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1488        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1489            MultipleDevicesId::all(),
1490        )));
1491        let mut api = ctx.device_socket_api();
1492
1493        let bound = api.create(Default::default());
1494        api.set_device(&bound, device.clone());
1495        assert_eq!(
1496            api.get_info(&bound),
1497            SocketInfo { device: device.with_weak_id(), protocol: None }
1498        );
1499
1500        let device_sockets = &api.core_ctx().state.device_sockets;
1501        if let TargetDevice::SpecificDevice(d) = device {
1502            let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1503            assert_eq!(socket_ids, &HashSet::from([bound]));
1504        }
1505    }
1506
1507    #[test]
1508    fn update_device() {
1509        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1510            MultipleDevicesId::all(),
1511        )));
1512        let mut api = ctx.device_socket_api();
1513        let bound = api.create(Default::default());
1514
1515        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1516
1517        // Now update the device and make sure the socket only appears in the
1518        // one device's list.
1519        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1520        assert_eq!(
1521            api.get_info(&bound),
1522            SocketInfo {
1523                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1524                protocol: None
1525            }
1526        );
1527
1528        let device_sockets = &api.core_ctx().state.device_sockets;
1529        let device_socket_lists = device_sockets
1530            .iter()
1531            .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1532            .collect::<HashMap<_, _>>();
1533
1534        assert_eq!(
1535            device_socket_lists,
1536            HashMap::from([
1537                (&MultipleDevicesId::A, vec![]),
1538                (&MultipleDevicesId::B, vec![&bound]),
1539                (&MultipleDevicesId::C, vec![])
1540            ])
1541        );
1542    }
1543
1544    #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1545    #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1546    #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1547    #[test_case(
1548        Protocol::Specific(SOME_PROTOCOL),
1549        TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1550    )]
1551    fn create_set_device_and_protocol_remove_multiple(
1552        protocol: Protocol,
1553        device: TargetDevice<&MultipleDevicesId>,
1554    ) {
1555        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1556            MultipleDevicesId::all(),
1557        )));
1558        let mut api = ctx.device_socket_api();
1559
1560        let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1561        for socket in &mut sockets {
1562            api.set_device_and_protocol(socket, device.clone(), protocol);
1563            assert_eq!(
1564                api.get_info(socket),
1565                SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1566            );
1567        }
1568
1569        for socket in sockets {
1570            let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1571        }
1572    }
1573
1574    #[test]
1575    fn change_device_after_removal() {
1576        let device_to_remove = FakeReferencyDeviceId::default();
1577        let device_to_maintain = FakeReferencyDeviceId::default();
1578        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1579            device_to_remove.clone(),
1580            device_to_maintain.clone(),
1581        ])));
1582        let mut api = ctx.device_socket_api();
1583
1584        let bound = api.create(Default::default());
1585        // Set the device for the socket before removing the device state
1586        // entirely.
1587        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1588
1589        // Now remove the device; this should cause future attempts to upgrade
1590        // the device ID to fail.
1591        device_to_remove.mark_removed();
1592
1593        // Changing the device should gracefully handle the fact that the
1594        // earlier-bound device is now gone.
1595        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1596        assert_eq!(
1597            api.get_info(&bound),
1598            SocketInfo {
1599                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1600                protocol: None,
1601            }
1602        );
1603
1604        let device_sockets = &api.core_ctx().state.device_sockets;
1605        let DeviceSockets(weak_sockets) =
1606            device_sockets.get(&device_to_maintain).expect("device state exists");
1607        assert_eq!(weak_sockets, &HashSet::from([bound]));
1608    }
1609
1610    struct TestData;
1611    impl TestData {
1612        const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1613        const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1614        /// Arbitrary protocol number.
1615        const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1616        const BODY: &'static [u8] = b"some pig";
1617        const BUFFER: &'static [u8] = &[
1618            6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1619            b'i', b'g',
1620        ];
1621        const BUFFER_OFFSET: usize = Self::BUFFER.len() - Self::BODY.len();
1622
1623        /// Creates an EthernetFrame with the values specified above.
1624        fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1625            let mut buffer_view = Self::BUFFER;
1626            packet_formats::ethernet::EthernetFrame::parse(
1627                &mut buffer_view,
1628                EthernetFrameLengthCheck::NoCheck,
1629            )
1630            .unwrap()
1631        }
1632    }
1633
1634    const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1635
1636    fn make_bound<D: FakeStrongDeviceId>(
1637        ctx: &mut FakeCtx<D>,
1638        device: TargetDevice<D>,
1639        protocol: Option<Protocol>,
1640        state: ExternalSocketState<D::Weak>,
1641    ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1642        let mut api = ctx.device_socket_api();
1643        let id = api.create(state);
1644        let device = match &device {
1645            TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1646            TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1647        };
1648        match protocol {
1649            Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1650            None => api.set_device(&id, device),
1651        };
1652        id
1653    }
1654
1655    /// Deliver one frame to the provided contexts and return the IDs of the
1656    /// sockets it was delivered to.
1657    fn deliver_one_frame(
1658        delivered_frame: Frame<&[u8]>,
1659        FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1660    ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1661        DeviceSocketHandler::handle_frame(
1662            core_ctx,
1663            bindings_ctx,
1664            &MultipleDevicesId::A,
1665            delivered_frame.clone(),
1666            TestData::BUFFER,
1667        );
1668
1669        let FakeSockets {
1670            all_sockets: AllSockets(all_sockets),
1671            any_device_sockets: _,
1672            device_sockets: _,
1673            counters: _,
1674            sent_frames: _,
1675        } = &core_ctx.state;
1676
1677        all_sockets
1678            .iter()
1679            .filter_map(|(id, _primary)| {
1680                let DeviceSocketId(rc) = &id;
1681                let ExternalSocketState(frames) = &rc.external_state;
1682                let lock_guard = frames.lock();
1683                let testutil::RxQueue { frames, .. } = lock_guard.deref();
1684                (!frames.is_empty()).then(|| {
1685                    assert_eq!(
1686                        &*frames,
1687                        &[ReceivedFrame {
1688                            device: FakeWeakDeviceId(MultipleDevicesId::A),
1689                            frame: delivered_frame.cloned(),
1690                            raw: TestData::BUFFER.into(),
1691                        }]
1692                    );
1693                    id.clone()
1694                })
1695            })
1696            .collect()
1697    }
1698
1699    #[test]
1700    fn receive_frame_deliver_to_multiple() {
1701        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1702            MultipleDevicesId::all(),
1703        )));
1704
1705        use Protocol::*;
1706        use TargetDevice::*;
1707        let never_bound = {
1708            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1709            ctx.device_socket_api().create(state)
1710        };
1711
1712        let mut make_bound = |device, protocol| {
1713            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1714            make_bound(&mut ctx, device, protocol, state)
1715        };
1716        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1717        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1718        let bound_a_right_protocol =
1719            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1720        let bound_a_wrong_protocol =
1721            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1722        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1723        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1724        let bound_b_right_protocol =
1725            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1726        let bound_b_wrong_protocol =
1727            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1728        let bound_any_no_protocol = make_bound(AnyDevice, None);
1729        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1730        let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1731        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1732
1733        let mut sockets_with_received_frames = deliver_one_frame(
1734            super::ReceivedFrame::from_ethernet(
1735                TestData::frame(),
1736                FrameDestination::Individual { local: true },
1737            )
1738            .into(),
1739            &mut ctx,
1740        );
1741
1742        let sockets_not_expecting_frames = [
1743            never_bound,
1744            bound_a_no_protocol,
1745            bound_a_wrong_protocol,
1746            bound_b_no_protocol,
1747            bound_b_all_protocols,
1748            bound_b_right_protocol,
1749            bound_b_wrong_protocol,
1750            bound_any_no_protocol,
1751            bound_any_wrong_protocol,
1752        ];
1753        let sockets_expecting_frames = [
1754            bound_a_all_protocols,
1755            bound_a_right_protocol,
1756            bound_any_all_protocols,
1757            bound_any_right_protocol,
1758        ];
1759
1760        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1761            assert!(
1762                sockets_with_received_frames.remove(&socket),
1763                "socket {n} didn't receive the frame"
1764            );
1765        }
1766        assert!(sockets_with_received_frames.is_empty());
1767
1768        // Verify Counters were set appropriately for each socket.
1769        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1770            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1771        }
1772        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1773            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1774        }
1775    }
1776
1777    #[test]
1778    fn sent_frame_deliver_to_multiple() {
1779        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1780            MultipleDevicesId::all(),
1781        )));
1782
1783        use Protocol::*;
1784        use TargetDevice::*;
1785        let never_bound = {
1786            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1787            ctx.device_socket_api().create(state)
1788        };
1789
1790        let mut make_bound = |device, protocol| {
1791            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1792            make_bound(&mut ctx, device, protocol, state)
1793        };
1794        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1795        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1796        let bound_a_same_protocol =
1797            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1798        let bound_a_wrong_protocol =
1799            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1800        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1801        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1802        let bound_b_same_protocol =
1803            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1804        let bound_b_wrong_protocol =
1805            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1806        let bound_any_no_protocol = make_bound(AnyDevice, None);
1807        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1808        let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1809        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1810
1811        let mut sockets_with_received_frames =
1812            deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1813
1814        let sockets_not_expecting_frames = [
1815            never_bound,
1816            bound_a_no_protocol,
1817            bound_a_same_protocol,
1818            bound_a_wrong_protocol,
1819            bound_b_no_protocol,
1820            bound_b_all_protocols,
1821            bound_b_same_protocol,
1822            bound_b_wrong_protocol,
1823            bound_any_no_protocol,
1824            bound_any_same_protocol,
1825            bound_any_wrong_protocol,
1826        ];
1827        // Only any-protocol sockets receive sent frames.
1828        let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1829
1830        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1831            assert!(
1832                sockets_with_received_frames.remove(&socket),
1833                "socket {n} didn't receive the frame"
1834            );
1835        }
1836        assert!(sockets_with_received_frames.is_empty());
1837
1838        // Verify Counters were set appropriately for each socket.
1839        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1840            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1841        }
1842        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1843            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1844        }
1845    }
1846
1847    #[test]
1848    fn deliver_multiple_frames() {
1849        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1850            MultipleDevicesId::all(),
1851        )));
1852        let socket = make_bound(
1853            &mut ctx,
1854            TargetDevice::AnyDevice,
1855            Some(Protocol::All),
1856            ExternalSocketState::default(),
1857        );
1858        let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1859
1860        const RECEIVE_COUNT: usize = 10;
1861        for _ in 0..RECEIVE_COUNT {
1862            DeviceSocketHandler::handle_frame(
1863                &mut core_ctx,
1864                &mut bindings_ctx,
1865                &MultipleDevicesId::A,
1866                super::ReceivedFrame::from_ethernet(
1867                    TestData::frame(),
1868                    FrameDestination::Individual { local: true },
1869                )
1870                .into(),
1871                TestData::BUFFER,
1872            );
1873        }
1874
1875        let FakeSockets {
1876            all_sockets: AllSockets(mut all_sockets),
1877            any_device_sockets: _,
1878            device_sockets: _,
1879            counters: _,
1880            sent_frames: _,
1881        } = core_ctx.into_state();
1882        let primary = all_sockets.remove(&socket).unwrap();
1883        let PrimaryDeviceSocketId(primary) = primary;
1884        assert!(all_sockets.is_empty());
1885        drop(socket);
1886        let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1887            PrimaryRc::unwrap(primary);
1888        assert_eq!(
1889            received.into_inner().frames,
1890            vec![
1891                ReceivedFrame {
1892                    device: FakeWeakDeviceId(MultipleDevicesId::A),
1893                    frame: Frame::Received(super::ReceivedFrame::Ethernet {
1894                        destination: FrameDestination::Individual { local: true },
1895                        frame: EthernetFrame {
1896                            src_mac: TestData::SRC_MAC,
1897                            dst_mac: TestData::DST_MAC,
1898                            ethertype: Some(TestData::PROTO.get().into()),
1899                            body_offset: TestData::BUFFER_OFFSET,
1900                            body: Vec::from(TestData::BODY),
1901                        }
1902                    }),
1903                    raw: TestData::BUFFER.into()
1904                };
1905                RECEIVE_COUNT
1906            ]
1907        );
1908        assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1909    }
1910
1911    #[test]
1912    fn deliver_frame_queue_full() {
1913        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1914            MultipleDevicesId::all(),
1915        )));
1916
1917        // Simulate a full RX queue for sock1.
1918        let sock1 = make_bound(
1919            &mut ctx,
1920            TargetDevice::AnyDevice,
1921            Some(Protocol::All),
1922            ExternalSocketState(Mutex::new(testutil::RxQueue { frames: vec![], max_size: 0 })),
1923        );
1924        let sock2 = make_bound(
1925            &mut ctx,
1926            TargetDevice::AnyDevice,
1927            Some(Protocol::All),
1928            ExternalSocketState::default(),
1929        );
1930
1931        let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1932
1933        DeviceSocketHandler::handle_frame(
1934            &mut core_ctx,
1935            &mut bindings_ctx,
1936            &MultipleDevicesId::A,
1937            super::ReceivedFrame::from_ethernet(
1938                TestData::frame(),
1939                FrameDestination::Individual { local: true },
1940            )
1941            .into(),
1942            TestData::BUFFER,
1943        );
1944
1945        assert_eq!(core_ctx.state.counters.rx_frames.get(), 2);
1946        assert_eq!(core_ctx.state.counters.rx_queue_full.get(), 1);
1947        assert_eq!(sock1.counters().rx_frames.get(), 1);
1948        assert_eq!(sock1.counters().rx_queue_full.get(), 1);
1949        assert_eq!(sock2.counters().rx_frames.get(), 1);
1950        assert_eq!(sock2.counters().rx_queue_full.get(), 0);
1951
1952        // Drop our strong references to the sockets so that `core_ctx` can tear
1953        // down successfully.
1954        drop(sock1);
1955        drop(sock2);
1956    }
1957
1958    pub struct FakeSendMetadata;
1959    impl DeviceSocketSendTypes for AnyDevice {
1960        type Metadata = FakeSendMetadata;
1961    }
1962    impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1963        for DeviceSocketMetadata<AnyDevice, D>
1964    {
1965        fn send_meta<S>(
1966            self,
1967            core_ctx: &mut FakeCoreCtx<D>,
1968            _bindings_ctx: &mut BC,
1969            frame: S,
1970        ) -> Result<(), SendFrameError<S>>
1971        where
1972            S: packet::Serializer,
1973            S::Buffer: packet::BufferMut,
1974        {
1975            let frame = match frame.serialize_vec_outer() {
1976                Err(e) => {
1977                    let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1978                    unreachable!()
1979                }
1980                Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1981            };
1982            core_ctx.state.sent_frames.push(frame);
1983            Ok(())
1984        }
1985    }
1986
1987    #[test]
1988    fn send_multiple_frames() {
1989        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1990            MultipleDevicesId::all(),
1991        )));
1992
1993        const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1994        let socket = make_bound(
1995            &mut ctx,
1996            TargetDevice::SpecificDevice(DEVICE),
1997            Some(Protocol::All),
1998            ExternalSocketState::default(),
1999        );
2000        let mut api = ctx.device_socket_api();
2001
2002        const SEND_COUNT: usize = 10;
2003        const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
2004        for _ in 0..SEND_COUNT {
2005            let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
2006            api.send_frame(
2007                &socket,
2008                DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
2009                buf,
2010            )
2011            .expect("send failed");
2012        }
2013
2014        assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
2015
2016        assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
2017    }
2018}