usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::{Scope, Socket};
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::{Address, Header, Packet, PacketType, UsbPacketBuilder, UsbPacketFiller};
19
20/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
21pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
22impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
23
24/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
25/// transmitted for them.
26///
27/// This implementation aims to be agnostic to both the underlying transport and the buffers used
28/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
29/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
30///
31/// The client of this library will:
32/// - Use methods on this struct to initiate actions like connecting and accepting
33/// connections to the other end.
34/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
35/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
36pub struct Connection<B> {
37    control_socket_writer: Mutex<WriteHalf<Socket>>,
38    packet_filler: Arc<UsbPacketFiller<B>>,
39    connections: std::sync::Mutex<HashMap<Address, VsockConnection>>,
40    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
41    _task_scope: Scope,
42}
43
44impl<B: PacketBuffer> Connection<B> {
45    /// Creates a new connection with:
46    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
47    /// between host and device) can be read and written from.
48    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
49    /// connection requests from the other side.
50    pub fn new(
51        control_socket: Socket,
52        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
53    ) -> Self {
54        let (control_socket_reader, control_socket_writer) = control_socket.split();
55        let control_socket_writer = Mutex::new(control_socket_writer);
56        let packet_filler = Arc::new(UsbPacketFiller::default());
57        let connections = Default::default();
58        let task_scope = Scope::new_with_name("vsock_usb");
59        task_scope.spawn(Self::run_socket(
60            control_socket_reader,
61            Address::default(),
62            packet_filler.clone(),
63        ));
64        Self {
65            control_socket_writer,
66            packet_filler,
67            connections,
68            incoming_requests_tx,
69            _task_scope: task_scope,
70        }
71    }
72
73    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
74        let header = &mut Header::new(PacketType::Finish);
75        header.set_address(address);
76        usb_packet_filler
77            .write_vsock_packet(&Packet { header, payload: &[] })
78            .await
79            .expect("Finish packet should never be too big");
80    }
81
82    async fn run_socket(
83        mut reader: ReadHalf<Socket>,
84        address: Address,
85        usb_packet_filler: Arc<UsbPacketFiller<B>>,
86    ) {
87        let mut buf = [0; 4096];
88        loop {
89            log::trace!("reading from control socket");
90            let read = match reader.read(&mut buf).await {
91                Ok(0) => {
92                    if !address.is_zeros() {
93                        Self::send_close_packet(&address, &usb_packet_filler).await;
94                    }
95                    return;
96                }
97                Ok(read) => read,
98                Err(err) => {
99                    if address.is_zeros() {
100                        log::error!("Error reading usb socket: {err:?}");
101                    } else {
102                        Self::send_close_packet(&address, &usb_packet_filler).await;
103                    }
104                    return;
105                }
106            };
107            log::trace!("writing {read} bytes to vsock packet");
108            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
109            log::trace!("wrote {read} bytes to vsock packet");
110        }
111    }
112
113    fn set_connection(&self, address: Address, state: VsockConnectionState) -> Result<(), Error> {
114        let mut connections = self.connections.lock().unwrap();
115        if !connections.contains_key(&address) {
116            connections.insert(address.clone(), VsockConnection { _address: address, state });
117            Ok(())
118        } else {
119            Err(Error::other(format!("connection on address {address:?} already set")))
120        }
121    }
122
123    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
124    /// have a distinct target address or payload.
125    pub async fn send_empty_echo(&self) {
126        debug!("Sending empty echo packet");
127        let header = &mut Header::new(PacketType::Echo);
128        self.packet_filler
129            .write_vsock_packet(&Packet { header, payload: &[] })
130            .await
131            .expect("empty echo packet should never be too large to fit in a usb packet");
132    }
133
134    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
135    /// to read and write from. The function will complete when the other end has accepted or
136    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
137    /// for the connection to be closed.
138    pub async fn connect(&self, addr: Address, socket: Socket) -> Result<ConnectionState, Error> {
139        let (read_socket, write_socket) = socket.split();
140        let write_socket = Arc::new(Mutex::new(write_socket));
141        let (connected_tx, connected_rx) = oneshot::channel();
142
143        self.set_connection(
144            addr.clone(),
145            VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
146        )?;
147
148        let header = &mut Header::new(PacketType::Connect);
149        header.set_address(&addr);
150        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
151        connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
152    }
153
154    /// Sends a request for the other end to close the connection.
155    pub async fn close(&self, address: &Address) {
156        Self::send_close_packet(address, &self.packet_filler).await
157    }
158
159    /// Resets the named connection without going through a close request.
160    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
161        let mut notify = None;
162        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
163            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
164                notify = Some(notify_closed);
165            }
166        } else {
167            return Err(Error::other(
168                "Client asked to reset connection {address:?} that did not exist",
169            ));
170        }
171
172        if let Some(mut notify) = notify {
173            notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
174        }
175
176        let header = &mut Header::new(PacketType::Reset);
177        header.set_address(address);
178        self.packet_filler
179            .write_vsock_packet(&Packet { header, payload: &[] })
180            .await
181            .expect("Reset packet should never be too big");
182        Ok(())
183    }
184
185    /// Accepts a connection for which an outstanding connection request has been made, and
186    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
187    /// can be used to wait for the connection to be closed.
188    pub async fn accept(
189        &self,
190        request: ConnectionRequest,
191        socket: Socket,
192    ) -> Result<ConnectionState, Error> {
193        let address = request.address;
194        let notify_closed_rx;
195        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
196            let VsockConnectionState::ConnectingIncoming = &conn.state else {
197                return Err(Error::other(format!(
198                    "Attempted to accept connection that was not waiting at {address:?}"
199                )));
200            };
201
202            let (read_socket, write_socket) = socket.split();
203            let writer = Arc::new(Mutex::new(write_socket));
204            let notify_closed = mpsc::channel(2);
205            notify_closed_rx = notify_closed.1;
206            let notify_closed = notify_closed.0;
207
208            let reader_task = Scope::new_with_name("connection-reader");
209            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
210
211            conn.state = VsockConnectionState::Connected {
212                writer,
213                _reader_scope: reader_task,
214                notify_closed,
215            };
216        } else {
217            return Err(Error::other(format!(
218                "Attempting to accept connection that did not exist at {address:?}"
219            )));
220        }
221        let header = &mut Header::new(PacketType::Accept);
222        header.set_address(&address);
223        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
224        Ok(ConnectionState(notify_closed_rx))
225    }
226
227    /// Rejects a pending connection request from the other side.
228    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
229        let address = request.address;
230        match self.connections.lock().unwrap().entry(address.clone()) {
231            Entry::Occupied(entry) => {
232                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
233                    return Err(Error::other(format!(
234                        "Attempted to reject connection that was not waiting at {address:?}"
235                    )));
236                };
237                entry.remove();
238            }
239            Entry::Vacant(_) => {
240                return Err(Error::other(format!(
241                    "Attempted to reject connection that was not waiting at {address:?}"
242                )));
243            }
244        }
245
246        let header = &mut Header::new(PacketType::Reset);
247        header.set_address(&address);
248        self.packet_filler
249            .write_vsock_packet(&Packet { header, payload: &[] })
250            .await
251            .expect("accept packet should never be too large for packet buffer");
252        Ok(())
253    }
254
255    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
256        // all zero data packets go to the control channel
257        if address.is_zeros() {
258            let written = self.control_socket_writer.lock().await.write(payload).await?;
259            assert_eq!(written, payload.len());
260            Ok(())
261        } else {
262            let payload_socket;
263            if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
264                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
265                    warn!(
266                        "Received data packet for connection in unexpected state for {address:?}"
267                    );
268                    return Ok(());
269                };
270                payload_socket = writer.clone();
271            } else {
272                warn!("Received data packet for connection that didn't exist at {address:?}");
273                return Ok(());
274            }
275            if let Err(err) = payload_socket.lock().await.write_all(payload).await {
276                debug!("Write to socket address {address:?} failed, resetting connection immediately: {err:?}");
277                self.reset(&address).await.inspect_err(|err| warn!("Attempt to reset connection to {address:?} failed after write error: {err:?}")).ok();
278            }
279            Ok(())
280        }
281    }
282
283    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
284        debug!("received echo for {address:?} with payload {payload:?}");
285        let header = &mut Header::new(PacketType::EchoReply);
286        header.payload_len.set(payload.len() as u32);
287        header.set_address(&address);
288        self.packet_filler
289            .write_vsock_packet(&Packet { header, payload })
290            .await
291            .map_err(|_| Error::other("Echo packet was too large to be sent back"))
292    }
293
294    async fn handle_echo_reply_packet(
295        &self,
296        address: Address,
297        payload: &[u8],
298    ) -> Result<(), Error> {
299        // ignore but log replies
300        debug!("received echo reply for {address:?} with payload {payload:?}");
301        Ok(())
302    }
303
304    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
305        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
306            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
307            let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
308            else {
309                warn!("Received accept packet for connection in unexpected state for {address:?}");
310                return Ok(());
311            };
312            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
313            if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
314                warn!("Accept packet received for {address:?} but connect caller stopped waiting for it");
315            }
316
317            let reader_task = Scope::new_with_name("connection-reader");
318            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
319            conn.state = VsockConnectionState::Connected {
320                writer,
321                _reader_scope: reader_task,
322                notify_closed,
323            };
324        } else {
325            warn!("Got accept packet for connection that was not being made at {address:?}");
326            return Ok(());
327        }
328        Ok(())
329    }
330
331    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
332        trace!("received connect packet for {address:?}");
333        match self.connections.lock().unwrap().entry(address.clone()) {
334            Entry::Vacant(entry) => {
335                debug!("valid connect request for {address:?}");
336                entry.insert(VsockConnection {
337                    _address: address,
338                    state: VsockConnectionState::ConnectingIncoming,
339                });
340            }
341            Entry::Occupied(_) => {
342                warn!("Received connect packet for already existing connection for address {address:?}. Ignoring");
343                return Ok(());
344            }
345        }
346
347        trace!("sending incoming connection request to client for {address:?}");
348        let connection_request = ConnectionRequest { address };
349        self.incoming_requests_tx
350            .clone()
351            .send(connection_request)
352            .await
353            .inspect(|_| trace!("sent incoming request for {address:?}"))
354            .map_err(|_| Error::other("Failed to send connection request"))
355    }
356
357    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
358        trace!("received finish packet for {address:?}");
359        let mut notify;
360        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
361            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
362                warn!("Received finish (close) packet for {address:?} which was not in a connected state. Ignoring and dropping connection state.");
363                return Ok(());
364            };
365            notify = notify_closed;
366        } else {
367            warn!("Received finish (close) packet for connection that didn't exist on address {address:?}. Ignoring");
368            return Ok(());
369        }
370
371        notify.send(Ok(())).await.ok();
372
373        let header = &mut Header::new(PacketType::Reset);
374        header.set_address(&address);
375        self.packet_filler
376            .write_vsock_packet(&Packet { header, payload: &[] })
377            .await
378            .expect("accept packet should never be too large for packet buffer");
379        Ok(())
380    }
381
382    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
383        trace!("received reset packet for {address:?}");
384        let mut notify = None;
385        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
386            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
387                notify = Some(notify_closed);
388            } else {
389                debug!("Received reset packet for connection that wasn't in a connecting or disconnected state on address {address:?}.");
390            }
391        } else {
392            warn!("Received reset packet for connection that didn't exist on address {address:?}. Ignoring");
393        }
394
395        if let Some(mut notify) = notify {
396            notify.send(Ok(())).await.ok();
397        }
398        Ok(())
399    }
400
401    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
402    /// or the overall state of the connection.
403    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
404        trace!("received vsock packet {header:?}", header = packet.header);
405        let payload_len = packet.header.payload_len.get() as usize;
406        let payload = &packet.payload[..payload_len];
407        let address = Address::from(packet.header);
408        match packet.header.packet_type {
409            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
410            PacketType::Data => self.handle_data_packet(address, payload).await,
411            PacketType::Accept => self.handle_accept_packet(address).await,
412            PacketType::Connect => self.handle_connect_packet(address).await,
413            PacketType::Finish => self.handle_finish_packet(address).await,
414            PacketType::Reset => self.handle_reset_packet(address).await,
415            PacketType::Echo => self.handle_echo_packet(address, payload).await,
416            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
417        }
418    }
419
420    /// Provides a packet builder for the state machine to write packets to. Returns a future that
421    /// will be fulfilled when there is data available to send on the packet.
422    ///
423    /// # Panics
424    ///
425    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
426    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
427        self.packet_filler.fill_usb_packet(builder).await
428    }
429}
430
431enum VsockConnectionState {
432    ConnectingOutgoing(
433        Arc<Mutex<WriteHalf<Socket>>>,
434        ReadHalf<Socket>,
435        oneshot::Sender<Result<ConnectionState, Error>>,
436    ),
437    ConnectingIncoming,
438    Connected {
439        writer: Arc<Mutex<WriteHalf<Socket>>>,
440        notify_closed: mpsc::Sender<Result<(), Error>>,
441        _reader_scope: Scope,
442    },
443    Invalid,
444}
445
446struct VsockConnection {
447    _address: Address,
448    state: VsockConnectionState,
449}
450
451/// A handle for the state of a connection established with either [`Connection::connect`] or
452/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
453/// needing to hold on to the Socket end.
454#[derive(Debug)]
455pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
456
457impl ConnectionState {
458    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
459    /// and an error if it closed because of an error.
460    pub async fn wait_for_close(mut self) -> Result<(), Error> {
461        self.0
462            .next()
463            .await
464            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
465    }
466}
467
468/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
469/// [`Connection::reject`]ed.
470#[derive(Debug)]
471pub struct ConnectionRequest {
472    address: Address,
473}
474
475impl ConnectionRequest {
476    /// Creates a new connection request for the given address.
477    pub fn new(address: Address) -> Self {
478        Self { address }
479    }
480
481    /// The address this connection request is being made for.
482    pub fn address(&self) -> &Address {
483        &self.address
484    }
485}
486
487#[cfg(test)]
488mod test {
489    use std::sync::Arc;
490
491    use crate::VsockPacketIterator;
492
493    use super::*;
494
495    #[cfg(not(target_os = "fuchsia"))]
496    use fuchsia_async::emulated_handle::Socket as SyncSocket;
497    use fuchsia_async::Task;
498    use futures::StreamExt;
499    #[cfg(target_os = "fuchsia")]
500    use zx::Socket as SyncSocket;
501
502    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>>>) {
503        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
504        loop {
505            println!("waiting for usb packet");
506            builder = echo_connection.fill_usb_packet(builder).await;
507            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
508            println!("got usb packet, echoing it back to the other side");
509            let mut packet_count = 0;
510            for packet in packets {
511                let packet = packet.unwrap();
512                match packet.header.packet_type {
513                    PacketType::Connect => {
514                        // respond with an accept packet
515                        let mut reply_header = packet.header.clone();
516                        reply_header.packet_type = PacketType::Accept;
517                        echo_connection
518                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
519                            .await
520                            .unwrap();
521                    }
522                    PacketType::Accept => {
523                        // just ignore it
524                    }
525                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
526                }
527                packet_count += 1;
528            }
529            println!("handled {packet_count} packets");
530        }
531    }
532
533    #[fuchsia::test]
534    async fn data_over_control_socket() {
535        let (socket, other_socket) = SyncSocket::create_stream();
536        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
537        let mut socket = Socket::from_socket(socket);
538        let connection =
539            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
540
541        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
542
543        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
544            println!("round tripping packet of size {size}");
545            socket.write_all(&vec![size; size as usize]).await.unwrap();
546            let mut buf = vec![0u8; size as usize];
547            socket.read_exact(&mut buf).await.unwrap();
548            assert_eq!(buf, vec![size; size as usize]);
549        }
550        echo_task.cancel().await;
551    }
552
553    #[fuchsia::test]
554    async fn data_over_normal_outgoing_socket() {
555        let (_control_socket, other_socket) = SyncSocket::create_stream();
556        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
557        let connection =
558            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
559
560        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
561
562        let (socket, other_socket) = SyncSocket::create_stream();
563        let mut socket = Socket::from_socket(socket);
564        connection
565            .connect(
566                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
567                Socket::from_socket(other_socket),
568            )
569            .await
570            .unwrap();
571
572        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
573            println!("round tripping packet of size {size}");
574            socket.write_all(&vec![size; size as usize]).await.unwrap();
575            let mut buf = vec![0u8; size as usize];
576            socket.read_exact(&mut buf).await.unwrap();
577            assert_eq!(buf, vec![size; size as usize]);
578        }
579        echo_task.cancel().await;
580    }
581
582    #[fuchsia::test]
583    async fn data_over_normal_incoming_socket() {
584        let (_control_socket, other_socket) = SyncSocket::create_stream();
585        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
586        let connection =
587            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
588
589        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
590
591        let header = &mut Header::new(PacketType::Connect);
592        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
593        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
594
595        let request = incoming_requests.next().await.unwrap();
596        assert_eq!(
597            request.address,
598            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
599        );
600
601        let (socket, other_socket) = SyncSocket::create_stream();
602        let mut socket = Socket::from_socket(socket);
603        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
604
605        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
606            println!("round tripping packet of size {size}");
607            socket.write_all(&vec![size; size as usize]).await.unwrap();
608            let mut buf = vec![0u8; size as usize];
609            socket.read_exact(&mut buf).await.unwrap();
610            assert_eq!(buf, vec![size; size as usize]);
611        }
612        echo_task.cancel().await;
613    }
614
615    async fn copy_connection(from: &Connection<Vec<u8>>, to: &Connection<Vec<u8>>) {
616        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
617        loop {
618            builder = from.fill_usb_packet(builder).await;
619            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
620            for packet in packets {
621                println!("forwarding vsock packet");
622                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
623            }
624        }
625    }
626
627    pub(crate) trait EndToEndTestFn<R>:
628        AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
629    {
630    }
631    impl<T, R> EndToEndTestFn<R> for T where
632        T: AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
633    {
634    }
635
636    pub(crate) async fn end_to_end_test<R1, R2>(
637        left_side: impl EndToEndTestFn<R1>,
638        right_side: impl EndToEndTestFn<R2>,
639    ) -> (R1, R2) {
640        type Connection = crate::Connection<Vec<u8>>;
641        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
642        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
643        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
644        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
645
646        let connection1 =
647            Arc::new(Connection::new(Socket::from_socket(other_socket1), incoming_requests_tx1));
648        let connection2 =
649            Arc::new(Connection::new(Socket::from_socket(other_socket2), incoming_requests_tx2));
650
651        let conn1 = connection1.clone();
652        let conn2 = connection2.clone();
653        let passthrough_task = Task::spawn(async move {
654            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
655            println!("passthrough task loop ended");
656        });
657
658        let res = futures::join!(
659            left_side(connection1, incoming_requests1),
660            right_side(connection2, incoming_requests2)
661        );
662        passthrough_task.cancel().await;
663        res
664    }
665
666    #[fuchsia::test]
667    async fn data_over_end_to_end() {
668        end_to_end_test(
669            async |conn, _incoming| {
670                println!("sending request on connection 1");
671                let (socket, other_socket) = SyncSocket::create_stream();
672                let mut socket = Socket::from_socket(socket);
673                let state = conn
674                    .connect(
675                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
676                        Socket::from_socket(other_socket),
677                    )
678                    .await
679                    .unwrap();
680
681                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
682                    println!("round tripping packet of size {size}");
683                    socket.write_all(&vec![size; size as usize]).await.unwrap();
684                }
685                drop(socket);
686                state.wait_for_close().await.unwrap();
687            },
688            async |conn, mut incoming| {
689                println!("accepting request on connection 2");
690                let request = incoming.next().await.unwrap();
691                assert_eq!(
692                    request.address,
693                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
694                );
695
696                let (socket, other_socket) = SyncSocket::create_stream();
697                let mut socket = Socket::from_socket(socket);
698                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
699
700                println!("accepted request on connection 2");
701                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
702                    let mut buf = vec![0u8; size as usize];
703                    socket.read_exact(&mut buf).await.unwrap();
704                    assert_eq!(buf, vec![size; size as usize]);
705                }
706                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
707                state.wait_for_close().await.unwrap();
708            },
709        )
710        .await;
711    }
712
713    #[fuchsia::test]
714    async fn normal_close_end_to_end() {
715        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
716        end_to_end_test(
717            async |conn, _incoming| {
718                let (socket, other_socket) = SyncSocket::create_stream();
719                let mut socket = Socket::from_socket(socket);
720                let state =
721                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
722                conn.close(&addr).await;
723                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
724                state.wait_for_close().await.unwrap();
725            },
726            async |conn, mut incoming| {
727                println!("accepting request on connection 2");
728                let request = incoming.next().await.unwrap();
729                assert_eq!(request.address, addr.clone(),);
730
731                let (socket, other_socket) = SyncSocket::create_stream();
732                let mut socket = Socket::from_socket(socket);
733                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
734                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
735                state.wait_for_close().await.unwrap();
736            },
737        )
738        .await;
739    }
740
741    #[fuchsia::test]
742    async fn reset_end_to_end() {
743        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
744        end_to_end_test(
745            async |conn, _incoming| {
746                let (socket, other_socket) = SyncSocket::create_stream();
747                let mut socket = Socket::from_socket(socket);
748                let state =
749                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
750                conn.reset(&addr).await.unwrap();
751                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
752                state.wait_for_close().await.expect_err("expected reset");
753            },
754            async |conn, mut incoming| {
755                println!("accepting request on connection 2");
756                let request = incoming.next().await.unwrap();
757                assert_eq!(request.address, addr.clone(),);
758
759                let (socket, other_socket) = SyncSocket::create_stream();
760                let mut socket = Socket::from_socket(socket);
761                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
762                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
763                state.wait_for_close().await.unwrap();
764            },
765        )
766        .await;
767    }
768}