1use core::num::NonZeroU8;
10use core::time::Duration;
11
12use assert_matches::assert_matches;
13use derivative::Derivative;
14use net_types::ip::Ipv6Addr;
15use net_types::UnicastAddr;
16use netstack3_base::{
17 AnyDevice, CoreTimerContext, DeviceIdContext, HandleableTimer, RngContext,
18 StrongDeviceIdentifier as _, TimerBindingsTypes, TimerContext, TimerHandler,
19 WeakDeviceIdentifier,
20};
21use packet::{EitherSerializer, EmptyBuf, InnerPacketBuilder as _, Serializer};
22use packet_formats::icmp::ndp::options::NdpOptionBuilder;
23use packet_formats::icmp::ndp::{OptionSequenceBuilder, RouterSolicitation};
24use rand::Rng as _;
25
26use crate::internal::base::IpSendFrameError;
27use crate::internal::device::Ipv6LinkLayerAddr;
28
29pub const MAX_RTR_SOLICITATION_DELAY: Duration = Duration::from_secs(1);
44
45pub const RTR_SOLICITATION_INTERVAL: Duration = Duration::from_secs(4);
50
51#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
53pub struct RsTimerId<D: WeakDeviceIdentifier> {
54 device_id: D,
55}
56
57impl<D: WeakDeviceIdentifier> RsTimerId<D> {
58 pub(super) fn device_id(&self) -> &D {
59 &self.device_id
60 }
61
62 #[cfg(any(test, feature = "testutils"))]
64 pub fn new(device_id: D) -> Self {
65 Self { device_id }
66 }
67}
68
69#[derive(Derivative)]
71#[derivative(Default(bound = ""))]
72pub struct RsState<BT: RsBindingsTypes> {
73 remaining: Option<NonZeroU8>,
74 timer: Option<BT::Timer>,
75}
76
77pub trait RsContext<BC: RsBindingsTypes>:
79 DeviceIdContext<AnyDevice> + CoreTimerContext<RsTimerId<Self::WeakDeviceId>, BC>
80{
81 type LinkLayerAddr: Ipv6LinkLayerAddr;
83
84 fn with_rs_state_mut_and_max<O, F: FnOnce(&mut RsState<BC>, Option<NonZeroU8>) -> O>(
87 &mut self,
88 device_id: &Self::DeviceId,
89 cb: F,
90 ) -> O;
91
92 fn with_rs_state_mut<O, F: FnOnce(&mut RsState<BC>) -> O>(
95 &mut self,
96 device_id: &Self::DeviceId,
97 cb: F,
98 ) -> O {
99 self.with_rs_state_mut_and_max(device_id, |state, _max| cb(state))
100 }
101
102 fn get_link_layer_addr(&mut self, device_id: &Self::DeviceId) -> Option<Self::LinkLayerAddr>;
105
106 fn send_rs_packet<
111 S: Serializer<Buffer = EmptyBuf>,
112 F: FnOnce(Option<UnicastAddr<Ipv6Addr>>) -> S,
113 >(
114 &mut self,
115 bindings_ctx: &mut BC,
116 device_id: &Self::DeviceId,
117 message: RouterSolicitation,
118 body: F,
119 ) -> Result<(), IpSendFrameError<S>>;
120}
121
122pub trait RsBindingsTypes: TimerBindingsTypes {}
124impl<BT> RsBindingsTypes for BT where BT: TimerBindingsTypes {}
125
126pub trait RsBindingsContext: RngContext + TimerContext {}
128impl<BC> RsBindingsContext for BC where BC: RngContext + TimerContext {}
129
130pub trait RsHandler<BC: RsBindingsTypes>:
132 DeviceIdContext<AnyDevice> + TimerHandler<BC, RsTimerId<Self::WeakDeviceId>>
133{
134 fn start_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId);
136
137 fn stop_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId);
141}
142
143impl<BC: RsBindingsContext, CC: RsContext<BC>> RsHandler<BC> for CC {
144 fn start_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId) {
145 self.with_rs_state_mut_and_max(device_id, |state, max| {
146 let RsState { remaining, timer } = state;
147 *remaining = max;
148
149 assert_matches!(timer, None);
152
153 match remaining {
154 None => {}
155 Some(_) => {
156 let delay =
161 bindings_ctx.rng().gen_range(Duration::ZERO..MAX_RTR_SOLICITATION_DELAY);
162
163 let timer = timer.insert(CC::new_timer(
164 bindings_ctx,
165 RsTimerId { device_id: device_id.downgrade() },
166 ));
167 assert_eq!(bindings_ctx.schedule_timer(delay, timer), None);
168 }
169 }
170 });
171 }
172
173 fn stop_router_solicitation(&mut self, bindings_ctx: &mut BC, device_id: &Self::DeviceId) {
174 self.with_rs_state_mut(device_id, |state| {
175 if let Some(mut timer) = state.timer.take() {
177 let _: Option<BC::Instant> = bindings_ctx.cancel_timer(&mut timer);
178 }
179 });
180 }
181}
182
183impl<BC: RsBindingsContext, CC: RsContext<BC>> HandleableTimer<CC, BC>
184 for RsTimerId<CC::WeakDeviceId>
185{
186 fn handle(self, core_ctx: &mut CC, bindings_ctx: &mut BC, timer: BC::UniqueTimerId) {
187 let Self { device_id } = self;
188 if let Some(device_id) = device_id.upgrade() {
189 do_router_solicitation(core_ctx, bindings_ctx, &device_id, timer)
190 }
191 }
192}
193
194fn do_router_solicitation<BC: RsBindingsContext, CC: RsContext<BC>>(
196 core_ctx: &mut CC,
197 bindings_ctx: &mut BC,
198 device_id: &CC::DeviceId,
199 timer_id: BC::UniqueTimerId,
200) {
201 let send_rs = core_ctx.with_rs_state_mut(device_id, |RsState { remaining, timer }| {
202 let Some(timer) = timer.as_mut() else {
203 return false;
205 };
206 if bindings_ctx.unique_timer_id(timer) != timer_id {
207 return false;
210 }
211 *remaining = NonZeroU8::new(
212 remaining
213 .expect("should only send a router solicitations when at least one is remaining")
214 .get()
215 - 1,
216 );
217
218 match *remaining {
220 None => {}
221 Some(NonZeroU8 { .. }) => {
222 assert_eq!(bindings_ctx.schedule_timer(RTR_SOLICITATION_INTERVAL, timer), None);
223 }
224 }
225
226 true
227 });
228
229 if !send_rs {
230 return;
231 }
232
233 let src_ll = core_ctx.get_link_layer_addr(device_id);
234
235 let _: Result<(), _> =
238 core_ctx.send_rs_packet(bindings_ctx, device_id, RouterSolicitation::default(), |src_ip| {
239 src_ip.map_or(EitherSerializer::A(EmptyBuf), |UnicastAddr { .. }| {
249 EitherSerializer::B(
250 OptionSequenceBuilder::new(
251 src_ll
252 .as_ref()
253 .map(Ipv6LinkLayerAddr::as_bytes)
254 .into_iter()
255 .map(NdpOptionBuilder::SourceLinkLayerAddress),
256 )
257 .into_serializer(),
258 )
259 })
260 });
261}
262
263#[cfg(test)]
264mod tests {
265 use alloc::vec;
266 use alloc::vec::Vec;
267
268 use net_declare::net_ip_v6;
269 use netstack3_base::testutil::{
270 FakeBindingsCtx, FakeCoreCtx, FakeDeviceId, FakeTimerCtxExt as _, FakeWeakDeviceId,
271 };
272 use netstack3_base::{CtxPair, InstantContext as _, SendFrameContext as _};
273 use packet_formats::icmp::ndp::options::NdpOption;
274 use packet_formats::icmp::ndp::Options;
275 use test_case::test_case;
276
277 use super::*;
278
279 struct FakeRsContext {
280 max_router_solicitations: Option<NonZeroU8>,
281 rs_state: RsState<FakeBindingsCtxImpl>,
282 source_address: Option<UnicastAddr<Ipv6Addr>>,
283 link_layer_bytes: Option<Vec<u8>>,
284 }
285
286 #[derive(Debug, PartialEq)]
287 struct RsMessageMeta {
288 message: RouterSolicitation,
289 }
290
291 type FakeCoreCtxImpl = FakeCoreCtx<FakeRsContext, RsMessageMeta, FakeDeviceId>;
292 type FakeBindingsCtxImpl =
293 FakeBindingsCtx<RsTimerId<FakeWeakDeviceId<FakeDeviceId>>, (), (), ()>;
294
295 impl CoreTimerContext<RsTimerId<FakeWeakDeviceId<FakeDeviceId>>, FakeBindingsCtxImpl>
296 for FakeCoreCtxImpl
297 {
298 fn convert_timer(
299 dispatch_id: RsTimerId<FakeWeakDeviceId<FakeDeviceId>>,
300 ) -> <FakeBindingsCtxImpl as TimerBindingsTypes>::DispatchId {
301 dispatch_id
302 }
303 }
304
305 impl Ipv6LinkLayerAddr for Vec<u8> {
306 fn as_bytes(&self) -> &[u8] {
307 &self
308 }
309
310 fn eui64_iid(&self) -> [u8; 8] {
311 unimplemented!()
312 }
313 }
314
315 impl RsContext<FakeBindingsCtxImpl> for FakeCoreCtxImpl {
316 type LinkLayerAddr = Vec<u8>;
317
318 fn with_rs_state_mut_and_max<
319 O,
320 F: FnOnce(&mut RsState<FakeBindingsCtxImpl>, Option<NonZeroU8>) -> O,
321 >(
322 &mut self,
323 &FakeDeviceId: &FakeDeviceId,
324 cb: F,
325 ) -> O {
326 let FakeRsContext { max_router_solicitations, rs_state, .. } = &mut self.state;
327 cb(rs_state, *max_router_solicitations)
328 }
329
330 fn get_link_layer_addr(&mut self, &FakeDeviceId: &FakeDeviceId) -> Option<Vec<u8>> {
331 let FakeRsContext { link_layer_bytes, .. } = &self.state;
332 link_layer_bytes.clone()
333 }
334
335 fn send_rs_packet<
336 S: Serializer<Buffer = EmptyBuf>,
337 F: FnOnce(Option<UnicastAddr<Ipv6Addr>>) -> S,
338 >(
339 &mut self,
340 bindings_ctx: &mut FakeBindingsCtxImpl,
341 &FakeDeviceId: &FakeDeviceId,
342 message: RouterSolicitation,
343 body: F,
344 ) -> Result<(), IpSendFrameError<S>> {
345 let FakeRsContext { source_address, .. } = &self.state;
346 self.send_frame(bindings_ctx, RsMessageMeta { message }, body(*source_address))
347 .map_err(|e| e.err_into())
348 }
349 }
350
351 const RS_TIMER_ID: RsTimerId<FakeWeakDeviceId<FakeDeviceId>> =
352 RsTimerId { device_id: FakeWeakDeviceId(FakeDeviceId) };
353
354 #[test]
355 fn stop_router_solicitation() {
356 let CtxPair { mut core_ctx, mut bindings_ctx } =
357 CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
358 max_router_solicitations: NonZeroU8::new(1),
359 rs_state: Default::default(),
360 source_address: None,
361 link_layer_bytes: None,
362 }));
363 RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
364
365 let now = bindings_ctx.now();
366 bindings_ctx
367 .timers
368 .assert_timers_installed_range([(RS_TIMER_ID, now..=now + MAX_RTR_SOLICITATION_DELAY)]);
369
370 RsHandler::stop_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
371 bindings_ctx.timers.assert_no_timers_installed();
372
373 assert_eq!(core_ctx.frames(), &[][..]);
374 }
375
376 const SOURCE_ADDRESS: UnicastAddr<Ipv6Addr> =
377 unsafe { UnicastAddr::new_unchecked(net_ip_v6!("fe80::1")) };
378
379 #[test_case(0, None, None, None; "disabled")]
380 #[test_case(1, None, None, None; "once_without_source_address_or_link_layer_option")]
381 #[test_case(
382 1,
383 Some(SOURCE_ADDRESS),
384 None,
385 None; "once_with_source_address_and_without_link_layer_option")]
386 #[test_case(
387 1,
388 None,
389 Some(vec![1, 2, 3, 4, 5, 6]),
390 None; "once_without_source_address_and_with_mac_address_source_link_layer_option")]
391 #[test_case(
392 1,
393 Some(SOURCE_ADDRESS),
394 Some(vec![1, 2, 3, 4, 5, 6]),
395 Some(&[1, 2, 3, 4, 5, 6]); "once_with_source_address_and_mac_address_source_link_layer_option")]
396 #[test_case(
397 1,
398 Some(SOURCE_ADDRESS),
399 Some(vec![1, 2, 3, 4, 5]),
400 Some(&[1, 2, 3, 4, 5, 0]); "once_with_source_address_and_short_address_source_link_layer_option")]
401 #[test_case(
402 1,
403 Some(SOURCE_ADDRESS),
404 Some(vec![1, 2, 3, 4, 5, 6, 7]),
405 Some(&[
406 1, 2, 3, 4, 5, 6, 7,
407 0, 0, 0, 0, 0, 0, 0,
408 ]); "once_with_source_address_and_long_address_source_link_layer_option")]
409 fn perform_router_solicitation(
410 max_router_solicitations: u8,
411 source_address: Option<UnicastAddr<Ipv6Addr>>,
412 link_layer_bytes: Option<Vec<u8>>,
413 expected_sll_bytes: Option<&[u8]>,
414 ) {
415 let CtxPair { mut core_ctx, mut bindings_ctx } =
416 CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
417 max_router_solicitations: NonZeroU8::new(max_router_solicitations),
418 rs_state: Default::default(),
419 source_address,
420 link_layer_bytes,
421 }));
422 RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
423
424 assert_eq!(core_ctx.frames(), &[][..]);
425
426 let mut duration = MAX_RTR_SOLICITATION_DELAY;
427 for i in 0..max_router_solicitations {
428 assert_eq!(
429 core_ctx.state.rs_state.remaining,
430 NonZeroU8::new(max_router_solicitations - i)
431 );
432 let now = bindings_ctx.now();
433 bindings_ctx
434 .timers
435 .assert_timers_installed_range([(RS_TIMER_ID, now..=now + duration)]);
436
437 assert_eq!(bindings_ctx.trigger_next_timer(&mut core_ctx), Some(RS_TIMER_ID));
438 let frames = core_ctx.frames();
439 assert_eq!(frames.len(), usize::from(i + 1), "frames = {:?}", frames);
440 let (RsMessageMeta { message }, frame) =
441 frames.last().expect("should have transmitted a frame");
442 assert_eq!(*message, RouterSolicitation::default());
443 let options = Options::parse(&frame[..]).expect("parse NDP options");
444 let sll_bytes = options.iter().find_map(|o| match o {
445 NdpOption::SourceLinkLayerAddress(a) => Some(a),
446 o => panic!("unexpected NDP option = {:?}", o),
447 });
448
449 assert_eq!(sll_bytes, expected_sll_bytes);
450 duration = RTR_SOLICITATION_INTERVAL;
451 }
452
453 bindings_ctx.timers.assert_no_timers_installed();
454 assert_eq!(core_ctx.state.rs_state.remaining, None);
455 let frames = core_ctx.frames();
456 assert_eq!(frames.len(), usize::from(max_router_solicitations), "frames = {:?}", frames);
457 }
458
459 #[test]
462 fn previous_cycle_timers_ignored() {
463 let CtxPair { mut core_ctx, mut bindings_ctx } =
464 CtxPair::with_core_ctx(FakeCoreCtxImpl::with_state(FakeRsContext {
465 max_router_solicitations: NonZeroU8::new(1),
466 rs_state: Default::default(),
467 source_address: None,
468 link_layer_bytes: None,
469 }));
470 RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
471 let timer_id = core_ctx.state.rs_state.timer.as_ref().unwrap().timer_id();
473 RsHandler::stop_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
474 do_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId, timer_id);
476 assert_eq!(core_ctx.frames(), &[][..]);
477 RsHandler::start_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId);
480 do_router_solicitation(&mut core_ctx, &mut bindings_ctx, &FakeDeviceId, timer_id);
481 assert_eq!(core_ctx.frames(), &[][..]);
482 }
483}