Skip to main content

fuchsia_async/net/fuchsia/
tcp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::ready;
11use futures::stream::Stream;
12use futures::task::{Context, Poll};
13use std::io::{self, Write};
14use std::net::{self, Shutdown, SocketAddr};
15use std::ops::Deref;
16use std::os::fd::{AsRawFd, RawFd};
17use std::os::unix::io::FromRawFd as _;
18use std::pin::Pin;
19use zx_status_ext::StatusExt;
20
21/// An I/O object representing a TCP socket listening for incoming connections.
22///
23/// This object can be converted into a stream of incoming connections for
24/// various forms of processing.
25#[derive(Debug)]
26pub struct TcpListener(EventedFd<net::TcpListener>);
27
28impl Unpin for TcpListener {}
29
30impl Deref for TcpListener {
31    type Target = EventedFd<net::TcpListener>;
32
33    fn deref(&self) -> &Self::Target {
34        &self.0
35    }
36}
37
38impl TcpListener {
39    /// Creates a new `TcpListener` bound to the provided socket.
40    pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
41        let domain = match *addr {
42            SocketAddr::V4(..) => socket2::Domain::IPV4,
43            SocketAddr::V6(..) => socket2::Domain::IPV6,
44        };
45        let socket =
46            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
47        // Allow this socket to be rebound while it is in TIME_WAIT.
48        //
49        // This is borrowed from std::net::TcpListener::bind. See
50        // https://github.com/rust-lang/rust/blob/db492ec/library/std/src/sys_common/net.rs#L371-L379.
51        let () = socket.set_reuse_address(true)?;
52        let addr = (*addr).into();
53        let () = socket.bind(&addr)?;
54        let () = socket.listen(1024)?;
55        TcpListener::from_std(socket.into())
56    }
57
58    /// Consumes this listener and returns a `Future` that resolves to an
59    /// `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
60    pub fn accept(self) -> Acceptor {
61        Acceptor(Some(self))
62    }
63
64    /// Consumes this listener and returns a `Stream` that resolves to elements
65    /// of type `io::Result<(TcpStream, SocketAddr)>`.
66    pub fn accept_stream(self) -> AcceptStream {
67        AcceptStream(self)
68    }
69
70    /// Poll on `accept`ing a new `TcpStream` from this listener.
71    /// This function is mainly intended for usage in manual `Future` or `Stream`
72    /// implementations.
73    pub fn async_accept(
74        &mut self,
75        cx: &mut Context<'_>,
76    ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
77        ready!(EventedFd::poll_readable(&self.0, cx)).map_err(|s| s.into_io_error())?;
78        match self.0.as_ref().accept() {
79            Err(e) => {
80                if e.kind() == io::ErrorKind::WouldBlock {
81                    self.0.need_read(cx);
82                    Poll::Pending
83                } else {
84                    Poll::Ready(Err(e))
85                }
86            }
87            Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
88        }
89    }
90
91    /// Creates a new instance of `fuchsia_async::net::TcpListener` from an
92    /// `std::net::TcpListener`.
93    pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
94        let listener: socket2::Socket = listener.into();
95        let () = listener.set_nonblocking(true)?;
96        let listener = listener.into();
97        let listener = unsafe { EventedFd::new(listener)? };
98        Ok(TcpListener(listener))
99    }
100
101    /// Returns a reference to the underlying `std::net::TcpListener`.
102    pub fn std(&self) -> &net::TcpListener {
103        self.as_ref()
104    }
105
106    /// Returns the local socket address of the listener.
107    pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
108        self.std().local_addr()
109    }
110}
111
112/// A future which resolves to an `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
113#[derive(Debug)]
114pub struct Acceptor(Option<TcpListener>);
115
116impl Future for Acceptor {
117    type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
118
119    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
120        let (stream, addr);
121        {
122            let listener = self.0.as_mut().expect("polled an Acceptor after completion");
123            let (s, a) = ready!(listener.async_accept(cx))?;
124            stream = s;
125            addr = a;
126        }
127        let listener = self.0.take().unwrap();
128        Poll::Ready(Ok((listener, stream, addr)))
129    }
130}
131
132/// A stream which resolves to an `io::Result<(TcpStream, SocketAddr)>`.
133#[derive(Debug)]
134pub struct AcceptStream(TcpListener);
135
136impl Stream for AcceptStream {
137    type Item = io::Result<(TcpStream, SocketAddr)>;
138
139    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        let (stream, addr) = ready!(self.0.async_accept(cx)?);
141        Poll::Ready(Some(Ok((stream, addr))))
142    }
143}
144
145/// A single TCP connection.
146///
147/// This type and references to it implement the `AsyncRead` and `AsyncWrite`
148/// traits. For more on using this type, see the `AsyncReadExt` and `AsyncWriteExt`
149/// traits.
150#[derive(Debug)]
151pub struct TcpStream {
152    stream: EventedFd<net::TcpStream>,
153}
154
155impl Deref for TcpStream {
156    type Target = EventedFd<net::TcpStream>;
157
158    fn deref(&self) -> &Self::Target {
159        &self.stream
160    }
161}
162
163impl TcpStream {
164    /// Creates a new `TcpStream` connected to a specific socket address from an existing socket
165    /// descriptor.
166    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
167    pub fn connect_from_raw(
168        socket: impl std::os::unix::io::IntoRawFd,
169        addr: SocketAddr,
170    ) -> io::Result<TcpConnector> {
171        // This is safe because `into_raw_fd()` consumes ownership of the socket, so we are
172        // guaranteed that the returned value is not shared among more than one owner at this
173        // point.
174        let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
175        Self::from_socket2(socket, addr)
176    }
177
178    /// Creates a new `TcpStream` connected to a specific socket address.
179    ///
180    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
181    pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
182        let domain = match addr {
183            SocketAddr::V4(..) => socket2::Domain::IPV4,
184            SocketAddr::V6(..) => socket2::Domain::IPV6,
185        };
186        let socket =
187            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
188        Self::from_socket2(socket, addr)
189    }
190
191    // This function is intentionally kept private to avoid socket2 appearing in the public API.
192    fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
193        let () = socket.set_nonblocking(true)?;
194        let addr = addr.into();
195        let () = match socket.connect(&addr) {
196            Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
197            res => res,
198        }?;
199        let stream = socket.into();
200        // This is safe because the file descriptor for stream will live as long as the TcpStream.
201        let stream = unsafe { EventedFd::new(stream)? };
202        let stream = Some(TcpStream { stream });
203
204        Ok(TcpConnector { need_write: true, stream })
205    }
206
207    /// Shuts down the connection, see `std::net::TcpStream.shutdown`
208    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
209        self.std().shutdown(how)
210    }
211
212    /// Flushes the connection, see `std::net::TcpStream.flush`
213    fn flush(&mut self) -> io::Result<()> {
214        self.std_mut().flush()
215    }
216
217    /// Creates a new `fuchsia_async::net::TcpStream` from a `std::net::TcpStream`.
218    fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
219        let stream: socket2::Socket = stream.into();
220        let () = stream.set_nonblocking(true)?;
221        let stream = stream.into();
222        // This is safe because the file descriptor for stream will live as long as the TcpStream.
223        let stream = unsafe { EventedFd::new(stream)? };
224        Ok(TcpStream { stream })
225    }
226
227    /// Returns a reference to the underlying `std::net::TcpStream`
228    pub fn std(&self) -> &net::TcpStream {
229        self.as_ref()
230    }
231
232    fn std_mut(&mut self) -> &mut net::TcpStream {
233        self.stream.as_mut()
234    }
235}
236
237impl AsRawFd for TcpStream {
238    fn as_raw_fd(&self) -> RawFd {
239        self.stream.as_raw_fd()
240    }
241}
242
243impl AsyncRead for TcpStream {
244    fn poll_read(
245        mut self: Pin<&mut Self>,
246        cx: &mut Context<'_>,
247        buf: &mut [u8],
248    ) -> Poll<io::Result<usize>> {
249        Pin::new(&mut self.stream).poll_read(cx, buf)
250    }
251
252    // TODO: override poll_vectored_read and call readv on the underlying stream
253}
254
255impl AsyncWrite for TcpStream {
256    fn poll_write(
257        mut self: Pin<&mut Self>,
258        cx: &mut Context<'_>,
259        buf: &[u8],
260    ) -> Poll<io::Result<usize>> {
261        Pin::new(&mut self.stream).poll_write(cx, buf)
262    }
263
264    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
265        match self.get_mut().flush() {
266            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
267            Err(e) => Poll::Ready(Err(e)),
268            Ok(()) => Poll::Ready(Ok(())),
269        }
270    }
271
272    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
273        Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
274    }
275
276    // TODO: override poll_vectored_write and call writev on the underlying stream
277}
278
279/// A future which resolves to a connected `TcpStream`.
280#[derive(Debug)]
281pub struct TcpConnector {
282    // The stream needs to have `need_write` called on it to defeat the optimization in
283    // EventedFd::new which assumes that the operand is immediately readable and writable.
284    need_write: bool,
285    stream: Option<TcpStream>,
286}
287
288impl Future for TcpConnector {
289    type Output = io::Result<TcpStream>;
290
291    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292        let this = &mut *self;
293        {
294            let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
295            if this.need_write {
296                this.need_write = false;
297                stream.need_write(cx);
298                return Poll::Pending;
299            }
300            let () = ready!(stream.poll_writable(cx)).map_err(|s| s.into_io_error())?;
301            let () = match stream.as_ref().take_error() {
302                Ok(None) => Ok(()),
303                Ok(Some(err)) | Err(err) => Err(err),
304            }?;
305        }
306        let stream = this.stream.take().unwrap();
307        Poll::Ready(Ok(stream))
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::{TcpListener, TcpStream};
314    use crate::TestExecutor;
315    use futures::io::{AsyncReadExt, AsyncWriteExt};
316    use futures::stream::StreamExt;
317    use std::io::{Error, ErrorKind};
318    use std::net::{self, Ipv4Addr, SocketAddr};
319
320    #[test]
321    fn choose_listen_port() {
322        let _exec = TestExecutor::new();
323        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
324        let listener = TcpListener::bind(&addr_request).expect("could not create listener");
325        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
326        assert_eq!(actual_addr.ip(), addr_request.ip());
327        assert_ne!(actual_addr.port(), 0);
328    }
329
330    #[test]
331    fn choose_listen_port_from_std() {
332        let _exec = TestExecutor::new();
333        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
334        let inner = net::TcpListener::bind(addr_request).expect("could not create inner listener");
335        let listener = TcpListener::from_std(inner).expect("could not create listener");
336        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
337        assert_eq!(actual_addr.ip(), addr_request.ip());
338        assert_ne!(actual_addr.port(), 0);
339    }
340
341    #[test]
342    fn connect_to_nonlistening_endpoint() {
343        let mut exec = TestExecutor::new();
344
345        // bind to a port to find an unused one, but don't start listening.
346        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
347        let socket = socket2::Socket::new(
348            socket2::Domain::IPV4,
349            socket2::Type::STREAM,
350            Some(socket2::Protocol::TCP),
351        )
352        .expect("could not create socket");
353        let () = socket.bind(&addr).expect("could not bind");
354        let addr = socket.local_addr().expect("local addr query to succeed");
355        let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
356
357        // connecting to the nonlistening port should fail.
358        let connector = TcpStream::connect(addr).expect("could not create client");
359        let fut = async move {
360            let res = connector.await;
361            assert!(res.is_err());
362            Ok::<(), Error>(())
363        };
364
365        exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
366    }
367
368    #[test]
369    fn send_recv() {
370        let mut exec = TestExecutor::new();
371
372        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
373        let listener = TcpListener::bind(&addr).expect("could not create listener");
374        let addr = listener.local_addr().expect("local_addr query to succeed");
375        let mut listener = listener.accept_stream();
376
377        let query = b"ping";
378        let response = b"pong";
379
380        let server = async move {
381            let (mut socket, _clientaddr) =
382                listener.next().await.expect("stream to not be done").expect("client to connect");
383            drop(listener);
384
385            let mut buf = [0u8; 20];
386            let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
387            assert_eq!(query, &buf[..n]);
388
389            socket.write_all(&response[..]).await.expect("server write to succeed");
390
391            let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
392            assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
393        };
394
395        let client = async move {
396            let connector = TcpStream::connect(addr).expect("could not create client");
397            let mut socket = connector.await.expect("client to connect to server");
398
399            socket.write_all(&query[..]).await.expect("client write to succeed");
400
401            let mut buf = [0u8; 20];
402            let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
403            assert_eq!(response, &buf[..n]);
404        };
405
406        exec.run_singlethreaded(futures::future::join(server, client));
407    }
408
409    #[test]
410    fn send_recv_large() {
411        let mut exec = TestExecutor::new();
412        let addr = "127.0.0.1:0".parse().unwrap();
413
414        const BUF_SIZE: usize = 10 * 1024;
415        const WRITES: usize = 1024;
416        const LENGTH: usize = WRITES * BUF_SIZE;
417
418        let listener = TcpListener::bind(&addr).expect("could not create listener");
419        let addr = listener.local_addr().expect("query local_addr");
420        let mut listener = listener.accept_stream();
421
422        let server = async move {
423            let (mut socket, _clientaddr) =
424                listener.next().await.expect("stream to not be done").expect("client to connect");
425            drop(listener);
426
427            let buf = [0u8; BUF_SIZE];
428            for _ in 0usize..WRITES {
429                socket.write_all(&buf[..]).await.expect("server write to succeed");
430            }
431        };
432
433        let client = async move {
434            let connector = TcpStream::connect(addr).expect("could not create client");
435            let mut socket = connector.await.expect("client to connect to server");
436
437            let zeroes = Box::new([0u8; BUF_SIZE]);
438            let mut read = 0;
439            while read < LENGTH {
440                let mut buf = Box::new([1u8; BUF_SIZE]);
441                let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
442                assert_eq!(&buf[0..n], &zeroes[0..n]);
443                read += n;
444            }
445        };
446
447        exec.run_singlethreaded(futures::future::join(server, client));
448    }
449}