Skip to main content

netstack3_ip/
multicast_forwarding.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! An implementation of multicast forwarding.
6//!
7//! Multicast forwarding is the ability for netstack to forward multicast
8//! packets that arrive on an interface out multiple interfaces (while also
9//! optionally delivering the packet to the host itself if the arrival host has
10//! an interest in the packet).
11//!
12//! Note that multicast forwarding decisions are made by consulting the
13//! multicast routing table, a routing table entirely separate from the unicast
14//! routing table(s).
15
16pub(crate) mod api;
17pub(crate) mod counters;
18pub(crate) mod packet_queue;
19pub(crate) mod route;
20pub(crate) mod state;
21
22use core::sync::atomic::Ordering;
23
24use net_types::ip::{GenericOverIp, Ip, IpVersionMarker};
25use netstack3_base::{
26    AnyDevice, AtomicInstant, CounterContext, DeviceIdContext, EventContext, FrameDestination,
27    HandleableTimer, InstantBindingsTypes, InstantContext, TimerBindingsTypes, TimerContext,
28    WeakDeviceIdentifier,
29};
30use packet_formats::ip::IpPacket;
31use zerocopy::SplitByteSlice;
32
33use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
34use crate::internal::multicast_forwarding::packet_queue::QueuePacketOutcome;
35use crate::internal::multicast_forwarding::route::{
36    Action, MulticastRouteEntry, MulticastRouteTargets,
37};
38use crate::multicast_forwarding::{
39    MulticastForwardingPendingPacketsContext, MulticastForwardingState,
40    MulticastForwardingStateContext, MulticastRoute, MulticastRouteKey,
41    MulticastRouteTableContext as _,
42};
43use crate::{IpLayerEvent, IpLayerIpExt};
44
45/// Required types for multicast forwarding provided by Bindings.
46pub trait MulticastForwardingBindingsTypes: InstantBindingsTypes + TimerBindingsTypes {}
47impl<BT: InstantBindingsTypes + TimerBindingsTypes> MulticastForwardingBindingsTypes for BT {}
48
49/// Required functionality for multicast forwarding provided by Bindings.
50pub trait MulticastForwardingBindingsContext<I: IpLayerIpExt, D>:
51    MulticastForwardingBindingsTypes + InstantContext + TimerContext + EventContext<IpLayerEvent<D, I>>
52{
53}
54impl<
55    I: IpLayerIpExt,
56    D,
57    BC: MulticastForwardingBindingsTypes
58        + InstantContext
59        + TimerContext
60        + EventContext<IpLayerEvent<D, I>>,
61> MulticastForwardingBindingsContext<I, D> for BC
62{
63}
64
65/// Device related functionality required by multicast forwarding.
66pub trait MulticastForwardingDeviceContext<I: IpLayerIpExt>: DeviceIdContext<AnyDevice> {
67    /// True if the given device has multicast forwarding enabled.
68    fn is_device_multicast_forwarding_enabled(&mut self, dev: &Self::DeviceId) -> bool;
69}
70
71/// A timer event for multicast forwarding.
72#[derive(Clone, Debug, Eq, GenericOverIp, Hash, PartialEq)]
73#[generic_over_ip(I, Ip)]
74pub enum MulticastForwardingTimerId<I: Ip> {
75    /// A trigger to perform garbage collection on the pending packets table.
76    PendingPacketsGc(IpVersionMarker<I>),
77}
78
79impl<
80    I: IpLayerIpExt,
81    BC: MulticastForwardingBindingsContext<I, CC::DeviceId>,
82    CC: MulticastForwardingStateContext<I, BC> + CounterContext<MulticastForwardingCounters<I>>,
83> HandleableTimer<CC, BC> for MulticastForwardingTimerId<I>
84{
85    fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BC, _: BC::UniqueTimerId) {
86        match self {
87            MulticastForwardingTimerId::PendingPacketsGc(_) => {
88                core_ctx.with_state(|state, ctx| match state {
89                    // Multicast forwarding was disabled after GC was scheduled;
90                    // there are no resources to GC now.
91                    MulticastForwardingState::Disabled => {}
92                    MulticastForwardingState::Enabled(state) => {
93                        CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
94                            .pending_table_gc
95                            .increment();
96                        let removed_count = ctx.with_pending_table_mut(state, |pending_table| {
97                            pending_table.run_garbage_collection(bindings_ctx)
98                        });
99                        CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
100                            .pending_packet_drops_gc
101                            .add(removed_count);
102                    }
103                })
104            }
105        }
106    }
107}
108
109/// Events that may be published by the multicast forwarding engine.
110#[derive(Debug, Eq, Hash, PartialEq, GenericOverIp)]
111#[generic_over_ip(I, Ip)]
112pub enum MulticastForwardingEvent<I: IpLayerIpExt, D> {
113    /// A multicast packet was received for which there was no applicable route.
114    MissingRoute {
115        /// The key of the route that's missing.
116        key: MulticastRouteKey<I>,
117        /// The interface on which the packet was received.
118        input_interface: D,
119    },
120    /// A multicast packet was received on an unexpected input interface.
121    WrongInputInterface {
122        /// The key of the route with the unexpected input interface.
123        key: MulticastRouteKey<I>,
124        /// The interface on which the packet was received.
125        actual_input_interface: D,
126        /// The interface on which the packet was expected (as specified in the
127        /// multicast route).
128        expected_input_interface: D,
129    },
130}
131
132impl<I: IpLayerIpExt, D> MulticastForwardingEvent<I, D> {
133    pub(crate) fn map_device<O, F: Fn(D) -> O>(self, map: F) -> MulticastForwardingEvent<I, O> {
134        match self {
135            MulticastForwardingEvent::MissingRoute { key, input_interface } => {
136                MulticastForwardingEvent::MissingRoute {
137                    key,
138                    input_interface: map(input_interface),
139                }
140            }
141            MulticastForwardingEvent::WrongInputInterface {
142                key,
143                actual_input_interface,
144                expected_input_interface,
145            } => MulticastForwardingEvent::WrongInputInterface {
146                key,
147                actual_input_interface: map(actual_input_interface),
148                expected_input_interface: map(expected_input_interface),
149            },
150        }
151    }
152}
153
154impl<I: IpLayerIpExt, D: WeakDeviceIdentifier> MulticastForwardingEvent<I, D> {
155    /// Upgrades the device IDs held by this event.
156    pub fn upgrade_device_id(self) -> Option<MulticastForwardingEvent<I, D::Strong>> {
157        match self {
158            MulticastForwardingEvent::MissingRoute { key, input_interface } => {
159                Some(MulticastForwardingEvent::MissingRoute {
160                    key,
161                    input_interface: input_interface.upgrade()?,
162                })
163            }
164            MulticastForwardingEvent::WrongInputInterface {
165                key,
166                actual_input_interface,
167                expected_input_interface,
168            } => Some(MulticastForwardingEvent::WrongInputInterface {
169                key,
170                actual_input_interface: actual_input_interface.upgrade()?,
171                expected_input_interface: expected_input_interface.upgrade()?,
172            }),
173        }
174    }
175}
176
177/// Query the multicast route table and return the forwarding targets.
178///
179/// `None` may be returned in several situations:
180///   * if multicast forwarding is disabled (either stack-wide or for the
181///     provided `dev`),
182///   * if the packets src/dst addrs are not viable for multicast forwarding
183///     (see the requirements on [`MulticastRouteKey`]), or
184///   * if the route table does not have an entry suitable for this packet.
185///
186/// In the latter case, the packet is stashed in the
187/// [`MulticastForwardingPendingPackets`] table, and a relevant event is
188/// dispatched to bindings.
189///
190/// Note that the returned targets are not synchronized with the multicast route
191/// table and may grow stale if the table is updated.
192pub(crate) fn lookup_multicast_route_or_stash_packet<I, B, CC, BC>(
193    core_ctx: &mut CC,
194    bindings_ctx: &mut BC,
195    packet: &I::Packet<B>,
196    dev: &CC::DeviceId,
197    frame_dst: Option<FrameDestination>,
198) -> Option<MulticastRouteTargets<CC::DeviceId>>
199where
200    I: IpLayerIpExt,
201    B: SplitByteSlice,
202    CC: MulticastForwardingStateContext<I, BC>
203        + MulticastForwardingDeviceContext<I>
204        + CounterContext<MulticastForwardingCounters<I>>,
205    BC: MulticastForwardingBindingsContext<I, CC::DeviceId>,
206{
207    CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx).rx.increment();
208    // Short circuit if the packet's addresses don't constitute a valid
209    // multicast route key (e.g. src is not unicast, or dst is not multicast).
210    let Some(key) = MulticastRouteKey::new(packet.src_ip(), packet.dst_ip()) else {
211        CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx)
212            .no_tx_invalid_key
213            .increment();
214        return None;
215    };
216
217    // Short circuit if the device has forwarding disabled.
218    if !core_ctx.is_device_multicast_forwarding_enabled(dev) {
219        CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx)
220            .no_tx_disabled_dev
221            .increment();
222        return None;
223    }
224
225    core_ctx.with_state(|state, ctx| {
226        // Short circuit if forwarding is disabled stack-wide.
227        let Some(state) = state.enabled() else {
228            CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
229                .no_tx_disabled_stack_wide
230                .increment();
231            return None;
232        };
233        ctx.with_route_table(state, |route_table, ctx| {
234            if let Some(MulticastRouteEntry {
235                route: MulticastRoute { input_interface, action },
236                stats,
237            }) = route_table.get(&key)
238            {
239                if dev != input_interface {
240                    CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
241                        .no_tx_wrong_dev
242                        .increment();
243                    bindings_ctx.on_event(
244                        MulticastForwardingEvent::WrongInputInterface {
245                            key,
246                            actual_input_interface: dev.clone(),
247                            expected_input_interface: input_interface.clone(),
248                        }
249                        .into(),
250                    );
251                    return None;
252                }
253
254                stats.last_used.store_max(bindings_ctx.now(), Ordering::Relaxed);
255
256                match action {
257                    Action::Forward(targets) => {
258                        CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
259                            .tx
260                            .increment();
261                        return Some(targets.clone());
262                    }
263                }
264            }
265            CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
266                .pending_packets
267                .increment();
268            match ctx.with_pending_table_mut(state, |pending_table| {
269                pending_table.try_queue_packet(bindings_ctx, key.clone(), packet, dev, frame_dst)
270            }) {
271                QueuePacketOutcome::QueuedInNewQueue => {
272                    bindings_ctx.on_event(
273                        MulticastForwardingEvent::MissingRoute {
274                            key,
275                            input_interface: dev.clone(),
276                        }
277                        .into(),
278                    );
279                }
280                QueuePacketOutcome::QueuedInExistingQueue => {}
281                QueuePacketOutcome::ExistingQueueFull => {
282                    CounterContext::<MulticastForwardingCounters<I>>::counters(ctx)
283                        .pending_packet_drops_queue_full
284                        .increment();
285                }
286            }
287            return None;
288        })
289    })
290}
291
292#[cfg(test)]
293mod testutil {
294    use super::*;
295
296    use alloc::rc::Rc;
297    use alloc::vec::Vec;
298    use core::cell::RefCell;
299    use derivative::Derivative;
300    use net_declare::{net_ip_v4, net_ip_v6};
301    use net_types::MulticastAddr;
302    use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr, Mtu};
303    use netstack3_base::socket::SocketIpAddr;
304    use netstack3_base::testutil::{FakeStrongDeviceId, MultipleDevicesId};
305    use netstack3_base::{
306        CoreTimerContext, CounterContext, CtxPair, FrameDestination, Marks, ResourceCounterContext,
307    };
308    use netstack3_filter::ProofOfEgressCheck;
309    use netstack3_hashmap::HashSet;
310    use packet::{BufferMut, InnerPacketBuilder, PacketBuilder, Serializer};
311    use packet_formats::ip::{IpPacketBuilder, IpProto};
312
313    use crate::device::IpDeviceSendContext;
314    use crate::internal::base::DeviceIpLayerMetadata;
315    use crate::internal::icmp::IcmpErrorHandler;
316    use crate::multicast_forwarding::{
317        MulticastForwardingApi, MulticastForwardingEnabledState, MulticastForwardingPendingPackets,
318        MulticastForwardingPendingPacketsContext, MulticastForwardingState, MulticastRouteTable,
319        MulticastRouteTableContext,
320    };
321    use crate::{IpCounters, IpDeviceMtuContext, IpLayerEvent, IpPacketDestination};
322
323    /// An IP extension trait providing constants for various IP addresses.
324    pub(crate) trait TestIpExt: IpLayerIpExt {
325        const SRC1: Self::Addr;
326        const SRC2: Self::Addr;
327        const DST1: Self::Addr;
328        const DST2: Self::Addr;
329    }
330
331    impl TestIpExt for Ipv4 {
332        const SRC1: Ipv4Addr = net_ip_v4!("192.0.2.1");
333        const SRC2: Ipv4Addr = net_ip_v4!("192.0.2.2");
334        const DST1: Ipv4Addr = net_ip_v4!("224.0.1.1");
335        const DST2: Ipv4Addr = net_ip_v4!("224.0.1.2");
336    }
337
338    impl TestIpExt for Ipv6 {
339        const SRC1: Ipv6Addr = net_ip_v6!("2001:0DB8::1");
340        const SRC2: Ipv6Addr = net_ip_v6!("2001:0DB8::2");
341        const DST1: Ipv6Addr = net_ip_v6!("ff0e::1");
342        const DST2: Ipv6Addr = net_ip_v6!("ff0e::2");
343    }
344
345    /// Constructs a buffer containing an IP packet with sensible defaults.
346    pub(crate) fn new_ip_packet_buf<I: IpLayerIpExt>(
347        src_addr: I::Addr,
348        dst_addr: I::Addr,
349    ) -> impl AsRef<[u8]> {
350        const TTL: u8 = 255;
351        /// Arbitrary data to put inside of an IP packet.
352        const IP_BODY: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
353        I::PacketBuilder::new(src_addr, dst_addr, TTL, IpProto::Udp.into())
354            .wrap_body(IP_BODY.into_serializer())
355            .serialize_vec_outer()
356            .unwrap()
357    }
358
359    #[derive(Debug, PartialEq)]
360    pub(crate) struct SentPacket<I: IpLayerIpExt, D> {
361        pub(crate) dst: MulticastAddr<I::Addr>,
362        pub(crate) device: D,
363    }
364
365    #[derive(Derivative)]
366    #[derivative(Default(bound = ""))]
367    pub(crate) struct FakeCoreCtxState<I: IpLayerIpExt, D: FakeStrongDeviceId> {
368        // NB: Hold in an `Rc<RefCell<...>>` to switch to runtime borrow
369        // checking. This allows us to borrow the multicast forwarding state at
370        // the same time as the outer `FakeCoreCtx` is mutably borrowed.
371        pub(crate) multicast_forwarding:
372            Rc<RefCell<MulticastForwardingState<I, D, FakeBindingsCtx<I, D>>>>,
373        // The list of devices that have multicast forwarding enabled.
374        pub(crate) forwarding_enabled_devices: HashSet<D>,
375        // The list of packets sent by the netstack.
376        pub(crate) sent_packets: Vec<SentPacket<I, D>>,
377        stack_wide_counters: IpCounters<I>,
378        per_device_counters: IpCounters<I>,
379        multicast_forwarding_counters: MulticastForwardingCounters<I>,
380    }
381
382    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> FakeCoreCtxState<I, D> {
383        pub(crate) fn set_multicast_forwarding_enabled_for_dev(&mut self, dev: D, enabled: bool) {
384            if enabled {
385                let _: bool = self.forwarding_enabled_devices.insert(dev);
386            } else {
387                let _: bool = self.forwarding_enabled_devices.remove(&dev);
388            }
389        }
390
391        pub(crate) fn take_sent_packets(&mut self) -> Vec<SentPacket<I, D>> {
392            core::mem::take(&mut self.sent_packets)
393        }
394    }
395
396    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> CounterContext<IpCounters<I>>
397        for FakeCoreCtxState<I, D>
398    {
399        fn counters(&self) -> &IpCounters<I> {
400            &self.stack_wide_counters
401        }
402    }
403
404    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> ResourceCounterContext<D, IpCounters<I>>
405        for FakeCoreCtxState<I, D>
406    {
407        fn per_resource_counters(&self, _resource: &D) -> &IpCounters<I> {
408            &self.per_device_counters
409        }
410    }
411
412    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> CounterContext<MulticastForwardingCounters<I>>
413        for FakeCoreCtxState<I, D>
414    {
415        fn counters(&self) -> &MulticastForwardingCounters<I> {
416            &self.multicast_forwarding_counters
417        }
418    }
419
420    pub(crate) type FakeBindingsCtx<I, D> = netstack3_base::testutil::FakeBindingsCtx<
421        MulticastForwardingTimerId<I>,
422        IpLayerEvent<D, I>,
423        (),
424        (),
425    >;
426    pub(crate) type FakeCoreCtx<I, D> =
427        netstack3_base::testutil::FakeCoreCtx<FakeCoreCtxState<I, D>, (), D>;
428
429    impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
430        MulticastForwardingStateContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
431    {
432        type Ctx<'a> = FakeCoreCtx<I, D>;
433        fn with_state<
434            O,
435            F: FnOnce(
436                &MulticastForwardingState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
437                &mut Self::Ctx<'_>,
438            ) -> O,
439        >(
440            &mut self,
441            cb: F,
442        ) -> O {
443            let state = self.state.multicast_forwarding.clone();
444            let borrow = state.borrow();
445            cb(&borrow, self)
446        }
447        fn with_state_mut<
448            O,
449            F: FnOnce(
450                &mut MulticastForwardingState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
451                &mut Self::Ctx<'_>,
452            ) -> O,
453        >(
454            &mut self,
455            cb: F,
456        ) -> O {
457            let state = self.state.multicast_forwarding.clone();
458            let mut borrow = state.borrow_mut();
459            cb(&mut borrow, self)
460        }
461    }
462
463    impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
464        MulticastRouteTableContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
465    {
466        type Ctx<'a> = FakeCoreCtx<I, D>;
467        fn with_route_table<
468            O,
469            F: FnOnce(
470                &MulticastRouteTable<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
471                &mut Self::Ctx<'_>,
472            ) -> O,
473        >(
474            &mut self,
475            state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
476            cb: F,
477        ) -> O {
478            let route_table = state.route_table().read();
479            cb(&route_table, self)
480        }
481        fn with_route_table_mut<
482            O,
483            F: FnOnce(
484                &mut MulticastRouteTable<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
485                &mut Self::Ctx<'_>,
486            ) -> O,
487        >(
488            &mut self,
489            state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
490            cb: F,
491        ) -> O {
492            let mut route_table = state.route_table().write();
493            cb(&mut route_table, self)
494        }
495    }
496
497    impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
498        MulticastForwardingPendingPacketsContext<I, FakeBindingsCtx<I, D>> for FakeCoreCtx<I, D>
499    {
500        fn with_pending_table_mut<
501            O,
502            F: FnOnce(
503                &mut MulticastForwardingPendingPackets<I, Self::WeakDeviceId, FakeBindingsCtx<I, D>>,
504            ) -> O,
505        >(
506            &mut self,
507            state: &MulticastForwardingEnabledState<I, Self::DeviceId, FakeBindingsCtx<I, D>>,
508            cb: F,
509        ) -> O {
510            let mut pending_table = state.pending_table().lock();
511            cb(&mut pending_table)
512        }
513    }
514
515    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> MulticastForwardingDeviceContext<I>
516        for FakeCoreCtx<I, D>
517    {
518        fn is_device_multicast_forwarding_enabled(&mut self, device_id: &Self::DeviceId) -> bool {
519            self.state.forwarding_enabled_devices.contains(device_id)
520        }
521    }
522
523    impl<I: IpLayerIpExt, D: FakeStrongDeviceId>
524        CoreTimerContext<MulticastForwardingTimerId<I>, FakeBindingsCtx<I, D>>
525        for FakeCoreCtx<I, D>
526    {
527        fn convert_timer(
528            dispatch_id: MulticastForwardingTimerId<I>,
529        ) -> MulticastForwardingTimerId<I> {
530            dispatch_id
531        }
532    }
533
534    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IpDeviceSendContext<I, FakeBindingsCtx<I, D>>
535        for FakeCoreCtx<I, D>
536    {
537        fn send_ip_frame<S>(
538            &mut self,
539            _bindings_ctx: &mut FakeBindingsCtx<I, D>,
540            device_id: &D,
541            destination: IpPacketDestination<I, &D>,
542            _ip_layer_metadata: DeviceIpLayerMetadata<FakeBindingsCtx<I, D>>,
543            _body: S,
544            _egress_proof: ProofOfEgressCheck,
545        ) -> Result<(), netstack3_base::SendFrameError<S>>
546        where
547            S: Serializer,
548            S::Buffer: BufferMut,
549        {
550            let dst = match destination {
551                IpPacketDestination::Multicast(dst) => dst,
552                dst => panic!("unexpected sent packet: destination={dst:?}"),
553            };
554            self.state.sent_packets.push(SentPacket { dst, device: device_id.clone() });
555            Ok(())
556        }
557    }
558
559    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IpDeviceMtuContext<I> for FakeCoreCtx<I, D> {
560        fn get_mtu(&mut self, _device_id: &Self::DeviceId) -> Mtu {
561            Mtu::max()
562        }
563    }
564
565    impl<I: IpLayerIpExt, D: FakeStrongDeviceId> IcmpErrorHandler<I, FakeBindingsCtx<I, D>>
566        for FakeCoreCtx<I, D>
567    {
568        fn send_icmp_error_message<B: BufferMut>(
569            &mut self,
570            _bindings_ctx: &mut FakeBindingsCtx<I, D>,
571            _device: Option<&D>,
572            _frame_dst: Option<FrameDestination>,
573            _src_ip: SocketIpAddr<I::Addr>,
574            _dst_ip: SocketIpAddr<I::Addr>,
575            _original_packet: B,
576            _error: I::IcmpError,
577            _header_len: usize,
578            _proto: I::Proto,
579            _marks: &Marks,
580        ) {
581            unimplemented!()
582        }
583    }
584
585    pub(crate) fn new_api<I: IpLayerIpExt>() -> MulticastForwardingApi<
586        I,
587        CtxPair<FakeCoreCtx<I, MultipleDevicesId>, FakeBindingsCtx<I, MultipleDevicesId>>,
588    > {
589        MulticastForwardingApi::new(CtxPair::with_core_ctx(FakeCoreCtx::with_state(
590            Default::default(),
591        )))
592    }
593
594    /// A test helper to access the [`MulticastForwardingPendingPackets`] table.
595    ///
596    /// # Panics
597    ///
598    /// Panics if multicast forwarding is disabled.
599    pub(crate) fn with_pending_table<I, O, F, CC, BT>(core_ctx: &mut CC, cb: F) -> O
600    where
601        I: IpLayerIpExt,
602        CC: MulticastForwardingStateContext<I, BT>,
603        BT: MulticastForwardingBindingsTypes,
604        F: FnOnce(&mut MulticastForwardingPendingPackets<I, CC::WeakDeviceId, BT>) -> O,
605    {
606        core_ctx.with_state(|state, ctx| {
607            let state = state.enabled().unwrap();
608            ctx.with_route_table(state, |_routing_table, ctx| {
609                ctx.with_pending_table_mut(state, |pending_table| cb(pending_table))
610            })
611        })
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    use alloc::vec;
620    use core::time::Duration;
621
622    use ip_test_macro::ip_test;
623    use netstack3_base::testutil::MultipleDevicesId;
624    use packet::ParseBuffer;
625    use test_case::test_case;
626    use testutil::TestIpExt;
627
628    use crate::internal::multicast_forwarding::route::MulticastRouteStats;
629    use crate::multicast_forwarding::MulticastRouteTarget;
630
631    struct LookupTestCase {
632        // Whether multicast forwarding is enabled for the netstack.
633        enabled: bool,
634        // Whether multicast forwarding is enabled for the device.
635        dev_enabled: bool,
636        // Whether the packet has the correct src/dst addrs.
637        right_key: bool,
638        // Whether the packet arrived on the correct device.
639        right_dev: bool,
640    }
641    const LOOKUP_SUCCESS_CASE: LookupTestCase =
642        LookupTestCase { enabled: true, dev_enabled: true, right_key: true, right_dev: true };
643
644    #[ip_test(I)]
645    #[test_case(LOOKUP_SUCCESS_CASE => true; "success")]
646    #[test_case(LookupTestCase{enabled: false, ..LOOKUP_SUCCESS_CASE} => false; "disabled")]
647    #[test_case(LookupTestCase{dev_enabled: false, ..LOOKUP_SUCCESS_CASE} => false; "dev_disabled")]
648    #[test_case(LookupTestCase{right_key: false, ..LOOKUP_SUCCESS_CASE} => false; "wrong_key")]
649    #[test_case(LookupTestCase{right_dev: false, ..LOOKUP_SUCCESS_CASE} => false; "wrong_dev")]
650    fn lookup_route<I: TestIpExt>(test_case: LookupTestCase) -> bool {
651        let LookupTestCase { enabled, dev_enabled, right_key, right_dev } = test_case;
652        const FRAME_DST: Option<FrameDestination> = None;
653        let mut api = testutil::new_api::<I>();
654
655        let expected_key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
656        let actual_key = if right_key {
657            expected_key.clone()
658        } else {
659            MulticastRouteKey::new(I::SRC2, I::DST2).unwrap()
660        };
661
662        let expected_dev = MultipleDevicesId::A;
663        let actual_dev = if right_dev { expected_dev } else { MultipleDevicesId::B };
664
665        if enabled {
666            assert!(api.enable());
667            // NB: Only attempt to install the route when enabled; Otherwise
668            // installation fails.
669            assert_eq!(
670                api.add_multicast_route(
671                    expected_key.clone(),
672                    MulticastRoute::new_forward(
673                        expected_dev,
674                        [MulticastRouteTarget {
675                            output_interface: MultipleDevicesId::C,
676                            min_ttl: 0
677                        }]
678                        .into()
679                    )
680                    .unwrap()
681                ),
682                Ok(None)
683            );
684        }
685
686        api.core_ctx().state.set_multicast_forwarding_enabled_for_dev(actual_dev, dev_enabled);
687
688        let (core_ctx, bindings_ctx) = api.contexts();
689        let creation_time = bindings_ctx.now();
690        bindings_ctx.timers.instant.sleep(Duration::from_secs(5));
691        let lookup_time = bindings_ctx.now();
692        assert!(lookup_time > creation_time);
693
694        let buf = testutil::new_ip_packet_buf::<I>(actual_key.src_addr(), actual_key.dst_addr());
695        let mut buf_ref = buf.as_ref();
696        let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
697
698        let route = lookup_multicast_route_or_stash_packet(
699            core_ctx,
700            bindings_ctx,
701            &packet,
702            &actual_dev,
703            FRAME_DST,
704        );
705
706        // Verify that multicast routing events are generated.
707        let mut expected_events = vec![];
708        if !right_key {
709            expected_events.push(IpLayerEvent::MulticastForwarding(
710                MulticastForwardingEvent::MissingRoute {
711                    key: actual_key.clone(),
712                    input_interface: actual_dev,
713                },
714            ));
715        }
716        if !right_dev {
717            expected_events.push(IpLayerEvent::MulticastForwarding(
718                MulticastForwardingEvent::WrongInputInterface {
719                    key: actual_key,
720                    actual_input_interface: actual_dev,
721                    expected_input_interface: expected_dev,
722                },
723            ));
724        }
725        assert_eq!(bindings_ctx.take_events(), expected_events);
726
727        let lookup_succeeded = route.is_some();
728
729        if enabled {
730            // Verify that on success, the last_used field in stats is updated.
731            let expected_stats = if lookup_succeeded {
732                MulticastRouteStats { last_used: lookup_time }
733            } else {
734                MulticastRouteStats { last_used: creation_time }
735            };
736            assert_eq!(api.get_route_stats(&expected_key), Ok(Some(expected_stats)));
737        }
738
739        // Verify that counters are updated.
740        let counters: &MulticastForwardingCounters<I> = api.core_ctx().counters();
741        assert_eq!(counters.rx.get(), 1);
742        assert_eq!(counters.tx.get(), if lookup_succeeded { 1 } else { 0 });
743        assert_eq!(counters.no_tx_disabled_dev.get(), if dev_enabled { 0 } else { 1 });
744        assert_eq!(counters.no_tx_disabled_stack_wide.get(), if enabled { 0 } else { 1 });
745        assert_eq!(counters.no_tx_wrong_dev.get(), if right_dev { 0 } else { 1 });
746        assert_eq!(counters.pending_packets.get(), if right_key { 0 } else { 1 });
747
748        lookup_succeeded
749    }
750}