1use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::{Scope, Socket};
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::{Address, Header, Packet, PacketType, UsbPacketBuilder, UsbPacketFiller};
19
20pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
22impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
23
24pub struct Connection<B> {
37 control_socket_writer: Mutex<WriteHalf<Socket>>,
38 packet_filler: Arc<UsbPacketFiller<B>>,
39 connections: std::sync::Mutex<HashMap<Address, VsockConnection>>,
40 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
41 _task_scope: Scope,
42}
43
44impl<B: PacketBuffer> Connection<B> {
45 pub fn new(
51 control_socket: Socket,
52 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
53 ) -> Self {
54 let (control_socket_reader, control_socket_writer) = control_socket.split();
55 let control_socket_writer = Mutex::new(control_socket_writer);
56 let packet_filler = Arc::new(UsbPacketFiller::default());
57 let connections = Default::default();
58 let task_scope = Scope::new_with_name("vsock_usb");
59 task_scope.spawn(Self::run_socket(
60 control_socket_reader,
61 Address::default(),
62 packet_filler.clone(),
63 ));
64 Self {
65 control_socket_writer,
66 packet_filler,
67 connections,
68 incoming_requests_tx,
69 _task_scope: task_scope,
70 }
71 }
72
73 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
74 let header = &mut Header::new(PacketType::Finish);
75 header.set_address(address);
76 usb_packet_filler
77 .write_vsock_packet(&Packet { header, payload: &[] })
78 .await
79 .expect("Finish packet should never be too big");
80 }
81
82 async fn run_socket(
83 mut reader: ReadHalf<Socket>,
84 address: Address,
85 usb_packet_filler: Arc<UsbPacketFiller<B>>,
86 ) {
87 let mut buf = [0; 4096];
88 loop {
89 log::trace!("reading from control socket");
90 let read = match reader.read(&mut buf).await {
91 Ok(0) => {
92 if !address.is_zeros() {
93 Self::send_close_packet(&address, &usb_packet_filler).await;
94 }
95 return;
96 }
97 Ok(read) => read,
98 Err(err) => {
99 if address.is_zeros() {
100 log::error!("Error reading usb socket: {err:?}");
101 } else {
102 Self::send_close_packet(&address, &usb_packet_filler).await;
103 }
104 return;
105 }
106 };
107 log::trace!("writing {read} bytes to vsock packet");
108 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
109 log::trace!("wrote {read} bytes to vsock packet");
110 }
111 }
112
113 fn set_connection(&self, address: Address, state: VsockConnectionState) -> Result<(), Error> {
114 let mut connections = self.connections.lock().unwrap();
115 if !connections.contains_key(&address) {
116 connections.insert(address.clone(), VsockConnection { _address: address, state });
117 Ok(())
118 } else {
119 Err(Error::other(format!("connection on address {address:?} already set")))
120 }
121 }
122
123 pub async fn send_empty_echo(&self) {
126 debug!("Sending empty echo packet");
127 let header = &mut Header::new(PacketType::Echo);
128 self.packet_filler
129 .write_vsock_packet(&Packet { header, payload: &[] })
130 .await
131 .expect("empty echo packet should never be too large to fit in a usb packet");
132 }
133
134 pub async fn connect(&self, addr: Address, socket: Socket) -> Result<ConnectionState, Error> {
139 let (read_socket, write_socket) = socket.split();
140 let write_socket = Arc::new(Mutex::new(write_socket));
141 let (connected_tx, connected_rx) = oneshot::channel();
142
143 self.set_connection(
144 addr.clone(),
145 VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
146 )?;
147
148 let header = &mut Header::new(PacketType::Connect);
149 header.set_address(&addr);
150 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
151 connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
152 }
153
154 pub async fn close(&self, address: &Address) {
156 Self::send_close_packet(address, &self.packet_filler).await
157 }
158
159 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
161 let mut notify = None;
162 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
163 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
164 notify = Some(notify_closed);
165 }
166 } else {
167 return Err(Error::other(
168 "Client asked to reset connection {address:?} that did not exist",
169 ));
170 }
171
172 if let Some(mut notify) = notify {
173 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
174 }
175
176 let header = &mut Header::new(PacketType::Reset);
177 header.set_address(address);
178 self.packet_filler
179 .write_vsock_packet(&Packet { header, payload: &[] })
180 .await
181 .expect("Reset packet should never be too big");
182 Ok(())
183 }
184
185 pub async fn accept(
189 &self,
190 request: ConnectionRequest,
191 socket: Socket,
192 ) -> Result<ConnectionState, Error> {
193 let address = request.address;
194 let notify_closed_rx;
195 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
196 let VsockConnectionState::ConnectingIncoming = &conn.state else {
197 return Err(Error::other(format!(
198 "Attempted to accept connection that was not waiting at {address:?}"
199 )));
200 };
201
202 let (read_socket, write_socket) = socket.split();
203 let writer = Arc::new(Mutex::new(write_socket));
204 let notify_closed = mpsc::channel(2);
205 notify_closed_rx = notify_closed.1;
206 let notify_closed = notify_closed.0;
207
208 let reader_task = Scope::new_with_name("connection-reader");
209 reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
210
211 conn.state = VsockConnectionState::Connected {
212 writer,
213 _reader_scope: reader_task,
214 notify_closed,
215 };
216 } else {
217 return Err(Error::other(format!(
218 "Attempting to accept connection that did not exist at {address:?}"
219 )));
220 }
221 let header = &mut Header::new(PacketType::Accept);
222 header.set_address(&address);
223 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
224 Ok(ConnectionState(notify_closed_rx))
225 }
226
227 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
229 let address = request.address;
230 match self.connections.lock().unwrap().entry(address.clone()) {
231 Entry::Occupied(entry) => {
232 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
233 return Err(Error::other(format!(
234 "Attempted to reject connection that was not waiting at {address:?}"
235 )));
236 };
237 entry.remove();
238 }
239 Entry::Vacant(_) => {
240 return Err(Error::other(format!(
241 "Attempted to reject connection that was not waiting at {address:?}"
242 )));
243 }
244 }
245
246 let header = &mut Header::new(PacketType::Reset);
247 header.set_address(&address);
248 self.packet_filler
249 .write_vsock_packet(&Packet { header, payload: &[] })
250 .await
251 .expect("accept packet should never be too large for packet buffer");
252 Ok(())
253 }
254
255 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
256 if address.is_zeros() {
258 let written = self.control_socket_writer.lock().await.write(payload).await?;
259 assert_eq!(written, payload.len());
260 Ok(())
261 } else {
262 let payload_socket;
263 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
264 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
265 warn!(
266 "Received data packet for connection in unexpected state for {address:?}"
267 );
268 return Ok(());
269 };
270 payload_socket = writer.clone();
271 } else {
272 warn!("Received data packet for connection that didn't exist at {address:?}");
273 return Ok(());
274 }
275 if let Err(err) = payload_socket.lock().await.write_all(payload).await {
276 debug!("Write to socket address {address:?} failed, resetting connection immediately: {err:?}");
277 self.reset(&address).await.inspect_err(|err| warn!("Attempt to reset connection to {address:?} failed after write error: {err:?}")).ok();
278 }
279 Ok(())
280 }
281 }
282
283 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
284 debug!("received echo for {address:?} with payload {payload:?}");
285 let header = &mut Header::new(PacketType::EchoReply);
286 header.payload_len.set(payload.len() as u32);
287 header.set_address(&address);
288 self.packet_filler
289 .write_vsock_packet(&Packet { header, payload })
290 .await
291 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
292 }
293
294 async fn handle_echo_reply_packet(
295 &self,
296 address: Address,
297 payload: &[u8],
298 ) -> Result<(), Error> {
299 debug!("received echo reply for {address:?} with payload {payload:?}");
301 Ok(())
302 }
303
304 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
305 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
306 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
307 let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
308 else {
309 warn!("Received accept packet for connection in unexpected state for {address:?}");
310 return Ok(());
311 };
312 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
313 if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
314 warn!("Accept packet received for {address:?} but connect caller stopped waiting for it");
315 }
316
317 let reader_task = Scope::new_with_name("connection-reader");
318 reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
319 conn.state = VsockConnectionState::Connected {
320 writer,
321 _reader_scope: reader_task,
322 notify_closed,
323 };
324 } else {
325 warn!("Got accept packet for connection that was not being made at {address:?}");
326 return Ok(());
327 }
328 Ok(())
329 }
330
331 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
332 trace!("received connect packet for {address:?}");
333 match self.connections.lock().unwrap().entry(address.clone()) {
334 Entry::Vacant(entry) => {
335 debug!("valid connect request for {address:?}");
336 entry.insert(VsockConnection {
337 _address: address,
338 state: VsockConnectionState::ConnectingIncoming,
339 });
340 }
341 Entry::Occupied(_) => {
342 warn!("Received connect packet for already existing connection for address {address:?}. Ignoring");
343 return Ok(());
344 }
345 }
346
347 trace!("sending incoming connection request to client for {address:?}");
348 let connection_request = ConnectionRequest { address };
349 self.incoming_requests_tx
350 .clone()
351 .send(connection_request)
352 .await
353 .inspect(|_| trace!("sent incoming request for {address:?}"))
354 .map_err(|_| Error::other("Failed to send connection request"))
355 }
356
357 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
358 trace!("received finish packet for {address:?}");
359 let mut notify;
360 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
361 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
362 warn!("Received finish (close) packet for {address:?} which was not in a connected state. Ignoring and dropping connection state.");
363 return Ok(());
364 };
365 notify = notify_closed;
366 } else {
367 warn!("Received finish (close) packet for connection that didn't exist on address {address:?}. Ignoring");
368 return Ok(());
369 }
370
371 notify.send(Ok(())).await.ok();
372
373 let header = &mut Header::new(PacketType::Reset);
374 header.set_address(&address);
375 self.packet_filler
376 .write_vsock_packet(&Packet { header, payload: &[] })
377 .await
378 .expect("accept packet should never be too large for packet buffer");
379 Ok(())
380 }
381
382 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
383 trace!("received reset packet for {address:?}");
384 let mut notify = None;
385 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
386 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
387 notify = Some(notify_closed);
388 } else {
389 debug!("Received reset packet for connection that wasn't in a connecting or disconnected state on address {address:?}.");
390 }
391 } else {
392 warn!("Received reset packet for connection that didn't exist on address {address:?}. Ignoring");
393 }
394
395 if let Some(mut notify) = notify {
396 notify.send(Ok(())).await.ok();
397 }
398 Ok(())
399 }
400
401 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
404 trace!("received vsock packet {header:?}", header = packet.header);
405 let payload_len = packet.header.payload_len.get() as usize;
406 let payload = &packet.payload[..payload_len];
407 let address = Address::from(packet.header);
408 match packet.header.packet_type {
409 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
410 PacketType::Data => self.handle_data_packet(address, payload).await,
411 PacketType::Accept => self.handle_accept_packet(address).await,
412 PacketType::Connect => self.handle_connect_packet(address).await,
413 PacketType::Finish => self.handle_finish_packet(address).await,
414 PacketType::Reset => self.handle_reset_packet(address).await,
415 PacketType::Echo => self.handle_echo_packet(address, payload).await,
416 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
417 }
418 }
419
420 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
427 self.packet_filler.fill_usb_packet(builder).await
428 }
429}
430
431enum VsockConnectionState {
432 ConnectingOutgoing(
433 Arc<Mutex<WriteHalf<Socket>>>,
434 ReadHalf<Socket>,
435 oneshot::Sender<Result<ConnectionState, Error>>,
436 ),
437 ConnectingIncoming,
438 Connected {
439 writer: Arc<Mutex<WriteHalf<Socket>>>,
440 notify_closed: mpsc::Sender<Result<(), Error>>,
441 _reader_scope: Scope,
442 },
443 Invalid,
444}
445
446struct VsockConnection {
447 _address: Address,
448 state: VsockConnectionState,
449}
450
451#[derive(Debug)]
455pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
456
457impl ConnectionState {
458 pub async fn wait_for_close(mut self) -> Result<(), Error> {
461 self.0
462 .next()
463 .await
464 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
465 }
466}
467
468#[derive(Debug)]
471pub struct ConnectionRequest {
472 address: Address,
473}
474
475impl ConnectionRequest {
476 pub fn new(address: Address) -> Self {
478 Self { address }
479 }
480
481 pub fn address(&self) -> &Address {
483 &self.address
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use std::sync::Arc;
490
491 use crate::VsockPacketIterator;
492
493 use super::*;
494
495 #[cfg(not(target_os = "fuchsia"))]
496 use fuchsia_async::emulated_handle::Socket as SyncSocket;
497 use fuchsia_async::Task;
498 use futures::StreamExt;
499 #[cfg(target_os = "fuchsia")]
500 use zx::Socket as SyncSocket;
501
502 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>>>) {
503 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
504 loop {
505 println!("waiting for usb packet");
506 builder = echo_connection.fill_usb_packet(builder).await;
507 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
508 println!("got usb packet, echoing it back to the other side");
509 let mut packet_count = 0;
510 for packet in packets {
511 let packet = packet.unwrap();
512 match packet.header.packet_type {
513 PacketType::Connect => {
514 let mut reply_header = packet.header.clone();
516 reply_header.packet_type = PacketType::Accept;
517 echo_connection
518 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
519 .await
520 .unwrap();
521 }
522 PacketType::Accept => {
523 }
525 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
526 }
527 packet_count += 1;
528 }
529 println!("handled {packet_count} packets");
530 }
531 }
532
533 #[fuchsia::test]
534 async fn data_over_control_socket() {
535 let (socket, other_socket) = SyncSocket::create_stream();
536 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
537 let mut socket = Socket::from_socket(socket);
538 let connection =
539 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
540
541 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
542
543 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
544 println!("round tripping packet of size {size}");
545 socket.write_all(&vec![size; size as usize]).await.unwrap();
546 let mut buf = vec![0u8; size as usize];
547 socket.read_exact(&mut buf).await.unwrap();
548 assert_eq!(buf, vec![size; size as usize]);
549 }
550 echo_task.cancel().await;
551 }
552
553 #[fuchsia::test]
554 async fn data_over_normal_outgoing_socket() {
555 let (_control_socket, other_socket) = SyncSocket::create_stream();
556 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
557 let connection =
558 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
559
560 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
561
562 let (socket, other_socket) = SyncSocket::create_stream();
563 let mut socket = Socket::from_socket(socket);
564 connection
565 .connect(
566 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
567 Socket::from_socket(other_socket),
568 )
569 .await
570 .unwrap();
571
572 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
573 println!("round tripping packet of size {size}");
574 socket.write_all(&vec![size; size as usize]).await.unwrap();
575 let mut buf = vec![0u8; size as usize];
576 socket.read_exact(&mut buf).await.unwrap();
577 assert_eq!(buf, vec![size; size as usize]);
578 }
579 echo_task.cancel().await;
580 }
581
582 #[fuchsia::test]
583 async fn data_over_normal_incoming_socket() {
584 let (_control_socket, other_socket) = SyncSocket::create_stream();
585 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
586 let connection =
587 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
588
589 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
590
591 let header = &mut Header::new(PacketType::Connect);
592 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
593 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
594
595 let request = incoming_requests.next().await.unwrap();
596 assert_eq!(
597 request.address,
598 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
599 );
600
601 let (socket, other_socket) = SyncSocket::create_stream();
602 let mut socket = Socket::from_socket(socket);
603 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
604
605 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
606 println!("round tripping packet of size {size}");
607 socket.write_all(&vec![size; size as usize]).await.unwrap();
608 let mut buf = vec![0u8; size as usize];
609 socket.read_exact(&mut buf).await.unwrap();
610 assert_eq!(buf, vec![size; size as usize]);
611 }
612 echo_task.cancel().await;
613 }
614
615 async fn copy_connection(from: &Connection<Vec<u8>>, to: &Connection<Vec<u8>>) {
616 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
617 loop {
618 builder = from.fill_usb_packet(builder).await;
619 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
620 for packet in packets {
621 println!("forwarding vsock packet");
622 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
623 }
624 }
625 }
626
627 pub(crate) trait EndToEndTestFn<R>:
628 AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
629 {
630 }
631 impl<T, R> EndToEndTestFn<R> for T where
632 T: AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
633 {
634 }
635
636 pub(crate) async fn end_to_end_test<R1, R2>(
637 left_side: impl EndToEndTestFn<R1>,
638 right_side: impl EndToEndTestFn<R2>,
639 ) -> (R1, R2) {
640 type Connection = crate::Connection<Vec<u8>>;
641 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
642 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
643 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
644 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
645
646 let connection1 =
647 Arc::new(Connection::new(Socket::from_socket(other_socket1), incoming_requests_tx1));
648 let connection2 =
649 Arc::new(Connection::new(Socket::from_socket(other_socket2), incoming_requests_tx2));
650
651 let conn1 = connection1.clone();
652 let conn2 = connection2.clone();
653 let passthrough_task = Task::spawn(async move {
654 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
655 println!("passthrough task loop ended");
656 });
657
658 let res = futures::join!(
659 left_side(connection1, incoming_requests1),
660 right_side(connection2, incoming_requests2)
661 );
662 passthrough_task.cancel().await;
663 res
664 }
665
666 #[fuchsia::test]
667 async fn data_over_end_to_end() {
668 end_to_end_test(
669 async |conn, _incoming| {
670 println!("sending request on connection 1");
671 let (socket, other_socket) = SyncSocket::create_stream();
672 let mut socket = Socket::from_socket(socket);
673 let state = conn
674 .connect(
675 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
676 Socket::from_socket(other_socket),
677 )
678 .await
679 .unwrap();
680
681 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
682 println!("round tripping packet of size {size}");
683 socket.write_all(&vec![size; size as usize]).await.unwrap();
684 }
685 drop(socket);
686 state.wait_for_close().await.unwrap();
687 },
688 async |conn, mut incoming| {
689 println!("accepting request on connection 2");
690 let request = incoming.next().await.unwrap();
691 assert_eq!(
692 request.address,
693 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
694 );
695
696 let (socket, other_socket) = SyncSocket::create_stream();
697 let mut socket = Socket::from_socket(socket);
698 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
699
700 println!("accepted request on connection 2");
701 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
702 let mut buf = vec![0u8; size as usize];
703 socket.read_exact(&mut buf).await.unwrap();
704 assert_eq!(buf, vec![size; size as usize]);
705 }
706 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
707 state.wait_for_close().await.unwrap();
708 },
709 )
710 .await;
711 }
712
713 #[fuchsia::test]
714 async fn normal_close_end_to_end() {
715 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
716 end_to_end_test(
717 async |conn, _incoming| {
718 let (socket, other_socket) = SyncSocket::create_stream();
719 let mut socket = Socket::from_socket(socket);
720 let state =
721 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
722 conn.close(&addr).await;
723 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
724 state.wait_for_close().await.unwrap();
725 },
726 async |conn, mut incoming| {
727 println!("accepting request on connection 2");
728 let request = incoming.next().await.unwrap();
729 assert_eq!(request.address, addr.clone(),);
730
731 let (socket, other_socket) = SyncSocket::create_stream();
732 let mut socket = Socket::from_socket(socket);
733 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
734 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
735 state.wait_for_close().await.unwrap();
736 },
737 )
738 .await;
739 }
740
741 #[fuchsia::test]
742 async fn reset_end_to_end() {
743 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
744 end_to_end_test(
745 async |conn, _incoming| {
746 let (socket, other_socket) = SyncSocket::create_stream();
747 let mut socket = Socket::from_socket(socket);
748 let state =
749 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
750 conn.reset(&addr).await.unwrap();
751 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
752 state.wait_for_close().await.expect_err("expected reset");
753 },
754 async |conn, mut incoming| {
755 println!("accepting request on connection 2");
756 let request = incoming.next().await.unwrap();
757 assert_eq!(request.address, addr.clone(),);
758
759 let (socket, other_socket) = SyncSocket::create_stream();
760 let mut socket = Socket::from_socket(socket);
761 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
762 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
763 state.wait_for_close().await.unwrap();
764 },
765 )
766 .await;
767 }
768}