ping/
lib.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7//! Helpers to ping an IPv4 or IPv6 address by sending ICMP echo requests and
8//! waiting for ICMP echo replies.
9//!
10//! Functionality in this crate relies on [ICMP sockets], a kind of socket where
11//! each payload read/written contains ICMP headers.
12//!
13//! As a starting point, see [`new_unicast_sink_and_stream`], which is built
14//! on top of the other facilities in the crate and models pinging as sending
15//! an ICMP echo request whenever a value is sent to the sink, and a stream
16//! which yields an item for every echo reply received.
17//!
18//! [ICMP sockets]: https://lwn.net/Articles/422330/
19
20#[cfg(target_os = "fuchsia")]
21mod fuchsia;
22
23#[cfg(target_os = "fuchsia")]
24pub use fuchsia::{new_icmp_socket, IpExt as FuchsiaIpExt};
25
26use futures::{ready, Sink, SinkExt as _, Stream, TryStreamExt as _};
27use net_types::ip::{Ip, Ipv4, Ipv6};
28use std::marker::PhantomData;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use thiserror::Error;
32use zerocopy::byteorder::network_endian::U16;
33use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
34
35/// The number of bytes of an ICMP (v4 or v6) header.
36pub const ICMP_HEADER_LEN: usize = std::mem::size_of::<IcmpHeader>();
37
38/// ICMP header representation.
39#[repr(C)]
40#[derive(KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned, Debug, PartialEq, Eq, Clone)]
41struct IcmpHeader {
42    type_: u8,
43    code: u8,
44    checksum: U16,
45    id: U16,
46    sequence: U16,
47}
48
49impl IcmpHeader {
50    fn new<I: IpExt>(sequence: u16) -> Self {
51        Self {
52            type_: I::ECHO_REQUEST_TYPE,
53            code: 0,
54            checksum: 0.into(),
55            id: 0.into(),
56            sequence: sequence.into(),
57        }
58    }
59}
60
61/// Ping error.
62#[derive(Debug, Error)]
63pub enum PingError {
64    /// Send error.
65    #[error("send error")]
66    Send(#[source] std::io::Error),
67    /// Send length mismatch.
68    #[error("wrong number of bytes sent, got: {got}, want: {want}")]
69    SendLength {
70        /// Number of bytes sent.
71        got: usize,
72        /// Number of bytes expected to be sent.
73        want: usize,
74    },
75    /// Recv error.
76    #[error("recv error")]
77    Recv(#[source] std::io::Error),
78    /// ICMP header parsing error.
79    #[error("failed to parse ICMP header")]
80    Parse,
81    /// Reply type mismatch.
82    #[error("wrong reply type, got: {got}, want: {want}")]
83    ReplyType {
84        /// ICMP type received in reply.
85        got: u8,
86        /// ICMP type expected in reply.
87        want: u8,
88    },
89    /// Reply code mismatch.
90    #[error("non-zero reply code: {0}")]
91    ReplyCode(u8),
92    /// ICMP message body mismatch.
93    #[error("reply message body mismatch, got: {got:?}, want: {want:?}")]
94    Body {
95        /// Body received in reply.
96        got: Vec<u8>,
97        /// Body expected in reply.
98        want: Vec<u8>,
99    },
100}
101
102/// Addresses which can be converted from `socket2::SockAddr`.
103///
104/// This trait exists to get around not being able to implement the foreign trait
105/// `TryFrom<socket2::SockAddr>` for the foreign types `std::net::SocketAddr(V4|V6)?`.
106pub trait TryFromSockAddr: Sized {
107    /// Try to convert from `socket2::SockAddr`.
108    fn try_from(value: socket2::SockAddr) -> std::io::Result<Self>;
109}
110
111impl TryFromSockAddr for std::net::SocketAddrV4 {
112    fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
113        addr.as_socket_ipv4()
114            .ok_or_else(|| std::io::Error::other(format!("socket address is not v4 {:?}", addr)))
115    }
116}
117
118impl TryFromSockAddr for std::net::SocketAddrV6 {
119    fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
120        addr.as_socket_ipv6()
121            .ok_or_else(|| std::io::Error::other(format!("socket address is not v6 {:?}", addr)))
122    }
123}
124
125/// Trait for IP protocol versions.
126pub trait IpExt: Ip + Unpin {
127    /// IP Socket address type.
128    type SockAddr: Into<socket2::SockAddr>
129        + TryFromSockAddr
130        + Clone
131        + Copy
132        + Unpin
133        + PartialEq
134        + std::fmt::Debug
135        + std::fmt::Display
136        + Eq
137        + std::hash::Hash;
138
139    /// ICMP socket domain.
140    const DOMAIN: socket2::Domain;
141    /// ICMP socket protocol.
142    const PROTOCOL: socket2::Protocol;
143
144    /// ICMP echo request type.
145    const ECHO_REQUEST_TYPE: u8;
146    /// ICMP echo reply type.
147    const ECHO_REPLY_TYPE: u8;
148}
149
150// TODO(https://fxbug.dev/323955204): Implement ext trait on net_types::ip::Ipv4
151// instead and remove the Ipv4 type.
152impl IpExt for Ipv4 {
153    type SockAddr = std::net::SocketAddrV4;
154
155    const DOMAIN: socket2::Domain = socket2::Domain::IPV4;
156    const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV4;
157
158    const ECHO_REQUEST_TYPE: u8 = 8;
159    const ECHO_REPLY_TYPE: u8 = 0;
160}
161
162// TODO(https://fxbug.dev/323955204): Implement ext trait on net_types::ip::Ipv6
163// instead and remove the Ipv6 type.
164impl IpExt for Ipv6 {
165    type SockAddr = std::net::SocketAddrV6;
166
167    const DOMAIN: socket2::Domain = socket2::Domain::IPV6;
168    const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV6;
169
170    const ECHO_REQUEST_TYPE: u8 = 128;
171    const ECHO_REPLY_TYPE: u8 = 129;
172}
173
174/// Async ICMP socket.
175pub trait IcmpSocket<I>: Unpin
176where
177    I: IpExt,
178{
179    /// Async method for receiving an ICMP packet.
180    ///
181    /// Upon successful return, `buf` will contain an ICMP packet.
182    fn async_recv_from(
183        &self,
184        buf: &mut [u8],
185        cx: &mut Context<'_>,
186    ) -> Poll<std::io::Result<(usize, I::SockAddr)>>;
187
188    /// Async method for sending an ICMP packet.
189    ///
190    /// `bufs` must contain a valid ICMP packet.
191    fn async_send_to_vectored(
192        &self,
193        bufs: &[std::io::IoSlice<'_>],
194        addr: &I::SockAddr,
195        cx: &mut Context<'_>,
196    ) -> Poll<std::io::Result<usize>>;
197
198    /// Binds this to an interface so that packets can only flow in/out via the specified
199    /// interface.
200    ///
201    /// If `interface` is `None`, the binding is removed.
202    fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()>;
203}
204
205/// Parameters of a ping request/reply.
206#[derive(Clone, Debug, PartialEq, Eq, Hash)]
207pub struct PingData<I: IpExt> {
208    /// The destination address of a ping request; or the source address of a ping reply.
209    pub addr: I::SockAddr,
210    /// The sequence number in the ICMP header.
211    pub sequence: u16,
212    /// The body of the echo request/reply.
213    pub body: Vec<u8>,
214}
215
216// TODO(https://github.com/rust-lang/rust/issues/76560): Define N as the length of the message body
217// rather than the length of the ICMP packet.
218/// Create a ping sink and stream for pinging a unicast destination with the same body for every
219/// packet.
220///
221/// Echo replies received with a source address not equal to `addr` will be silently dropped. Echo
222/// replies with a body not equal to `body` will result in an error on the stream.
223pub fn new_unicast_sink_and_stream<'a, I, S, const N: usize>(
224    socket: &'a S,
225    addr: &'a I::SockAddr,
226    body: &'a [u8],
227) -> (impl Sink<u16, Error = PingError> + 'a, impl Stream<Item = Result<u16, PingError>> + 'a)
228where
229    I: IpExt,
230    S: IcmpSocket<I>,
231{
232    (
233        PingSink::new(socket).with(move |sequence| {
234            futures::future::ok(PingData { addr: addr.clone(), sequence, body: body.to_vec() })
235        }),
236        PingStream::<I, S, N>::new(socket).try_filter_map(
237            move |PingData { addr: got_addr, sequence, body: got_body }| {
238                futures::future::ready(if got_addr == *addr {
239                    if got_body == body {
240                        Ok(Some(sequence))
241                    } else {
242                        Err(PingError::Body { got: got_body, want: body.to_vec() })
243                    }
244                } else {
245                    Ok(None)
246                })
247            },
248        ),
249    )
250}
251
252// TODO(https://github.com/rust-lang/rust/issues/76560): Define N as the length of the message body
253// rather than the length of the ICMP packet.
254/// Stream of received ping replies.
255pub struct PingStream<'a, I, S, const N: usize>
256where
257    I: IpExt,
258    S: IcmpSocket<I>,
259{
260    socket: &'a S,
261    recv_buf: [u8; N],
262    _marker: PhantomData<I>,
263}
264
265impl<'a, I, S, const N: usize> PingStream<'a, I, S, N>
266where
267    I: IpExt,
268    S: IcmpSocket<I>,
269{
270    /// Construct a stream from an `IcmpSocket`.
271    ///
272    /// `N` must be set to the length of the largest ICMP body expected
273    /// to be received plus the 8 bytes of overhead due to the ICMP
274    /// header, otherwise received packets may be truncated. Note
275    /// that this does not need to include the 8-byte overhead of
276    /// the ICMP header.
277    pub fn new(socket: &'a S) -> Self {
278        Self { socket, recv_buf: [0; N], _marker: PhantomData::<I> }
279    }
280}
281
282impl<'a, I, S, const N: usize> futures::stream::Stream for PingStream<'a, I, S, N>
283where
284    I: IpExt,
285    S: IcmpSocket<I>,
286{
287    type Item = Result<PingData<I>, PingError>;
288
289    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
290        let ping_stream = Pin::into_inner(self);
291        let buf = &mut ping_stream.recv_buf[..];
292        let socket = &ping_stream.socket;
293        Poll::Ready(Some(
294            ready!(socket.async_recv_from(buf, cx))
295                .map_err(PingError::Recv)
296                .and_then(|(len, addr)| verify_packet::<I>(addr, &ping_stream.recv_buf[..len])),
297        ))
298    }
299}
300
301/// Sink for sending ping requests.
302pub struct PingSink<'a, I, S>
303where
304    I: IpExt,
305    S: IcmpSocket<I>,
306{
307    socket: &'a S,
308    packet: Option<(I::SockAddr, IcmpHeader, Vec<u8>)>,
309    _marker: PhantomData<I>,
310}
311
312impl<'a, I, S> PingSink<'a, I, S>
313where
314    I: IpExt,
315    S: IcmpSocket<I>,
316{
317    /// Construct a sink from an `IcmpSocket`.
318    pub fn new(socket: &'a S) -> Self {
319        Self { socket, packet: None, _marker: PhantomData::<I> }
320    }
321}
322
323impl<'a, I, S> futures::sink::Sink<PingData<I>> for PingSink<'a, I, S>
324where
325    I: IpExt,
326    S: IcmpSocket<I>,
327{
328    type Error = PingError;
329
330    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331        self.poll_flush(cx)
332    }
333
334    fn start_send(
335        mut self: Pin<&mut Self>,
336        PingData { addr, sequence, body }: PingData<I>,
337    ) -> Result<(), Self::Error> {
338        let header = IcmpHeader::new::<I>(sequence);
339        assert_eq!(
340            self.packet.replace((addr, header, body)),
341            None,
342            "start_send called while element has yet to be flushed"
343        );
344        Ok(())
345    }
346
347    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
348        Poll::Ready(match &self.packet {
349            Some((addr, header, body)) => {
350                match ready!(self.socket.async_send_to_vectored(
351                    &[
352                        std::io::IoSlice::new(header.as_bytes()),
353                        std::io::IoSlice::new(body.as_bytes()),
354                    ],
355                    addr,
356                    cx
357                )) {
358                    Ok(got) => {
359                        let want = std::mem::size_of_val(&header) + body.len();
360                        if got != want {
361                            Err(PingError::SendLength { got, want })
362                        } else {
363                            self.packet = None;
364                            Ok(())
365                        }
366                    }
367                    Err(e) => Err(PingError::Send(e)),
368                }
369            }
370            None => Ok(()),
371        })
372    }
373
374    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375        self.poll_flush(cx)
376    }
377}
378
379fn verify_packet<I: IpExt>(addr: I::SockAddr, packet: &[u8]) -> Result<PingData<I>, PingError> {
380    let (reply, body): (zerocopy::Ref<_, IcmpHeader>, _) = zerocopy::Ref::from_prefix(packet)
381        .map_err(Into::into)
382        .map_err(|_: zerocopy::SizeError<_, _>| PingError::Parse)?;
383
384    // The identifier cannot be verified, since ICMP socket implementations rewrites the field on
385    // send and uses its value to demultiplex packets for delivery to sockets on receive.
386    //
387    // Also, don't bother verifying the checksum, since ICMP socket implementations must have
388    // verified the checksum since the code and identifier fields must be inspected. Also, the
389    // ICMPv6 checksum computation includes a pseudo header which includes the src and dst
390    // addresses, and the dst/local address is not readily available.
391    let &IcmpHeader { type_, code, checksum: _, id: _, sequence } = zerocopy::Ref::into_ref(reply);
392
393    if type_ != I::ECHO_REPLY_TYPE {
394        return Err(PingError::ReplyType { got: type_, want: I::ECHO_REPLY_TYPE });
395    }
396
397    if code != 0 {
398        return Err(PingError::ReplyCode(code));
399    }
400
401    Ok(PingData { addr, sequence: sequence.into(), body: body.to_vec() })
402}
403
404#[cfg(test)]
405mod test {
406    use super::{IcmpHeader, IcmpSocket, Ipv4, Ipv6, PingData, PingSink, PingStream};
407
408    use futures::{FutureExt as _, SinkExt as _, StreamExt as _, TryStreamExt as _};
409    use net_declare::{std_socket_addr_v4, std_socket_addr_v6};
410    use std::cell::RefCell;
411    use std::collections::VecDeque;
412    use std::task::{Context, Poll};
413    use zerocopy::IntoBytes as _;
414
415    // A fake impl of a IcmpSocket which computes and buffers a reply when `send_to` is called,
416    // which is then returned when `recv_from` is called. The order in which replies are returned
417    // is guaranteed to be FIFO.
418    #[derive(Default, Debug)]
419    struct FakeSocket<I: IpExt> {
420        // NB: interior mutability is necessary here because the `IcmpSocket` trait's methods
421        // operate on &self.
422        buffer: RefCell<VecDeque<(Vec<u8>, I::SockAddr)>>,
423    }
424
425    impl<I: IpExt> FakeSocket<I> {
426        fn new() -> Self {
427            Self { buffer: RefCell::new(VecDeque::new()) }
428        }
429    }
430
431    impl<I: IpExt> IcmpSocket<I> for FakeSocket<I> {
432        fn async_recv_from(
433            &self,
434            buf: &mut [u8],
435            _cx: &mut Context<'_>,
436        ) -> Poll<std::io::Result<(usize, I::SockAddr)>> {
437            Poll::Ready(
438                self.buffer
439                    .borrow_mut()
440                    .pop_front()
441                    .ok_or_else(|| {
442                        std::io::Error::new(
443                            std::io::ErrorKind::WouldBlock,
444                            "fake socket request buffer is empty",
445                        )
446                    })
447                    .and_then(|(reply, addr)| {
448                        if buf.len() < reply.len() {
449                            Err(std::io::Error::other(format!(
450                                "recv buffer too small, got: {}, want: {}",
451                                buf.len(),
452                                reply.len()
453                            )))
454                        } else {
455                            buf[..reply.len()].copy_from_slice(&reply);
456                            Ok((reply.len(), addr))
457                        }
458                    }),
459            )
460        }
461
462        fn async_send_to_vectored(
463            &self,
464            bufs: &[std::io::IoSlice<'_>],
465            addr: &I::SockAddr,
466            _cx: &mut Context<'_>,
467        ) -> Poll<std::io::Result<usize>> {
468            let mut buf = bufs
469                .iter()
470                .map(|io_slice| io_slice.as_bytes())
471                .flatten()
472                .copied()
473                .collect::<Vec<u8>>();
474            let (mut header, _): (zerocopy::Ref<_, IcmpHeader>, _) =
475                match zerocopy::Ref::from_prefix(&mut buf[..]).map_err(Into::into) {
476                    Ok(layout_verified) => layout_verified,
477                    Err(zerocopy::SizeError { .. }) => {
478                        return Poll::Ready(Err(std::io::Error::new(
479                            std::io::ErrorKind::InvalidInput,
480                            "failed to parse ICMP header from provided bytes",
481                        )))
482                    }
483                };
484            header.type_ = I::ECHO_REPLY_TYPE;
485            let len = buf.len();
486            let () = self.buffer.borrow_mut().push_back((buf, addr.clone()));
487            Poll::Ready(Ok(len))
488        }
489
490        fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()> {
491            panic!("unexpected call to bind_device({:?})", interface);
492        }
493    }
494
495    trait IpExt: super::IpExt {
496        // NB: This is only a function because there is no way to create a constant for any of the
497        // socket address types.
498        fn test_addr() -> Self::SockAddr;
499    }
500
501    impl IpExt for Ipv4 {
502        fn test_addr() -> Self::SockAddr {
503            // A port must be specified in the socket addr, but it is irrelevant for ICMP sockets,
504            // so just set it to 0.
505            std_socket_addr_v4!("1.2.3.4:0")
506        }
507    }
508
509    impl IpExt for Ipv6 {
510        fn test_addr() -> Self::SockAddr {
511            // A port must be specified in the socket addr, but it is irrelevant for ICMP sockets,
512            // so just set it to 0.
513            std_socket_addr_v6!("[abcd::1]:0")
514        }
515    }
516
517    const PING_MESSAGE: &str = "Hello from ping library unit test!";
518    const PING_COUNT: u16 = 3;
519    const PING_SEQ_RANGE: std::ops::RangeInclusive<u16> = 1..=PING_COUNT;
520
521    #[test]
522    fn test_ipv4() {
523        test_ping::<Ipv4>();
524    }
525
526    #[test]
527    fn test_ipv6() {
528        test_ping::<Ipv6>();
529    }
530
531    fn test_ping<I: IpExt>() {
532        let socket = FakeSocket::<I>::new();
533
534        let packets = PING_SEQ_RANGE
535            .into_iter()
536            .map(|sequence| PingData {
537                addr: I::test_addr(),
538                sequence,
539                body: PING_MESSAGE.as_bytes().to_vec(),
540            })
541            .collect::<Vec<_>>();
542        let packet_stream = futures::stream::iter(packets.iter().cloned());
543        let () = PingSink::new(&socket)
544            .send_all(&mut packet_stream.map(Ok))
545            .now_or_never()
546            .expect("ping request send blocked unexpectedly")
547            .expect("ping send error");
548
549        let replies =
550            PingStream::<_, _, { PING_MESSAGE.len() + std::mem::size_of::<IcmpHeader>() }>::new(
551                &socket,
552            )
553            .take(PING_COUNT.into())
554            .try_collect::<Vec<_>>()
555            .now_or_never()
556            .expect("ping reply stream blocked unexpectedly")
557            .expect("failed to collect ping reply stream");
558        assert_eq!(packets, replies);
559    }
560}