1use diagnostics_traits::InspectableInstant;
10use fuchsia_async as fasync;
11use rand::Rng;
12use std::future::Future;
13
14pub trait RngProvider {
16 type RNG: Rng + ?Sized;
18
19 fn get_rng(&mut self) -> &mut Self::RNG;
21}
22
23impl RngProvider for rand::rngs::StdRng {
24 type RNG = Self;
25 fn get_rng(&mut self) -> &mut Self::RNG {
26 self
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31pub struct DatagramInfo<T> {
33 pub length: usize,
35 pub address: T,
38}
39
40#[derive(thiserror::Error, Debug)]
41pub enum SocketError {
43 #[error("failed to open socket: {0}")]
45 FailedToOpen(anyhow::Error),
46 #[error("tried to bind socket on nonexistent interface")]
48 NoInterface,
49 #[error("unsupported hardware type")]
51 UnsupportedHardwareType,
52 #[error("host unreachable")]
54 HostUnreachable,
55 #[error("network unreachable")]
57 NetworkUnreachable,
58 #[error("address not available")]
60 AddrNotAvailable,
61 #[error("broken pipe")]
63 BrokenPipe,
64 #[error("connection aborted")]
66 ConnectionAborted,
67 #[error("socket error: {0}")]
69 Other(std::io::Error),
70}
71
72pub trait Socket<T> {
74 fn send_to(&self, buf: &[u8], addr: T) -> impl Future<Output = Result<(), SocketError>>;
76
77 fn recv_from(
80 &self,
81 buf: &mut [u8],
82 ) -> impl Future<Output = Result<DatagramInfo<T>, SocketError>>;
83}
84
85pub trait PacketSocketProvider {
87 type Sock: Socket<net_types::ethernet::Mac>;
89
90 fn get_packet_socket(&self) -> impl Future<Output = Result<Self::Sock, SocketError>>;
94}
95
96pub trait UdpSocketProvider {
98 type Sock: Socket<std::net::SocketAddr>;
100
101 fn bind_new_udp_socket(
104 &self,
105 bound_addr: std::net::SocketAddr,
106 ) -> impl Future<Output = Result<Self::Sock, SocketError>>;
107}
108
109pub trait Instant:
111 Sized + Ord + Copy + Clone + std::fmt::Debug + Send + Sync + InspectableInstant
112{
113 fn add(&self, duration: std::time::Duration) -> Self;
116
117 fn average(&self, other: Self) -> Self;
119}
120
121impl Instant for fasync::MonotonicInstant {
122 fn add(&self, duration: std::time::Duration) -> Self {
123 #[allow(clippy::useless_conversion)]
126 {
127 *self + duration.into()
128 }
129 }
130
131 fn average(&self, other: Self) -> Self {
132 let lower = *self.min(&other);
133 let higher = *self.max(&other);
134 lower + (higher - lower) / 2
135 }
136}
137
138pub trait Clock {
140 type Instant: Instant;
142
143 fn wait_until(&self, time: Self::Instant) -> impl Future<Output = ()>;
145
146 fn now(&self) -> Self::Instant;
148}
149
150#[cfg(test)]
151pub(crate) mod testutil {
152 use super::*;
153 use diagnostics_traits::InstantPropertyName;
154 use futures::StreamExt as _;
155 use futures::channel::{mpsc, oneshot};
156 use futures::lock::Mutex;
157 use rand::SeedableRng as _;
158 use std::cell::RefCell;
159 use std::cmp::Reverse;
160 use std::collections::BTreeMap;
161 use std::future::Future;
162 use std::ops::{Deref as _, DerefMut as _};
163 use std::rc::Rc;
164
165 pub(crate) struct FakeRngProvider {
167 std_rng: rand::rngs::StdRng,
168 }
169
170 impl FakeRngProvider {
171 pub(crate) fn new(seed: u64) -> Self {
172 Self { std_rng: rand::rngs::StdRng::seed_from_u64(seed) }
173 }
174 }
175
176 impl RngProvider for FakeRngProvider {
177 type RNG = rand::rngs::StdRng;
178 fn get_rng(&mut self) -> &mut Self::RNG {
179 &mut self.std_rng
180 }
181 }
182
183 pub(crate) struct FakeSocket<T> {
189 sender: mpsc::UnboundedSender<(Vec<u8>, T)>,
190 receiver: Mutex<mpsc::UnboundedReceiver<(Vec<u8>, T)>>,
191 }
192
193 impl<T> FakeSocket<T> {
194 pub(crate) fn new_pair() -> (FakeSocket<T>, FakeSocket<T>) {
195 let (send_a, recv_a) = mpsc::unbounded();
196 let (send_b, recv_b) = mpsc::unbounded();
197 (
198 FakeSocket { sender: send_a, receiver: Mutex::new(recv_b) },
199 FakeSocket { sender: send_b, receiver: Mutex::new(recv_a) },
200 )
201 }
202 }
203
204 impl<T: Send> Socket<T> for FakeSocket<T> {
205 async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError> {
206 let FakeSocket { sender, receiver: _ } = self;
207 sender.clone().unbounded_send((buf.to_vec(), addr)).expect("unbounded_send error");
208 Ok(())
209 }
210
211 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError> {
212 let FakeSocket { receiver, sender: _ } = self;
213 let mut receiver = receiver.lock().await;
214 let (bytes, addr) = receiver.next().await.expect("TestSocket receiver closed");
215 if buf.len() < bytes.len() {
216 panic!("TestSocket receiver would produce short read")
217 }
218 (buf[..bytes.len()]).copy_from_slice(&bytes);
219 Ok(DatagramInfo { length: bytes.len(), address: addr })
220 }
221 }
222
223 impl<T, U> Socket<U> for T
224 where
225 T: AsRef<FakeSocket<U>>,
226 U: Send + 'static,
227 {
228 async fn send_to(&self, buf: &[u8], addr: U) -> Result<(), SocketError> {
229 self.as_ref().send_to(buf, addr).await
230 }
231
232 async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<U>, SocketError> {
233 self.as_ref().recv_from(buf).await
234 }
235 }
236
237 pub(crate) struct FakeSocketProvider<T, E> {
243 pub(crate) socket: Rc<FakeSocket<T>>,
245
246 pub(crate) bound_events: Option<mpsc::UnboundedSender<E>>,
248 }
249
250 impl<T, E> FakeSocketProvider<T, E> {
251 pub(crate) fn new(socket: FakeSocket<T>) -> Self {
252 Self { socket: Rc::new(socket), bound_events: None }
253 }
254
255 pub(crate) fn new_with_events(
256 socket: FakeSocket<T>,
257 bound_events: mpsc::UnboundedSender<E>,
258 ) -> Self {
259 Self { socket: Rc::new(socket), bound_events: Some(bound_events) }
260 }
261 }
262
263 impl PacketSocketProvider for FakeSocketProvider<net_types::ethernet::Mac, ()> {
264 type Sock = Rc<FakeSocket<net_types::ethernet::Mac>>;
265 async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError> {
266 let Self { socket, bound_events } = self;
267 if let Some(bound_events) = bound_events {
268 bound_events.unbounded_send(()).expect("events receiver should not be dropped");
269 }
270 Ok(socket.clone())
271 }
272 }
273
274 impl UdpSocketProvider for FakeSocketProvider<std::net::SocketAddr, std::net::SocketAddr> {
275 type Sock = Rc<FakeSocket<std::net::SocketAddr>>;
276 async fn bind_new_udp_socket(
277 &self,
278 bound_addr: std::net::SocketAddr,
279 ) -> Result<Self::Sock, SocketError> {
280 let Self { socket, bound_events } = self;
281 if let Some(bound_events) = bound_events {
282 bound_events
283 .unbounded_send(bound_addr)
284 .expect("events receiver should not be dropped");
285 }
286 Ok(socket.clone())
287 }
288 }
289
290 #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
291 pub(crate) struct TestInstant(pub(crate) std::time::Duration);
292
293 impl InspectableInstant for TestInstant {
294 fn record<I: diagnostics_traits::Inspector>(
295 &self,
296 name: InstantPropertyName,
297 inspector: &mut I,
298 ) {
299 inspector.record_debug(name.into(), self);
300 }
301 }
302
303 impl Instant for TestInstant {
304 fn add(&self, duration: std::time::Duration) -> Self {
305 Self(self.0.checked_add(duration).unwrap())
306 }
307
308 fn average(&self, other: Self) -> Self {
309 let lower = self.0.min(other.0);
310 let higher = self.0.max(other.0);
311 Self(lower + (higher - lower) / 2)
312 }
313 }
314
315 pub(crate) struct FakeTimeController {
318 pub(super) timer_heap:
319 BTreeMap<std::cmp::Reverse<std::time::Duration>, Vec<oneshot::Sender<()>>>,
320 pub(super) current_time: std::time::Duration,
321 }
322
323 impl FakeTimeController {
324 pub(crate) fn new() -> Rc<RefCell<FakeTimeController>> {
325 Rc::new(RefCell::new(FakeTimeController {
326 timer_heap: BTreeMap::default(),
327 current_time: std::time::Duration::default(),
328 }))
329 }
330 }
331
332 pub(crate) fn advance(ctl: &Rc<RefCell<FakeTimeController>>, duration: std::time::Duration) {
335 let timers_to_fire = {
336 let mut ctl = ctl.borrow_mut();
337 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
338 let next_time = *current_time + duration;
339 *current_time = next_time;
340 timer_heap.split_off(&std::cmp::Reverse(next_time))
341 };
342 for (_, senders) in timers_to_fire {
343 for sender in senders {
344 match sender.send(()) {
345 Ok(()) => (),
346 Err(()) => {
347 }
350 }
351 }
352 }
353 }
354
355 pub(crate) fn run_until_next_timers_fire<F>(
356 executor: &mut fasync::TestExecutor,
357 time: &Rc<RefCell<FakeTimeController>>,
358 main_future: &mut F,
359 ) -> std::task::Poll<F::Output>
360 where
361 F: Future + Unpin,
362 {
363 let poll: std::task::Poll<_> = executor.run_until_stalled(main_future);
364 if poll.is_ready() {
365 return poll;
366 }
367
368 {
369 let mut time = time.borrow_mut();
370 let FakeTimeController { timer_heap, current_time } = time.deref_mut();
371
372 let earliest_entry = timer_heap.last_entry().expect("no timers installed");
378
379 let (Reverse(instant), senders) = earliest_entry.remove_entry();
380 *current_time = instant;
381 for sender in senders {
382 match sender.send(()) {
383 Ok(()) => (),
384 Err(()) => {
385 }
388 }
389 }
390 }
391
392 executor.run_until_stalled(main_future)
393 }
394
395 impl Clock for Rc<RefCell<FakeTimeController>> {
396 type Instant = TestInstant;
397
398 fn now(&self) -> Self::Instant {
399 let ctl = self.borrow_mut();
400 let FakeTimeController { timer_heap: _, current_time } = ctl.deref();
401 TestInstant(*current_time)
402 }
403
404 async fn wait_until(&self, TestInstant(time): Self::Instant) {
405 log::info!("registering timer at {:?}", time);
406 let receiver = {
407 let mut ctl = self.borrow_mut();
408 let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
409 if *current_time >= time {
410 return;
411 }
412 let (sender, receiver) = oneshot::channel();
413 timer_heap.entry(std::cmp::Reverse(time)).or_default().push(sender);
414 receiver
415 };
416 receiver.await.expect("shouldn't be cancelled")
417 }
418 }
419}
420
421#[cfg(test)]
422mod test {
423 use super::testutil::*;
424 use super::*;
425 use fuchsia_async as fasync;
426 use futures::channel::mpsc;
427 use futures::{FutureExt, StreamExt};
428 use net_declare::std_socket_addr;
429 use std::pin::pin;
430
431 #[test]
432 fn test_rng() {
433 let make_sequence = |seed| {
434 let mut rng = FakeRngProvider::new(seed);
435 std::iter::from_fn(|| Some(rng.get_rng().random::<u32>())).take(5).collect::<Vec<_>>()
436 };
437 assert_eq!(
438 make_sequence(42),
439 make_sequence(42),
440 "should provide identical sequences with same seed"
441 );
442 assert_ne!(
443 make_sequence(42),
444 make_sequence(999999),
445 "should provide different sequences with different seeds"
446 );
447 }
448
449 #[fasync::run_singlethreaded(test)]
450 async fn test_socket() {
451 let (a, b) = FakeSocket::new_pair();
452 let to_send = [
453 (b"hello".to_vec(), "1.2.3.4:5".to_string()),
454 (b"test".to_vec(), "1.2.3.5:5".to_string()),
455 (b"socket".to_vec(), "1.2.3.6:5".to_string()),
456 ];
457
458 let mut buf = [0u8; 10];
459 for (msg, addr) in &to_send {
460 a.send_to(msg, addr.clone()).await.unwrap();
461
462 let DatagramInfo { length: n, address: received_addr } =
463 b.recv_from(&mut buf).await.unwrap();
464 assert_eq!(&received_addr, addr);
465 assert_eq!(&buf[..n], msg);
466 }
467
468 let (a, b) = (b, a);
469 for (msg, addr) in &to_send {
470 a.send_to(msg, addr.clone()).await.unwrap();
471
472 let DatagramInfo { length: n, address: received_addr } =
473 b.recv_from(&mut buf).await.unwrap();
474 assert_eq!(&received_addr, addr);
475 assert_eq!(&buf[..n], msg);
476 }
477 }
478
479 #[fasync::run_singlethreaded(test)]
480 #[should_panic]
481 async fn test_socket_panics_on_short_read() {
482 let (a, b) = FakeSocket::new_pair();
483
484 let mut buf = [0u8; 10];
485 let message = b"this message is way longer than 10 bytes";
486 a.send_to(message, "1.2.3.4:5".to_string()).await.unwrap();
487
488 let _: Result<_, _> = b.recv_from(&mut buf).await;
490 }
491
492 #[fasync::run_singlethreaded(test)]
493 async fn test_fake_udp_socket_provider() {
494 let (a, b) = FakeSocket::new_pair();
495 let (events_sender, mut events_receiver) = mpsc::unbounded();
496 let provider = FakeSocketProvider::new_with_events(b, events_sender);
497 const ADDR_1: std::net::SocketAddr = std_socket_addr!("1.1.1.1:11");
498 const ADDR_2: std::net::SocketAddr = std_socket_addr!("2.2.2.2:22");
499 const ADDR_3: std::net::SocketAddr = std_socket_addr!("3.3.3.3:33");
500 let b_1 = provider.bind_new_udp_socket(ADDR_1).await.expect("get packet socket");
501 assert_eq!(
502 events_receiver
503 .next()
504 .now_or_never()
505 .expect("should have received bound event")
506 .expect("stream should not have ended"),
507 ADDR_1
508 );
509
510 let b_2 = provider.bind_new_udp_socket(ADDR_2).await.expect("get packet socket");
511 assert_eq!(
512 events_receiver
513 .next()
514 .now_or_never()
515 .expect("should have received bound event")
516 .expect("stream should not have ended"),
517 ADDR_2
518 );
519
520 a.send_to(b"hello", ADDR_3).await.unwrap();
521 a.send_to(b"world", ADDR_3).await.unwrap();
522
523 let mut buf = [0u8; 5];
524 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
525 assert_eq!(&buf[..length], b"hello");
526 assert_eq!(address, ADDR_3);
527
528 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
529 assert_eq!(&buf[..length], b"world");
530 assert_eq!(address, ADDR_3);
531 }
532
533 #[fasync::run_singlethreaded(test)]
534 async fn test_fake_packet_socket_provider() {
535 let (a, b) = FakeSocket::new_pair();
536 let (events_sender, mut events_receiver) = mpsc::unbounded();
537 let provider = FakeSocketProvider::new_with_events(b, events_sender);
538 let b_1 = provider.get_packet_socket().await.expect("get packet socket");
539 events_receiver
540 .next()
541 .now_or_never()
542 .expect("should have received bound event")
543 .expect("stream should not have ended");
544
545 let b_2 = provider.get_packet_socket().await.expect("get packet socket");
546 events_receiver
547 .next()
548 .now_or_never()
549 .expect("should have received bound event")
550 .expect("stream should not have ended");
551
552 const ADDRESS: net_types::ethernet::Mac = net_declare::net_mac!("01:02:03:04:05:06");
553
554 a.send_to(b"hello", ADDRESS).await.unwrap();
555
556 a.send_to(b"world", ADDRESS).await.unwrap();
557
558 let mut buf = [0u8; 5];
559 let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
560 assert_eq!(&buf[..length], b"hello");
561 assert_eq!(address, ADDRESS);
562
563 let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
564 assert_eq!(&buf[..length], b"world");
565 assert_eq!(address, ADDRESS);
566 }
567
568 #[test]
569 fn test_time_controller() {
570 let time_ctl = FakeTimeController::new();
571 assert!(time_ctl.borrow().timer_heap.is_empty());
572 assert_eq!(time_ctl.borrow().current_time, std::time::Duration::from_secs(0));
573 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(0)));
574
575 let mut timer_registered_before_should_fire_1 =
576 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
577 let mut timer_registered_before_should_fire_2 =
578 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
579
580 let mut timer_should_not_fire =
581 pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(10))));
582
583 {
586 let waker = std::task::Waker::noop();
587 let mut context = futures::task::Context::from_waker(&waker);
588 assert_eq!(
589 timer_registered_before_should_fire_1.poll_unpin(&mut context),
590 futures::task::Poll::Pending
591 );
592 assert_eq!(
593 timer_registered_before_should_fire_2.poll_unpin(&mut context),
594 futures::task::Poll::Pending
595 );
596 assert_eq!(
597 timer_should_not_fire.poll_unpin(&mut context),
598 futures::task::Poll::Pending
599 );
600 }
601
602 {
603 let time_ctl = time_ctl.borrow_mut();
604 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
605 assert_eq!(entries.len(), 2);
606
607 let (time, senders) = entries[0];
608 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
609 assert_eq!(senders.len(), 1);
610
611 let (time, senders) = entries[1];
612 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(1)));
613 assert_eq!(senders.len(), 2);
614 }
615
616 advance(&time_ctl, std::time::Duration::from_secs(1));
617
618 assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(1)));
619 {
620 let time_ctl = time_ctl.borrow_mut();
621 let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
622 assert_eq!(entries.len(), 1);
623 let (time, senders) = entries[0];
624 assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
625 assert_eq!(senders.len(), 1);
626 }
627
628 assert_eq!(timer_registered_before_should_fire_1.now_or_never(), Some(()));
629 assert_eq!(timer_registered_before_should_fire_2.now_or_never(), Some(()));
630 assert_eq!(timer_should_not_fire.now_or_never(), None);
631
632 let timer_set_in_past = time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(0)));
633 assert_eq!(timer_set_in_past.now_or_never(), Some(()));
634
635 let timer_set_for_present = time_ctl.wait_until(time_ctl.now());
636 assert_eq!(timer_set_for_present.now_or_never(), Some(()));
637 }
638}