use core::num::NonZeroU8;
use core::time::Duration;
use assert_matches::assert_matches;
use derivative::Derivative;
use net_types::ip::Ipv6Addr;
use net_types::UnicastAddr;
use netstack3_base::{
AnyDevice, CoreTimerContext, DeviceIdContext, HandleableTimer, RngContext,
StrongDeviceIdentifier as _, TimerBindingsTypes, TimerContext, TimerHandler,
WeakDeviceIdentifier,
};
use packet::{EitherSerializer, EmptyBuf, InnerPacketBuilder as _, Serializer};
use packet_formats::icmp::ndp::options::NdpOptionBuilder;
use packet_formats::icmp::ndp::{OptionSequenceBuilder, RouterSolicitation};
use rand::Rng as _;
use crate::internal::base::IpSendFrameError;
pub const MAX_RTR_SOLICITATION_DELAY: Duration = Duration::from_secs(1);
pub const RTR_SOLICITATION_INTERVAL: Duration = Duration::from_secs(4);
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub struct RsTimerId<D: WeakDeviceIdentifier> {
device_id: D,
}
impl<D: WeakDeviceIdentifier> RsTimerId<D> {
pub(super) fn device_id(&self) -> &D {
&self.device_id
}
#[cfg(any(test, feature = "testutils"))]
pub fn new(device_id: D) -> Self {
Self { device_id }
}
}
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct RsState<BT: RsBindingsTypes> {
remaining: Option<NonZeroU8>,
timer: Option<BT::Timer>,
}
pub trait RsContext<BC: RsBindingsTypes>:
DeviceIdContext<AnyDevice> + CoreTimerContext<RsTimerId<Self::WeakDeviceId>, BC>
{
type LinkLayerAddr: AsRef<[u8]>;
fn with_rs_state_mut_and_max<O, F: FnOnce(&mut RsState<BC>, Option<NonZeroU8>) -> O>(
&mut self,
device_id: &Self::DeviceId,
cb: F,
) -> O;
fn with_rs_state_mut<O, F: FnOnce(&mut RsState<BC>) -> O>(
&mut self,
device_id: &Self::DeviceId,
cb: F,
) -> O {
self.with_rs_state_mut_and_max(device_id, |state, _max| cb(state))
}
fn get_link_layer_addr_bytes(
&mut self,
device_id: &Self::DeviceId,
) -> Option<Self::LinkLayerAddr>;
fn send_rs_packet<
S: Serializer<Buffer = EmptyBuf>,
F: FnOnce(Option<UnicastAddr<Ipv6Addr>>) -> S,
>(
&mut self,
bindings_ctx: &mut BC,
device_id: &Self::DeviceId,
message: RouterSolicitation,
body: F,
) -> Result<(), IpSendFrameError<S>>;
}
pub trait RsBindingsTypes: TimerBindingsTypes {}
impl<BT> RsBindingsTypes for BT where BT: TimerBindingsTypes {}
pub trait RsBindingsContext: RngContext + TimerContext {}
impl<BC> RsBindingsContext for BC where BC: RngContext + TimerContext {}
pub trait RsHandler<BC: RsBindingsTypes>:
DeviceIdContext<AnyDevice> + TimerHandler<BC, RsTimerId<Self::WeakDeviceId>>
{
fn start_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId);
fn stop_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId);
}
impl<BC: RsBindingsContext, CC: RsContext<BC>> RsHandler<BC> for CC {
fn start_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId) {
self.with_rs_state_mut_and_max(device_id, |state, max| {
let RsState { remaining, timer } = state;
*remaining = max;
assert_matches!(timer, None);
match remaining {
None => {}
Some(_) => {
let delay =
bindings_ctx.rng().gen_range(Duration::ZERO..MAX_RTR_SOLICITATION_DELAY);
let timer = timer.insert(CC::new_timer(
bindings_ctx,
RsTimerId { device_id: device_id.downgrade() },
));
assert_eq!(bindings_ctx.schedule_timer(delay, timer), None);
}
}
});
}
fn stop_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId) {
self.with_rs_state_mut(device_id, |state| {
if let Some(mut timer) = state.timer.take() {
let _: Option<BC::Instant> = bindings_ctx.cancel_timer(&mut timer);
}
});
}
}
impl<BC: RsBindingsContext, CC: RsContext<BC>> HandleableTimer<CC, BC>
for RsTimerId<CC::WeakDeviceId>
{
fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BC, timer: BC::UniqueTimerId) {
let Self { device_id } = self;
if let Some(device_id) = device_id.upgrade() {
do_router_solicitation(core_ctx, bindings_ctx, &device_id, timer)
}
}
}
fn do_router_solicitation<BC: RsBindingsContext, CC: RsContext<BC>>(
core_ctx: &mut CC,
bindings_ctx: &mut BC,
device_id: &CC::DeviceId,
timer_id: BC::UniqueTimerId,
) {
let send_rs = core_ctx.with_rs_state_mut(device_id, |RsState { remaining, timer }| {
let Some(timer) = timer.as_mut() else {
return false;
};
if bindings_ctx.unique_timer_id(timer) != timer_id {
return false;
}
*remaining = NonZeroU8::new(
remaining
.expect("should only send a router solicitations when at least one is remaining")
.get()
- 1,
);
match *remaining {
None => {}
Some(NonZeroU8 { .. }) => {
assert_eq!(bindings_ctx.schedule_timer(RTR_SOLICITATION_INTERVAL, timer), None);
}
}
true
});
if !send_rs {
return;
}
let src_ll = core_ctx.get_link_layer_addr_bytes(device_id);
let _: Result<(), _> =
core_ctx.send_rs_packet(bindings_ctx, device_id, RouterSolicitation::default(), |src_ip| {
src_ip.map_or(EitherSerializer::A(EmptyBuf), |UnicastAddr { .. }| {
EitherSerializer::B(
OptionSequenceBuilder::new(
src_ll
.as_ref()
.map(AsRef::as_ref)
.into_iter()
.map(NdpOptionBuilder::SourceLinkLayerAddress),
)
.into_serializer(),
)
})
});
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use net_declare::net_ip_v6;
use netstack3_base::testutil::{
FakeBindingsCtx, FakeCoreCtx, FakeDeviceId, FakeTimerCtxExt as _, FakeWeakDeviceId,
};
use netstack3_base::{CtxPair, InstantContext as _, SendFrameContext as _};
use packet_formats::icmp::ndp::options::NdpOption;
use packet_formats::icmp::ndp::Options;
use test_case::test_case;
use super::*;
struct FakeRsContext {
max_router_solicitations: Option<NonZeroU8>,
rs_state: RsState<FakeBindingsCtxImpl>,
source_address: Option<UnicastAddr<Ipv6Addr>>,
link_layer_bytes: Option<Vec<u8>>,
}
#[derive(Debug, PartialEq)]
struct RsMessageMeta {
message: RouterSolicitation,
}
type FakeCoreCtxImpl = FakeCoreCtx<FakeRsContext, RsMessageMeta, FakeDeviceId>;
type FakeBindingsCtxImpl =
FakeBindingsCtx<RsTimerId<FakeWeakDeviceId<FakeDeviceId>>, (), (), ()>;
impl CoreTimerContext<RsTimerId<FakeWeakDeviceId<FakeDeviceId>>, FakeBindingsCtxImpl>
for FakeCoreCtxImpl
{
fn convert_timer(
dispatch_id: RsTimerId<FakeWeakDeviceId<FakeDeviceId>>,
) -> <FakeBindingsCtxImpl as TimerBindingsTypes>::DispatchId {
dispatch_id
}
}
impl RsContext<FakeBindingsCtxImpl> for FakeCoreCtxImpl {
type LinkLayerAddr = Vec<u8>;
fn with_rs_state_mut_and_max<
O,
F: FnOnce(&mut RsState<FakeBindingsCtxImpl>, Option<NonZeroU8>) -> O,
>(
&mut self,
&FakeDeviceId: &FakeDeviceId,
cb: F,
) -> O {
let FakeRsContext { max_router_solicitations, rs_state, .. } = &mut self.state;
cb(rs_state, *max_router_solicitations)
}
fn get_link_layer_addr_bytes(&mut self, &FakeDeviceId: &FakeDeviceId) -> Option<Vec<u8>> {
let FakeRsContext { link_layer_bytes, .. } = &self.state;
link_layer_bytes.clone()
}
fn send_rs_packet<
S: Serializer<Buffer = EmptyBuf>,
F: FnOnce(Option<UnicastAddr<Ipv6Addr>>) -> S,
>(
&mut self,
bindings_ctx: &mut FakeBindingsCtxImpl,
&FakeDeviceId: &FakeDeviceId,
message: RouterSolicitation,
body: F,
) -> Result<(), IpSendFrameError<S>> {
let FakeRsContext { source_address, .. } = &self.state;
self.send_frame(bindings_ctx, RsMessageMeta { message }, body(*source_address))
.map_err(|e| e.err_into())
}
}
const RS_TIMER_ID: RsTimerId<FakeWeakDeviceId<FakeDeviceId>> =
RsTimerId { device_id: FakeWeakDeviceId(FakeDeviceId) };
#[test]
fn stop_router_solicitation() {
let CtxPair { mut core_ctx, mut bindings_ctx } =
CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
max_router_solicitations: NonZeroU8::new(1),
rs_state: Default::default(),
source_address: None,
link_layer_bytes: None,
}));
RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
let now = bindings_ctx.now();
bindings_ctx
.timers
.assert_timers_installed_range([(RS_TIMER_ID, now..=now + MAX_RTR_SOLICITATION_DELAY)]);
RsHandler::stop_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
bindings_ctx.timers.assert_no_timers_installed();
assert_eq!(core_ctx.frames(), &[][..]);
}
const SOURCE_ADDRESS: UnicastAddr<Ipv6Addr> =
unsafe { UnicastAddr::new_unchecked(net_ip_v6!("fe80::1")) };
#[test_case(0, None, None, None; "disabled")]
#[test_case(1, None, None, None; "once_without_source_address_or_link_layer_option")]
#[test_case(
1,
Some(SOURCE_ADDRESS),
None,
None; "once_with_source_address_and_without_link_layer_option")]
#[test_case(
1,
None,
Some(vec![1, 2, 3, 4, 5, 6]),
None; "once_without_source_address_and_with_mac_address_source_link_layer_option")]
#[test_case(
1,
Some(SOURCE_ADDRESS),
Some(vec![1, 2, 3, 4, 5, 6]),
Some(&[1, 2, 3, 4, 5, 6]); "once_with_source_address_and_mac_address_source_link_layer_option")]
#[test_case(
1,
Some(SOURCE_ADDRESS),
Some(vec![1, 2, 3, 4, 5]),
Some(&[1, 2, 3, 4, 5, 0]); "once_with_source_address_and_short_address_source_link_layer_option")]
#[test_case(
1,
Some(SOURCE_ADDRESS),
Some(vec![1, 2, 3, 4, 5, 6, 7]),
Some(&[
1, 2, 3, 4, 5, 6, 7,
0, 0, 0, 0, 0, 0, 0,
]); "once_with_source_address_and_long_address_source_link_layer_option")]
fn perform_router_solicitation(
max_router_solicitations: u8,
source_address: Option<UnicastAddr<Ipv6Addr>>,
link_layer_bytes: Option<Vec<u8>>,
expected_sll_bytes: Option<&[u8]>,
) {
let CtxPair { mut core_ctx, mut bindings_ctx } =
CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
max_router_solicitations: NonZeroU8::new(max_router_solicitations),
rs_state: Default::default(),
source_address,
link_layer_bytes,
}));
RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
assert_eq!(core_ctx.frames(), &[][..]);
let mut duration = MAX_RTR_SOLICITATION_DELAY;
for i in 0..max_router_solicitations {
assert_eq!(
core_ctx.state.rs_state.remaining,
NonZeroU8::new(max_router_solicitations - i)
);
let now = bindings_ctx.now();
bindings_ctx
.timers
.assert_timers_installed_range([(RS_TIMER_ID, now..=now + duration)]);
assert_eq!(bindings_ctx.trigger_next_timer(&mut core_ctx), Some(RS_TIMER_ID));
let frames = core_ctx.frames();
assert_eq!(frames.len(), usize::from(i + 1), "frames = {:?}", frames);
let (RsMessageMeta { message }, frame) =
frames.last().expect("should have transmitted a frame");
assert_eq!(*message, RouterSolicitation::default());
let options = Options::parse(&frame[..]).expect("parse NDP options");
let sll_bytes = options.iter().find_map(|o| match o {
NdpOption::SourceLinkLayerAddress(a) => Some(a),
o => panic!("unexpected NDP option = {:?}", o),
});
assert_eq!(sll_bytes, expected_sll_bytes);
duration = RTR_SOLICITATION_INTERVAL;
}
bindings_ctx.timers.assert_no_timers_installed();
assert_eq!(core_ctx.state.rs_state.remaining, None);
let frames = core_ctx.frames();
assert_eq!(frames.len(), usize::from(max_router_solicitations), "frames = {:?}", frames);
}
#[test]
fn previous_cycle_timers_ignored() {
let CtxPair { mut core_ctx, mut bindings_ctx } =
CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
max_router_solicitations: NonZeroU8::new(1),
rs_state: Default::default(),
source_address: None,
link_layer_bytes: None,
}));
RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
let timer_id = core_ctx.state.rs_state.timer.as_ref().unwrap().timer_id();
RsHandler::stop_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
do_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId, timer_id);
assert_eq!(core_ctx.frames(), &[][..]);
RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
do_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId, timer_id);
assert_eq!(core_ctx.frames(), &[][..]);
}
}