use alloc::collections::{btree_map, BTreeMap};
use alloc::vec::Vec;
use arrayvec::ArrayVec;
use core::time::Duration;
use derivative::Derivative;
use net_types::ip::{Ip, IpVersionMarker};
use netstack3_base::{
CoreTimerContext, FrameDestination, Inspectable, Inspector, Instant as _,
StrongDeviceIdentifier as _, WeakDeviceIdentifier,
};
use packet::{Buf, ParseBufferMut};
use packet_formats::ip::IpPacket;
use zerocopy::SplitByteSlice;
use crate::internal::multicast_forwarding::{
MulticastForwardingBindingsContext, MulticastForwardingBindingsTypes,
MulticastForwardingTimerId,
};
use crate::multicast_forwarding::MulticastRouteKey;
use crate::IpLayerIpExt;
pub(crate) const PACKET_QUEUE_LEN: usize = 3;
const PENDING_ROUTE_EXPIRATION: Duration = Duration::from_secs(10);
const PENDING_ROUTE_GC_PERIOD: Duration = Duration::from_secs(10);
#[derive(Derivative)]
#[derivative(Debug(bound = ""))]
pub struct MulticastForwardingPendingPackets<
I: IpLayerIpExt,
D: WeakDeviceIdentifier,
BT: MulticastForwardingBindingsTypes,
> {
table: BTreeMap<MulticastRouteKey<I>, PacketQueue<I, D, BT>>,
gc_timer: BT::Timer,
}
impl<
I: IpLayerIpExt,
D: WeakDeviceIdentifier,
BC: MulticastForwardingBindingsContext<I, D::Strong>,
> MulticastForwardingPendingPackets<I, D, BC>
{
pub(crate) fn new<CC>(bindings_ctx: &mut BC) -> Self
where
CC: CoreTimerContext<MulticastForwardingTimerId<I>, BC>,
{
Self {
table: Default::default(),
gc_timer: CC::new_timer(
bindings_ctx,
MulticastForwardingTimerId::PendingPacketsGc(IpVersionMarker::<I>::new()),
),
}
}
pub(crate) fn try_queue_packet<B>(
&mut self,
bindings_ctx: &mut BC,
key: MulticastRouteKey<I>,
packet: &I::Packet<B>,
dev: &D::Strong,
frame_dst: Option<FrameDestination>,
) -> QueuePacketOutcome
where
B: SplitByteSlice,
{
let was_empty = self.table.is_empty();
let outcome = match self.table.entry(key) {
btree_map::Entry::Vacant(entry) => {
let queue = entry.insert(PacketQueue::new(bindings_ctx));
queue
.try_push(|| QueuedPacket::new(dev, packet, frame_dst))
.expect("newly instantiated queue must have capacity");
QueuePacketOutcome::QueuedInNewQueue
}
btree_map::Entry::Occupied(mut entry) => {
match entry.get_mut().try_push(|| QueuedPacket::new(dev, packet, frame_dst)) {
Ok(()) => QueuePacketOutcome::QueuedInExistingQueue,
Err(PacketQueueFullError) => QueuePacketOutcome::ExistingQueueFull,
}
}
};
if was_empty && !self.table.is_empty() {
assert!(bindings_ctx
.schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer)
.is_none());
}
outcome
}
#[cfg(any(debug_assertions, test))]
pub(crate) fn contains(&self, key: &MulticastRouteKey<I>) -> bool {
self.table.contains_key(key)
}
pub(crate) fn remove(
&mut self,
key: &MulticastRouteKey<I>,
bindings_ctx: &mut BC,
) -> Option<PacketQueue<I, D, BC>> {
let was_empty = self.table.is_empty();
let queue = self.table.remove(key);
if !was_empty && self.table.is_empty() {
let _: Option<BC::Instant> = bindings_ctx.cancel_timer(&mut self.gc_timer);
}
queue
}
pub(crate) fn run_garbage_collection(&mut self, bindings_ctx: &mut BC) -> u64 {
let now = bindings_ctx.now();
let mut removed_count = 0u64;
self.table.retain(|_key, packet_queue| {
if packet_queue.expires_at > now {
true
} else {
removed_count += packet_queue.queue.len() as u64;
false
}
});
if !self.table.is_empty() {
let _: Option<BC::Instant> =
bindings_ctx.schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer);
}
removed_count
}
}
impl<I: IpLayerIpExt, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> Inspectable
for MulticastForwardingPendingPackets<I, D, BT>
{
fn record<II: Inspector>(&self, inspector: &mut II) {
let MulticastForwardingPendingPackets { table, gc_timer: _ } = self;
inspector.record_usize("NumRoutes", table.len())
}
}
#[derive(Debug, PartialEq)]
pub(crate) enum QueuePacketOutcome {
QueuedInNewQueue,
QueuedInExistingQueue,
ExistingQueueFull,
}
#[derive(Derivative)]
#[derivative(Debug(bound = ""))]
pub struct PacketQueue<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> {
queue: ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN>,
expires_at: BT::Instant,
}
impl<
I: IpLayerIpExt,
D: WeakDeviceIdentifier,
BC: MulticastForwardingBindingsContext<I, D::Strong>,
> PacketQueue<I, D, BC>
{
fn new(bindings_ctx: &mut BC) -> Self {
Self {
queue: Default::default(),
expires_at: bindings_ctx.now().panicking_add(PENDING_ROUTE_EXPIRATION),
}
}
fn try_push(
&mut self,
packet_builder: impl FnOnce() -> QueuedPacket<I, D>,
) -> Result<(), PacketQueueFullError> {
if self.queue.is_full() {
return Err(PacketQueueFullError);
}
self.queue.push(packet_builder());
Ok(())
}
}
#[derive(Debug)]
struct PacketQueueFullError;
impl<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> IntoIterator
for PacketQueue<I, D, BT>
{
type Item = QueuedPacket<I, D>;
type IntoIter = <ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
let Self { queue, expires_at: _ } = self;
queue.into_iter()
}
}
#[derive(Debug, PartialEq)]
pub struct QueuedPacket<I: Ip, D: WeakDeviceIdentifier> {
pub(crate) device: D,
pub(crate) packet: ValidIpPacketBuf<I>,
pub(crate) frame_dst: Option<FrameDestination>,
}
impl<I: IpLayerIpExt, D: WeakDeviceIdentifier> QueuedPacket<I, D> {
fn new<B: SplitByteSlice>(
device: &D::Strong,
packet: &I::Packet<B>,
frame_dst: Option<FrameDestination>,
) -> Self {
QueuedPacket {
device: device.downgrade(),
packet: ValidIpPacketBuf::new(packet),
frame_dst,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) struct ValidIpPacketBuf<I: Ip> {
buffer: Buf<Vec<u8>>,
_version_marker: IpVersionMarker<I>,
}
impl<I: IpLayerIpExt> ValidIpPacketBuf<I> {
fn new<B: SplitByteSlice>(packet: &I::Packet<B>) -> Self {
Self { buffer: Buf::new(packet.to_vec(), ..), _version_marker: Default::default() }
}
pub(crate) fn parse_ip_packet_mut(&mut self) -> I::Packet<&mut [u8]> {
self.buffer.parse_mut().unwrap()
}
pub(crate) fn into_inner(self) -> Buf<Vec<u8>> {
let Self { buffer, _version_marker } = self;
buffer
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use ip_test_macro::ip_test;
use netstack3_base::testutil::{
FakeInstant, FakeTimerCtxExt, FakeWeakDeviceId, MultipleDevicesId,
};
use netstack3_base::{CounterContext, InstantContext, StrongDeviceIdentifier, TimerContext};
use packet::ParseBuffer;
use static_assertions::const_assert;
use test_case::test_case;
use crate::internal::multicast_forwarding;
use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
use crate::internal::multicast_forwarding::testutil::{
FakeBindingsCtx, FakeCoreCtx, TestIpExt,
};
#[ip_test(I)]
#[test_case(None; "no_frame_dst")]
#[test_case(Some(FrameDestination::Multicast); "some_frame_dst")]
fn queue_packet<I: TestIpExt>(frame_dst: Option<FrameDestination>) {
const DEV: MultipleDevicesId = MultipleDevicesId::A;
let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
let key2 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
let key3 = MulticastRouteKey::new(I::SRC1, I::DST2).unwrap();
let buf = multicast_forwarding::testutil::new_ip_packet_buf::<I>(I::SRC1, I::DST1);
let mut buf_ref = buf.as_ref();
let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
let mut bindings_ctx = FakeBindingsCtx::<I, MultipleDevicesId>::default();
let mut pending_table =
MulticastForwardingPendingPackets::<
I,
<MultipleDevicesId as StrongDeviceIdentifier>::Weak,
_,
>::new::<FakeCoreCtx<I, MultipleDevicesId>>(&mut bindings_ctx);
assert_eq!(
pending_table.try_queue_packet(
&mut bindings_ctx,
key1.clone(),
&packet,
&DEV,
frame_dst
),
QueuePacketOutcome::QueuedInNewQueue
);
for _ in 1..PACKET_QUEUE_LEN {
assert_eq!(
pending_table.try_queue_packet(
&mut bindings_ctx,
key1.clone(),
&packet,
&DEV,
frame_dst
),
QueuePacketOutcome::QueuedInExistingQueue
);
}
assert_eq!(
pending_table.try_queue_packet(
&mut bindings_ctx,
key1.clone(),
&packet,
&DEV,
frame_dst
),
QueuePacketOutcome::ExistingQueueFull
);
assert_eq!(
pending_table.try_queue_packet(
&mut bindings_ctx,
key2.clone(),
&packet,
&DEV,
frame_dst
),
QueuePacketOutcome::QueuedInNewQueue
);
let expected_packet = QueuedPacket::new(&DEV, &packet, frame_dst);
let queue =
pending_table.remove(&key1, &mut bindings_ctx).expect("key1 should have a queue");
assert_eq!(queue.queue.len(), PACKET_QUEUE_LEN);
for packet in queue.queue.as_slice() {
assert_eq!(packet, &expected_packet);
}
let queue =
pending_table.remove(&key2, &mut bindings_ctx).expect("key2 should have a queue");
let packet = assert_matches!(&queue.queue[..], [p] => p);
assert_eq!(packet, &expected_packet);
assert_matches!(pending_table.remove(&key3, &mut bindings_ctx), None);
}
fn next_gc_time<I: TestIpExt>(
core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
) -> Option<FakeInstant> {
multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
bindings_ctx.scheduled_instant(&mut pending_table.gc_timer)
})
}
fn try_queue_packet<I: TestIpExt>(
core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
key: MulticastRouteKey<I>,
dev: &MultipleDevicesId,
frame_dst: Option<FrameDestination>,
) -> QueuePacketOutcome {
let buf =
multicast_forwarding::testutil::new_ip_packet_buf::<I>(key.src_addr(), key.dst_addr());
let mut buf_ref = buf.as_ref();
let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
pending_table.try_queue_packet(bindings_ctx, key, &packet, dev, frame_dst)
})
}
fn remove_packet_queue<I: TestIpExt>(
core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
key: &MulticastRouteKey<I>,
) -> Option<
PacketQueue<I, FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx<I, MultipleDevicesId>>,
> {
multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
pending_table.remove(key, bindings_ctx)
})
}
fn run_gc<I: TestIpExt>(
core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
) {
assert_matches!(
&bindings_ctx.trigger_timers_until_instant(bindings_ctx.now(), core_ctx)[..],
[MulticastForwardingTimerId::PendingPacketsGc(_)]
);
}
#[ip_test(I)]
fn garbage_collection<I: TestIpExt>() {
const DEV: MultipleDevicesId = MultipleDevicesId::A;
const FRAME_DST: Option<FrameDestination> = None;
let key1 = MulticastRouteKey::<I>::new(I::SRC1, I::DST1).unwrap();
let key2 = MulticastRouteKey::<I>::new(I::SRC2, I::DST2).unwrap();
let mut api = multicast_forwarding::testutil::new_api();
assert!(api.enable());
let (core_ctx, bindings_ctx) = api.contexts();
const_assert!(PENDING_ROUTE_GC_PERIOD.checked_sub(PENDING_ROUTE_EXPIRATION).is_some());
const_assert!(!PENDING_ROUTE_EXPIRATION.is_zero());
assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
core_ctx.with_counters(|counters: &MulticastForwardingCounters<I>| {
assert_eq!(counters.pending_table_gc.get(), 0);
assert_eq!(counters.pending_packet_drops_gc.get(), 0);
});
let expected_first_gc = bindings_ctx.now() + PENDING_ROUTE_GC_PERIOD;
assert_eq!(
try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
QueuePacketOutcome::QueuedInNewQueue
);
assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
assert_eq!(
try_queue_packet(core_ctx, bindings_ctx, key2.clone(), &DEV, FRAME_DST),
QueuePacketOutcome::QueuedInNewQueue
);
assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
run_gc(core_ctx, bindings_ctx);
let expected_second_gc = bindings_ctx.timers.instant.now() + PENDING_ROUTE_GC_PERIOD;
assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
core_ctx.with_counters(|counters: &MulticastForwardingCounters<I>| {
assert_eq!(counters.pending_table_gc.get(), 1);
assert_eq!(counters.pending_packet_drops_gc.get(), 1);
});
assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key2), Some(_));
assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
assert_eq!(
try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
QueuePacketOutcome::QueuedInNewQueue
);
assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
run_gc(core_ctx, bindings_ctx);
core_ctx.with_counters(|counters: &MulticastForwardingCounters<I>| {
assert_eq!(counters.pending_table_gc.get(), 2);
assert_eq!(counters.pending_packet_drops_gc.get(), 2);
});
assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
}
}