fuchsia_async/net/fuchsia/
tcp.rs#![deny(missing_docs)]
use crate::net::EventedFd;
use futures::future::Future;
use futures::io::{AsyncRead, AsyncWrite};
use futures::ready;
use futures::stream::Stream;
use futures::task::{Context, Poll};
use std::io::{self, Write};
use std::net::{self, Shutdown, SocketAddr};
use std::ops::Deref;
use std::os::unix::io::FromRawFd as _;
use std::pin::Pin;
#[derive(Debug)]
pub struct TcpListener(EventedFd<net::TcpListener>);
impl Unpin for TcpListener {}
impl Deref for TcpListener {
type Target = EventedFd<net::TcpListener>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl TcpListener {
pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
let domain = match *addr {
SocketAddr::V4(..) => socket2::Domain::IPV4,
SocketAddr::V6(..) => socket2::Domain::IPV6,
};
let socket =
socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
let () = socket.set_reuse_address(true)?;
let addr = (*addr).into();
let () = socket.bind(&addr)?;
let () = socket.listen(1024)?;
TcpListener::from_std(socket.into())
}
pub fn accept(self) -> Acceptor {
Acceptor(Some(self))
}
pub fn accept_stream(self) -> AcceptStream {
AcceptStream(self)
}
pub fn async_accept(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
ready!(EventedFd::poll_readable(&self.0, cx))?;
match self.0.as_ref().accept() {
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
self.0.need_read(cx);
Poll::Pending
} else {
Poll::Ready(Err(e))
}
}
Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
}
}
pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
let listener: socket2::Socket = listener.into();
let () = listener.set_nonblocking(true)?;
let listener = listener.into();
let listener = unsafe { EventedFd::new(listener)? };
Ok(TcpListener(listener))
}
pub fn std(&self) -> &net::TcpListener {
self.as_ref()
}
pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
self.std().local_addr()
}
}
#[derive(Debug)]
pub struct Acceptor(Option<TcpListener>);
impl Future for Acceptor {
type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let (stream, addr);
{
let listener = self.0.as_mut().expect("polled an Acceptor after completion");
let (s, a) = ready!(listener.async_accept(cx))?;
stream = s;
addr = a;
}
let listener = self.0.take().unwrap();
Poll::Ready(Ok((listener, stream, addr)))
}
}
#[derive(Debug)]
pub struct AcceptStream(TcpListener);
impl Stream for AcceptStream {
type Item = io::Result<(TcpStream, SocketAddr)>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (stream, addr) = ready!(self.0.async_accept(cx)?);
Poll::Ready(Some(Ok((stream, addr))))
}
}
#[derive(Debug)]
pub struct TcpStream {
stream: EventedFd<net::TcpStream>,
}
impl Deref for TcpStream {
type Target = EventedFd<net::TcpStream>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl TcpStream {
pub fn connect_from_raw(
socket: impl std::os::unix::io::IntoRawFd,
addr: SocketAddr,
) -> io::Result<TcpConnector> {
let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
Self::from_socket2(socket, addr)
}
pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
let domain = match addr {
SocketAddr::V4(..) => socket2::Domain::IPV4,
SocketAddr::V6(..) => socket2::Domain::IPV6,
};
let socket =
socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
Self::from_socket2(socket, addr)
}
fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
let () = socket.set_nonblocking(true)?;
let addr = addr.into();
let () = match socket.connect(&addr) {
Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
res => res,
}?;
let stream = socket.into();
let stream = unsafe { EventedFd::new(stream)? };
let stream = Some(TcpStream { stream });
Ok(TcpConnector { need_write: true, stream })
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.std().shutdown(how)
}
fn flush(&mut self) -> io::Result<()> {
self.std_mut().flush()
}
fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
let stream: socket2::Socket = stream.into();
let () = stream.set_nonblocking(true)?;
let stream = stream.into();
let stream = unsafe { EventedFd::new(stream)? };
Ok(TcpStream { stream })
}
pub fn std(&self) -> &net::TcpStream {
self.as_ref()
}
fn std_mut<'a>(&'a mut self) -> &'a mut net::TcpStream {
self.stream.as_mut()
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut().flush() {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Ok(())),
}
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
}
}
#[derive(Debug)]
pub struct TcpConnector {
need_write: bool,
stream: Option<TcpStream>,
}
impl Future for TcpConnector {
type Output = io::Result<TcpStream>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
{
let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
if this.need_write {
this.need_write = false;
stream.need_write(cx);
return Poll::Pending;
}
let () = ready!(stream.poll_writable(cx)?);
let () = match stream.as_ref().take_error() {
Ok(None) => Ok(()),
Ok(Some(err)) | Err(err) => Err(err),
}?;
}
let stream = this.stream.take().unwrap();
Poll::Ready(Ok(stream))
}
}
#[cfg(test)]
mod tests {
use super::{TcpListener, TcpStream};
use crate::TestExecutor;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use futures::stream::StreamExt;
use std::io::{Error, ErrorKind};
use std::net::{self, Ipv4Addr, SocketAddr};
#[test]
fn choose_listen_port() {
let _exec = TestExecutor::new();
let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = TcpListener::bind(&addr_request).expect("could not create listener");
let actual_addr = listener.local_addr().expect("local_addr query to succeed");
assert_eq!(actual_addr.ip(), addr_request.ip());
assert_ne!(actual_addr.port(), 0);
}
#[test]
fn choose_listen_port_from_std() {
let _exec = TestExecutor::new();
let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let inner = net::TcpListener::bind(&addr_request).expect("could not create inner listener");
let listener = TcpListener::from_std(inner).expect("could not create listener");
let actual_addr = listener.local_addr().expect("local_addr query to succeed");
assert_eq!(actual_addr.ip(), addr_request.ip());
assert_ne!(actual_addr.port(), 0);
}
#[test]
fn connect_to_nonlistening_endpoint() {
let mut exec = TestExecutor::new();
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)
.expect("could not create socket");
let () = socket.bind(&addr).expect("could not bind");
let addr = socket.local_addr().expect("local addr query to succeed");
let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
let connector = TcpStream::connect(addr).expect("could not create client");
let fut = async move {
let res = connector.await;
assert!(res.is_err());
Ok::<(), Error>(())
};
exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
}
#[test]
fn send_recv() {
let mut exec = TestExecutor::new();
let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
let listener = TcpListener::bind(&addr).expect("could not create listener");
let addr = listener.local_addr().expect("local_addr query to succeed");
let mut listener = listener.accept_stream();
let query = b"ping";
let response = b"pong";
let server = async move {
let (mut socket, _clientaddr) =
listener.next().await.expect("stream to not be done").expect("client to connect");
drop(listener);
let mut buf = [0u8; 20];
let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
assert_eq!(query, &buf[..n]);
socket.write_all(&response[..]).await.expect("server write to succeed");
let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
};
let client = async move {
let connector = TcpStream::connect(addr).expect("could not create client");
let mut socket = connector.await.expect("client to connect to server");
socket.write_all(&query[..]).await.expect("client write to succeed");
let mut buf = [0u8; 20];
let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
assert_eq!(response, &buf[..n]);
};
exec.run_singlethreaded(futures::future::join(server, client));
}
#[test]
fn send_recv_large() {
let mut exec = TestExecutor::new();
let addr = "127.0.0.1:0".parse().unwrap();
const BUF_SIZE: usize = 10 * 1024;
const WRITES: usize = 1024;
const LENGTH: usize = WRITES * BUF_SIZE;
let listener = TcpListener::bind(&addr).expect("could not create listener");
let addr = listener.local_addr().expect("query local_addr");
let mut listener = listener.accept_stream();
let server = async move {
let (mut socket, _clientaddr) =
listener.next().await.expect("stream to not be done").expect("client to connect");
drop(listener);
let buf = [0u8; BUF_SIZE];
for _ in 0usize..WRITES {
socket.write_all(&buf[..]).await.expect("server write to succeed");
}
};
let client = async move {
let connector = TcpStream::connect(addr).expect("could not create client");
let mut socket = connector.await.expect("client to connect to server");
let zeroes = Box::new([0u8; BUF_SIZE]);
let mut read = 0;
while read < LENGTH {
let mut buf = Box::new([1u8; BUF_SIZE]);
let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
assert_eq!(&buf[0..n], &zeroes[0..n]);
read += n;
}
};
exec.run_singlethreaded(futures::future::join(server, client));
}
}