ping/
lib.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7//! Helpers to ping an IPv4 or IPv6 address by sending ICMP echo requests and
8//! waiting for ICMP echo replies.
9//!
10//! Functionality in this crate relies on [ICMP sockets], a kind of socket where
11//! each payload read/written contains ICMP headers.
12//!
13//! As a starting point, see [`new_unicast_sink_and_stream`], which is built
14//! on top of the other facilities in the crate and models pinging as sending
15//! an ICMP echo request whenever a value is sent to the sink, and a stream
16//! which yields an item for every echo reply received.
17//!
18//! [ICMP sockets]: https://lwn.net/Articles/422330/
19
20#[cfg(target_os = "fuchsia")]
21mod fuchsia;
22
23#[cfg(target_os = "fuchsia")]
24pub use fuchsia::{new_icmp_socket, IpExt as FuchsiaIpExt};
25
26use futures::{ready, Sink, SinkExt as _, Stream, TryStreamExt as _};
27use net_types::ip::{Ip, Ipv4, Ipv6};
28use std::marker::PhantomData;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use thiserror::Error;
32use zerocopy::byteorder::network_endian::U16;
33use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
34
35/// The number of bytes of an ICMP (v4 or v6) header.
36pub const ICMP_HEADER_LEN: usize = std::mem::size_of::<IcmpHeader>();
37
38/// ICMP header representation.
39#[repr(C)]
40#[derive(KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned, Debug, PartialEq, Eq, Clone)]
41struct IcmpHeader {
42    type_: u8,
43    code: u8,
44    checksum: U16,
45    id: U16,
46    sequence: U16,
47}
48
49impl IcmpHeader {
50    fn new<I: IpExt>(sequence: u16) -> Self {
51        Self {
52            type_: I::ECHO_REQUEST_TYPE,
53            code: 0,
54            checksum: 0.into(),
55            id: 0.into(),
56            sequence: sequence.into(),
57        }
58    }
59}
60
61/// Ping error.
62#[derive(Debug, Error)]
63pub enum PingError {
64    /// Send error.
65    #[error("send error")]
66    Send(#[source] std::io::Error),
67    /// Send length mismatch.
68    #[error("wrong number of bytes sent, got: {got}, want: {want}")]
69    SendLength {
70        /// Number of bytes sent.
71        got: usize,
72        /// Number of bytes expected to be sent.
73        want: usize,
74    },
75    /// Recv error.
76    #[error("recv error")]
77    Recv(#[source] std::io::Error),
78    /// ICMP header parsing error.
79    #[error("failed to parse ICMP header")]
80    Parse,
81    /// Reply type mismatch.
82    #[error("wrong reply type, got: {got}, want: {want}")]
83    ReplyType {
84        /// ICMP type received in reply.
85        got: u8,
86        /// ICMP type expected in reply.
87        want: u8,
88    },
89    /// Reply code mismatch.
90    #[error("non-zero reply code: {0}")]
91    ReplyCode(u8),
92    /// ICMP message body mismatch.
93    #[error("reply message body mismatch, got: {got:?}, want: {want:?}")]
94    Body {
95        /// Body received in reply.
96        got: Vec<u8>,
97        /// Body expected in reply.
98        want: Vec<u8>,
99    },
100}
101
102/// Addresses which can be converted from `socket2::SockAddr`.
103///
104/// This trait exists to get around not being able to implement the foreign trait
105/// `TryFrom<socket2::SockAddr>` for the foreign types `std::net::SocketAddr(V4|V6)?`.
106pub trait TryFromSockAddr: Sized {
107    /// Try to convert from `socket2::SockAddr`.
108    fn try_from(value: socket2::SockAddr) -> std::io::Result<Self>;
109}
110
111impl TryFromSockAddr for std::net::SocketAddrV4 {
112    fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
113        addr.as_socket_ipv4().ok_or_else(|| {
114            std::io::Error::new(
115                std::io::ErrorKind::Other,
116                format!("socket address is not v4 {:?}", addr),
117            )
118        })
119    }
120}
121
122impl TryFromSockAddr for std::net::SocketAddrV6 {
123    fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
124        addr.as_socket_ipv6().ok_or_else(|| {
125            std::io::Error::new(
126                std::io::ErrorKind::Other,
127                format!("socket address is not v6 {:?}", addr),
128            )
129        })
130    }
131}
132
133/// Trait for IP protocol versions.
134pub trait IpExt: Ip + Unpin {
135    /// IP Socket address type.
136    type SockAddr: Into<socket2::SockAddr>
137        + TryFromSockAddr
138        + Clone
139        + Copy
140        + Unpin
141        + PartialEq
142        + std::fmt::Debug
143        + std::fmt::Display
144        + Eq
145        + std::hash::Hash;
146
147    /// ICMP socket domain.
148    const DOMAIN: socket2::Domain;
149    /// ICMP socket protocol.
150    const PROTOCOL: socket2::Protocol;
151
152    /// ICMP echo request type.
153    const ECHO_REQUEST_TYPE: u8;
154    /// ICMP echo reply type.
155    const ECHO_REPLY_TYPE: u8;
156}
157
158// TODO(https://fxbug.dev/323955204): Implement ext trait on net_types::ip::Ipv4
159// instead and remove the Ipv4 type.
160impl IpExt for Ipv4 {
161    type SockAddr = std::net::SocketAddrV4;
162
163    const DOMAIN: socket2::Domain = socket2::Domain::IPV4;
164    const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV4;
165
166    const ECHO_REQUEST_TYPE: u8 = 8;
167    const ECHO_REPLY_TYPE: u8 = 0;
168}
169
170// TODO(https://fxbug.dev/323955204): Implement ext trait on net_types::ip::Ipv6
171// instead and remove the Ipv6 type.
172impl IpExt for Ipv6 {
173    type SockAddr = std::net::SocketAddrV6;
174
175    const DOMAIN: socket2::Domain = socket2::Domain::IPV6;
176    const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV6;
177
178    const ECHO_REQUEST_TYPE: u8 = 128;
179    const ECHO_REPLY_TYPE: u8 = 129;
180}
181
182/// Async ICMP socket.
183pub trait IcmpSocket<I>: Unpin
184where
185    I: IpExt,
186{
187    /// Async method for receiving an ICMP packet.
188    ///
189    /// Upon successful return, `buf` will contain an ICMP packet.
190    fn async_recv_from(
191        &self,
192        buf: &mut [u8],
193        cx: &mut Context<'_>,
194    ) -> Poll<std::io::Result<(usize, I::SockAddr)>>;
195
196    /// Async method for sending an ICMP packet.
197    ///
198    /// `bufs` must contain a valid ICMP packet.
199    fn async_send_to_vectored(
200        &self,
201        bufs: &[std::io::IoSlice<'_>],
202        addr: &I::SockAddr,
203        cx: &mut Context<'_>,
204    ) -> Poll<std::io::Result<usize>>;
205
206    /// Binds this to an interface so that packets can only flow in/out via the specified
207    /// interface.
208    ///
209    /// If `interface` is `None`, the binding is removed.
210    fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()>;
211}
212
213/// Parameters of a ping request/reply.
214#[derive(Clone, Debug, PartialEq, Eq, Hash)]
215pub struct PingData<I: IpExt> {
216    /// The destination address of a ping request; or the source address of a ping reply.
217    pub addr: I::SockAddr,
218    /// The sequence number in the ICMP header.
219    pub sequence: u16,
220    /// The body of the echo request/reply.
221    pub body: Vec<u8>,
222}
223
224// TODO(https://github.com/rust-lang/rust/issues/76560): Define N as the length of the message body
225// rather than the length of the ICMP packet.
226/// Create a ping sink and stream for pinging a unicast destination with the same body for every
227/// packet.
228///
229/// Echo replies received with a source address not equal to `addr` will be silently dropped. Echo
230/// replies with a body not equal to `body` will result in an error on the stream.
231pub fn new_unicast_sink_and_stream<'a, I, S, const N: usize>(
232    socket: &'a S,
233    addr: &'a I::SockAddr,
234    body: &'a [u8],
235) -> (impl Sink<u16, Error = PingError> + 'a, impl Stream<Item = Result<u16, PingError>> + 'a)
236where
237    I: IpExt,
238    S: IcmpSocket<I>,
239{
240    (
241        PingSink::new(socket).with(move |sequence| {
242            futures::future::ok(PingData { addr: addr.clone(), sequence, body: body.to_vec() })
243        }),
244        PingStream::<I, S, N>::new(socket).try_filter_map(
245            move |PingData { addr: got_addr, sequence, body: got_body }| {
246                futures::future::ready(if got_addr == *addr {
247                    if got_body == body {
248                        Ok(Some(sequence))
249                    } else {
250                        Err(PingError::Body { got: got_body, want: body.to_vec() })
251                    }
252                } else {
253                    Ok(None)
254                })
255            },
256        ),
257    )
258}
259
260// TODO(https://github.com/rust-lang/rust/issues/76560): Define N as the length of the message body
261// rather than the length of the ICMP packet.
262/// Stream of received ping replies.
263pub struct PingStream<'a, I, S, const N: usize>
264where
265    I: IpExt,
266    S: IcmpSocket<I>,
267{
268    socket: &'a S,
269    recv_buf: [u8; N],
270    _marker: PhantomData<I>,
271}
272
273impl<'a, I, S, const N: usize> PingStream<'a, I, S, N>
274where
275    I: IpExt,
276    S: IcmpSocket<I>,
277{
278    /// Construct a stream from an `IcmpSocket`.
279    ///
280    /// `N` must be set to the length of the largest ICMP body expected
281    /// to be received plus the 8 bytes of overhead due to the ICMP
282    /// header, otherwise received packets may be truncated. Note
283    /// that this does not need to include the 8-byte overhead of
284    /// the ICMP header.
285    pub fn new(socket: &'a S) -> Self {
286        Self { socket, recv_buf: [0; N], _marker: PhantomData::<I> }
287    }
288}
289
290impl<'a, I, S, const N: usize> futures::stream::Stream for PingStream<'a, I, S, N>
291where
292    I: IpExt,
293    S: IcmpSocket<I>,
294{
295    type Item = Result<PingData<I>, PingError>;
296
297    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298        let ping_stream = Pin::into_inner(self);
299        let buf = &mut ping_stream.recv_buf[..];
300        let socket = &ping_stream.socket;
301        Poll::Ready(Some(
302            ready!(socket.async_recv_from(buf, cx))
303                .map_err(PingError::Recv)
304                .and_then(|(len, addr)| verify_packet::<I>(addr, &ping_stream.recv_buf[..len])),
305        ))
306    }
307}
308
309/// Sink for sending ping requests.
310pub struct PingSink<'a, I, S>
311where
312    I: IpExt,
313    S: IcmpSocket<I>,
314{
315    socket: &'a S,
316    packet: Option<(I::SockAddr, IcmpHeader, Vec<u8>)>,
317    _marker: PhantomData<I>,
318}
319
320impl<'a, I, S> PingSink<'a, I, S>
321where
322    I: IpExt,
323    S: IcmpSocket<I>,
324{
325    /// Construct a sink from an `IcmpSocket`.
326    pub fn new(socket: &'a S) -> Self {
327        Self { socket, packet: None, _marker: PhantomData::<I> }
328    }
329}
330
331impl<'a, I, S> futures::sink::Sink<PingData<I>> for PingSink<'a, I, S>
332where
333    I: IpExt,
334    S: IcmpSocket<I>,
335{
336    type Error = PingError;
337
338    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
339        self.poll_flush(cx)
340    }
341
342    fn start_send(
343        mut self: Pin<&mut Self>,
344        PingData { addr, sequence, body }: PingData<I>,
345    ) -> Result<(), Self::Error> {
346        let header = IcmpHeader::new::<I>(sequence);
347        assert_eq!(
348            self.packet.replace((addr, header, body)),
349            None,
350            "start_send called while element has yet to be flushed"
351        );
352        Ok(())
353    }
354
355    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
356        Poll::Ready(match &self.packet {
357            Some((addr, header, body)) => {
358                match ready!(self.socket.async_send_to_vectored(
359                    &[
360                        std::io::IoSlice::new(header.as_bytes()),
361                        std::io::IoSlice::new(body.as_bytes()),
362                    ],
363                    addr,
364                    cx
365                )) {
366                    Ok(got) => {
367                        let want = std::mem::size_of_val(&header) + body.len();
368                        if got != want {
369                            Err(PingError::SendLength { got, want })
370                        } else {
371                            self.packet = None;
372                            Ok(())
373                        }
374                    }
375                    Err(e) => Err(PingError::Send(e)),
376                }
377            }
378            None => Ok(()),
379        })
380    }
381
382    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383        self.poll_flush(cx)
384    }
385}
386
387fn verify_packet<I: IpExt>(addr: I::SockAddr, packet: &[u8]) -> Result<PingData<I>, PingError> {
388    let (reply, body): (zerocopy::Ref<_, IcmpHeader>, _) = zerocopy::Ref::from_prefix(packet)
389        .map_err(Into::into)
390        .map_err(|_: zerocopy::SizeError<_, _>| PingError::Parse)?;
391
392    // The identifier cannot be verified, since ICMP socket implementations rewrites the field on
393    // send and uses its value to demultiplex packets for delivery to sockets on receive.
394    //
395    // Also, don't bother verifying the checksum, since ICMP socket implementations must have
396    // verified the checksum since the code and identifier fields must be inspected. Also, the
397    // ICMPv6 checksum computation includes a pseudo header which includes the src and dst
398    // addresses, and the dst/local address is not readily available.
399    let &IcmpHeader { type_, code, checksum: _, id: _, sequence } = zerocopy::Ref::into_ref(reply);
400
401    if type_ != I::ECHO_REPLY_TYPE {
402        return Err(PingError::ReplyType { got: type_, want: I::ECHO_REPLY_TYPE });
403    }
404
405    if code != 0 {
406        return Err(PingError::ReplyCode(code));
407    }
408
409    Ok(PingData { addr, sequence: sequence.into(), body: body.to_vec() })
410}
411
412#[cfg(test)]
413mod test {
414    use super::{IcmpHeader, IcmpSocket, Ipv4, Ipv6, PingData, PingSink, PingStream};
415
416    use futures::{FutureExt as _, SinkExt as _, StreamExt as _, TryStreamExt as _};
417    use net_declare::{std_socket_addr_v4, std_socket_addr_v6};
418    use std::cell::RefCell;
419    use std::collections::VecDeque;
420    use std::task::{Context, Poll};
421    use zerocopy::IntoBytes as _;
422
423    // A fake impl of a IcmpSocket which computes and buffers a reply when `send_to` is called,
424    // which is then returned when `recv_from` is called. The order in which replies are returned
425    // is guaranteed to be FIFO.
426    #[derive(Default, Debug)]
427    struct FakeSocket<I: IpExt> {
428        // NB: interior mutability is necessary here because the `IcmpSocket` trait's methods
429        // operate on &self.
430        buffer: RefCell<VecDeque<(Vec<u8>, I::SockAddr)>>,
431    }
432
433    impl<I: IpExt> FakeSocket<I> {
434        fn new() -> Self {
435            Self { buffer: RefCell::new(VecDeque::new()) }
436        }
437    }
438
439    impl<I: IpExt> IcmpSocket<I> for FakeSocket<I> {
440        fn async_recv_from(
441            &self,
442            buf: &mut [u8],
443            _cx: &mut Context<'_>,
444        ) -> Poll<std::io::Result<(usize, I::SockAddr)>> {
445            Poll::Ready(
446                self.buffer
447                    .borrow_mut()
448                    .pop_front()
449                    .ok_or_else(|| {
450                        std::io::Error::new(
451                            std::io::ErrorKind::WouldBlock,
452                            "fake socket request buffer is empty",
453                        )
454                    })
455                    .and_then(|(reply, addr)| {
456                        if buf.len() < reply.len() {
457                            Err(std::io::Error::new(
458                                std::io::ErrorKind::Other,
459                                format!(
460                                    "recv buffer too small, got: {}, want: {}",
461                                    buf.len(),
462                                    reply.len()
463                                ),
464                            ))
465                        } else {
466                            buf[..reply.len()].copy_from_slice(&reply);
467                            Ok((reply.len(), addr))
468                        }
469                    }),
470            )
471        }
472
473        fn async_send_to_vectored(
474            &self,
475            bufs: &[std::io::IoSlice<'_>],
476            addr: &I::SockAddr,
477            _cx: &mut Context<'_>,
478        ) -> Poll<std::io::Result<usize>> {
479            let mut buf = bufs
480                .iter()
481                .map(|io_slice| io_slice.as_bytes())
482                .flatten()
483                .copied()
484                .collect::<Vec<u8>>();
485            let (mut header, _): (zerocopy::Ref<_, IcmpHeader>, _) =
486                match zerocopy::Ref::from_prefix(&mut buf[..]).map_err(Into::into) {
487                    Ok(layout_verified) => layout_verified,
488                    Err(zerocopy::SizeError { .. }) => {
489                        return Poll::Ready(Err(std::io::Error::new(
490                            std::io::ErrorKind::InvalidInput,
491                            "failed to parse ICMP header from provided bytes",
492                        )))
493                    }
494                };
495            header.type_ = I::ECHO_REPLY_TYPE;
496            let len = buf.len();
497            let () = self.buffer.borrow_mut().push_back((buf, addr.clone()));
498            Poll::Ready(Ok(len))
499        }
500
501        fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()> {
502            panic!("unexpected call to bind_device({:?})", interface);
503        }
504    }
505
506    trait IpExt: super::IpExt {
507        // NB: This is only a function because there is no way to create a constant for any of the
508        // socket address types.
509        fn test_addr() -> Self::SockAddr;
510    }
511
512    impl IpExt for Ipv4 {
513        fn test_addr() -> Self::SockAddr {
514            // A port must be specified in the socket addr, but it is irrelevant for ICMP sockets,
515            // so just set it to 0.
516            std_socket_addr_v4!("1.2.3.4:0")
517        }
518    }
519
520    impl IpExt for Ipv6 {
521        fn test_addr() -> Self::SockAddr {
522            // A port must be specified in the socket addr, but it is irrelevant for ICMP sockets,
523            // so just set it to 0.
524            std_socket_addr_v6!("[abcd::1]:0")
525        }
526    }
527
528    const PING_MESSAGE: &str = "Hello from ping library unit test!";
529    const PING_COUNT: u16 = 3;
530    const PING_SEQ_RANGE: std::ops::RangeInclusive<u16> = 1..=PING_COUNT;
531
532    #[test]
533    fn test_ipv4() {
534        test_ping::<Ipv4>();
535    }
536
537    #[test]
538    fn test_ipv6() {
539        test_ping::<Ipv6>();
540    }
541
542    fn test_ping<I: IpExt>() {
543        let socket = FakeSocket::<I>::new();
544
545        let packets = PING_SEQ_RANGE
546            .into_iter()
547            .map(|sequence| PingData {
548                addr: I::test_addr(),
549                sequence,
550                body: PING_MESSAGE.as_bytes().to_vec(),
551            })
552            .collect::<Vec<_>>();
553        let packet_stream = futures::stream::iter(packets.iter().cloned());
554        let () = PingSink::new(&socket)
555            .send_all(&mut packet_stream.map(Ok))
556            .now_or_never()
557            .expect("ping request send blocked unexpectedly")
558            .expect("ping send error");
559
560        let replies =
561            PingStream::<_, _, { PING_MESSAGE.len() + std::mem::size_of::<IcmpHeader>() }>::new(
562                &socket,
563            )
564            .take(PING_COUNT.into())
565            .try_collect::<Vec<_>>()
566            .now_or_never()
567            .expect("ping reply stream blocked unexpectedly")
568            .expect("failed to collect ping reply stream");
569        assert_eq!(packets, replies);
570    }
571}