1use alloc::collections::{BTreeMap, btree_map};
8use alloc::vec::Vec;
9use arrayvec::ArrayVec;
10use core::time::Duration;
11use derivative::Derivative;
12use net_types::ip::{Ip, IpVersionMarker};
13use netstack3_base::{
14 CoreTimerContext, FrameDestination, Inspectable, Inspector, Instant as _,
15 StrongDeviceIdentifier as _, WeakDeviceIdentifier,
16};
17use packet::{Buf, ParseBufferMut};
18use packet_formats::ip::IpPacket;
19use zerocopy::SplitByteSlice;
20
21use crate::IpLayerIpExt;
22use crate::internal::multicast_forwarding::{
23 MulticastForwardingBindingsContext, MulticastForwardingBindingsTypes,
24 MulticastForwardingTimerId,
25};
26use crate::multicast_forwarding::MulticastRouteKey;
27
28pub(crate) const PACKET_QUEUE_LEN: usize = 3;
33
34const PENDING_ROUTE_EXPIRATION: Duration = Duration::from_secs(10);
39
40const PENDING_ROUTE_GC_PERIOD: Duration = Duration::from_secs(10);
46
47#[derive(Derivative)]
53#[derivative(Debug(bound = ""))]
54pub struct MulticastForwardingPendingPackets<
55 I: IpLayerIpExt,
56 D: WeakDeviceIdentifier,
57 BT: MulticastForwardingBindingsTypes,
58> {
59 table: BTreeMap<MulticastRouteKey<I>, PacketQueue<I, D, BT>>,
60 gc_timer: BT::Timer,
70}
71
72impl<I: IpLayerIpExt, D: WeakDeviceIdentifier, BC: MulticastForwardingBindingsContext<I, D::Strong>>
73 MulticastForwardingPendingPackets<I, D, BC>
74{
75 pub(crate) fn new<CC>(bindings_ctx: &mut BC) -> Self
76 where
77 CC: CoreTimerContext<MulticastForwardingTimerId<I>, BC>,
78 {
79 Self {
80 table: Default::default(),
81 gc_timer: CC::new_timer(
82 bindings_ctx,
83 MulticastForwardingTimerId::PendingPacketsGc(IpVersionMarker::<I>::new()),
84 ),
85 }
86 }
87
88 pub(crate) fn try_queue_packet<B>(
92 &mut self,
93 bindings_ctx: &mut BC,
94 key: MulticastRouteKey<I>,
95 packet: &I::Packet<B>,
96 dev: &D::Strong,
97 frame_dst: Option<FrameDestination>,
98 ) -> QueuePacketOutcome
99 where
100 B: SplitByteSlice,
101 {
102 let was_empty = self.table.is_empty();
103 let outcome = match self.table.entry(key) {
104 btree_map::Entry::Vacant(entry) => {
105 let queue = entry.insert(PacketQueue::new(bindings_ctx));
106 queue
107 .try_push(|| QueuedPacket::new(dev, packet, frame_dst))
108 .expect("newly instantiated queue must have capacity");
109 QueuePacketOutcome::QueuedInNewQueue
110 }
111 btree_map::Entry::Occupied(mut entry) => {
112 match entry.get_mut().try_push(|| QueuedPacket::new(dev, packet, frame_dst)) {
113 Ok(()) => QueuePacketOutcome::QueuedInExistingQueue,
114 Err(PacketQueueFullError) => QueuePacketOutcome::ExistingQueueFull,
115 }
116 }
117 };
118
119 if was_empty && !self.table.is_empty() {
122 assert!(
123 bindings_ctx.schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer).is_none()
124 );
125 }
126
127 outcome
128 }
129
130 #[cfg(any(debug_assertions, test))]
131 pub(crate) fn contains(&self, key: &MulticastRouteKey<I>) -> bool {
132 self.table.contains_key(key)
133 }
134
135 pub(crate) fn remove(
139 &mut self,
140 key: &MulticastRouteKey<I>,
141 bindings_ctx: &mut BC,
142 ) -> Option<PacketQueue<I, D, BC>> {
143 let was_empty = self.table.is_empty();
144 let queue = self.table.remove(key);
145
146 if !was_empty && self.table.is_empty() {
150 let _: Option<BC::Instant> = bindings_ctx.cancel_timer(&mut self.gc_timer);
151 }
152
153 queue
154 }
155
156 pub(crate) fn run_garbage_collection(&mut self, bindings_ctx: &mut BC) -> u64 {
160 let now = bindings_ctx.now();
161 let mut removed_count = 0u64;
162 self.table.retain(|_key, packet_queue| {
163 if packet_queue.expires_at > now {
164 true
165 } else {
166 removed_count += packet_queue.queue.len() as u64;
169 false
170 }
171 });
172
173 if !self.table.is_empty() {
177 let _: Option<BC::Instant> =
178 bindings_ctx.schedule_timer(PENDING_ROUTE_GC_PERIOD, &mut self.gc_timer);
179 }
180
181 removed_count
182 }
183}
184
185impl<I: IpLayerIpExt, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> Inspectable
186 for MulticastForwardingPendingPackets<I, D, BT>
187{
188 fn record<II: Inspector>(&self, inspector: &mut II) {
189 let MulticastForwardingPendingPackets { table, gc_timer: _ } = self;
190 inspector.record_usize("NumRoutes", table.len())
193 }
194}
195
196#[derive(Debug, PartialEq)]
198pub(crate) enum QueuePacketOutcome {
199 QueuedInNewQueue,
202 QueuedInExistingQueue,
205 ExistingQueueFull,
208}
209
210#[derive(Derivative)]
212#[derivative(Debug(bound = ""))]
213pub struct PacketQueue<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> {
214 queue: ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN>,
215 expires_at: BT::Instant,
217}
218
219impl<I: IpLayerIpExt, D: WeakDeviceIdentifier, BC: MulticastForwardingBindingsContext<I, D::Strong>>
220 PacketQueue<I, D, BC>
221{
222 fn new(bindings_ctx: &mut BC) -> Self {
223 Self {
224 queue: Default::default(),
225 expires_at: bindings_ctx.now().panicking_add(PENDING_ROUTE_EXPIRATION),
226 }
227 }
228
229 fn try_push(
236 &mut self,
237 packet_builder: impl FnOnce() -> QueuedPacket<I, D>,
238 ) -> Result<(), PacketQueueFullError> {
239 if self.queue.is_full() {
240 return Err(PacketQueueFullError);
241 }
242 self.queue.push(packet_builder());
243 Ok(())
244 }
245}
246
247#[derive(Debug)]
248struct PacketQueueFullError;
249
250impl<I: Ip, D: WeakDeviceIdentifier, BT: MulticastForwardingBindingsTypes> IntoIterator
251 for PacketQueue<I, D, BT>
252{
253 type Item = QueuedPacket<I, D>;
254 type IntoIter = <ArrayVec<QueuedPacket<I, D>, PACKET_QUEUE_LEN> as IntoIterator>::IntoIter;
255 fn into_iter(self) -> Self::IntoIter {
256 let Self { queue, expires_at: _ } = self;
257 queue.into_iter()
258 }
259}
260
261#[derive(Debug, PartialEq)]
263pub struct QueuedPacket<I: Ip, D: WeakDeviceIdentifier> {
264 pub(crate) device: D,
266 pub(crate) packet: ValidIpPacketBuf<I>,
268 pub(crate) frame_dst: Option<FrameDestination>,
271}
272
273impl<I: IpLayerIpExt, D: WeakDeviceIdentifier> QueuedPacket<I, D> {
274 fn new<B: SplitByteSlice>(
275 device: &D::Strong,
276 packet: &I::Packet<B>,
277 frame_dst: Option<FrameDestination>,
278 ) -> Self {
279 QueuedPacket {
280 device: device.downgrade(),
281 packet: ValidIpPacketBuf::new(packet),
282 frame_dst,
283 }
284 }
285}
286
287#[derive(Clone, Debug, PartialEq)]
292pub(crate) struct ValidIpPacketBuf<I: Ip> {
293 buffer: Buf<Vec<u8>>,
294 _version_marker: IpVersionMarker<I>,
295}
296
297impl<I: IpLayerIpExt> ValidIpPacketBuf<I> {
298 fn new<B: SplitByteSlice>(packet: &I::Packet<B>) -> Self {
299 Self { buffer: Buf::new(packet.to_vec(), ..), _version_marker: Default::default() }
300 }
301
302 pub(crate) fn parse_ip_packet_mut(&mut self) -> I::Packet<&mut [u8]> {
310 self.buffer.parse_mut().unwrap()
312 }
313
314 pub(crate) fn into_inner(self) -> Buf<Vec<u8>> {
315 let Self { buffer, _version_marker } = self;
316 buffer
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 use assert_matches::assert_matches;
325 use ip_test_macro::ip_test;
326 use netstack3_base::testutil::{
327 FakeInstant, FakeTimerCtxExt, FakeWeakDeviceId, MultipleDevicesId,
328 };
329 use netstack3_base::{CounterContext, InstantContext, StrongDeviceIdentifier, TimerContext};
330 use packet::ParseBuffer;
331 use static_assertions::const_assert;
332 use test_case::test_case;
333
334 use crate::internal::multicast_forwarding;
335 use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
336 use crate::internal::multicast_forwarding::testutil::{
337 FakeBindingsCtx, FakeCoreCtx, TestIpExt,
338 };
339
340 #[ip_test(I)]
341 #[test_case(None; "no_frame_dst")]
342 #[test_case(Some(FrameDestination::Multicast); "some_frame_dst")]
343 fn queue_packet<I: TestIpExt>(frame_dst: Option<FrameDestination>) {
344 const DEV: MultipleDevicesId = MultipleDevicesId::A;
345 let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
346 let key2 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
347 let key3 = MulticastRouteKey::new(I::SRC1, I::DST2).unwrap();
348
349 let buf = multicast_forwarding::testutil::new_ip_packet_buf::<I>(I::SRC1, I::DST1);
352 let mut buf_ref = buf.as_ref();
353 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
354
355 let mut bindings_ctx = FakeBindingsCtx::<I, MultipleDevicesId>::default();
356
357 let mut pending_table =
358 MulticastForwardingPendingPackets::<
359 I,
360 <MultipleDevicesId as StrongDeviceIdentifier>::Weak,
361 _,
362 >::new::<FakeCoreCtx<I, MultipleDevicesId>>(&mut bindings_ctx);
363
364 assert_eq!(
366 pending_table.try_queue_packet(
367 &mut bindings_ctx,
368 key1.clone(),
369 &packet,
370 &DEV,
371 frame_dst
372 ),
373 QueuePacketOutcome::QueuedInNewQueue
374 );
375 for _ in 1..PACKET_QUEUE_LEN {
377 assert_eq!(
378 pending_table.try_queue_packet(
379 &mut bindings_ctx,
380 key1.clone(),
381 &packet,
382 &DEV,
383 frame_dst
384 ),
385 QueuePacketOutcome::QueuedInExistingQueue
386 );
387 }
388 assert_eq!(
390 pending_table.try_queue_packet(
391 &mut bindings_ctx,
392 key1.clone(),
393 &packet,
394 &DEV,
395 frame_dst
396 ),
397 QueuePacketOutcome::ExistingQueueFull
398 );
399
400 assert_eq!(
402 pending_table.try_queue_packet(
403 &mut bindings_ctx,
404 key2.clone(),
405 &packet,
406 &DEV,
407 frame_dst
408 ),
409 QueuePacketOutcome::QueuedInNewQueue
410 );
411
412 let expected_packet = QueuedPacket::new(&DEV, &packet, frame_dst);
416 let queue =
417 pending_table.remove(&key1, &mut bindings_ctx).expect("key1 should have a queue");
418 assert_eq!(queue.queue.len(), PACKET_QUEUE_LEN);
419 for packet in queue.queue.as_slice() {
420 assert_eq!(packet, &expected_packet);
421 }
422
423 let queue =
424 pending_table.remove(&key2, &mut bindings_ctx).expect("key2 should have a queue");
425 let packet = assert_matches!(&queue.queue[..], [p] => p);
426 assert_eq!(packet, &expected_packet);
427
428 assert_matches!(pending_table.remove(&key3, &mut bindings_ctx), None);
429 }
430
431 fn next_gc_time<I: TestIpExt>(
433 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
434 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
435 ) -> Option<FakeInstant> {
436 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
437 bindings_ctx.scheduled_instant(&mut pending_table.gc_timer)
438 })
439 }
440
441 fn try_queue_packet<I: TestIpExt>(
443 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
444 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
445 key: MulticastRouteKey<I>,
446 dev: &MultipleDevicesId,
447 frame_dst: Option<FrameDestination>,
448 ) -> QueuePacketOutcome {
449 let buf =
450 multicast_forwarding::testutil::new_ip_packet_buf::<I>(key.src_addr(), key.dst_addr());
451 let mut buf_ref = buf.as_ref();
452 let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
453 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
454 pending_table.try_queue_packet(bindings_ctx, key, &packet, dev, frame_dst)
455 })
456 }
457
458 fn remove_packet_queue<I: TestIpExt>(
460 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
461 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
462 key: &MulticastRouteKey<I>,
463 ) -> Option<
464 PacketQueue<I, FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx<I, MultipleDevicesId>>,
465 > {
466 multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
467 pending_table.remove(key, bindings_ctx)
468 })
469 }
470
471 fn run_gc<I: TestIpExt>(
473 core_ctx: &mut FakeCoreCtx<I, MultipleDevicesId>,
474 bindings_ctx: &mut FakeBindingsCtx<I, MultipleDevicesId>,
475 ) {
476 assert_matches!(
477 &bindings_ctx.trigger_timers_until_instant(bindings_ctx.now(), core_ctx)[..],
478 [MulticastForwardingTimerId::PendingPacketsGc(_)]
479 );
480 }
481
482 #[ip_test(I)]
483 fn garbage_collection<I: TestIpExt>() {
484 const DEV: MultipleDevicesId = MultipleDevicesId::A;
485 const FRAME_DST: Option<FrameDestination> = None;
486 let key1 = MulticastRouteKey::<I>::new(I::SRC1, I::DST1).unwrap();
487 let key2 = MulticastRouteKey::<I>::new(I::SRC2, I::DST2).unwrap();
488
489 let mut api = multicast_forwarding::testutil::new_api();
490 assert!(api.enable());
491 let (core_ctx, bindings_ctx) = api.contexts();
492
493 const_assert!(PENDING_ROUTE_GC_PERIOD.checked_sub(PENDING_ROUTE_EXPIRATION).is_some());
499 const_assert!(!PENDING_ROUTE_EXPIRATION.is_zero());
500
501 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
503 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
504 assert_eq!(counters.pending_table_gc.get(), 0);
505 assert_eq!(counters.pending_packet_drops_gc.get(), 0);
506
507 let expected_first_gc = bindings_ctx.now() + PENDING_ROUTE_GC_PERIOD;
509 assert_eq!(
510 try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
511 QueuePacketOutcome::QueuedInNewQueue
512 );
513 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
514
515 bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
519 assert_eq!(
520 try_queue_packet(core_ctx, bindings_ctx, key2.clone(), &DEV, FRAME_DST),
521 QueuePacketOutcome::QueuedInNewQueue
522 );
523 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_first_gc));
524
525 run_gc(core_ctx, bindings_ctx);
528 let expected_second_gc = bindings_ctx.timers.instant.now() + PENDING_ROUTE_GC_PERIOD;
529 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
530
531 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
533 assert_eq!(counters.pending_table_gc.get(), 1);
534 assert_eq!(counters.pending_packet_drops_gc.get(), 1);
535 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
536 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key2), Some(_));
537
538 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
541
542 assert_eq!(
545 try_queue_packet(core_ctx, bindings_ctx, key1.clone(), &DEV, FRAME_DST),
546 QueuePacketOutcome::QueuedInNewQueue
547 );
548 assert_eq!(next_gc_time(core_ctx, bindings_ctx), Some(expected_second_gc));
549 bindings_ctx.timers.instant.sleep(PENDING_ROUTE_GC_PERIOD);
550 run_gc(core_ctx, bindings_ctx);
551 let counters: &MulticastForwardingCounters<I> = core_ctx.counters();
552 assert_eq!(counters.pending_table_gc.get(), 2);
553 assert_eq!(counters.pending_packet_drops_gc.get(), 2);
554 assert_matches!(remove_packet_queue(core_ctx, bindings_ctx, &key1), None);
555 assert!(next_gc_time(core_ctx, bindings_ctx).is_none());
556 }
557}