fuchsia_async/net/fuchsia/
tcp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::ready;
11use futures::stream::Stream;
12use futures::task::{Context, Poll};
13use std::io::{self, Write};
14use std::net::{self, Shutdown, SocketAddr};
15use std::ops::Deref;
16use std::os::unix::io::FromRawFd as _;
17use std::pin::Pin;
18
19/// An I/O object representing a TCP socket listening for incoming connections.
20///
21/// This object can be converted into a stream of incoming connections for
22/// various forms of processing.
23#[derive(Debug)]
24pub struct TcpListener(EventedFd<net::TcpListener>);
25
26impl Unpin for TcpListener {}
27
28impl Deref for TcpListener {
29    type Target = EventedFd<net::TcpListener>;
30
31    fn deref(&self) -> &Self::Target {
32        &self.0
33    }
34}
35
36impl TcpListener {
37    /// Creates a new `TcpListener` bound to the provided socket.
38    pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
39        let domain = match *addr {
40            SocketAddr::V4(..) => socket2::Domain::IPV4,
41            SocketAddr::V6(..) => socket2::Domain::IPV6,
42        };
43        let socket =
44            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
45        // Allow this socket to be rebound while it is in TIME_WAIT.
46        //
47        // This is borrowed from std::net::TcpListener::bind. See
48        // https://github.com/rust-lang/rust/blob/db492ec/library/std/src/sys_common/net.rs#L371-L379.
49        let () = socket.set_reuse_address(true)?;
50        let addr = (*addr).into();
51        let () = socket.bind(&addr)?;
52        let () = socket.listen(1024)?;
53        TcpListener::from_std(socket.into())
54    }
55
56    /// Consumes this listener and returns a `Future` that resolves to an
57    /// `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
58    pub fn accept(self) -> Acceptor {
59        Acceptor(Some(self))
60    }
61
62    /// Consumes this listener and returns a `Stream` that resolves to elements
63    /// of type `io::Result<(TcpStream, SocketAddr)>`.
64    pub fn accept_stream(self) -> AcceptStream {
65        AcceptStream(self)
66    }
67
68    /// Poll on `accept`ing a new `TcpStream` from this listener.
69    /// This function is mainly intended for usage in manual `Future` or `Stream`
70    /// implementations.
71    pub fn async_accept(
72        &mut self,
73        cx: &mut Context<'_>,
74    ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
75        ready!(EventedFd::poll_readable(&self.0, cx))?;
76        match self.0.as_ref().accept() {
77            Err(e) => {
78                if e.kind() == io::ErrorKind::WouldBlock {
79                    self.0.need_read(cx);
80                    Poll::Pending
81                } else {
82                    Poll::Ready(Err(e))
83                }
84            }
85            Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
86        }
87    }
88
89    /// Creates a new instance of `fuchsia_async::net::TcpListener` from an
90    /// `std::net::TcpListener`.
91    pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
92        let listener: socket2::Socket = listener.into();
93        let () = listener.set_nonblocking(true)?;
94        let listener = listener.into();
95        let listener = unsafe { EventedFd::new(listener)? };
96        Ok(TcpListener(listener))
97    }
98
99    /// Returns a reference to the underlying `std::net::TcpListener`.
100    pub fn std(&self) -> &net::TcpListener {
101        self.as_ref()
102    }
103
104    /// Returns the local socket address of the listener.
105    pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
106        self.std().local_addr()
107    }
108}
109
110/// A future which resolves to an `io::Result<(TcpListener, TcpStream, SocketAddr)>`.
111#[derive(Debug)]
112pub struct Acceptor(Option<TcpListener>);
113
114impl Future for Acceptor {
115    type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
116
117    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
118        let (stream, addr);
119        {
120            let listener = self.0.as_mut().expect("polled an Acceptor after completion");
121            let (s, a) = ready!(listener.async_accept(cx))?;
122            stream = s;
123            addr = a;
124        }
125        let listener = self.0.take().unwrap();
126        Poll::Ready(Ok((listener, stream, addr)))
127    }
128}
129
130/// A stream which resolves to an `io::Result<(TcpStream, SocketAddr)>`.
131#[derive(Debug)]
132pub struct AcceptStream(TcpListener);
133
134impl Stream for AcceptStream {
135    type Item = io::Result<(TcpStream, SocketAddr)>;
136
137    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        let (stream, addr) = ready!(self.0.async_accept(cx)?);
139        Poll::Ready(Some(Ok((stream, addr))))
140    }
141}
142
143/// A single TCP connection.
144///
145/// This type and references to it implement the `AsyncRead` and `AsyncWrite`
146/// traits. For more on using this type, see the `AsyncReadExt` and `AsyncWriteExt`
147/// traits.
148#[derive(Debug)]
149pub struct TcpStream {
150    stream: EventedFd<net::TcpStream>,
151}
152
153impl Deref for TcpStream {
154    type Target = EventedFd<net::TcpStream>;
155
156    fn deref(&self) -> &Self::Target {
157        &self.stream
158    }
159}
160
161impl TcpStream {
162    /// Creates a new `TcpStream` connected to a specific socket address from an existing socket
163    /// descriptor.
164    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
165    pub fn connect_from_raw(
166        socket: impl std::os::unix::io::IntoRawFd,
167        addr: SocketAddr,
168    ) -> io::Result<TcpConnector> {
169        // This is safe because `into_raw_fd()` consumes ownership of the socket, so we are
170        // guaranteed that the returned value is not shared among more than one owner at this
171        // point.
172        let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
173        Self::from_socket2(socket, addr)
174    }
175
176    /// Creates a new `TcpStream` connected to a specific socket address.
177    ///
178    /// This function returns a future which resolves to an `io::Result<TcpStream>`.
179    pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
180        let domain = match addr {
181            SocketAddr::V4(..) => socket2::Domain::IPV4,
182            SocketAddr::V6(..) => socket2::Domain::IPV6,
183        };
184        let socket =
185            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
186        Self::from_socket2(socket, addr)
187    }
188
189    // This function is intentionally kept private to avoid socket2 appearing in the public API.
190    fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
191        let () = socket.set_nonblocking(true)?;
192        let addr = addr.into();
193        let () = match socket.connect(&addr) {
194            Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
195            res => res,
196        }?;
197        let stream = socket.into();
198        // This is safe because the file descriptor for stream will live as long as the TcpStream.
199        let stream = unsafe { EventedFd::new(stream)? };
200        let stream = Some(TcpStream { stream });
201
202        Ok(TcpConnector { need_write: true, stream })
203    }
204
205    /// Shuts down the connection, see `std::net::TcpStream.shutdown`
206    pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
207        self.std().shutdown(how)
208    }
209
210    /// Flushes the connection, see `std::net::TcpStream.flush`
211    fn flush(&mut self) -> io::Result<()> {
212        self.std_mut().flush()
213    }
214
215    /// Creates a new `fuchsia_async::net::TcpStream` from a `std::net::TcpStream`.
216    fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
217        let stream: socket2::Socket = stream.into();
218        let () = stream.set_nonblocking(true)?;
219        let stream = stream.into();
220        // This is safe because the file descriptor for stream will live as long as the TcpStream.
221        let stream = unsafe { EventedFd::new(stream)? };
222        Ok(TcpStream { stream })
223    }
224
225    /// Returns a reference to the underlying `std::net::TcpStream`
226    pub fn std(&self) -> &net::TcpStream {
227        self.as_ref()
228    }
229
230    fn std_mut<'a>(&'a mut self) -> &'a mut net::TcpStream {
231        self.stream.as_mut()
232    }
233}
234
235impl AsyncRead for TcpStream {
236    fn poll_read(
237        mut self: Pin<&mut Self>,
238        cx: &mut Context<'_>,
239        buf: &mut [u8],
240    ) -> Poll<io::Result<usize>> {
241        Pin::new(&mut self.stream).poll_read(cx, buf)
242    }
243
244    // TODO: override poll_vectored_read and call readv on the underlying stream
245}
246
247impl AsyncWrite for TcpStream {
248    fn poll_write(
249        mut self: Pin<&mut Self>,
250        cx: &mut Context<'_>,
251        buf: &[u8],
252    ) -> Poll<io::Result<usize>> {
253        Pin::new(&mut self.stream).poll_write(cx, buf)
254    }
255
256    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
257        match self.get_mut().flush() {
258            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
259            Err(e) => Poll::Ready(Err(e)),
260            Ok(()) => Poll::Ready(Ok(())),
261        }
262    }
263
264    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
265        Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
266    }
267
268    // TODO: override poll_vectored_write and call writev on the underlying stream
269}
270
271/// A future which resolves to a connected `TcpStream`.
272#[derive(Debug)]
273pub struct TcpConnector {
274    // The stream needs to have `need_write` called on it to defeat the optimization in
275    // EventedFd::new which assumes that the operand is immediately readable and writable.
276    need_write: bool,
277    stream: Option<TcpStream>,
278}
279
280impl Future for TcpConnector {
281    type Output = io::Result<TcpStream>;
282
283    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
284        let this = &mut *self;
285        {
286            let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
287            if this.need_write {
288                this.need_write = false;
289                stream.need_write(cx);
290                return Poll::Pending;
291            }
292            let () = ready!(stream.poll_writable(cx)?);
293            let () = match stream.as_ref().take_error() {
294                Ok(None) => Ok(()),
295                Ok(Some(err)) | Err(err) => Err(err),
296            }?;
297        }
298        let stream = this.stream.take().unwrap();
299        Poll::Ready(Ok(stream))
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::{TcpListener, TcpStream};
306    use crate::TestExecutor;
307    use futures::io::{AsyncReadExt, AsyncWriteExt};
308    use futures::stream::StreamExt;
309    use std::io::{Error, ErrorKind};
310    use std::net::{self, Ipv4Addr, SocketAddr};
311
312    #[test]
313    fn choose_listen_port() {
314        let _exec = TestExecutor::new();
315        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
316        let listener = TcpListener::bind(&addr_request).expect("could not create listener");
317        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
318        assert_eq!(actual_addr.ip(), addr_request.ip());
319        assert_ne!(actual_addr.port(), 0);
320    }
321
322    #[test]
323    fn choose_listen_port_from_std() {
324        let _exec = TestExecutor::new();
325        let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
326        let inner = net::TcpListener::bind(&addr_request).expect("could not create inner listener");
327        let listener = TcpListener::from_std(inner).expect("could not create listener");
328        let actual_addr = listener.local_addr().expect("local_addr query to succeed");
329        assert_eq!(actual_addr.ip(), addr_request.ip());
330        assert_ne!(actual_addr.port(), 0);
331    }
332
333    #[test]
334    fn connect_to_nonlistening_endpoint() {
335        let mut exec = TestExecutor::new();
336
337        // bind to a port to find an unused one, but don't start listening.
338        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
339        let socket = socket2::Socket::new(
340            socket2::Domain::IPV4,
341            socket2::Type::STREAM,
342            Some(socket2::Protocol::TCP),
343        )
344        .expect("could not create socket");
345        let () = socket.bind(&addr).expect("could not bind");
346        let addr = socket.local_addr().expect("local addr query to succeed");
347        let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
348
349        // connecting to the nonlistening port should fail.
350        let connector = TcpStream::connect(addr).expect("could not create client");
351        let fut = async move {
352            let res = connector.await;
353            assert!(res.is_err());
354            Ok::<(), Error>(())
355        };
356
357        exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
358    }
359
360    #[test]
361    fn send_recv() {
362        let mut exec = TestExecutor::new();
363
364        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
365        let listener = TcpListener::bind(&addr).expect("could not create listener");
366        let addr = listener.local_addr().expect("local_addr query to succeed");
367        let mut listener = listener.accept_stream();
368
369        let query = b"ping";
370        let response = b"pong";
371
372        let server = async move {
373            let (mut socket, _clientaddr) =
374                listener.next().await.expect("stream to not be done").expect("client to connect");
375            drop(listener);
376
377            let mut buf = [0u8; 20];
378            let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
379            assert_eq!(query, &buf[..n]);
380
381            socket.write_all(&response[..]).await.expect("server write to succeed");
382
383            let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
384            assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
385        };
386
387        let client = async move {
388            let connector = TcpStream::connect(addr).expect("could not create client");
389            let mut socket = connector.await.expect("client to connect to server");
390
391            socket.write_all(&query[..]).await.expect("client write to succeed");
392
393            let mut buf = [0u8; 20];
394            let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
395            assert_eq!(response, &buf[..n]);
396        };
397
398        exec.run_singlethreaded(futures::future::join(server, client));
399    }
400
401    #[test]
402    fn send_recv_large() {
403        let mut exec = TestExecutor::new();
404        let addr = "127.0.0.1:0".parse().unwrap();
405
406        const BUF_SIZE: usize = 10 * 1024;
407        const WRITES: usize = 1024;
408        const LENGTH: usize = WRITES * BUF_SIZE;
409
410        let listener = TcpListener::bind(&addr).expect("could not create listener");
411        let addr = listener.local_addr().expect("query local_addr");
412        let mut listener = listener.accept_stream();
413
414        let server = async move {
415            let (mut socket, _clientaddr) =
416                listener.next().await.expect("stream to not be done").expect("client to connect");
417            drop(listener);
418
419            let buf = [0u8; BUF_SIZE];
420            for _ in 0usize..WRITES {
421                socket.write_all(&buf[..]).await.expect("server write to succeed");
422            }
423        };
424
425        let client = async move {
426            let connector = TcpStream::connect(addr).expect("could not create client");
427            let mut socket = connector.await.expect("client to connect to server");
428
429            let zeroes = Box::new([0u8; BUF_SIZE]);
430            let mut read = 0;
431            while read < LENGTH {
432                let mut buf = Box::new([1u8; BUF_SIZE]);
433                let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
434                assert_eq!(&buf[0..n], &zeroes[0..n]);
435                read += n;
436            }
437        };
438
439        exec.run_singlethreaded(futures::future::join(server, client));
440    }
441}