fuchsia_async/net/fuchsia/
tcp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::ready;
11use futures::stream::Stream;
12use futures::task::{Context, Poll};
13use std::io::{self, Write};
14use std::net::{self, Shutdown, SocketAddr};
15use std::ops::Deref;
16use std::os::fd::{AsRawFd, RawFd};
17use std::os::unix::io::FromRawFd as _;
18use std::pin::Pin;
19
20/// An I/O object representing a TCP socket listening for incoming connections.
21///
22/// This object can be converted into a stream of incoming connections for
23/// various forms of processing.
24#[derive(Debug)]
25pub struct TcpListener(EventedFd<net::TcpListener>);
26
27impl Unpin for TcpListener {}
28
29impl Deref for TcpListener {
30    type Target = EventedFd<net::TcpListener>;
31
32    fn deref(&self) -> &Self::Target {
33        &self.0
34    }
35}
36
37impl TcpListener {
38    /// Creates a new `TcpListener` bound to the provided socket.
39    pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
40        let domain = match *addr {
41            SocketAddr::V4(..) => socket2::Domain::IPV4,
42            SocketAddr::V6(..) => socket2::Domain::IPV6,
43        };
44        let socket =
45            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
46        // Allow this socket to be rebound while it is in TIME_WAIT.
47        //
48        // This is borrowed from std::net::TcpListener::bind. See
49        // https://github.com/rust-lang/rust/blob/db492ec/library/std/src/sys_common/net.rs#L371-L379.
50        let () = socket.set_reuse_address(true)?;
51        let addr = (*addr).into();
52        let () = socket.bind(&addr)?;
53        let () = socket.listen(1024)?;
54        TcpListener::from_std(socket.into())
55    }
56
57    /// Consumes this listener and returns a `Future` that resolves to an
58    /// `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
59    pub fn accept(self) -> Acceptor {
60        Acceptor(Some(self))
61    }
62
63    /// Consumes this listener and returns a `Stream` that resolves to elements
64    /// of type `io::Result<(TcpStream, SocketAddr)>`.
65    pub fn accept_stream(self) -> AcceptStream {
66        AcceptStream(self)
67    }
68
69    /// Poll on `accept`ing a new `TcpStream` from this listener.
70    /// This function is mainly intended for usage in manual `Future` or `Stream`
71    /// implementations.
72    pub fn async_accept(
73        &mut self,
74        cx: &mut Context<'_>,
75    ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
76        ready!(EventedFd::poll_readable(&self.0, cx))?;
77        match self.0.as_ref().accept() {
78            Err(e) => {
79                if e.kind() == io::ErrorKind::WouldBlock {
80                    self.0.need_read(cx);
81                    Poll::Pending
82                } else {
83                    Poll::Ready(Err(e))
84                }
85            }
86            Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
87        }
88    }
89
90    /// Creates a new instance of `fuchsia_async::net::TcpListener` from an
91    /// `std::net::TcpListener`.
92    pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
93        let listener: socket2::Socket = listener.into();
94        let () = listener.set_nonblocking(true)?;
95        let listener = listener.into();
96        let listener = unsafe { EventedFd::new(listener)? };
97        Ok(TcpListener(listener))
98    }
99
100    /// Returns a reference to the underlying `std::net::TcpListener`.
101    pub fn std(&self) -> &net::TcpListener {
102        self.as_ref()
103    }
104
105    /// Returns the local socket address of the listener.
106    pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
107        self.std().local_addr()
108    }
109}
110
111/// A future which resolves to an `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
112#[derive(Debug)]
113pub struct Acceptor(Option<TcpListener>);
114
115impl Future for Acceptor {
116    type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
117
118    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119        let (stream, addr);
120        {
121            let listener = self.0.as_mut().expect("polled an Acceptor after completion");
122            let (s, a) = ready!(listener.async_accept(cx))?;
123            stream = s;
124            addr = a;
125        }
126        let listener = self.0.take().unwrap();
127        Poll::Ready(Ok((listener, stream, addr)))
128    }
129}
130
131/// A stream which resolves to an `io::Result<(TcpStream, SocketAddr)>`.
132#[derive(Debug)]
133pub struct AcceptStream(TcpListener);
134
135impl Stream for AcceptStream {
136    type Item = io::Result<(TcpStream, SocketAddr)>;
137
138    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139        let (stream, addr) = ready!(self.0.async_accept(cx)?);
140        Poll::Ready(Some(Ok((stream, addr))))
141    }
142}
143
144/// A single TCP connection.
145///
146/// This type and references to it implement the `AsyncRead` and `AsyncWrite`
147/// traits. For more on using this type, see the `AsyncReadExt` and `AsyncWriteExt`
148/// traits.
149#[derive(Debug)]
150pub struct TcpStream {
151    stream: EventedFd<net::TcpStream>,
152}
153
154impl Deref for TcpStream {
155    type Target = EventedFd<net::TcpStream>;
156
157    fn deref(&self) -> &Self::Target {
158        &self.stream
159    }
160}
161
162impl TcpStream {
163    /// Creates a new `TcpStream` connected to a specific socket address from an existing socket
164    /// descriptor.
165    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
166    pub fn connect_from_raw(
167        socket: impl std::os::unix::io::IntoRawFd,
168        addr: SocketAddr,
169    ) -> io::Result<TcpConnector> {
170        // This is safe because `into_raw_fd()` consumes ownership of the socket, so we are
171        // guaranteed that the returned value is not shared among more than one owner at this
172        // point.
173        let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
174        Self::from_socket2(socket, addr)
175    }
176
177    /// Creates a new `TcpStream` connected to a specific socket address.
178    ///
179    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
180    pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
181        let domain = match addr {
182            SocketAddr::V4(..) => socket2::Domain::IPV4,
183            SocketAddr::V6(..) => socket2::Domain::IPV6,
184        };
185        let socket =
186            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
187        Self::from_socket2(socket, addr)
188    }
189
190    // This function is intentionally kept private to avoid socket2 appearing in the public API.
191    fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
192        let () = socket.set_nonblocking(true)?;
193        let addr = addr.into();
194        let () = match socket.connect(&addr) {
195            Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
196            res => res,
197        }?;
198        let stream = socket.into();
199        // This is safe because the file descriptor for stream will live as long as the TcpStream.
200        let stream = unsafe { EventedFd::new(stream)? };
201        let stream = Some(TcpStream { stream });
202
203        Ok(TcpConnector { need_write: true, stream })
204    }
205
206    /// Shuts down the connection, see `std::net::TcpStream.shutdown`
207    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
208        self.std().shutdown(how)
209    }
210
211    /// Flushes the connection, see `std::net::TcpStream.flush`
212    fn flush(&mut self) -> io::Result<()> {
213        self.std_mut().flush()
214    }
215
216    /// Creates a new `fuchsia_async::net::TcpStream` from a `std::net::TcpStream`.
217    fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
218        let stream: socket2::Socket = stream.into();
219        let () = stream.set_nonblocking(true)?;
220        let stream = stream.into();
221        // This is safe because the file descriptor for stream will live as long as the TcpStream.
222        let stream = unsafe { EventedFd::new(stream)? };
223        Ok(TcpStream { stream })
224    }
225
226    /// Returns a reference to the underlying `std::net::TcpStream`
227    pub fn std(&self) -> &net::TcpStream {
228        self.as_ref()
229    }
230
231    fn std_mut(&mut self) -> &mut net::TcpStream {
232        self.stream.as_mut()
233    }
234}
235
236impl AsRawFd for TcpStream {
237    fn as_raw_fd(&self) -> RawFd {
238        self.stream.as_raw_fd()
239    }
240}
241
242impl AsyncRead for TcpStream {
243    fn poll_read(
244        mut self: Pin<&mut Self>,
245        cx: &mut Context<'_>,
246        buf: &mut [u8],
247    ) -> Poll<io::Result<usize>> {
248        Pin::new(&mut self.stream).poll_read(cx, buf)
249    }
250
251    // TODO: override poll_vectored_read and call readv on the underlying stream
252}
253
254impl AsyncWrite for TcpStream {
255    fn poll_write(
256        mut self: Pin<&mut Self>,
257        cx: &mut Context<'_>,
258        buf: &[u8],
259    ) -> Poll<io::Result<usize>> {
260        Pin::new(&mut self.stream).poll_write(cx, buf)
261    }
262
263    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
264        match self.get_mut().flush() {
265            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
266            Err(e) => Poll::Ready(Err(e)),
267            Ok(()) => Poll::Ready(Ok(())),
268        }
269    }
270
271    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
272        Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
273    }
274
275    // TODO: override poll_vectored_write and call writev on the underlying stream
276}
277
278/// A future which resolves to a connected `TcpStream`.
279#[derive(Debug)]
280pub struct TcpConnector {
281    // The stream needs to have `need_write` called on it to defeat the optimization in
282    // EventedFd::new which assumes that the operand is immediately readable and writable.
283    need_write: bool,
284    stream: Option<TcpStream>,
285}
286
287impl Future for TcpConnector {
288    type Output = io::Result<TcpStream>;
289
290    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291        let this = &mut *self;
292        {
293            let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
294            if this.need_write {
295                this.need_write = false;
296                stream.need_write(cx);
297                return Poll::Pending;
298            }
299            let () = ready!(stream.poll_writable(cx)?);
300            let () = match stream.as_ref().take_error() {
301                Ok(None) => Ok(()),
302                Ok(Some(err)) | Err(err) => Err(err),
303            }?;
304        }
305        let stream = this.stream.take().unwrap();
306        Poll::Ready(Ok(stream))
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::{TcpListener, TcpStream};
313    use crate::TestExecutor;
314    use futures::io::{AsyncReadExt, AsyncWriteExt};
315    use futures::stream::StreamExt;
316    use std::io::{Error, ErrorKind};
317    use std::net::{self, Ipv4Addr, SocketAddr};
318
319    #[test]
320    fn choose_listen_port() {
321        let _exec = TestExecutor::new();
322        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
323        let listener = TcpListener::bind(&addr_request).expect("could not create listener");
324        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
325        assert_eq!(actual_addr.ip(), addr_request.ip());
326        assert_ne!(actual_addr.port(), 0);
327    }
328
329    #[test]
330    fn choose_listen_port_from_std() {
331        let _exec = TestExecutor::new();
332        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
333        let inner = net::TcpListener::bind(addr_request).expect("could not create inner listener");
334        let listener = TcpListener::from_std(inner).expect("could not create listener");
335        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
336        assert_eq!(actual_addr.ip(), addr_request.ip());
337        assert_ne!(actual_addr.port(), 0);
338    }
339
340    #[test]
341    fn connect_to_nonlistening_endpoint() {
342        let mut exec = TestExecutor::new();
343
344        // bind to a port to find an unused one, but don't start listening.
345        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
346        let socket = socket2::Socket::new(
347            socket2::Domain::IPV4,
348            socket2::Type::STREAM,
349            Some(socket2::Protocol::TCP),
350        )
351        .expect("could not create socket");
352        let () = socket.bind(&addr).expect("could not bind");
353        let addr = socket.local_addr().expect("local addr query to succeed");
354        let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
355
356        // connecting to the nonlistening port should fail.
357        let connector = TcpStream::connect(addr).expect("could not create client");
358        let fut = async move {
359            let res = connector.await;
360            assert!(res.is_err());
361            Ok::<(), Error>(())
362        };
363
364        exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
365    }
366
367    #[test]
368    fn send_recv() {
369        let mut exec = TestExecutor::new();
370
371        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
372        let listener = TcpListener::bind(&addr).expect("could not create listener");
373        let addr = listener.local_addr().expect("local_addr query to succeed");
374        let mut listener = listener.accept_stream();
375
376        let query = b"ping";
377        let response = b"pong";
378
379        let server = async move {
380            let (mut socket, _clientaddr) =
381                listener.next().await.expect("stream to not be done").expect("client to connect");
382            drop(listener);
383
384            let mut buf = [0u8; 20];
385            let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
386            assert_eq!(query, &buf[..n]);
387
388            socket.write_all(&response[..]).await.expect("server write to succeed");
389
390            let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
391            assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
392        };
393
394        let client = async move {
395            let connector = TcpStream::connect(addr).expect("could not create client");
396            let mut socket = connector.await.expect("client to connect to server");
397
398            socket.write_all(&query[..]).await.expect("client write to succeed");
399
400            let mut buf = [0u8; 20];
401            let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
402            assert_eq!(response, &buf[..n]);
403        };
404
405        exec.run_singlethreaded(futures::future::join(server, client));
406    }
407
408    #[test]
409    fn send_recv_large() {
410        let mut exec = TestExecutor::new();
411        let addr = "127.0.0.1:0".parse().unwrap();
412
413        const BUF_SIZE: usize = 10 * 1024;
414        const WRITES: usize = 1024;
415        const LENGTH: usize = WRITES * BUF_SIZE;
416
417        let listener = TcpListener::bind(&addr).expect("could not create listener");
418        let addr = listener.local_addr().expect("query local_addr");
419        let mut listener = listener.accept_stream();
420
421        let server = async move {
422            let (mut socket, _clientaddr) =
423                listener.next().await.expect("stream to not be done").expect("client to connect");
424            drop(listener);
425
426            let buf = [0u8; BUF_SIZE];
427            for _ in 0usize..WRITES {
428                socket.write_all(&buf[..]).await.expect("server write to succeed");
429            }
430        };
431
432        let client = async move {
433            let connector = TcpStream::connect(addr).expect("could not create client");
434            let mut socket = connector.await.expect("client to connect to server");
435
436            let zeroes = Box::new([0u8; BUF_SIZE]);
437            let mut read = 0;
438            while read < LENGTH {
439                let mut buf = Box::new([1u8; BUF_SIZE]);
440                let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
441                assert_eq!(&buf[0..n], &zeroes[0..n]);
442                read += n;
443            }
444        };
445
446        exec.run_singlethreaded(futures::future::join(server, client));
447    }
448}