#[cfg(test)]
macro_rules! assert_gmp_state {
($ctx:expr, $group:expr, NonMember) => {
assert_gmp_state!(@inner $ctx, $group, crate::internal::gmp::v1::MemberState::NonMember(_));
};
($ctx:expr, $group:expr, Delaying) => {
assert_gmp_state!(@inner $ctx, $group, crate::internal::gmp::v1::MemberState::Delaying(_));
};
(@inner $ctx:expr, $group:expr, $pattern:pat) => {
assert!(matches!($ctx.state.groups().get($group).unwrap().v1().inner.as_ref().unwrap(), $pattern))
};
}
pub(crate) mod igmp;
pub(crate) mod mld;
#[cfg(test)]
mod testutil;
mod v1;
mod v2;
use core::fmt::Debug;
use core::num::NonZeroU64;
use core::time::Duration;
use assert_matches::assert_matches;
use log::info;
use net_types::ip::{Ip, IpAddress, IpVersionMarker};
use net_types::MulticastAddr;
use netstack3_base::ref_counted_hash_map::{InsertResult, RefCountedHashMap, RemoveResult};
use netstack3_base::{
AnyDevice, CoreTimerContext, DeviceIdContext, InspectableValue, Inspector,
InstantBindingsTypes, LocalTimerHeap, RngContext, TimerBindingsTypes, TimerContext,
WeakDeviceIdentifier,
};
use rand::Rng;
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub enum GroupJoinResult<O = ()> {
Joined(O),
AlreadyMember,
}
impl<O> GroupJoinResult<O> {
pub(crate) fn map<P, F: FnOnce(O) -> P>(self, f: F) -> GroupJoinResult<P> {
match self {
GroupJoinResult::Joined(output) => GroupJoinResult::Joined(f(output)),
GroupJoinResult::AlreadyMember => GroupJoinResult::AlreadyMember,
}
}
}
impl<O> From<InsertResult<O>> for GroupJoinResult<O> {
fn from(result: InsertResult<O>) -> Self {
match result {
InsertResult::Inserted(output) => GroupJoinResult::Joined(output),
InsertResult::AlreadyPresent => GroupJoinResult::AlreadyMember,
}
}
}
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub enum GroupLeaveResult<T = ()> {
Left(T),
StillMember,
NotMember,
}
impl<T> GroupLeaveResult<T> {
pub(crate) fn map<U, F: FnOnce(T) -> U>(self, f: F) -> GroupLeaveResult<U> {
match self {
GroupLeaveResult::Left(value) => GroupLeaveResult::Left(f(value)),
GroupLeaveResult::StillMember => GroupLeaveResult::StillMember,
GroupLeaveResult::NotMember => GroupLeaveResult::NotMember,
}
}
}
impl<T> From<RemoveResult<T>> for GroupLeaveResult<T> {
fn from(result: RemoveResult<T>) -> Self {
match result {
RemoveResult::Removed(value) => GroupLeaveResult::Left(value),
RemoveResult::StillPresent => GroupLeaveResult::StillMember,
RemoveResult::NotPresent => GroupLeaveResult::NotMember,
}
}
}
#[cfg_attr(test, derive(Debug))]
pub struct MulticastGroupSet<A: IpAddress, T> {
inner: RefCountedHashMap<MulticastAddr<A>, T>,
}
impl<A: IpAddress, T> Default for MulticastGroupSet<A, T> {
fn default() -> MulticastGroupSet<A, T> {
MulticastGroupSet { inner: RefCountedHashMap::default() }
}
}
impl<A: IpAddress, T> MulticastGroupSet<A, T> {
fn groups_mut(&mut self) -> impl Iterator<Item = (&MulticastAddr<A>, &mut T)> + '_ {
self.inner.iter_mut()
}
fn join_group_with<O, F: FnOnce() -> (T, O)>(
&mut self,
group: MulticastAddr<A>,
f: F,
) -> GroupJoinResult<O> {
self.inner.insert_with(group, f).into()
}
fn leave_group(&mut self, group: MulticastAddr<A>) -> GroupLeaveResult<T> {
self.inner.remove(group).into()
}
pub(crate) fn contains(&self, group: &MulticastAddr<A>) -> bool {
self.inner.contains_key(group)
}
#[cfg(test)]
fn get(&self, group: &MulticastAddr<A>) -> Option<&T> {
self.inner.get(group)
}
fn get_mut(&mut self, group: &MulticastAddr<A>) -> Option<&mut T> {
self.inner.get_mut(group)
}
fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a MulticastAddr<A>, &'a mut T)> {
self.inner.iter_mut()
}
fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a MulticastAddr<A>, &'a T)> + Clone {
self.inner.iter()
}
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<A: IpAddress, T> InspectableValue for MulticastGroupSet<A, T> {
fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
inspector.record_child(name, |inspector| {
for (addr, ref_count) in self.inner.iter_ref_counts() {
inspector.record_display_child(addr, |inspector| {
inspector.record_usize("Refs", ref_count.get())
});
}
});
}
}
pub trait GmpQueryHandler<I: Ip, BC>: DeviceIdContext<AnyDevice> {
fn gmp_is_in_group(
&mut self,
device: &Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> bool;
}
pub trait GmpHandler<I: IpExt, BC>: DeviceIdContext<AnyDevice> {
fn gmp_handle_maybe_enabled(&mut self, bindings_ctx: &mut BC, device: &Self::DeviceId);
fn gmp_handle_disabled(&mut self, bindings_ctx: &mut BC, device: &Self::DeviceId);
fn gmp_join_group(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupJoinResult;
fn gmp_leave_group(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupLeaveResult;
fn gmp_get_mode(&mut self, device: &Self::DeviceId) -> I::GmpProtoConfigMode;
fn gmp_set_mode(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
new_mode: I::GmpProtoConfigMode,
) -> I::GmpProtoConfigMode;
}
impl<I: IpExt, BT: GmpBindingsTypes, CC: GmpStateContext<I, BT>> GmpQueryHandler<I, BT> for CC {
fn gmp_is_in_group(
&mut self,
device: &Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> bool {
self.with_multicast_groups(device, |groups| groups.contains(&group_addr))
}
}
impl<I: IpExt, BC: GmpBindingsContext, CC: GmpContext<I, BC>> GmpHandler<I, BC> for CC {
fn gmp_handle_maybe_enabled(&mut self, bindings_ctx: &mut BC, device: &Self::DeviceId) {
self.with_gmp_state_mut_and_ctx(device, |mut core_ctx, state| {
if !state.enabled {
return;
}
match core::mem::replace(
&mut state.gmp.enablement_idempotency_guard,
LastState::Enabled,
) {
LastState::Disabled => {}
LastState::Enabled => {
return;
}
}
match state.gmp.gmp_mode() {
GmpMode::V1 { compat: _ } => {
v1::handle_enabled(&mut core_ctx, bindings_ctx, device, state);
}
GmpMode::V2 => {
v2::handle_enabled(bindings_ctx, state);
}
}
})
}
fn gmp_handle_disabled(&mut self, bindings_ctx: &mut BC, device: &Self::DeviceId) {
self.with_gmp_state_mut_and_ctx(device, |mut core_ctx, mut state| {
assert!(!state.enabled, "handle_disabled called with enabled GMP state");
match core::mem::replace(
&mut state.gmp.enablement_idempotency_guard,
LastState::Disabled,
) {
LastState::Enabled => {}
LastState::Disabled => {
return;
}
}
match state.gmp.gmp_mode() {
GmpMode::V1 { .. } => {
v1::handle_disabled(&mut core_ctx, bindings_ctx, device, state.as_mut());
}
GmpMode::V2 => {
v2::handle_disabled(&mut core_ctx, bindings_ctx, device, state.as_mut());
}
}
let next_mode =
<CC::Inner<'_> as GmpContextInner<I, BC>>::mode_on_disable(&state.gmp.mode);
enter_mode(bindings_ctx, state.as_mut(), next_mode);
state.gmp.v2_proto = Default::default();
state.gmp.timers.clear(bindings_ctx);
})
}
fn gmp_join_group(
&mut self,
bindings_ctx: &mut BC,
device: &CC::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupJoinResult {
self.with_gmp_state_mut_and_ctx(device, |mut core_ctx, state| match state.gmp.gmp_mode() {
GmpMode::V1 { compat: _ } => {
v1::join_group(&mut core_ctx, bindings_ctx, device, group_addr, state)
}
GmpMode::V2 => v2::join_group(bindings_ctx, group_addr, state),
})
}
fn gmp_leave_group(
&mut self,
bindings_ctx: &mut BC,
device: &CC::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupLeaveResult {
self.with_gmp_state_mut_and_ctx(device, |mut core_ctx, state| match state.gmp.gmp_mode() {
GmpMode::V1 { compat: _ } => {
v1::leave_group(&mut core_ctx, bindings_ctx, device, group_addr, state)
}
GmpMode::V2 => v2::leave_group(bindings_ctx, group_addr, state),
})
}
fn gmp_get_mode(&mut self, device: &CC::DeviceId) -> I::GmpProtoConfigMode {
self.with_gmp_state_mut(device, |state| {
<CC::Inner<'_> as GmpContextInner<I, BC>>::mode_to_config(&state.gmp.mode)
})
}
fn gmp_set_mode(
&mut self,
bindings_ctx: &mut BC,
device: &CC::DeviceId,
new_mode: I::GmpProtoConfigMode,
) -> I::GmpProtoConfigMode {
self.with_gmp_state_mut(device, |state| {
let old_mode =
<CC::Inner<'_> as GmpContextInner<I, BC>>::mode_to_config(&state.gmp.mode);
info!("GMP({}) mode change by user from {:?} to {:?}", I::NAME, old_mode, new_mode);
let new_mode = <CC::Inner<'_> as GmpContextInner<I, BC>>::config_to_mode(
&state.gmp.mode,
new_mode,
);
enter_mode(bindings_ctx, state, new_mode);
old_mode
})
}
}
fn random_report_timeout<R: Rng>(rng: &mut R, period: Duration) -> Duration {
let micros = if let Some(micros) =
NonZeroU64::new(u64::try_from(period.as_micros()).unwrap_or(u64::MAX))
{
rng.gen_range(1..=micros.get())
} else {
1
};
Duration::from_micros(micros)
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub struct GmpTimerId<I: Ip, D: WeakDeviceIdentifier> {
pub(crate) device: D,
pub(crate) _marker: IpVersionMarker<I>,
}
impl<I: Ip, D: WeakDeviceIdentifier> GmpTimerId<I, D> {
fn device_id(&self) -> &D {
let Self { device, _marker: IpVersionMarker { .. } } = self;
device
}
const fn new(device: D) -> Self {
Self { device, _marker: IpVersionMarker::new() }
}
}
pub trait GmpBindingsTypes: InstantBindingsTypes + TimerBindingsTypes {}
impl<BT> GmpBindingsTypes for BT where BT: InstantBindingsTypes + TimerBindingsTypes {}
pub trait GmpBindingsContext: RngContext + TimerContext + GmpBindingsTypes {}
impl<BC> GmpBindingsContext for BC where BC: RngContext + TimerContext + GmpBindingsTypes {}
pub trait IpExt: Ip {
type GmpProtoConfigMode: Debug + Copy + Clone + Eq + PartialEq;
fn should_perform_gmp(addr: MulticastAddr<Self::Addr>) -> bool;
}
#[derive(Debug, Eq, PartialEq, Hash, Clone)]
enum TimerIdInner<I: Ip> {
V1(v1::DelayedReportTimerId<I>),
V1Compat,
V2(v2::TimerId<I>),
}
impl<I: Ip> From<v1::DelayedReportTimerId<I>> for TimerIdInner<I> {
fn from(value: v1::DelayedReportTimerId<I>) -> Self {
Self::V1(value)
}
}
impl<I: Ip> From<v2::TimerId<I>> for TimerIdInner<I> {
fn from(value: v2::TimerId<I>) -> Self {
Self::V2(value)
}
}
#[cfg_attr(test, derive(Debug))]
pub struct GmpState<I: Ip, CC: GmpTypeLayout<I, BT>, BT: GmpBindingsTypes> {
timers: LocalTimerHeap<TimerIdInner<I>, (), BT>,
mode: CC::ProtoMode,
v2_proto: v2::ProtocolState<I>,
enablement_idempotency_guard: LastState,
}
#[cfg_attr(test, derive(Debug))]
enum LastState {
Disabled,
Enabled,
}
impl LastState {
fn from_enabled(enabled: bool) -> Self {
if enabled {
Self::Enabled
} else {
Self::Disabled
}
}
}
impl<I: Ip, T: GmpTypeLayout<I, BC>, BC: GmpBindingsTypes + TimerContext> GmpState<I, T, BC> {
pub fn new<D: WeakDeviceIdentifier, CC: CoreTimerContext<GmpTimerId<I, D>, BC>>(
bindings_ctx: &mut BC,
device: D,
) -> Self {
Self::new_with_enabled_and_mode::<D, CC>(bindings_ctx, device, false, Default::default())
}
fn new_with_enabled_and_mode<
D: WeakDeviceIdentifier,
CC: CoreTimerContext<GmpTimerId<I, D>, BC>,
>(
bindings_ctx: &mut BC,
device: D,
enabled: bool,
mode: T::ProtoMode,
) -> Self {
Self {
timers: LocalTimerHeap::new_with_context::<_, CC>(
bindings_ctx,
GmpTimerId::new(device),
),
mode,
v2_proto: Default::default(),
enablement_idempotency_guard: LastState::from_enabled(enabled),
}
}
}
impl<I: IpExt, T: GmpTypeLayout<I, BT>, BT: GmpBindingsTypes> GmpState<I, T, BT> {
fn gmp_mode(&self) -> GmpMode {
self.mode.into()
}
pub(crate) fn mode(&self) -> &T::ProtoMode {
&self.mode
}
}
pub struct GmpStateRef<'a, I: IpExt, CC: GmpTypeLayout<I, BT>, BT: GmpBindingsTypes> {
pub enabled: bool,
pub groups: &'a mut MulticastGroupSet<I::Addr, GmpGroupState<I, BT>>,
pub gmp: &'a mut GmpState<I, CC, BT>,
pub config: &'a CC::Config,
}
impl<'a, I: IpExt, CC: GmpTypeLayout<I, BT>, BT: GmpBindingsTypes> GmpStateRef<'a, I, CC, BT> {
fn as_mut(&mut self) -> GmpStateRef<'_, I, CC, BT> {
let Self { enabled, groups, gmp, config } = self;
GmpStateRef { enabled: *enabled, groups, gmp, config }
}
}
pub trait GmpTypeLayout<I: Ip, BT: GmpBindingsTypes>: Sized {
type Config: Debug + v1::ProtocolConfig + v2::ProtocolConfig;
type ProtoMode: Debug
+ Copy
+ Clone
+ Eq
+ PartialEq
+ Into<GmpMode>
+ Default
+ InspectableValue;
}
pub struct GmpGroupState<I: Ip, BT: GmpBindingsTypes> {
version_specific: GmpGroupStateByVersion<I, BT>,
}
impl<I: Ip, BT: GmpBindingsTypes> GmpGroupState<I, BT> {
fn v1_mut(&mut self) -> &mut v1::GmpStateMachine<BT::Instant> {
match &mut self.version_specific {
GmpGroupStateByVersion::V1(v1) => return v1,
GmpGroupStateByVersion::V2(_) => {
panic!("expected GMP v1")
}
}
}
fn v2_mut(&mut self) -> &mut v2::GroupState<I> {
match &mut self.version_specific {
GmpGroupStateByVersion::V2(v2) => return v2,
GmpGroupStateByVersion::V1(_) => {
panic!("expected GMP v2")
}
}
}
#[cfg(test)]
fn v1(&self) -> &v1::GmpStateMachine<BT::Instant> {
match &self.version_specific {
GmpGroupStateByVersion::V1(v1) => v1,
GmpGroupStateByVersion::V2(_) => panic!("group not in v1 mode"),
}
}
fn v2(&self) -> &v2::GroupState<I> {
match &self.version_specific {
GmpGroupStateByVersion::V2(v2) => v2,
GmpGroupStateByVersion::V1 { .. } => panic!("group not in v2 mode"),
}
}
fn into_v1(self) -> v1::GmpStateMachine<BT::Instant> {
let Self { version_specific } = self;
match version_specific {
GmpGroupStateByVersion::V1(v1) => v1,
GmpGroupStateByVersion::V2(_) => panic!("expected GMP v1"),
}
}
fn into_v2(self) -> v2::GroupState<I> {
let Self { version_specific } = self;
match version_specific {
GmpGroupStateByVersion::V2(v2) => v2,
GmpGroupStateByVersion::V1(_) => panic!("expected GMP v2"),
}
}
fn new_v1(v1: v1::GmpStateMachine<BT::Instant>) -> Self {
Self { version_specific: GmpGroupStateByVersion::V1(v1) }
}
fn new_v2(v2: v2::GroupState<I>) -> Self {
Self { version_specific: GmpGroupStateByVersion::V2(v2) }
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
pub enum GmpMode {
V1 {
compat: bool,
},
#[default]
V2,
}
impl GmpMode {
fn is_v1(&self) -> bool {
match self {
Self::V1 { .. } => true,
Self::V2 => false,
}
}
fn is_v2(&self) -> bool {
match self {
Self::V2 => true,
Self::V1 { .. } => false,
}
}
fn maybe_enter_v1_compat(&self) -> Self {
match self {
Self::V2 => Self::V1 { compat: true },
m @ Self::V1 { .. } => *m,
}
}
fn maybe_exit_v1_compat(&self) -> Self {
match self {
m @ Self::V2 | m @ Self::V1 { compat: false } => *m,
Self::V1 { compat: true } => Self::V2,
}
}
}
#[cfg_attr(test, derive(derivative::Derivative))]
#[cfg_attr(test, derivative(Debug(bound = "")))]
enum GmpGroupStateByVersion<I: Ip, BT: GmpBindingsTypes> {
V1(v1::GmpStateMachine<BT::Instant>),
V2(v2::GroupState<I>),
}
pub trait GmpStateContext<I: IpExt, BT: GmpBindingsTypes>: DeviceIdContext<AnyDevice> {
type TypeLayout: GmpTypeLayout<I, BT>;
fn with_multicast_groups<
O,
F: FnOnce(&MulticastGroupSet<I::Addr, GmpGroupState<I, BT>>) -> O,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> O {
self.with_gmp_state(device, |groups, _gmp_state| cb(groups))
}
fn with_gmp_state<
O,
F: FnOnce(
&MulticastGroupSet<I::Addr, GmpGroupState<I, BT>>,
&GmpState<I, Self::TypeLayout, BT>,
) -> O,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> O;
}
trait GmpContext<I: IpExt, BC: GmpBindingsContext>: DeviceIdContext<AnyDevice> {
type TypeLayout: GmpTypeLayout<I, BC>;
type Inner<'a>: GmpContextInner<I, BC, TypeLayout = Self::TypeLayout, DeviceId = Self::DeviceId>
+ 'a;
fn with_gmp_state_mut_and_ctx<
O,
F: FnOnce(Self::Inner<'_>, GmpStateRef<'_, I, Self::TypeLayout, BC>) -> O,
>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> O;
fn with_gmp_state_mut<O, F: FnOnce(GmpStateRef<'_, I, Self::TypeLayout, BC>) -> O>(
&mut self,
device: &Self::DeviceId,
cb: F,
) -> O {
self.with_gmp_state_mut_and_ctx(device, |_core_ctx, state| cb(state))
}
}
trait GmpContextInner<I: IpExt, BC: GmpBindingsContext>: DeviceIdContext<AnyDevice> {
type TypeLayout: GmpTypeLayout<I, BC>;
fn send_message_v1(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
cur_mode: &<Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode,
group_addr: GmpEnabledGroup<I::Addr>,
msg_type: v1::GmpMessageType,
);
fn send_report_v2(
&mut self,
bindings_ctx: &mut BC,
device: &Self::DeviceId,
groups: impl Iterator<Item: v2::VerifiedReportGroupRecord<I::Addr> + Clone> + Clone,
);
fn mode_update_from_v1_query<Q: v1::QueryMessage<I>>(
&mut self,
bindings_ctx: &mut BC,
query: &Q,
gmp_state: &GmpState<I, Self::TypeLayout, BC>,
config: &<Self::TypeLayout as GmpTypeLayout<I, BC>>::Config,
) -> <Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode;
fn mode_to_config(
mode: &<Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode,
) -> I::GmpProtoConfigMode;
fn config_to_mode(
cur_mode: &<Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode,
config: I::GmpProtoConfigMode,
) -> <Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode;
fn mode_on_disable(
cur_mode: &<Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode,
) -> <Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode;
fn mode_on_exit_compat() -> <Self::TypeLayout as GmpTypeLayout<I, BC>>::ProtoMode;
}
fn handle_timer<I, BC, CC>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
timer: GmpTimerId<I, CC::WeakDeviceId>,
) where
BC: GmpBindingsContext,
CC: GmpContext<I, BC>,
I: IpExt,
{
let GmpTimerId { device, _marker: IpVersionMarker { .. } } = timer;
let Some(device) = device.upgrade() else {
return;
};
core_ctx.with_gmp_state_mut_and_ctx(&device, |mut core_ctx, state| {
let Some((timer_id, ())) = state.gmp.timers.pop(bindings_ctx) else {
return;
};
assert!(state.enabled, "{timer_id:?} fired in GMP disabled state");
match (timer_id, state.gmp.gmp_mode()) {
(TimerIdInner::V1(v1), GmpMode::V1 { .. }) => {
v1::handle_timer(&mut core_ctx, bindings_ctx, &device, state, v1);
}
(TimerIdInner::V1Compat, GmpMode::V1 { compat: true }) => {
let mode = <CC::Inner<'_> as GmpContextInner<I, BC>>::mode_on_exit_compat();
debug_assert_eq!(mode.into(), GmpMode::V2);
enter_mode(bindings_ctx, state, mode);
}
(TimerIdInner::V2(timer), GmpMode::V2) => {
v2::handle_timer(&mut core_ctx, bindings_ctx, &device, timer, state);
}
(TimerIdInner::V1Compat, bad) => {
panic!("v1 compat timer fired in non v1 compat mode: {bad:?}")
}
bad @ (TimerIdInner::V1(_), GmpMode::V2)
| bad @ (TimerIdInner::V2(_), GmpMode::V1 { .. }) => {
panic!("incompatible timer fired {bad:?}")
}
}
});
}
fn enter_mode<I: IpExt, CC: GmpTypeLayout<I, BC>, BC: GmpBindingsContext>(
bindings_ctx: &mut BC,
state: GmpStateRef<'_, I, CC, BC>,
new_mode: CC::ProtoMode,
) {
let GmpStateRef { enabled: _, gmp, groups, config: _ } = state;
let old_mode = core::mem::replace(&mut gmp.mode, new_mode);
match (old_mode.into(), gmp.gmp_mode()) {
(GmpMode::V1 { compat }, GmpMode::V1 { compat: new_compat }) => {
if new_compat != compat {
assert_eq!(new_compat, false, "attempted to enter compatibility mode from forced");
assert_matches!(
gmp.timers.cancel(bindings_ctx, &TimerIdInner::V1Compat),
Some((_, ()))
);
info!("GMP({}) enter mode {:?}", I::NAME, &gmp.mode);
}
return;
}
(GmpMode::V2, GmpMode::V2) => {
return;
}
(GmpMode::V1 { compat: _ }, GmpMode::V2) => {
for (_, GmpGroupState { version_specific }) in groups.iter_mut() {
*version_specific =
GmpGroupStateByVersion::V2(v2::GroupState::new_for_mode_transition())
}
}
(GmpMode::V2, GmpMode::V1 { compat: _ }) => {
for (_, GmpGroupState { version_specific }) in groups.iter_mut() {
*version_specific =
GmpGroupStateByVersion::V1(v1::GmpStateMachine::new_for_mode_transition())
}
gmp.v2_proto.on_enter_v1();
}
};
info!("GMP({}) enter mode {:?}", I::NAME, new_mode);
gmp.timers.clear(bindings_ctx);
gmp.mode = new_mode;
}
fn schedule_v1_compat<I: IpExt, CC: GmpTypeLayout<I, BC>, BC: GmpBindingsContext>(
bindings_ctx: &mut BC,
state: GmpStateRef<'_, I, CC, BC>,
) {
let GmpStateRef { gmp, config, .. } = state;
let timeout = gmp.v2_proto.older_version_querier_present_timeout(config);
let _: Option<_> =
gmp.timers.schedule_after(bindings_ctx, TimerIdInner::V1Compat, (), timeout.into());
}
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
struct NotAMemberErr<I: Ip>(I::Addr);
enum QueryTarget<A> {
Unspecified,
Specified(MulticastAddr<A>),
}
impl<A: IpAddress> QueryTarget<A> {
fn new(addr: A) -> Option<Self> {
if addr == <A::Version as Ip>::UNSPECIFIED_ADDRESS {
Some(Self::Unspecified)
} else {
MulticastAddr::new(addr).map(Self::Specified)
}
}
}
mod witness {
use super::*;
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub(super) struct GmpEnabledGroup<A>(MulticastAddr<A>);
impl<A: IpAddress<Version: IpExt>> GmpEnabledGroup<A> {
pub fn new(addr: MulticastAddr<A>) -> Option<Self> {
<A::Version as IpExt>::should_perform_gmp(addr).then(|| Self(addr))
}
pub fn try_new(addr: MulticastAddr<A>) -> Result<Self, MulticastAddr<A>> {
Self::new(addr).ok_or_else(|| addr)
}
pub fn multicast_addr(&self) -> MulticastAddr<A> {
let Self(addr) = self;
*addr
}
pub fn into_multicast_addr(self) -> MulticastAddr<A> {
let Self(addr) = self;
addr
}
}
impl<A> AsRef<MulticastAddr<A>> for GmpEnabledGroup<A> {
fn as_ref(&self) -> &MulticastAddr<A> {
let Self(addr) = self;
addr
}
}
}
use witness::GmpEnabledGroup;
#[cfg(test)]
mod tests {
use alloc::vec::Vec;
use core::num::NonZeroU8;
use assert_matches::assert_matches;
use ip_test_macro::ip_test;
use net_types::Witness as _;
use netstack3_base::testutil::{FakeDeviceId, FakeTimerCtxExt, FakeWeakDeviceId};
use netstack3_base::InstantContext as _;
use testutil::{FakeCtx, FakeGmpContextInner, FakeV1Query, TestIpExt};
use super::*;
#[ip_test(I)]
fn mode_change_state_clearing<I: TestIpExt>() {
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V1 { compat: false });
assert_eq!(
core_ctx.gmp_join_group(&mut bindings_ctx, &FakeDeviceId, I::GROUP_ADDR1),
GroupJoinResult::Joined(())
);
core_ctx.inner.v1_messages.clear();
assert!(core_ctx.gmp.timers.iter().next().is_some());
assert_matches!(
core_ctx.groups.get(&I::GROUP_ADDR1).unwrap().version_specific,
GmpGroupStateByVersion::V1(_)
);
core_ctx.with_gmp_state_mut(&FakeDeviceId, |mut state| {
enter_mode(&mut bindings_ctx, state.as_mut(), GmpMode::V2);
assert_eq!(state.gmp.mode, GmpMode::V2);
});
core_ctx.gmp.timers.assert_timers([]);
assert_matches!(
core_ctx.groups.get(&I::GROUP_ADDR1).unwrap().version_specific,
GmpGroupStateByVersion::V2(_)
);
core_ctx.with_gmp_state_mut(&FakeDeviceId, |mut state| {
enter_mode(&mut bindings_ctx, state.as_mut(), GmpMode::V1 { compat: false });
assert_eq!(state.gmp.mode, GmpMode::V1 { compat: false });
});
assert_matches!(
core_ctx.groups.get(&I::GROUP_ADDR1).unwrap().version_specific,
GmpGroupStateByVersion::V1(_)
);
let FakeGmpContextInner { v1_messages, v2_messages } = &core_ctx.inner;
assert_eq!(v1_messages, &Vec::new());
assert_eq!(v2_messages, &Vec::<Vec<_>>::new());
}
#[ip_test(I)]
#[should_panic(expected = "attempted to enter compatibility mode from forced")]
fn cant_enter_v1_compat<I: TestIpExt>() {
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V1 { compat: false });
core_ctx.with_gmp_state_mut(&FakeDeviceId, |mut state| {
enter_mode(&mut bindings_ctx, state.as_mut(), GmpMode::V1 { compat: true });
});
}
#[ip_test(I)]
fn disable_exits_compat<I: TestIpExt>() {
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V1 { compat: true });
core_ctx.enabled = false;
core_ctx.gmp_handle_disabled(&mut bindings_ctx, &FakeDeviceId);
assert_eq!(core_ctx.gmp.mode, GmpMode::V2);
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V1 { compat: false });
core_ctx.enabled = false;
core_ctx.gmp_handle_disabled(&mut bindings_ctx, &FakeDeviceId);
assert_eq!(core_ctx.gmp.mode, GmpMode::V1 { compat: false });
}
#[ip_test(I)]
fn disable_clears_v2_state<I: TestIpExt>() {
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V1 { compat: false });
let v2::ProtocolState { robustness_variable, query_interval, left_groups } =
&mut core_ctx.gmp.v2_proto;
*robustness_variable = robustness_variable.checked_add(1).unwrap();
*query_interval = *query_interval + Duration::from_secs(20);
*left_groups =
[(GmpEnabledGroup::new(I::GROUP_ADDR1).unwrap(), NonZeroU8::new(1).unwrap())]
.into_iter()
.collect();
core_ctx.enabled = false;
core_ctx.gmp_handle_disabled(&mut bindings_ctx, &FakeDeviceId);
assert_eq!(core_ctx.gmp.v2_proto, v2::ProtocolState::default());
}
#[ip_test(I)]
fn v1_compat_mode_on_timeout<I: TestIpExt>() {
let FakeCtx { mut core_ctx, mut bindings_ctx } =
testutil::new_context_with_mode::<I>(GmpMode::V2);
assert_eq!(
v1::handle_query_message(
&mut core_ctx,
&mut bindings_ctx,
&FakeDeviceId,
&FakeV1Query {
group_addr: I::GROUP_ADDR1.get(),
max_response_time: Duration::from_secs(1)
}
),
Err(NotAMemberErr(I::GROUP_ADDR1.get()))
);
assert_eq!(core_ctx.gmp.mode, GmpMode::V1 { compat: true });
let timeout =
core_ctx.gmp.v2_proto.older_version_querier_present_timeout(&core_ctx.config).into();
core_ctx.gmp.timers.assert_timers([(
TimerIdInner::V1Compat,
(),
bindings_ctx.now() + timeout,
)]);
bindings_ctx.timers.instant.sleep(timeout / 2);
assert_eq!(
v1::handle_query_message(
&mut core_ctx,
&mut bindings_ctx,
&FakeDeviceId,
&FakeV1Query {
group_addr: I::GROUP_ADDR1.get(),
max_response_time: Duration::from_secs(1)
}
),
Err(NotAMemberErr(I::GROUP_ADDR1.get()))
);
assert_eq!(core_ctx.gmp.mode, GmpMode::V1 { compat: true });
core_ctx.gmp.timers.assert_timers([(
TimerIdInner::V1Compat,
(),
bindings_ctx.now() + timeout,
)]);
let timer = bindings_ctx.trigger_next_timer(&mut core_ctx);
assert_eq!(timer, Some(GmpTimerId::new(FakeWeakDeviceId(FakeDeviceId))));
assert_eq!(core_ctx.gmp.mode, GmpMode::V2);
core_ctx.gmp.timers.assert_timers([]);
let testutil::FakeGmpContextInner { v1_messages, v2_messages } = &core_ctx.inner;
assert_eq!(v1_messages, &Vec::new());
assert_eq!(v2_messages, &Vec::<Vec<_>>::new());
}
}