use alloc::collections::btree_map;
use core::sync::atomic::Ordering;
use log::warn;
use net_types::ip::{Ip, IpVersionMarker};
use net_types::SpecifiedAddr;
use netstack3_base::{
AnyDevice, AtomicInstant, ContextPair, CoreTimerContext, CounterContext, DeviceIdContext,
Inspector, InspectorDeviceExt, InstantBindingsTypes, InstantContext, StrongDeviceIdentifier,
WeakDeviceIdentifier,
};
use crate::internal::base::IpLayerForwardingContext;
use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
use crate::internal::multicast_forwarding::packet_queue::{PacketQueue, QueuedPacket};
use crate::internal::multicast_forwarding::route::{
Action, MulticastRoute, MulticastRouteEntry, MulticastRouteKey, MulticastRouteStats,
MulticastRouteTarget,
};
use crate::internal::multicast_forwarding::state::{
MulticastForwardingEnabledState, MulticastForwardingPendingPacketsContext as _,
MulticastForwardingState, MulticastForwardingStateContext, MulticastRouteTableContext as _,
};
use crate::internal::multicast_forwarding::{
MulticastForwardingBindingsTypes, MulticastForwardingDeviceContext, MulticastForwardingEvent,
MulticastForwardingTimerId,
};
use crate::{IpLayerBindingsContext, IpLayerIpExt, IpPacketDestination};
#[derive(Debug, Eq, PartialEq)]
pub struct MulticastForwardingDisabledError {}
trait MulticastForwardingStateExt<
I: IpLayerIpExt,
D: StrongDeviceIdentifier,
BT: MulticastForwardingBindingsTypes,
>
{
fn try_enabled(
&self,
) -> Result<&MulticastForwardingEnabledState<I, D, BT>, MulticastForwardingDisabledError>;
}
impl<I: IpLayerIpExt, D: StrongDeviceIdentifier, BT: MulticastForwardingBindingsTypes>
MulticastForwardingStateExt<I, D, BT> for MulticastForwardingState<I, D, BT>
{
fn try_enabled(
&self,
) -> Result<&MulticastForwardingEnabledState<I, D, BT>, MulticastForwardingDisabledError> {
self.enabled().ok_or(MulticastForwardingDisabledError {})
}
}
pub struct MulticastForwardingApi<I: Ip, C> {
ctx: C,
_ip_mark: IpVersionMarker<I>,
}
impl<I: Ip, C> MulticastForwardingApi<I, C> {
pub fn new(ctx: C) -> Self {
Self { ctx, _ip_mark: IpVersionMarker::new() }
}
}
impl<I: IpLayerIpExt, C> MulticastForwardingApi<I, C>
where
C: ContextPair,
C::CoreContext: MulticastForwardingStateContext<I, C::BindingsContext>
+ MulticastForwardingDeviceContext<I>
+ IpLayerForwardingContext<I, C::BindingsContext>
+ CounterContext<MulticastForwardingCounters<I>>
+ CoreTimerContext<MulticastForwardingTimerId<I>, C::BindingsContext>,
C::BindingsContext:
IpLayerBindingsContext<I, <C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
{
pub(crate) fn core_ctx(&mut self) -> &mut C::CoreContext {
let Self { ctx, _ip_mark } = self;
ctx.core_ctx()
}
pub(crate) fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
let Self { ctx, _ip_mark } = self;
ctx.contexts()
}
pub fn enable(&mut self) -> bool {
let (core_ctx, bindings_ctx) = self.contexts();
core_ctx.with_state_mut(|state, _ctx| match state {
MulticastForwardingState::Enabled(_) => false,
MulticastForwardingState::Disabled => {
*state = MulticastForwardingState::Enabled(MulticastForwardingEnabledState::new::<
C::CoreContext,
>(bindings_ctx));
true
}
})
}
pub fn disable(&mut self) -> bool {
self.core_ctx().with_state_mut(|state, _ctx| match state {
MulticastForwardingState::Disabled => false,
MulticastForwardingState::Enabled(_) => {
*state = MulticastForwardingState::Disabled;
true
}
})
}
pub fn add_multicast_route(
&mut self,
key: MulticastRouteKey<I>,
route: MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
) -> Result<
Option<MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>>,
MulticastForwardingDisabledError,
> {
let (core_ctx, bindings_ctx) = self.contexts();
let (orig_route, packet_queue_and_new_route) = core_ctx.with_state_mut(|state, ctx| {
let state = state.try_enabled()?;
ctx.with_route_table_mut(state, |route_table, ctx| {
let stats = MulticastRouteStats { last_used: bindings_ctx.now_atomic() };
match route_table.entry(key.clone()) {
btree_map::Entry::Occupied(mut entry) => {
let MulticastRouteEntry { route: orig_route, stats: _ } =
entry.insert(MulticastRouteEntry { route, stats });
#[cfg(debug_assertions)]
ctx.with_pending_table_mut(state, |pending_table| {
debug_assert!(!pending_table.contains(&key));
});
Ok((Some(orig_route), None))
}
btree_map::Entry::Vacant(entry) => {
let MulticastRouteEntry { route: new_route_ref, stats: _ } =
entry.insert(MulticastRouteEntry { route, stats });
let packet_queue_and_new_route = ctx
.with_pending_table_mut(state, |pending_table| {
pending_table.remove(&key, bindings_ctx)
})
.map(|packet_queue| (packet_queue, new_route_ref.clone()));
Ok((None, packet_queue_and_new_route))
}
}
})
})?;
if let Some((packet_queue, new_route)) = packet_queue_and_new_route {
handle_pending_packets(core_ctx, bindings_ctx, packet_queue, key, new_route)
}
Ok(orig_route)
}
pub fn remove_multicast_route(
&mut self,
key: &MulticastRouteKey<I>,
) -> Result<
Option<MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>>,
MulticastForwardingDisabledError,
> {
self.core_ctx().with_state_mut(|state, ctx| {
let state = state.try_enabled()?;
ctx.with_route_table_mut(state, |route_table, _ctx| {
Ok(route_table.remove(key).map(|MulticastRouteEntry { route, stats: _ }| route))
})
})
}
pub fn remove_references_to_device(
&mut self,
dev: &<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
) {
self.core_ctx().with_state_mut(|state, ctx| {
let Some(state) = state.enabled() else {
return;
};
ctx.with_route_table_mut(state, |route_table, _ctx| {
route_table.retain(
|_route_key,
MulticastRouteEntry {
route: MulticastRoute { action, input_interface },
stats: _,
}| {
if dev == &*input_interface {
return false;
}
match action {
Action::Forward(ref mut targets) => {
if targets.iter().all(|target| dev == &target.output_interface) {
return false;
}
if targets.iter().any(|target| dev == &target.output_interface) {
*targets = targets
.iter()
.filter(|target| dev != &target.output_interface)
.cloned()
.collect();
}
}
}
true
},
)
})
})
}
pub fn get_route_stats(
&mut self,
key: &MulticastRouteKey<I>,
) -> Result<
Option<MulticastRouteStats<<C::BindingsContext as InstantBindingsTypes>::Instant>>,
MulticastForwardingDisabledError,
> {
self.core_ctx().with_state(|state, ctx| {
let state = state.try_enabled()?;
ctx.with_route_table(state, |route_table, _ctx| {
Ok(route_table.get(key).map(
|MulticastRouteEntry { route: _, stats: MulticastRouteStats { last_used } }| {
MulticastRouteStats { last_used: last_used.load(Ordering::Relaxed) }
},
))
})
})
}
pub fn inspect<
N: Inspector + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
>(
&mut self,
inspector: &mut N,
) {
self.core_ctx().with_state(|state, ctx| match state {
MulticastForwardingState::Disabled => {
inspector.record_bool("ForwardingEnabled", false);
}
MulticastForwardingState::Enabled(state) => {
inspector.record_bool("ForwardingEnabled", true);
inspector.record_child("Routes", |inspector| {
ctx.with_route_table(state, |route_table, _ctx| {
for (route_key, route_entry) in route_table.iter() {
inspector.record_unnamed_child(|inspector| {
inspector.delegate_inspectable(route_key);
route_entry.inspect::<_, N>(inspector);
})
}
})
});
ctx.with_pending_table_mut(state, |pending_table| {
inspector.record_inspectable("PendingRoutes", pending_table);
});
}
})
}
}
fn handle_pending_packets<I: IpLayerIpExt, CC, BC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
packet_queue: PacketQueue<I, CC::WeakDeviceId, BC>,
key: MulticastRouteKey<I>,
route: MulticastRoute<CC::DeviceId>,
) where
CC: IpLayerForwardingContext<I, BC>
+ MulticastForwardingDeviceContext<I>
+ CounterContext<MulticastForwardingCounters<I>>,
BC: IpLayerBindingsContext<I, CC::DeviceId>,
{
let MulticastRoute { input_interface, action } = route;
if !core_ctx.is_device_multicast_forwarding_enabled(&input_interface) {
warn!(
"Dropping pending packets for newly installed multicast route: {key:?}. \
Multicast forwarding is disabled on input interface: {input_interface:?}"
);
core_ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.pending_packet_drops_disabled_dev
});
return;
}
let MulticastRouteKey { src_addr, dst_addr } = key.clone();
let dst_ip: SpecifiedAddr<I::Addr> = dst_addr.into();
let src_ip: I::RecvSrcAddr = src_addr.into();
for QueuedPacket { device, packet, frame_dst } in packet_queue.into_iter() {
let device = match device.upgrade() {
None => continue,
Some(d) => d,
};
if device != input_interface {
core_ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.pending_packet_drops_wrong_dev
});
bindings_ctx.on_event(
MulticastForwardingEvent::WrongInputInterface {
key: key.clone(),
actual_input_interface: device.clone(),
expected_input_interface: input_interface.clone(),
}
.into(),
);
continue;
}
match &action {
Action::Forward(targets) => {
core_ctx.increment(|counters: &MulticastForwardingCounters<I>| {
&counters.pending_packet_tx
});
let packet_iter = RepeatN::new(packet, targets.len());
for (mut packet, MulticastRouteTarget { output_interface, min_ttl }) in
packet_iter.zip(targets.iter())
{
let packet_metadata = Default::default();
crate::internal::base::determine_ip_packet_forwarding_action::<I, _, _>(
core_ctx,
packet.parse_ip_packet_mut(),
packet_metadata,
Some(*min_ttl),
&input_interface,
&output_interface,
IpPacketDestination::from_addr(dst_ip),
frame_dst,
src_ip,
dst_ip,
)
.perform_action_with_buffer(
core_ctx,
bindings_ctx,
packet.into_inner(),
);
}
}
}
}
}
struct RepeatN<T> {
elem: Option<T>,
size: usize,
}
impl<T> RepeatN<T> {
fn new(elem: T, size: usize) -> Self {
if size == 0 {
Self { elem: None, size }
} else {
Self { elem: Some(elem), size: size - 1 }
}
}
}
impl<T: Clone> Iterator for RepeatN<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
let Self { elem, size } = self;
if *size > 0 {
*size -= 1;
Some(elem.as_ref().unwrap().clone())
} else {
elem.take()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use core::ops::Deref;
use core::time::Duration;
use assert_matches::assert_matches;
use ip_test_macro::ip_test;
use net_types::MulticastAddr;
use netstack3_base::testutil::MultipleDevicesId;
use netstack3_base::{FrameDestination, StrongDeviceIdentifier};
use packet::ParseBuffer;
use test_case::test_case;
use crate::internal::multicast_forwarding;
use crate::internal::multicast_forwarding::packet_queue::QueuePacketOutcome;
use crate::internal::multicast_forwarding::testutil::{SentPacket, TestIpExt};
use crate::multicast_forwarding::{MulticastRoute, MulticastRouteKey, MulticastRouteTarget};
use crate::IpLayerEvent;
#[ip_test(I)]
fn enable_disable<I: IpLayerIpExt>() {
let mut api = multicast_forwarding::testutil::new_api::<I>();
assert_matches!(
api.core_ctx().state.multicast_forwarding.borrow().deref(),
&MulticastForwardingState::Disabled
);
assert!(api.enable());
assert!(!api.enable());
assert_matches!(
api.core_ctx().state.multicast_forwarding.borrow().deref(),
&MulticastForwardingState::Enabled(_)
);
assert!(api.disable());
assert!(!api.disable());
assert_matches!(
api.core_ctx().state.multicast_forwarding.borrow().deref(),
&MulticastForwardingState::Disabled
);
}
#[ip_test(I)]
fn add_remove_route<I: TestIpExt>() {
let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let key2 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
let forward_to_b = MulticastRoute::new_forward(
MultipleDevicesId::A,
[MulticastRouteTarget { output_interface: MultipleDevicesId::B, min_ttl: 0 }].into(),
)
.unwrap();
let forward_to_c = MulticastRoute::new_forward(
MultipleDevicesId::A,
[MulticastRouteTarget { output_interface: MultipleDevicesId::C, min_ttl: 0 }].into(),
)
.unwrap();
let mut api = multicast_forwarding::testutil::new_api::<I>();
assert_eq!(
api.add_multicast_route(key1.clone(), forward_to_b.clone()),
Err(MulticastForwardingDisabledError {})
);
assert_eq!(api.remove_multicast_route(&key1), Err(MulticastForwardingDisabledError {}));
assert!(api.enable());
assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_b.clone())));
assert_eq!(api.remove_multicast_route(&key1), Ok(None));
assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
assert_eq!(
api.add_multicast_route(key1.clone(), forward_to_c.clone()),
Ok(Some(forward_to_b.clone()))
);
assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_c.clone())));
assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
assert_eq!(api.add_multicast_route(key2.clone(), forward_to_c.clone()), Ok(None));
assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_b)));
assert_eq!(api.remove_multicast_route(&key2), Ok(Some(forward_to_c)));
}
#[ip_test(I)]
#[test_case(false, true; "forwarding_disabled")]
#[test_case(true, false; "forwarding_enabled_and_wrong_dev")]
#[test_case(true, true; "forwarding_enabled_and_right_dev")]
fn add_route_with_pending_packets<I: TestIpExt>(
forwarding_enabled_for_dev: bool,
right_dev: bool,
) {
const FRAME_DST: Option<FrameDestination> = None;
const OUTPUT_DEV: MultipleDevicesId = MultipleDevicesId::C;
let right_key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let wrong_key = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
let expected_dev = MultipleDevicesId::A;
let actual_dev = if right_dev { expected_dev } else { MultipleDevicesId::B };
let route = MulticastRoute::new_forward(
expected_dev,
[MulticastRouteTarget { output_interface: OUTPUT_DEV, min_ttl: 0 }].into(),
)
.unwrap();
let mut api = multicast_forwarding::testutil::new_api::<I>();
assert!(api.enable());
api.core_ctx()
.state
.set_multicast_forwarding_enabled_for_dev(expected_dev, forwarding_enabled_for_dev);
let (core_ctx, bindings_ctx) = api.contexts();
multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
let buf = multicast_forwarding::testutil::new_ip_packet_buf::<I>(I::SRC1, I::DST1);
let mut buf_ref = buf.as_ref();
let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
assert_eq!(
pending_table.try_queue_packet(
bindings_ctx,
right_key.clone(),
&packet,
&actual_dev,
FRAME_DST
),
QueuePacketOutcome::QueuedInNewQueue,
);
});
assert_eq!(api.add_multicast_route(wrong_key, route.clone()), Ok(None));
assert!(multicast_forwarding::testutil::with_pending_table(
api.core_ctx(),
|pending_table| pending_table.contains(&right_key)
));
assert_eq!(api.add_multicast_route(right_key.clone(), route), Ok(None));
assert!(multicast_forwarding::testutil::with_pending_table(
api.core_ctx(),
|pending_table| !pending_table.contains(&right_key)
));
let expect_sent_packet = forwarding_enabled_for_dev && right_dev;
let mut expected_sent_packets = vec![];
if expect_sent_packet {
expected_sent_packets.push(SentPacket {
dst: MulticastAddr::new(right_key.dst_addr()).unwrap(),
device: OUTPUT_DEV,
});
}
assert_eq!(api.core_ctx().state.take_sent_packets(), expected_sent_packets);
let mut expected_events = vec![];
if !right_dev {
expected_events.push(IpLayerEvent::MulticastForwarding(
MulticastForwardingEvent::WrongInputInterface {
key: right_key,
actual_input_interface: actual_dev,
expected_input_interface: expected_dev,
},
));
}
let (_core_ctx, bindings_ctx) = api.contexts();
assert_eq!(bindings_ctx.take_events(), expected_events);
api.core_ctx().with_counters(|counters: &MulticastForwardingCounters<I>| {
assert_eq!(counters.pending_packet_tx.get(), if expect_sent_packet { 1 } else { 0 });
assert_eq!(
counters.pending_packet_drops_disabled_dev.get(),
if forwarding_enabled_for_dev { 0 } else { 1 }
);
assert_eq!(
counters.pending_packet_drops_wrong_dev.get(),
if right_dev { 0 } else { 1 }
);
});
}
#[ip_test(I)]
fn remove_references_to_device<I: TestIpExt>() {
let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let key2 = MulticastRouteKey::new(I::SRC2, I::DST1).unwrap();
let key3 = MulticastRouteKey::new(I::SRC1, I::DST2).unwrap();
let key4 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
const GOOD_DEV1: MultipleDevicesId = MultipleDevicesId::A;
const GOOD_DEV2: MultipleDevicesId = MultipleDevicesId::B;
const BAD_DEV: MultipleDevicesId = MultipleDevicesId::C;
const GOOD_TARGET1: MulticastRouteTarget<MultipleDevicesId> =
MulticastRouteTarget { output_interface: GOOD_DEV1, min_ttl: 0 };
const GOOD_TARGET2: MulticastRouteTarget<MultipleDevicesId> =
MulticastRouteTarget { output_interface: GOOD_DEV2, min_ttl: 0 };
const BAD_TARGET: MulticastRouteTarget<MultipleDevicesId> =
MulticastRouteTarget { output_interface: BAD_DEV, min_ttl: 0 };
let dev_is_input = MulticastRoute::new_forward(BAD_DEV, [GOOD_TARGET1].into()).unwrap();
let dev_is_only_output =
MulticastRoute::new_forward(GOOD_DEV1, [BAD_TARGET].into()).unwrap();
let dev_is_one_output =
MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2, BAD_TARGET].into()).unwrap();
let no_ref_to_dev = MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2].into()).unwrap();
let mut api = multicast_forwarding::testutil::new_api::<I>();
api.remove_references_to_device(&BAD_DEV.downgrade());
assert!(api.enable());
assert_eq!(api.add_multicast_route(key1.clone(), dev_is_input), Ok(None));
assert_eq!(api.add_multicast_route(key2.clone(), dev_is_only_output), Ok(None));
assert_eq!(api.add_multicast_route(key3.clone(), dev_is_one_output), Ok(None));
assert_eq!(api.add_multicast_route(key4.clone(), no_ref_to_dev.clone()), Ok(None));
api.remove_references_to_device(&BAD_DEV.downgrade());
assert_eq!(api.remove_multicast_route(&key1), Ok(None));
assert_eq!(api.remove_multicast_route(&key2), Ok(None));
assert_eq!(
api.remove_multicast_route(&key3),
Ok(Some(MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2].into()).unwrap()))
);
assert_eq!(api.remove_multicast_route(&key4), Ok(Some(no_ref_to_dev)));
}
#[ip_test(I)]
fn get_route_stats<I: TestIpExt>() {
let key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let mut api = multicast_forwarding::testutil::new_api::<I>();
assert_eq!(api.get_route_stats(&key), Err(MulticastForwardingDisabledError {}));
assert!(api.enable());
assert_eq!(api.get_route_stats(&key), Ok(None));
let route = MulticastRoute::new_forward(
MultipleDevicesId::A,
[MulticastRouteTarget { output_interface: MultipleDevicesId::B, min_ttl: 0 }].into(),
)
.unwrap();
assert_eq!(api.add_multicast_route(key.clone(), route.clone()), Ok(None));
let original_time = api.ctx.bindings_ctx().now();
let expected_stats = MulticastRouteStats { last_used: original_time };
assert_eq!(api.get_route_stats(&key), Ok(Some(expected_stats)));
api.ctx.bindings_ctx().timers.instant.sleep(Duration::from_secs(5));
let new_time = api.ctx.bindings_ctx().now();
assert!(new_time > original_time);
let expected_stats = MulticastRouteStats { last_used: new_time };
assert_eq!(api.add_multicast_route(key.clone(), route.clone()), Ok(Some(route)));
assert_eq!(api.get_route_stats(&key), Ok(Some(expected_stats)));
}
#[test_case(0)]
#[test_case(1)]
#[test_case(10)]
fn repeat_n(size: usize) {
#[derive(Clone)]
struct Foo;
assert_eq!(RepeatN::new(Foo, size).count(), size);
}
}