pub(crate) mod api;
pub(crate) mod counters;
pub(crate) mod packet_queue;
pub(crate) mod route;
pub(crate) mod state;
use core::sync::atomic::Ordering;
use net_types::ip::{GenericOverIp, Ip, IpVersionMarker};
use netstack3_base::{
AnyDevice, AtomicInstant, CounterContext, DeviceIdContext, EventContext, FrameDestination,
HandleableTimer, InstantBindingsTypes, InstantContext, TimerBindingsTypes, TimerContext,
WeakDeviceIdentifier,
};
use packet_formats::ip::IpPacket;
use zerocopy::SplitByteSlice;
use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
use crate::internal::multicast_forwarding::packet_queue::QueuePacketOutcome;
use crate::internal::multicast_forwarding::route::{
Action, MulticastRouteEntry, MulticastRouteTargets,
};
use crate::multicast_forwarding::{
MulticastForwardingPendingPacketsContext, MulticastForwardingState,
MulticastForwardingStateContext, MulticastRoute, MulticastRouteKey,
MulticastRouteTableContext as _,
};
use crate::{IpLayerEvent, IpLayerIpExt};
pub trait MulticastForwardingBindingsTypes: InstantBindingsTypes + TimerBindingsTypes {}
impl<BT: InstantBindingsTypes + TimerBindingsTypes> MulticastForwardingBindingsTypes for BT {}
pub trait MulticastForwardingBindingsContext<I: IpLayerIpExt, D>:
MulticastForwardingBindingsTypes + InstantContext + TimerContext + EventContext<IpLayerEvent<D, I>>
{
}
impl<
I: IpLayerIpExt,
D,
BC: MulticastForwardingBindingsTypes
+ InstantContext
+ TimerContext
+ EventContext<IpLayerEvent<D, I>>,
> MulticastForwardingBindingsContext<I, D> for BC
{
}
pub trait MulticastForwardingDeviceContext<I: IpLayerIpExt>: DeviceIdContext<AnyDevice> {
fn is_device_multicast_forwarding_enabled(&mut self, dev: &Self::DeviceId) -> bool;
}
#[derive(Clone, Debug, Eq, GenericOverIp, Hash, PartialEq)]
#[generic_over_ip(I, Ip)]
pub enum MulticastForwardingTimerId<I: Ip> {
PendingPacketsGc(IpVersionMarker<I>),
}
impl<
I: IpLayerIpExt,
BC: MulticastForwardingBindingsContext<I, CC::DeviceId>,
CC: MulticastForwardingStateContext<I, BC> + CounterContext<MulticastForwardingCounters<I>>,
> HandleableTimer<CC, BC> for MulticastForwardingTimerId<I>
{
fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BC, _: BC::UniqueTimerId) {
match self {
MulticastForwardingTimerId::PendingPacketsGc(_) => {
core_ctx.with_state(|state, ctx| match state {
MulticastForwardingState::Disabled => {}
MulticastForwardingState::Enabled(state) => {
ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.pending_table_gc
});
let removed_count = ctx.with_pending_table_mut(state, |pending_table| {
pending_table.run_garbage_collection(bindings_ctx)
});
ctx.add(removed_count, |counters: &MulticastForwardingCounters<I>| {
&counters.pending_packet_drops_gc
});
}
})
}
}
}
}
#[derive(Debug, Eq, Hash, PartialEq, GenericOverIp)]
#[generic_over_ip(I, Ip)]
pub enum MulticastForwardingEvent<I: IpLayerIpExt, D> {
MissingRoute {
key: MulticastRouteKey<I>,
input_interface: D,
},
WrongInputInterface {
key: MulticastRouteKey<I>,
actual_input_interface: D,
expected_input_interface: D,
},
}
impl<I: IpLayerIpExt, D> MulticastForwardingEvent<I, D> {
pub(crate) fn map_device<O, F: Fn(D) -> O>(self, map: F) -> MulticastForwardingEvent<I, O> {
match self {
MulticastForwardingEvent::MissingRoute { key, input_interface } => {
MulticastForwardingEvent::MissingRoute {
key,
input_interface: map(input_interface),
}
}
MulticastForwardingEvent::WrongInputInterface {
key,
actual_input_interface,
expected_input_interface,
} => MulticastForwardingEvent::WrongInputInterface {
key,
actual_input_interface: map(actual_input_interface),
expected_input_interface: map(expected_input_interface),
},
}
}
}
impl<I: IpLayerIpExt, D: WeakDeviceIdentifier> MulticastForwardingEvent<I, D> {
pub fn upgrade_device_id(self) -> Option<MulticastForwardingEvent<I, D::Strong>> {
match self {
MulticastForwardingEvent::MissingRoute { key, input_interface } => {
Some(MulticastForwardingEvent::MissingRoute {
key,
input_interface: input_interface.upgrade()?,
})
}
MulticastForwardingEvent::WrongInputInterface {
key,
actual_input_interface,
expected_input_interface,
} => Some(MulticastForwardingEvent::WrongInputInterface {
key,
actual_input_interface: actual_input_interface.upgrade()?,
expected_input_interface: expected_input_interface.upgrade()?,
}),
}
}
}
pub(crate) fn lookup_multicast_route_or_stash_packet<I, B, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
packet: &I::Packet<B>,
dev: &CC::DeviceId,
frame_dst: Option<FrameDestination>,
) -> Option<MulticastRouteTargets<CC::DeviceId>>
where
I: IpLayerIpExt,
B: SplitByteSlice,
CC: MulticastForwardingStateContext<I, BC>
+ MulticastForwardingDeviceContext<I>
+ CounterContext<MulticastForwardingCounters<I>>,
BC: MulticastForwardingBindingsContext<I, CC::DeviceId>,
{
core_ctx.increment(|counters: &MulticastForwardingCounters<I>| &counters.rx);
let key = MulticastRouteKey::new(packet.src_ip(), packet.dst_ip())?;
core_ctx.increment(|counters: &MulticastForwardingCounters<I>| &counters.no_tx_invalid_key);
if !core_ctx.is_device_multicast_forwarding_enabled(dev) {
core_ctx
.increment(|counters: &MulticastForwardingCounters<I>| &counters.no_tx_disabled_dev);
return None;
}
core_ctx.with_state(|state, ctx| {
let Some(state) = state.enabled() else {
ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.no_tx_disabled_stack_wide
});
return None;
};
ctx.with_route_table(state, |route_table, ctx| {
if let Some(MulticastRouteEntry {
route: MulticastRoute { input_interface, action },
stats,
}) = route_table.get(&key)
{
if dev != input_interface {
ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.no_tx_wrong_dev
});
bindings_ctx.on_event(
MulticastForwardingEvent::WrongInputInterface {
key,
actual_input_interface: dev.clone(),
expected_input_interface: input_interface.clone(),
}
.into(),
);
return None;
}
stats.last_used.store_max(bindings_ctx.now(), Ordering::Relaxed);
match action {
Action::Forward(targets) => {
ctx.increment(|counters: &MulticastForwardingCounters<I>| &counters.tx);
return Some(targets.clone());
}
}
}
ctx.increment(|counters: &MulticastForwardingCounters<I>| &counters.pending_packets);
match ctx.with_pending_table_mut(state, |pending_table| {
pending_table.try_queue_packet(bindings_ctx, key.clone(), packet, dev, frame_dst)
}) {
QueuePacketOutcome::QueuedInNewQueue => {
bindings_ctx.on_event(
MulticastForwardingEvent::MissingRoute {
key,
input_interface: dev.clone(),
}
.into(),
);
}
QueuePacketOutcome::QueuedInExistingQueue => {}
QueuePacketOutcome::ExistingQueueFull => {
ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.pending_packet_drops_queue_full
});
}
}
return None;
})
})
}
#[cfg(test)]
mod testutil {
use super::*;
use alloc::collections::HashSet;
use alloc::rc::Rc;
use alloc::vec::Vec;
use core::cell::RefCell;
use derivative::Derivative;
use net_declare::{net_ip_v4, net_ip_v6};
use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr, Mtu};
use net_types::{MulticastAddr, SpecifiedAddr};
use netstack3_base::testutil::{FakeStrongDeviceId, MultipleDevicesId};
use netstack3_base::{CoreTimerContext, CounterContext, CtxPair, FrameDestination};
use netstack3_filter::ProofOfEgressCheck;
use packet::{BufferMut, InnerPacketBuilder, Serializer};
use packet_formats::ip::{IpPacketBuilder, IpProto};
use crate::device::IpDeviceSendContext;
use crate::internal::base::DeviceIpLayerMetadata;
use crate::internal::icmp::{IcmpErrorHandler, IcmpHandlerIpExt};
use crate::multicast_forwarding::{
MulticastForwardingApi, MulticastForwardingEnabledState, MulticastForwardingPendingPackets,
MulticastForwardingPendingPacketsContext, MulticastForwardingState, MulticastRouteTable,
MulticastRouteTableContext,
};
use crate::{IpCounters, IpDeviceMtuContext, IpLayerEvent, IpPacketDestination};
pub(crate) trait TestIpExt: IpLayerIpExt {
const SRC1: Self::Addr;
const SRC2: Self::Addr;
const DST1: Self::Addr;
const DST2: Self::Addr;
}
impl TestIpExt for Ipv4 {
const SRC1: Ipv4Addr = net_ip_v4!("192.0.2.1");
const SRC2: Ipv4Addr = net_ip_v4!("192.0.2.2");
const DST1: Ipv4Addr = net_ip_v4!("224.0.1.1");
const DST2: Ipv4Addr = net_ip_v4!("224.0.1.2");
}
impl TestIpExt for Ipv6 {
const SRC1: Ipv6Addr = net_ip_v6!("2001:0DB8::1");
const SRC2: Ipv6Addr = net_ip_v6!("2001:0DB8::2");
const DST1: Ipv6Addr = net_ip_v6!("ff0e::1");
const DST2: Ipv6Addr = net_ip_v6!("ff0e::2");
}
pub(crate) fn new_ip_packet_buf<I: IpLayerIpExt>(
src_addr: I::Addr,
dst_addr: I::Addr,
) -> impl AsRef<[u8]> {
const TTL: u8 = 255;
const IP_BODY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
IP_BODY
.into_serializer()
.encapsulate(I::PacketBuilder::new(src_addr, dst_addr, TTL, IpProto::Udp.into()))
.serialize_vec_outer()
.unwrap()
}
#[derive(Debug, PartialEq)]
pub(crate) struct SentPacket<I: IpLayerIpExt, D> {
pub(crate) dst: MulticastAddr<I::Addr>,
pub(crate) device: D,
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub(crate) struct FakeCoreCtxState<I: IpLayerIpExt, D: FakeStrongDeviceId> {
pub(crate) multicast_forwarding:
Rc<RefCell<MulticastForwardingState<I, D, FakeBindingsCtx<I, D>>>>,
pub(crate) forwarding_enabled_devices: HashSet<D>,
pub(crate) sent_packets: Vec<SentPacket<I, D>>,
counters: IpCounters<I>,
multicast_forwarding_counters: MulticastForwardingCounters<I>,
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> FakeCoreCtxState<I, D> {
pub(crate) fn set_multicast_forwarding_enabled_for_dev(&mut self, dev: D, enabled: bool) {
if enabled {
let _: bool = self.forwarding_enabled_devices.insert(dev);
} else {
let _: bool = self.forwarding_enabled_devices.remove(&dev);
}
}
pub(crate) fn take_sent_packets(&mut self) -> Vec<SentPacket<I, D>> {
core::mem::take(&mut self.sent_packets)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> CounterContext<IpCounters<I>>
for FakeCoreCtxState<I, D>
{
fn with_counters<O, F: FnOnce(&IpCounters<I>) -> O>(&self, cb: F) -> O {
cb(&self.counters)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> CounterContext<MulticastForwardingCounters<I>>
for FakeCoreCtxState<I, D>
{
fn with_counters<O, F: FnOnce(&MulticastForwardingCounters<I>) -> O>(&self, cb: F) -> O {
cb(&self.multicast_forwarding_counters)
}
}
pub(crate) type FakeBindingsCtx<I, D> = netstack3_base::testutil::FakeBindingsCtx<
MulticastForwardingTimerId<I>,
IpLayerEvent<D, I>,
(),
(),
>;
pub(crate) type FakeCoreCtx<I, D> =
netstack3_base::testutil::FakeCoreCtx<FakeCoreCtxState<I, D>, (), D>;
impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
MulticastForwardingStateContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
{
type Ctx<'a> = FakeCoreCtx<I, D>;
fn with_state<
O,
F: FnOnce(
&MulticastForwardingState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
&mut Self::Ctx<'_>,
) -> O,
>(
&mut self,
cb: F,
) -> O {
let state = self.state.multicast_forwarding.clone();
let borrow = state.borrow();
cb(&borrow, self)
}
fn with_state_mut<
O,
F: FnOnce(
&mut MulticastForwardingState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
&mut Self::Ctx<'_>,
) -> O,
>(
&mut self,
cb: F,
) -> O {
let state = self.state.multicast_forwarding.clone();
let mut borrow = state.borrow_mut();
cb(&mut borrow, self)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
MulticastRouteTableContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
{
type Ctx<'a> = FakeCoreCtx<I, D>;
fn with_route_table<
O,
F: FnOnce(
&MulticastRouteTable<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
&mut Self::Ctx<'_>,
) -> O,
>(
&mut self,
state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
cb: F,
) -> O {
let route_table = state.route_table().read();
cb(&route_table, self)
}
fn with_route_table_mut<
O,
F: FnOnce(
&mut MulticastRouteTable<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
&mut Self::Ctx<'_>,
) -> O,
>(
&mut self,
state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
cb: F,
) -> O {
let mut route_table = state.route_table().write();
cb(&mut route_table, self)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
MulticastForwardingPendingPacketsContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
{
fn with_pending_table_mut<
O,
F: FnOnce(
&mut MulticastForwardingPendingPackets<I, Self::WeakDeviceId, FakeBindingsCtx<I, D>>,
) -> O,
>(
&mut self,
state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
cb: F,
) -> O {
let mut pending_table = state.pending_table().lock();
cb(&mut pending_table)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> MulticastForwardingDeviceContext<I>
for FakeCoreCtx<I, D>
{
fn is_device_multicast_forwarding_enabled(&mut self, device_id: &Self::DeviceId) -> bool {
self.state.forwarding_enabled_devices.contains(device_id)
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
CoreTimerContext<MulticastForwardingTimerId<I>, FakeBindingsCtx<I, D>>
for FakeCoreCtx<I, D>
{
fn convert_timer(
dispatch_id: MulticastForwardingTimerId<I>,
) -> MulticastForwardingTimerId<I> {
dispatch_id
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IpDeviceSendContext<I, FakeBindingsCtx<I, D>>
for FakeCoreCtx<I, D>
{
fn send_ip_frame<S>(
&mut self,
_bindings_ctx: &mut FakeBindingsCtx<I, D>,
device_id: &D,
destination: IpPacketDestination<I, &D>,
_ip_layer_metadata: DeviceIpLayerMetadata,
_body: S,
_egress_proof: ProofOfEgressCheck,
) -> Result<(), netstack3_base::SendFrameError<S>>
where
S: Serializer,
S::Buffer: BufferMut,
{
let dst = match destination {
IpPacketDestination::Multicast(dst) => dst,
dst => panic!("unexpected sent packet: destination={dst:?}"),
};
self.state.sent_packets.push(SentPacket { dst, device: device_id.clone() });
Ok(())
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IpDeviceMtuContext<I> for FakeCoreCtx<I, D> {
fn get_mtu(&mut self, _device_id: &Self::DeviceId) -> Mtu {
Mtu::max()
}
}
impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IcmpErrorHandler<I, FakeBindingsCtx<I, D>>
for FakeCoreCtx<I, D>
{
fn send_icmp_error_message<B: BufferMut>(
&mut self,
_bindings_ctx: &mut FakeBindingsCtx<I, D>,
_device: &D,
_frame_dst: Option<FrameDestination>,
_src_ip: <I as IcmpHandlerIpExt>::SourceAddress,
_dst_ip: SpecifiedAddr<I::Addr>,
_original_packet: B,
_error: I::IcmpError,
) {
unimplemented!()
}
}
pub(crate) fn new_api<I: IpLayerIpExt>() -> MulticastForwardingApi<
I,
CtxPair<FakeCoreCtx<I, MultipleDevicesId>, FakeBindingsCtx<I, MultipleDevicesId>>,
> {
MulticastForwardingApi::new(CtxPair::with_core_ctx(FakeCoreCtx::with_state(
Default::default(),
)))
}
pub(crate) fn with_pending_table<I, O, F, CC, BT>(core_ctx: &mut CC, cb: F) -> O
where
I: IpLayerIpExt,
CC: MulticastForwardingStateContext<I, BT>,
BT: MulticastForwardingBindingsTypes,
F: FnOnce(&mut MulticastForwardingPendingPackets<I, CC::WeakDeviceId, BT>) -> O,
{
core_ctx.with_state(|state, ctx| {
let state = state.enabled().unwrap();
ctx.with_route_table(state, |_routing_table, ctx| {
ctx.with_pending_table_mut(state, |pending_table| cb(pending_table))
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use core::time::Duration;
use ip_test_macro::ip_test;
use netstack3_base::testutil::MultipleDevicesId;
use packet::ParseBuffer;
use test_case::test_case;
use testutil::TestIpExt;
use crate::internal::multicast_forwarding::route::MulticastRouteStats;
use crate::multicast_forwarding::MulticastRouteTarget;
struct LookupTestCase {
enabled: bool,
dev_enabled: bool,
right_key: bool,
right_dev: bool,
}
const LOOKUP_SUCCESS_CASE: LookupTestCase =
LookupTestCase { enabled: true, dev_enabled: true, right_key: true, right_dev: true };
#[ip_test(I)]
#[test_case(LOOKUP_SUCCESS_CASE => true; "success")]
#[test_case(LookupTestCase{enabled: false, ..LOOKUP_SUCCESS_CASE} => false; "disabled")]
#[test_case(LookupTestCase{dev_enabled: false, ..LOOKUP_SUCCESS_CASE} => false; "dev_disabled")]
#[test_case(LookupTestCase{right_key: false, ..LOOKUP_SUCCESS_CASE} => false; "wrong_key")]
#[test_case(LookupTestCase{right_dev: false, ..LOOKUP_SUCCESS_CASE} => false; "wrong_dev")]
fn lookup_route<I: TestIpExt>(test_case: LookupTestCase) -> bool {
let LookupTestCase { enabled, dev_enabled, right_key, right_dev } = test_case;
const FRAME_DST: Option<FrameDestination> = None;
let mut api = testutil::new_api::<I>();
let expected_key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let actual_key = if right_key {
expected_key.clone()
} else {
MulticastRouteKey::new(I::SRC2, I::DST2).unwrap()
};
let expected_dev = MultipleDevicesId::A;
let actual_dev = if right_dev { expected_dev } else { MultipleDevicesId::B };
if enabled {
assert!(api.enable());
assert_eq!(
api.add_multicast_route(
expected_key.clone(),
MulticastRoute::new_forward(
expected_dev,
[MulticastRouteTarget {
output_interface: MultipleDevicesId::C,
min_ttl: 0
}]
.into()
)
.unwrap()
),
Ok(None)
);
}
api.core_ctx().state.set_multicast_forwarding_enabled_for_dev(actual_dev, dev_enabled);
let (core_ctx, bindings_ctx) = api.contexts();
let creation_time = bindings_ctx.now();
bindings_ctx.timers.instant.sleep(Duration::from_secs(5));
let lookup_time = bindings_ctx.now();
assert!(lookup_time > creation_time);
let buf = testutil::new_ip_packet_buf::<I>(actual_key.src_addr(), actual_key.dst_addr());
let mut buf_ref = buf.as_ref();
let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
let route = lookup_multicast_route_or_stash_packet(
core_ctx,
bindings_ctx,
&packet,
&actual_dev,
FRAME_DST,
);
let mut expected_events = vec![];
if !right_key {
expected_events.push(IpLayerEvent::MulticastForwarding(
MulticastForwardingEvent::MissingRoute {
key: actual_key.clone(),
input_interface: actual_dev,
},
));
}
if !right_dev {
expected_events.push(IpLayerEvent::MulticastForwarding(
MulticastForwardingEvent::WrongInputInterface {
key: actual_key,
actual_input_interface: actual_dev,
expected_input_interface: expected_dev,
},
));
}
assert_eq!(bindings_ctx.take_events(), expected_events);
let lookup_succeeded = route.is_some();
if enabled {
let expected_stats = if lookup_succeeded {
MulticastRouteStats { last_used: lookup_time }
} else {
MulticastRouteStats { last_used: creation_time }
};
assert_eq!(api.get_route_stats(&expected_key), Ok(Some(expected_stats)));
}
api.core_ctx().with_counters(|counters: &MulticastForwardingCounters<I>| {
assert_eq!(counters.rx.get(), 1);
assert_eq!(counters.tx.get(), if lookup_succeeded { 1 } else { 0 });
assert_eq!(counters.no_tx_disabled_dev.get(), if dev_enabled { 0 } else { 1 });
assert_eq!(counters.no_tx_disabled_stack_wide.get(), if enabled { 0 } else { 1 });
assert_eq!(counters.no_tx_wrong_dev.get(), if right_dev { 0 } else { 1 });
assert_eq!(counters.pending_packets.get(), if right_key { 0 } else { 1 });
});
lookup_succeeded
}
}