1use diagnostics_traits::InspectableInstant;
10use fuchsia_async as fasync;
11use rand::Rng;
12use std::future::Future;
13
14pub trait RngProvider {
16 type RNG: Rng + ?Sized;
18
19 fn get_rng(&mut self) -> &mut Self::RNG;
21}
22
23impl RngProvider for rand::rngs::StdRng {
24 type RNG = Self;
25 fn get_rng(&mut self) -> &mut Self::RNG {
26 self
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub struct DatagramInfo<T> {
33 pub length: usize,
35 pub address: T,
38}
39
40#[derive(thiserror::Error, Debug)]
41pub enum SocketError {
43 #[error("failed to open socket: {0}")]
45 FailedToOpen(anyhow::Error),
46 #[error("tried to bind socket on nonexistent interface")]
48 NoInterface,
49 #[error("unsupported hardware type")]
51 UnsupportedHardwareType,
52 #[error("host unreachable")]
54 HostUnreachable,
55 #[error("network unreachable")]
57 NetworkUnreachable,
58 #[error("address not available")]
60 AddrNotAvailable,
61 #[error("socket error: {0}")]
63 Other(std::io::Error),
64}
65
66pub trait Socket<T> {
68 fn send_to(&self, buf: &[u8], addr: T) -> impl Future<Output = Result<(), SocketError>>;
70
71 fn recv_from(
74 &self,
75 buf: &mut [u8],
76 ) -> impl Future<Output = Result<DatagramInfo<T>, SocketError>>;
77}
78
79pub trait PacketSocketProvider {
81 type Sock: Socket<net_types::ethernet::Mac>;
83
84 fn get_packet_socket(&self) -> impl Future<Output = Result<Self::Sock, SocketError>>;
88}
89
90pub trait UdpSocketProvider {
92 type Sock: Socket<std::net::SocketAddr>;
94
95 fn bind_new_udp_socket(
98 &self,
99 bound_addr: std::net::SocketAddr,
100 ) -> impl Future<Output = Result<Self::Sock, SocketError>>;
101}
102
103pub trait Instant:
105 Sized + Ord + Copy + Clone + std::fmt::Debug + Send + Sync + InspectableInstant
106{
107 fn add(&self, duration: std::time::Duration) -> Self;
110
111 fn average(&self, other: Self) -> Self;
113}
114
115impl Instant for fasync::MonotonicInstant {
116 fn add(&self, duration: std::time::Duration) -> Self {
117 #[allow(clippy::useless_conversion)]
120 {
121 *self + duration.into()
122 }
123 }
124
125 fn average(&self, other: Self) -> Self {
126 let lower = *self.min(&other);
127 let higher = *self.max(&other);
128 lower + (higher - lower) / 2
129 }
130}
131
132pub trait Clock {
134 type Instant: Instant;
136
137 fn wait_until(&self, time: Self::Instant) -> impl Future<Output = ()>;
139
140 fn now(&self) -> Self::Instant;
142}
143
144#[cfg(test)]
145pub(crate) mod testutil {
146 use super::*;
147 use diagnostics_traits::InstantPropertyName;
148 use futures::StreamExt as _;
149 use futures::channel::{mpsc, oneshot};
150 use futures::lock::Mutex;
151 use rand::SeedableRng as _;
152 use std::cell::RefCell;
153 use std::cmp::Reverse;
154 use std::collections::BTreeMap;
155 use std::future::Future;
156 use std::ops::{Deref as _, DerefMut as _};
157 use std::rc::Rc;
158
159 pub(crate) struct FakeRngProvider {
161 std_rng: rand::rngs::StdRng,
162 }
163
164 impl FakeRngProvider {
165 pub(crate) fn new(seed: u64) -> Self {
166 Self { std_rng: rand::rngs::StdRng::seed_from_u64(seed) }
167 }
168 }
169
170 impl RngProvider for FakeRngProvider {
171 type RNG = rand::rngs::StdRng;
172 fn get_rng(&mut self) -> &mut Self::RNG {
173 &mut self.std_rng
174 }
175 }
176
177 pub(crate) struct FakeSocket<T> {
183 sender: mpsc::UnboundedSender<(Vec<u8>, T)>,
184 receiver: Mutex<mpsc::UnboundedReceiver<(Vec<u8>, T)>>,
185 }
186
187 impl<T> FakeSocket<T> {
188 pub(crate) fn new_pair() -> (FakeSocket<T>, FakeSocket<T>) {
189 let (send_a, recv_a) = mpsc::unbounded();
190 let (send_b, recv_b) = mpsc::unbounded();
191 (
192 FakeSocket { sender: send_a, receiver: Mutex::new(recv_b) },
193 FakeSocket { sender: send_b, receiver: Mutex::new(recv_a) },
194 )
195 }
196 }
197
198 impl<T: Send> Socket<T> for FakeSocket<T> {
199 async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError> {
200 let FakeSocket { sender, receiver: _ } = self;
201 sender.clone().unbounded_send((buf.to_vec(), addr)).expect("unbounded_send error");
202 Ok(())
203 }
204
205 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError> {
206 let FakeSocket { receiver, sender: _ } = self;
207 let mut receiver = receiver.lock().await;
208 let (bytes, addr) = receiver.next().await.expect("TestSocket receiver closed");
209 if buf.len() < bytes.len() {
210 panic!("TestSocket receiver would produce short read")
211 }
212 (buf[..bytes.len()]).copy_from_slice(&bytes);
213 Ok(DatagramInfo { length: bytes.len(), address: addr })
214 }
215 }
216
217 impl<T, U> Socket<U> for T
218 where
219 T: AsRef<FakeSocket<U>>,
220 U: Send + 'static,
221 {
222 async fn send_to(&self, buf: &[u8], addr: U) -> Result<(), SocketError> {
223 self.as_ref().send_to(buf, addr).await
224 }
225
226 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<U>, SocketError> {
227 self.as_ref().recv_from(buf).await
228 }
229 }
230
231 pub(crate) struct FakeSocketProvider<T, E> {
237 pub(crate) socket: Rc<FakeSocket<T>>,
239
240 pub(crate) bound_events: Option<mpsc::UnboundedSender<E>>,
242 }
243
244 impl<T, E> FakeSocketProvider<T, E> {
245 pub(crate) fn new(socket: FakeSocket<T>) -> Self {
246 Self { socket: Rc::new(socket), bound_events: None }
247 }
248
249 pub(crate) fn new_with_events(
250 socket: FakeSocket<T>,
251 bound_events: mpsc::UnboundedSender<E>,
252 ) -> Self {
253 Self { socket: Rc::new(socket), bound_events: Some(bound_events) }
254 }
255 }
256
257 impl PacketSocketProvider for FakeSocketProvider<net_types::ethernet::Mac, ()> {
258 type Sock = Rc<FakeSocket<net_types::ethernet::Mac>>;
259 async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError> {
260 let Self { socket, bound_events } = self;
261 if let Some(bound_events) = bound_events {
262 bound_events.unbounded_send(()).expect("events receiver should not be dropped");
263 }
264 Ok(socket.clone())
265 }
266 }
267
268 impl UdpSocketProvider for FakeSocketProvider<std::net::SocketAddr, std::net::SocketAddr> {
269 type Sock = Rc<FakeSocket<std::net::SocketAddr>>;
270 async fn bind_new_udp_socket(
271 &self,
272 bound_addr: std::net::SocketAddr,
273 ) -> Result<Self::Sock, SocketError> {
274 let Self { socket, bound_events } = self;
275 if let Some(bound_events) = bound_events {
276 bound_events
277 .unbounded_send(bound_addr)
278 .expect("events receiver should not be dropped");
279 }
280 Ok(socket.clone())
281 }
282 }
283
284 #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
285 pub(crate) struct TestInstant(pub(crate) std::time::Duration);
286
287 impl InspectableInstant for TestInstant {
288 fn record<I: diagnostics_traits::Inspector>(
289 &self,
290 name: InstantPropertyName,
291 inspector: &mut I,
292 ) {
293 inspector.record_debug(name.into(), self);
294 }
295 }
296
297 impl Instant for TestInstant {
298 fn add(&self, duration: std::time::Duration) -> Self {
299 Self(self.0.checked_add(duration).unwrap())
300 }
301
302 fn average(&self, other: Self) -> Self {
303 let lower = self.0.min(other.0);
304 let higher = self.0.max(other.0);
305 Self(lower + (higher - lower) / 2)
306 }
307 }
308
309 pub(crate) struct FakeTimeController {
312 pub(super) timer_heap:
313 BTreeMap<std::cmp::Reverse<std::time::Duration>, Vec<oneshot::Sender<()>>>,
314 pub(super) current_time: std::time::Duration,
315 }
316
317 impl FakeTimeController {
318 pub(crate) fn new() -> Rc<RefCell<FakeTimeController>> {
319 Rc::new(RefCell::new(FakeTimeController {
320 timer_heap: BTreeMap::default(),
321 current_time: std::time::Duration::default(),
322 }))
323 }
324 }
325
326 pub(crate) fn advance(ctl: &Rc<RefCell<FakeTimeController>>, duration: std::time::Duration) {
329 let timers_to_fire = {
330 let mut ctl = ctl.borrow_mut();
331 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
332 let next_time = *current_time + duration;
333 *current_time = next_time;
334 timer_heap.split_off(&std::cmp::Reverse(next_time))
335 };
336 for (_, senders) in timers_to_fire {
337 for sender in senders {
338 match sender.send(()) {
339 Ok(()) => (),
340 Err(()) => {
341 }
344 }
345 }
346 }
347 }
348
349 pub(crate) fn run_until_next_timers_fire<F>(
350 executor: &mut fasync::TestExecutor,
351 time: &Rc<RefCell<FakeTimeController>>,
352 main_future: &mut F,
353 ) -> std::task::Poll<F::Output>
354 where
355 F: Future + Unpin,
356 {
357 let poll: std::task::Poll<_> = executor.run_until_stalled(main_future);
358 if poll.is_ready() {
359 return poll;
360 }
361
362 {
363 let mut time = time.borrow_mut();
364 let FakeTimeController { timer_heap, current_time } = time.deref_mut();
365
366 let earliest_entry = timer_heap.last_entry().expect("no timers installed");
372
373 let (Reverse(instant), senders) = earliest_entry.remove_entry();
374 *current_time = instant;
375 for sender in senders {
376 match sender.send(()) {
377 Ok(()) => (),
378 Err(()) => {
379 }
382 }
383 }
384 }
385
386 executor.run_until_stalled(main_future)
387 }
388
389 impl Clock for Rc<RefCell<FakeTimeController>> {
390 type Instant = TestInstant;
391
392 fn now(&self) -> Self::Instant {
393 let ctl = self.borrow_mut();
394 let FakeTimeController { timer_heap: _, current_time } = ctl.deref();
395 TestInstant(*current_time)
396 }
397
398 async fn wait_until(&self, TestInstant(time): Self::Instant) {
399 log::info!("registering timer at {:?}", time);
400 let receiver = {
401 let mut ctl = self.borrow_mut();
402 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
403 if *current_time >= time {
404 return;
405 }
406 let (sender, receiver) = oneshot::channel();
407 timer_heap.entry(std::cmp::Reverse(time)).or_default().push(sender);
408 receiver
409 };
410 receiver.await.expect("shouldn't be cancelled")
411 }
412 }
413}
414
415#[cfg(test)]
416mod test {
417 use super::testutil::*;
418 use super::*;
419 use fuchsia_async as fasync;
420 use futures::channel::mpsc;
421 use futures::{FutureExt, StreamExt};
422 use net_declare::std_socket_addr;
423 use std::pin::pin;
424
425 #[test]
426 fn test_rng() {
427 let make_sequence = |seed| {
428 let mut rng = FakeRngProvider::new(seed);
429 std::iter::from_fn(|| Some(rng.get_rng().random::<u32>())).take(5).collect::<Vec<_>>()
430 };
431 assert_eq!(
432 make_sequence(42),
433 make_sequence(42),
434 "should provide identical sequences with same seed"
435 );
436 assert_ne!(
437 make_sequence(42),
438 make_sequence(999999),
439 "should provide different sequences with different seeds"
440 );
441 }
442
443 #[fasync::run_singlethreaded(test)]
444 async fn test_socket() {
445 let (a, b) = FakeSocket::new_pair();
446 let to_send = [
447 (b"hello".to_vec(), "1.2.3.4:5".to_string()),
448 (b"test".to_vec(), "1.2.3.5:5".to_string()),
449 (b"socket".to_vec(), "1.2.3.6:5".to_string()),
450 ];
451
452 let mut buf = [0u8; 10];
453 for (msg, addr) in &to_send {
454 a.send_to(msg, addr.clone()).await.unwrap();
455
456 let DatagramInfo { length: n, address: received_addr } =
457 b.recv_from(&mut buf).await.unwrap();
458 assert_eq!(&received_addr, addr);
459 assert_eq!(&buf[..n], msg);
460 }
461
462 let (a, b) = (b, a);
463 for (msg, addr) in &to_send {
464 a.send_to(msg, addr.clone()).await.unwrap();
465
466 let DatagramInfo { length: n, address: received_addr } =
467 b.recv_from(&mut buf).await.unwrap();
468 assert_eq!(&received_addr, addr);
469 assert_eq!(&buf[..n], msg);
470 }
471 }
472
473 #[fasync::run_singlethreaded(test)]
474 #[should_panic]
475 async fn test_socket_panics_on_short_read() {
476 let (a, b) = FakeSocket::new_pair();
477
478 let mut buf = [0u8; 10];
479 let message = b"this message is way longer than 10 bytes";
480 a.send_to(message, "1.2.3.4:5".to_string()).await.unwrap();
481
482 let _: Result<_, _> = b.recv_from(&mut buf).await;
484 }
485
486 #[fasync::run_singlethreaded(test)]
487 async fn test_fake_udp_socket_provider() {
488 let (a, b) = FakeSocket::new_pair();
489 let (events_sender, mut events_receiver) = mpsc::unbounded();
490 let provider = FakeSocketProvider::new_with_events(b, events_sender);
491 const ADDR_1: std::net::SocketAddr = std_socket_addr!("1.1.1.1:11");
492 const ADDR_2: std::net::SocketAddr = std_socket_addr!("2.2.2.2:22");
493 const ADDR_3: std::net::SocketAddr = std_socket_addr!("3.3.3.3:33");
494 let b_1 = provider.bind_new_udp_socket(ADDR_1).await.expect("get packet socket");
495 assert_eq!(
496 events_receiver
497 .next()
498 .now_or_never()
499 .expect("should have received bound event")
500 .expect("stream should not have ended"),
501 ADDR_1
502 );
503
504 let b_2 = provider.bind_new_udp_socket(ADDR_2).await.expect("get packet socket");
505 assert_eq!(
506 events_receiver
507 .next()
508 .now_or_never()
509 .expect("should have received bound event")
510 .expect("stream should not have ended"),
511 ADDR_2
512 );
513
514 a.send_to(b"hello", ADDR_3).await.unwrap();
515 a.send_to(b"world", ADDR_3).await.unwrap();
516
517 let mut buf = [0u8; 5];
518 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
519 assert_eq!(&buf[..length], b"hello");
520 assert_eq!(address, ADDR_3);
521
522 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
523 assert_eq!(&buf[..length], b"world");
524 assert_eq!(address, ADDR_3);
525 }
526
527 #[fasync::run_singlethreaded(test)]
528 async fn test_fake_packet_socket_provider() {
529 let (a, b) = FakeSocket::new_pair();
530 let (events_sender, mut events_receiver) = mpsc::unbounded();
531 let provider = FakeSocketProvider::new_with_events(b, events_sender);
532 let b_1 = provider.get_packet_socket().await.expect("get packet socket");
533 events_receiver
534 .next()
535 .now_or_never()
536 .expect("should have received bound event")
537 .expect("stream should not have ended");
538
539 let b_2 = provider.get_packet_socket().await.expect("get packet socket");
540 events_receiver
541 .next()
542 .now_or_never()
543 .expect("should have received bound event")
544 .expect("stream should not have ended");
545
546 const ADDRESS: net_types::ethernet::Mac = net_declare::net_mac!("01:02:03:04:05:06");
547
548 a.send_to(b"hello", ADDRESS).await.unwrap();
549
550 a.send_to(b"world", ADDRESS).await.unwrap();
551
552 let mut buf = [0u8; 5];
553 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
554 assert_eq!(&buf[..length], b"hello");
555 assert_eq!(address, ADDRESS);
556
557 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
558 assert_eq!(&buf[..length], b"world");
559 assert_eq!(address, ADDRESS);
560 }
561
562 #[test]
563 fn test_time_controller() {
564 let time_ctl = FakeTimeController::new();
565 assert!(time_ctl.borrow().timer_heap.is_empty());
566 assert_eq!(time_ctl.borrow().current_time, std::time::Duration::from_secs(0));
567 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(0)));
568
569 let mut timer_registered_before_should_fire_1 =
570 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
571 let mut timer_registered_before_should_fire_2 =
572 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
573
574 let mut timer_should_not_fire =
575 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(10))));
576
577 {
580 let waker = std::task::Waker::noop();
581 let mut context = futures::task::Context::from_waker(&waker);
582 assert_eq!(
583 timer_registered_before_should_fire_1.poll_unpin(&mut context),
584 futures::task::Poll::Pending
585 );
586 assert_eq!(
587 timer_registered_before_should_fire_2.poll_unpin(&mut context),
588 futures::task::Poll::Pending
589 );
590 assert_eq!(
591 timer_should_not_fire.poll_unpin(&mut context),
592 futures::task::Poll::Pending
593 );
594 }
595
596 {
597 let time_ctl = time_ctl.borrow_mut();
598 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
599 assert_eq!(entries.len(), 2);
600
601 let (time, senders) = entries[0];
602 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
603 assert_eq!(senders.len(), 1);
604
605 let (time, senders) = entries[1];
606 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(1)));
607 assert_eq!(senders.len(), 2);
608 }
609
610 advance(&time_ctl, std::time::Duration::from_secs(1));
611
612 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(1)));
613 {
614 let time_ctl = time_ctl.borrow_mut();
615 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
616 assert_eq!(entries.len(), 1);
617 let (time, senders) = entries[0];
618 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
619 assert_eq!(senders.len(), 1);
620 }
621
622 assert_eq!(timer_registered_before_should_fire_1.now_or_never(), Some(()));
623 assert_eq!(timer_registered_before_should_fire_2.now_or_never(), Some(()));
624 assert_eq!(timer_should_not_fire.now_or_never(), None);
625
626 let timer_set_in_past = time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(0)));
627 assert_eq!(timer_set_in_past.now_or_never(), Some(()));
628
629 let timer_set_for_present = time_ctl.wait_until(time_ctl.now());
630 assert_eq!(timer_set_for_present.now_or_never(), Some(()));
631 }
632}