fuchsia_async/net/fuchsia/
udp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::ready;
10use futures::task::{Context, Poll};
11use std::io;
12use std::net::{self, SocketAddr};
13use std::ops::Deref;
14use std::pin::Pin;
15
16fn new_socket_address_conversion_error() -> std::io::Error {
17    io::Error::other("socket address is not IPv4 or IPv6")
18}
19
20/// An I/O object representing a UDP socket.
21///
22/// Like [`std::net::UdpSocket`], a `UdpSocket` represents a socket that is
23/// bound to a local address, and optionally is connected to a remote address.
24#[derive(Debug)]
25pub struct UdpSocket(DatagramSocket);
26
27impl Deref for UdpSocket {
28    type Target = DatagramSocket;
29
30    fn deref(&self) -> &Self::Target {
31        &self.0
32    }
33}
34
35impl UdpSocket {
36    /// Creates an async UDP socket from the given address.
37    ///
38    /// See [`std::net::UdpSocket::bind()`].
39    pub fn bind(addr: &SocketAddr) -> io::Result<UdpSocket> {
40        let socket = net::UdpSocket::bind(addr)?;
41        UdpSocket::from_socket(socket)
42    }
43
44    /// Creates an async UDP socket from a [`std::net::UdpSocket`].
45    pub fn from_socket(socket: net::UdpSocket) -> io::Result<UdpSocket> {
46        let socket: socket2::Socket = socket.into();
47        socket.set_nonblocking(true)?;
48        let socket = socket.into();
49        let evented_fd = unsafe { EventedFd::new(socket)? };
50        Ok(UdpSocket(DatagramSocket(evented_fd)))
51    }
52
53    /// Create a new UDP socket from an existing bound socket.
54    pub fn from_datagram(socket: DatagramSocket) -> io::Result<Self> {
55        let sock: &socket2::Socket = socket.as_ref();
56        if sock.r#type()? != socket2::Type::DGRAM {
57            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket type is not datagram"));
58        }
59        if sock.protocol()? != Some(socket2::Protocol::UDP) {
60            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket protocol is not UDP"));
61        }
62        // Maintain the invariant that the socket is bound (or connected).
63        let _: socket2::SockAddr = socket.local_addr()?;
64        Ok(Self(socket))
65    }
66
67    /// Returns the socket address that this socket was created from.
68    pub fn local_addr(&self) -> io::Result<SocketAddr> {
69        self.0
70            .local_addr()
71            .and_then(|sa| sa.as_socket().ok_or_else(new_socket_address_conversion_error))
72    }
73
74    /// Receive a UDP datagram from the socket.
75    ///
76    /// Asynchronous version of [`std::net::UdpSocket::recv_from()`].
77    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> UdpRecvFrom<'a> {
78        UdpRecvFrom { socket: self, buf }
79    }
80
81    /// Send a UDP datagram via the socket.
82    ///
83    /// Asynchronous version of [`std::net::UdpSocket::send_to()`].
84    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: SocketAddr) -> SendTo<'a> {
85        SendTo { socket: self, buf, addr: addr.into() }
86    }
87
88    /// Asynchronously send a datagram (possibly split over multiple buffers) via the socket.
89    pub fn send_to_vectored<'a>(
90        &'a self,
91        bufs: &'a [io::IoSlice<'a>],
92        addr: SocketAddr,
93    ) -> SendToVectored<'a> {
94        SendToVectored { socket: self, bufs, addr: addr.into() }
95    }
96}
97
98/// An I/O object representing a datagram socket.
99#[derive(Debug)]
100pub struct DatagramSocket(EventedFd<socket2::Socket>);
101
102impl Deref for DatagramSocket {
103    type Target = EventedFd<socket2::Socket>;
104
105    fn deref(&self) -> &Self::Target {
106        &self.0
107    }
108}
109
110impl DatagramSocket {
111    /// Create a new async datagram socket.
112    pub fn new(domain: socket2::Domain, protocol: Option<socket2::Protocol>) -> io::Result<Self> {
113        let socket = socket2::Socket::new(domain, socket2::Type::DGRAM.nonblocking(), protocol)?;
114        let evented_fd = unsafe { EventedFd::new(socket)? };
115        Ok(Self(evented_fd))
116    }
117
118    /// Create a new async datagram socket from an existing socket.
119    pub fn new_from_socket(socket: socket2::Socket) -> io::Result<Self> {
120        match socket.r#type()? {
121            socket2::Type::DGRAM
122            // SOCK_RAW sockets operate on raw datagrams (e.g. datagrams that
123            // include the frame/packet header). For the purposes of
124            // `DatagramSocket`, their semantics are identical.
125            | socket2::Type::RAW => {
126                socket.set_nonblocking(true)?;
127                let evented_fd = unsafe { EventedFd::new(socket)? };
128                Ok(Self(evented_fd))
129            }
130            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid socket type.")),
131        }
132    }
133
134    /// Returns the socket address that this socket was created from.
135    pub fn local_addr(&self) -> io::Result<socket2::SockAddr> {
136        self.0.as_ref().local_addr()
137    }
138
139    /// Receive a datagram asynchronously from the socket.
140    ///
141    /// The returned future will resolve with the number of bytes read and the source address of
142    /// the datagram on success.
143    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFrom<'a> {
144        RecvFrom { socket: self, buf }
145    }
146
147    /// Attempt to receive a datagram from the socket without blocking.
148    pub fn async_recv_from(
149        &self,
150        buf: &mut [u8],
151        cx: &mut Context<'_>,
152    ) -> Poll<io::Result<(usize, socket2::SockAddr)>> {
153        ready!(EventedFd::poll_readable(&self.0, cx))?;
154        // SAFETY: socket2::Socket::recv_from takes a `&mut [MaybeUninit<u8>]`, so it's necessary to
155        // type-pun `&mut [u8]`. This is safe because the bytes are known to be initialized, and
156        // MaybeUninit's layout is guaranteed to be equivalent to its wrapped type.
157        let buf = unsafe {
158            std::slice::from_raw_parts_mut(
159                buf.as_mut_ptr() as *mut core::mem::MaybeUninit<u8>,
160                buf.len(),
161            )
162        };
163        match self.0.as_ref().recv_from(buf) {
164            Err(e) => {
165                if e.kind() == io::ErrorKind::WouldBlock {
166                    self.0.need_read(cx);
167                    Poll::Pending
168                } else {
169                    Poll::Ready(Err(e))
170                }
171            }
172            Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
173        }
174    }
175
176    /// Send a datagram via the socket to the given address.
177    ///
178    /// The returned future will resolve with the number of bytes sent on success.
179    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: socket2::SockAddr) -> SendTo<'a> {
180        SendTo { socket: self, buf, addr }
181    }
182
183    /// Attempt to send a datagram via the socket without blocking.
184    pub fn async_send_to(
185        &self,
186        buf: &[u8],
187        addr: &socket2::SockAddr,
188        cx: &mut Context<'_>,
189    ) -> Poll<io::Result<usize>> {
190        ready!(EventedFd::poll_writable(&self.0, cx))?;
191        match self.0.as_ref().send_to(buf, addr) {
192            Err(e) => {
193                if e.kind() == io::ErrorKind::WouldBlock {
194                    self.0.need_write(cx);
195                    Poll::Pending
196                } else {
197                    Poll::Ready(Err(e))
198                }
199            }
200            Ok(size) => Poll::Ready(Ok(size)),
201        }
202    }
203
204    /// Send a datagram (possibly split over multiple buffers) via the socket.
205    pub fn send_to_vectored<'a>(
206        &'a self,
207        bufs: &'a [io::IoSlice<'a>],
208        addr: socket2::SockAddr,
209    ) -> SendToVectored<'a> {
210        SendToVectored { socket: self, bufs, addr }
211    }
212
213    /// Attempt to send a datagram (possibly split over multiple buffers) via the socket without
214    /// blocking.
215    pub fn async_send_to_vectored<'a>(
216        &self,
217        bufs: &'a [io::IoSlice<'a>],
218        addr: &socket2::SockAddr,
219        cx: &mut Context<'_>,
220    ) -> Poll<io::Result<usize>> {
221        ready!(EventedFd::poll_writable(&self.0, cx))?;
222        match self.0.as_ref().send_to_vectored(bufs, addr) {
223            Err(e) => {
224                if e.kind() == io::ErrorKind::WouldBlock {
225                    self.0.need_write(cx);
226                    Poll::Pending
227                } else {
228                    Poll::Ready(Err(e))
229                }
230            }
231            Ok(size) => Poll::Ready(Ok(size)),
232        }
233    }
234
235    /// Sets the value of the `SO_BROADCAST` option for this socket.
236    ///
237    /// When enabled, this socket is allowed to send packets to a broadcast address.
238    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
239        self.0.as_ref().set_broadcast(broadcast)
240    }
241
242    /// Gets the value of the `SO_BROADCAST` option for this socket.
243    pub fn broadcast(&self) -> io::Result<bool> {
244        self.0.as_ref().broadcast()
245    }
246
247    /// Sets the `SO_BINDTODEVICE` socket option.
248    ///
249    /// If a socket is bound to an interface, only packets received from that particular interface
250    /// are processed by the socket. Note that this only works for some socket types, particularly
251    /// AF_INET sockets.
252    ///
253    /// The binding will be removed if `interface` is `None` or an empty byte slice.
254    pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
255        self.0.as_ref().bind_device(interface)
256    }
257
258    /// Gets the value of the `SO_BINDTODEVICE` socket option.
259    ///
260    /// `Ok(None)` will be returned if the socket option is not set.
261    pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
262        self.0.as_ref().device()
263    }
264}
265
266/// Future returned by [`UdpSocket::recv_from()`].
267#[must_use = "futures do nothing unless you `.await` or poll them"]
268pub struct UdpRecvFrom<'a> {
269    socket: &'a UdpSocket,
270    buf: &'a mut [u8],
271}
272
273impl<'a> Future for UdpRecvFrom<'a> {
274    type Output = io::Result<(usize, SocketAddr)>;
275
276    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277        let this = &mut *self;
278        let (received, addr) = ready!(this.socket.0.async_recv_from(this.buf, cx))?;
279        Poll::Ready(
280            addr.as_socket()
281                .ok_or_else(new_socket_address_conversion_error)
282                .map(|addr| (received, addr)),
283        )
284    }
285}
286
287/// Future returned by [`DatagramSocket::recv_from()`].
288#[must_use = "futures do nothing unless you `.await` or poll them"]
289pub struct RecvFrom<'a> {
290    socket: &'a DatagramSocket,
291    buf: &'a mut [u8],
292}
293
294impl<'a> Future for RecvFrom<'a> {
295    type Output = io::Result<(usize, socket2::SockAddr)>;
296
297    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
298        let this = &mut *self;
299        let (received, addr) = ready!(this.socket.async_recv_from(this.buf, cx))?;
300        Poll::Ready(Ok((received, addr)))
301    }
302}
303
304/// Future returned by [`DatagramSocket::send_to()`].
305#[must_use = "futures do nothing unless you `.await` or poll them"]
306pub struct SendTo<'a> {
307    socket: &'a DatagramSocket,
308    buf: &'a [u8],
309    addr: socket2::SockAddr,
310}
311
312impl<'a> Future for SendTo<'a> {
313    type Output = io::Result<usize>;
314
315    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
316        self.socket.async_send_to(self.buf, &self.addr, cx)
317    }
318}
319
320/// Future returned by [`DatagramSocket::send_to_vectored()`].
321#[must_use = "futures do nothing unless you `.await` or poll them"]
322pub struct SendToVectored<'a> {
323    socket: &'a DatagramSocket,
324    bufs: &'a [io::IoSlice<'a>],
325    addr: socket2::SockAddr,
326}
327
328impl<'a> Future for SendToVectored<'a> {
329    type Output = io::Result<usize>;
330
331    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
332        self.socket.async_send_to_vectored(self.bufs, &self.addr, cx)
333    }
334}
335
336#[cfg(test)]
337mod test {
338    #[test]
339    fn datagram_socket_new_from_socket() {
340        let sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)
341            .expect("failed to create stream socket");
342        match super::DatagramSocket::new_from_socket(sock) {
343            Err(e) => {
344                if e.kind() != std::io::ErrorKind::InvalidInput {
345                    panic!("got: {:?}; want error of kind InvalidInput", e);
346                }
347            }
348            Ok(_) => panic!("DatagramSocket created from stream socket succeeded unexpectedly"),
349        }
350    }
351}