Skip to main content

netstack3_device/
socket.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Link-layer sockets (analogous to Linux's AF_PACKET sockets).
6
7use core::fmt::Debug;
8use core::hash::Hash;
9use core::num::NonZeroU16;
10
11use derivative::Derivative;
12use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
13use net_types::ethernet::Mac;
14use net_types::ip::IpVersion;
15use netstack3_base::socket::SocketCookie;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18    AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19    Inspector, InspectorDeviceExt, InspectorExt, NetworkSerializer, ReferenceNotifiers,
20    ReferenceNotifiersExt as _, RemoveResourceResultWithContext, ResourceCounterContext,
21    SendFrameContext, SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use netstack3_hashmap::{HashMap, HashSet};
24use packet::{BufferMut, ParsablePacket as _};
25use packet_formats::error::ParseError;
26use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
27
28use crate::internal::base::DeviceLayerTypes;
29use crate::internal::id::WeakDeviceId;
30
31/// A selector for frames based on link-layer protocol number.
32#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
33pub enum Protocol {
34    /// Select all frames, regardless of protocol number.
35    All,
36    /// Select frames with the given protocol number.
37    Specific(NonZeroU16),
38}
39
40/// Selector for devices to send and receive packets on.
41#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
42#[derivative(Default(bound = ""))]
43pub enum TargetDevice<D> {
44    /// Act on any device in the system.
45    #[derivative(Default)]
46    AnyDevice,
47    /// Act on a specific device.
48    SpecificDevice(D),
49}
50
51/// Information about the bound state of a socket.
52#[derive(Debug)]
53#[cfg_attr(test, derive(PartialEq))]
54pub struct SocketInfo<D> {
55    /// The protocol the socket is bound to, or `None` if no protocol is set.
56    pub protocol: Option<Protocol>,
57    /// The device selector for which the socket is set.
58    pub device: TargetDevice<D>,
59}
60
61/// Provides associated types for device sockets provided by the bindings
62/// context.
63pub trait DeviceSocketTypes {
64    /// State for the socket held by core and exposed to bindings.
65    type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
66}
67
68/// Errors that Bindings may encounter when receiving frames on a Device Socket.
69pub enum ReceiveFrameError {
70    /// The socket's receive queue is full and can't hold the frame.
71    QueueFull,
72}
73
74/// The execution context for device sockets provided by bindings.
75pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>:
76    DeviceSocketTypes + Sized
77{
78    /// Called for each received frame that matches the provided socket.
79    ///
80    /// `frame` and `raw_frame` are parsed and raw views into the same data.
81    fn receive_frame(
82        &self,
83        socket_id: &DeviceSocketId<DeviceId::Weak, Self>,
84        device: &DeviceId,
85        frame: Frame<&[u8]>,
86        raw_frame: &[u8],
87    ) -> Result<(), ReceiveFrameError>;
88}
89
90/// Strong owner of socket state.
91///
92/// This type strongly owns the socket state.
93#[derive(Debug)]
94pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
95    PrimaryRc<SocketState<D, BT>>,
96);
97
98impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
99    /// Creates a new socket ID with `external_state`.
100    fn new(external_state: BT::SocketState<D>) -> Self {
101        Self(PrimaryRc::new(SocketState {
102            external_state,
103            counters: Default::default(),
104            target: Default::default(),
105        }))
106    }
107
108    /// Clones the primary's underlying reference and returns as a strong id.
109    fn clone_strong(&self) -> DeviceSocketId<D, BT> {
110        let PrimaryDeviceSocketId(rc) = self;
111        DeviceSocketId(PrimaryRc::clone_strong(rc))
112    }
113}
114
115/// Reference to live socket state.
116///
117/// The existence of a `StrongId` attests to the liveness of the state of the
118/// backing socket.
119#[derive(Derivative)]
120#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
121pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
122    StrongRc<SocketState<D, BT>>,
123);
124
125impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
126    /// Returns [`SocketCookie`] for this socket.
127    pub fn socket_cookie(&self) -> SocketCookie {
128        let Self(rc) = self;
129        SocketCookie::new(rc.resource_token())
130    }
131}
132
133impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
134    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
135        let Self(rc) = self;
136        f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
137    }
138}
139
140impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
141    for DeviceSocketId<D, BT>
142{
143    type Lock = Mutex<Target<D>>;
144    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
145        let Self(rc) = self;
146        OrderedLockRef::new(&rc.target)
147    }
148}
149
150/// A weak reference to socket state.
151///
152/// The existence of a [`WeakSocketDeviceId`] does not attest to the liveness of
153/// the backing socket.
154#[derive(Derivative)]
155#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
156pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
157    WeakRc<SocketState<D, BT>>,
158);
159
160impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
161    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
162        let Self(rc) = self;
163        f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
164    }
165}
166
167/// Holds shared state for sockets.
168#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
171    /// Holds strong (but not owning) references to sockets that aren't
172    /// targeting a particular device.
173    any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
174
175    /// Table of all sockets in the system, regardless of target.
176    ///
177    /// Holds the primary (owning) reference for all sockets.
178    // This needs to be after `any_device_sockets` so that when an instance of
179    // this type is dropped, any strong IDs get dropped before their
180    // corresponding primary IDs.
181    all_sockets: RwLock<AllSockets<D, BT>>,
182}
183
184/// The set of sockets associated with a device.
185#[derive(Derivative)]
186#[derivative(Default(bound = ""))]
187pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
188    HashSet<DeviceSocketId<D, BT>>,
189);
190
191/// A collection of all device sockets in the system.
192#[derive(Derivative)]
193#[derivative(Default(bound = ""))]
194pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
195    HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
196);
197
198/// State held by a device socket.
199#[derive(Debug)]
200pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
201    /// State provided by bindings that is held in core.
202    pub external_state: BT::SocketState<D>,
203    /// The socket's target device and protocol.
204    // TODO(https://fxbug.dev/42077026): Consider splitting up the state here to
205    // improve performance.
206    target: Mutex<Target<D>>,
207    /// Statistics about the socket's usage.
208    counters: DeviceSocketCounters,
209}
210
211/// A device socket's binding information.
212#[derive(Debug, Derivative)]
213#[derivative(Default(bound = ""))]
214pub struct Target<D> {
215    protocol: Option<Protocol>,
216    device: TargetDevice<D>,
217}
218
219/// Per-device state for packet sockets.
220///
221/// Holds sockets that are bound to a particular device. An instance of this
222/// should be held in the state for each device in the system.
223#[derive(Derivative)]
224#[derivative(Default(bound = ""))]
225#[cfg_attr(
226    test,
227    derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
228)]
229pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
230    HashSet<DeviceSocketId<D, BT>>,
231);
232
233/// Convenience alias for use in device state storage.
234pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
235
236/// Convenience alias for use in shared storage.
237///
238/// The type parameter is expected to implement [`DeviceSocketTypes`].
239pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
240
241/// Core context for accessing socket state.
242pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
243    /// The core context available in callbacks to methods on this context.
244    type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>;
245
246    /// Executes the provided callback with access to the collection of all
247    /// sockets.
248    fn with_all_device_sockets<
249        F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
250        R,
251    >(
252        &mut self,
253        cb: F,
254    ) -> R;
255
256    /// Executes the provided callback with mutable access to the collection of
257    /// all sockets.
258    fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
259        &mut self,
260        cb: F,
261    ) -> R;
262
263    /// Executes the provided callback with immutable access to socket state.
264    fn with_any_device_sockets<
265        F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
266        R,
267    >(
268        &mut self,
269        cb: F,
270    ) -> R;
271
272    /// Executes the provided callback with mutable access to socket state.
273    fn with_any_device_sockets_mut<
274        F: FnOnce(
275            &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
276            &mut Self::SocketTablesCoreCtx<'_>,
277        ) -> R,
278        R,
279    >(
280        &mut self,
281        cb: F,
282    ) -> R;
283}
284
285/// Core context for accessing the state of an individual socket.
286pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
287    /// Provides read-only access to the state of a socket.
288    fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
289        &mut self,
290        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291        cb: F,
292    ) -> R;
293
294    /// Provides mutable access to the state of a socket.
295    fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
296        &mut self,
297        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
298        cb: F,
299    ) -> R;
300}
301
302/// Core context for accessing the socket state for a device.
303pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
304    /// Core context available in callbacks to methods on this context.
305    type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
306        + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
307
308    /// Executes the provided callback with immutable access to device-specific
309    /// socket state.
310    fn with_device_sockets<
311        F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
312        R,
313    >(
314        &mut self,
315        device: &Self::DeviceId,
316        cb: F,
317    ) -> R;
318
319    /// Executes the provided callback with mutable access to device-specific
320    /// socket state.
321    fn with_device_sockets_mut<
322        F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
323        R,
324    >(
325        &mut self,
326        device: &Self::DeviceId,
327        cb: F,
328    ) -> R;
329}
330
331enum MaybeUpdate<T> {
332    NoChange,
333    NewValue(T),
334}
335
336fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
337    core_ctx: &mut CC,
338    socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
339    new_device: TargetDevice<&CC::DeviceId>,
340    protocol_update: MaybeUpdate<Protocol>,
341) {
342    core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
343        // Even if we're never moving the socket from/to the any-device
344        // state, we acquire the lock to make the move between devices
345        // atomic from the perspective of frame delivery. Otherwise there
346        // would be a brief period during which arriving frames wouldn't be
347        // delivered to the socket from either device.
348        let old_device = core_ctx.with_socket_state_mut(socket, |Target { protocol, device }| {
349            match protocol_update {
350                MaybeUpdate::NewValue(p) => *protocol = Some(p),
351                MaybeUpdate::NoChange => (),
352            };
353            let old_device = match &device {
354                TargetDevice::SpecificDevice(device) => device.upgrade(),
355                TargetDevice::AnyDevice => {
356                    assert!(any_device_sockets.remove(socket));
357                    None
358                }
359            };
360            *device = match &new_device {
361                TargetDevice::AnyDevice => TargetDevice::AnyDevice,
362                TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
363            };
364            old_device
365        });
366
367        // This modification occurs without holding the socket's individual
368        // lock. That's safe because all modifications to the socket's
369        // device are done within a `with_sockets_mut` call, which
370        // synchronizes them.
371
372        if let Some(device) = old_device {
373            // Remove the reference to the socket from the old device if
374            // there is one, and it hasn't been removed.
375            core_ctx.with_device_sockets_mut(
376                &device,
377                |DeviceSockets(device_sockets), _core_ctx| {
378                    assert!(device_sockets.remove(socket), "socket not found in device state");
379                },
380            );
381        }
382
383        // Add the reference to the new device, if there is one.
384        match &new_device {
385            TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
386                new_device,
387                |DeviceSockets(device_sockets), _core_ctx| {
388                    assert!(device_sockets.insert(socket.clone()));
389                },
390            ),
391            TargetDevice::AnyDevice => {
392                assert!(any_device_sockets.insert(socket.clone()))
393            }
394        }
395    })
396}
397
398/// The device socket API.
399pub struct DeviceSocketApi<C>(C);
400
401impl<C> DeviceSocketApi<C> {
402    /// Creates a new `DeviceSocketApi` for `ctx`.
403    pub fn new(ctx: C) -> Self {
404        Self(ctx)
405    }
406}
407
408/// A local alias for [`DeviceSocketId`] for use in [`DeviceSocketApi`].
409///
410/// TODO(https://github.com/rust-lang/rust/issues/8995): Make this an inherent
411/// associated type.
412type ApiSocketId<C> = DeviceSocketId<
413    <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
414    <C as ContextPair>::BindingsContext,
415>;
416
417impl<C> DeviceSocketApi<C>
418where
419    C: ContextPair,
420    C::CoreContext: DeviceSocketContext<C::BindingsContext>
421        + SocketStateAccessor<C::BindingsContext>
422        + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
423    C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
424        + ReferenceNotifiers
425        + 'static,
426{
427    fn core_ctx(&mut self) -> &mut C::CoreContext {
428        let Self(pair) = self;
429        pair.core_ctx()
430    }
431
432    fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
433        let Self(pair) = self;
434        pair.contexts()
435    }
436
437    /// Creates an packet socket with no protocol set configured for all devices.
438    pub fn create(
439        &mut self,
440        external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
441            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
442        >,
443    ) -> ApiSocketId<C> {
444        let core_ctx = self.core_ctx();
445
446        let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
447            let primary = PrimaryDeviceSocketId::new(external_state);
448            let strong = primary.clone_strong();
449            assert!(sockets.insert(strong.clone(), primary).is_none());
450            strong
451        });
452        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
453            // On creation, sockets do not target any device or protocol.
454            // Inserting them into the `any_device_sockets` table lets us treat
455            // newly-created sockets uniformly with sockets whose target device
456            // or protocol was set. The difference is unobservable at runtime
457            // since newly-created sockets won't match any frames being
458            // delivered.
459            assert!(any_device_sockets.insert(strong.clone()));
460        });
461        strong
462    }
463
464    /// Sets the device for which a packet socket will receive packets.
465    pub fn set_device(
466        &mut self,
467        socket: &ApiSocketId<C>,
468        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
469    ) {
470        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
471    }
472
473    /// Sets the device and protocol for which a socket will receive packets.
474    pub fn set_device_and_protocol(
475        &mut self,
476        socket: &ApiSocketId<C>,
477        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
478        protocol: Protocol,
479    ) {
480        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
481    }
482
483    /// Gets the bound info for a socket.
484    pub fn get_info(
485        &mut self,
486        id: &ApiSocketId<C>,
487    ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
488        self.core_ctx().with_socket_state(id, |Target { device, protocol }| SocketInfo {
489            device: device.clone(),
490            protocol: *protocol,
491        })
492    }
493
494    /// Removes a bound socket.
495    pub fn remove(
496        &mut self,
497        id: ApiSocketId<C>,
498    ) -> RemoveResourceResultWithContext<
499        <C::BindingsContext as DeviceSocketTypes>::SocketState<
500            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
501        >,
502        C::BindingsContext,
503    > {
504        let core_ctx = self.core_ctx();
505        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
506            let old_device = core_ctx.with_socket_state_mut(&id, |target| {
507                let Target { device, protocol: _ } = target;
508                match &device {
509                    TargetDevice::SpecificDevice(device) => device.upgrade(),
510                    TargetDevice::AnyDevice => {
511                        assert!(any_device_sockets.remove(&id));
512                        None
513                    }
514                }
515            });
516            if let Some(device) = old_device {
517                core_ctx.with_device_sockets_mut(
518                    &device,
519                    |DeviceSockets(device_sockets), _core_ctx| {
520                        assert!(device_sockets.remove(&id), "device doesn't have socket");
521                    },
522                )
523            }
524        });
525
526        core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
527            let primary = sockets
528                .remove(&id)
529                .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
530            // Make sure to drop the strong ID before trying to unwrap the primary
531            // ID.
532            drop(id);
533
534            let PrimaryDeviceSocketId(primary) = primary;
535            C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
536                primary,
537                |SocketState { external_state, counters: _, target: _ }| external_state,
538            )
539        })
540    }
541
542    /// Sends a frame for the specified socket.
543    pub fn send_frame<S, D>(
544        &mut self,
545        id: &ApiSocketId<C>,
546        metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
547        body: S,
548    ) -> Result<(), SendFrameErrorReason>
549    where
550        S: NetworkSerializer,
551        S::Buffer: BufferMut,
552        D: DeviceSocketSendTypes,
553        C::CoreContext: DeviceIdContext<D>
554            + SendFrameContext<
555                C::BindingsContext,
556                DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
557            >,
558        C::BindingsContext: DeviceLayerTypes,
559    {
560        let (core_ctx, bindings_ctx) = self.contexts();
561        let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
562        match &result {
563            Ok(()) => {
564                core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
565            }
566            Err(SendFrameErrorReason::QueueFull) => core_ctx
567                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
568            Err(SendFrameErrorReason::Alloc) => core_ctx
569                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
570            Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
571                .increment_both(id, |counters: &DeviceSocketCounters| {
572                    &counters.tx_err_size_constraint
573                }),
574        }
575        result
576    }
577
578    /// Provides inspect data for raw IP sockets.
579    pub fn inspect<N>(&mut self, inspector: &mut N)
580    where
581        N: Inspector
582            + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
583    {
584        self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
585            sockets.keys().for_each(|socket| {
586                inspector.record_debug_child(socket, |node| {
587                    core_ctx.with_socket_state(socket, |Target { protocol, device }| {
588                        node.record_debug("Protocol", protocol);
589                        match device {
590                            TargetDevice::AnyDevice => node.record_str("Device", "Any"),
591                            TargetDevice::SpecificDevice(d) => N::record_device(node, "Device", d),
592                        }
593                    });
594                    node.record_child("Counters", |node| {
595                        node.delegate_inspectable(socket.counters())
596                    })
597                })
598            })
599        })
600    }
601}
602
603/// A provider of the types required to send on a device socket.
604pub trait DeviceSocketSendTypes: Device {
605    /// The metadata required to send a frame on the device.
606    type Metadata;
607}
608
609/// Metadata required to send a frame on a device socket.
610#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612    /// The device ID to send via.
613    pub device_id: DeviceId,
614    /// The metadata required to send that's specific to the device type.
615    pub metadata: D::Metadata,
616    // TODO(https://fxbug.dev/391946195): Include send buffer ownership metadata
617    // here.
618}
619
620/// Parameters needed to apply system-framing of an Ethernet frame.
621#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623    /// The destination MAC address to send to.
624    pub dest_addr: Mac,
625    /// The upperlayer protocol of the data contained in this Ethernet frame.
626    pub protocol: EtherType,
627}
628
629/// Public identifier for a socket.
630///
631/// Strongly owns the state of the socket. So long as the `SocketId` for a
632/// socket is not dropped, the socket is guaranteed to exist.
633pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636    /// Provides immutable access to [`DeviceSocketTypes::SocketState`] for the
637    /// socket.
638    pub fn socket_state(&self) -> &BT::SocketState<D> {
639        let Self(strong) = self;
640        let SocketState { external_state, counters: _, target: _ } = &**strong;
641        external_state
642    }
643
644    /// Obtain a [`WeakDeviceSocketId`] from this [`DeviceSocketId`].
645    pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646        let Self(inner) = self;
647        WeakDeviceSocketId(StrongRc::downgrade(inner))
648    }
649
650    /// Provides access to the socket's counters.
651    pub fn counters(&self) -> &DeviceSocketCounters {
652        let Self(strong) = self;
653        let SocketState { external_state: _, counters, target: _ } = &**strong;
654        counters
655    }
656}
657
658/// Allows the rest of the stack to dispatch packets to listening sockets.
659///
660/// This is implemented on top of [`DeviceSocketContext`] and abstracts packet
661/// socket delivery from the rest of the system.
662pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663    /// Dispatch a received frame to sockets.
664    fn handle_frame(
665        &mut self,
666        bindings_ctx: &mut BC,
667        device: &Self::DeviceId,
668        frame: Frame<&[u8]>,
669        whole_frame: &[u8],
670    );
671}
672
673/// A frame received on a device.
674#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676    /// An ethernet frame received on a device.
677    Ethernet {
678        /// Where the frame was destined.
679        destination: FrameDestination,
680        /// The parsed ethernet frame.
681        frame: EthernetFrame<B>,
682    },
683    /// An IP frame received on a device.
684    ///
685    /// Note that this is not an IP packet within an Ethernet Frame. This is an
686    /// IP packet received directly from the device (e.g. a pure IP device).
687    Ip(IpFrame<B>),
688}
689
690/// A frame sent on a device.
691#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693    /// An ethernet frame sent on a device.
694    Ethernet(EthernetFrame<B>),
695    /// An IP frame sent on a device.
696    ///
697    /// Note that this is not an IP packet within an Ethernet Frame. This is an
698    /// IP Packet send directly on the device (e.g. a pure IP device).
699    Ip(IpFrame<B>),
700}
701
702/// A frame couldn't be parsed as a [`SentFrame`].
703#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707    /// Tries to parse the given frame as an Ethernet frame.
708    pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709        packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710            .map_err(|_: ParseError| ParseSentFrameError)
711            .map(|frame| SentFrame::Ethernet(frame.into()))
712    }
713}
714
715/// Data from an Ethernet frame.
716#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718    /// The source address of the frame.
719    pub src_mac: Mac,
720    /// The destination address of the frame.
721    pub dst_mac: Mac,
722    /// The EtherType of the frame, or `None` if there was none.
723    pub ethertype: Option<EtherType>,
724    /// The offset of the body within the frame.
725    pub body_offset: usize,
726    /// The body of the frame.
727    pub body: B,
728}
729
730/// Data from an IP frame.
731#[derive(Clone, Copy, Debug, Eq, PartialEq)]
732pub struct IpFrame<B> {
733    /// The IP version of the frame.
734    pub ip_version: IpVersion,
735    /// The body of the frame.
736    pub body: B,
737}
738
739impl<B> IpFrame<B> {
740    fn ethertype(&self) -> EtherType {
741        let IpFrame { ip_version, body: _ } = self;
742        EtherType::from_ip_version(*ip_version)
743    }
744}
745
746/// A frame sent or received on a device
747#[derive(Clone, Copy, Debug, Eq, PartialEq)]
748pub enum Frame<B> {
749    /// A sent frame.
750    Sent(SentFrame<B>),
751    /// A received frame.
752    Received(ReceivedFrame<B>),
753}
754
755impl<B> From<SentFrame<B>> for Frame<B> {
756    fn from(value: SentFrame<B>) -> Self {
757        Self::Sent(value)
758    }
759}
760
761impl<B> From<ReceivedFrame<B>> for Frame<B> {
762    fn from(value: ReceivedFrame<B>) -> Self {
763        Self::Received(value)
764    }
765}
766
767impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
768    fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
769        Self {
770            src_mac: frame.src_mac(),
771            dst_mac: frame.dst_mac(),
772            ethertype: frame.ethertype(),
773            body_offset: frame.parse_metadata().header_len(),
774            body: frame.into_body(),
775        }
776    }
777}
778
779impl<'a> ReceivedFrame<&'a [u8]> {
780    pub(crate) fn from_ethernet(
781        frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
782        destination: FrameDestination,
783    ) -> Self {
784        Self::Ethernet { destination, frame: frame.into() }
785    }
786}
787
788impl<B> Frame<B> {
789    /// Returns ether type for the packet if it's known.
790    pub fn protocol(&self) -> Option<u16> {
791        let ethertype = match self {
792            Self::Sent(SentFrame::Ethernet(frame))
793            | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
794            Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
795                Some(frame.ethertype())
796            }
797        };
798        ethertype.map(Into::into)
799    }
800
801    /// Convenience method for consuming the `Frame` and producing the body.
802    pub fn into_body(self) -> B {
803        match self {
804            Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
805            | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
806            Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
807                frame.body
808            }
809        }
810    }
811
812    /// Returns the offset of the body within the frame.
813    pub fn body_offset(&self) -> usize {
814        match self {
815            Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
816            | Self::Sent(SentFrame::Ethernet(frame)) => frame.body_offset,
817            Self::Received(ReceivedFrame::Ip(_)) | Self::Sent(SentFrame::Ip(_)) => 0,
818        }
819    }
820}
821
822impl<
823    D: Device,
824    BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
825    CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
826> DeviceSocketHandler<D, BC> for CC
827where
828    <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
829{
830    fn handle_frame(
831        &mut self,
832        bindings_ctx: &mut BC,
833        device: &Self::DeviceId,
834        frame: Frame<&[u8]>,
835        whole_frame: &[u8],
836    ) {
837        let device = device.clone().into();
838
839        // TODO(https://fxbug.dev/42076496): Invert the order of acquisition
840        // for the lock on the sockets held in the device and the any-device
841        // sockets lock.
842        self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
843            // Iterate through the device's sockets while also holding the
844            // any-device sockets lock. This prevents double delivery to the
845            // same socket. If the two tables were locked independently,
846            // we could end up with a race, with the following thread
847            // interleaving (thread A is executing this code for device D,
848            // thread B is updating the device to D for the same socket X):
849            //   A) lock the any device sockets table
850            //   A) deliver to socket X in the table
851            //   A) unlock the any device sockets table
852            //   B) lock the any device sockets table, then D's sockets
853            //   B) remove X from the any table and add to D's
854            //   B) unlock D's sockets and any device sockets
855            //   A) lock D's sockets
856            //   A) deliver to socket X in D's table (!)
857            core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
858                for socket in any_device_sockets.iter().chain(device_sockets) {
859                    let delivered =
860                        core_ctx.with_socket_state(socket, |Target { protocol, device: _ }| {
861                            let should_deliver = match protocol {
862                                None => false,
863                                Some(p) => match p {
864                                    // Sent frames are only delivered to sockets
865                                    // matching all protocols for Linux
866                                    // compatibility. See https://github.com/google/gvisor/blob/68eae979409452209e4faaeac12aee4191b3d6f0/test/syscalls/linux/packet_socket.cc#L331-L392.
867                                    Protocol::Specific(p) => match frame {
868                                        Frame::Received(_) => Some(p.get()) == frame.protocol(),
869                                        Frame::Sent(_) => false,
870                                    },
871                                    Protocol::All => true,
872                                },
873                            };
874                            should_deliver.then(|| {
875                                bindings_ctx.receive_frame(socket, &device, frame, whole_frame)
876                            })
877                        });
878                    match delivered {
879                        None => {}
880                        Some(result) => {
881                            core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
882                                &counters.rx_frames
883                            });
884                            match result {
885                                Ok(()) => {}
886                                Err(ReceiveFrameError::QueueFull) => {
887                                    core_ctx.increment_both(
888                                        socket,
889                                        |counters: &DeviceSocketCounters| &counters.rx_queue_full,
890                                    );
891                                }
892                            }
893                        }
894                    }
895                }
896            })
897        })
898    }
899}
900
901/// Usage statistics about Device Sockets.
902///
903/// Tracked stack-wide and per-socket.
904#[derive(Debug, Default)]
905pub struct DeviceSocketCounters {
906    /// Count of incoming frames that were delivered to the socket.
907    ///
908    /// Note that a single frame may be delivered to multiple device sockets.
909    /// Thus this counter, when tracking the stack-wide aggregate, may exceed
910    /// the total number of frames received by the stack.
911    rx_frames: Counter,
912    /// Count of incoming frames that could not be delivered to a socket because
913    /// its receive buffer was full.
914    rx_queue_full: Counter,
915    /// Count of outgoing frames that were sent by the socket.
916    tx_frames: Counter,
917    /// Count of failed tx frames due to [`SendFrameErrorReason::QueueFull`].
918    tx_err_queue_full: Counter,
919    /// Count of failed tx frames due to [`SendFrameErrorReason::Alloc`].
920    tx_err_alloc: Counter,
921    /// Count of failed tx frames due to [`SendFrameErrorReason::SizeConstraintsViolation`].
922    tx_err_size_constraint: Counter,
923}
924
925impl Inspectable for DeviceSocketCounters {
926    fn record<I: Inspector>(&self, inspector: &mut I) {
927        let Self {
928            rx_frames,
929            rx_queue_full,
930            tx_frames,
931            tx_err_queue_full,
932            tx_err_alloc,
933            tx_err_size_constraint,
934        } = self;
935        inspector.record_child("Rx", |inspector| {
936            inspector.record_counter("DeliveredFrames", rx_frames);
937            inspector.record_counter("DroppedQueueFull", rx_queue_full);
938        });
939        inspector.record_child("Tx", |inspector| {
940            inspector.record_counter("SentFrames", tx_frames);
941            inspector.record_counter("QueueFullError", tx_err_queue_full);
942            inspector.record_counter("AllocError", tx_err_alloc);
943            inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
944        });
945    }
946}
947
948impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
949    for Sockets<D, BT>
950{
951    type Lock = RwLock<AnyDeviceSockets<D, BT>>;
952    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
953        OrderedLockRef::new(&self.any_device_sockets)
954    }
955}
956
957impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
958    for Sockets<D, BT>
959{
960    type Lock = RwLock<AllSockets<D, BT>>;
961    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
962        OrderedLockRef::new(&self.all_sockets)
963    }
964}
965
966#[cfg(any(test, feature = "testutils"))]
967mod testutil {
968    use alloc::vec::Vec;
969    use core::num::NonZeroU64;
970    use core::ops::DerefMut;
971    use netstack3_base::StrongDeviceIdentifier;
972    use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
973
974    use super::*;
975    use crate::internal::base::{
976        DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
977    };
978
979    #[derive(Derivative, Debug)]
980    #[derivative(Default(bound = ""))]
981    pub struct RxQueue<D> {
982        pub frames: Vec<ReceivedFrame<D>>,
983        #[derivative(Default(value = "usize::MAX"))]
984        pub max_size: usize,
985    }
986
987    #[derive(Clone, Debug, PartialEq)]
988    pub struct ReceivedFrame<D> {
989        pub device: D,
990        pub frame: Frame<Vec<u8>>,
991        pub raw: Vec<u8>,
992    }
993
994    #[derive(Debug, Derivative)]
995    #[derivative(Default(bound = ""))]
996    pub struct ExternalSocketState<D>(pub Mutex<RxQueue<D>>);
997
998    impl<TimerId, Event: Debug, State> DeviceSocketTypes
999        for FakeBindingsCtx<TimerId, Event, State, ()>
1000    {
1001        type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
1002    }
1003
1004    impl Frame<&[u8]> {
1005        pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
1006            match self {
1007                Self::Sent(SentFrame::Ethernet(frame)) => {
1008                    Frame::Sent(SentFrame::Ethernet(frame.cloned()))
1009                }
1010                Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
1011                    Frame::Received(super::ReceivedFrame::Ethernet {
1012                        destination,
1013                        frame: frame.cloned(),
1014                    })
1015                }
1016                Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
1017                Self::Received(super::ReceivedFrame::Ip(frame)) => {
1018                    Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
1019                }
1020            }
1021        }
1022    }
1023
1024    impl EthernetFrame<&[u8]> {
1025        fn cloned(self) -> EthernetFrame<Vec<u8>> {
1026            let Self { src_mac, dst_mac, ethertype, body_offset, body } = self;
1027            EthernetFrame { src_mac, dst_mac, ethertype, body_offset, body: Vec::from(body) }
1028        }
1029    }
1030
1031    impl IpFrame<&[u8]> {
1032        fn cloned(self) -> IpFrame<Vec<u8>> {
1033            let Self { ip_version, body } = self;
1034            IpFrame { ip_version, body: Vec::from(body) }
1035        }
1036    }
1037
1038    impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1039        for FakeBindingsCtx<TimerId, Event, State, ()>
1040    {
1041        fn receive_frame(
1042            &self,
1043            state: &DeviceSocketId<D::Weak, Self>,
1044            device: &D,
1045            frame: Frame<&[u8]>,
1046            raw_frame: &[u8],
1047        ) -> Result<(), ReceiveFrameError> {
1048            let ExternalSocketState(queue) = state.socket_state();
1049            let mut lock_guard = queue.lock();
1050            let RxQueue { frames, max_size } = lock_guard.deref_mut();
1051            if frames.len() < *max_size {
1052                frames.push(ReceivedFrame {
1053                    device: device.downgrade(),
1054                    frame: frame.cloned(),
1055                    raw: raw_frame.into(),
1056                });
1057                Ok(())
1058            } else {
1059                Err(ReceiveFrameError::QueueFull)
1060            }
1061        }
1062    }
1063
1064    impl<
1065        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1066        Event: Debug + 'static,
1067        State: 'static,
1068    > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1069    {
1070        type EthernetDeviceState = ();
1071        type LoopbackDeviceState = ();
1072        type PureIpDeviceState = ();
1073        type BlackholeDeviceState = ();
1074        type DeviceIdentifier = MonotonicIdentifier;
1075    }
1076
1077    impl DeviceClassMatcher<()> for () {
1078        fn device_class_matches(&self, (): &()) -> bool {
1079            unimplemented!()
1080        }
1081    }
1082
1083    impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1084        fn id_matches(&self, _id: &NonZeroU64) -> bool {
1085            unimplemented!()
1086        }
1087
1088        fn name_matches(&self, _name: &str) -> bool {
1089            unimplemented!()
1090        }
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use alloc::vec;
1097    use alloc::vec::Vec;
1098    use core::marker::PhantomData;
1099    use core::ops::Deref;
1100
1101    use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1102    use netstack3_base::testutil::{
1103        FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1104    };
1105    use netstack3_base::{
1106        CounterContext, CtxPair, NetworkSerializationContext, SendFrameError, SendableFrameMeta,
1107    };
1108    use netstack3_hashmap::HashMap;
1109    use packet::ParsablePacket;
1110    use test_case::test_case;
1111
1112    use super::*;
1113
1114    type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1115    type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1116    type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1117
1118    /// A trait providing a shortcut to instantiate a [`DeviceSocketApi`] from a
1119    /// context.
1120    trait DeviceSocketApiExt: ContextPair + Sized {
1121        fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1122            DeviceSocketApi::new(self)
1123        }
1124    }
1125
1126    impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1127
1128    #[derive(Derivative)]
1129    #[derivative(Default(bound = ""))]
1130    struct FakeSockets<D: FakeStrongDeviceId> {
1131        any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1132        device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1133        all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1134        /// The stack-wide counters for device sockets.
1135        counters: DeviceSocketCounters,
1136        sent_frames: Vec<Vec<u8>>,
1137    }
1138
1139    /// Tuple of references
1140    pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1141        &'m mut AnyDevice,
1142        &'m mut AllSockets,
1143        &'m mut Devices,
1144        PhantomData<Device>,
1145        &'m DeviceSocketCounters,
1146    );
1147
1148    /// Helper trait to allow treating a `&mut self` as a
1149    /// [`FakeSocketsMutRefs`].
1150    pub trait AsFakeSocketsMutRefs {
1151        type AnyDevice: 'static;
1152        type AllSockets: 'static;
1153        type Devices: 'static;
1154        type Device: 'static;
1155        fn as_sockets_ref(
1156            &mut self,
1157        ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1158    }
1159
1160    impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1161        type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1162        type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1163        type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1164        type Device = D;
1165
1166        fn as_sockets_ref(
1167            &mut self,
1168        ) -> FakeSocketsMutRefs<
1169            '_,
1170            AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1171            AllSockets<D::Weak, FakeBindingsCtx>,
1172            HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1173            D,
1174        > {
1175            let FakeSockets {
1176                any_device_sockets,
1177                device_sockets,
1178                all_sockets,
1179                counters,
1180                sent_frames: _,
1181            } = &mut self.state;
1182            FakeSocketsMutRefs(
1183                any_device_sockets,
1184                all_sockets,
1185                device_sockets,
1186                PhantomData,
1187                counters,
1188            )
1189        }
1190    }
1191
1192    impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1193        AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1194    {
1195        type AnyDevice = AnyDevice;
1196        type AllSockets = AllSockets;
1197        type Devices = Devices;
1198        type Device = Device;
1199
1200        fn as_sockets_ref(
1201            &mut self,
1202        ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1203            let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1204            FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1205        }
1206    }
1207
1208    impl<D: Clone> TargetDevice<&D> {
1209        fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1210            match self {
1211                TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1212                TargetDevice::SpecificDevice(d) => {
1213                    TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1214                }
1215            }
1216        }
1217    }
1218
1219    impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1220        fn new(devices: impl IntoIterator<Item = D>) -> Self {
1221            let device_sockets =
1222                devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1223            Self {
1224                any_device_sockets: AnyDeviceSockets::default(),
1225                device_sockets,
1226                all_sockets: Default::default(),
1227                counters: Default::default(),
1228                sent_frames: Default::default(),
1229            }
1230        }
1231    }
1232
1233    impl<
1234        'm,
1235        DeviceId: FakeStrongDeviceId,
1236        As: AsFakeSocketsMutRefs
1237            + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1238    > SocketStateAccessor<FakeBindingsCtx> for As
1239    {
1240        fn with_socket_state<F: FnOnce(&Target<Self::WeakDeviceId>) -> R, R>(
1241            &mut self,
1242            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1243            cb: F,
1244        ) -> R {
1245            let DeviceSocketId(rc) = socket;
1246            // NB: Circumvent lock ordering for tests.
1247            let target = rc.target.lock();
1248            cb(&target)
1249        }
1250
1251        fn with_socket_state_mut<F: FnOnce(&mut Target<Self::WeakDeviceId>) -> R, R>(
1252            &mut self,
1253            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1254            cb: F,
1255        ) -> R {
1256            let DeviceSocketId(rc) = socket;
1257            // NB: Circumvent lock ordering for tests.
1258            let mut target = rc.target.lock();
1259            cb(&mut target)
1260        }
1261    }
1262
1263    impl<
1264        'm,
1265        DeviceId: FakeStrongDeviceId,
1266        As: AsFakeSocketsMutRefs<
1267                Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1268            > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1269    > DeviceSocketAccessor<FakeBindingsCtx> for As
1270    {
1271        type DeviceSocketCoreCtx<'a> =
1272            FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1273        fn with_device_sockets<
1274            F: FnOnce(
1275                &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1276                &mut Self::DeviceSocketCoreCtx<'_>,
1277            ) -> R,
1278            R,
1279        >(
1280            &mut self,
1281            device: &Self::DeviceId,
1282            cb: F,
1283        ) -> R {
1284            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1285                self.as_sockets_ref();
1286            let mut devices = device_sockets.keys().cloned().collect();
1287            let device = device_sockets.get(device).unwrap();
1288            cb(
1289                device,
1290                &mut FakeSocketsMutRefs(
1291                    any_device,
1292                    all_sockets,
1293                    &mut devices,
1294                    PhantomData,
1295                    counters,
1296                ),
1297            )
1298        }
1299        fn with_device_sockets_mut<
1300            F: FnOnce(
1301                &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1302                &mut Self::DeviceSocketCoreCtx<'_>,
1303            ) -> R,
1304            R,
1305        >(
1306            &mut self,
1307            device: &Self::DeviceId,
1308            cb: F,
1309        ) -> R {
1310            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1311                self.as_sockets_ref();
1312            let mut devices = device_sockets.keys().cloned().collect();
1313            let device = device_sockets.get_mut(device).unwrap();
1314            cb(
1315                device,
1316                &mut FakeSocketsMutRefs(
1317                    any_device,
1318                    all_sockets,
1319                    &mut devices,
1320                    PhantomData,
1321                    counters,
1322                ),
1323            )
1324        }
1325    }
1326
1327    impl<
1328        'm,
1329        DeviceId: FakeStrongDeviceId,
1330        As: AsFakeSocketsMutRefs<
1331                AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1332                AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1333                Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1334            > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1335    > DeviceSocketContext<FakeBindingsCtx> for As
1336    {
1337        type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1338            'a,
1339            (),
1340            (),
1341            HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1342            DeviceId,
1343        >;
1344
1345        fn with_any_device_sockets<
1346            F: FnOnce(
1347                &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1348                &mut Self::SocketTablesCoreCtx<'_>,
1349            ) -> R,
1350            R,
1351        >(
1352            &mut self,
1353            cb: F,
1354        ) -> R {
1355            let FakeSocketsMutRefs(
1356                any_device_sockets,
1357                _all_sockets,
1358                device_sockets,
1359                PhantomData,
1360                counters,
1361            ) = self.as_sockets_ref();
1362            cb(
1363                any_device_sockets,
1364                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1365            )
1366        }
1367        fn with_any_device_sockets_mut<
1368            F: FnOnce(
1369                &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1370                &mut Self::SocketTablesCoreCtx<'_>,
1371            ) -> R,
1372            R,
1373        >(
1374            &mut self,
1375            cb: F,
1376        ) -> R {
1377            let FakeSocketsMutRefs(
1378                any_device_sockets,
1379                _all_sockets,
1380                device_sockets,
1381                PhantomData,
1382                counters,
1383            ) = self.as_sockets_ref();
1384            cb(
1385                any_device_sockets,
1386                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1387            )
1388        }
1389
1390        fn with_all_device_sockets<
1391            F: FnOnce(
1392                &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1393                &mut Self::SocketTablesCoreCtx<'_>,
1394            ) -> R,
1395            R,
1396        >(
1397            &mut self,
1398            cb: F,
1399        ) -> R {
1400            let FakeSocketsMutRefs(
1401                _any_device_sockets,
1402                all_sockets,
1403                device_sockets,
1404                PhantomData,
1405                counters,
1406            ) = self.as_sockets_ref();
1407            cb(
1408                all_sockets,
1409                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1410            )
1411        }
1412
1413        fn with_all_device_sockets_mut<
1414            F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1415            R,
1416        >(
1417            &mut self,
1418            cb: F,
1419        ) -> R {
1420            let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1421            cb(all_sockets)
1422        }
1423    }
1424
1425    impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1426        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1427    {
1428        type DeviceId = D;
1429        type WeakDeviceId = FakeWeakDeviceId<D>;
1430    }
1431
1432    impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1433        fn counters(&self) -> &DeviceSocketCounters {
1434            &self.state.counters
1435        }
1436    }
1437
1438    impl<D: FakeStrongDeviceId>
1439        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1440        for FakeCoreCtx<D>
1441    {
1442        fn per_resource_counters<'a>(
1443            &'a self,
1444            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1445        ) -> &'a DeviceSocketCounters {
1446            socket.counters()
1447        }
1448    }
1449
1450    impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1451        fn counters(&self) -> &DeviceSocketCounters {
1452            let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1453            counters
1454        }
1455    }
1456
1457    impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1458        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1459        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1460    {
1461        fn per_resource_counters<'a>(
1462            &'a self,
1463            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1464        ) -> &'a DeviceSocketCounters {
1465            socket.counters()
1466        }
1467    }
1468
1469    const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1470
1471    #[test]
1472    fn create_remove() {
1473        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1474            MultipleDevicesId::all(),
1475        )));
1476        let mut api = ctx.device_socket_api();
1477
1478        let bound = api.create(Default::default());
1479        assert_eq!(
1480            api.get_info(&bound),
1481            SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1482        );
1483
1484        let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1485    }
1486
1487    #[test_case(TargetDevice::AnyDevice)]
1488    #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1489    fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1490        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1491            MultipleDevicesId::all(),
1492        )));
1493        let mut api = ctx.device_socket_api();
1494
1495        let bound = api.create(Default::default());
1496        api.set_device(&bound, device.clone());
1497        assert_eq!(
1498            api.get_info(&bound),
1499            SocketInfo { device: device.with_weak_id(), protocol: None }
1500        );
1501
1502        let device_sockets = &api.core_ctx().state.device_sockets;
1503        if let TargetDevice::SpecificDevice(d) = device {
1504            let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1505            assert_eq!(socket_ids, &HashSet::from([bound]));
1506        }
1507    }
1508
1509    #[test]
1510    fn update_device() {
1511        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1512            MultipleDevicesId::all(),
1513        )));
1514        let mut api = ctx.device_socket_api();
1515        let bound = api.create(Default::default());
1516
1517        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1518
1519        // Now update the device and make sure the socket only appears in the
1520        // one device's list.
1521        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1522        assert_eq!(
1523            api.get_info(&bound),
1524            SocketInfo {
1525                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1526                protocol: None
1527            }
1528        );
1529
1530        let device_sockets = &api.core_ctx().state.device_sockets;
1531        let device_socket_lists = device_sockets
1532            .iter()
1533            .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1534            .collect::<HashMap<_, _>>();
1535
1536        assert_eq!(
1537            device_socket_lists,
1538            HashMap::from([
1539                (&MultipleDevicesId::A, vec![]),
1540                (&MultipleDevicesId::B, vec![&bound]),
1541                (&MultipleDevicesId::C, vec![])
1542            ])
1543        );
1544    }
1545
1546    #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1547    #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1548    #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1549    #[test_case(
1550        Protocol::Specific(SOME_PROTOCOL),
1551        TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1552    )]
1553    fn create_set_device_and_protocol_remove_multiple(
1554        protocol: Protocol,
1555        device: TargetDevice<&MultipleDevicesId>,
1556    ) {
1557        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1558            MultipleDevicesId::all(),
1559        )));
1560        let mut api = ctx.device_socket_api();
1561
1562        let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1563        for socket in &mut sockets {
1564            api.set_device_and_protocol(socket, device.clone(), protocol);
1565            assert_eq!(
1566                api.get_info(socket),
1567                SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1568            );
1569        }
1570
1571        for socket in sockets {
1572            let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1573        }
1574    }
1575
1576    #[test]
1577    fn change_device_after_removal() {
1578        let device_to_remove = FakeReferencyDeviceId::default();
1579        let device_to_maintain = FakeReferencyDeviceId::default();
1580        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1581            device_to_remove.clone(),
1582            device_to_maintain.clone(),
1583        ])));
1584        let mut api = ctx.device_socket_api();
1585
1586        let bound = api.create(Default::default());
1587        // Set the device for the socket before removing the device state
1588        // entirely.
1589        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1590
1591        // Now remove the device; this should cause future attempts to upgrade
1592        // the device ID to fail.
1593        device_to_remove.mark_removed();
1594
1595        // Changing the device should gracefully handle the fact that the
1596        // earlier-bound device is now gone.
1597        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1598        assert_eq!(
1599            api.get_info(&bound),
1600            SocketInfo {
1601                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1602                protocol: None,
1603            }
1604        );
1605
1606        let device_sockets = &api.core_ctx().state.device_sockets;
1607        let DeviceSockets(weak_sockets) =
1608            device_sockets.get(&device_to_maintain).expect("device state exists");
1609        assert_eq!(weak_sockets, &HashSet::from([bound]));
1610    }
1611
1612    struct TestData;
1613    impl TestData {
1614        const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1615        const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1616        /// Arbitrary protocol number.
1617        const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1618        const BODY: &'static [u8] = b"some pig";
1619        const BUFFER: &'static [u8] = &[
1620            6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1621            b'i', b'g',
1622        ];
1623        const BUFFER_OFFSET: usize = Self::BUFFER.len() - Self::BODY.len();
1624
1625        /// Creates an EthernetFrame with the values specified above.
1626        fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1627            let mut buffer_view = Self::BUFFER;
1628            packet_formats::ethernet::EthernetFrame::parse(
1629                &mut buffer_view,
1630                EthernetFrameLengthCheck::NoCheck,
1631            )
1632            .unwrap()
1633        }
1634    }
1635
1636    const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1637
1638    fn make_bound<D: FakeStrongDeviceId>(
1639        ctx: &mut FakeCtx<D>,
1640        device: TargetDevice<D>,
1641        protocol: Option<Protocol>,
1642        state: ExternalSocketState<D::Weak>,
1643    ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1644        let mut api = ctx.device_socket_api();
1645        let id = api.create(state);
1646        let device = match &device {
1647            TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1648            TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1649        };
1650        match protocol {
1651            Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1652            None => api.set_device(&id, device),
1653        };
1654        id
1655    }
1656
1657    /// Deliver one frame to the provided contexts and return the IDs of the
1658    /// sockets it was delivered to.
1659    fn deliver_one_frame(
1660        delivered_frame: Frame<&[u8]>,
1661        FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1662    ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1663        DeviceSocketHandler::handle_frame(
1664            core_ctx,
1665            bindings_ctx,
1666            &MultipleDevicesId::A,
1667            delivered_frame.clone(),
1668            TestData::BUFFER,
1669        );
1670
1671        let FakeSockets {
1672            all_sockets: AllSockets(all_sockets),
1673            any_device_sockets: _,
1674            device_sockets: _,
1675            counters: _,
1676            sent_frames: _,
1677        } = &core_ctx.state;
1678
1679        all_sockets
1680            .iter()
1681            .filter_map(|(id, _primary)| {
1682                let DeviceSocketId(rc) = &id;
1683                let ExternalSocketState(frames) = &rc.external_state;
1684                let lock_guard = frames.lock();
1685                let testutil::RxQueue { frames, .. } = lock_guard.deref();
1686                (!frames.is_empty()).then(|| {
1687                    assert_eq!(
1688                        &*frames,
1689                        &[ReceivedFrame {
1690                            device: FakeWeakDeviceId(MultipleDevicesId::A),
1691                            frame: delivered_frame.cloned(),
1692                            raw: TestData::BUFFER.into(),
1693                        }]
1694                    );
1695                    id.clone()
1696                })
1697            })
1698            .collect()
1699    }
1700
1701    #[test]
1702    fn receive_frame_deliver_to_multiple() {
1703        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1704            MultipleDevicesId::all(),
1705        )));
1706
1707        use Protocol::*;
1708        use TargetDevice::*;
1709        let never_bound = {
1710            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1711            ctx.device_socket_api().create(state)
1712        };
1713
1714        let mut make_bound = |device, protocol| {
1715            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1716            make_bound(&mut ctx, device, protocol, state)
1717        };
1718        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1719        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1720        let bound_a_right_protocol =
1721            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1722        let bound_a_wrong_protocol =
1723            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1724        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1725        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1726        let bound_b_right_protocol =
1727            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1728        let bound_b_wrong_protocol =
1729            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1730        let bound_any_no_protocol = make_bound(AnyDevice, None);
1731        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1732        let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1733        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1734
1735        let mut sockets_with_received_frames = deliver_one_frame(
1736            super::ReceivedFrame::from_ethernet(
1737                TestData::frame(),
1738                FrameDestination::Individual { local: true },
1739            )
1740            .into(),
1741            &mut ctx,
1742        );
1743
1744        let sockets_not_expecting_frames = [
1745            never_bound,
1746            bound_a_no_protocol,
1747            bound_a_wrong_protocol,
1748            bound_b_no_protocol,
1749            bound_b_all_protocols,
1750            bound_b_right_protocol,
1751            bound_b_wrong_protocol,
1752            bound_any_no_protocol,
1753            bound_any_wrong_protocol,
1754        ];
1755        let sockets_expecting_frames = [
1756            bound_a_all_protocols,
1757            bound_a_right_protocol,
1758            bound_any_all_protocols,
1759            bound_any_right_protocol,
1760        ];
1761
1762        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1763            assert!(
1764                sockets_with_received_frames.remove(&socket),
1765                "socket {n} didn't receive the frame"
1766            );
1767        }
1768        assert!(sockets_with_received_frames.is_empty());
1769
1770        // Verify Counters were set appropriately for each socket.
1771        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1772            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1773        }
1774        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1775            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1776        }
1777    }
1778
1779    #[test]
1780    fn sent_frame_deliver_to_multiple() {
1781        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1782            MultipleDevicesId::all(),
1783        )));
1784
1785        use Protocol::*;
1786        use TargetDevice::*;
1787        let never_bound = {
1788            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1789            ctx.device_socket_api().create(state)
1790        };
1791
1792        let mut make_bound = |device, protocol| {
1793            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1794            make_bound(&mut ctx, device, protocol, state)
1795        };
1796        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1797        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1798        let bound_a_same_protocol =
1799            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1800        let bound_a_wrong_protocol =
1801            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1802        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1803        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1804        let bound_b_same_protocol =
1805            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1806        let bound_b_wrong_protocol =
1807            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1808        let bound_any_no_protocol = make_bound(AnyDevice, None);
1809        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1810        let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1811        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1812
1813        let mut sockets_with_received_frames =
1814            deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1815
1816        let sockets_not_expecting_frames = [
1817            never_bound,
1818            bound_a_no_protocol,
1819            bound_a_same_protocol,
1820            bound_a_wrong_protocol,
1821            bound_b_no_protocol,
1822            bound_b_all_protocols,
1823            bound_b_same_protocol,
1824            bound_b_wrong_protocol,
1825            bound_any_no_protocol,
1826            bound_any_same_protocol,
1827            bound_any_wrong_protocol,
1828        ];
1829        // Only any-protocol sockets receive sent frames.
1830        let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1831
1832        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1833            assert!(
1834                sockets_with_received_frames.remove(&socket),
1835                "socket {n} didn't receive the frame"
1836            );
1837        }
1838        assert!(sockets_with_received_frames.is_empty());
1839
1840        // Verify Counters were set appropriately for each socket.
1841        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1842            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1843        }
1844        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1845            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1846        }
1847    }
1848
1849    #[test]
1850    fn deliver_multiple_frames() {
1851        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1852            MultipleDevicesId::all(),
1853        )));
1854        let socket = make_bound(
1855            &mut ctx,
1856            TargetDevice::AnyDevice,
1857            Some(Protocol::All),
1858            ExternalSocketState::default(),
1859        );
1860        let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1861
1862        const RECEIVE_COUNT: usize = 10;
1863        for _ in 0..RECEIVE_COUNT {
1864            DeviceSocketHandler::handle_frame(
1865                &mut core_ctx,
1866                &mut bindings_ctx,
1867                &MultipleDevicesId::A,
1868                super::ReceivedFrame::from_ethernet(
1869                    TestData::frame(),
1870                    FrameDestination::Individual { local: true },
1871                )
1872                .into(),
1873                TestData::BUFFER,
1874            );
1875        }
1876
1877        let FakeSockets {
1878            all_sockets: AllSockets(mut all_sockets),
1879            any_device_sockets: _,
1880            device_sockets: _,
1881            counters: _,
1882            sent_frames: _,
1883        } = core_ctx.into_state();
1884        let primary = all_sockets.remove(&socket).unwrap();
1885        let PrimaryDeviceSocketId(primary) = primary;
1886        assert!(all_sockets.is_empty());
1887        drop(socket);
1888        let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1889            PrimaryRc::unwrap(primary);
1890        assert_eq!(
1891            received.into_inner().frames,
1892            vec![
1893                ReceivedFrame {
1894                    device: FakeWeakDeviceId(MultipleDevicesId::A),
1895                    frame: Frame::Received(super::ReceivedFrame::Ethernet {
1896                        destination: FrameDestination::Individual { local: true },
1897                        frame: EthernetFrame {
1898                            src_mac: TestData::SRC_MAC,
1899                            dst_mac: TestData::DST_MAC,
1900                            ethertype: Some(TestData::PROTO.get().into()),
1901                            body_offset: TestData::BUFFER_OFFSET,
1902                            body: Vec::from(TestData::BODY),
1903                        }
1904                    }),
1905                    raw: TestData::BUFFER.into()
1906                };
1907                RECEIVE_COUNT
1908            ]
1909        );
1910        assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1911    }
1912
1913    #[test]
1914    fn deliver_frame_queue_full() {
1915        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1916            MultipleDevicesId::all(),
1917        )));
1918
1919        // Simulate a full RX queue for sock1.
1920        let sock1 = make_bound(
1921            &mut ctx,
1922            TargetDevice::AnyDevice,
1923            Some(Protocol::All),
1924            ExternalSocketState(Mutex::new(testutil::RxQueue { frames: vec![], max_size: 0 })),
1925        );
1926        let sock2 = make_bound(
1927            &mut ctx,
1928            TargetDevice::AnyDevice,
1929            Some(Protocol::All),
1930            ExternalSocketState::default(),
1931        );
1932
1933        let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1934
1935        DeviceSocketHandler::handle_frame(
1936            &mut core_ctx,
1937            &mut bindings_ctx,
1938            &MultipleDevicesId::A,
1939            super::ReceivedFrame::from_ethernet(
1940                TestData::frame(),
1941                FrameDestination::Individual { local: true },
1942            )
1943            .into(),
1944            TestData::BUFFER,
1945        );
1946
1947        assert_eq!(core_ctx.state.counters.rx_frames.get(), 2);
1948        assert_eq!(core_ctx.state.counters.rx_queue_full.get(), 1);
1949        assert_eq!(sock1.counters().rx_frames.get(), 1);
1950        assert_eq!(sock1.counters().rx_queue_full.get(), 1);
1951        assert_eq!(sock2.counters().rx_frames.get(), 1);
1952        assert_eq!(sock2.counters().rx_queue_full.get(), 0);
1953
1954        // Drop our strong references to the sockets so that `core_ctx` can tear
1955        // down successfully.
1956        drop(sock1);
1957        drop(sock2);
1958    }
1959
1960    pub struct FakeSendMetadata;
1961    impl DeviceSocketSendTypes for AnyDevice {
1962        type Metadata = FakeSendMetadata;
1963    }
1964    impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1965        for DeviceSocketMetadata<AnyDevice, D>
1966    {
1967        fn send_meta<S>(
1968            self,
1969            core_ctx: &mut FakeCoreCtx<D>,
1970            _bindings_ctx: &mut BC,
1971            frame: S,
1972        ) -> Result<(), SendFrameError<S>>
1973        where
1974            S: NetworkSerializer,
1975            S::Buffer: BufferMut,
1976        {
1977            let frame = match frame.serialize_vec_outer(&mut NetworkSerializationContext::default())
1978            {
1979                Err(e) => {
1980                    let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1981                    unreachable!()
1982                }
1983                Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1984            };
1985            core_ctx.state.sent_frames.push(frame);
1986            Ok(())
1987        }
1988    }
1989
1990    #[test]
1991    fn send_multiple_frames() {
1992        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1993            MultipleDevicesId::all(),
1994        )));
1995
1996        const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1997        let socket = make_bound(
1998            &mut ctx,
1999            TargetDevice::SpecificDevice(DEVICE),
2000            Some(Protocol::All),
2001            ExternalSocketState::default(),
2002        );
2003        let mut api = ctx.device_socket_api();
2004
2005        const SEND_COUNT: usize = 10;
2006        const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
2007        for _ in 0..SEND_COUNT {
2008            let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
2009            api.send_frame(
2010                &socket,
2011                DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
2012                buf,
2013            )
2014            .expect("send failed");
2015        }
2016
2017        assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
2018
2019        assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
2020    }
2021}