Skip to main content

dhcp_client_core/
deps.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Defines trait abstractions for platform dependencies of the DHCP client
6//! core, and provides fake implementations of these dependencies for testing
7//! purposes.
8
9use diagnostics_traits::InspectableInstant;
10use fuchsia_async as fasync;
11use rand::Rng;
12use std::future::Future;
13
14/// Provides access to random number generation.
15pub trait RngProvider {
16    /// The random number generator being provided.
17    type RNG: Rng + ?Sized;
18
19    /// Get access to a random number generator.
20    fn get_rng(&mut self) -> &mut Self::RNG;
21}
22
23impl RngProvider for rand::rngs::StdRng {
24    type RNG = Self;
25    fn get_rng(&mut self) -> &mut Self::RNG {
26        self
27    }
28}
29
30#[derive(Clone, Copy, PartialEq, Debug)]
31/// Contains information about a datagram received on a socket.
32pub struct DatagramInfo<T> {
33    /// The length in bytes of the datagram received on the socket.
34    pub length: usize,
35    /// The address associated with the datagram received on the socket
36    /// (usually, the address from which the datagram was received).
37    pub address: T,
38}
39
40#[derive(thiserror::Error, Debug)]
41/// Errors encountered while performing a socket operation.
42pub enum SocketError {
43    /// Failure while attempting to open a socket.
44    #[error("failed to open socket: {0}")]
45    FailedToOpen(anyhow::Error),
46    /// Tried to bind a socket on a nonexistent interface.
47    #[error("tried to bind socket on nonexistent interface")]
48    NoInterface,
49    /// The hardware type of the interface is unsupported.
50    #[error("unsupported hardware type")]
51    UnsupportedHardwareType,
52    /// The host we are attempting to send to is unreachable.
53    #[error("host unreachable")]
54    HostUnreachable,
55    /// The network is unreachable.
56    #[error("network unreachable")]
57    NetworkUnreachable,
58    /// The address is not available.
59    #[error("address not available")]
60    AddrNotAvailable,
61    /// Other IO errors observed on socket operations.
62    #[error("socket error: {0}")]
63    Other(std::io::Error),
64}
65
66/// Abstracts sending and receiving datagrams on a socket.
67pub trait Socket<T> {
68    /// Sends a datagram containing the contents of `buf` to `addr`.
69    fn send_to(&self, buf: &[u8], addr: T) -> impl Future<Output = Result<(), SocketError>>;
70
71    /// Receives a datagram into `buf`, returning the number of bytes received
72    /// and the address the datagram was received from.
73    fn recv_from(
74        &self,
75        buf: &mut [u8],
76    ) -> impl Future<Output = Result<DatagramInfo<T>, SocketError>>;
77}
78
79/// Provides access to AF_PACKET sockets.
80pub trait PacketSocketProvider {
81    /// The type of sockets provided by this `PacketSocketProvider`.
82    type Sock: Socket<net_types::ethernet::Mac>;
83
84    /// Gets a packet socket bound to the device on which the DHCP client
85    /// protocol is being performed. The packet socket should already be bound
86    /// to the appropriate device and protocol number.
87    fn get_packet_socket(&self) -> impl Future<Output = Result<Self::Sock, SocketError>>;
88}
89
90/// Provides access to UDP sockets.
91pub trait UdpSocketProvider {
92    /// The type of sockets provided by this `UdpSocketProvider`.
93    type Sock: Socket<std::net::SocketAddr>;
94
95    /// Gets a UDP socket bound to the given address. The UDP socket should be
96    /// allowed to send broadcast packets.
97    fn bind_new_udp_socket(
98        &self,
99        bound_addr: std::net::SocketAddr,
100    ) -> impl Future<Output = Result<Self::Sock, SocketError>>;
101}
102
103/// A type representing an instant in time.
104pub trait Instant:
105    Sized + Ord + Copy + Clone + std::fmt::Debug + Send + Sync + InspectableInstant
106{
107    /// Returns the time `self + duration`. Panics if `self + duration` would
108    /// overflow the underlying instant storage type.
109    fn add(&self, duration: std::time::Duration) -> Self;
110
111    /// Returns the instant halfway between `self` and `other`.
112    fn average(&self, other: Self) -> Self;
113}
114
115impl Instant for fasync::MonotonicInstant {
116    fn add(&self, duration: std::time::Duration) -> Self {
117        // On host builds, fasync::MonotonicDuration is simply an alias for
118        // std::time::Duration, making the `duration.into()` appear useless.
119        #[allow(clippy::useless_conversion)]
120        {
121            *self + duration.into()
122        }
123    }
124
125    fn average(&self, other: Self) -> Self {
126        let lower = *self.min(&other);
127        let higher = *self.max(&other);
128        lower + (higher - lower) / 2
129    }
130}
131
132/// Provides access to system-time-related operations.
133pub trait Clock {
134    /// The type representing monotonic system time.
135    type Instant: Instant;
136
137    /// Completes once the monotonic system time is at or after the given time.
138    fn wait_until(&self, time: Self::Instant) -> impl Future<Output = ()>;
139
140    /// Gets the monotonic system time.
141    fn now(&self) -> Self::Instant;
142}
143
144#[cfg(test)]
145pub(crate) mod testutil {
146    use super::*;
147    use diagnostics_traits::InstantPropertyName;
148    use futures::StreamExt as _;
149    use futures::channel::{mpsc, oneshot};
150    use futures::lock::Mutex;
151    use rand::SeedableRng as _;
152    use std::cell::RefCell;
153    use std::cmp::Reverse;
154    use std::collections::BTreeMap;
155    use std::future::Future;
156    use std::ops::{Deref as _, DerefMut as _};
157    use std::rc::Rc;
158
159    /// Provides a seedable implementation of `RngProvider` using `StdRng`.
160    pub(crate) struct FakeRngProvider {
161        std_rng: rand::rngs::StdRng,
162    }
163
164    impl FakeRngProvider {
165        pub(crate) fn new(seed: u64) -> Self {
166            Self { std_rng: rand::rngs::StdRng::seed_from_u64(seed) }
167        }
168    }
169
170    impl RngProvider for FakeRngProvider {
171        type RNG = rand::rngs::StdRng;
172        fn get_rng(&mut self) -> &mut Self::RNG {
173            &mut self.std_rng
174        }
175    }
176
177    /// Provides a fake implementation of `Socket` using `mpsc` channels.
178    ///
179    /// Simply forwards pairs of (payload, address) over the channel. This means
180    /// that the "sent to" address from the sender side is actually observed as
181    /// the "received from" address on the receiver side.
182    pub(crate) struct FakeSocket<T> {
183        sender: mpsc::UnboundedSender<(Vec<u8>, T)>,
184        receiver: Mutex<mpsc::UnboundedReceiver<(Vec<u8>, T)>>,
185    }
186
187    impl<T> FakeSocket<T> {
188        pub(crate) fn new_pair() -> (FakeSocket<T>, FakeSocket<T>) {
189            let (send_a, recv_a) = mpsc::unbounded();
190            let (send_b, recv_b) = mpsc::unbounded();
191            (
192                FakeSocket { sender: send_a, receiver: Mutex::new(recv_b) },
193                FakeSocket { sender: send_b, receiver: Mutex::new(recv_a) },
194            )
195        }
196    }
197
198    impl<T: Send> Socket<T> for FakeSocket<T> {
199        async fn send_to(&self, buf: &[u8], addr: T) -> Result<(), SocketError> {
200            let FakeSocket { sender, receiver: _ } = self;
201            sender.clone().unbounded_send((buf.to_vec(), addr)).expect("unbounded_send error");
202            Ok(())
203        }
204
205        async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<T>, SocketError> {
206            let FakeSocket { receiver, sender: _ } = self;
207            let mut receiver = receiver.lock().await;
208            let (bytes, addr) = receiver.next().await.expect("TestSocket receiver closed");
209            if buf.len() < bytes.len() {
210                panic!("TestSocket receiver would produce short read")
211            }
212            (buf[..bytes.len()]).copy_from_slice(&bytes);
213            Ok(DatagramInfo { length: bytes.len(), address: addr })
214        }
215    }
216
217    impl<T, U> Socket<U> for T
218    where
219        T: AsRef<FakeSocket<U>>,
220        U: Send + 'static,
221    {
222        async fn send_to(&self, buf: &[u8], addr: U) -> Result<(), SocketError> {
223            self.as_ref().send_to(buf, addr).await
224        }
225
226        async fn recv_from(&self, buf: &mut [u8]) -> Result<DatagramInfo<U>, SocketError> {
227            self.as_ref().recv_from(buf).await
228        }
229    }
230
231    /// Fake socket provider implementation that vends out copies of
232    /// the same `FakeSocket`.
233    ///
234    /// These copies will compete to receive and send on the same underlying
235    /// `mpsc` channels.
236    pub(crate) struct FakeSocketProvider<T, E> {
237        /// The socket being vended out.
238        pub(crate) socket: Rc<FakeSocket<T>>,
239
240        /// If present, used to notify tests when the client binds new sockets.
241        pub(crate) bound_events: Option<mpsc::UnboundedSender<E>>,
242    }
243
244    impl<T, E> FakeSocketProvider<T, E> {
245        pub(crate) fn new(socket: FakeSocket<T>) -> Self {
246            Self { socket: Rc::new(socket), bound_events: None }
247        }
248
249        pub(crate) fn new_with_events(
250            socket: FakeSocket<T>,
251            bound_events: mpsc::UnboundedSender<E>,
252        ) -> Self {
253            Self { socket: Rc::new(socket), bound_events: Some(bound_events) }
254        }
255    }
256
257    impl PacketSocketProvider for FakeSocketProvider<net_types::ethernet::Mac, ()> {
258        type Sock = Rc<FakeSocket<net_types::ethernet::Mac>>;
259        async fn get_packet_socket(&self) -> Result<Self::Sock, SocketError> {
260            let Self { socket, bound_events } = self;
261            if let Some(bound_events) = bound_events {
262                bound_events.unbounded_send(()).expect("events receiver should not be dropped");
263            }
264            Ok(socket.clone())
265        }
266    }
267
268    impl UdpSocketProvider for FakeSocketProvider<std::net::SocketAddr, std::net::SocketAddr> {
269        type Sock = Rc<FakeSocket<std::net::SocketAddr>>;
270        async fn bind_new_udp_socket(
271            &self,
272            bound_addr: std::net::SocketAddr,
273        ) -> Result<Self::Sock, SocketError> {
274            let Self { socket, bound_events } = self;
275            if let Some(bound_events) = bound_events {
276                bound_events
277                    .unbounded_send(bound_addr)
278                    .expect("events receiver should not be dropped");
279            }
280            Ok(socket.clone())
281        }
282    }
283
284    #[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
285    pub(crate) struct TestInstant(pub(crate) std::time::Duration);
286
287    impl InspectableInstant for TestInstant {
288        fn record<I: diagnostics_traits::Inspector>(
289            &self,
290            name: InstantPropertyName,
291            inspector: &mut I,
292        ) {
293            inspector.record_debug(name.into(), self);
294        }
295    }
296
297    impl Instant for TestInstant {
298        fn add(&self, duration: std::time::Duration) -> Self {
299            Self(self.0.checked_add(duration).unwrap())
300        }
301
302        fn average(&self, other: Self) -> Self {
303            let lower = self.0.min(other.0);
304            let higher = self.0.max(other.0);
305            Self(lower + (higher - lower) / 2)
306        }
307    }
308
309    /// Fake implementation of `Time` that uses `std::time::Duration` as its
310    /// `Instant` type.
311    pub(crate) struct FakeTimeController {
312        pub(super) timer_heap:
313            BTreeMap<std::cmp::Reverse<std::time::Duration>, Vec<oneshot::Sender<()>>>,
314        pub(super) current_time: std::time::Duration,
315    }
316
317    impl FakeTimeController {
318        pub(crate) fn new() -> Rc<RefCell<FakeTimeController>> {
319            Rc::new(RefCell::new(FakeTimeController {
320                timer_heap: BTreeMap::default(),
321                current_time: std::time::Duration::default(),
322            }))
323        }
324    }
325
326    /// Advances the "current time" encoded by `ctl` by `duration`. Any timers
327    /// that were set at or before the resulting "current time" will fire.
328    pub(crate) fn advance(ctl: &Rc<RefCell<FakeTimeController>>, duration: std::time::Duration) {
329        let timers_to_fire = {
330            let mut ctl = ctl.borrow_mut();
331            let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
332            let next_time = *current_time + duration;
333            *current_time = next_time;
334            timer_heap.split_off(&std::cmp::Reverse(next_time))
335        };
336        for (_, senders) in timers_to_fire {
337            for sender in senders {
338                match sender.send(()) {
339                    Ok(()) => (),
340                    Err(()) => {
341                        // ignore, it's fine for the client core to drop a timer
342                        // to cancel it
343                    }
344                }
345            }
346        }
347    }
348
349    pub(crate) fn run_until_next_timers_fire<F>(
350        executor: &mut fasync::TestExecutor,
351        time: &Rc<RefCell<FakeTimeController>>,
352        main_future: &mut F,
353    ) -> std::task::Poll<F::Output>
354    where
355        F: Future + Unpin,
356    {
357        let poll: std::task::Poll<_> = executor.run_until_stalled(main_future);
358        if poll.is_ready() {
359            return poll;
360        }
361
362        {
363            let mut time = time.borrow_mut();
364            let FakeTimeController { timer_heap, current_time } = time.deref_mut();
365
366            // NOTE: the timer heap is ordered by Reverse<Duration> in order to
367            // facilitate the implementation of `advance()` by making
368            // `BTreeMap::split_off` have the right edge-case behavior. This
369            // makes it easy to get it first_entry vs last_entry mixed up here,
370            // though.
371            let earliest_entry = timer_heap.last_entry().expect("no timers installed");
372
373            let (Reverse(instant), senders) = earliest_entry.remove_entry();
374            *current_time = instant;
375            for sender in senders {
376                match sender.send(()) {
377                    Ok(()) => (),
378                    Err(()) => {
379                        // ignore, it's fine for the client core to drop a timer
380                        // to cancel it
381                    }
382                }
383            }
384        }
385
386        executor.run_until_stalled(main_future)
387    }
388
389    impl Clock for Rc<RefCell<FakeTimeController>> {
390        type Instant = TestInstant;
391
392        fn now(&self) -> Self::Instant {
393            let ctl = self.borrow_mut();
394            let FakeTimeController { timer_heap: _, current_time } = ctl.deref();
395            TestInstant(*current_time)
396        }
397
398        async fn wait_until(&self, TestInstant(time): Self::Instant) {
399            log::info!("registering timer at {:?}", time);
400            let receiver = {
401                let mut ctl = self.borrow_mut();
402                let FakeTimeController { timer_heap, current_time } = ctl.deref_mut();
403                if *current_time >= time {
404                    return;
405                }
406                let (sender, receiver) = oneshot::channel();
407                timer_heap.entry(std::cmp::Reverse(time)).or_default().push(sender);
408                receiver
409            };
410            receiver.await.expect("shouldn't be cancelled")
411        }
412    }
413}
414
415#[cfg(test)]
416mod test {
417    use super::testutil::*;
418    use super::*;
419    use fuchsia_async as fasync;
420    use futures::channel::mpsc;
421    use futures::{FutureExt, StreamExt};
422    use net_declare::std_socket_addr;
423    use std::pin::pin;
424
425    #[test]
426    fn test_rng() {
427        let make_sequence = |seed| {
428            let mut rng = FakeRngProvider::new(seed);
429            std::iter::from_fn(|| Some(rng.get_rng().random::<u32>())).take(5).collect::<Vec<_>>()
430        };
431        assert_eq!(
432            make_sequence(42),
433            make_sequence(42),
434            "should provide identical sequences with same seed"
435        );
436        assert_ne!(
437            make_sequence(42),
438            make_sequence(999999),
439            "should provide different sequences with different seeds"
440        );
441    }
442
443    #[fasync::run_singlethreaded(test)]
444    async fn test_socket() {
445        let (a, b) = FakeSocket::new_pair();
446        let to_send = [
447            (b"hello".to_vec(), "1.2.3.4:5".to_string()),
448            (b"test".to_vec(), "1.2.3.5:5".to_string()),
449            (b"socket".to_vec(), "1.2.3.6:5".to_string()),
450        ];
451
452        let mut buf = [0u8; 10];
453        for (msg, addr) in &to_send {
454            a.send_to(msg, addr.clone()).await.unwrap();
455
456            let DatagramInfo { length: n, address: received_addr } =
457                b.recv_from(&mut buf).await.unwrap();
458            assert_eq!(&received_addr, addr);
459            assert_eq!(&buf[..n], msg);
460        }
461
462        let (a, b) = (b, a);
463        for (msg, addr) in &to_send {
464            a.send_to(msg, addr.clone()).await.unwrap();
465
466            let DatagramInfo { length: n, address: received_addr } =
467                b.recv_from(&mut buf).await.unwrap();
468            assert_eq!(&received_addr, addr);
469            assert_eq!(&buf[..n], msg);
470        }
471    }
472
473    #[fasync::run_singlethreaded(test)]
474    #[should_panic]
475    async fn test_socket_panics_on_short_read() {
476        let (a, b) = FakeSocket::new_pair();
477
478        let mut buf = [0u8; 10];
479        let message = b"this message is way longer than 10 bytes";
480        a.send_to(message, "1.2.3.4:5".to_string()).await.unwrap();
481
482        // Should panic here.
483        let _: Result<_, _> = b.recv_from(&mut buf).await;
484    }
485
486    #[fasync::run_singlethreaded(test)]
487    async fn test_fake_udp_socket_provider() {
488        let (a, b) = FakeSocket::new_pair();
489        let (events_sender, mut events_receiver) = mpsc::unbounded();
490        let provider = FakeSocketProvider::new_with_events(b, events_sender);
491        const ADDR_1: std::net::SocketAddr = std_socket_addr!("1.1.1.1:11");
492        const ADDR_2: std::net::SocketAddr = std_socket_addr!("2.2.2.2:22");
493        const ADDR_3: std::net::SocketAddr = std_socket_addr!("3.3.3.3:33");
494        let b_1 = provider.bind_new_udp_socket(ADDR_1).await.expect("get packet socket");
495        assert_eq!(
496            events_receiver
497                .next()
498                .now_or_never()
499                .expect("should have received bound event")
500                .expect("stream should not have ended"),
501            ADDR_1
502        );
503
504        let b_2 = provider.bind_new_udp_socket(ADDR_2).await.expect("get packet socket");
505        assert_eq!(
506            events_receiver
507                .next()
508                .now_or_never()
509                .expect("should have received bound event")
510                .expect("stream should not have ended"),
511            ADDR_2
512        );
513
514        a.send_to(b"hello", ADDR_3).await.unwrap();
515        a.send_to(b"world", ADDR_3).await.unwrap();
516
517        let mut buf = [0u8; 5];
518        let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
519        assert_eq!(&buf[..length], b"hello");
520        assert_eq!(address, ADDR_3);
521
522        let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
523        assert_eq!(&buf[..length], b"world");
524        assert_eq!(address, ADDR_3);
525    }
526
527    #[fasync::run_singlethreaded(test)]
528    async fn test_fake_packet_socket_provider() {
529        let (a, b) = FakeSocket::new_pair();
530        let (events_sender, mut events_receiver) = mpsc::unbounded();
531        let provider = FakeSocketProvider::new_with_events(b, events_sender);
532        let b_1 = provider.get_packet_socket().await.expect("get packet socket");
533        events_receiver
534            .next()
535            .now_or_never()
536            .expect("should have received bound event")
537            .expect("stream should not have ended");
538
539        let b_2 = provider.get_packet_socket().await.expect("get packet socket");
540        events_receiver
541            .next()
542            .now_or_never()
543            .expect("should have received bound event")
544            .expect("stream should not have ended");
545
546        const ADDRESS: net_types::ethernet::Mac = net_declare::net_mac!("01:02:03:04:05:06");
547
548        a.send_to(b"hello", ADDRESS).await.unwrap();
549
550        a.send_to(b"world", ADDRESS).await.unwrap();
551
552        let mut buf = [0u8; 5];
553        let DatagramInfo { length, address } = b_1.recv_from(&mut buf).await.unwrap();
554        assert_eq!(&buf[..length], b"hello");
555        assert_eq!(address, ADDRESS);
556
557        let DatagramInfo { length, address } = b_2.recv_from(&mut buf).await.unwrap();
558        assert_eq!(&buf[..length], b"world");
559        assert_eq!(address, ADDRESS);
560    }
561
562    #[test]
563    fn test_time_controller() {
564        let time_ctl = FakeTimeController::new();
565        assert!(time_ctl.borrow().timer_heap.is_empty());
566        assert_eq!(time_ctl.borrow().current_time, std::time::Duration::from_secs(0));
567        assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(0)));
568
569        let mut timer_registered_before_should_fire_1 =
570            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
571        let mut timer_registered_before_should_fire_2 =
572            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(1))));
573
574        let mut timer_should_not_fire =
575            pin!(time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(10))));
576
577        // Poll the timer futures once so that they have the chance to
578        // register themselves in our timer heap.
579        {
580            let waker = std::task::Waker::noop();
581            let mut context = futures::task::Context::from_waker(&waker);
582            assert_eq!(
583                timer_registered_before_should_fire_1.poll_unpin(&mut context),
584                futures::task::Poll::Pending
585            );
586            assert_eq!(
587                timer_registered_before_should_fire_2.poll_unpin(&mut context),
588                futures::task::Poll::Pending
589            );
590            assert_eq!(
591                timer_should_not_fire.poll_unpin(&mut context),
592                futures::task::Poll::Pending
593            );
594        }
595
596        {
597            let time_ctl = time_ctl.borrow_mut();
598            let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
599            assert_eq!(entries.len(), 2);
600
601            let (time, senders) = entries[0];
602            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
603            assert_eq!(senders.len(), 1);
604
605            let (time, senders) = entries[1];
606            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(1)));
607            assert_eq!(senders.len(), 2);
608        }
609
610        advance(&time_ctl, std::time::Duration::from_secs(1));
611
612        assert_eq!(time_ctl.now(), TestInstant(std::time::Duration::from_secs(1)));
613        {
614            let time_ctl = time_ctl.borrow_mut();
615            let entries = time_ctl.timer_heap.iter().collect::<Vec<_>>();
616            assert_eq!(entries.len(), 1);
617            let (time, senders) = entries[0];
618            assert_eq!(time, &std::cmp::Reverse(std::time::Duration::from_secs(10)));
619            assert_eq!(senders.len(), 1);
620        }
621
622        assert_eq!(timer_registered_before_should_fire_1.now_or_never(), Some(()));
623        assert_eq!(timer_registered_before_should_fire_2.now_or_never(), Some(()));
624        assert_eq!(timer_should_not_fire.now_or_never(), None);
625
626        let timer_set_in_past = time_ctl.wait_until(TestInstant(std::time::Duration::from_secs(0)));
627        assert_eq!(timer_set_in_past.now_or_never(), Some(()));
628
629        let timer_set_for_present = time_ctl.wait_until(time_ctl.now());
630        assert_eq!(timer_set_for_present.now_or_never(), Some(()));
631    }
632}