fuchsia_async/net/fuchsia/
tcp.rs1#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::ready;
11use futures::stream::Stream;
12use futures::task::{Context, Poll};
13use std::io::{self, Write};
14use std::net::{self, Shutdown, SocketAddr};
15use std::ops::Deref;
16use std::os::fd::{AsRawFd, RawFd};
17use std::os::unix::io::FromRawFd as _;
18use std::pin::Pin;
19use zx_status_ext::StatusExt;
20
21#[derive(Debug)]
26pub struct TcpListener(EventedFd<net::TcpListener>);
27
28impl Unpin for TcpListener {}
29
30impl Deref for TcpListener {
31 type Target = EventedFd<net::TcpListener>;
32
33 fn deref(&self) -> &Self::Target {
34 &self.0
35 }
36}
37
38impl TcpListener {
39 pub fn bind(addr: &SocketAddr) -> io::Result<TcpListener> {
41 let domain = match *addr {
42 SocketAddr::V4(..) => socket2::Domain::IPV4,
43 SocketAddr::V6(..) => socket2::Domain::IPV6,
44 };
45 let socket =
46 socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
47 let () = socket.set_reuse_address(true)?;
52 let addr = (*addr).into();
53 let () = socket.bind(&addr)?;
54 let () = socket.listen(1024)?;
55 TcpListener::from_std(socket.into())
56 }
57
58 pub fn accept(self) -> Acceptor {
61 Acceptor(Some(self))
62 }
63
64 pub fn accept_stream(self) -> AcceptStream {
67 AcceptStream(self)
68 }
69
70 pub fn async_accept(
74 &mut self,
75 cx: &mut Context<'_>,
76 ) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
77 ready!(EventedFd::poll_readable(&self.0, cx)).map_err(|s| s.into_io_error())?;
78 match self.0.as_ref().accept() {
79 Err(e) => {
80 if e.kind() == io::ErrorKind::WouldBlock {
81 self.0.need_read(cx);
82 Poll::Pending
83 } else {
84 Poll::Ready(Err(e))
85 }
86 }
87 Ok((sock, addr)) => Poll::Ready(Ok((TcpStream::from_std(sock)?, addr))),
88 }
89 }
90
91 pub fn from_std(listener: net::TcpListener) -> io::Result<TcpListener> {
94 let listener: socket2::Socket = listener.into();
95 let () = listener.set_nonblocking(true)?;
96 let listener = listener.into();
97 let listener = unsafe { EventedFd::new(listener)? };
98 Ok(TcpListener(listener))
99 }
100
101 pub fn std(&self) -> &net::TcpListener {
103 self.as_ref()
104 }
105
106 pub fn local_addr(&self) -> io::Result<net::SocketAddr> {
108 self.std().local_addr()
109 }
110}
111
112#[derive(Debug)]
114pub struct Acceptor(Option<TcpListener>);
115
116impl Future for Acceptor {
117 type Output = io::Result<(TcpListener, TcpStream, SocketAddr)>;
118
119 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
120 let (stream, addr);
121 {
122 let listener = self.0.as_mut().expect("polled an Acceptor after completion");
123 let (s, a) = ready!(listener.async_accept(cx))?;
124 stream = s;
125 addr = a;
126 }
127 let listener = self.0.take().unwrap();
128 Poll::Ready(Ok((listener, stream, addr)))
129 }
130}
131
132#[derive(Debug)]
134pub struct AcceptStream(TcpListener);
135
136impl Stream for AcceptStream {
137 type Item = io::Result<(TcpStream, SocketAddr)>;
138
139 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 let (stream, addr) = ready!(self.0.async_accept(cx)?);
141 Poll::Ready(Some(Ok((stream, addr))))
142 }
143}
144
145#[derive(Debug)]
151pub struct TcpStream {
152 stream: EventedFd<net::TcpStream>,
153}
154
155impl Deref for TcpStream {
156 type Target = EventedFd<net::TcpStream>;
157
158 fn deref(&self) -> &Self::Target {
159 &self.stream
160 }
161}
162
163impl TcpStream {
164 pub fn connect_from_raw(
168 socket: impl std::os::unix::io::IntoRawFd,
169 addr: SocketAddr,
170 ) -> io::Result<TcpConnector> {
171 let socket = unsafe { socket2::Socket::from_raw_fd(socket.into_raw_fd()) };
175 Self::from_socket2(socket, addr)
176 }
177
178 pub fn connect(addr: SocketAddr) -> io::Result<TcpConnector> {
182 let domain = match addr {
183 SocketAddr::V4(..) => socket2::Domain::IPV4,
184 SocketAddr::V6(..) => socket2::Domain::IPV6,
185 };
186 let socket =
187 socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
188 Self::from_socket2(socket, addr)
189 }
190
191 fn from_socket2(socket: socket2::Socket, addr: SocketAddr) -> io::Result<TcpConnector> {
193 let () = socket.set_nonblocking(true)?;
194 let addr = addr.into();
195 let () = match socket.connect(&addr) {
196 Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => Ok(()),
197 res => res,
198 }?;
199 let stream = socket.into();
200 let stream = unsafe { EventedFd::new(stream)? };
202 let stream = Some(TcpStream { stream });
203
204 Ok(TcpConnector { need_write: true, stream })
205 }
206
207 pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
209 self.std().shutdown(how)
210 }
211
212 fn flush(&mut self) -> io::Result<()> {
214 self.std_mut().flush()
215 }
216
217 fn from_std(stream: net::TcpStream) -> io::Result<TcpStream> {
219 let stream: socket2::Socket = stream.into();
220 let () = stream.set_nonblocking(true)?;
221 let stream = stream.into();
222 let stream = unsafe { EventedFd::new(stream)? };
224 Ok(TcpStream { stream })
225 }
226
227 pub fn std(&self) -> &net::TcpStream {
229 self.as_ref()
230 }
231
232 fn std_mut(&mut self) -> &mut net::TcpStream {
233 self.stream.as_mut()
234 }
235}
236
237impl AsRawFd for TcpStream {
238 fn as_raw_fd(&self) -> RawFd {
239 self.stream.as_raw_fd()
240 }
241}
242
243impl AsyncRead for TcpStream {
244 fn poll_read(
245 mut self: Pin<&mut Self>,
246 cx: &mut Context<'_>,
247 buf: &mut [u8],
248 ) -> Poll<io::Result<usize>> {
249 Pin::new(&mut self.stream).poll_read(cx, buf)
250 }
251
252 }
254
255impl AsyncWrite for TcpStream {
256 fn poll_write(
257 mut self: Pin<&mut Self>,
258 cx: &mut Context<'_>,
259 buf: &[u8],
260 ) -> Poll<io::Result<usize>> {
261 Pin::new(&mut self.stream).poll_write(cx, buf)
262 }
263
264 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
265 match self.get_mut().flush() {
266 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
267 Err(e) => Poll::Ready(Err(e)),
268 Ok(()) => Poll::Ready(Ok(())),
269 }
270 }
271
272 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
273 Poll::Ready(self.as_ref().shutdown(Shutdown::Write))
274 }
275
276 }
278
279#[derive(Debug)]
281pub struct TcpConnector {
282 need_write: bool,
285 stream: Option<TcpStream>,
286}
287
288impl Future for TcpConnector {
289 type Output = io::Result<TcpStream>;
290
291 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292 let this = &mut *self;
293 {
294 let stream = this.stream.as_mut().expect("polled a TcpConnector after completion");
295 if this.need_write {
296 this.need_write = false;
297 stream.need_write(cx);
298 return Poll::Pending;
299 }
300 let () = ready!(stream.poll_writable(cx)).map_err(|s| s.into_io_error())?;
301 let () = match stream.as_ref().take_error() {
302 Ok(None) => Ok(()),
303 Ok(Some(err)) | Err(err) => Err(err),
304 }?;
305 }
306 let stream = this.stream.take().unwrap();
307 Poll::Ready(Ok(stream))
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::{TcpListener, TcpStream};
314 use crate::TestExecutor;
315 use futures::io::{AsyncReadExt, AsyncWriteExt};
316 use futures::stream::StreamExt;
317 use std::io::{Error, ErrorKind};
318 use std::net::{self, Ipv4Addr, SocketAddr};
319
320 #[test]
321 fn choose_listen_port() {
322 let _exec = TestExecutor::new();
323 let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
324 let listener = TcpListener::bind(&addr_request).expect("could not create listener");
325 let actual_addr = listener.local_addr().expect("local_addr query to succeed");
326 assert_eq!(actual_addr.ip(), addr_request.ip());
327 assert_ne!(actual_addr.port(), 0);
328 }
329
330 #[test]
331 fn choose_listen_port_from_std() {
332 let _exec = TestExecutor::new();
333 let addr_request = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
334 let inner = net::TcpListener::bind(addr_request).expect("could not create inner listener");
335 let listener = TcpListener::from_std(inner).expect("could not create listener");
336 let actual_addr = listener.local_addr().expect("local_addr query to succeed");
337 assert_eq!(actual_addr.ip(), addr_request.ip());
338 assert_ne!(actual_addr.port(), 0);
339 }
340
341 #[test]
342 fn connect_to_nonlistening_endpoint() {
343 let mut exec = TestExecutor::new();
344
345 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0).into();
347 let socket = socket2::Socket::new(
348 socket2::Domain::IPV4,
349 socket2::Type::STREAM,
350 Some(socket2::Protocol::TCP),
351 )
352 .expect("could not create socket");
353 let () = socket.bind(&addr).expect("could not bind");
354 let addr = socket.local_addr().expect("local addr query to succeed");
355 let addr = addr.as_socket().expect("local addr to be ipv4 or ipv6");
356
357 let connector = TcpStream::connect(addr).expect("could not create client");
359 let fut = async move {
360 let res = connector.await;
361 assert!(res.is_err());
362 Ok::<(), Error>(())
363 };
364
365 exec.run_singlethreaded(fut).expect("failed to run tcp socket test");
366 }
367
368 #[test]
369 fn send_recv() {
370 let mut exec = TestExecutor::new();
371
372 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
373 let listener = TcpListener::bind(&addr).expect("could not create listener");
374 let addr = listener.local_addr().expect("local_addr query to succeed");
375 let mut listener = listener.accept_stream();
376
377 let query = b"ping";
378 let response = b"pong";
379
380 let server = async move {
381 let (mut socket, _clientaddr) =
382 listener.next().await.expect("stream to not be done").expect("client to connect");
383 drop(listener);
384
385 let mut buf = [0u8; 20];
386 let n = socket.read(&mut buf[..]).await.expect("server read to succeed");
387 assert_eq!(query, &buf[..n]);
388
389 socket.write_all(&response[..]).await.expect("server write to succeed");
390
391 let err = socket.read_exact(&mut buf[..]).await.unwrap_err();
392 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
393 };
394
395 let client = async move {
396 let connector = TcpStream::connect(addr).expect("could not create client");
397 let mut socket = connector.await.expect("client to connect to server");
398
399 socket.write_all(&query[..]).await.expect("client write to succeed");
400
401 let mut buf = [0u8; 20];
402 let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
403 assert_eq!(response, &buf[..n]);
404 };
405
406 exec.run_singlethreaded(futures::future::join(server, client));
407 }
408
409 #[test]
410 fn send_recv_large() {
411 let mut exec = TestExecutor::new();
412 let addr = "127.0.0.1:0".parse().unwrap();
413
414 const BUF_SIZE: usize = 10 * 1024;
415 const WRITES: usize = 1024;
416 const LENGTH: usize = WRITES * BUF_SIZE;
417
418 let listener = TcpListener::bind(&addr).expect("could not create listener");
419 let addr = listener.local_addr().expect("query local_addr");
420 let mut listener = listener.accept_stream();
421
422 let server = async move {
423 let (mut socket, _clientaddr) =
424 listener.next().await.expect("stream to not be done").expect("client to connect");
425 drop(listener);
426
427 let buf = [0u8; BUF_SIZE];
428 for _ in 0usize..WRITES {
429 socket.write_all(&buf[..]).await.expect("server write to succeed");
430 }
431 };
432
433 let client = async move {
434 let connector = TcpStream::connect(addr).expect("could not create client");
435 let mut socket = connector.await.expect("client to connect to server");
436
437 let zeroes = Box::new([0u8; BUF_SIZE]);
438 let mut read = 0;
439 while read < LENGTH {
440 let mut buf = Box::new([1u8; BUF_SIZE]);
441 let n = socket.read(&mut buf[..]).await.expect("client read to succeed");
442 assert_eq!(&buf[0..n], &zeroes[0..n]);
443 read += n;
444 }
445 };
446
447 exec.run_singlethreaded(futures::future::join(server, client));
448 }
449}