fuchsia_async/net/fuchsia/
udp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::ready;
10use futures::task::{Context, Poll};
11use std::io;
12use std::net::{self, SocketAddr};
13use std::ops::Deref;
14use std::os::fd::{AsRawFd, RawFd};
15use std::pin::Pin;
16
17fn new_socket_address_conversion_error() -> std::io::Error {
18    io::Error::other("socket address is not IPv4 or IPv6")
19}
20
21/// An I/O object representing a UDP socket.
22///
23/// Like [`std::net::UdpSocket`], a `UdpSocket` represents a socket that is
24/// bound to a local address, and optionally is connected to a remote address.
25#[derive(Debug)]
26pub struct UdpSocket(DatagramSocket);
27
28impl Deref for UdpSocket {
29    type Target = DatagramSocket;
30
31    fn deref(&self) -> &Self::Target {
32        &self.0
33    }
34}
35
36impl UdpSocket {
37    /// Creates an async UDP socket from the given address.
38    ///
39    /// See [`std::net::UdpSocket::bind()`].
40    pub fn bind(addr: &SocketAddr) -> io::Result<UdpSocket> {
41        let socket = net::UdpSocket::bind(addr)?;
42        UdpSocket::from_socket(socket)
43    }
44
45    /// Creates an async UDP socket from a [`std::net::UdpSocket`].
46    pub fn from_socket(socket: net::UdpSocket) -> io::Result<UdpSocket> {
47        let socket: socket2::Socket = socket.into();
48        socket.set_nonblocking(true)?;
49        let evented_fd = unsafe { EventedFd::new(socket)? };
50        Ok(UdpSocket(DatagramSocket(evented_fd)))
51    }
52
53    /// Create a new UDP socket from an existing bound socket.
54    pub fn from_datagram(socket: DatagramSocket) -> io::Result<Self> {
55        let sock: &socket2::Socket = socket.as_ref();
56        if sock.r#type()? != socket2::Type::DGRAM {
57            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket type is not datagram"));
58        }
59        if sock.protocol()? != Some(socket2::Protocol::UDP) {
60            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket protocol is not UDP"));
61        }
62        // Maintain the invariant that the socket is bound (or connected).
63        let _: socket2::SockAddr = socket.local_addr()?;
64        Ok(Self(socket))
65    }
66
67    /// Returns the socket address that this socket was created from.
68    pub fn local_addr(&self) -> io::Result<SocketAddr> {
69        self.0
70            .local_addr()
71            .and_then(|sa| sa.as_socket().ok_or_else(new_socket_address_conversion_error))
72    }
73
74    /// Connects the socket to the specified remote address.
75    ///
76    /// See [`std::net::UdpSocket::connect()`].
77    pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> {
78        let addr: socket2::SockAddr = (*addr).into();
79        self.0.as_ref().connect(&addr)
80    }
81
82    /// Receive a UDP datagram from the socket.
83    ///
84    /// Asynchronous version of [`std::net::UdpSocket::recv_from()`].
85    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> UdpRecvFrom<'a> {
86        UdpRecvFrom { socket: self, buf }
87    }
88
89    /// Send a UDP datagram via the socket. Fails if the socket is not connected.
90    ///
91    /// Asynchronous version of [`std::net::UdpSocket::send()`].
92    pub fn send<'a>(&'a self, buf: &'a [u8]) -> SendFuture<'a> {
93        SendFuture { socket: self, buf }
94    }
95
96    /// Send a UDP datagram via the socket to the specified address.
97    ///
98    /// Asynchronous version of [`std::net::UdpSocket::send_to()`].
99    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: SocketAddr) -> SendTo<'a> {
100        SendTo { socket: self, buf, addr: addr.into() }
101    }
102
103    /// Asynchronously send a datagram (possibly split over multiple buffers) via the socket.
104    pub fn send_to_vectored<'a>(
105        &'a self,
106        bufs: &'a [io::IoSlice<'a>],
107        addr: SocketAddr,
108    ) -> SendToVectored<'a> {
109        SendToVectored { socket: self, bufs, addr: addr.into() }
110    }
111}
112
113impl AsRawFd for UdpSocket {
114    fn as_raw_fd(&self) -> RawFd {
115        self.0.as_raw_fd()
116    }
117}
118
119/// An I/O object representing a datagram socket.
120#[derive(Debug)]
121pub struct DatagramSocket(EventedFd<socket2::Socket>);
122
123impl Deref for DatagramSocket {
124    type Target = EventedFd<socket2::Socket>;
125
126    fn deref(&self) -> &Self::Target {
127        &self.0
128    }
129}
130
131impl DatagramSocket {
132    /// Create a new async datagram socket.
133    pub fn new(domain: socket2::Domain, protocol: Option<socket2::Protocol>) -> io::Result<Self> {
134        let socket = socket2::Socket::new(domain, socket2::Type::DGRAM.nonblocking(), protocol)?;
135        let evented_fd = unsafe { EventedFd::new(socket)? };
136        Ok(Self(evented_fd))
137    }
138
139    /// Create a new async datagram socket from an existing socket.
140    pub fn new_from_socket(socket: socket2::Socket) -> io::Result<Self> {
141        match socket.r#type()? {
142            socket2::Type::DGRAM
143            // SOCK_RAW sockets operate on raw datagrams (e.g. datagrams that
144            // include the frame/packet header). For the purposes of
145            // `DatagramSocket`, their semantics are identical.
146            | socket2::Type::RAW => {
147                socket.set_nonblocking(true)?;
148                let evented_fd = unsafe { EventedFd::new(socket)? };
149                Ok(Self(evented_fd))
150            }
151            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid socket type.")),
152        }
153    }
154
155    /// Returns the socket address that this socket was created from.
156    pub fn local_addr(&self) -> io::Result<socket2::SockAddr> {
157        self.0.as_ref().local_addr()
158    }
159
160    /// Receive a datagram asynchronously from the socket.
161    ///
162    /// The returned future will resolve with the number of bytes read and the source address of
163    /// the datagram on success.
164    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFrom<'a> {
165        RecvFrom { socket: self, buf }
166    }
167
168    /// Attempt to receive a datagram from the socket without blocking.
169    pub fn async_recv_from(
170        &self,
171        buf: &mut [u8],
172        cx: &mut Context<'_>,
173    ) -> Poll<io::Result<(usize, socket2::SockAddr)>> {
174        ready!(EventedFd::poll_readable(&self.0, cx))?;
175        // SAFETY: socket2::Socket::recv_from takes a `&mut [MaybeUninit<u8>]`, so it's necessary to
176        // type-pun `&mut [u8]`. This is safe because the bytes are known to be initialized, and
177        // MaybeUninit's layout is guaranteed to be equivalent to its wrapped type.
178        let buf = unsafe {
179            std::slice::from_raw_parts_mut(
180                buf.as_mut_ptr() as *mut core::mem::MaybeUninit<u8>,
181                buf.len(),
182            )
183        };
184        match self.0.as_ref().recv_from(buf) {
185            Err(e) => {
186                if e.kind() == io::ErrorKind::WouldBlock {
187                    self.0.need_read(cx);
188                    Poll::Pending
189                } else {
190                    Poll::Ready(Err(e))
191                }
192            }
193            Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
194        }
195    }
196
197    /// Send a datagram via the socket to the given address.
198    ///
199    /// The returned future will resolve with the number of bytes sent on success.
200    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: socket2::SockAddr) -> SendTo<'a> {
201        SendTo { socket: self, buf, addr }
202    }
203
204    fn send_result_to_poll_result(
205        &self,
206        r: io::Result<usize>,
207        cx: &mut Context<'_>,
208    ) -> Poll<io::Result<usize>> {
209        match r {
210            Err(e) => {
211                if e.kind() == io::ErrorKind::WouldBlock {
212                    self.0.need_write(cx);
213                    Poll::Pending
214                } else {
215                    Poll::Ready(Err(e))
216                }
217            }
218            Ok(size) => Poll::Ready(Ok(size)),
219        }
220    }
221
222    /// Attempt to send a datagram via the socket without blocking.
223    pub fn async_send(&self, buf: &[u8], cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
224        ready!(EventedFd::poll_writable(&self.0, cx))?;
225        self.send_result_to_poll_result(self.0.as_ref().send(buf), cx)
226    }
227
228    /// Attempt to send a datagram to the specified address via the socket
229    /// without blocking.
230    pub fn async_send_to(
231        &self,
232        buf: &[u8],
233        addr: &socket2::SockAddr,
234        cx: &mut Context<'_>,
235    ) -> Poll<io::Result<usize>> {
236        ready!(EventedFd::poll_writable(&self.0, cx))?;
237        self.send_result_to_poll_result(self.0.as_ref().send_to(buf, addr), cx)
238    }
239
240    /// Send a datagram (possibly split over multiple buffers) via the socket.
241    pub fn send_to_vectored<'a>(
242        &'a self,
243        bufs: &'a [io::IoSlice<'a>],
244        addr: socket2::SockAddr,
245    ) -> SendToVectored<'a> {
246        SendToVectored { socket: self, bufs, addr }
247    }
248
249    /// Attempt to send a datagram (possibly split over multiple buffers) via the socket without
250    /// blocking.
251    pub fn async_send_to_vectored<'a>(
252        &self,
253        bufs: &'a [io::IoSlice<'a>],
254        addr: &socket2::SockAddr,
255        cx: &mut Context<'_>,
256    ) -> Poll<io::Result<usize>> {
257        ready!(EventedFd::poll_writable(&self.0, cx))?;
258        self.send_result_to_poll_result(self.0.as_ref().send_to_vectored(bufs, addr), cx)
259    }
260
261    /// Sets the value of the `SO_BROADCAST` option for this socket.
262    ///
263    /// When enabled, this socket is allowed to send packets to a broadcast address.
264    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
265        self.0.as_ref().set_broadcast(broadcast)
266    }
267
268    /// Gets the value of the `SO_BROADCAST` option for this socket.
269    pub fn broadcast(&self) -> io::Result<bool> {
270        self.0.as_ref().broadcast()
271    }
272
273    /// Sets the `SO_BINDTODEVICE` socket option.
274    ///
275    /// If a socket is bound to an interface, only packets received from that particular interface
276    /// are processed by the socket. Note that this only works for some socket types, particularly
277    /// AF_INET sockets.
278    ///
279    /// The binding will be removed if `interface` is `None` or an empty byte slice.
280    pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
281        self.0.as_ref().bind_device(interface)
282    }
283
284    /// Gets the value of the `SO_BINDTODEVICE` socket option.
285    ///
286    /// `Ok(None)` will be returned if the socket option is not set.
287    pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
288        self.0.as_ref().device()
289    }
290}
291
292/// Future returned by [`UdpSocket::recv_from()`].
293#[must_use = "futures do nothing unless you `.await` or poll them"]
294pub struct UdpRecvFrom<'a> {
295    socket: &'a UdpSocket,
296    buf: &'a mut [u8],
297}
298
299impl<'a> Future for UdpRecvFrom<'a> {
300    type Output = io::Result<(usize, SocketAddr)>;
301
302    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
303        let this = &mut *self;
304        let (received, addr) = ready!(this.socket.0.async_recv_from(this.buf, cx))?;
305        Poll::Ready(
306            addr.as_socket()
307                .ok_or_else(new_socket_address_conversion_error)
308                .map(|addr| (received, addr)),
309        )
310    }
311}
312
313/// Future returned by [`DatagramSocket::recv_from()`].
314#[must_use = "futures do nothing unless you `.await` or poll them"]
315pub struct RecvFrom<'a> {
316    socket: &'a DatagramSocket,
317    buf: &'a mut [u8],
318}
319
320impl<'a> Future for RecvFrom<'a> {
321    type Output = io::Result<(usize, socket2::SockAddr)>;
322
323    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
324        let this = &mut *self;
325        let (received, addr) = ready!(this.socket.async_recv_from(this.buf, cx))?;
326        Poll::Ready(Ok((received, addr)))
327    }
328}
329
330/// Future returned by [`DatagramSocket::send()`].
331#[must_use = "futures do nothing unless you `.await` or poll them"]
332pub struct SendFuture<'a> {
333    socket: &'a DatagramSocket,
334    buf: &'a [u8],
335}
336
337impl<'a> Future for SendFuture<'a> {
338    type Output = io::Result<usize>;
339
340    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341        self.socket.async_send(self.buf, cx)
342    }
343}
344
345/// Future returned by [`DatagramSocket::send_to()`].
346#[must_use = "futures do nothing unless you `.await` or poll them"]
347pub struct SendTo<'a> {
348    socket: &'a DatagramSocket,
349    buf: &'a [u8],
350    addr: socket2::SockAddr,
351}
352
353impl<'a> Future for SendTo<'a> {
354    type Output = io::Result<usize>;
355
356    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
357        self.socket.async_send_to(self.buf, &self.addr, cx)
358    }
359}
360
361/// Future returned by [`DatagramSocket::send_to_vectored()`].
362#[must_use = "futures do nothing unless you `.await` or poll them"]
363pub struct SendToVectored<'a> {
364    socket: &'a DatagramSocket,
365    bufs: &'a [io::IoSlice<'a>],
366    addr: socket2::SockAddr,
367}
368
369impl<'a> Future for SendToVectored<'a> {
370    type Output = io::Result<usize>;
371
372    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
373        self.socket.async_send_to_vectored(self.bufs, &self.addr, cx)
374    }
375}
376
377#[cfg(test)]
378mod test {
379    #[test]
380    fn datagram_socket_new_from_socket() {
381        let sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)
382            .expect("failed to create stream socket");
383        match super::DatagramSocket::new_from_socket(sock) {
384            Err(e) => {
385                if e.kind() != std::io::ErrorKind::InvalidInput {
386                    panic!("got: {e:?}; want error of kind InvalidInput");
387                }
388            }
389            Ok(_) => panic!("DatagramSocket created from stream socket succeeded unexpectedly"),
390        }
391    }
392}