1use alloc::collections::{btree_map, BTreeMap};
8use alloc::vec::Vec;
9use arrayvec::ArrayVec;
10use core::time::Duration;
11use derivative::Derivative;
12use net_types::ip::{Ip, IpVersionMarker};
13use netstack3_base::{
14 CoreTimerContext, FrameDestination, Inspectable, Inspector, Instant as _,
15 StrongDeviceIdentifier as _, WeakDeviceIdentifier,
16};
17use packet::{Buf, ParseBufferMut};
18use packet_formats::ip::IpPacket;
19use zerocopy::SplitByteSlice;
20
21use crate::internal::multicast_forwarding::{
22 MulticastForwardingBindingsContext, MulticastForwardingBindingsTypes,
23 MulticastForwardingTimerId,
24};
25use crate::multicast_forwarding::MulticastRouteKey;
26use crate::IpLayerIpExt;
27
28pub(crate) const PACKET_QUEUE_LEN: usize = 3;
33
34const PENDING_ROUTE_EXPIRATION: Duration = Duration::from_secs(10);
39
40const PENDING_ROUTE_GC_PERIOD: Duration = Duration::from_secs(10);
46
47#[derive(Derivative)]
53#[derivative(Debug(bound = ""))]
54pub struct MulticastForwardingPendingPackets<
55 I: IpLayerIpExt,
56 D: WeakDeviceIdentifier,
57 BT: MulticastForwardingBindingsTypes,
58> {
59 table: BTreeMap<MulticastRouteKey<I>, PacketQueue<I, D, BT>>,
60 gc_timer: BT::Timer,
70}
71
72impl<
73 I: IpLayerIpExt,
74 D: WeakDeviceIdentifier,
75 BC: MulticastForwardingBindingsContext<I, D::Strong>,
76 > MulticastForwardingPendingPackets<I, D, BC>
77{
78 pub(crate) fn new<CC>(bindings_ctx: &mut BC) -> Self
79 where
80 CC: CoreTimerContext<MulticastForwardingTimerId<I>, BC>,
81 {
82 Self {
83 table: Default::default(),
84 gc_timer: CC::new_timer(
85 bindings_ctx,
86 MulticastForwardingTimerId::PendingPacketsGc(IpVersionMarker::<I>::new()),
87 ),
88 }
89 }
90
91 pub(crate) fn try_queue_packet<B>(
95 &mut self,
96 bindings_ctx: &mut BC,
97 key: MulticastRouteKey<I>,
98 packet: &I::Packet<B>,
99 dev: &D::Strong,
100 frame_dst: Option<FrameDestination>,
101 ) -> QueuePacketOutcome
102 where
103 B: SplitByteSlice,
104 {
105 let was_empty = self.table.is_empty();
106 let outcome = match self.table.entry(key) {
107 btree_map::Entry::Vacant(entry) => {
108 let queue = entry.insert(PacketQueue::new(bindings_ctx));
109 queue
110 .try_push(|| QueuedPacket::new(dev, packet, frame_dst))
111 .expect("newly instantiated queue must have capacity");
112 QueuePacketOutcome::QueuedInNewQueue
113 }
114 btree_map::Entry::Occupied(mut entry) => {
115 match entry.get_mut().try_push(|| QueuedPacket::new(dev, packet, frame_dst)) {
116 Ok(()) => QueuePacketOutcome::QueuedInExistingQueue,
117 Err(PacketQueueFullError) => QueuePacketOutcome::ExistingQueueFull,
118 }
119 }
120 };
121
122 if was_empty && !self.table.is_empty() {
125 assert!(bindings_ctx
126 .schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer)
127 .is_none());
128 }
129
130 outcome
131 }
132
133 #[cfg(any(debug_assertions, test))]
134 pub(crate) fn contains(&self, key: &MulticastRouteKey<I>) -> bool {
135 self.table.contains_key(key)
136 }
137
138 pub(crate) fn remove(
142 &mut self,
143 key: &MulticastRouteKey<I>,
144 bindings_ctx: &mut BC,
145 ) -> Option<PacketQueue<I, D, BC>> {
146 let was_empty = self.table.is_empty();
147 let queue = self.table.remove(key);
148
149 if !was_empty && self.table.is_empty() {
153 let _: Option<BC::Instant> = bindings_ctx.cancel_timer(&mut self.gc_timer);
154 }
155
156 queue
157 }
158
159 pub(crate) fn run_garbage_collection(&mut self, bindings_ctx: &mut BC) -> u64 {
163 let now = bindings_ctx.now();
164 let mut removed_count = 0u64;
165 self.table.retain(|_key, packet_queue| {
166 if packet_queue.expires_at > now {
167 true
168 } else {
169 removed_count += packet_queue.queue.len() as u64;
172 false
173 }
174 });
175
176 if !self.table.is_empty() {
180 let _: Option<BC::Instant> =
181 bindings_ctx.schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer);
182 }
183
184 removed_count
185 }
186}
187
188impl<I: IpLayerIpExt, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> Inspectable
189 for MulticastForwardingPendingPackets<I, D, BT>
190{
191 fn record<II: Inspector>(&self, inspector: &mut II) {
192 let MulticastForwardingPendingPackets { table, gc_timer: _ } = self;
193 inspector.record_usize("NumRoutes", table.len())
196 }
197}
198
199#[derive(Debug, PartialEq)]
201pub(crate) enum QueuePacketOutcome {
202 QueuedInNewQueue,
205 QueuedInExistingQueue,
208 ExistingQueueFull,
211}
212
213#[derive(Derivative)]
215#[derivative(Debug(bound = ""))]
216pub struct PacketQueue<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> {
217 queue: ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN>,
218 expires_at: BT::Instant,
220}
221
222impl<
223 I: IpLayerIpExt,
224 D: WeakDeviceIdentifier,
225 BC: MulticastForwardingBindingsContext<I, D::Strong>,
226 > PacketQueue<I, D, BC>
227{
228 fn new(bindings_ctx: &mut BC) -> Self {
229 Self {
230 queue: Default::default(),
231 expires_at: bindings_ctx.now().panicking_add(PENDING_ROUTE_EXPIRATION),
232 }
233 }
234
235 fn try_push(
242 &mut self,
243 packet_builder: impl FnOnce() -> QueuedPacket<I, D>,
244 ) -> Result<(), PacketQueueFullError> {
245 if self.queue.is_full() {
246 return Err(PacketQueueFullError);
247 }
248 self.queue.push(packet_builder());
249 Ok(())
250 }
251}
252
253#[derive(Debug)]
254struct PacketQueueFullError;
255
256impl<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> IntoIterator
257 for PacketQueue<I, D, BT>
258{
259 type Item = QueuedPacket<I, D>;
260 type IntoIter = <ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN> as IntoIterator>::IntoIter;
261 fn into_iter(self) -> Self::IntoIter {
262 let Self { queue, expires_at: _ } = self;
263 queue.into_iter()
264 }
265}
266
267#[derive(Debug, PartialEq)]
269pub struct QueuedPacket<I: Ip, D: WeakDeviceIdentifier> {
270 pub(crate) device: D,
272 pub(crate) packet: ValidIpPacketBuf<I>,
274 pub(crate) frame_dst: Option<FrameDestination>,
277}
278
279impl<I: IpLayerIpExt, D: WeakDeviceIdentifier> QueuedPacket<I, D> {
280 fn new<B: SplitByteSlice>(
281 device: &D::Strong,
282 packet: &I::Packet<B>,
283 frame_dst: Option<FrameDestination>,
284 ) -> Self {
285 QueuedPacket {
286 device: device.downgrade(),
287 packet: ValidIpPacketBuf::new(packet),
288 frame_dst,
289 }
290 }
291}
292
293#[derive(Clone, Debug, PartialEq)]
298pub(crate) struct ValidIpPacketBuf<I: Ip> {
299 buffer: Buf<Vec<u8>>,
300 _version_marker: IpVersionMarker<I>,
301}
302
303impl<I: IpLayerIpExt> ValidIpPacketBuf<I> {
304 fn new<B: SplitByteSlice>(packet: &I::Packet<B>) -> Self {
305 Self { buffer: Buf::new(packet.to_vec(), ..), _version_marker: Default::default() }
306 }
307
308 pub(crate) fn parse_ip_packet_mut(&mut self) -> I::Packet<&mut [u8]> {
316 self.buffer.parse_mut().unwrap()
318 }
319
320 pub(crate) fn into_inner(self) -> Buf<Vec<u8>> {
321 let Self { buffer, _version_marker } = self;
322 buffer
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 use assert_matches::assert_matches;
331 use ip_test_macro::ip_test;
332 use netstack3_base::testutil::{
333 FakeInstant, FakeTimerCtxExt, FakeWeakDeviceId, MultipleDevicesId,
334 };
335 use netstack3_base::{CounterContext, InstantContext, StrongDeviceIdentifier, TimerContext};
336 use packet::ParseBuffer;
337 use static_assertions::const_assert;
338 use test_case::test_case;
339
340 use crate::internal::multicast_forwarding;
341 use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
342 use crate::internal::multicast_forwarding::testutil::{
343 FakeBindingsCtx, FakeCoreCtx, TestIpExt,
344 };
345
346 #[ip_test(I)]
347 #[test_case(None; "no_frame_dst")]
348 #[test_case(Some(FrameDestination::Multicast); "some_frame_dst")]
349 fn queue_packet<I: TestIpExt>(frame_dst: Option<FrameDestination>) {
350 const DEV: MultipleDevicesId = MultipleDevicesId::A;
351 let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
352 let key2 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
353 let key3 = MulticastRouteKey::new(I::SRC1, I::DST2).unwrap();
354
355 let buf = multicast_forwarding::testutil::new_ip_packet_buf::<I>(I::SRC1, I::DST1);
358 let mut buf_ref = buf.as_ref();
359 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
360
361 let mut bindings_ctx = FakeBindingsCtx::<I, MultipleDevicesId>::default();
362
363 let mut pending_table =
364 MulticastForwardingPendingPackets::<
365 I,
366 <MultipleDevicesId as StrongDeviceIdentifier>::Weak,
367 _,
368 >::new::<FakeCoreCtx<I, MultipleDevicesId>>(&mut bindings_ctx);
369
370 assert_eq!(
372 pending_table.try_queue_packet(
373 &mut bindings_ctx,
374 key1.clone(),
375 &packet,
376 &DEV,
377 frame_dst
378 ),
379 QueuePacketOutcome::QueuedInNewQueue
380 );
381 for _ in 1..PACKET_QUEUE_LEN {
383 assert_eq!(
384 pending_table.try_queue_packet(
385 &mut bindings_ctx,
386 key1.clone(),
387 &packet,
388 &DEV,
389 frame_dst
390 ),
391 QueuePacketOutcome::QueuedInExistingQueue
392 );
393 }
394 assert_eq!(
396 pending_table.try_queue_packet(
397 &mut bindings_ctx,
398 key1.clone(),
399 &packet,
400 &DEV,
401 frame_dst
402 ),
403 QueuePacketOutcome::ExistingQueueFull
404 );
405
406 assert_eq!(
408 pending_table.try_queue_packet(
409 &mut bindings_ctx,
410 key2.clone(),
411 &packet,
412 &DEV,
413 frame_dst
414 ),
415 QueuePacketOutcome::QueuedInNewQueue
416 );
417
418 let expected_packet = QueuedPacket::new(&DEV, &packet, frame_dst);
422 let queue =
423 pending_table.remove(&key1, &mut bindings_ctx).expect("key1 should have a queue");
424 assert_eq!(queue.queue.len(), PACKET_QUEUE_LEN);
425 for packet in queue.queue.as_slice() {
426 assert_eq!(packet, &expected_packet);
427 }
428
429 let queue =
430 pending_table.remove(&key2, &mut bindings_ctx).expect("key2 should have a queue");
431 let packet = assert_matches!(&queue.queue[..], [p] => p);
432 assert_eq!(packet, &expected_packet);
433
434 assert_matches!(pending_table.remove(&key3, &mut bindings_ctx), None);
435 }
436
437 fn next_gc_time<I: TestIpExt>(
439 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
440 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
441 ) -> Option<FakeInstant> {
442 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
443 bindings_ctx.scheduled_instant(&mut pending_table.gc_timer)
444 })
445 }
446
447 fn try_queue_packet<I: TestIpExt>(
449 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
450 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
451 key: MulticastRouteKey<I>,
452 dev: &MultipleDevicesId,
453 frame_dst: Option<FrameDestination>,
454 ) -> QueuePacketOutcome {
455 let buf =
456 multicast_forwarding::testutil::new_ip_packet_buf::<I>(key.src_addr(), key.dst_addr());
457 let mut buf_ref = buf.as_ref();
458 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
459 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
460 pending_table.try_queue_packet(bindings_ctx, key, &packet, dev, frame_dst)
461 })
462 }
463
464 fn remove_packet_queue<I: TestIpExt>(
466 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
467 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
468 key: &MulticastRouteKey<I>,
469 ) -> Option<
470 PacketQueue<I, FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx<I, MultipleDevicesId>>,
471 > {
472 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
473 pending_table.remove(key, bindings_ctx)
474 })
475 }
476
477 fn run_gc<I: TestIpExt>(
479 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
480 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
481 ) {
482 assert_matches!(
483 &bindings_ctx.trigger_timers_until_instant(bindings_ctx.now(), core_ctx)[..],
484 [MulticastForwardingTimerId::PendingPacketsGc(_)]
485 );
486 }
487
488 #[ip_test(I)]
489 fn garbage_collection<I: TestIpExt>() {
490 const DEV: MultipleDevicesId = MultipleDevicesId::A;
491 const FRAME_DST: Option<FrameDestination> = None;
492 let key1 = MulticastRouteKey::<I>::new(I::SRC1, I::DST1).unwrap();
493 let key2 = MulticastRouteKey::<I>::new(I::SRC2, I::DST2).unwrap();
494
495 let mut api = multicast_forwarding::testutil::new_api();
496 assert!(api.enable());
497 let (core_ctx, bindings_ctx) = api.contexts();
498
499 const_assert!(PENDING_ROUTE_GC_PERIOD.checked_sub(PENDING_ROUTE_EXPIRATION).is_some());
505 const_assert!(!PENDING_ROUTE_EXPIRATION.is_zero());
506
507 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
509 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
510 assert_eq!(counters.pending_table_gc.get(), 0);
511 assert_eq!(counters.pending_packet_drops_gc.get(), 0);
512
513 let expected_first_gc = bindings_ctx.now() + PENDING_ROUTE_GC_PERIOD;
515 assert_eq!(
516 try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
517 QueuePacketOutcome::QueuedInNewQueue
518 );
519 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
520
521 bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
525 assert_eq!(
526 try_queue_packet(core_ctx, bindings_ctx, key2.clone(), &DEV, FRAME_DST),
527 QueuePacketOutcome::QueuedInNewQueue
528 );
529 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
530
531 run_gc(core_ctx, bindings_ctx);
534 let expected_second_gc = bindings_ctx.timers.instant.now() + PENDING_ROUTE_GC_PERIOD;
535 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
536
537 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
539 assert_eq!(counters.pending_table_gc.get(), 1);
540 assert_eq!(counters.pending_packet_drops_gc.get(), 1);
541 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
542 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key2), Some(_));
543
544 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
547
548 assert_eq!(
551 try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
552 QueuePacketOutcome::QueuedInNewQueue
553 );
554 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
555 bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
556 run_gc(core_ctx, bindings_ctx);
557 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
558 assert_eq!(counters.pending_table_gc.get(), 2);
559 assert_eq!(counters.pending_packet_drops_gc.get(), 2);
560 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
561 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
562 }
563}