Skip to main content

dhcp_client_core/
deps.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Defines trait abstractions for platform dependencies of the DHCP client
6//! core, and provides fake implementations of these dependencies for testing
7//! purposes.
8
9use diagnostics_traits::InspectableInstant;
10use fuchsia_async as fasync;
11use rand::Rng;
12use std::future::Future;
13
14/// Provides access to random number generation.
15pub trait RngProvider {
16    /// The random number generator being provided.
17    type RNG: Rng + ?Sized;
18
19    /// Get access to a random number generator.
20    fn get_rng(&mut self) -> &mut Self::RNG;
21}
22
23impl RngProvider for rand::rngs::StdRng {
24    type RNG = Self;
25    fn get_rng(&mut self) -> &mut Self::RNG {
26        self
27    }
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31/// Contains information about a datagram received on a socket.
32pub struct DatagramInfo<T> {
33    /// The length in bytes of the datagram received on the socket.
34    pub length: usize,
35    /// The address associated with the datagram received on the socket
36    /// (usually, the address from which the datagram was received).
37    pub address: T,
38}
39
40#[derive(thiserror::Error, Debug)]
41/// Errors encountered while performing a socket operation.
42pub enum SocketError {
43    /// Failure while attempting to open a socket.
44    #[error("failed to open socket: {0}")]
45    FailedToOpen(anyhow::Error),
46    /// Tried to bind a socket on a nonexistent interface.
47    #[error("tried to bind socket on nonexistent interface")]
48    NoInterface,
49    /// The hardware type of the interface is unsupported.
50    #[error("unsupported hardware type")]
51    UnsupportedHardwareType,
52    /// The host we are attempting to send to is unreachable.
53    #[error("host unreachable")]
54    HostUnreachable,
55    /// The network is unreachable.
56    #[error("network unreachable")]
57    NetworkUnreachable,
58    /// The address is not available.
59    #[error("address not available")]
60    AddrNotAvailable,
61    /// Broken pipe.
62    #[error("broken pipe")]
63    BrokenPipe,
64    /// The connection was aborted.
65    #[error("connection aborted")]
66    ConnectionAborted,
67    /// Other IO errors observed on socket operations.
68    #[error("socket error: {0}")]
69    Other(std::io::Error),
70}
71
72/// Abstracts sending and receiving datagrams on a socket.
73pub trait Socket<T> {
74    /// Sends a datagram containing the contents of `buf` to `addr`.
75    fn send_to(&self, buf: &[u8], addr: T) -> impl Future<Output = Result<(), SocketError>>;
76
77    /// Receives a datagram into `buf`, returning the number of bytes received
78    /// and the address the datagram was received from.
79    fn recv_from(
80        &self,
81        buf: &mut [u8],
82    ) -> impl Future<Output = Result<DatagramInfo<T>, SocketError>>;
83}
84
85/// Provides access to AF_PACKET sockets.
86pub trait PacketSocketProvider {
87    /// The type of sockets provided by this `PacketSocketProvider`.
88    type Sock: Socket<net_types::ethernet::Mac>;
89
90    /// Gets a packet socket bound to the device on which the DHCP client
91    /// protocol is being performed. The packet socket should already be bound
92    /// to the appropriate device and protocol number.
93    fn get_packet_socket(&self) -> impl Future<Output = Result<Self::Sock, SocketError>>;
94}
95
96/// Provides access to UDP sockets.
97pub trait UdpSocketProvider {
98    /// The type of sockets provided by this `UdpSocketProvider`.
99    type Sock: Socket<std::net::SocketAddr>;
100
101    /// Gets a UDP socket bound to the given address. The UDP socket should be
102    /// allowed to send broadcast packets.
103    fn bind_new_udp_socket(
104        &self,
105        bound_addr: std::net::SocketAddr,
106    ) -> impl Future<Output = Result<Self::Sock, SocketError>>;
107}
108
109/// A type representing an instant in time.
110pub trait Instant:
111    Sized + Ord + Copy + Clone + std::fmt::Debug + Send + Sync + InspectableInstant
112{
113    /// Returns the time `self + duration`. Panics if `self + duration` would
114    /// overflow the underlying instant storage type.
115    fn add(&self, duration: std::time::Duration) -> Self;
116
117    /// Returns the instant halfway between `self` and `other`.
118    fn average(&self, other: Self) -> Self;
119}
120
121impl Instant for fasync::MonotonicInstant {
122    fn add(&self, duration: std::time::Duration) -> Self {
123        // On host builds, fasync::MonotonicDuration is simply an alias for
124        // std::time::Duration, making the `duration.into()` appear useless.
125        #[allow(clippy::useless_conversion)]
126        {
127            *self + duration.into()
128        }
129    }
130
131    fn average(&self, other: Self) -> Self {
132        let lower = *self.min(&other);
133        let higher = *self.max(&other);
134        lower + (higher - lower) / 2
135    }
136}
137
138/// Provides access to system-time-related operations.
139pub trait Clock {
140    /// The type representing monotonic system time.
141    type Instant: Instant;
142
143    /// Completes once the monotonic system time is at or after the given time.
144    fn wait_until(&self, time: Self::Instant) -> impl Future<Output = ()>;
145
146    /// Gets the monotonic system time.
147    fn now(&self) -> Self::Instant;
148}
149
150#[cfg(test)]
151pub(crate) mod testutil {
152    use super::*;
153    use diagnostics_traits::InstantPropertyName;
154    use futures::StreamExt as _;
155    use futures::channel::{mpsc, oneshot};
156    use futures::lock::Mutex;
157    use rand::SeedableRng as _;
158    use std::cell::RefCell;
159    use std::cmp::Reverse;
160    use std::collections::BTreeMap;
161    use std::future::Future;
162    use std::ops::{Deref as _, DerefMut as _};
163    use std::rc::Rc;
164
165    /// Provides a seedable implementation of `RngProvider` using `StdRng`.
166    pub(crate) struct FakeRngProvider {
167        std_rng: rand::rngs::StdRng,
168    }
169
170    impl FakeRngProvider {
171        pub(crate) fn new(seed: u64) -> Self {
172            Self { std_rng: rand::rngs::StdRng::seed_from_u64(seed) }
173        }
174    }
175
176    impl RngProvider for FakeRngProvider {
177        type RNG = rand::rngs::StdRng;
178        fn get_rng(&mut self) -> &mut Self::RNG {
179            &mut self.std_rng
180        }
181    }
182
183    /// Provides a fake implementation of `Socket` using `mpsc` channels.
184    ///
185    /// Simply forwards pairs of (payload, address) over the channel. This means
186    /// that the "sent to" address from the sender side is actually observed as
187    /// the "received from" address on the receiver side.
188    pub(crate) struct FakeSocket<T> {
189        sender: mpsc::UnboundedSender<(Vec<u8>, T)>,
190        receiver: Mutex<mpsc::UnboundedReceiver<(Vec<u8>, T)>>,
191    }
192
193    impl<T> FakeSocket<T> {
194        pub(crate) fn new_pair() -> (FakeSocket<T>, FakeSocket<T>) {
195            let (send_a, recv_a) = mpsc::unbounded();
196            let (send_b, recv_b) = mpsc::unbounded();
197            (
198                FakeSocket { sender: send_a, receiver: Mutex::new(recv_b) },
199                FakeSocket { sender: send_b, receiver: Mutex::new(recv_a) },
200            )
201        }
202    }
203
204    impl<T: Send> Socket<T> for FakeSocket<T> {
205        async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError> {
206            let FakeSocket { sender, receiver: _ } = self;
207            sender.clone().unbounded_send((buf.to_vec(), addr)).expect("unbounded_send error");
208            Ok(())
209        }
210
211        async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError> {
212            let FakeSocket { receiver, sender: _ } = self;
213            let mut receiver = receiver.lock().await;
214            let (bytes, addr) = receiver.next().await.expect("TestSocket receiver closed");
215            if buf.len() < bytes.len() {
216                panic!("TestSocket receiver would produce short read")
217            }
218            (buf[..bytes.len()]).copy_from_slice(&bytes);
219            Ok(DatagramInfo { length: bytes.len(), address: addr })
220        }
221    }
222
223    impl<T, U> Socket<U> for T
224    where
225        T: AsRef<FakeSocket<U>>,
226        U: Send + 'static,
227    {
228        async fn send_to(&self, buf: &[u8], addr: U) -> Result<(), SocketError> {
229            self.as_ref().send_to(buf, addr).await
230        }
231
232        async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<U>, SocketError> {
233            self.as_ref().recv_from(buf).await
234        }
235    }
236
237    /// Fake socket provider implementation that vends out copies of
238    /// the same `FakeSocket`.
239    ///
240    /// These copies will compete to receive and send on the same underlying
241    /// `mpsc` channels.
242    pub(crate) struct FakeSocketProvider<T, E> {
243        /// The socket being vended out.
244        pub(crate) socket: Rc<FakeSocket<T>>,
245
246        /// If present, used to notify tests when the client binds new sockets.
247        pub(crate) bound_events: Option<mpsc::UnboundedSender<E>>,
248    }
249
250    impl<T, E> FakeSocketProvider<T, E> {
251        pub(crate) fn new(socket: FakeSocket<T>) -> Self {
252            Self { socket: Rc::new(socket), bound_events: None }
253        }
254
255        pub(crate) fn new_with_events(
256            socket: FakeSocket<T>,
257            bound_events: mpsc::UnboundedSender<E>,
258        ) -> Self {
259            Self { socket: Rc::new(socket), bound_events: Some(bound_events) }
260        }
261    }
262
263    impl PacketSocketProvider for FakeSocketProvider<net_types::ethernet::Mac, ()> {
264        type Sock = Rc<FakeSocket<net_types::ethernet::Mac>>;
265        async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError> {
266            let Self { socket, bound_events } = self;
267            if let Some(bound_events) = bound_events {
268                bound_events.unbounded_send(()).expect("events receiver should not be dropped");
269            }
270            Ok(socket.clone())
271        }
272    }
273
274    impl UdpSocketProvider for FakeSocketProvider<std::net::SocketAddr, std::net::SocketAddr> {
275        type Sock = Rc<FakeSocket<std::net::SocketAddr>>;
276        async fn bind_new_udp_socket(
277            &self,
278            bound_addr: std::net::SocketAddr,
279        ) -> Result<Self::Sock, SocketError> {
280            let Self { socket, bound_events } = self;
281            if let Some(bound_events) = bound_events {
282                bound_events
283                    .unbounded_send(bound_addr)
284                    .expect("events receiver should not be dropped");
285            }
286            Ok(socket.clone())
287        }
288    }
289
290    #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
291    pub(crate) struct TestInstant(pub(crate) std::time::Duration);
292
293    impl InspectableInstant for TestInstant {
294        fn record<I: diagnostics_traits::Inspector>(
295            &self,
296            name: InstantPropertyName,
297            inspector: &mut I,
298        ) {
299            inspector.record_debug(name.into(), self);
300        }
301    }
302
303    impl Instant for TestInstant {
304        fn add(&self, duration: std::time::Duration) -> Self {
305            Self(self.0.checked_add(duration).unwrap())
306        }
307
308        fn average(&self, other: Self) -> Self {
309            let lower = self.0.min(other.0);
310            let higher = self.0.max(other.0);
311            Self(lower + (higher - lower) / 2)
312        }
313    }
314
315    /// Fake implementation of `Time` that uses `std::time::Duration` as its
316    /// `Instant` type.
317    pub(crate) struct FakeTimeController {
318        pub(super) timer_heap:
319            BTreeMap<std::cmp::Reverse<std::time::Duration>, Vec<oneshot::Sender<()>>>,
320        pub(super) current_time: std::time::Duration,
321    }
322
323    impl FakeTimeController {
324        pub(crate) fn new() -> Rc<RefCell<FakeTimeController>> {
325            Rc::new(RefCell::new(FakeTimeController {
326                timer_heap: BTreeMap::default(),
327                current_time: std::time::Duration::default(),
328            }))
329        }
330    }
331
332    /// Advances the "current time" encoded by `ctl` by `duration`. Any timers
333    /// that were set at or before the resulting "current time" will fire.
334    pub(crate) fn advance(ctl: &Rc<RefCell<FakeTimeController>>, duration: std::time::Duration) {
335        let timers_to_fire = {
336            let mut ctl = ctl.borrow_mut();
337            let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
338            let next_time = *current_time + duration;
339            *current_time = next_time;
340            timer_heap.split_off(&std::cmp::Reverse(next_time))
341        };
342        for (_, senders) in timers_to_fire {
343            for sender in senders {
344                match sender.send(()) {
345                    Ok(()) => (),
346                    Err(()) => {
347                        // ignore, it's fine for the client core to drop a timer
348                        // to cancel it
349                    }
350                }
351            }
352        }
353    }
354
355    pub(crate) fn run_until_next_timers_fire<F>(
356        executor: &mut fasync::TestExecutor,
357        time: &Rc<RefCell<FakeTimeController>>,
358        main_future: &mut F,
359    ) -> std::task::Poll<F::Output>
360    where
361        F: Future + Unpin,
362    {
363        let poll: std::task::Poll<_> = executor.run_until_stalled(main_future);
364        if poll.is_ready() {
365            return poll;
366        }
367
368        {
369            let mut time = time.borrow_mut();
370            let FakeTimeController { timer_heap, current_time } = time.deref_mut();
371
372            // NOTE: the timer heap is ordered by Reverse<Duration> in order to
373            // facilitate the implementation of `advance()` by making
374            // `BTreeMap::split_off` have the right edge-case behavior. This
375            // makes it easy to get it first_entry vs last_entry mixed up here,
376            // though.
377            let earliest_entry = timer_heap.last_entry().expect("no timers installed");
378
379            let (Reverse(instant), senders) = earliest_entry.remove_entry();
380            *current_time = instant;
381            for sender in senders {
382                match sender.send(()) {
383                    Ok(()) => (),
384                    Err(()) => {
385                        // ignore, it's fine for the client core to drop a timer
386                        // to cancel it
387                    }
388                }
389            }
390        }
391
392        executor.run_until_stalled(main_future)
393    }
394
395    impl Clock for Rc<RefCell<FakeTimeController>> {
396        type Instant = TestInstant;
397
398        fn now(&self) -> Self::Instant {
399            let ctl = self.borrow_mut();
400            let FakeTimeController { timer_heap: _, current_time } = ctl.deref();
401            TestInstant(*current_time)
402        }
403
404        async fn wait_until(&self, TestInstant(time): Self::Instant) {
405            log::info!("registering timer at {:?}", time);
406            let receiver = {
407                let mut ctl = self.borrow_mut();
408                let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
409                if *current_time >= time {
410                    return;
411                }
412                let (sender, receiver) = oneshot::channel();
413                timer_heap.entry(std::cmp::Reverse(time)).or_default().push(sender);
414                receiver
415            };
416            receiver.await.expect("shouldn't be cancelled")
417        }
418    }
419}
420
421#[cfg(test)]
422mod test {
423    use super::testutil::*;
424    use super::*;
425    use fuchsia_async as fasync;
426    use futures::channel::mpsc;
427    use futures::{FutureExt, StreamExt};
428    use net_declare::std_socket_addr;
429    use std::pin::pin;
430
431    #[test]
432    fn test_rng() {
433        let make_sequence = |seed| {
434            let mut rng = FakeRngProvider::new(seed);
435            std::iter::from_fn(|| Some(rng.get_rng().random::<u32>())).take(5).collect::<Vec<_>>()
436        };
437        assert_eq!(
438            make_sequence(42),
439            make_sequence(42),
440            "should provide identical sequences with same seed"
441        );
442        assert_ne!(
443            make_sequence(42),
444            make_sequence(999999),
445            "should provide different sequences with different seeds"
446        );
447    }
448
449    #[fasync::run_singlethreaded(test)]
450    async fn test_socket() {
451        let (a, b) = FakeSocket::new_pair();
452        let to_send = [
453            (b"hello".to_vec(), "1.2.3.4:5".to_string()),
454            (b"test".to_vec(), "1.2.3.5:5".to_string()),
455            (b"socket".to_vec(), "1.2.3.6:5".to_string()),
456        ];
457
458        let mut buf = [0u8; 10];
459        for (msg, addr) in &to_send {
460            a.send_to(msg, addr.clone()).await.unwrap();
461
462            let DatagramInfo { length: n, address: received_addr } =
463                b.recv_from(&mut buf).await.unwrap();
464            assert_eq!(&received_addr, addr);
465            assert_eq!(&buf[..n], msg);
466        }
467
468        let (a, b) = (b, a);
469        for (msg, addr) in &to_send {
470            a.send_to(msg, addr.clone()).await.unwrap();
471
472            let DatagramInfo { length: n, address: received_addr } =
473                b.recv_from(&mut buf).await.unwrap();
474            assert_eq!(&received_addr, addr);
475            assert_eq!(&buf[..n], msg);
476        }
477    }
478
479    #[fasync::run_singlethreaded(test)]
480    #[should_panic]
481    async fn test_socket_panics_on_short_read() {
482        let (a, b) = FakeSocket::new_pair();
483
484        let mut buf = [0u8; 10];
485        let message = b"this message is way longer than 10 bytes";
486        a.send_to(message, "1.2.3.4:5".to_string()).await.unwrap();
487
488        // Should panic here.
489        let _: Result<_, _> = b.recv_from(&mut buf).await;
490    }
491
492    #[fasync::run_singlethreaded(test)]
493    async fn test_fake_udp_socket_provider() {
494        let (a, b) = FakeSocket::new_pair();
495        let (events_sender, mut events_receiver) = mpsc::unbounded();
496        let provider = FakeSocketProvider::new_with_events(b, events_sender);
497        const ADDR_1: std::net::SocketAddr = std_socket_addr!("1.1.1.1:11");
498        const ADDR_2: std::net::SocketAddr = std_socket_addr!("2.2.2.2:22");
499        const ADDR_3: std::net::SocketAddr = std_socket_addr!("3.3.3.3:33");
500        let b_1 = provider.bind_new_udp_socket(ADDR_1).await.expect("get packet socket");
501        assert_eq!(
502            events_receiver
503                .next()
504                .now_or_never()
505                .expect("should have received bound event")
506                .expect("stream should not have ended"),
507            ADDR_1
508        );
509
510        let b_2 = provider.bind_new_udp_socket(ADDR_2).await.expect("get packet socket");
511        assert_eq!(
512            events_receiver
513                .next()
514                .now_or_never()
515                .expect("should have received bound event")
516                .expect("stream should not have ended"),
517            ADDR_2
518        );
519
520        a.send_to(b"hello", ADDR_3).await.unwrap();
521        a.send_to(b"world", ADDR_3).await.unwrap();
522
523        let mut buf = [0u8; 5];
524        let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
525        assert_eq!(&buf[..length], b"hello");
526        assert_eq!(address, ADDR_3);
527
528        let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
529        assert_eq!(&buf[..length], b"world");
530        assert_eq!(address, ADDR_3);
531    }
532
533    #[fasync::run_singlethreaded(test)]
534    async fn test_fake_packet_socket_provider() {
535        let (a, b) = FakeSocket::new_pair();
536        let (events_sender, mut events_receiver) = mpsc::unbounded();
537        let provider = FakeSocketProvider::new_with_events(b, events_sender);
538        let b_1 = provider.get_packet_socket().await.expect("get packet socket");
539        events_receiver
540            .next()
541            .now_or_never()
542            .expect("should have received bound event")
543            .expect("stream should not have ended");
544
545        let b_2 = provider.get_packet_socket().await.expect("get packet socket");
546        events_receiver
547            .next()
548            .now_or_never()
549            .expect("should have received bound event")
550            .expect("stream should not have ended");
551
552        const ADDRESS: net_types::ethernet::Mac = net_declare::net_mac!("01:02:03:04:05:06");
553
554        a.send_to(b"hello", ADDRESS).await.unwrap();
555
556        a.send_to(b"world", ADDRESS).await.unwrap();
557
558        let mut buf = [0u8; 5];
559        let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
560        assert_eq!(&buf[..length], b"hello");
561        assert_eq!(address, ADDRESS);
562
563        let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
564        assert_eq!(&buf[..length], b"world");
565        assert_eq!(address, ADDRESS);
566    }
567
568    #[test]
569    fn test_time_controller() {
570        let time_ctl = FakeTimeController::new();
571        assert!(time_ctl.borrow().timer_heap.is_empty());
572        assert_eq!(time_ctl.borrow().current_time, std::time::Duration::from_secs(0));
573        assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(0)));
574
575        let mut timer_registered_before_should_fire_1 =
576            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
577        let mut timer_registered_before_should_fire_2 =
578            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
579
580        let mut timer_should_not_fire =
581            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(10))));
582
583        // Poll the timer futures once so that they have the chance to
584        // register themselves in our timer heap.
585        {
586            let waker = std::task::Waker::noop();
587            let mut context = futures::task::Context::from_waker(&waker);
588            assert_eq!(
589                timer_registered_before_should_fire_1.poll_unpin(&mut context),
590                futures::task::Poll::Pending
591            );
592            assert_eq!(
593                timer_registered_before_should_fire_2.poll_unpin(&mut context),
594                futures::task::Poll::Pending
595            );
596            assert_eq!(
597                timer_should_not_fire.poll_unpin(&mut context),
598                futures::task::Poll::Pending
599            );
600        }
601
602        {
603            let time_ctl = time_ctl.borrow_mut();
604            let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
605            assert_eq!(entries.len(), 2);
606
607            let (time, senders) = entries[0];
608            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
609            assert_eq!(senders.len(), 1);
610
611            let (time, senders) = entries[1];
612            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(1)));
613            assert_eq!(senders.len(), 2);
614        }
615
616        advance(&time_ctl, std::time::Duration::from_secs(1));
617
618        assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(1)));
619        {
620            let time_ctl = time_ctl.borrow_mut();
621            let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
622            assert_eq!(entries.len(), 1);
623            let (time, senders) = entries[0];
624            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
625            assert_eq!(senders.len(), 1);
626        }
627
628        assert_eq!(timer_registered_before_should_fire_1.now_or_never(), Some(()));
629        assert_eq!(timer_registered_before_should_fire_2.now_or_never(), Some(()));
630        assert_eq!(timer_should_not_fire.now_or_never(), None);
631
632        let timer_set_in_past = time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(0)));
633        assert_eq!(timer_set_in_past.now_or_never(), Some(()));
634
635        let timer_set_for_present = time_ctl.wait_until(time_ctl.now());
636        assert_eq!(timer_set_for_present.now_or_never(), Some(()));
637    }
638}