Skip to main content

fuchsia_async/net/fuchsia/
udp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::ready;
10use futures::task::{Context, Poll};
11use std::io;
12use std::net::{self, SocketAddr};
13use std::ops::Deref;
14use std::os::fd::{AsRawFd, RawFd};
15use std::pin::Pin;
16use zx_status_ext::StatusExt;
17
18fn new_socket_address_conversion_error() -> std::io::Error {
19    io::Error::other("socket address is not IPv4 or IPv6")
20}
21
22/// An I/O object representing a UDP socket.
23///
24/// Like [`std::net::UdpSocket`], a `UdpSocket` represents a socket that is
25/// bound to a local address, and optionally is connected to a remote address.
26#[derive(Debug)]
27pub struct UdpSocket(DatagramSocket);
28
29impl Deref for UdpSocket {
30    type Target = DatagramSocket;
31
32    fn deref(&self) -> &Self::Target {
33        &self.0
34    }
35}
36
37impl UdpSocket {
38    /// Creates an async UDP socket from the given address.
39    ///
40    /// See [`std::net::UdpSocket::bind()`].
41    pub fn bind(addr: &SocketAddr) -> io::Result<UdpSocket> {
42        let socket = net::UdpSocket::bind(addr)?;
43        UdpSocket::from_socket(socket)
44    }
45
46    /// Creates an async UDP socket from a [`std::net::UdpSocket`].
47    pub fn from_socket(socket: net::UdpSocket) -> io::Result<UdpSocket> {
48        let socket: socket2::Socket = socket.into();
49        socket.set_nonblocking(true)?;
50        let evented_fd = unsafe { EventedFd::new(socket)? };
51        Ok(UdpSocket(DatagramSocket(evented_fd)))
52    }
53
54    /// Create a new UDP socket from an existing bound socket.
55    pub fn from_datagram(socket: DatagramSocket) -> io::Result<Self> {
56        let sock: &socket2::Socket = socket.as_ref();
57        if sock.r#type()? != socket2::Type::DGRAM {
58            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket type is not datagram"));
59        }
60        if sock.protocol()? != Some(socket2::Protocol::UDP) {
61            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket protocol is not UDP"));
62        }
63        // Maintain the invariant that the socket is bound (or connected).
64        let _: socket2::SockAddr = socket.local_addr()?;
65        Ok(Self(socket))
66    }
67
68    /// Returns the socket address that this socket was created from.
69    pub fn local_addr(&self) -> io::Result<SocketAddr> {
70        self.0
71            .local_addr()
72            .and_then(|sa| sa.as_socket().ok_or_else(new_socket_address_conversion_error))
73    }
74
75    /// Connects the socket to the specified remote address.
76    ///
77    /// See [`std::net::UdpSocket::connect()`].
78    pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
79        let addr: socket2::SockAddr = (*addr).into();
80        self.0.as_ref().connect(&addr)
81    }
82
83    /// Receive a UDP datagram from the socket.
84    ///
85    /// Asynchronous version of [`std::net::UdpSocket::recv_from()`].
86    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> UdpRecvFrom<'a> {
87        UdpRecvFrom { socket: self, buf }
88    }
89
90    /// Send a UDP datagram via the socket. Fails if the socket is not connected.
91    ///
92    /// Asynchronous version of [`std::net::UdpSocket::send()`].
93    pub fn send<'a>(&'a self, buf: &'a [u8]) -> SendFuture<'a> {
94        SendFuture { socket: self, buf }
95    }
96
97    /// Send a UDP datagram via the socket to the specified address.
98    ///
99    /// Asynchronous version of [`std::net::UdpSocket::send_to()`].
100    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: SocketAddr) -> SendTo<'a> {
101        SendTo { socket: self, buf, addr: addr.into() }
102    }
103
104    /// Asynchronously send a datagram (possibly split over multiple buffers) via the socket.
105    pub fn send_to_vectored<'a>(
106        &'a self,
107        bufs: &'a [io::IoSlice<'a>],
108        addr: SocketAddr,
109    ) -> SendToVectored<'a> {
110        SendToVectored { socket: self, bufs, addr: addr.into() }
111    }
112}
113
114impl AsRawFd for UdpSocket {
115    fn as_raw_fd(&self) -> RawFd {
116        self.0.as_raw_fd()
117    }
118}
119
120/// An I/O object representing a datagram socket.
121#[derive(Debug)]
122pub struct DatagramSocket(EventedFd<socket2::Socket>);
123
124impl Deref for DatagramSocket {
125    type Target = EventedFd<socket2::Socket>;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132impl DatagramSocket {
133    /// Create a new async datagram socket.
134    pub fn new(domain: socket2::Domain, protocol: Option<socket2::Protocol>) -> io::Result<Self> {
135        let socket = socket2::Socket::new(domain, socket2::Type::DGRAM.nonblocking(), protocol)?;
136        let evented_fd = unsafe { EventedFd::new(socket)? };
137        Ok(Self(evented_fd))
138    }
139
140    /// Create a new async datagram socket from an existing socket.
141    pub fn new_from_socket(socket: socket2::Socket) -> io::Result<Self> {
142        match socket.r#type()? {
143            socket2::Type::DGRAM
144            // SOCK_RAW sockets operate on raw datagrams (e.g. datagrams that
145            // include the frame/packet header). For the purposes of
146            // `DatagramSocket`, their semantics are identical.
147            | socket2::Type::RAW => {
148                socket.set_nonblocking(true)?;
149                let evented_fd = unsafe { EventedFd::new(socket)? };
150                Ok(Self(evented_fd))
151            }
152            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid socket type.")),
153        }
154    }
155
156    /// Returns the socket address that this socket was created from.
157    pub fn local_addr(&self) -> io::Result<socket2::SockAddr> {
158        self.0.as_ref().local_addr()
159    }
160
161    /// Receive a datagram asynchronously from the socket.
162    ///
163    /// The returned future will resolve with the number of bytes read and the source address of
164    /// the datagram on success.
165    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFrom<'a> {
166        RecvFrom { socket: self, buf }
167    }
168
169    /// Attempt to receive a datagram from the socket without blocking.
170    pub fn async_recv_from(
171        &self,
172        buf: &mut [u8],
173        cx: &mut Context<'_>,
174    ) -> Poll<io::Result<(usize, socket2::SockAddr)>> {
175        ready!(EventedFd::poll_readable(&self.0, cx)).map_err(|s| s.into_io_error())?;
176        // SAFETY: socket2::Socket::recv_from takes a `&mut [MaybeUninit<u8>]`, so it's necessary to
177        // type-pun `&mut [u8]`. This is safe because the bytes are known to be initialized, and
178        // MaybeUninit's layout is guaranteed to be equivalent to its wrapped type.
179        let buf = unsafe {
180            std::slice::from_raw_parts_mut(
181                buf.as_mut_ptr() as *mut core::mem::MaybeUninit<u8>,
182                buf.len(),
183            )
184        };
185        match self.0.as_ref().recv_from(buf) {
186            Err(e) => {
187                if e.kind() == io::ErrorKind::WouldBlock {
188                    self.0.need_read(cx);
189                    Poll::Pending
190                } else {
191                    Poll::Ready(Err(e))
192                }
193            }
194            Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
195        }
196    }
197
198    /// Send a datagram via the socket to the given address.
199    ///
200    /// The returned future will resolve with the number of bytes sent on success.
201    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: socket2::SockAddr) -> SendTo<'a> {
202        SendTo { socket: self, buf, addr }
203    }
204
205    fn send_result_to_poll_result(
206        &self,
207        r: io::Result<usize>,
208        cx: &mut Context<'_>,
209    ) -> Poll<io::Result<usize>> {
210        match r {
211            Err(e) => {
212                if e.kind() == io::ErrorKind::WouldBlock {
213                    self.0.need_write(cx);
214                    Poll::Pending
215                } else {
216                    Poll::Ready(Err(e))
217                }
218            }
219            Ok(size) => Poll::Ready(Ok(size)),
220        }
221    }
222
223    /// Attempt to send a datagram via the socket without blocking.
224    pub fn async_send(&self, buf: &[u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
225        ready!(EventedFd::poll_writable(&self.0, cx)).map_err(|s| s.into_io_error())?;
226        self.send_result_to_poll_result(self.0.as_ref().send(buf), cx)
227    }
228
229    /// Attempt to send a datagram to the specified address via the socket
230    /// without blocking.
231    pub fn async_send_to(
232        &self,
233        buf: &[u8],
234        addr: &socket2::SockAddr,
235        cx: &mut Context<'_>,
236    ) -> Poll<io::Result<usize>> {
237        ready!(EventedFd::poll_writable(&self.0, cx)).map_err(|s| s.into_io_error())?;
238        self.send_result_to_poll_result(self.0.as_ref().send_to(buf, addr), cx)
239    }
240
241    /// Send a datagram (possibly split over multiple buffers) via the socket.
242    pub fn send_to_vectored<'a>(
243        &'a self,
244        bufs: &'a [io::IoSlice<'a>],
245        addr: socket2::SockAddr,
246    ) -> SendToVectored<'a> {
247        SendToVectored { socket: self, bufs, addr }
248    }
249
250    /// Attempt to send a datagram (possibly split over multiple buffers) via the socket without
251    /// blocking.
252    pub fn async_send_to_vectored<'a>(
253        &self,
254        bufs: &'a [io::IoSlice<'a>],
255        addr: &socket2::SockAddr,
256        cx: &mut Context<'_>,
257    ) -> Poll<io::Result<usize>> {
258        ready!(EventedFd::poll_writable(&self.0, cx)).map_err(|s| s.into_io_error())?;
259        self.send_result_to_poll_result(self.0.as_ref().send_to_vectored(bufs, addr), cx)
260    }
261
262    /// Sets the value of the `SO_BROADCAST` option for this socket.
263    ///
264    /// When enabled, this socket is allowed to send packets to a broadcast address.
265    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
266        self.0.as_ref().set_broadcast(broadcast)
267    }
268
269    /// Gets the value of the `SO_BROADCAST` option for this socket.
270    pub fn broadcast(&self) -> io::Result<bool> {
271        self.0.as_ref().broadcast()
272    }
273
274    /// Sets the `SO_BINDTODEVICE` socket option.
275    ///
276    /// If a socket is bound to an interface, only packets received from that particular interface
277    /// are processed by the socket. Note that this only works for some socket types, particularly
278    /// AF_INET sockets.
279    ///
280    /// The binding will be removed if `interface` is `None` or an empty byte slice.
281    pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
282        self.0.as_ref().bind_device(interface)
283    }
284
285    /// Gets the value of the `SO_BINDTODEVICE` socket option.
286    ///
287    /// `Ok(None)` will be returned if the socket option is not set.
288    pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
289        self.0.as_ref().device()
290    }
291}
292
293/// Future returned by [`UdpSocket::recv_from()`].
294#[must_use = "futures do nothing unless you `.await` or poll them"]
295pub struct UdpRecvFrom<'a> {
296    socket: &'a UdpSocket,
297    buf: &'a mut [u8],
298}
299
300impl<'a> Future for UdpRecvFrom<'a> {
301    type Output = io::Result<(usize, SocketAddr)>;
302
303    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
304        let this = &mut *self;
305        let (received, addr) = ready!(this.socket.0.async_recv_from(this.buf, cx))?;
306        Poll::Ready(
307            addr.as_socket()
308                .ok_or_else(new_socket_address_conversion_error)
309                .map(|addr| (received, addr)),
310        )
311    }
312}
313
314/// Future returned by [`DatagramSocket::recv_from()`].
315#[must_use = "futures do nothing unless you `.await` or poll them"]
316pub struct RecvFrom<'a> {
317    socket: &'a DatagramSocket,
318    buf: &'a mut [u8],
319}
320
321impl<'a> Future for RecvFrom<'a> {
322    type Output = io::Result<(usize, socket2::SockAddr)>;
323
324    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
325        let this = &mut *self;
326        let (received, addr) = ready!(this.socket.async_recv_from(this.buf, cx))?;
327        Poll::Ready(Ok((received, addr)))
328    }
329}
330
331/// Future returned by [`DatagramSocket::send()`].
332#[must_use = "futures do nothing unless you `.await` or poll them"]
333pub struct SendFuture<'a> {
334    socket: &'a DatagramSocket,
335    buf: &'a [u8],
336}
337
338impl<'a> Future for SendFuture<'a> {
339    type Output = io::Result<usize>;
340
341    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
342        self.socket.async_send(self.buf, cx)
343    }
344}
345
346/// Future returned by [`DatagramSocket::send_to()`].
347#[must_use = "futures do nothing unless you `.await` or poll them"]
348pub struct SendTo<'a> {
349    socket: &'a DatagramSocket,
350    buf: &'a [u8],
351    addr: socket2::SockAddr,
352}
353
354impl<'a> Future for SendTo<'a> {
355    type Output = io::Result<usize>;
356
357    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358        self.socket.async_send_to(self.buf, &self.addr, cx)
359    }
360}
361
362/// Future returned by [`DatagramSocket::send_to_vectored()`].
363#[must_use = "futures do nothing unless you `.await` or poll them"]
364pub struct SendToVectored<'a> {
365    socket: &'a DatagramSocket,
366    bufs: &'a [io::IoSlice<'a>],
367    addr: socket2::SockAddr,
368}
369
370impl<'a> Future for SendToVectored<'a> {
371    type Output = io::Result<usize>;
372
373    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
374        self.socket.async_send_to_vectored(self.bufs, &self.addr, cx)
375    }
376}
377
378#[cfg(test)]
379mod test {
380    #[test]
381    fn datagram_socket_new_from_socket() {
382        let sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)
383            .expect("failed to create stream socket");
384        match super::DatagramSocket::new_from_socket(sock) {
385            Err(e) => {
386                if e.kind() != std::io::ErrorKind::InvalidInput {
387                    panic!("got: {e:?}; want error of kind InvalidInput");
388                }
389            }
390            Ok(_) => panic!("DatagramSocket created from stream socket succeeded unexpectedly"),
391        }
392    }
393}