use alloc::collections::HashSet;
use core::{fmt::Debug, hash::Hash, num::NonZeroU16};
use dense_map::{DenseMap, EntryKey};
use derivative::Derivative;
use lock_order::{
lock::{LockFor, RwLockFor},
relation::LockBefore,
wrap::prelude::*,
};
use net_types::{ethernet::Mac, ip::IpVersion};
use packet::{BufferMut, ParsablePacket as _, Serializer};
use packet_formats::{
error::ParseError,
ethernet::{EtherType, EthernetFrameLengthCheck},
};
use crate::{
context::{ContextPair, SendFrameContext},
device::{
self, AnyDevice, Device, DeviceId, DeviceIdContext, DeviceLayerTypes, FrameDestination,
StrongId as _, WeakDeviceId, WeakId as _,
},
for_any_device_id,
sync::{Mutex, PrimaryRc, RwLock, StrongRc},
CoreCtx, StackState,
};
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub enum Protocol {
All,
Specific(NonZeroU16),
}
#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
#[derivative(Default(bound = ""))]
pub enum TargetDevice<D> {
#[derivative(Default)]
AnyDevice,
SpecificDevice(D),
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub struct SocketInfo<D> {
pub protocol: Option<Protocol>,
pub device: TargetDevice<D>,
}
pub trait DeviceSocketTypes {
type SocketState: Send + Sync + Debug;
}
pub trait DeviceSocketBindingsContext<DeviceId>: DeviceSocketTypes {
fn receive_frame(
&self,
socket: &Self::SocketState,
device: &DeviceId,
frame: Frame<&[u8]>,
raw_frame: &[u8],
);
}
#[derive(Debug)]
pub struct PrimaryId<S, D>(PrimaryRc<SocketState<S, D>>);
#[derive(Debug, Derivative)]
#[derivative(Clone(bound = ""), Hash(bound = ""))]
pub struct StrongId<S, D>(StrongRc<SocketState<S, D>>);
impl<S, D> PartialEq for StrongId<S, D> {
fn eq(&self, StrongId(other): &Self) -> bool {
let Self(strong) = self;
StrongRc::ptr_eq(strong, other)
}
}
impl<S, D> Eq for StrongId<S, D> {}
impl<S, D> EntryKey for StrongId<S, D> {
fn get_key_index(&self) -> usize {
let Self(strong) = self;
let SocketState { external_state: _, all_sockets_index, target: _ } = &**strong;
*all_sockets_index
}
}
pub trait StrongSocketId {
type Primary;
}
impl<S, D> StrongSocketId for StrongId<S, D> {
type Primary = PrimaryId<S, D>;
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub(super) struct Sockets<Primary, Strong> {
any_device_sockets: RwLock<AnyDeviceSockets<Strong>>,
all_sockets: Mutex<AllSockets<Primary>>,
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct AnyDeviceSockets<Id>(HashSet<Id>);
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct AllSockets<Id>(DenseMap<Id>);
#[derive(Debug)]
struct SocketState<S, D> {
all_sockets_index: usize,
external_state: S,
target: Mutex<Target<D>>,
}
#[derive(Debug, Derivative)]
#[derivative(Default(bound = ""))]
pub struct Target<D> {
protocol: Option<Protocol>,
device: TargetDevice<D>,
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
#[cfg_attr(test, derivative(Debug, PartialEq(bound = "Id: Hash + Eq")))]
pub struct DeviceSockets<Id>(HashSet<Id>);
pub(super) type HeldDeviceSockets<BT> =
DeviceSockets<StrongId<<BT as DeviceSocketTypes>::SocketState, WeakDeviceId<BT>>>;
pub(super) type HeldSockets<BT> = Sockets<
PrimaryId<<BT as DeviceSocketTypes>::SocketState, WeakDeviceId<BT>>,
StrongId<<BT as DeviceSocketTypes>::SocketState, WeakDeviceId<BT>>,
>;
pub trait DeviceSocketContextTypes {
type SocketId: Clone + Debug + Eq + Hash + StrongSocketId;
}
pub trait DeviceSocketContext<BC: DeviceSocketBindingsContext<Self::DeviceId>>:
DeviceSocketAccessor<BC>
{
type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<
BC,
DeviceId = Self::DeviceId,
WeakDeviceId = Self::WeakDeviceId,
SocketId = Self::SocketId,
>;
fn create_socket(&mut self, state: BC::SocketState) -> Self::SocketId;
fn remove_socket(&mut self, socket: Self::SocketId);
fn with_any_device_sockets<
F: FnOnce(&AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R;
fn with_any_device_sockets_mut<
F: FnOnce(&mut AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R;
}
pub trait SocketStateAccessor<BC: DeviceSocketBindingsContext<Self::DeviceId>>:
DeviceSocketContextTypes + DeviceIdContext<AnyDevice>
{
fn with_socket_state<F: FnOnce(&BC::SocketState, &Target<Self::WeakDeviceId>) -> R, R>(
&mut self,
socket: &Self::SocketId,
cb: F,
) -> R;
fn with_socket_state_mut<F: FnOnce(&BC::SocketState, &mut Target<Self::WeakDeviceId>) -> R, R>(
&mut self,
socket: &Self::SocketId,
cb: F,
) -> R;
}
pub trait DeviceSocketAccessor<BC: DeviceSocketBindingsContext<Self::DeviceId>>:
SocketStateAccessor<BC>
{
type DeviceSocketCoreCtx<'a>: SocketStateAccessor<
BC,
SocketId = Self::SocketId,
DeviceId = Self::DeviceId,
WeakDeviceId = Self::WeakDeviceId,
>;
fn with_device_sockets<
F: FnOnce(&DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R;
fn with_device_sockets_mut<
F: FnOnce(&mut DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R;
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum SendFrameError {
SendFailed,
}
enum MaybeUpdate<T> {
NoChange,
NewValue(T),
}
fn update_device_and_protocol<
CC: DeviceSocketContext<BC>,
BC: DeviceSocketBindingsContext<CC::DeviceId>,
>(
core_ctx: &mut CC,
socket: &CC::SocketId,
new_device: TargetDevice<&CC::DeviceId>,
protocol_update: MaybeUpdate<Protocol>,
) {
core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
let old_device = core_ctx.with_socket_state_mut(
socket,
|_: &BC::SocketState, Target { protocol, device }| {
match protocol_update {
MaybeUpdate::NewValue(p) => *protocol = Some(p),
MaybeUpdate::NoChange => (),
};
let old_device = match &device {
TargetDevice::SpecificDevice(device) => device.upgrade(),
TargetDevice::AnyDevice => {
assert!(any_device_sockets.remove(socket));
None
}
};
*device = match &new_device {
TargetDevice::AnyDevice => TargetDevice::AnyDevice,
TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
};
old_device
},
);
if let Some(device) = old_device {
core_ctx.with_device_sockets_mut(
&device,
|DeviceSockets(device_sockets), _core_ctx| {
assert!(device_sockets.remove(socket), "socket not found in device state");
},
);
}
match &new_device {
TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
new_device,
|DeviceSockets(device_sockets), _core_ctx| {
assert!(device_sockets.insert(socket.clone()));
},
),
TargetDevice::AnyDevice => {
assert!(any_device_sockets.insert(socket.clone()))
}
}
})
}
pub struct DeviceSocketApi<C>(C);
impl<C> DeviceSocketApi<C> {
pub(crate) fn new(ctx: C) -> Self {
Self(ctx)
}
}
impl<C> DeviceSocketApi<C>
where
C: ContextPair,
C::CoreContext: DeviceSocketContext<C::BindingsContext>,
C::BindingsContext:
DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
{
fn core_ctx(&mut self) -> &mut C::CoreContext {
let Self(pair) = self;
pair.core_ctx()
}
fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
let Self(pair) = self;
pair.contexts()
}
pub fn create(
&mut self,
external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState,
) -> <C::CoreContext as DeviceSocketContextTypes>::SocketId {
let core_ctx = self.core_ctx();
let strong = core_ctx.create_socket(external_state);
core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
assert!(any_device_sockets.insert(strong.clone()));
});
strong
}
pub fn set_device(
&mut self,
socket: &<C::CoreContext as DeviceSocketContextTypes>::SocketId,
device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
) {
update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
}
pub fn set_device_and_protocol(
&mut self,
socket: &<C::CoreContext as DeviceSocketContextTypes>::SocketId,
device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
protocol: Protocol,
) {
update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
}
pub fn get_info(
&mut self,
id: &<C::CoreContext as DeviceSocketContextTypes>::SocketId,
) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
self.core_ctx().with_socket_state(id, |_external_state, Target { device, protocol }| {
SocketInfo { device: device.clone(), protocol: *protocol }
})
}
pub fn remove(&mut self, id: <C::CoreContext as DeviceSocketContextTypes>::SocketId) {
let core_ctx = self.core_ctx();
core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
let old_device = core_ctx.with_socket_state_mut(&id, |_external_state, target| {
let Target { device, protocol: _ } = target;
match &device {
TargetDevice::SpecificDevice(device) => device.upgrade(),
TargetDevice::AnyDevice => {
assert!(any_device_sockets.remove(&id));
None
}
}
});
if let Some(device) = old_device {
core_ctx.with_device_sockets_mut(
&device,
|DeviceSockets(device_sockets), _core_ctx| {
assert!(device_sockets.remove(&id), "device doesn't have socket");
},
)
}
});
core_ctx.remove_socket(id)
}
pub fn send_frame<S, D>(
&mut self,
_id: &<C::CoreContext as DeviceSocketContextTypes>::SocketId,
metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
body: S,
) -> Result<(), (S, SendFrameError)>
where
S: Serializer,
S::Buffer: BufferMut,
D: DeviceSocketSendTypes,
C::CoreContext: DeviceIdContext<D>
+ SendFrameContext<
C::BindingsContext,
DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
>,
C::BindingsContext: DeviceLayerTypes,
{
let (core_ctx, bindings_ctx) = self.contexts();
core_ctx
.send_frame(bindings_ctx, metadata, body)
.map_err(|s| (s, SendFrameError::SendFailed))
}
}
pub trait DeviceSocketSendTypes: Device {
type Metadata;
}
#[derive(Debug, PartialEq)]
pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
pub device_id: DeviceId,
pub metadata: D::Metadata,
}
#[derive(Debug, PartialEq)]
pub struct EthernetHeaderParams {
pub dest_addr: Mac,
pub protocol: EtherType,
}
pub type SocketId<BC> = StrongId<<BC as DeviceSocketTypes>::SocketState, WeakDeviceId<BC>>;
impl<S, D> StrongId<S, D> {
pub fn socket_state(&self) -> &S {
let Self(strong) = self;
let SocketState { external_state, all_sockets_index: _, target: _ } = &**strong;
external_state
}
}
pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
fn handle_frame(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
frame: Frame<&[u8]>,
whole_frame: &[u8],
);
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ReceivedFrame<B> {
Ethernet {
destination: FrameDestination,
frame: EthernetFrame<B>,
},
Ip(IpFrame<B>),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SentFrame<B> {
Ethernet(EthernetFrame<B>),
Ip(IpFrame<B>),
}
#[derive(Debug)]
pub struct ParseSentFrameError;
impl SentFrame<&[u8]> {
pub(crate) fn try_parse_as_ethernet(
mut buf: &[u8],
) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
.map_err(|_: ParseError| ParseSentFrameError)
.map(|frame| SentFrame::Ethernet(frame.into()))
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct EthernetFrame<B> {
pub src_mac: Mac,
pub dst_mac: Mac,
pub ethertype: Option<EtherType>,
pub body: B,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct IpFrame<B> {
pub ip_version: IpVersion,
pub body: B,
}
impl<B> IpFrame<B> {
fn ethertype(&self) -> EtherType {
let IpFrame { ip_version, body: _ } = self;
EtherType::from_ip_version(*ip_version)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Frame<B> {
Sent(SentFrame<B>),
Received(ReceivedFrame<B>),
}
impl<B> From<SentFrame<B>> for Frame<B> {
fn from(value: SentFrame<B>) -> Self {
Self::Sent(value)
}
}
impl<B> From<ReceivedFrame<B>> for Frame<B> {
fn from(value: ReceivedFrame<B>) -> Self {
Self::Received(value)
}
}
impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
Self {
src_mac: frame.src_mac(),
dst_mac: frame.dst_mac(),
ethertype: frame.ethertype(),
body: frame.into_body(),
}
}
}
impl<'a> ReceivedFrame<&'a [u8]> {
pub(crate) fn from_ethernet(
frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
destination: FrameDestination,
) -> Self {
Self::Ethernet { destination, frame: frame.into() }
}
}
impl<B> Frame<B> {
fn protocol(&self) -> Option<u16> {
let ethertype = match self {
Self::Sent(SentFrame::Ethernet(frame))
| Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
Some(frame.ethertype())
}
};
ethertype.map(Into::into)
}
pub fn into_body(self) -> B {
match self {
Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
| Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
frame.body
}
}
}
}
impl<
D: Device,
BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
> DeviceSocketHandler<D, BC> for CC
where
<CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
{
fn handle_frame(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
frame: Frame<&[u8]>,
whole_frame: &[u8],
) {
let device = device.clone().into();
self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
for socket in any_device_sockets.iter().chain(device_sockets) {
core_ctx.with_socket_state(
socket,
|external_state, Target { protocol, device: _ }| {
let should_deliver = match protocol {
None => false,
Some(p) => match p {
Protocol::Specific(p) => match frame {
Frame::Received(_) => Some(p.get()) == frame.protocol(),
Frame::Sent(_) => false,
},
Protocol::All => true,
},
};
if should_deliver {
bindings_ctx.receive_frame(
external_state,
&device,
frame,
whole_frame,
)
}
},
)
}
})
})
}
}
impl<BC: crate::BindingsContext, L> DeviceSocketContextTypes for CoreCtx<'_, BC, L> {
type SocketId = StrongId<BC::SocketState, WeakDeviceId<BC>>;
}
impl<BC: crate::BindingsContext, L: LockBefore<crate::lock_ordering::AllDeviceSockets>>
DeviceSocketContext<BC> for CoreCtx<'_, BC, L>
{
type SocketTablesCoreCtx<'a> = CoreCtx<'a, BC, crate::lock_ordering::AnyDeviceSockets>;
fn create_socket(&mut self, state: BC::SocketState) -> Self::SocketId {
let mut sockets = self.lock();
let AllSockets(sockets) = &mut *sockets;
let entry = sockets.push_with(|index| {
PrimaryId(PrimaryRc::new(SocketState {
all_sockets_index: index,
external_state: state,
target: Mutex::new(Target::default()),
}))
});
let PrimaryId(primary) = &entry.get();
StrongId(PrimaryRc::clone_strong(primary))
}
fn remove_socket(&mut self, socket: Self::SocketId) {
let mut state = self.lock();
let AllSockets(sockets) = &mut *state;
let PrimaryId(primary) = sockets.remove(socket.get_key_index()).expect("unknown socket ID");
drop(socket);
let _: SocketState<_, _> = PrimaryRc::unwrap(primary);
}
fn with_any_device_sockets<
F: FnOnce(&AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R {
let (sockets, mut locked) = self.read_lock_and::<crate::lock_ordering::AnyDeviceSockets>();
cb(&*sockets, &mut locked)
}
fn with_any_device_sockets_mut<
F: FnOnce(&mut AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R {
let (mut sockets, mut locked) =
self.write_lock_and::<crate::lock_ordering::AnyDeviceSockets>();
cb(&mut *sockets, &mut locked)
}
}
impl<BC: crate::BindingsContext, L: LockBefore<crate::lock_ordering::DeviceSocketState>>
SocketStateAccessor<BC> for CoreCtx<'_, BC, L>
{
fn with_socket_state<F: FnOnce(&BC::SocketState, &Target<Self::WeakDeviceId>) -> R, R>(
&mut self,
StrongId(strong): &Self::SocketId,
cb: F,
) -> R {
let SocketState { external_state, target, all_sockets_index: _ } = &**strong;
cb(external_state, &*target.lock())
}
fn with_socket_state_mut<
F: FnOnce(&BC::SocketState, &mut Target<Self::WeakDeviceId>) -> R,
R,
>(
&mut self,
StrongId(primary): &Self::SocketId,
cb: F,
) -> R {
let SocketState { external_state, target, all_sockets_index: _ } = &**primary;
cb(external_state, &mut *target.lock())
}
}
impl<BC: crate::BindingsContext, L: LockBefore<crate::lock_ordering::DeviceSockets>>
DeviceSocketAccessor<BC> for CoreCtx<'_, BC, L>
{
type DeviceSocketCoreCtx<'a> = CoreCtx<'a, BC, crate::lock_ordering::DeviceSockets>;
fn with_device_sockets<
F: FnOnce(&DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R {
for_any_device_id!(
DeviceId,
device,
device => device::integration::with_device_state_and_core_ctx(
self,
device,
|mut core_ctx_and_resource| {
let (device_sockets, mut locked) = core_ctx_and_resource
.read_lock_with_and::<crate::lock_ordering::DeviceSockets, _>(
|c| c.right(),
);
cb(&*device_sockets, &mut locked.cast_core_ctx())
},
)
)
}
fn with_device_sockets_mut<
F: FnOnce(&mut DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R {
for_any_device_id!(
DeviceId,
device,
device => device::integration::with_device_state_and_core_ctx(
self,
device,
|mut core_ctx_and_resource| {
let (mut device_sockets, mut locked) = core_ctx_and_resource
.write_lock_with_and::<crate::lock_ordering::DeviceSockets, _>(
|c| c.right(),
);
cb(&mut *device_sockets, &mut locked.cast_core_ctx())
},
)
)
}
}
impl<BC: crate::BindingsContext> RwLockFor<crate::lock_ordering::AnyDeviceSockets>
for StackState<BC>
{
type Data = AnyDeviceSockets<StrongId<BC::SocketState, WeakDeviceId<BC>>>;
type ReadGuard<'l> = crate::sync::RwLockReadGuard<'l, AnyDeviceSockets<StrongId<BC::SocketState, WeakDeviceId<BC>>>>
where Self: 'l;
type WriteGuard<'l> = crate::sync::RwLockWriteGuard<'l, AnyDeviceSockets<StrongId<BC::SocketState, WeakDeviceId<BC>>>>
where Self: 'l;
fn read_lock(&self) -> Self::ReadGuard<'_> {
self.device.shared_sockets.any_device_sockets.read()
}
fn write_lock(&self) -> Self::WriteGuard<'_> {
self.device.shared_sockets.any_device_sockets.write()
}
}
impl<BC: crate::BindingsContext> LockFor<crate::lock_ordering::AllDeviceSockets>
for StackState<BC>
{
type Data = AllSockets<PrimaryId<BC::SocketState, WeakDeviceId<BC>>>;
type Guard<'l> = crate::sync::LockGuard<'l, AllSockets<PrimaryId<BC::SocketState, WeakDeviceId<BC>>>>
where Self: 'l;
fn lock(&self) -> Self::Guard<'_> {
self.device.shared_sockets.all_sockets.lock()
}
}
#[cfg(test)]
mod testutil {
use crate::context::testutil::FakeBindingsCtx;
use crate::device::DeviceLayerStateTypes;
use crate::testutil::MonotonicIdentifier;
use super::*;
impl<TimerId, Event: Debug, State> DeviceSocketTypes
for FakeBindingsCtx<TimerId, Event, State, ()>
{
type SocketState = ();
}
impl<TimerId: Debug + PartialEq + Clone + Send + Sync, Event: Debug, State>
DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
{
type EthernetDeviceState = ();
type LoopbackDeviceState = ();
type PureIpDeviceState = ();
type DeviceIdentifier = MonotonicIdentifier;
}
impl<TimerId, Event: Debug, State, DeviceId> DeviceSocketBindingsContext<DeviceId>
for FakeBindingsCtx<TimerId, Event, State, ()>
{
fn receive_frame(
&self,
_socket: &Self::SocketState,
_device: &DeviceId,
_frame: Frame<&[u8]>,
_raw_frame: &[u8],
) {
unimplemented!()
}
}
}
#[cfg(test)]
mod tests {
use alloc::{collections::HashMap, vec, vec::Vec};
use const_unwrap::const_unwrap_option;
use derivative::Derivative;
use packet::ParsablePacket;
use test_case::test_case;
use crate::{
context::ContextProvider,
device::{
testutil::{
FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
},
Id,
},
};
use super::*;
impl Frame<&[u8]> {
pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
match self {
Self::Sent(SentFrame::Ethernet(frame)) => {
Frame::Sent(SentFrame::Ethernet(frame.cloned()))
}
Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
Frame::Received(super::ReceivedFrame::Ethernet {
destination,
frame: frame.cloned(),
})
}
Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
Self::Received(super::ReceivedFrame::Ip(frame)) => {
Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
}
}
}
}
impl EthernetFrame<&[u8]> {
fn cloned(self) -> EthernetFrame<Vec<u8>> {
let Self { src_mac, dst_mac, ethertype, body } = self;
EthernetFrame { src_mac, dst_mac, ethertype, body: Vec::from(body) }
}
}
impl IpFrame<&[u8]> {
fn cloned(self) -> IpFrame<Vec<u8>> {
let Self { ip_version, body } = self;
IpFrame { ip_version, body: Vec::from(body) }
}
}
#[derive(Clone, Debug, PartialEq)]
struct ReceivedFrame<D> {
device: D,
frame: Frame<Vec<u8>>,
raw: Vec<u8>,
}
type FakeCoreCtx<D> = crate::context::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
type FakeCtx<D> = crate::testutil::ContextPair<FakeCoreCtx<D>, FakeBindingsCtx<D>>;
#[derive(Debug, Derivative)]
#[derivative(Default(bound = ""))]
struct FakeBindingsCtx<D>(core::marker::PhantomData<D>);
impl<D> ContextProvider for FakeBindingsCtx<D> {
type Context = Self;
fn context(&mut self) -> &mut Self::Context {
self
}
}
impl<D: Id> DeviceSocketTypes for FakeBindingsCtx<D> {
type SocketState = ExternalSocketState<D>;
}
impl<D: Id> DeviceSocketBindingsContext<D> for FakeBindingsCtx<D> {
fn receive_frame(
&self,
state: &ExternalSocketState<D>,
device: &D,
frame: Frame<&[u8]>,
raw_frame: &[u8],
) {
let ExternalSocketState(queue) = state;
queue.lock().push(ReceivedFrame {
device: device.clone(),
frame: frame.cloned(),
raw: raw_frame.into(),
})
}
}
trait DeviceSocketApiExt: crate::context::ContextPair + Sized {
fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
DeviceSocketApi::new(self)
}
}
impl<O> DeviceSocketApiExt for O where O: crate::context::ContextPair + Sized {}
#[derive(Debug, Derivative)]
#[derivative(Default(bound = ""))]
struct ExternalSocketState<D>(Mutex<Vec<ReceivedFrame<D>>>);
type FakeAllSockets<D> =
DenseMap<(ExternalSocketState<D>, Target<<D as crate::device::StrongId>::Weak>)>;
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
struct FakeSockets<D: FakeStrongDeviceId> {
any_device_sockets: AnyDeviceSockets<FakeStrongId>,
device_sockets: HashMap<D, DeviceSockets<FakeStrongId>>,
all_sockets: FakeAllSockets<D>,
}
struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices>(
&'m mut AnyDevice,
&'m mut AllSockets,
&'m mut Devices,
);
trait AsFakeSocketsMutRefs {
type AnyDevice: 'static;
type AllSockets: 'static;
type Devices: 'static;
fn as_sockets_ref(
&mut self,
) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices>;
}
impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
type AnyDevice = AnyDeviceSockets<FakeStrongId>;
type AllSockets = FakeAllSockets<D>;
type Devices = HashMap<D, DeviceSockets<FakeStrongId>>;
fn as_sockets_ref(
&mut self,
) -> FakeSocketsMutRefs<
'_,
AnyDeviceSockets<FakeStrongId>,
FakeAllSockets<D>,
HashMap<D, DeviceSockets<FakeStrongId>>,
> {
let FakeSockets { any_device_sockets, device_sockets, all_sockets } = self.get_mut();
FakeSocketsMutRefs(any_device_sockets, all_sockets, device_sockets)
}
}
impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static> AsFakeSocketsMutRefs
for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices>
{
type AnyDevice = AnyDevice;
type AllSockets = AllSockets;
type Devices = Devices;
fn as_sockets_ref(&mut self) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices> {
let Self(any_device, all_sockets, devices) = self;
FakeSocketsMutRefs(any_device, all_sockets, devices)
}
}
impl<As: AsFakeSocketsMutRefs> DeviceSocketContextTypes for As {
type SocketId = FakeStrongId;
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct FakeStrongId(usize);
#[derive(Debug)]
pub struct FakePrimaryId;
impl StrongSocketId for FakeStrongId {
type Primary = FakePrimaryId;
}
impl<D: Clone> TargetDevice<&D> {
fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
match self {
TargetDevice::AnyDevice => TargetDevice::AnyDevice,
TargetDevice::SpecificDevice(d) => {
TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
}
}
}
}
impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
fn new(devices: impl IntoIterator<Item = D>) -> Self {
let device_sockets =
devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
Self {
any_device_sockets: AnyDeviceSockets::default(),
device_sockets,
all_sockets: FakeAllSockets::<D>::default(),
}
}
}
pub trait FakeDeviceIdContext {
type DeviceId: FakeStrongDeviceId;
fn contains_id(&self, device_id: &Self::DeviceId) -> bool;
}
impl<CC: FakeDeviceIdContext> DeviceIdContext<AnyDevice> for CC {
type DeviceId = CC::DeviceId;
type WeakDeviceId = FakeWeakDeviceId<CC::DeviceId>;
}
impl<
'm,
DeviceId: FakeStrongDeviceId,
As: AsFakeSocketsMutRefs<AllSockets = FakeAllSockets<DeviceId>>
+ DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>
+ DeviceSocketContextTypes<SocketId = FakeStrongId>,
> SocketStateAccessor<FakeBindingsCtx<DeviceId>> for As
where
As::Devices: FakeDeviceIdContext<DeviceId = DeviceId>,
{
fn with_socket_state<
F: FnOnce(&ExternalSocketState<Self::DeviceId>, &Target<Self::WeakDeviceId>) -> R,
R,
>(
&mut self,
socket: &Self::SocketId,
cb: F,
) -> R {
let FakeSocketsMutRefs(_, all_sockets, _) = self.as_sockets_ref();
let (state, target) = all_sockets.get(socket.0).unwrap();
cb(state, target)
}
fn with_socket_state_mut<
F: FnOnce(&ExternalSocketState<Self::DeviceId>, &mut Target<Self::WeakDeviceId>) -> R,
R,
>(
&mut self,
socket: &Self::SocketId,
cb: F,
) -> R {
let FakeSocketsMutRefs(_, all_sockets, _) = self.as_sockets_ref();
let (state, target) = all_sockets.get_mut(socket.0).unwrap();
cb(state, target)
}
}
impl<
'm,
DeviceId: FakeStrongDeviceId,
As: AsFakeSocketsMutRefs<
AllSockets = FakeAllSockets<DeviceId>,
Devices = HashMap<DeviceId, DeviceSockets<FakeStrongId>>,
> + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>
+ DeviceSocketContextTypes<SocketId = FakeStrongId>,
> DeviceSocketAccessor<FakeBindingsCtx<DeviceId>> for As
{
type DeviceSocketCoreCtx<'a> =
FakeSocketsMutRefs<'a, As::AnyDevice, FakeAllSockets<DeviceId>, HashSet<DeviceId>>;
fn with_device_sockets<
F: FnOnce(&DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R {
let FakeSocketsMutRefs(any_device, all_sockets, device_sockets) = self.as_sockets_ref();
let mut devices = device_sockets.keys().cloned().collect();
let device = device_sockets.get(device).unwrap();
cb(device, &mut FakeSocketsMutRefs(any_device, all_sockets, &mut devices))
}
fn with_device_sockets_mut<
F: FnOnce(&mut DeviceSockets<Self::SocketId>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
R,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> R {
let FakeSocketsMutRefs(any_device, all_sockets, device_sockets) = self.as_sockets_ref();
let mut devices = device_sockets.keys().cloned().collect();
let device = device_sockets.get_mut(device).unwrap();
cb(device, &mut FakeSocketsMutRefs(any_device, all_sockets, &mut devices))
}
}
impl<
'm,
DeviceId: FakeStrongDeviceId,
As: AsFakeSocketsMutRefs<
AnyDevice = AnyDeviceSockets<FakeStrongId>,
AllSockets = FakeAllSockets<DeviceId>,
Devices = HashMap<DeviceId, DeviceSockets<FakeStrongId>>,
> + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>
+ DeviceSocketContextTypes<SocketId = FakeStrongId>,
> DeviceSocketContext<FakeBindingsCtx<DeviceId>> for As
{
type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
'a,
(),
FakeAllSockets<DeviceId>,
HashMap<DeviceId, DeviceSockets<FakeStrongId>>,
>;
fn create_socket(&mut self, state: ExternalSocketState<DeviceId>) -> Self::SocketId {
let FakeSocketsMutRefs(_any_device, all_sockets, _devices) = self.as_sockets_ref();
FakeStrongId(all_sockets.push((state, Target::default())))
}
fn remove_socket(&mut self, id: Self::SocketId) {
let FakeSocketsMutRefs(
AnyDeviceSockets(any_device_sockets),
all_sockets,
device_sockets,
) = self.as_sockets_ref();
assert!(!any_device_sockets.contains(&id));
assert!(!device_sockets
.iter()
.any(|(_device, DeviceSockets(sockets))| sockets.contains(&id)));
let FakeStrongId(index) = id;
let _: (_, _) = all_sockets.remove(index).unwrap();
}
fn with_any_device_sockets<
F: FnOnce(&AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R {
let FakeSocketsMutRefs(any_device_sockets, all_sockets, device_sockets) =
self.as_sockets_ref();
cb(any_device_sockets, &mut FakeSocketsMutRefs(&mut (), all_sockets, device_sockets))
}
fn with_any_device_sockets_mut<
F: FnOnce(&mut AnyDeviceSockets<Self::SocketId>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
R,
>(
&mut self,
cb: F,
) -> R {
let FakeSocketsMutRefs(any_device_sockets, all_sockets, device_sockets) =
self.as_sockets_ref();
cb(any_device_sockets, &mut FakeSocketsMutRefs(&mut (), all_sockets, device_sockets))
}
}
impl<'m, X: 'static, Y: 'static, Z: FakeDeviceIdContext + 'static> FakeDeviceIdContext
for FakeSocketsMutRefs<'m, X, Y, Z>
{
type DeviceId = Z::DeviceId;
fn contains_id(&self, device_id: &Self::DeviceId) -> bool {
self.2.contains_id(device_id)
}
}
impl<D: FakeStrongDeviceId> FakeDeviceIdContext for HashSet<D> {
type DeviceId = D;
fn contains_id(&self, device_id: &Self::DeviceId) -> bool {
self.contains(device_id)
}
}
impl<V, D: FakeStrongDeviceId> FakeDeviceIdContext for HashMap<D, V> {
type DeviceId = D;
fn contains_id(&self, device_id: &Self::DeviceId) -> bool {
self.contains_key(device_id)
}
}
const SOME_PROTOCOL: NonZeroU16 = const_unwrap_option(NonZeroU16::new(2000));
#[test]
fn create_remove() {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
let mut api = ctx.device_socket_api();
let bound = api.create(Default::default());
assert_eq!(
api.get_info(&bound),
SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
);
api.remove(bound);
}
#[test_case(TargetDevice::AnyDevice)]
#[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
let mut api = ctx.device_socket_api();
let bound = api.create(Default::default());
api.set_device(&bound, device.clone());
assert_eq!(
api.get_info(&bound),
SocketInfo { device: device.with_weak_id(), protocol: None }
);
let FakeSockets { device_sockets, any_device_sockets: _, all_sockets: _ } =
api.core_ctx().get_ref();
if let TargetDevice::SpecificDevice(d) = device {
let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
assert_eq!(socket_ids, &HashSet::from([bound]));
}
}
#[test]
fn update_device() {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
let mut api = ctx.device_socket_api();
let bound = api.create(Default::default());
api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
assert_eq!(
api.get_info(&bound),
SocketInfo {
device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
protocol: None
}
);
let FakeSockets { device_sockets, any_device_sockets: _, all_sockets: _ } =
api.core_ctx().get_ref();
let device_socket_lists = device_sockets
.iter()
.map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
.collect::<HashMap<_, _>>();
assert_eq!(
device_socket_lists,
HashMap::from([
(&MultipleDevicesId::A, vec![]),
(&MultipleDevicesId::B, vec![&bound]),
(&MultipleDevicesId::C, vec![])
])
);
}
#[test_case(Protocol::All, TargetDevice::AnyDevice)]
#[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
#[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
#[test_case(
Protocol::Specific(SOME_PROTOCOL),
TargetDevice::SpecificDevice(&MultipleDevicesId::A)
)]
fn create_set_device_and_protocol_remove_multiple(
protocol: Protocol,
device: TargetDevice<&MultipleDevicesId>,
) {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
let mut api = ctx.device_socket_api();
let mut sockets = [(); 3].map(|()| api.create(Default::default()));
for socket in &mut sockets {
api.set_device_and_protocol(socket, device.clone(), protocol);
assert_eq!(
api.get_info(socket),
SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
);
}
for socket in sockets {
api.remove(socket)
}
}
#[test]
fn change_device_after_removal() {
let device_to_remove = FakeReferencyDeviceId::default();
let device_to_maintain = FakeReferencyDeviceId::default();
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
device_to_remove.clone(),
device_to_maintain.clone(),
])));
let mut api = ctx.device_socket_api();
let bound = api.create(Default::default());
api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
device_to_remove.mark_removed();
api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
assert_eq!(
api.get_info(&bound),
SocketInfo {
device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
protocol: None,
}
);
let FakeSockets { device_sockets, any_device_sockets: _, all_sockets: _ } =
api.core_ctx().get_ref();
let DeviceSockets(weak_sockets) =
device_sockets.get(&device_to_maintain).expect("device state exists");
assert_eq!(weak_sockets, &HashSet::from([bound]));
}
struct TestData;
impl TestData {
const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
const PROTO: NonZeroU16 = const_unwrap_option(NonZeroU16::new(0x08AB));
const BODY: &'static [u8] = b"some pig";
const BUFFER: &'static [u8] = &[
6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
b'i', b'g',
];
fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
let mut buffer_view = Self::BUFFER;
packet_formats::ethernet::EthernetFrame::parse(
&mut buffer_view,
EthernetFrameLengthCheck::NoCheck,
)
.unwrap()
}
}
const WRONG_PROTO: NonZeroU16 = const_unwrap_option(NonZeroU16::new(0x08ff));
fn make_bound<D: FakeStrongDeviceId>(
ctx: &mut FakeCtx<D>,
device: TargetDevice<D>,
protocol: Option<Protocol>,
state: ExternalSocketState<D>,
) -> FakeStrongId {
let mut api = ctx.device_socket_api();
let id = api.create(state);
let device = match &device {
TargetDevice::AnyDevice => TargetDevice::AnyDevice,
TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
};
match protocol {
Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
None => api.set_device(&id, device),
};
id
}
fn deliver_one_frame(
delivered_frame: Frame<&[u8]>,
FakeCtx { mut core_ctx, mut bindings_ctx }: FakeCtx<MultipleDevicesId>,
) -> HashSet<FakeStrongId> {
DeviceSocketHandler::handle_frame(
&mut core_ctx,
&mut bindings_ctx,
&MultipleDevicesId::A,
delivered_frame.clone(),
TestData::BUFFER,
);
let FakeSockets { all_sockets, any_device_sockets: _, device_sockets: _ } =
core_ctx.into_state();
all_sockets
.into_iter()
.filter_map(|(index, (ExternalSocketState(frames), _)): (_, (_, Target<_>))| {
let frames = frames.into_inner();
(!frames.is_empty()).then(|| {
assert_eq!(
frames,
&[ReceivedFrame {
device: MultipleDevicesId::A,
frame: delivered_frame.cloned(),
raw: TestData::BUFFER.into(),
}]
);
FakeStrongId(index)
})
})
.collect()
}
#[test]
fn receive_frame_deliver_to_multiple() {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
use Protocol::*;
use TargetDevice::*;
let never_bound = {
let state = ExternalSocketState::<MultipleDevicesId>::default();
ctx.device_socket_api().create(state)
};
let mut make_bound = |device, protocol| {
let state = ExternalSocketState::<MultipleDevicesId>::default();
make_bound(&mut ctx, device, protocol, state)
};
let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
let bound_a_right_protocol =
make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
let bound_a_wrong_protocol =
make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
let bound_b_right_protocol =
make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
let bound_b_wrong_protocol =
make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
let bound_any_no_protocol = make_bound(AnyDevice, None);
let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
let mut sockets_with_received_frames = deliver_one_frame(
super::ReceivedFrame::from_ethernet(
TestData::frame(),
FrameDestination::Individual { local: true },
)
.into(),
ctx,
);
let _ = (
never_bound,
bound_a_no_protocol,
bound_a_wrong_protocol,
bound_b_no_protocol,
bound_b_all_protocols,
bound_b_right_protocol,
bound_b_wrong_protocol,
bound_any_no_protocol,
bound_any_wrong_protocol,
);
assert!(sockets_with_received_frames.remove(&bound_a_all_protocols));
assert!(sockets_with_received_frames.remove(&bound_a_right_protocol));
assert!(sockets_with_received_frames.remove(&bound_any_all_protocols));
assert!(sockets_with_received_frames.remove(&bound_any_right_protocol));
assert!(sockets_with_received_frames.is_empty());
}
#[test]
fn sent_frame_deliver_to_multiple() {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
use Protocol::*;
use TargetDevice::*;
let never_bound = {
let state = ExternalSocketState::<MultipleDevicesId>::default();
ctx.device_socket_api().create(state)
};
let mut make_bound = |device, protocol| {
let state = ExternalSocketState::<MultipleDevicesId>::default();
make_bound(&mut ctx, device, protocol, state)
};
let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
let bound_a_same_protocol =
make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
let bound_a_wrong_protocol =
make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
let bound_b_same_protocol =
make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
let bound_b_wrong_protocol =
make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
let bound_any_no_protocol = make_bound(AnyDevice, None);
let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
let mut sockets_with_received_frames =
deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), ctx);
let _ = (
never_bound,
bound_a_no_protocol,
bound_a_same_protocol,
bound_a_wrong_protocol,
bound_b_no_protocol,
bound_b_all_protocols,
bound_b_same_protocol,
bound_b_wrong_protocol,
bound_any_no_protocol,
bound_any_same_protocol,
bound_any_wrong_protocol,
);
assert!(sockets_with_received_frames.remove(&bound_a_all_protocols));
assert!(sockets_with_received_frames.remove(&bound_any_all_protocols));
assert!(sockets_with_received_frames.is_empty());
}
#[test]
fn deliver_multiple_frames() {
let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
MultipleDevicesId::all(),
)));
let socket = make_bound(
&mut ctx,
TargetDevice::AnyDevice,
Some(Protocol::All),
ExternalSocketState::default(),
);
let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
const RECEIVE_COUNT: usize = 10;
for _ in 0..RECEIVE_COUNT {
DeviceSocketHandler::handle_frame(
&mut core_ctx,
&mut bindings_ctx,
&MultipleDevicesId::A,
super::ReceivedFrame::from_ethernet(
TestData::frame(),
FrameDestination::Individual { local: true },
)
.into(),
TestData::BUFFER,
);
}
let FakeSockets { mut all_sockets, any_device_sockets: _, device_sockets: _ } =
core_ctx.into_state();
let FakeStrongId(index) = socket;
let (ExternalSocketState(received), _): (_, Target<_>) = all_sockets.remove(index).unwrap();
assert_eq!(
received.into_inner(),
vec![
ReceivedFrame {
device: MultipleDevicesId::A,
frame: Frame::Received(super::ReceivedFrame::Ethernet {
destination: FrameDestination::Individual { local: true },
frame: EthernetFrame {
src_mac: TestData::SRC_MAC,
dst_mac: TestData::DST_MAC,
ethertype: Some(TestData::PROTO.get().into()),
body: Vec::from(TestData::BODY),
}
}),
raw: TestData::BUFFER.into()
};
RECEIVE_COUNT
]
);
assert!(all_sockets.is_empty());
}
#[test]
fn drop_real_ids() {
use crate::testutil::{FakeEventDispatcherBuilder, FAKE_CONFIG_V4};
let (mut ctx, device_ids) = FakeEventDispatcherBuilder::from_config(FAKE_CONFIG_V4).build();
let mut api = ctx.core_api().device_socket();
let never_bound = api.create(Mutex::default());
let bound_any_device = {
let id = api.create(Mutex::default());
api.set_device(&id, TargetDevice::AnyDevice);
id
};
let bound_specific_device = {
let id = api.create(Mutex::default());
api.set_device(
&id,
TargetDevice::SpecificDevice(&DeviceId::Ethernet(device_ids[0].clone())),
);
id
};
drop((never_bound, bound_any_device, bound_specific_device));
}
}