1#![deny(missing_docs)]
6
7#[cfg(target_os = "fuchsia")]
21mod fuchsia;
22
23#[cfg(target_os = "fuchsia")]
24pub use fuchsia::{new_icmp_socket, IpExt as FuchsiaIpExt};
25
26use futures::{ready, Sink, SinkExt as _, Stream, TryStreamExt as _};
27use net_types::ip::{Ip, Ipv4, Ipv6};
28use std::marker::PhantomData;
29use std::pin::Pin;
30use std::task::{Context, Poll};
31use thiserror::Error;
32use zerocopy::byteorder::network_endian::U16;
33use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
34
35pub const ICMP_HEADER_LEN: usize = std::mem::size_of::<IcmpHeader>();
37
38#[repr(C)]
40#[derive(KnownLayout, FromBytes, IntoBytes, Immutable, Unaligned, Debug, PartialEq, Eq, Clone)]
41struct IcmpHeader {
42 type_: u8,
43 code: u8,
44 checksum: U16,
45 id: U16,
46 sequence: U16,
47}
48
49impl IcmpHeader {
50 fn new<I: IpExt>(sequence: u16) -> Self {
51 Self {
52 type_: I::ECHO_REQUEST_TYPE,
53 code: 0,
54 checksum: 0.into(),
55 id: 0.into(),
56 sequence: sequence.into(),
57 }
58 }
59}
60
61#[derive(Debug, Error)]
63pub enum PingError {
64 #[error("send error")]
66 Send(#[source] std::io::Error),
67 #[error("wrong number of bytes sent, got: {got}, want: {want}")]
69 SendLength {
70 got: usize,
72 want: usize,
74 },
75 #[error("recv error")]
77 Recv(#[source] std::io::Error),
78 #[error("failed to parse ICMP header")]
80 Parse,
81 #[error("wrong reply type, got: {got}, want: {want}")]
83 ReplyType {
84 got: u8,
86 want: u8,
88 },
89 #[error("non-zero reply code: {0}")]
91 ReplyCode(u8),
92 #[error("reply message body mismatch, got: {got:?}, want: {want:?}")]
94 Body {
95 got: Vec<u8>,
97 want: Vec<u8>,
99 },
100}
101
102pub trait TryFromSockAddr: Sized {
107 fn try_from(value: socket2::SockAddr) -> std::io::Result<Self>;
109}
110
111impl TryFromSockAddr for std::net::SocketAddrV4 {
112 fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
113 addr.as_socket_ipv4()
114 .ok_or_else(|| std::io::Error::other(format!("socket address is not v4 {:?}", addr)))
115 }
116}
117
118impl TryFromSockAddr for std::net::SocketAddrV6 {
119 fn try_from(addr: socket2::SockAddr) -> std::io::Result<Self> {
120 addr.as_socket_ipv6()
121 .ok_or_else(|| std::io::Error::other(format!("socket address is not v6 {:?}", addr)))
122 }
123}
124
125pub trait IpExt: Ip + Unpin {
127 type SockAddr: Into<socket2::SockAddr>
129 + TryFromSockAddr
130 + Clone
131 + Copy
132 + Unpin
133 + PartialEq
134 + std::fmt::Debug
135 + std::fmt::Display
136 + Eq
137 + std::hash::Hash;
138
139 const DOMAIN: socket2::Domain;
141 const PROTOCOL: socket2::Protocol;
143
144 const ECHO_REQUEST_TYPE: u8;
146 const ECHO_REPLY_TYPE: u8;
148}
149
150impl IpExt for Ipv4 {
153 type SockAddr = std::net::SocketAddrV4;
154
155 const DOMAIN: socket2::Domain = socket2::Domain::IPV4;
156 const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV4;
157
158 const ECHO_REQUEST_TYPE: u8 = 8;
159 const ECHO_REPLY_TYPE: u8 = 0;
160}
161
162impl IpExt for Ipv6 {
165 type SockAddr = std::net::SocketAddrV6;
166
167 const DOMAIN: socket2::Domain = socket2::Domain::IPV6;
168 const PROTOCOL: socket2::Protocol = socket2::Protocol::ICMPV6;
169
170 const ECHO_REQUEST_TYPE: u8 = 128;
171 const ECHO_REPLY_TYPE: u8 = 129;
172}
173
174pub trait IcmpSocket<I>: Unpin
176where
177 I: IpExt,
178{
179 fn async_recv_from(
183 &self,
184 buf: &mut [u8],
185 cx: &mut Context<'_>,
186 ) -> Poll<std::io::Result<(usize, I::SockAddr)>>;
187
188 fn async_send_to_vectored(
192 &self,
193 bufs: &[std::io::IoSlice<'_>],
194 addr: &I::SockAddr,
195 cx: &mut Context<'_>,
196 ) -> Poll<std::io::Result<usize>>;
197
198 fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()>;
203}
204
205#[derive(Clone, Debug, PartialEq, Eq, Hash)]
207pub struct PingData<I: IpExt> {
208 pub addr: I::SockAddr,
210 pub sequence: u16,
212 pub body: Vec<u8>,
214}
215
216pub fn new_unicast_sink_and_stream<'a, I, S, const N: usize>(
224 socket: &'a S,
225 addr: &'a I::SockAddr,
226 body: &'a [u8],
227) -> (impl Sink<u16, Error = PingError> + 'a, impl Stream<Item = Result<u16, PingError>> + 'a)
228where
229 I: IpExt,
230 S: IcmpSocket<I>,
231{
232 (
233 PingSink::new(socket).with(move |sequence| {
234 futures::future::ok(PingData { addr: addr.clone(), sequence, body: body.to_vec() })
235 }),
236 PingStream::<I, S, N>::new(socket).try_filter_map(
237 move |PingData { addr: got_addr, sequence, body: got_body }| {
238 futures::future::ready(if got_addr == *addr {
239 if got_body == body {
240 Ok(Some(sequence))
241 } else {
242 Err(PingError::Body { got: got_body, want: body.to_vec() })
243 }
244 } else {
245 Ok(None)
246 })
247 },
248 ),
249 )
250}
251
252pub struct PingStream<'a, I, S, const N: usize>
256where
257 I: IpExt,
258 S: IcmpSocket<I>,
259{
260 socket: &'a S,
261 recv_buf: [u8; N],
262 _marker: PhantomData<I>,
263}
264
265impl<'a, I, S, const N: usize> PingStream<'a, I, S, N>
266where
267 I: IpExt,
268 S: IcmpSocket<I>,
269{
270 pub fn new(socket: &'a S) -> Self {
278 Self { socket, recv_buf: [0; N], _marker: PhantomData::<I> }
279 }
280}
281
282impl<'a, I, S, const N: usize> futures::stream::Stream for PingStream<'a, I, S, N>
283where
284 I: IpExt,
285 S: IcmpSocket<I>,
286{
287 type Item = Result<PingData<I>, PingError>;
288
289 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
290 let ping_stream = Pin::into_inner(self);
291 let buf = &mut ping_stream.recv_buf[..];
292 let socket = &ping_stream.socket;
293 Poll::Ready(Some(
294 ready!(socket.async_recv_from(buf, cx))
295 .map_err(PingError::Recv)
296 .and_then(|(len, addr)| verify_packet::<I>(addr, &ping_stream.recv_buf[..len])),
297 ))
298 }
299}
300
301pub struct PingSink<'a, I, S>
303where
304 I: IpExt,
305 S: IcmpSocket<I>,
306{
307 socket: &'a S,
308 packet: Option<(I::SockAddr, IcmpHeader, Vec<u8>)>,
309 _marker: PhantomData<I>,
310}
311
312impl<'a, I, S> PingSink<'a, I, S>
313where
314 I: IpExt,
315 S: IcmpSocket<I>,
316{
317 pub fn new(socket: &'a S) -> Self {
319 Self { socket, packet: None, _marker: PhantomData::<I> }
320 }
321}
322
323impl<'a, I, S> futures::sink::Sink<PingData<I>> for PingSink<'a, I, S>
324where
325 I: IpExt,
326 S: IcmpSocket<I>,
327{
328 type Error = PingError;
329
330 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331 self.poll_flush(cx)
332 }
333
334 fn start_send(
335 mut self: Pin<&mut Self>,
336 PingData { addr, sequence, body }: PingData<I>,
337 ) -> Result<(), Self::Error> {
338 let header = IcmpHeader::new::<I>(sequence);
339 assert_eq!(
340 self.packet.replace((addr, header, body)),
341 None,
342 "start_send called while element has yet to be flushed"
343 );
344 Ok(())
345 }
346
347 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
348 Poll::Ready(match &self.packet {
349 Some((addr, header, body)) => {
350 match ready!(self.socket.async_send_to_vectored(
351 &[
352 std::io::IoSlice::new(header.as_bytes()),
353 std::io::IoSlice::new(body.as_bytes()),
354 ],
355 addr,
356 cx
357 )) {
358 Ok(got) => {
359 let want = std::mem::size_of_val(&header) + body.len();
360 if got != want {
361 Err(PingError::SendLength { got, want })
362 } else {
363 self.packet = None;
364 Ok(())
365 }
366 }
367 Err(e) => Err(PingError::Send(e)),
368 }
369 }
370 None => Ok(()),
371 })
372 }
373
374 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
375 self.poll_flush(cx)
376 }
377}
378
379fn verify_packet<I: IpExt>(addr: I::SockAddr, packet: &[u8]) -> Result<PingData<I>, PingError> {
380 let (reply, body): (zerocopy::Ref<_, IcmpHeader>, _) = zerocopy::Ref::from_prefix(packet)
381 .map_err(Into::into)
382 .map_err(|_: zerocopy::SizeError<_, _>| PingError::Parse)?;
383
384 let &IcmpHeader { type_, code, checksum: _, id: _, sequence } = zerocopy::Ref::into_ref(reply);
392
393 if type_ != I::ECHO_REPLY_TYPE {
394 return Err(PingError::ReplyType { got: type_, want: I::ECHO_REPLY_TYPE });
395 }
396
397 if code != 0 {
398 return Err(PingError::ReplyCode(code));
399 }
400
401 Ok(PingData { addr, sequence: sequence.into(), body: body.to_vec() })
402}
403
404#[cfg(test)]
405mod test {
406 use super::{IcmpHeader, IcmpSocket, Ipv4, Ipv6, PingData, PingSink, PingStream};
407
408 use futures::{FutureExt as _, SinkExt as _, StreamExt as _, TryStreamExt as _};
409 use net_declare::{std_socket_addr_v4, std_socket_addr_v6};
410 use std::cell::RefCell;
411 use std::collections::VecDeque;
412 use std::task::{Context, Poll};
413 use zerocopy::IntoBytes as _;
414
415 #[derive(Default, Debug)]
419 struct FakeSocket<I: IpExt> {
420 buffer: RefCell<VecDeque<(Vec<u8>, I::SockAddr)>>,
423 }
424
425 impl<I: IpExt> FakeSocket<I> {
426 fn new() -> Self {
427 Self { buffer: RefCell::new(VecDeque::new()) }
428 }
429 }
430
431 impl<I: IpExt> IcmpSocket<I> for FakeSocket<I> {
432 fn async_recv_from(
433 &self,
434 buf: &mut [u8],
435 _cx: &mut Context<'_>,
436 ) -> Poll<std::io::Result<(usize, I::SockAddr)>> {
437 Poll::Ready(
438 self.buffer
439 .borrow_mut()
440 .pop_front()
441 .ok_or_else(|| {
442 std::io::Error::new(
443 std::io::ErrorKind::WouldBlock,
444 "fake socket request buffer is empty",
445 )
446 })
447 .and_then(|(reply, addr)| {
448 if buf.len() < reply.len() {
449 Err(std::io::Error::other(format!(
450 "recv buffer too small, got: {}, want: {}",
451 buf.len(),
452 reply.len()
453 )))
454 } else {
455 buf[..reply.len()].copy_from_slice(&reply);
456 Ok((reply.len(), addr))
457 }
458 }),
459 )
460 }
461
462 fn async_send_to_vectored(
463 &self,
464 bufs: &[std::io::IoSlice<'_>],
465 addr: &I::SockAddr,
466 _cx: &mut Context<'_>,
467 ) -> Poll<std::io::Result<usize>> {
468 let mut buf = bufs
469 .iter()
470 .map(|io_slice| io_slice.as_bytes())
471 .flatten()
472 .copied()
473 .collect::<Vec<u8>>();
474 let (mut header, _): (zerocopy::Ref<_, IcmpHeader>, _) =
475 match zerocopy::Ref::from_prefix(&mut buf[..]).map_err(Into::into) {
476 Ok(layout_verified) => layout_verified,
477 Err(zerocopy::SizeError { .. }) => {
478 return Poll::Ready(Err(std::io::Error::new(
479 std::io::ErrorKind::InvalidInput,
480 "failed to parse ICMP header from provided bytes",
481 )))
482 }
483 };
484 header.type_ = I::ECHO_REPLY_TYPE;
485 let len = buf.len();
486 let () = self.buffer.borrow_mut().push_back((buf, addr.clone()));
487 Poll::Ready(Ok(len))
488 }
489
490 fn bind_device(&self, interface: Option<&[u8]>) -> std::io::Result<()> {
491 panic!("unexpected call to bind_device({:?})", interface);
492 }
493 }
494
495 trait IpExt: super::IpExt {
496 fn test_addr() -> Self::SockAddr;
499 }
500
501 impl IpExt for Ipv4 {
502 fn test_addr() -> Self::SockAddr {
503 std_socket_addr_v4!("1.2.3.4:0")
506 }
507 }
508
509 impl IpExt for Ipv6 {
510 fn test_addr() -> Self::SockAddr {
511 std_socket_addr_v6!("[abcd::1]:0")
514 }
515 }
516
517 const PING_MESSAGE: &str = "Hello from ping library unit test!";
518 const PING_COUNT: u16 = 3;
519 const PING_SEQ_RANGE: std::ops::RangeInclusive<u16> = 1..=PING_COUNT;
520
521 #[test]
522 fn test_ipv4() {
523 test_ping::<Ipv4>();
524 }
525
526 #[test]
527 fn test_ipv6() {
528 test_ping::<Ipv6>();
529 }
530
531 fn test_ping<I: IpExt>() {
532 let socket = FakeSocket::<I>::new();
533
534 let packets = PING_SEQ_RANGE
535 .into_iter()
536 .map(|sequence| PingData {
537 addr: I::test_addr(),
538 sequence,
539 body: PING_MESSAGE.as_bytes().to_vec(),
540 })
541 .collect::<Vec<_>>();
542 let packet_stream = futures::stream::iter(packets.iter().cloned());
543 let () = PingSink::new(&socket)
544 .send_all(&mut packet_stream.map(Ok))
545 .now_or_never()
546 .expect("ping request send blocked unexpectedly")
547 .expect("ping send error");
548
549 let replies =
550 PingStream::<_, _, { PING_MESSAGE.len() + std::mem::size_of::<IcmpHeader>() }>::new(
551 &socket,
552 )
553 .take(PING_COUNT.into())
554 .try_collect::<Vec<_>>()
555 .now_or_never()
556 .expect("ping reply stream blocked unexpectedly")
557 .expect("failed to collect ping reply stream");
558 assert_eq!(packets, replies);
559 }
560}