use alloc::vec::Vec;
use lock_order::{
lock::{LockFor, RwLockFor},
relation::LockBefore,
Locked,
};
use net_types::{ethernet::Mac, ip::IpAddress, SpecifiedAddr};
use packet::{Buf, Buffer as _, BufferMut, Serializer};
use packet_formats::ethernet::{
EtherType, EthernetFrame, EthernetFrameBuilder, EthernetFrameLengthCheck, EthernetIpExt,
};
use tracing::trace;
use crate::{
context::{CounterContext, SendFrameContext},
device::{
self,
id::{BaseDeviceId, BasePrimaryDeviceId, BaseWeakDeviceId},
queue::{
rx::{
ReceiveDequeContext, ReceiveDequeFrameContext, ReceiveQueue, ReceiveQueueContext,
ReceiveQueueHandler, ReceiveQueueNonSyncContext, ReceiveQueueState,
ReceiveQueueTypes,
},
tx::{
BufVecU8Allocator, TransmitDequeueContext, TransmitQueue, TransmitQueueCommon,
TransmitQueueContext, TransmitQueueHandler, TransmitQueueNonSyncContext,
TransmitQueueState,
},
DequeueState, ReceiveQueueFullError, TransmitQueueFrameError,
},
socket::{
DatagramHeader, DeviceSocketHandler, DeviceSocketMetadata, HeldDeviceSockets,
ParseSentFrameError, ReceivedFrame, SentFrame,
},
state::{DeviceStateSpec, IpLinkDeviceState},
Device, DeviceCounters, DeviceIdContext, DeviceLayerEventDispatcher, DeviceLayerTypes,
DeviceSendFrameError, FrameDestination, Mtu,
},
ip::types::RawMetric,
NonSyncContext, SyncCtx,
};
const LOOPBACK_MAC: Mac = Mac::UNSPECIFIED;
pub type LoopbackWeakDeviceId<C> = BaseWeakDeviceId<LoopbackDevice, C>;
pub type LoopbackDeviceId<C> = BaseDeviceId<LoopbackDevice, C>;
pub(crate) type LoopbackPrimaryDeviceId<C> = BasePrimaryDeviceId<LoopbackDevice, C>;
#[derive(Copy, Clone)]
pub enum LoopbackDevice {}
impl Device for LoopbackDevice {}
impl DeviceStateSpec for LoopbackDevice {
type Link<C: DeviceLayerTypes> = LoopbackDeviceState;
type External<C: DeviceLayerTypes> = C::LoopbackDeviceState;
const IS_LOOPBACK: bool = true;
const DEBUG_TYPE: &'static str = "Loopback";
}
impl<NonSyncCtx: NonSyncContext, L> DeviceIdContext<LoopbackDevice>
for Locked<&SyncCtx<NonSyncCtx>, L>
{
type DeviceId = LoopbackDeviceId<NonSyncCtx>;
type WeakDeviceId = LoopbackWeakDeviceId<NonSyncCtx>;
fn downgrade_device_id(&self, device_id: &Self::DeviceId) -> Self::WeakDeviceId {
device_id.downgrade()
}
fn upgrade_weak_device_id(
&self,
weak_device_id: &Self::WeakDeviceId,
) -> Option<Self::DeviceId> {
weak_device_id.upgrade()
}
}
pub struct LoopbackDeviceState {
mtu: Mtu,
metric: RawMetric,
rx_queue: ReceiveQueue<(), Buf<Vec<u8>>>,
tx_queue: TransmitQueue<(), Buf<Vec<u8>>, BufVecU8Allocator>,
}
impl LoopbackDeviceState {
pub(super) fn new(mtu: Mtu, metric: RawMetric) -> LoopbackDeviceState {
LoopbackDeviceState {
mtu,
metric,
rx_queue: Default::default(),
tx_queue: Default::default(),
}
}
}
impl<C: NonSyncContext> LockFor<crate::lock_ordering::LoopbackRxQueue>
for IpLinkDeviceState<LoopbackDevice, C>
{
type Data = ReceiveQueueState<(), Buf<Vec<u8>>>;
type Guard<'l> = crate::sync::LockGuard<'l, ReceiveQueueState<(), Buf<Vec<u8>>>>
where
Self: 'l;
fn lock(&self) -> Self::Guard<'_> {
self.link.rx_queue.queue.lock()
}
}
impl<C: NonSyncContext> LockFor<crate::lock_ordering::LoopbackRxDequeue>
for IpLinkDeviceState<LoopbackDevice, C>
{
type Data = DequeueState<(), Buf<Vec<u8>>>;
type Guard<'l> = crate::sync::LockGuard<'l, DequeueState<(), Buf<Vec<u8>>>>
where
Self: 'l;
fn lock(&self) -> Self::Guard<'_> {
self.link.rx_queue.deque.lock()
}
}
impl<C: NonSyncContext> LockFor<crate::lock_ordering::LoopbackTxQueue>
for IpLinkDeviceState<LoopbackDevice, C>
{
type Data = TransmitQueueState<(), Buf<Vec<u8>>, BufVecU8Allocator>;
type Guard<'l> = crate::sync::LockGuard<'l, TransmitQueueState<(), Buf<Vec<u8>>, BufVecU8Allocator>>
where
Self: 'l;
fn lock(&self) -> Self::Guard<'_> {
self.link.tx_queue.queue.lock()
}
}
impl<C: NonSyncContext> LockFor<crate::lock_ordering::LoopbackTxDequeue>
for IpLinkDeviceState<LoopbackDevice, C>
{
type Data = DequeueState<(), Buf<Vec<u8>>>;
type Guard<'l> = crate::sync::LockGuard<'l, DequeueState<(), Buf<Vec<u8>>>>
where
Self: 'l;
fn lock(&self) -> Self::Guard<'_> {
self.link.tx_queue.deque.lock()
}
}
impl<C: NonSyncContext> RwLockFor<crate::lock_ordering::DeviceSockets>
for IpLinkDeviceState<LoopbackDevice, C>
{
type Data = HeldDeviceSockets<C>;
type ReadGuard<'l> = crate::sync::RwLockReadGuard<'l, HeldDeviceSockets<C>>
where
Self: 'l ;
type WriteGuard<'l> = crate::sync::RwLockWriteGuard<'l, HeldDeviceSockets<C>>
where
Self: 'l ;
fn read_lock(&self) -> Self::ReadGuard<'_> {
self.sockets.read()
}
fn write_lock(&self) -> Self::WriteGuard<'_> {
self.sockets.write()
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackTxQueue>>
SendFrameContext<C, DeviceSocketMetadata<LoopbackDeviceId<C>>> for Locked<&SyncCtx<C>, L>
{
fn send_frame<S>(
&mut self,
ctx: &mut C,
metadata: DeviceSocketMetadata<LoopbackDeviceId<C>>,
body: S,
) -> Result<(), S>
where
S: Serializer,
S::Buffer: BufferMut,
{
let DeviceSocketMetadata { device_id, header } = metadata;
match header {
Some(DatagramHeader { dest_addr, protocol }) => {
send_as_ethernet_frame_to_dst(self, ctx, &device_id, body, protocol, dest_addr)
}
None => send_ethernet_frame(self, ctx, &device_id, body),
}
}
}
pub(super) fn send_ip_frame<NonSyncCtx, A, S, L>(
sync_ctx: &mut Locked<&SyncCtx<NonSyncCtx>, L>,
ctx: &mut NonSyncCtx,
device_id: &LoopbackDeviceId<NonSyncCtx>,
_local_addr: SpecifiedAddr<A>,
packet: S,
) -> Result<(), S>
where
NonSyncCtx: NonSyncContext,
A: IpAddress,
S: Serializer,
S::Buffer: BufferMut,
L: LockBefore<crate::lock_ordering::LoopbackTxQueue>,
A::Version: EthernetIpExt,
{
send_as_ethernet_frame_to_dst(
sync_ctx,
ctx,
device_id,
packet,
<A::Version as EthernetIpExt>::ETHER_TYPE,
LOOPBACK_MAC,
)
}
fn send_as_ethernet_frame_to_dst<NonSyncCtx, S, L>(
sync_ctx: &mut Locked<&SyncCtx<NonSyncCtx>, L>,
ctx: &mut NonSyncCtx,
device_id: &LoopbackDeviceId<NonSyncCtx>,
packet: S,
protocol: EtherType,
dst_mac: Mac,
) -> Result<(), S>
where
NonSyncCtx: NonSyncContext,
S: Serializer,
S::Buffer: BufferMut,
L: LockBefore<crate::lock_ordering::LoopbackTxQueue>,
{
const MIN_BODY_LEN: usize = 0;
let frame = packet.encapsulate(EthernetFrameBuilder::new(
LOOPBACK_MAC,
dst_mac,
protocol,
MIN_BODY_LEN,
));
send_ethernet_frame(sync_ctx, ctx, device_id, frame).map_err(|s| s.into_inner())
}
fn send_ethernet_frame<L, S, NonSyncCtx>(
sync_ctx: &mut Locked<&SyncCtx<NonSyncCtx>, L>,
ctx: &mut NonSyncCtx,
device_id: &LoopbackDeviceId<NonSyncCtx>,
frame: S,
) -> Result<(), S>
where
L: LockBefore<crate::lock_ordering::LoopbackTxQueue>,
S: Serializer,
S::Buffer: BufferMut,
NonSyncCtx: NonSyncContext,
{
sync_ctx.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.send_total_frames.increment();
});
match TransmitQueueHandler::<LoopbackDevice, _>::queue_tx_frame(
sync_ctx,
ctx,
device_id,
(),
frame,
) {
Ok(()) => {
sync_ctx.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.send_frame.increment();
});
Ok(())
}
Err(TransmitQueueFrameError::NoQueue(_)) => {
unreachable!("loopback never fails to send a frame")
}
Err(TransmitQueueFrameError::QueueFull(s)) => {
sync_ctx.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.send_queue_full.increment();
});
Err(s)
}
Err(TransmitQueueFrameError::SerializeError(s)) => {
sync_ctx.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.send_serialize_error.increment();
});
Err(s)
}
}
}
pub(super) fn get_routing_metric<NonSyncCtx: NonSyncContext, L>(
ctx: &mut Locked<&SyncCtx<NonSyncCtx>, L>,
device_id: &LoopbackDeviceId<NonSyncCtx>,
) -> RawMetric {
device::integration::with_loopback_state(ctx, device_id, |mut state| {
state.cast_with(|s| &s.link.metric).copied()
})
}
pub(super) fn get_mtu<NonSyncCtx: NonSyncContext, L>(
ctx: &mut Locked<&SyncCtx<NonSyncCtx>, L>,
device_id: &LoopbackDeviceId<NonSyncCtx>,
) -> Mtu {
device::integration::with_loopback_state(ctx, device_id, |mut state| {
state.cast_with(|s| &s.link.mtu).copied()
})
}
impl<C: NonSyncContext> ReceiveQueueNonSyncContext<LoopbackDevice, LoopbackDeviceId<C>> for C {
fn wake_rx_task(&mut self, device_id: &LoopbackDeviceId<C>) {
DeviceLayerEventDispatcher::wake_rx_task(self, device_id)
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackRxQueue>>
ReceiveQueueTypes<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
type Meta = ();
type Buffer = Buf<Vec<u8>>;
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackRxQueue>>
ReceiveQueueContext<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
fn with_receive_queue_mut<
O,
F: FnOnce(&mut ReceiveQueueState<Self::Meta, Self::Buffer>) -> O,
>(
&mut self,
device_id: &LoopbackDeviceId<C>,
cb: F,
) -> O {
device::integration::with_loopback_state(self, device_id, |mut state| {
let mut x = state.lock::<crate::lock_ordering::LoopbackRxQueue>();
cb(&mut x)
})
}
}
impl<C: NonSyncContext> ReceiveDequeFrameContext<LoopbackDevice, C>
for Locked<&SyncCtx<C>, crate::lock_ordering::LoopbackRxDequeue>
{
fn handle_frame(
&mut self,
ctx: &mut C,
device_id: &LoopbackDeviceId<C>,
(): Self::Meta,
mut buf: Buf<Vec<u8>>,
) {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.recv_frame.increment();
});
let (frame, whole_body) =
match buf.parse_with_view::<_, EthernetFrame<_>>(EthernetFrameLengthCheck::NoCheck) {
Err(e) => {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.recv_parse_error.increment();
});
trace!("dropping invalid ethernet frame over loopback: {:?}", e);
return;
}
Ok(e) => e,
};
let frame_dest = FrameDestination::from_dest(frame.dst_mac(), Mac::UNSPECIFIED);
let ethertype = frame.ethertype();
DeviceSocketHandler::<LoopbackDevice, _>::handle_frame(
self,
ctx,
device_id,
ReceivedFrame::from_ethernet(frame, frame_dest).into(),
whole_body,
);
let ethertype = match ethertype {
Some(e) => e,
None => {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.recv_no_ethertype.increment();
});
trace!("dropping ethernet frame without ethertype");
return;
}
};
match ethertype {
EtherType::Ipv4 => {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.recv_ip_delivered.increment();
});
crate::ip::receive_ipv4_packet(
self,
ctx,
&device_id.clone().into(),
frame_dest,
buf,
)
}
EtherType::Ipv6 => {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.recv_ip_delivered.increment();
});
crate::ip::receive_ipv6_packet(
self,
ctx,
&device_id.clone().into(),
frame_dest,
buf,
)
}
ethertype @ EtherType::Arp | ethertype @ EtherType::Other(_) => {
self.with_counters(|counters: &DeviceCounters| {
counters.loopback.common.recv_unsupported_ethertype.increment();
});
trace!("not handling loopback frame of type {:?}", ethertype)
}
}
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackRxDequeue>>
ReceiveDequeContext<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
type ReceiveQueueCtx<'a> = Locked<&'a SyncCtx<C>, crate::lock_ordering::LoopbackRxDequeue>;
fn with_dequed_frames_and_rx_queue_ctx<
O,
F: FnOnce(&mut DequeueState<(), Buf<Vec<u8>>>, &mut Self::ReceiveQueueCtx<'_>) -> O,
>(
&mut self,
device_id: &LoopbackDeviceId<C>,
cb: F,
) -> O {
device::integration::with_loopback_state_and_sync_ctx(
self,
device_id,
|mut state, sync_ctx| {
let mut x = state.lock::<crate::lock_ordering::LoopbackRxDequeue>();
let mut locked = sync_ctx.cast_locked();
cb(&mut x, &mut locked)
},
)
}
}
impl<C: NonSyncContext> TransmitQueueNonSyncContext<LoopbackDevice, LoopbackDeviceId<C>> for C {
fn wake_tx_task(&mut self, device_id: &LoopbackDeviceId<C>) {
DeviceLayerEventDispatcher::wake_tx_task(self, &device_id.clone().into())
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackTxQueue>>
TransmitQueueCommon<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
type Meta = ();
type Allocator = BufVecU8Allocator;
type Buffer = Buf<Vec<u8>>;
fn parse_outgoing_frame(buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
SentFrame::try_parse_as_ethernet(buf)
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackTxQueue>>
TransmitQueueContext<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
fn with_transmit_queue_mut<
O,
F: FnOnce(&mut TransmitQueueState<Self::Meta, Self::Buffer, Self::Allocator>) -> O,
>(
&mut self,
device_id: &LoopbackDeviceId<C>,
cb: F,
) -> O {
device::integration::with_loopback_state(self, device_id, |mut state| {
let mut x = state.lock::<crate::lock_ordering::LoopbackTxQueue>();
cb(&mut x)
})
}
fn send_frame(
&mut self,
ctx: &mut C,
device_id: &Self::DeviceId,
meta: Self::Meta,
buf: Self::Buffer,
) -> Result<(), DeviceSendFrameError<(Self::Meta, Self::Buffer)>> {
match ReceiveQueueHandler::queue_rx_frame(self, ctx, device_id, meta, buf) {
Ok(()) => {}
Err(ReceiveQueueFullError(((), _frame))) => {
tracing::error!("dropped RX frame on loopback device due to full RX queue")
}
}
Ok(())
}
}
impl<C: NonSyncContext, L: LockBefore<crate::lock_ordering::LoopbackTxDequeue>>
TransmitDequeueContext<LoopbackDevice, C> for Locked<&SyncCtx<C>, L>
{
type TransmitQueueCtx<'a> = Locked<&'a SyncCtx<C>, crate::lock_ordering::LoopbackTxDequeue>;
fn with_dequed_packets_and_tx_queue_ctx<
O,
F: FnOnce(&mut DequeueState<Self::Meta, Self::Buffer>, &mut Self::TransmitQueueCtx<'_>) -> O,
>(
&mut self,
device_id: &Self::DeviceId,
cb: F,
) -> O {
device::integration::with_loopback_state_and_sync_ctx(
self,
device_id,
|mut state, sync_ctx| {
let mut x = state.lock::<crate::lock_ordering::LoopbackTxDequeue>();
let mut locked = sync_ctx.cast_locked();
cb(&mut x, &mut locked)
},
)
}
}
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use assert_matches::assert_matches;
use ip_test_macro::ip_test;
use lock_order::{Locked, Unlocked};
use net_types::ip::{AddrSubnet, AddrSubnetEither, Ip, Ipv4, Ipv6};
use packet::ParseBuffer;
use crate::{
device::{DeviceId, Mtu},
error::NotFoundError,
ip::device::{IpAddressId as _, IpDeviceIpExt, IpDeviceStateContext},
testutil::{
Ctx, FakeEventDispatcherConfig, FakeNonSyncCtx, TestIpExt, DEFAULT_INTERFACE_METRIC,
},
SyncCtx,
};
use super::*;
const MTU: Mtu = Mtu::new(66);
#[test]
fn loopback_mtu() {
let Ctx { sync_ctx, mut non_sync_ctx } = crate::testutil::FakeCtx::default();
let sync_ctx = &sync_ctx;
let device = crate::device::add_loopback_device(&sync_ctx, MTU, DEFAULT_INTERFACE_METRIC)
.expect("error adding loopback device")
.into();
crate::device::testutil::enable_device(&sync_ctx, &mut non_sync_ctx, &device);
assert_eq!(
crate::ip::IpDeviceContext::<Ipv4, _>::get_mtu(&mut Locked::new(sync_ctx), &device),
MTU
);
assert_eq!(
crate::ip::IpDeviceContext::<Ipv6, _>::get_mtu(&mut Locked::new(sync_ctx), &device),
MTU
);
}
#[ip_test]
fn test_loopback_add_remove_addrs<I: Ip + TestIpExt + IpDeviceIpExt>()
where
for<'a> Locked<&'a SyncCtx<FakeNonSyncCtx>, Unlocked>:
IpDeviceStateContext<I, FakeNonSyncCtx, DeviceId = DeviceId<FakeNonSyncCtx>>,
{
let Ctx { sync_ctx, mut non_sync_ctx } = crate::testutil::FakeCtx::default();
let sync_ctx = &sync_ctx;
let device = crate::device::add_loopback_device(&sync_ctx, MTU, DEFAULT_INTERFACE_METRIC)
.expect("error adding loopback device")
.into();
crate::device::testutil::enable_device(&sync_ctx, &mut non_sync_ctx, &device);
let get_addrs = || {
crate::ip::device::IpDeviceStateContext::<I, _>::with_address_ids(
&mut Locked::new(sync_ctx),
&device,
|addrs, _sync_ctx| addrs.map(|a| a.addr()).collect::<Vec<_>>(),
)
};
let FakeEventDispatcherConfig {
subnet,
local_ip,
local_mac: _,
remote_ip: _,
remote_mac: _,
} = I::FAKE_CONFIG;
let addr =
AddrSubnet::from_witness(local_ip, subnet.prefix()).expect("error creating AddrSubnet");
assert_eq!(get_addrs(), []);
assert_eq!(
crate::device::add_ip_addr_subnet(
sync_ctx,
&mut non_sync_ctx,
&device,
AddrSubnetEither::from(addr)
),
Ok(())
);
let addr = addr.addr();
assert_eq!(&get_addrs()[..], [addr]);
assert_eq!(crate::device::del_ip_addr(sync_ctx, &mut non_sync_ctx, &device, &addr), Ok(()));
assert_eq!(get_addrs(), []);
assert_eq!(
crate::device::del_ip_addr(sync_ctx, &mut non_sync_ctx, &device, &addr),
Err(NotFoundError)
);
}
#[ip_test]
fn loopback_sends_ethernet<I: Ip + TestIpExt>() {
let Ctx { sync_ctx, mut non_sync_ctx } = crate::testutil::FakeCtx::default();
let sync_ctx = &sync_ctx;
let device = crate::device::add_loopback_device(&sync_ctx, MTU, DEFAULT_INTERFACE_METRIC)
.expect("error adding loopback device");
crate::device::testutil::enable_device(
&sync_ctx,
&mut non_sync_ctx,
&device.clone().into(),
);
let local_addr = I::FAKE_CONFIG.local_ip;
const BODY: &[u8] = b"IP body".as_slice();
let body = Buf::new(Vec::from(BODY), ..);
send_ip_frame(&mut Locked::new(sync_ctx), &mut non_sync_ctx, &device, local_addr, body)
.expect("can send");
let mut frames = ReceiveQueueContext::<LoopbackDevice, _>::with_receive_queue_mut(
&mut Locked::new(sync_ctx),
&device,
|queue_state| queue_state.take_frames().map(|((), frame)| frame).collect::<Vec<_>>(),
);
let frame = assert_matches!(frames.as_mut_slice(), [frame] => frame);
let eth = frame
.parse_with::<_, EthernetFrame<_>>(EthernetFrameLengthCheck::NoCheck)
.expect("is ethernet");
assert_eq!(eth.src_mac(), Mac::UNSPECIFIED);
assert_eq!(eth.dst_mac(), Mac::UNSPECIFIED);
assert_eq!(eth.ethertype(), Some(I::ETHER_TYPE));
assert_eq!(&frame.as_ref()[..BODY.len()], BODY);
}
}