1use diagnostics_traits::InspectableInstant;
10use fuchsia_async as fasync;
11use rand::Rng;
12
13pub trait RngProvider {
15 type RNG: Rng + ?Sized;
17
18 fn get_rng(&mut self) -> &mut Self::RNG;
20}
21
22impl RngProvider for rand::rngs::StdRng {
23 type RNG = Self;
24 fn get_rng(&mut self) -> &mut Self::RNG {
25 self
26 }
27}
28
29#[derive(Clone, Copy, PartialEq, Debug)]
30pub struct DatagramInfo<T> {
32 pub length: usize,
34 pub address: T,
37}
38
39#[derive(thiserror::Error, Debug)]
40pub enum SocketError {
42 #[error("failed to open socket: {0}")]
44 FailedToOpen(anyhow::Error),
45 #[error("tried to bind socket on nonexistent interface")]
47 NoInterface,
48 #[error("unsupported hardware type")]
50 UnsupportedHardwareType,
51 #[error("host unreachable")]
53 HostUnreachable,
54 #[error("network unreachable")]
56 NetworkUnreachable,
57 #[error("socket error: {0}")]
59 Other(std::io::Error),
60}
61
62#[allow(async_fn_in_trait)]
68pub trait Socket<T> {
69 async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError>;
71
72 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError>;
75}
76
77#[allow(async_fn_in_trait)]
83pub trait PacketSocketProvider {
84 type Sock: Socket<net_types::ethernet::Mac>;
86
87 async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError>;
91}
92
93#[allow(async_fn_in_trait)]
99pub trait UdpSocketProvider {
100 type Sock: Socket<std::net::SocketAddr>;
102
103 async fn bind_new_udp_socket(
106 &self,
107 bound_addr: std::net::SocketAddr,
108 ) -> Result<Self::Sock, SocketError>;
109}
110
111pub trait Instant:
113 Sized + Ord + Copy + Clone + std::fmt::Debug + Send + Sync + InspectableInstant
114{
115 fn add(&self, duration: std::time::Duration) -> Self;
118
119 fn average(&self, other: Self) -> Self;
121}
122
123impl Instant for fasync::MonotonicInstant {
124 fn add(&self, duration: std::time::Duration) -> Self {
125 #[allow(clippy::useless_conversion)]
128 {
129 *self + duration.into()
130 }
131 }
132
133 fn average(&self, other: Self) -> Self {
134 let lower = *self.min(&other);
135 let higher = *self.max(&other);
136 lower + (higher - lower) / 2
137 }
138}
139
140#[allow(async_fn_in_trait)]
146pub trait Clock {
147 type Instant: Instant;
149
150 async fn wait_until(&self, time: Self::Instant);
152
153 fn now(&self) -> Self::Instant;
155}
156
157#[cfg(test)]
158pub(crate) mod testutil {
159 use super::*;
160 use diagnostics_traits::InstantPropertyName;
161 use futures::channel::{mpsc, oneshot};
162 use futures::lock::Mutex;
163 use futures::StreamExt as _;
164 use rand::SeedableRng as _;
165 use std::cell::RefCell;
166 use std::cmp::Reverse;
167 use std::collections::BTreeMap;
168 use std::future::Future;
169 use std::ops::{Deref as _, DerefMut as _};
170 use std::rc::Rc;
171
172 pub(crate) struct FakeRngProvider {
174 std_rng: rand::rngs::StdRng,
175 }
176
177 impl FakeRngProvider {
178 pub(crate) fn new(seed: u64) -> Self {
179 Self { std_rng: rand::rngs::StdRng::seed_from_u64(seed) }
180 }
181 }
182
183 impl RngProvider for FakeRngProvider {
184 type RNG = rand::rngs::StdRng;
185 fn get_rng(&mut self) -> &mut Self::RNG {
186 &mut self.std_rng
187 }
188 }
189
190 pub(crate) struct FakeSocket<T> {
196 sender: mpsc::UnboundedSender<(Vec<u8>, T)>,
197 receiver: Mutex<mpsc::UnboundedReceiver<(Vec<u8>, T)>>,
198 }
199
200 impl<T> FakeSocket<T> {
201 pub(crate) fn new_pair() -> (FakeSocket<T>, FakeSocket<T>) {
202 let (send_a, recv_a) = mpsc::unbounded();
203 let (send_b, recv_b) = mpsc::unbounded();
204 (
205 FakeSocket { sender: send_a, receiver: Mutex::new(recv_b) },
206 FakeSocket { sender: send_b, receiver: Mutex::new(recv_a) },
207 )
208 }
209 }
210
211 impl<T: Send> Socket<T> for FakeSocket<T> {
212 async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError> {
213 let FakeSocket { sender, receiver: _ } = self;
214 sender.clone().unbounded_send((buf.to_vec(), addr)).expect("unbounded_send error");
215 Ok(())
216 }
217
218 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError> {
219 let FakeSocket { receiver, sender: _ } = self;
220 let mut receiver = receiver.lock().await;
221 let (bytes, addr) = receiver.next().await.expect("TestSocket receiver closed");
222 if buf.len() < bytes.len() {
223 panic!("TestSocket receiver would produce short read")
224 }
225 (buf[..bytes.len()]).copy_from_slice(&bytes);
226 Ok(DatagramInfo { length: bytes.len(), address: addr })
227 }
228 }
229
230 impl<T, U> Socket<U> for T
231 where
232 T: AsRef<FakeSocket<U>>,
233 U: Send + 'static,
234 {
235 async fn send_to(&self, buf: &[u8], addr: U) -> Result<(), SocketError> {
236 self.as_ref().send_to(buf, addr).await
237 }
238
239 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<U>, SocketError> {
240 self.as_ref().recv_from(buf).await
241 }
242 }
243
244 pub(crate) struct FakeSocketProvider<T, E> {
250 pub(crate) socket: Rc<FakeSocket<T>>,
252
253 pub(crate) bound_events: Option<mpsc::UnboundedSender<E>>,
255 }
256
257 impl<T, E> FakeSocketProvider<T, E> {
258 pub(crate) fn new(socket: FakeSocket<T>) -> Self {
259 Self { socket: Rc::new(socket), bound_events: None }
260 }
261
262 pub(crate) fn new_with_events(
263 socket: FakeSocket<T>,
264 bound_events: mpsc::UnboundedSender<E>,
265 ) -> Self {
266 Self { socket: Rc::new(socket), bound_events: Some(bound_events) }
267 }
268 }
269
270 impl PacketSocketProvider for FakeSocketProvider<net_types::ethernet::Mac, ()> {
271 type Sock = Rc<FakeSocket<net_types::ethernet::Mac>>;
272 async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError> {
273 let Self { socket, bound_events } = self;
274 if let Some(bound_events) = bound_events {
275 bound_events.unbounded_send(()).expect("events receiver should not be dropped");
276 }
277 Ok(socket.clone())
278 }
279 }
280
281 impl UdpSocketProvider for FakeSocketProvider<std::net::SocketAddr, std::net::SocketAddr> {
282 type Sock = Rc<FakeSocket<std::net::SocketAddr>>;
283 async fn bind_new_udp_socket(
284 &self,
285 bound_addr: std::net::SocketAddr,
286 ) -> Result<Self::Sock, SocketError> {
287 let Self { socket, bound_events } = self;
288 if let Some(bound_events) = bound_events {
289 bound_events
290 .unbounded_send(bound_addr)
291 .expect("events receiver should not be dropped");
292 }
293 Ok(socket.clone())
294 }
295 }
296
297 #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
298 pub(crate) struct TestInstant(pub(crate) std::time::Duration);
299
300 impl InspectableInstant for TestInstant {
301 fn record<I: diagnostics_traits::Inspector>(
302 &self,
303 name: InstantPropertyName,
304 inspector: &mut I,
305 ) {
306 inspector.record_debug(name.into(), self);
307 }
308 }
309
310 impl Instant for TestInstant {
311 fn add(&self, duration: std::time::Duration) -> Self {
312 Self(self.0.checked_add(duration).unwrap())
313 }
314
315 fn average(&self, other: Self) -> Self {
316 let lower = self.0.min(other.0);
317 let higher = self.0.max(other.0);
318 Self(lower + (higher - lower) / 2)
319 }
320 }
321
322 pub(crate) struct FakeTimeController {
325 pub(super) timer_heap:
326 BTreeMap<std::cmp::Reverse<std::time::Duration>, Vec<oneshot::Sender<()>>>,
327 pub(super) current_time: std::time::Duration,
328 }
329
330 impl FakeTimeController {
331 pub(crate) fn new() -> Rc<RefCell<FakeTimeController>> {
332 Rc::new(RefCell::new(FakeTimeController {
333 timer_heap: BTreeMap::default(),
334 current_time: std::time::Duration::default(),
335 }))
336 }
337 }
338
339 pub(crate) fn advance(ctl: &Rc<RefCell<FakeTimeController>>, duration: std::time::Duration) {
342 let timers_to_fire = {
343 let mut ctl = ctl.borrow_mut();
344 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
345 let next_time = *current_time + duration;
346 *current_time = next_time;
347 timer_heap.split_off(&std::cmp::Reverse(next_time))
348 };
349 for (_, senders) in timers_to_fire {
350 for sender in senders {
351 match sender.send(()) {
352 Ok(()) => (),
353 Err(()) => {
354 }
357 }
358 }
359 }
360 }
361
362 pub(crate) fn run_until_next_timers_fire<F>(
363 executor: &mut fasync::TestExecutor,
364 time: &Rc<RefCell<FakeTimeController>>,
365 main_future: &mut F,
366 ) -> std::task::Poll<F::Output>
367 where
368 F: Future + Unpin,
369 {
370 let poll: std::task::Poll<_> = executor.run_until_stalled(main_future);
371 if poll.is_ready() {
372 return poll;
373 }
374
375 {
376 let mut time = time.borrow_mut();
377 let FakeTimeController { timer_heap, current_time } = time.deref_mut();
378
379 let earliest_entry = timer_heap.last_entry().expect("no timers installed");
385
386 let (Reverse(instant), senders) = earliest_entry.remove_entry();
387 *current_time = instant;
388 for sender in senders {
389 match sender.send(()) {
390 Ok(()) => (),
391 Err(()) => {
392 }
395 }
396 }
397 }
398
399 executor.run_until_stalled(main_future)
400 }
401
402 impl Clock for Rc<RefCell<FakeTimeController>> {
403 type Instant = TestInstant;
404
405 fn now(&self) -> Self::Instant {
406 let ctl = self.borrow_mut();
407 let FakeTimeController { timer_heap: _, current_time } = ctl.deref();
408 TestInstant(*current_time)
409 }
410
411 async fn wait_until(&self, TestInstant(time): Self::Instant) {
412 log::info!("registering timer at {:?}", time);
413 let receiver = {
414 let mut ctl = self.borrow_mut();
415 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
416 if *current_time >= time {
417 return;
418 }
419 let (sender, receiver) = oneshot::channel();
420 timer_heap.entry(std::cmp::Reverse(time)).or_default().push(sender);
421 receiver
422 };
423 receiver.await.expect("shouldn't be cancelled")
424 }
425 }
426}
427
428#[cfg(test)]
429mod test {
430 use super::testutil::*;
431 use super::*;
432 use fuchsia_async as fasync;
433 use futures::channel::mpsc;
434 use futures::{FutureExt, StreamExt};
435 use net_declare::std_socket_addr;
436 use std::pin::pin;
437
438 #[test]
439 fn test_rng() {
440 let make_sequence = |seed| {
441 let mut rng = FakeRngProvider::new(seed);
442 std::iter::from_fn(|| Some(rng.get_rng().gen::<u32>())).take(5).collect::<Vec<_>>()
443 };
444 assert_eq!(
445 make_sequence(42),
446 make_sequence(42),
447 "should provide identical sequences with same seed"
448 );
449 assert_ne!(
450 make_sequence(42),
451 make_sequence(999999),
452 "should provide different sequences with different seeds"
453 );
454 }
455
456 #[fasync::run_singlethreaded(test)]
457 async fn test_socket() {
458 let (a, b) = FakeSocket::new_pair();
459 let to_send = [
460 (b"hello".to_vec(), "1.2.3.4:5".to_string()),
461 (b"test".to_vec(), "1.2.3.5:5".to_string()),
462 (b"socket".to_vec(), "1.2.3.6:5".to_string()),
463 ];
464
465 let mut buf = [0u8; 10];
466 for (msg, addr) in &to_send {
467 a.send_to(msg, addr.clone()).await.unwrap();
468
469 let DatagramInfo { length: n, address: received_addr } =
470 b.recv_from(&mut buf).await.unwrap();
471 assert_eq!(&received_addr, addr);
472 assert_eq!(&buf[..n], msg);
473 }
474
475 let (a, b) = (b, a);
476 for (msg, addr) in &to_send {
477 a.send_to(msg, addr.clone()).await.unwrap();
478
479 let DatagramInfo { length: n, address: received_addr } =
480 b.recv_from(&mut buf).await.unwrap();
481 assert_eq!(&received_addr, addr);
482 assert_eq!(&buf[..n], msg);
483 }
484 }
485
486 #[fasync::run_singlethreaded(test)]
487 #[should_panic]
488 async fn test_socket_panics_on_short_read() {
489 let (a, b) = FakeSocket::new_pair();
490
491 let mut buf = [0u8; 10];
492 let message = b"this message is way longer than 10 bytes";
493 a.send_to(message, "1.2.3.4:5".to_string()).await.unwrap();
494
495 let _: Result<_, _> = b.recv_from(&mut buf).await;
497 }
498
499 #[fasync::run_singlethreaded(test)]
500 async fn test_fake_udp_socket_provider() {
501 let (a, b) = FakeSocket::new_pair();
502 let (events_sender, mut events_receiver) = mpsc::unbounded();
503 let provider = FakeSocketProvider::new_with_events(b, events_sender);
504 const ADDR_1: std::net::SocketAddr = std_socket_addr!("1.1.1.1:11");
505 const ADDR_2: std::net::SocketAddr = std_socket_addr!("2.2.2.2:22");
506 const ADDR_3: std::net::SocketAddr = std_socket_addr!("3.3.3.3:33");
507 let b_1 = provider.bind_new_udp_socket(ADDR_1).await.expect("get packet socket");
508 assert_eq!(
509 events_receiver
510 .next()
511 .now_or_never()
512 .expect("should have received bound event")
513 .expect("stream should not have ended"),
514 ADDR_1
515 );
516
517 let b_2 = provider.bind_new_udp_socket(ADDR_2).await.expect("get packet socket");
518 assert_eq!(
519 events_receiver
520 .next()
521 .now_or_never()
522 .expect("should have received bound event")
523 .expect("stream should not have ended"),
524 ADDR_2
525 );
526
527 a.send_to(b"hello", ADDR_3).await.unwrap();
528 a.send_to(b"world", ADDR_3).await.unwrap();
529
530 let mut buf = [0u8; 5];
531 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
532 assert_eq!(&buf[..length], b"hello");
533 assert_eq!(address, ADDR_3);
534
535 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
536 assert_eq!(&buf[..length], b"world");
537 assert_eq!(address, ADDR_3);
538 }
539
540 #[fasync::run_singlethreaded(test)]
541 async fn test_fake_packet_socket_provider() {
542 let (a, b) = FakeSocket::new_pair();
543 let (events_sender, mut events_receiver) = mpsc::unbounded();
544 let provider = FakeSocketProvider::new_with_events(b, events_sender);
545 let b_1 = provider.get_packet_socket().await.expect("get packet socket");
546 events_receiver
547 .next()
548 .now_or_never()
549 .expect("should have received bound event")
550 .expect("stream should not have ended");
551
552 let b_2 = provider.get_packet_socket().await.expect("get packet socket");
553 events_receiver
554 .next()
555 .now_or_never()
556 .expect("should have received bound event")
557 .expect("stream should not have ended");
558
559 const ADDRESS: net_types::ethernet::Mac = net_declare::net_mac!("01:02:03:04:05:06");
560
561 a.send_to(b"hello", ADDRESS).await.unwrap();
562
563 a.send_to(b"world", ADDRESS).await.unwrap();
564
565 let mut buf = [0u8; 5];
566 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
567 assert_eq!(&buf[..length], b"hello");
568 assert_eq!(address, ADDRESS);
569
570 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
571 assert_eq!(&buf[..length], b"world");
572 assert_eq!(address, ADDRESS);
573 }
574
575 #[test]
576 fn test_time_controller() {
577 let time_ctl = FakeTimeController::new();
578 assert!(time_ctl.borrow().timer_heap.is_empty());
579 assert_eq!(time_ctl.borrow().current_time, std::time::Duration::from_secs(0));
580 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(0)));
581
582 let mut timer_registered_before_should_fire_1 =
583 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
584 let mut timer_registered_before_should_fire_2 =
585 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
586
587 let mut timer_should_not_fire =
588 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(10))));
589
590 {
593 let waker = futures::task::noop_waker();
594 let mut context = futures::task::Context::from_waker(&waker);
595 assert_eq!(
596 timer_registered_before_should_fire_1.poll_unpin(&mut context),
597 futures::task::Poll::Pending
598 );
599 assert_eq!(
600 timer_registered_before_should_fire_2.poll_unpin(&mut context),
601 futures::task::Poll::Pending
602 );
603 assert_eq!(
604 timer_should_not_fire.poll_unpin(&mut context),
605 futures::task::Poll::Pending
606 );
607 }
608
609 {
610 let time_ctl = time_ctl.borrow_mut();
611 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
612 assert_eq!(entries.len(), 2);
613
614 let (time, senders) = entries[0];
615 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
616 assert_eq!(senders.len(), 1);
617
618 let (time, senders) = entries[1];
619 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(1)));
620 assert_eq!(senders.len(), 2);
621 }
622
623 advance(&time_ctl, std::time::Duration::from_secs(1));
624
625 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(1)));
626 {
627 let time_ctl = time_ctl.borrow_mut();
628 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
629 assert_eq!(entries.len(), 1);
630 let (time, senders) = entries[0];
631 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
632 assert_eq!(senders.len(), 1);
633 }
634
635 assert_eq!(timer_registered_before_should_fire_1.now_or_never(), Some(()));
636 assert_eq!(timer_registered_before_should_fire_2.now_or_never(), Some(()));
637 assert_eq!(timer_should_not_fire.now_or_never(), None);
638
639 let timer_set_in_past = time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(0)));
640 assert_eq!(timer_set_in_past.now_or_never(), Some(()));
641
642 let timer_set_for_present = time_ctl.wait_until(time_ctl.now());
643 assert_eq!(timer_set_for_present.now_or_never(), Some(()));
644 }
645}