fuchsia_async/net/fuchsia/
tcp.rs
1#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::ready;
11use futures::stream::Stream;
12use futures::task::{Context, Poll};
13use std::io::{self, Write};
14use std::net::{self, Shutdown, SocketAddr};
15use std::ops::Deref;
16use std::os::unix::io::FromRawFd as _;
17use std::pin::Pin;
18
19#[derive(Debug)]
24pub struct TcpListener(EventedFd<net::TcpListener>);
25
26impl Unpin for TcpListener {}
27
28impl Deref for TcpListener {
29 type Target = EventedFd<net::TcpListener>;
30
31 fn deref(&self) -> &Self::Target {
32 &self.0
33 }
34}
35
36impl TcpListener {
37 pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
39 let domain = match *addr {
40 SocketAddr::V4(..) => socket2::Domain::IPV4,
41 SocketAddr::V6(..) => socket2::Domain::IPV6,
42 };
43 let socket =
44 socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
45 let () = socket.set_reuse_address(true)?;
50 let addr = (*addr).into();
51 let () = socket.bind(&addr)?;
52 let () = socket.listen(1024)?;
53 TcpListener::from_std(socket.into())
54 }
55
56 pub fn accept(self) -> Acceptor {
59 Acceptor(Some(self))
60 }
61
62 pub fn accept_stream(self) -> AcceptStream {
65 AcceptStream(self)
66 }
67
68 pub fn async_accept(
72 &mut self,
73 cx: &mut Context<'_>,
74 ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
75 ready!(EventedFd::poll_readable(&self.0, cx))?;
76 match self.0.as_ref().accept() {
77 Err(e) => {
78 if e.kind() == io::ErrorKind::WouldBlock {
79 self.0.need_read(cx);
80 Poll::Pending
81 } else {
82 Poll::Ready(Err(e))
83 }
84 }
85 Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
86 }
87 }
88
89 pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
92 let listener: socket2::Socket = listener.into();
93 let () = listener.set_nonblocking(true)?;
94 let listener = listener.into();
95 let listener = unsafe { EventedFd::new(listener)? };
96 Ok(TcpListener(listener))
97 }
98
99 pub fn std(&self) -> &net::TcpListener {
101 self.as_ref()
102 }
103
104 pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
106 self.std().local_addr()
107 }
108}
109
110#[derive(Debug)]
112pub struct Acceptor(Option<TcpListener>);
113
114impl Future for Acceptor {
115 type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
116
117 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
118 let (stream, addr);
119 {
120 let listener = self.0.as_mut().expect("polled an Acceptor after completion");
121 let (s, a) = ready!(listener.async_accept(cx))?;
122 stream = s;
123 addr = a;
124 }
125 let listener = self.0.take().unwrap();
126 Poll::Ready(Ok((listener, stream, addr)))
127 }
128}
129
130#[derive(Debug)]
132pub struct AcceptStream(TcpListener);
133
134impl Stream for AcceptStream {
135 type Item = io::Result<(TcpStream, SocketAddr)>;
136
137 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138 let (stream, addr) = ready!(self.0.async_accept(cx)?);
139 Poll::Ready(Some(Ok((stream, addr))))
140 }
141}
142
143#[derive(Debug)]
149pub struct TcpStream {
150 stream: EventedFd<net::TcpStream>,
151}
152
153impl Deref for TcpStream {
154 type Target = EventedFd<net::TcpStream>;
155
156 fn deref(&self) -> &Self::Target {
157 &self.stream
158 }
159}
160
161impl TcpStream {
162 pub fn connect_from_raw(
166 socket: impl std::os::unix::io::IntoRawFd,
167 addr: SocketAddr,
168 ) -> io::Result<TcpConnector> {
169 let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
173 Self::from_socket2(socket, addr)
174 }
175
176 pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
180 let domain = match addr {
181 SocketAddr::V4(..) => socket2::Domain::IPV4,
182 SocketAddr::V6(..) => socket2::Domain::IPV6,
183 };
184 let socket =
185 socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
186 Self::from_socket2(socket, addr)
187 }
188
189 fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
191 let () = socket.set_nonblocking(true)?;
192 let addr = addr.into();
193 let () = match socket.connect(&addr) {
194 Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
195 res => res,
196 }?;
197 let stream = socket.into();
198 let stream = unsafe { EventedFd::new(stream)? };
200 let stream = Some(TcpStream { stream });
201
202 Ok(TcpConnector { need_write: true, stream })
203 }
204
205 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
207 self.std().shutdown(how)
208 }
209
210 fn flush(&mut self) -> io::Result<()> {
212 self.std_mut().flush()
213 }
214
215 fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
217 let stream: socket2::Socket = stream.into();
218 let () = stream.set_nonblocking(true)?;
219 let stream = stream.into();
220 let stream = unsafe { EventedFd::new(stream)? };
222 Ok(TcpStream { stream })
223 }
224
225 pub fn std(&self) -> &net::TcpStream {
227 self.as_ref()
228 }
229
230 fn std_mut<'a>(&'a mut self) -> &'a mut net::TcpStream {
231 self.stream.as_mut()
232 }
233}
234
235impl AsyncRead for TcpStream {
236 fn poll_read(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 buf: &mut [u8],
240 ) -> Poll<io::Result<usize>> {
241 Pin::new(&mut self.stream).poll_read(cx, buf)
242 }
243
244 }
246
247impl AsyncWrite for TcpStream {
248 fn poll_write(
249 mut self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &[u8],
252 ) -> Poll<io::Result<usize>> {
253 Pin::new(&mut self.stream).poll_write(cx, buf)
254 }
255
256 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
257 match self.get_mut().flush() {
258 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
259 Err(e) => Poll::Ready(Err(e)),
260 Ok(()) => Poll::Ready(Ok(())),
261 }
262 }
263
264 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
265 Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
266 }
267
268 }
270
271#[derive(Debug)]
273pub struct TcpConnector {
274 need_write: bool,
277 stream: Option<TcpStream>,
278}
279
280impl Future for TcpConnector {
281 type Output = io::Result<TcpStream>;
282
283 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
284 let this = &mut *self;
285 {
286 let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
287 if this.need_write {
288 this.need_write = false;
289 stream.need_write(cx);
290 return Poll::Pending;
291 }
292 let () = ready!(stream.poll_writable(cx)?);
293 let () = match stream.as_ref().take_error() {
294 Ok(None) => Ok(()),
295 Ok(Some(err)) | Err(err) => Err(err),
296 }?;
297 }
298 let stream = this.stream.take().unwrap();
299 Poll::Ready(Ok(stream))
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::{TcpListener, TcpStream};
306 use crate::TestExecutor;
307 use futures::io::{AsyncReadExt, AsyncWriteExt};
308 use futures::stream::StreamExt;
309 use std::io::{Error, ErrorKind};
310 use std::net::{self, Ipv4Addr, SocketAddr};
311
312 #[test]
313 fn choose_listen_port() {
314 let _exec = TestExecutor::new();
315 let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
316 let listener = TcpListener::bind(&addr_request).expect("could not create listener");
317 let actual_addr = listener.local_addr().expect("local_addr query to succeed");
318 assert_eq!(actual_addr.ip(), addr_request.ip());
319 assert_ne!(actual_addr.port(), 0);
320 }
321
322 #[test]
323 fn choose_listen_port_from_std() {
324 let _exec = TestExecutor::new();
325 let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
326 let inner = net::TcpListener::bind(&addr_request).expect("could not create inner listener");
327 let listener = TcpListener::from_std(inner).expect("could not create listener");
328 let actual_addr = listener.local_addr().expect("local_addr query to succeed");
329 assert_eq!(actual_addr.ip(), addr_request.ip());
330 assert_ne!(actual_addr.port(), 0);
331 }
332
333 #[test]
334 fn connect_to_nonlistening_endpoint() {
335 let mut exec = TestExecutor::new();
336
337 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
339 let socket = socket2::Socket::new(
340 socket2::Domain::IPV4,
341 socket2::Type::STREAM,
342 Some(socket2::Protocol::TCP),
343 )
344 .expect("could not create socket");
345 let () = socket.bind(&addr).expect("could not bind");
346 let addr = socket.local_addr().expect("local addr query to succeed");
347 let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
348
349 let connector = TcpStream::connect(addr).expect("could not create client");
351 let fut = async move {
352 let res = connector.await;
353 assert!(res.is_err());
354 Ok::<(), Error>(())
355 };
356
357 exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
358 }
359
360 #[test]
361 fn send_recv() {
362 let mut exec = TestExecutor::new();
363
364 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
365 let listener = TcpListener::bind(&addr).expect("could not create listener");
366 let addr = listener.local_addr().expect("local_addr query to succeed");
367 let mut listener = listener.accept_stream();
368
369 let query = b"ping";
370 let response = b"pong";
371
372 let server = async move {
373 let (mut socket, _clientaddr) =
374 listener.next().await.expect("stream to not be done").expect("client to connect");
375 drop(listener);
376
377 let mut buf = [0u8; 20];
378 let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
379 assert_eq!(query, &buf[..n]);
380
381 socket.write_all(&response[..]).await.expect("server write to succeed");
382
383 let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
384 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
385 };
386
387 let client = async move {
388 let connector = TcpStream::connect(addr).expect("could not create client");
389 let mut socket = connector.await.expect("client to connect to server");
390
391 socket.write_all(&query[..]).await.expect("client write to succeed");
392
393 let mut buf = [0u8; 20];
394 let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
395 assert_eq!(response, &buf[..n]);
396 };
397
398 exec.run_singlethreaded(futures::future::join(server, client));
399 }
400
401 #[test]
402 fn send_recv_large() {
403 let mut exec = TestExecutor::new();
404 let addr = "127.0.0.1:0".parse().unwrap();
405
406 const BUF_SIZE: usize = 10 * 1024;
407 const WRITES: usize = 1024;
408 const LENGTH: usize = WRITES * BUF_SIZE;
409
410 let listener = TcpListener::bind(&addr).expect("could not create listener");
411 let addr = listener.local_addr().expect("query local_addr");
412 let mut listener = listener.accept_stream();
413
414 let server = async move {
415 let (mut socket, _clientaddr) =
416 listener.next().await.expect("stream to not be done").expect("client to connect");
417 drop(listener);
418
419 let buf = [0u8; BUF_SIZE];
420 for _ in 0usize..WRITES {
421 socket.write_all(&buf[..]).await.expect("server write to succeed");
422 }
423 };
424
425 let client = async move {
426 let connector = TcpStream::connect(addr).expect("could not create client");
427 let mut socket = connector.await.expect("client to connect to server");
428
429 let zeroes = Box::new([0u8; BUF_SIZE]);
430 let mut read = 0;
431 while read < LENGTH {
432 let mut buf = Box::new([1u8; BUF_SIZE]);
433 let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
434 assert_eq!(&buf[0..n], &zeroes[0..n]);
435 read += n;
436 }
437 };
438
439 exec.run_singlethreaded(futures::future::join(server, client));
440 }
441}