1#![deny(missing_docs)]
6
7#[cfg(target_os = "fuchsia")]
21mod fuchsia;
22
23#[cfg(target_os = "fuchsia")]
24pub use fuchsia::{new_icmp_socket, IpExt as FuchsiaIpExt};
25
26use futures::{ready, Sink, SinkExt as _, Stream, TryStreamExt as _};
27use net_types::ip::{Ip, Ipv4, Ipv6};
28use std::marker::PhantomData;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use thiserror::Error;
32use zerocopy::byteorder::network_endian::U16;
33use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
34
35pub const ICMP_HEADER_LEN: usize = std::mem::size_of::<IcmpHeader>();
37
38#[repr(C)]
40#[derive(KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned, Debug, PartialEq, Eq, Clone)]
41struct IcmpHeader {
42 type_: u8,
43 code: u8,
44 checksum: U16,
45 id: U16,
46 sequence: U16,
47}
48
49impl IcmpHeader {
50 fn new<I: IpExt>(sequence: u16) -> Self {
51 Self {
52 type_: I::ECHO_REQUEST_TYPE,
53 code: 0,
54 checksum: 0.into(),
55 id: 0.into(),
56 sequence: sequence.into(),
57 }
58 }
59}
60
61#[derive(Debug, Error)]
63pub enum PingError {
64 #[error("send error")]
66 Send(#[source] std::io::Error),
67 #[error("wrong number of bytes sent, got: {got}, want: {want}")]
69 SendLength {
70 got: usize,
72 want: usize,
74 },
75 #[error("recv error")]
77 Recv(#[source] std::io::Error),
78 #[error("failed to parse ICMP header")]
80 Parse,
81 #[error("wrong reply type, got: {got}, want: {want}")]
83 ReplyType {
84 got: u8,
86 want: u8,
88 },
89 #[error("non-zero reply code: {0}")]
91 ReplyCode(u8),
92 #[error("reply message body mismatch, got: {got:?}, want: {want:?}")]
94 Body {
95 got: Vec<u8>,
97 want: Vec<u8>,
99 },
100}
101
102pub trait TryFromSockAddr: Sized {
107 fn try_from(value: socket2::SockAddr) -> std::io::Result<Self>;
109}
110
111impl TryFromSockAddr for std::net::SocketAddrV4 {
112 fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
113 addr.as_socket_ipv4().ok_or_else(|| {
114 std::io::Error::new(
115 std::io::ErrorKind::Other,
116 format!("socket address is not v4 {:?}", addr),
117 )
118 })
119 }
120}
121
122impl TryFromSockAddr for std::net::SocketAddrV6 {
123 fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
124 addr.as_socket_ipv6().ok_or_else(|| {
125 std::io::Error::new(
126 std::io::ErrorKind::Other,
127 format!("socket address is not v6 {:?}", addr),
128 )
129 })
130 }
131}
132
133pub trait IpExt: Ip + Unpin {
135 type SockAddr: Into<socket2::SockAddr>
137 + TryFromSockAddr
138 + Clone
139 + Copy
140 + Unpin
141 + PartialEq
142 + std::fmt::Debug
143 + std::fmt::Display
144 + Eq
145 + std::hash::Hash;
146
147 const DOMAIN: socket2::Domain;
149 const PROTOCOL: socket2::Protocol;
151
152 const ECHO_REQUEST_TYPE: u8;
154 const ECHO_REPLY_TYPE: u8;
156}
157
158impl IpExt for Ipv4 {
161 type SockAddr = std::net::SocketAddrV4;
162
163 const DOMAIN: socket2::Domain = socket2::Domain::IPV4;
164 const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV4;
165
166 const ECHO_REQUEST_TYPE: u8 = 8;
167 const ECHO_REPLY_TYPE: u8 = 0;
168}
169
170impl IpExt for Ipv6 {
173 type SockAddr = std::net::SocketAddrV6;
174
175 const DOMAIN: socket2::Domain = socket2::Domain::IPV6;
176 const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV6;
177
178 const ECHO_REQUEST_TYPE: u8 = 128;
179 const ECHO_REPLY_TYPE: u8 = 129;
180}
181
182pub trait IcmpSocket<I>: Unpin
184where
185 I: IpExt,
186{
187 fn async_recv_from(
191 &self,
192 buf: &mut [u8],
193 cx: &mut Context<'_>,
194 ) -> Poll<std::io::Result<(usize, I::SockAddr)>>;
195
196 fn async_send_to_vectored(
200 &self,
201 bufs: &[std::io::IoSlice<'_>],
202 addr: &I::SockAddr,
203 cx: &mut Context<'_>,
204 ) -> Poll<std::io::Result<usize>>;
205
206 fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()>;
211}
212
213#[derive(Clone, Debug, PartialEq, Eq, Hash)]
215pub struct PingData<I: IpExt> {
216 pub addr: I::SockAddr,
218 pub sequence: u16,
220 pub body: Vec<u8>,
222}
223
224pub fn new_unicast_sink_and_stream<'a, I, S, const N: usize>(
232 socket: &'a S,
233 addr: &'a I::SockAddr,
234 body: &'a [u8],
235) -> (impl Sink<u16, Error = PingError> + 'a, impl Stream<Item = Result<u16, PingError>> + 'a)
236where
237 I: IpExt,
238 S: IcmpSocket<I>,
239{
240 (
241 PingSink::new(socket).with(move |sequence| {
242 futures::future::ok(PingData { addr: addr.clone(), sequence, body: body.to_vec() })
243 }),
244 PingStream::<I, S, N>::new(socket).try_filter_map(
245 move |PingData { addr: got_addr, sequence, body: got_body }| {
246 futures::future::ready(if got_addr == *addr {
247 if got_body == body {
248 Ok(Some(sequence))
249 } else {
250 Err(PingError::Body { got: got_body, want: body.to_vec() })
251 }
252 } else {
253 Ok(None)
254 })
255 },
256 ),
257 )
258}
259
260pub struct PingStream<'a, I, S, const N: usize>
264where
265 I: IpExt,
266 S: IcmpSocket<I>,
267{
268 socket: &'a S,
269 recv_buf: [u8; N],
270 _marker: PhantomData<I>,
271}
272
273impl<'a, I, S, const N: usize> PingStream<'a, I, S, N>
274where
275 I: IpExt,
276 S: IcmpSocket<I>,
277{
278 pub fn new(socket: &'a S) -> Self {
286 Self { socket, recv_buf: [0; N], _marker: PhantomData::<I> }
287 }
288}
289
290impl<'a, I, S, const N: usize> futures::stream::Stream for PingStream<'a, I, S, N>
291where
292 I: IpExt,
293 S: IcmpSocket<I>,
294{
295 type Item = Result<PingData<I>, PingError>;
296
297 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298 let ping_stream = Pin::into_inner(self);
299 let buf = &mut ping_stream.recv_buf[..];
300 let socket = &ping_stream.socket;
301 Poll::Ready(Some(
302 ready!(socket.async_recv_from(buf, cx))
303 .map_err(PingError::Recv)
304 .and_then(|(len, addr)| verify_packet::<I>(addr, &ping_stream.recv_buf[..len])),
305 ))
306 }
307}
308
309pub struct PingSink<'a, I, S>
311where
312 I: IpExt,
313 S: IcmpSocket<I>,
314{
315 socket: &'a S,
316 packet: Option<(I::SockAddr, IcmpHeader, Vec<u8>)>,
317 _marker: PhantomData<I>,
318}
319
320impl<'a, I, S> PingSink<'a, I, S>
321where
322 I: IpExt,
323 S: IcmpSocket<I>,
324{
325 pub fn new(socket: &'a S) -> Self {
327 Self { socket, packet: None, _marker: PhantomData::<I> }
328 }
329}
330
331impl<'a, I, S> futures::sink::Sink<PingData<I>> for PingSink<'a, I, S>
332where
333 I: IpExt,
334 S: IcmpSocket<I>,
335{
336 type Error = PingError;
337
338 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
339 self.poll_flush(cx)
340 }
341
342 fn start_send(
343 mut self: Pin<&mut Self>,
344 PingData { addr, sequence, body }: PingData<I>,
345 ) -> Result<(), Self::Error> {
346 let header = IcmpHeader::new::<I>(sequence);
347 assert_eq!(
348 self.packet.replace((addr, header, body)),
349 None,
350 "start_send called while element has yet to be flushed"
351 );
352 Ok(())
353 }
354
355 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
356 Poll::Ready(match &self.packet {
357 Some((addr, header, body)) => {
358 match ready!(self.socket.async_send_to_vectored(
359 &[
360 std::io::IoSlice::new(header.as_bytes()),
361 std::io::IoSlice::new(body.as_bytes()),
362 ],
363 addr,
364 cx
365 )) {
366 Ok(got) => {
367 let want = std::mem::size_of_val(&header) + body.len();
368 if got != want {
369 Err(PingError::SendLength { got, want })
370 } else {
371 self.packet = None;
372 Ok(())
373 }
374 }
375 Err(e) => Err(PingError::Send(e)),
376 }
377 }
378 None => Ok(()),
379 })
380 }
381
382 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
383 self.poll_flush(cx)
384 }
385}
386
387fn verify_packet<I: IpExt>(addr: I::SockAddr, packet: &[u8]) -> Result<PingData<I>, PingError> {
388 let (reply, body): (zerocopy::Ref<_, IcmpHeader>, _) = zerocopy::Ref::from_prefix(packet)
389 .map_err(Into::into)
390 .map_err(|_: zerocopy::SizeError<_, _>| PingError::Parse)?;
391
392 let &IcmpHeader { type_, code, checksum: _, id: _, sequence } = zerocopy::Ref::into_ref(reply);
400
401 if type_ != I::ECHO_REPLY_TYPE {
402 return Err(PingError::ReplyType { got: type_, want: I::ECHO_REPLY_TYPE });
403 }
404
405 if code != 0 {
406 return Err(PingError::ReplyCode(code));
407 }
408
409 Ok(PingData { addr, sequence: sequence.into(), body: body.to_vec() })
410}
411
412#[cfg(test)]
413mod test {
414 use super::{IcmpHeader, IcmpSocket, Ipv4, Ipv6, PingData, PingSink, PingStream};
415
416 use futures::{FutureExt as _, SinkExt as _, StreamExt as _, TryStreamExt as _};
417 use net_declare::{std_socket_addr_v4, std_socket_addr_v6};
418 use std::cell::RefCell;
419 use std::collections::VecDeque;
420 use std::task::{Context, Poll};
421 use zerocopy::IntoBytes as _;
422
423 #[derive(Default, Debug)]
427 struct FakeSocket<I: IpExt> {
428 buffer: RefCell<VecDeque<(Vec<u8>, I::SockAddr)>>,
431 }
432
433 impl<I: IpExt> FakeSocket<I> {
434 fn new() -> Self {
435 Self { buffer: RefCell::new(VecDeque::new()) }
436 }
437 }
438
439 impl<I: IpExt> IcmpSocket<I> for FakeSocket<I> {
440 fn async_recv_from(
441 &self,
442 buf: &mut [u8],
443 _cx: &mut Context<'_>,
444 ) -> Poll<std::io::Result<(usize, I::SockAddr)>> {
445 Poll::Ready(
446 self.buffer
447 .borrow_mut()
448 .pop_front()
449 .ok_or_else(|| {
450 std::io::Error::new(
451 std::io::ErrorKind::WouldBlock,
452 "fake socket request buffer is empty",
453 )
454 })
455 .and_then(|(reply, addr)| {
456 if buf.len() < reply.len() {
457 Err(std::io::Error::new(
458 std::io::ErrorKind::Other,
459 format!(
460 "recv buffer too small, got: {}, want: {}",
461 buf.len(),
462 reply.len()
463 ),
464 ))
465 } else {
466 buf[..reply.len()].copy_from_slice(&reply);
467 Ok((reply.len(), addr))
468 }
469 }),
470 )
471 }
472
473 fn async_send_to_vectored(
474 &self,
475 bufs: &[std::io::IoSlice<'_>],
476 addr: &I::SockAddr,
477 _cx: &mut Context<'_>,
478 ) -> Poll<std::io::Result<usize>> {
479 let mut buf = bufs
480 .iter()
481 .map(|io_slice| io_slice.as_bytes())
482 .flatten()
483 .copied()
484 .collect::<Vec<u8>>();
485 let (mut header, _): (zerocopy::Ref<_, IcmpHeader>, _) =
486 match zerocopy::Ref::from_prefix(&mut buf[..]).map_err(Into::into) {
487 Ok(layout_verified) => layout_verified,
488 Err(zerocopy::SizeError { .. }) => {
489 return Poll::Ready(Err(std::io::Error::new(
490 std::io::ErrorKind::InvalidInput,
491 "failed to parse ICMP header from provided bytes",
492 )))
493 }
494 };
495 header.type_ = I::ECHO_REPLY_TYPE;
496 let len = buf.len();
497 let () = self.buffer.borrow_mut().push_back((buf, addr.clone()));
498 Poll::Ready(Ok(len))
499 }
500
501 fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()> {
502 panic!("unexpected call to bind_device({:?})", interface);
503 }
504 }
505
506 trait IpExt: super::IpExt {
507 fn test_addr() -> Self::SockAddr;
510 }
511
512 impl IpExt for Ipv4 {
513 fn test_addr() -> Self::SockAddr {
514 std_socket_addr_v4!("1.2.3.4:0")
517 }
518 }
519
520 impl IpExt for Ipv6 {
521 fn test_addr() -> Self::SockAddr {
522 std_socket_addr_v6!("[abcd::1]:0")
525 }
526 }
527
528 const PING_MESSAGE: &str = "Hello from ping library unit test!";
529 const PING_COUNT: u16 = 3;
530 const PING_SEQ_RANGE: std::ops::RangeInclusive<u16> = 1..=PING_COUNT;
531
532 #[test]
533 fn test_ipv4() {
534 test_ping::<Ipv4>();
535 }
536
537 #[test]
538 fn test_ipv6() {
539 test_ping::<Ipv6>();
540 }
541
542 fn test_ping<I: IpExt>() {
543 let socket = FakeSocket::<I>::new();
544
545 let packets = PING_SEQ_RANGE
546 .into_iter()
547 .map(|sequence| PingData {
548 addr: I::test_addr(),
549 sequence,
550 body: PING_MESSAGE.as_bytes().to_vec(),
551 })
552 .collect::<Vec<_>>();
553 let packet_stream = futures::stream::iter(packets.iter().cloned());
554 let () = PingSink::new(&socket)
555 .send_all(&mut packet_stream.map(Ok))
556 .now_or_never()
557 .expect("ping request send blocked unexpectedly")
558 .expect("ping send error");
559
560 let replies =
561 PingStream::<_, _, { PING_MESSAGE.len() + std::mem::size_of::<IcmpHeader>() }>::new(
562 &socket,
563 )
564 .take(PING_COUNT.into())
565 .try_collect::<Vec<_>>()
566 .now_or_never()
567 .expect("ping reply stream blocked unexpectedly")
568 .expect("failed to collect ping reply stream");
569 assert_eq!(packets, replies);
570 }
571}