1use alloc::collections::BinaryHeap;
8use alloc::vec::Vec;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::time::Duration;
12use netstack3_hashmap::HashMap;
13
14use packet::Buf;
15
16use crate::InstantContext as _;
17use crate::testutil::{
18 FakeInstant, FakeTimerId, InstantAndData, WithFakeFrameContext, WithFakeTimerContext,
19};
20
21pub struct FakeNetwork<Spec: FakeNetworkSpec, CtxId, Links> {
26 links: Links,
27 current_time: FakeInstant,
28 pending_frames: BinaryHeap<PendingFrame<CtxId, Spec::RecvMeta>>,
29 contexts: HashMap<CtxId, Spec::Context>,
33}
34
35#[derive(Debug)]
37pub struct PendingFrameData<CtxId, Meta> {
38 pub dst_context: CtxId,
40 pub meta: Meta,
42 pub frame: Vec<u8>,
44}
45
46pub type PendingFrame<CtxId, Meta> = InstantAndData<PendingFrameData<CtxId, Meta>>;
48
49pub trait FakeNetworkSpec: Sized {
51 type Context: WithFakeTimerContext<Self::TimerId>;
53 type TimerId: Clone;
55 type SendMeta;
58 type RecvMeta;
61
62 fn handle_frame(ctx: &mut Self::Context, recv: Self::RecvMeta, data: Buf<Vec<u8>>);
64 fn handle_timer(ctx: &mut Self::Context, dispatch: Self::TimerId, timer: FakeTimerId);
70 fn process_queues(ctx: &mut Self::Context) -> bool;
76
77 fn fake_frames(ctx: &mut Self::Context) -> &mut impl WithFakeFrameContext<Self::SendMeta>;
79
80 fn new_network<CtxId, Links, I>(contexts: I, links: Links) -> FakeNetwork<Self, CtxId, Links>
82 where
83 CtxId: Eq + Hash + Copy + Debug,
84 I: IntoIterator<Item = (CtxId, Self::Context)>,
85 Links: FakeNetworkLinks<Self::SendMeta, Self::RecvMeta, CtxId>,
86 {
87 FakeNetwork::new(contexts, links)
88 }
89}
90
91pub trait FakeNetworkLinks<SendMeta, RecvMeta, CtxId> {
99 fn map_link(&self, ctx: CtxId, meta: SendMeta) -> Vec<(CtxId, RecvMeta, Option<Duration>)>;
102}
103
104impl<SendMeta, RecvMeta, CtxId, F: Fn(CtxId, SendMeta) -> Vec<(CtxId, RecvMeta, Option<Duration>)>>
105 FakeNetworkLinks<SendMeta, RecvMeta, CtxId> for F
106{
107 fn map_link(&self, ctx: CtxId, meta: SendMeta) -> Vec<(CtxId, RecvMeta, Option<Duration>)> {
108 (self)(ctx, meta)
109 }
110}
111
112#[derive(Debug)]
114pub struct StepResult {
115 pub timers_fired: usize,
117 pub frames_sent: usize,
119 pub contexts_with_queued_frames: usize,
121}
122
123impl StepResult {
124 fn new_idle() -> Self {
125 Self { timers_fired: 0, frames_sent: 0, contexts_with_queued_frames: 0 }
126 }
127
128 pub fn is_idle(&self) -> bool {
130 return self.timers_fired == 0
131 && self.frames_sent == 0
132 && self.contexts_with_queued_frames == 0;
133 }
134}
135
136impl<Spec, CtxId, Links> FakeNetwork<Spec, CtxId, Links>
137where
138 CtxId: Eq + Hash,
139 Spec: FakeNetworkSpec,
140{
141 pub fn context<K: Into<CtxId>>(&mut self, context: K) -> &mut Spec::Context {
143 self.contexts.get_mut(&context.into()).unwrap()
144 }
145
146 pub fn with_context<K: Into<CtxId>, O, F: FnOnce(&mut Spec::Context) -> O>(
148 &mut self,
149 context: K,
150 f: F,
151 ) -> O {
152 f(self.context(context))
153 }
154}
155
156impl<Spec, CtxId, Links> FakeNetwork<Spec, CtxId, Links>
157where
158 Spec: FakeNetworkSpec,
159 CtxId: Eq + Hash + Copy + Debug,
160 Links: FakeNetworkLinks<Spec::SendMeta, Spec::RecvMeta, CtxId>,
161{
162 pub fn new<I: IntoIterator<Item = (CtxId, Spec::Context)>>(contexts: I, links: Links) -> Self {
175 let mut contexts = contexts.into_iter().collect::<HashMap<_, _>>();
176 let latest_time = contexts
186 .iter()
187 .map(|(_, ctx)| ctx.with_fake_timer_ctx(|ctx| ctx.instant.time))
188 .max()
189 .unwrap_or_else(FakeInstant::default);
193
194 assert!(
195 !contexts
196 .iter()
197 .any(|(_, ctx)| { !ctx.with_fake_timer_ctx(|ctx| ctx.timers.is_empty()) }),
198 "can't start network with contexts that already have timers set"
199 );
200
201 for (_, ctx) in contexts.iter_mut() {
204 ctx.with_fake_timer_ctx_mut(|ctx| ctx.instant.time = latest_time);
205 }
206
207 Self { contexts, current_time: latest_time, pending_frames: BinaryHeap::new(), links }
208 }
209
210 pub fn iter_pending_frames(
212 &self,
213 ) -> impl Iterator<Item = &PendingFrame<CtxId, Spec::RecvMeta>> {
214 self.pending_frames.iter()
215 }
216
217 #[track_caller]
219 pub fn assert_no_pending_frames(&self)
220 where
221 Spec::RecvMeta: Debug,
222 {
223 assert!(self.pending_frames.is_empty(), "pending frames: {:?}", self.pending_frames);
224 }
225
226 pub fn drop_pending_frames(&mut self) {
228 self.pending_frames.clear();
229 }
230
231 pub fn step(&mut self) -> StepResult
264 where
265 Spec::TimerId: Debug,
266 {
267 self.step_with(|_, meta, buf| Some((meta, buf)))
268 }
269
270 pub fn step_with<
274 F: FnMut(
275 &mut Spec::Context,
276 Spec::RecvMeta,
277 Buf<Vec<u8>>,
278 ) -> Option<(Spec::RecvMeta, Buf<Vec<u8>>)>,
279 >(
280 &mut self,
281 filter_map_frame: F,
282 ) -> StepResult
283 where
284 Spec::TimerId: Debug,
285 {
286 let mut ret = StepResult::new_idle();
287 for (_, ctx) in self.contexts.iter_mut() {
290 if Spec::process_queues(ctx) {
291 ret.contexts_with_queued_frames += 1;
292 }
293 }
294
295 self.collect_frames();
296
297 let next_step = if let Some(t) = self.next_step() {
298 t
299 } else {
300 return ret;
301 };
302
303 assert!(next_step >= self.current_time);
306
307 self.current_time = next_step;
309 for (_, ctx) in self.contexts.iter_mut() {
310 ctx.with_fake_timer_ctx_mut(|ctx| ctx.instant.time = next_step);
311 }
312
313 ret.frames_sent = self.dispatch_pending_frames(filter_map_frame);
314
315 for (_, ctx) in self.contexts.iter_mut() {
317 let mut timers = Vec::<(Spec::TimerId, FakeTimerId)>::new();
321 ctx.with_fake_timer_ctx_mut(|ctx| {
322 while let Some(InstantAndData(t, timer)) = ctx.timers.peek()
323 && *t <= ctx.now()
324 {
325 timers.push((timer.dispatch_id.clone(), timer.timer_id()));
326 assert_ne!(ctx.timers.pop(), None);
327 }
328 });
329
330 for (dispatch_id, timer_id) in timers {
331 Spec::handle_timer(ctx, dispatch_id, timer_id);
332 ret.timers_fired += 1;
333 }
334 }
335 ret
336 }
337
338 pub fn step_deliver_frames(&mut self) -> StepResult
345 where
346 Spec::TimerId: Debug,
347 {
348 self.step_deliver_frames_with(|_, meta, frame| Some((meta, frame)))
349 }
350
351 pub fn step_deliver_frames_with<
355 F: FnMut(
356 &mut Spec::Context,
357 Spec::RecvMeta,
358 Buf<Vec<u8>>,
359 ) -> Option<(Spec::RecvMeta, Buf<Vec<u8>>)>,
360 >(
361 &mut self,
362 filter_map_frame: F,
363 ) -> StepResult
364 where
365 Spec::TimerId: Debug,
366 {
367 let mut ret = StepResult::new_idle();
368 for (_, ctx) in self.contexts.iter_mut() {
370 if Spec::process_queues(ctx) {
371 ret.contexts_with_queued_frames += 1;
372 }
373 }
374
375 self.collect_frames();
376 ret.frames_sent = self.dispatch_pending_frames(filter_map_frame);
377
378 ret
379 }
380
381 pub fn run_until_idle(&mut self)
388 where
389 Spec::TimerId: Debug,
390 {
391 self.run_until_idle_with(|_, meta, frame| Some((meta, frame)))
392 }
393
394 pub fn run_until_idle_with<
398 F: FnMut(
399 &mut Spec::Context,
400 Spec::RecvMeta,
401 Buf<Vec<u8>>,
402 ) -> Option<(Spec::RecvMeta, Buf<Vec<u8>>)>,
403 >(
404 &mut self,
405 mut filter_map_frame: F,
406 ) where
407 Spec::TimerId: Debug,
408 {
409 for _ in 0..1_000_000 {
410 if self.step_with(&mut filter_map_frame).is_idle() {
411 return;
412 }
413 }
414 panic!("FakeNetwork seems to have gotten stuck in a loop.");
415 }
416
417 pub fn collect_frames(&mut self) {
425 let all_frames = self.contexts.iter_mut().filter_map(|(n, ctx)| {
426 Spec::fake_frames(ctx).with_fake_frame_ctx_mut(|ctx| {
427 let frames = ctx.take_frames();
428 if frames.is_empty() { None } else { Some((n.clone(), frames)) }
429 })
430 });
431
432 for (src_context, frames) in all_frames {
433 for (send_meta, frame) in frames.into_iter() {
434 for (dst_context, recv_meta, latency) in self.links.map_link(src_context, send_meta)
435 {
436 self.pending_frames.push(PendingFrame::new(
437 self.current_time + latency.unwrap_or(Duration::from_millis(0)),
438 PendingFrameData { frame: frame.clone(), dst_context, meta: recv_meta },
439 ));
440 }
441 }
442 }
443 }
444
445 pub fn dispatch_pending_frames<
456 F: FnMut(
457 &mut Spec::Context,
458 Spec::RecvMeta,
459 Buf<Vec<u8>>,
460 ) -> Option<(Spec::RecvMeta, Buf<Vec<u8>>)>,
461 >(
462 &mut self,
463 mut filter_map_frame: F,
464 ) -> usize {
465 let mut frames_sent = 0;
466 while let Some(InstantAndData(t, _)) = self.pending_frames.peek()
467 && *t <= self.current_time
468 {
469 let PendingFrameData { dst_context, meta, frame } =
471 self.pending_frames.pop().unwrap().1;
472 let dst_context = self.context(dst_context);
473 if let Some((meta, frame)) = filter_map_frame(dst_context, meta, Buf::new(frame, ..)) {
474 Spec::handle_frame(dst_context, meta, frame)
475 }
476 frames_sent += 1;
477 }
478
479 frames_sent
480 }
481
482 pub fn next_step(&self) -> Option<FakeInstant> {
488 let next_timer = self
490 .contexts
491 .iter()
492 .filter_map(|(_, ctx)| {
493 ctx.with_fake_timer_ctx(|ctx| match ctx.timers.peek() {
494 Some(tmr) => Some(tmr.0),
495 None => None,
496 })
497 })
498 .min();
499 let next_packet_due = self.pending_frames.peek().map(|t| t.0);
501
502 match next_timer {
505 Some(t) if next_packet_due.is_some() => Some(t).min(next_packet_due),
506 Some(t) => Some(t),
507 None => next_packet_due,
508 }
509 .map(|t| t.max(self.current_time))
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 use alloc::vec;
518
519 use crate::testutil::{FakeFrameCtx, FakeTimerCtx};
520 use crate::{SendFrameContext as _, TimerContext as _};
521
522 #[derive(Default)]
525 struct FakeNetworkTestCtx {
526 timer_ctx: FakeTimerCtx<u32>,
527 frame_ctx: FakeFrameCtx<()>,
528 fired_timers: HashMap<u32, usize>,
529 frames_received: usize,
530 }
531
532 impl FakeNetworkTestCtx {
533 #[track_caller]
534 fn drain_and_assert_timers(&mut self, iter: impl IntoIterator<Item = (u32, usize)>) {
535 for (timer, fire_count) in iter {
536 assert_eq!(self.fired_timers.remove(&timer), Some(fire_count), "for timer {timer}");
537 }
538 assert!(self.fired_timers.is_empty(), "remaining timers: {:?}", self.fired_timers);
539 }
540
541 fn request() -> Vec<u8> {
543 vec![1, 2, 3, 4]
544 }
545
546 fn response() -> Vec<u8> {
548 vec![4, 3, 2, 1]
549 }
550 }
551
552 impl FakeNetworkSpec for FakeNetworkTestCtx {
553 type Context = Self;
554 type TimerId = u32;
555 type SendMeta = ();
556 type RecvMeta = ();
557
558 fn handle_frame(ctx: &mut Self, _recv: (), data: Buf<Vec<u8>>) {
559 ctx.frames_received += 1;
560 if data.into_inner() == Self::request() {
563 ctx.frame_ctx.push((), Self::response())
564 }
565 }
566
567 fn handle_timer(ctx: &mut Self, dispatch: u32, _: FakeTimerId) {
568 *ctx.fired_timers.entry(dispatch).or_insert(0) += 1;
569 }
570
571 fn process_queues(_ctx: &mut Self) -> bool {
572 false
573 }
574
575 fn fake_frames(ctx: &mut Self) -> &mut impl WithFakeFrameContext<Self::SendMeta> {
576 ctx
577 }
578 }
579
580 impl WithFakeFrameContext<()> for FakeNetworkTestCtx {
581 fn with_fake_frame_ctx_mut<O, F: FnOnce(&mut FakeFrameCtx<()>) -> O>(&mut self, f: F) -> O {
582 f(&mut self.frame_ctx)
583 }
584 }
585
586 impl WithFakeTimerContext<u32> for FakeNetworkTestCtx {
587 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<u32>) -> O>(&self, f: F) -> O {
588 f(&self.timer_ctx)
589 }
590
591 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<u32>) -> O>(
592 &mut self,
593 f: F,
594 ) -> O {
595 f(&mut self.timer_ctx)
596 }
597 }
598
599 fn new_fake_network_with_latency(
600 latency: Option<Duration>,
601 ) -> FakeNetwork<FakeNetworkTestCtx, i32, impl FakeNetworkLinks<(), (), i32>> {
602 FakeNetwork::new(
603 [(1, FakeNetworkTestCtx::default()), (2, FakeNetworkTestCtx::default())],
604 move |id, ()| {
605 vec![(
606 match id {
607 1 => 2,
608 2 => 1,
609 _ => unreachable!(),
610 },
611 (),
612 latency,
613 )]
614 },
615 )
616 }
617
618 #[test]
619 fn timers() {
620 let mut net = new_fake_network_with_latency(None);
621
622 let (mut t1, mut t4, mut t5) =
623 net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| {
624 (timer_ctx.new_timer(1), timer_ctx.new_timer(4), timer_ctx.new_timer(5))
625 });
626
627 net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| {
628 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(1), &mut t1), None);
629 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(4), &mut t4), None);
630 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(5), &mut t5), None);
631 });
632
633 let (mut t2, mut t3, mut t6) =
634 net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
635 (timer_ctx.new_timer(2), timer_ctx.new_timer(3), timer_ctx.new_timer(6))
636 });
637
638 net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
639 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(2), &mut t2), None);
640 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(3), &mut t3), None);
641 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(5), &mut t6), None);
642 });
643
644 net.context(1).drain_and_assert_timers([]);
646 net.context(2).drain_and_assert_timers([]);
647 assert_eq!(net.step().timers_fired, 1);
648 net.context(1).drain_and_assert_timers([(1, 1)]);
650 net.context(2).drain_and_assert_timers([]);
651 assert_eq!(net.step().timers_fired, 1);
652 net.context(1).drain_and_assert_timers([]);
654 net.context(2).drain_and_assert_timers([(2, 1)]);
655 assert_eq!(net.step().timers_fired, 1);
656 net.context(1).drain_and_assert_timers([]);
658 net.context(2).drain_and_assert_timers([(3, 1)]);
659 assert_eq!(net.step().timers_fired, 1);
660 net.context(1).drain_and_assert_timers([(4, 1)]);
662 net.context(2).drain_and_assert_timers([]);
663 assert_eq!(net.step().timers_fired, 2);
664 net.context(1).drain_and_assert_timers([(5, 1)]);
666 net.context(2).drain_and_assert_timers([(6, 1)]);
667
668 assert!(net.step().is_idle());
669 let t1 = net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| timer_ctx.now());
671 let t2 = net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| timer_ctx.now());
672 assert_eq!(t1, t2);
673 }
674
675 #[test]
676 fn until_idle() {
677 let mut net = new_fake_network_with_latency(None);
678
679 let mut t1 =
680 net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| timer_ctx.new_timer(1));
681 net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| {
682 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(1), &mut t1), None);
683 });
684
685 let (mut t2, mut t3) = net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
686 (timer_ctx.new_timer(2), timer_ctx.new_timer(3))
687 });
688 net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
689 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(2), &mut t2), None);
690 assert_eq!(timer_ctx.schedule_timer(Duration::from_secs(3), &mut t3), None);
691 });
692
693 while !net.step().is_idle() && net.context(1).fired_timers.len() < 1
694 || net.context(2).fired_timers.len() < 1
695 {}
696 assert_eq!(net.step().timers_fired, 1);
699 }
700
701 #[test]
702 fn delayed_packets() {
703 let mut net = new_fake_network_with_latency(Some(Duration::from_millis(5)));
705
706 let mut t11 =
708 net.with_context(1, |FakeNetworkTestCtx { timer_ctx, .. }| timer_ctx.new_timer(1));
709 net.with_context(1, |FakeNetworkTestCtx { frame_ctx, timer_ctx, .. }| {
710 frame_ctx.push((), FakeNetworkTestCtx::request());
711 assert_eq!(timer_ctx.schedule_timer(Duration::from_millis(3), &mut t11), None);
712 });
713 let (mut t21, mut t22) = net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
715 (timer_ctx.new_timer(1), timer_ctx.new_timer(2))
716 });
717 net.with_context(2, |FakeNetworkTestCtx { timer_ctx, .. }| {
718 assert_eq!(timer_ctx.schedule_timer(Duration::from_millis(7), &mut t22), None);
719 assert_eq!(timer_ctx.schedule_timer(Duration::from_millis(10), &mut t21), None);
720 });
721
722 let assert_full_state = |net: &mut FakeNetwork<FakeNetworkTestCtx, _, _>,
729 ctx1_timers,
730 ctx2_timers,
731 ctx2_frames,
732 ctx1_frames| {
733 let ctx1 = net.context(1);
734 assert_eq!(ctx1.fired_timers.len(), ctx1_timers);
735 assert_eq!(ctx1.frames_received, ctx1_frames);
736 let ctx2 = net.context(2);
737 assert_eq!(ctx2.fired_timers.len(), ctx2_timers);
738 assert_eq!(ctx2.frames_received, ctx2_frames);
739 };
740
741 assert_eq!(net.step().timers_fired, 1);
742 assert_full_state(&mut net, 1, 0, 0, 0);
743 assert_eq!(net.step().frames_sent, 1);
744 assert_full_state(&mut net, 1, 0, 1, 0);
745 assert_eq!(net.step().timers_fired, 1);
746 assert_full_state(&mut net, 1, 1, 1, 0);
747 let step = net.step();
748 assert_eq!(step.frames_sent, 1);
749 assert_eq!(step.timers_fired, 1);
750 assert_full_state(&mut net, 1, 2, 1, 1);
751
752 assert!(net.step().is_idle());
754 }
755
756 #[test]
757 fn fake_network_transmits_packets() {
758 let mut net = new_fake_network_with_latency(None);
759
760 net.with_context(1, |FakeNetworkTestCtx { frame_ctx, .. }| {
762 frame_ctx.send_frame(&mut (), (), Buf::new(FakeNetworkTestCtx::request(), ..)).unwrap();
763 });
764
765 assert_eq!(net.step().frames_sent, 1);
767 assert_eq!(net.step().frames_sent, 1);
769 assert!(net.step().is_idle());
771 }
772
773 #[test]
774 fn send_to_many() {
775 let mut net = FakeNetworkTestCtx::new_network(
776 [
777 (1, FakeNetworkTestCtx::default()),
778 (2, FakeNetworkTestCtx::default()),
779 (3, FakeNetworkTestCtx::default()),
780 ],
781 |id, ()| match id {
782 1 => vec![(2, (), None), (3, (), None)],
784 2 => vec![(1, (), None)],
786 3 => vec![],
788 _ => unreachable!(),
789 },
790 );
791 net.assert_no_pending_frames();
792
793 net.with_context(1, |FakeNetworkTestCtx { frame_ctx, .. }| {
795 frame_ctx.send_frame(&mut (), (), Buf::new(vec![], ..)).unwrap();
796 });
797 net.collect_frames();
798 assert_eq!(net.iter_pending_frames().count(), 2);
799 assert!(net.iter_pending_frames().any(|InstantAndData(_, x)| x.dst_context == 2));
800 assert!(net.iter_pending_frames().any(|InstantAndData(_, x)| x.dst_context == 3));
801 net.drop_pending_frames();
802
803 net.with_context(2, |FakeNetworkTestCtx { frame_ctx, .. }| {
805 frame_ctx.send_frame(&mut (), (), Buf::new(vec![], ..)).unwrap();
806 });
807 net.collect_frames();
808 assert_eq!(net.iter_pending_frames().count(), 1);
809 assert!(net.iter_pending_frames().any(|InstantAndData(_, x)| x.dst_context == 1));
810 net.drop_pending_frames();
811
812 net.with_context(3, |FakeNetworkTestCtx { frame_ctx, .. }| {
814 frame_ctx.send_frame(&mut (), (), Buf::new(vec![], ..)).unwrap();
815 });
816 net.collect_frames();
817 net.assert_no_pending_frames();
818
819 for i in 1..=3 {
822 assert_eq!(net.context(i).frames_received, 0, "context: {i}");
823 }
824 }
825}