netstack3_ip/multicast_forwarding/
api.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Declares the API for configuring multicast forwarding within the netstack.
6
7use alloc::collections::btree_map;
8use core::sync::atomic::Ordering;
9
10use log::warn;
11use net_types::ip::{Ip, IpVersionMarker};
12use net_types::SpecifiedAddr;
13use netstack3_base::{
14    AnyDevice, AtomicInstant, ContextPair, CoreTimerContext, CounterContext, DeviceIdContext,
15    Inspector, InspectorDeviceExt, InstantBindingsTypes, InstantContext, StrongDeviceIdentifier,
16    WeakDeviceIdentifier,
17};
18
19use crate::internal::base::IpLayerForwardingContext;
20use crate::internal::multicast_forwarding::counters::MulticastForwardingCounters;
21use crate::internal::multicast_forwarding::packet_queue::{PacketQueue, QueuedPacket};
22use crate::internal::multicast_forwarding::route::{
23    Action, MulticastRoute, MulticastRouteEntry, MulticastRouteKey, MulticastRouteStats,
24    MulticastRouteTarget,
25};
26use crate::internal::multicast_forwarding::state::{
27    MulticastForwardingEnabledState, MulticastForwardingPendingPacketsContext as _,
28    MulticastForwardingState, MulticastForwardingStateContext, MulticastRouteTableContext as _,
29};
30use crate::internal::multicast_forwarding::{
31    MulticastForwardingBindingsTypes, MulticastForwardingDeviceContext, MulticastForwardingEvent,
32    MulticastForwardingTimerId,
33};
34use crate::{IpLayerBindingsContext, IpLayerIpExt, IpPacketDestination};
35
36/// The API action can not be performed while multicast forwarding is disabled.
37#[derive(Debug, Eq, PartialEq)]
38pub struct MulticastForwardingDisabledError {}
39
40trait MulticastForwardingStateExt<
41    I: IpLayerIpExt,
42    D: StrongDeviceIdentifier,
43    BT: MulticastForwardingBindingsTypes,
44>
45{
46    fn try_enabled(
47        &self,
48    ) -> Result<&MulticastForwardingEnabledState<I, D, BT>, MulticastForwardingDisabledError>;
49}
50
51impl<I: IpLayerIpExt, D: StrongDeviceIdentifier, BT: MulticastForwardingBindingsTypes>
52    MulticastForwardingStateExt<I, D, BT> for MulticastForwardingState<I, D, BT>
53{
54    fn try_enabled(
55        &self,
56    ) -> Result<&MulticastForwardingEnabledState<I, D, BT>, MulticastForwardingDisabledError> {
57        self.enabled().ok_or(MulticastForwardingDisabledError {})
58    }
59}
60
61/// The multicast forwarding API.
62pub struct MulticastForwardingApi<I: Ip, C> {
63    ctx: C,
64    _ip_mark: IpVersionMarker<I>,
65}
66
67impl<I: Ip, C> MulticastForwardingApi<I, C> {
68    /// Constructs a new multicast forwarding API.
69    pub fn new(ctx: C) -> Self {
70        Self { ctx, _ip_mark: IpVersionMarker::new() }
71    }
72}
73
74impl<I: IpLayerIpExt, C> MulticastForwardingApi<I, C>
75where
76    C: ContextPair,
77    C::CoreContext: MulticastForwardingStateContext<I, C::BindingsContext>
78        + MulticastForwardingDeviceContext<I>
79        + IpLayerForwardingContext<I, C::BindingsContext>
80        + CounterContext<MulticastForwardingCounters<I>>
81        + CoreTimerContext<MulticastForwardingTimerId<I>, C::BindingsContext>,
82    C::BindingsContext:
83        IpLayerBindingsContext<I, <C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
84{
85    pub(crate) fn core_ctx(&mut self) -> &mut C::CoreContext {
86        let Self { ctx, _ip_mark } = self;
87        ctx.core_ctx()
88    }
89
90    pub(crate) fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
91        let Self { ctx, _ip_mark } = self;
92        ctx.contexts()
93    }
94
95    /// Enables multicast forwarding.
96    ///
97    /// Returns whether multicast forwarding was newly enabled.
98    pub fn enable(&mut self) -> bool {
99        let (core_ctx, bindings_ctx) = self.contexts();
100        core_ctx.with_state_mut(|state, _ctx| match state {
101            MulticastForwardingState::Enabled(_) => false,
102            MulticastForwardingState::Disabled => {
103                *state = MulticastForwardingState::Enabled(MulticastForwardingEnabledState::new::<
104                    C::CoreContext,
105                >(bindings_ctx));
106                true
107            }
108        })
109    }
110
111    /// Disables multicast forwarding.
112    ///
113    /// Returns whether multicast forwarding was newly disabled.
114    ///
115    /// Upon being disabled, the multicast route table will be cleared,
116    /// and all pending packets will be dropped.
117    pub fn disable(&mut self) -> bool {
118        self.core_ctx().with_state_mut(|state, _ctx| match state {
119            MulticastForwardingState::Disabled => false,
120            MulticastForwardingState::Enabled(_) => {
121                *state = MulticastForwardingState::Disabled;
122                true
123            }
124        })
125    }
126
127    /// Add the route to the multicast route table.
128    ///
129    /// If a route already exists with the same key, it will be replaced, and
130    /// the original route will be returned.
131    pub fn add_multicast_route(
132        &mut self,
133        key: MulticastRouteKey<I>,
134        route: MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
135    ) -> Result<
136        Option<MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>>,
137        MulticastForwardingDisabledError,
138    > {
139        let (core_ctx, bindings_ctx) = self.contexts();
140        let (orig_route, packet_queue_and_new_route) = core_ctx.with_state_mut(|state, ctx| {
141            let state = state.try_enabled()?;
142            ctx.with_route_table_mut(state, |route_table, ctx| {
143                let stats = MulticastRouteStats { last_used: bindings_ctx.now_atomic() };
144                match route_table.entry(key.clone()) {
145                    btree_map::Entry::Occupied(mut entry) => {
146                        // NB: We consider the stats to be associated with the
147                        // `route` rather than the route's key. As such we
148                        // replace the stats instead of preserving them.
149                        let MulticastRouteEntry { route: orig_route, stats: _ } =
150                            entry.insert(MulticastRouteEntry { route, stats });
151                        // NB: Check the invariant that any key present in the
152                        // route table is not also present in the pending table.
153                        #[cfg(debug_assertions)]
154                        ctx.with_pending_table_mut(state, |pending_table| {
155                            debug_assert!(!pending_table.contains(&key));
156                        });
157                        Ok((Some(orig_route), None))
158                    }
159                    btree_map::Entry::Vacant(entry) => {
160                        let MulticastRouteEntry { route: new_route_ref, stats: _ } =
161                            entry.insert(MulticastRouteEntry { route, stats });
162                        let packet_queue_and_new_route = ctx
163                            .with_pending_table_mut(state, |pending_table| {
164                                pending_table.remove(&key, bindings_ctx)
165                            })
166                            .map(|packet_queue| (packet_queue, new_route_ref.clone()));
167                        Ok((None, packet_queue_and_new_route))
168                    }
169                }
170            })
171        })?;
172
173        if let Some((packet_queue, new_route)) = packet_queue_and_new_route {
174            // NB: we cloned the route out to a context that's no longer holding
175            // the routing table lock. This means the route could have been
176            // removed. In general, that's okay. We'll operate on the
177            // potentially stale route as if it still exists. This mirrors the
178            // lookup pattern used by the unicast/multicast route tables in
179            // other parts of the stack.
180            handle_pending_packets(core_ctx, bindings_ctx, packet_queue, key, new_route)
181        }
182
183        Ok(orig_route)
184    }
185
186    /// Remove the route from the multicast route table.
187    ///
188    /// Returns `None` if the route did not exist.
189    pub fn remove_multicast_route(
190        &mut self,
191        key: &MulticastRouteKey<I>,
192    ) -> Result<
193        Option<MulticastRoute<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>>,
194        MulticastForwardingDisabledError,
195    > {
196        self.core_ctx().with_state_mut(|state, ctx| {
197            let state = state.try_enabled()?;
198            ctx.with_route_table_mut(state, |route_table, _ctx| {
199                Ok(route_table.remove(key).map(|MulticastRouteEntry { route, stats: _ }| route))
200            })
201        })
202    }
203
204    /// Remove all references to the device from the multicast forwarding state.
205    ///
206    /// Typically, this is called as part of device removal to purge all strong
207    /// device references.
208    ///
209    /// Any routes that reference the device as an `input_interface` will be
210    /// removed. Any routes that reference the device as a
211    /// [`MulticastRouteTarget`] will have that target removed (and will
212    /// themselves be removed if it's the only target).
213    pub fn remove_references_to_device(
214        &mut self,
215        dev: &<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
216    ) {
217        self.core_ctx().with_state_mut(|state, ctx| {
218            let Some(state) = state.enabled() else {
219                // There's no state to update if forwarding is disabled.
220                return;
221            };
222            ctx.with_route_table_mut(state, |route_table, _ctx| {
223                route_table.retain(
224                    |_route_key,
225                     MulticastRouteEntry {
226                         route: MulticastRoute { action, input_interface },
227                         stats: _,
228                     }| {
229                        if dev == &*input_interface {
230                            return false;
231                        }
232                        match action {
233                            Action::Forward(ref mut targets) => {
234                                // If all targets reference the device, we should
235                                // discard the route entirely.
236                                if targets.iter().all(|target| dev == &target.output_interface) {
237                                    return false;
238                                }
239                                // Otherwise, if any target references the device,
240                                // we should remove it from the set of targets.
241                                if targets.iter().any(|target| dev == &target.output_interface) {
242                                    *targets = targets
243                                        .iter()
244                                        .filter(|target| dev != &target.output_interface)
245                                        .cloned()
246                                        .collect();
247                                }
248                            }
249                        }
250                        true
251                    },
252                )
253            })
254        })
255    }
256
257    /// Returns the [`MulticastRouteStats`], if any, for the given key.
258    pub fn get_route_stats(
259        &mut self,
260        key: &MulticastRouteKey<I>,
261    ) -> Result<
262        Option<MulticastRouteStats<<C::BindingsContext as InstantBindingsTypes>::Instant>>,
263        MulticastForwardingDisabledError,
264    > {
265        self.core_ctx().with_state(|state, ctx| {
266            let state = state.try_enabled()?;
267            ctx.with_route_table(state, |route_table, _ctx| {
268                Ok(route_table.get(key).map(
269                    |MulticastRouteEntry { route: _, stats: MulticastRouteStats { last_used } }| {
270                        MulticastRouteStats { last_used: last_used.load(Ordering::Relaxed) }
271                    },
272                ))
273            })
274        })
275    }
276
277    /// Writes multicast routing table information to the provided `inspector`.
278    pub fn inspect<
279        N: Inspector + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
280    >(
281        &mut self,
282        inspector: &mut N,
283    ) {
284        self.core_ctx().with_state(|state, ctx| match state {
285            MulticastForwardingState::Disabled => {
286                inspector.record_bool("ForwardingEnabled", false);
287            }
288            MulticastForwardingState::Enabled(state) => {
289                inspector.record_bool("ForwardingEnabled", true);
290                inspector.record_child("Routes", |inspector| {
291                    ctx.with_route_table(state, |route_table, _ctx| {
292                        for (route_key, route_entry) in route_table.iter() {
293                            inspector.record_unnamed_child(|inspector| {
294                                inspector.delegate_inspectable(route_key);
295                                route_entry.inspect::<_, N>(inspector);
296                            })
297                        }
298                    })
299                });
300                // NB: All other operations on the pending table require mutable
301                // access; don't bother introducing an immutable accessor just
302                // for inspect.
303                ctx.with_pending_table_mut(state, |pending_table| {
304                    inspector.record_inspectable("PendingRoutes", pending_table);
305                });
306            }
307        })
308    }
309}
310
311/// Attempt to forward the packets from a pending [`PacketQueue`] according to a
312/// newly installed [`MulticastRoute`].
313fn handle_pending_packets<I: IpLayerIpExt, CC, BC>(
314    core_ctx: &mut CC,
315    bindings_ctx: &mut BC,
316    packet_queue: PacketQueue<I, CC::WeakDeviceId, BC>,
317    key: MulticastRouteKey<I>,
318    route: MulticastRoute<CC::DeviceId>,
319) where
320    CC: IpLayerForwardingContext<I, BC>
321        + MulticastForwardingDeviceContext<I>
322        + CounterContext<MulticastForwardingCounters<I>>,
323    BC: IpLayerBindingsContext<I, CC::DeviceId>,
324{
325    let MulticastRoute { input_interface, action } = route;
326
327    // NB: We checked that forwarding was enabled on the device before the
328    // packet was enqueued in the pending table. However, the packet may sit in
329    // the queue for an extended period of time, during which forwarding may
330    // have been disabled on the device. Check again here just in case.
331    if !core_ctx.is_device_multicast_forwarding_enabled(&input_interface) {
332        // The user just installed a multicast route, but also disabled
333        // forwarding on the device. Log a warning because that likely indicates
334        // incorrect API usage.
335        warn!(
336            "Dropping pending packets for newly installed multicast route: {key:?}. \
337            Multicast forwarding is disabled on input interface: {input_interface:?}"
338        );
339        CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx)
340            .pending_packet_drops_disabled_dev
341            .increment();
342        return;
343    }
344
345    let MulticastRouteKey { src_addr, dst_addr } = key.clone();
346    let dst_ip: SpecifiedAddr<I::Addr> = dst_addr.into();
347    let src_ip: I::RecvSrcAddr = src_addr.into();
348
349    for QueuedPacket { device, packet, frame_dst } in packet_queue.into_iter() {
350        let device = match device.upgrade() {
351            // Short circuit if the device was removed while the packet was
352            // pending.
353            None => continue,
354            Some(d) => d,
355        };
356        // Short circuit if the queued packet arrived on the wrong device.
357        if device != input_interface {
358            CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx)
359                .pending_packet_drops_wrong_dev
360                .increment();
361            bindings_ctx.on_event(
362                MulticastForwardingEvent::WrongInputInterface {
363                    key: key.clone(),
364                    actual_input_interface: device.clone(),
365                    expected_input_interface: input_interface.clone(),
366                }
367                .into(),
368            );
369            continue;
370        }
371
372        // NB: We could choose to update the `last_used` value on the route's
373        // statistics here, but that's probably overkill. We only end up in this
374        // function as part of route installation, which will have appropriately
375        // initialized `last_used`. It's not worth re-acquiring the route table
376        // lock to update it again here, as the change in time will be
377        // negligible.
378
379        match &action {
380            Action::Forward(targets) => {
381                CounterContext::<MulticastForwardingCounters<I>>::counters(core_ctx)
382                    .pending_packet_tx
383                    .increment();
384                let packet_iter = RepeatN::new(packet, targets.len());
385                for (mut packet, MulticastRouteTarget { output_interface, min_ttl }) in
386                    packet_iter.zip(targets.iter())
387                {
388                    let packet_metadata = Default::default();
389                    crate::internal::base::determine_ip_packet_forwarding_action::<I, _, _>(
390                        core_ctx,
391                        packet.parse_ip_packet_mut(),
392                        packet_metadata,
393                        Some(*min_ttl),
394                        &input_interface,
395                        &output_interface,
396                        IpPacketDestination::from_addr(dst_ip),
397                        frame_dst,
398                        src_ip,
399                        dst_ip,
400                    )
401                    .perform_action_with_buffer(
402                        core_ctx,
403                        bindings_ctx,
404                        packet.into_inner(),
405                    );
406                }
407            }
408        }
409    }
410}
411
412/// An iterator that repeats a provided item `N` times.
413///
414/// Notably, this iterator will clone the item n-1 times, and move the owned
415/// value into the final item.
416// TODO(https://github.com/rust-lang/rust/issues/104434): Replace this with the
417// standard library version, once it stabilizes.
418struct RepeatN<T> {
419    // `Some` while `size` is greater than 0; `None` otherwise.
420    elem: Option<T>,
421    size: usize,
422}
423
424impl<T> RepeatN<T> {
425    fn new(elem: T, size: usize) -> Self {
426        if size == 0 {
427            Self { elem: None, size }
428        } else {
429            Self { elem: Some(elem), size: size - 1 }
430        }
431    }
432}
433
434impl<T: Clone> Iterator for RepeatN<T> {
435    type Item = T;
436
437    fn next(&mut self) -> Option<T> {
438        let Self { elem, size } = self;
439        if *size > 0 {
440            *size -= 1;
441            Some(elem.as_ref().unwrap().clone())
442        } else {
443            elem.take()
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    use alloc::vec;
453    use core::ops::Deref;
454    use core::time::Duration;
455
456    use assert_matches::assert_matches;
457    use ip_test_macro::ip_test;
458    use net_types::MulticastAddr;
459    use netstack3_base::testutil::MultipleDevicesId;
460    use netstack3_base::{FrameDestination, StrongDeviceIdentifier};
461    use packet::ParseBuffer;
462    use test_case::test_case;
463
464    use crate::internal::multicast_forwarding;
465    use crate::internal::multicast_forwarding::packet_queue::QueuePacketOutcome;
466    use crate::internal::multicast_forwarding::testutil::{SentPacket, TestIpExt};
467    use crate::multicast_forwarding::{MulticastRoute, MulticastRouteKey, MulticastRouteTarget};
468    use crate::IpLayerEvent;
469
470    #[ip_test(I)]
471    fn enable_disable<I: IpLayerIpExt>() {
472        let mut api = multicast_forwarding::testutil::new_api::<I>();
473
474        assert_matches!(
475            api.core_ctx().state.multicast_forwarding.borrow().deref(),
476            &MulticastForwardingState::Disabled
477        );
478        assert!(api.enable());
479        assert!(!api.enable());
480        assert_matches!(
481            api.core_ctx().state.multicast_forwarding.borrow().deref(),
482            &MulticastForwardingState::Enabled(_)
483        );
484        assert!(api.disable());
485        assert!(!api.disable());
486        assert_matches!(
487            api.core_ctx().state.multicast_forwarding.borrow().deref(),
488            &MulticastForwardingState::Disabled
489        );
490    }
491
492    #[ip_test(I)]
493    fn add_remove_route<I: TestIpExt>() {
494        let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
495        let key2 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
496        let forward_to_b = MulticastRoute::new_forward(
497            MultipleDevicesId::A,
498            [MulticastRouteTarget { output_interface: MultipleDevicesId::B, min_ttl: 0 }].into(),
499        )
500        .unwrap();
501        let forward_to_c = MulticastRoute::new_forward(
502            MultipleDevicesId::A,
503            [MulticastRouteTarget { output_interface: MultipleDevicesId::C, min_ttl: 0 }].into(),
504        )
505        .unwrap();
506
507        let mut api = multicast_forwarding::testutil::new_api::<I>();
508
509        // Adding/removing routes before multicast forwarding is enabled should
510        // fail.
511        assert_eq!(
512            api.add_multicast_route(key1.clone(), forward_to_b.clone()),
513            Err(MulticastForwardingDisabledError {})
514        );
515        assert_eq!(api.remove_multicast_route(&key1), Err(MulticastForwardingDisabledError {}));
516
517        // Enable the API and observe success.
518        assert!(api.enable());
519        assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
520        assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_b.clone())));
521
522        // Removing a route that doesn't exist should return `None`.
523        assert_eq!(api.remove_multicast_route(&key1), Ok(None));
524
525        // Adding a route with the same key as an existing route should
526        // overwrite the original.
527        assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
528        assert_eq!(
529            api.add_multicast_route(key1.clone(), forward_to_c.clone()),
530            Ok(Some(forward_to_b.clone()))
531        );
532        assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_c.clone())));
533
534        // Routes with different keys can co-exist.
535        assert_eq!(api.add_multicast_route(key1.clone(), forward_to_b.clone()), Ok(None));
536        assert_eq!(api.add_multicast_route(key2.clone(), forward_to_c.clone()), Ok(None));
537        assert_eq!(api.remove_multicast_route(&key1), Ok(Some(forward_to_b)));
538        assert_eq!(api.remove_multicast_route(&key2), Ok(Some(forward_to_c)));
539    }
540
541    #[ip_test(I)]
542    #[test_case(false, true; "forwarding_disabled")]
543    #[test_case(true, false; "forwarding_enabled_and_wrong_dev")]
544    #[test_case(true, true; "forwarding_enabled_and_right_dev")]
545    fn add_route_with_pending_packets<I: TestIpExt>(
546        forwarding_enabled_for_dev: bool,
547        right_dev: bool,
548    ) {
549        const FRAME_DST: Option<FrameDestination> = None;
550        const OUTPUT_DEV: MultipleDevicesId = MultipleDevicesId::C;
551        let right_key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
552        let wrong_key = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
553        let expected_dev = MultipleDevicesId::A;
554        let actual_dev = if right_dev { expected_dev } else { MultipleDevicesId::B };
555
556        let route = MulticastRoute::new_forward(
557            expected_dev,
558            [MulticastRouteTarget { output_interface: OUTPUT_DEV, min_ttl: 0 }].into(),
559        )
560        .unwrap();
561
562        let mut api = multicast_forwarding::testutil::new_api::<I>();
563        assert!(api.enable());
564        api.core_ctx()
565            .state
566            .set_multicast_forwarding_enabled_for_dev(expected_dev, forwarding_enabled_for_dev);
567
568        // Setup a queued packet for `right_key`.
569        let (core_ctx, bindings_ctx) = api.contexts();
570        multicast_forwarding::testutil::with_pending_table(core_ctx, |pending_table| {
571            let buf = multicast_forwarding::testutil::new_ip_packet_buf::<I>(I::SRC1, I::DST1);
572            let mut buf_ref = buf.as_ref();
573            let packet = buf_ref.parse::<I::Packet<_>>().expect("parse should succeed");
574            assert_eq!(
575                pending_table.try_queue_packet(
576                    bindings_ctx,
577                    right_key.clone(),
578                    &packet,
579                    &actual_dev,
580                    FRAME_DST
581                ),
582                QueuePacketOutcome::QueuedInNewQueue,
583            );
584        });
585
586        // Add a route with the wrong key and expect that the packet queue is
587        // unaffected.
588        assert_eq!(api.add_multicast_route(wrong_key, route.clone()), Ok(None));
589        assert!(multicast_forwarding::testutil::with_pending_table(
590            api.core_ctx(),
591            |pending_table| pending_table.contains(&right_key)
592        ));
593
594        // Add a route with the right key and expect that the packet queue is
595        // removed.
596        assert_eq!(api.add_multicast_route(right_key.clone(), route), Ok(None));
597        assert!(multicast_forwarding::testutil::with_pending_table(
598            api.core_ctx(),
599            |pending_table| !pending_table.contains(&right_key)
600        ));
601
602        let expect_sent_packet = forwarding_enabled_for_dev && right_dev;
603        let mut expected_sent_packets = vec![];
604        if expect_sent_packet {
605            expected_sent_packets.push(SentPacket {
606                dst: MulticastAddr::new(right_key.dst_addr()).unwrap(),
607                device: OUTPUT_DEV,
608            });
609        }
610        assert_eq!(api.core_ctx().state.take_sent_packets(), expected_sent_packets);
611
612        // Verify that multicast routing events are generated.
613        let mut expected_events = vec![];
614        if !right_dev {
615            expected_events.push(IpLayerEvent::MulticastForwarding(
616                MulticastForwardingEvent::WrongInputInterface {
617                    key: right_key,
618                    actual_input_interface: actual_dev,
619                    expected_input_interface: expected_dev,
620                },
621            ));
622        }
623
624        let (_core_ctx, bindings_ctx) = api.contexts();
625        assert_eq!(bindings_ctx.take_events(), expected_events);
626
627        // Verify that counters are updated.
628        let counters: &MulticastForwardingCounters<I> = api.core_ctx().counters();
629        assert_eq!(counters.pending_packet_tx.get(), if expect_sent_packet { 1 } else { 0 });
630        assert_eq!(
631            counters.pending_packet_drops_disabled_dev.get(),
632            if forwarding_enabled_for_dev { 0 } else { 1 }
633        );
634        assert_eq!(counters.pending_packet_drops_wrong_dev.get(), if right_dev { 0 } else { 1 });
635    }
636
637    #[ip_test(I)]
638    fn remove_references_to_device<I: TestIpExt>() {
639        // NB: 4 arbitrary keys, that are unique from each other.
640        let key1 = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
641        let key2 = MulticastRouteKey::new(I::SRC2, I::DST1).unwrap();
642        let key3 = MulticastRouteKey::new(I::SRC1, I::DST2).unwrap();
643        let key4 = MulticastRouteKey::new(I::SRC2, I::DST2).unwrap();
644
645        // Create 4 routes, each exercising a different edge case.
646        const GOOD_DEV1: MultipleDevicesId = MultipleDevicesId::A;
647        const GOOD_DEV2: MultipleDevicesId = MultipleDevicesId::B;
648        const BAD_DEV: MultipleDevicesId = MultipleDevicesId::C;
649        const GOOD_TARGET1: MulticastRouteTarget<MultipleDevicesId> =
650            MulticastRouteTarget { output_interface: GOOD_DEV1, min_ttl: 0 };
651        const GOOD_TARGET2: MulticastRouteTarget<MultipleDevicesId> =
652            MulticastRouteTarget { output_interface: GOOD_DEV2, min_ttl: 0 };
653        const BAD_TARGET: MulticastRouteTarget<MultipleDevicesId> =
654            MulticastRouteTarget { output_interface: BAD_DEV, min_ttl: 0 };
655        let dev_is_input = MulticastRoute::new_forward(BAD_DEV, [GOOD_TARGET1].into()).unwrap();
656        let dev_is_only_output =
657            MulticastRoute::new_forward(GOOD_DEV1, [BAD_TARGET].into()).unwrap();
658        let dev_is_one_output =
659            MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2, BAD_TARGET].into()).unwrap();
660        let no_ref_to_dev = MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2].into()).unwrap();
661
662        // Verify that removing device references is a no-op when multicast
663        // forwarding is disabled.
664        let mut api = multicast_forwarding::testutil::new_api::<I>();
665        api.remove_references_to_device(&BAD_DEV.downgrade());
666        assert!(api.enable());
667
668        // Add the four routes, remove references to `Dev`, and verify that:
669        // * `dev_is_input` & `dev_is_only_output`, were both removed.
670        // * `dev_is_one_output` was updated to not list the dev in its
671        //    targets.
672        // * `no_ref_to_dev` was not updated.
673        assert_eq!(api.add_multicast_route(key1.clone(), dev_is_input), Ok(None));
674        assert_eq!(api.add_multicast_route(key2.clone(), dev_is_only_output), Ok(None));
675        assert_eq!(api.add_multicast_route(key3.clone(), dev_is_one_output), Ok(None));
676        assert_eq!(api.add_multicast_route(key4.clone(), no_ref_to_dev.clone()), Ok(None));
677        api.remove_references_to_device(&BAD_DEV.downgrade());
678        assert_eq!(api.remove_multicast_route(&key1), Ok(None));
679        assert_eq!(api.remove_multicast_route(&key2), Ok(None));
680        // NB: Equal to `dev_is_one_output`, but with `BAD_TARGET` removed.
681        assert_eq!(
682            api.remove_multicast_route(&key3),
683            Ok(Some(MulticastRoute::new_forward(GOOD_DEV1, [GOOD_TARGET2].into()).unwrap()))
684        );
685        assert_eq!(api.remove_multicast_route(&key4), Ok(Some(no_ref_to_dev)));
686    }
687
688    #[ip_test(I)]
689    fn get_route_stats<I: TestIpExt>() {
690        let key = MulticastRouteKey::new(I::SRC1, I::DST1).unwrap();
691
692        let mut api = multicast_forwarding::testutil::new_api::<I>();
693
694        // Verify that get_route_stats fails when forwarding is disabled.
695        assert_eq!(api.get_route_stats(&key), Err(MulticastForwardingDisabledError {}));
696
697        // Verify that get_route_stats returns `None` if the route doesn't exist.
698        assert!(api.enable());
699        assert_eq!(api.get_route_stats(&key), Ok(None));
700
701        // Install a route and verify that get_route_stats succeeds.
702        let route = MulticastRoute::new_forward(
703            MultipleDevicesId::A,
704            [MulticastRouteTarget { output_interface: MultipleDevicesId::B, min_ttl: 0 }].into(),
705        )
706        .unwrap();
707        assert_eq!(api.add_multicast_route(key.clone(), route.clone()), Ok(None));
708        let original_time = api.ctx.bindings_ctx().now();
709        let expected_stats = MulticastRouteStats { last_used: original_time };
710        assert_eq!(api.get_route_stats(&key), Ok(Some(expected_stats)));
711
712        // Advance the timer and overwrite the route to prove we initialize
713        // stats with an up-to-date instant.
714        api.ctx.bindings_ctx().timers.instant.sleep(Duration::from_secs(5));
715        let new_time = api.ctx.bindings_ctx().now();
716        assert!(new_time > original_time);
717        let expected_stats = MulticastRouteStats { last_used: new_time };
718        assert_eq!(api.add_multicast_route(key.clone(), route.clone()), Ok(Some(route)));
719        assert_eq!(api.get_route_stats(&key), Ok(Some(expected_stats)));
720    }
721
722    #[test_case(0)]
723    #[test_case(1)]
724    #[test_case(10)]
725    fn repeat_n(size: usize) {
726        #[derive(Clone)]
727        struct Foo;
728        assert_eq!(RepeatN::new(Foo, size).count(), size);
729    }
730}