usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23    Address, Header, Packet, PacketType, ProtocolVersion, ShutdownError, UsbPacketBuilder,
24    UsbPacketFiller, WritePacketErrorExt,
25};
26
27mod overflow_writer;
28mod pause_state;
29
30use overflow_writer::OverflowWriter;
31use pause_state::PauseState;
32
33/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
34pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
36
37#[derive(Copy, Clone, PartialEq, Eq)]
38enum PausePacket {
39    Pause,
40    UnPause,
41}
42
43impl PausePacket {
44    fn bytes(&self) -> [u8; 1] {
45        match self {
46            PausePacket::Pause => [1],
47            PausePacket::UnPause => [0],
48        }
49    }
50}
51
52/// A connection that has been established with the other end and now just needs
53/// a socket to start transmitting.
54pub struct ReadyConnect<B, S> {
55    connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
56    packet_filler: Arc<UsbPacketFiller<B>>,
57    address: Address,
58}
59
60impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
61    /// Finish establishing the connection by providing a socket for data transfer.
62    pub async fn finish_connect(self, socket: S) {
63        let (read_socket, write_socket) = socket.split();
64        let writer = {
65            let conns = self.connections.lock();
66            let Some(conn) = conns.get(&self.address) else {
67                warn!("Connection state was missing after connection success!");
68                return;
69            };
70            let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
71                &conn.state
72            else {
73                warn!("Connection state was invalid after connection success!");
74                return;
75            };
76            reader_scope.spawn(Connection::<B, S>::run_socket(
77                read_socket,
78                self.address,
79                self.packet_filler,
80                Arc::clone(pause_state),
81            ));
82            Arc::clone(writer)
83        };
84        let mut writer = writer.lock().await;
85        let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
86            &mut *writer,
87            ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
88        ) else {
89            unreachable!("Connection completed multiple times!")
90        };
91
92        wakers.into_iter().for_each(Waker::wake);
93    }
94}
95
96/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
97/// transmitted for them.
98///
99/// This implementation aims to be agnostic to both the underlying transport and the buffers used
100/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
101/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
102///
103/// The client of this library will:
104/// - Use methods on this struct to initiate actions like connecting and accepting
105/// connections to the other end.
106/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
107/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
108pub struct Connection<B, S> {
109    control_socket_writer: Option<Mutex<WriteHalf<S>>>,
110    packet_filler: Arc<UsbPacketFiller<B>>,
111    protocol_version: ProtocolVersion,
112    connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
113    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
114    task_scope: Scope,
115}
116
117impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
118    /// Creates a new connection with:
119    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
120    /// between host and device) can be read and written from. If this is `None`
121    /// we will discard control data.
122    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
123    /// connection requests from the other side.
124    pub fn new(
125        protocol_version: ProtocolVersion,
126        control_socket: Option<S>,
127        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
128    ) -> Self {
129        let packet_filler = Arc::new(UsbPacketFiller::default());
130        let connections = Default::default();
131        let task_scope = Scope::new_with_name("vsock_usb");
132        let control_socket_writer = control_socket.map(|control_socket| {
133            let (control_socket_reader, control_socket_writer) = control_socket.split();
134            task_scope.spawn(Self::run_socket(
135                control_socket_reader,
136                Address::default(),
137                packet_filler.clone(),
138                PauseState::new(),
139            ));
140            Mutex::new(control_socket_writer)
141        });
142        Self {
143            control_socket_writer,
144            packet_filler,
145            connections,
146            incoming_requests_tx,
147            protocol_version,
148            task_scope,
149        }
150    }
151
152    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
153        let header = &mut Header::new(PacketType::Finish);
154        header.set_address(address);
155        let _: Result<_, ShutdownError> = usb_packet_filler
156            .write_vsock_packet(&Packet { header, payload: &[] })
157            .await
158            .expect_right_size("Finish packet should never be too big");
159    }
160
161    async fn run_socket(
162        mut reader: ReadHalf<S>,
163        address: Address,
164        usb_packet_filler: Arc<UsbPacketFiller<B>>,
165        pause_state: Arc<PauseState>,
166    ) {
167        let mut buf = [0; 4096];
168        loop {
169            log::trace!("reading from control socket");
170            let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
171                Ok(0) => {
172                    if !address.is_zeros() {
173                        Self::send_close_packet(&address, &usb_packet_filler).await;
174                    }
175                    return;
176                }
177                Ok(read) => read,
178                Err(err) => {
179                    if address.is_zeros() {
180                        log::error!("Error reading usb socket: {err:?}");
181                    } else {
182                        Self::send_close_packet(&address, &usb_packet_filler).await;
183                    }
184                    return;
185                }
186            };
187            log::trace!("writing {read} bytes to vsock packet");
188            if usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await.is_err() {
189                log::trace!("transport shut down during read");
190                return;
191            }
192            log::trace!("wrote {read} bytes to vsock packet");
193        }
194    }
195
196    fn set_connection(
197        &self,
198        address: Address,
199        state: VsockConnectionState<S>,
200    ) -> Result<(), Error> {
201        let mut connections = self.connections.lock();
202        if !connections.contains_key(&address) {
203            connections.insert(address.clone(), VsockConnection { _address: address, state });
204            Ok(())
205        } else {
206            Err(Error::other(format!("connection on address {address:?} already set")))
207        }
208    }
209
210    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
211    /// have a distinct target address or payload.
212    pub async fn send_empty_echo(&self) {
213        debug!("Sending empty echo packet");
214        let header = &mut Header::new(PacketType::Echo);
215        let _: Result<_, ShutdownError> = self
216            .packet_filler
217            .write_vsock_packet(&Packet { header, payload: &[] })
218            .await
219            .expect_right_size(
220                "empty echo packet should never be too large to fit in a usb packet",
221            );
222    }
223
224    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
225    /// to read and write from. The function will complete when the other end has accepted or
226    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
227    /// for the connection to be closed.
228    pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
229        let (ready, state) = self.connect_late(addr).await?;
230        ready.finish_connect(socket).await;
231        Ok(state)
232    }
233
234    /// Same as [`connect`] but doesn't require the socket to be passed. Instead
235    /// we return a [`ReadyConnect`] which can be given the socket later. This
236    /// shouldn't be deferred very long but it is useful if the socket is
237    /// starting out speaking a different protocol and needs to execute a
238    /// protocol switch, but needs to know the connection status before doing
239    /// that switch.
240    pub async fn connect_late(
241        &self,
242        addr: Address,
243    ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
244        let (connected_tx, connected_rx) = oneshot::channel();
245
246        self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
247
248        let header = &mut Header::new(PacketType::Connect);
249        header.set_address(&addr);
250        self.packet_filler
251            .write_vsock_packet(&Packet { header, payload: &[] })
252            .await
253            .assert_right_size()?;
254        let Ok(conn_state) = connected_rx.await else {
255            return Err(Error::other("Accept was never received for {addr:?}"));
256        };
257
258        Ok((
259            ReadyConnect {
260                connections: Arc::clone(&self.connections),
261                packet_filler: Arc::clone(&self.packet_filler),
262                address: addr,
263            },
264            conn_state,
265        ))
266    }
267
268    /// Sends a request for the other end to close the connection.
269    pub async fn close(&self, address: &Address) {
270        Self::send_close_packet(address, &self.packet_filler).await
271    }
272
273    /// Resets the named connection without going through a close request.
274    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
275        reset(address, &self.connections, &self.packet_filler).await
276    }
277
278    /// Accepts a connection for which an outstanding connection request has been made, and
279    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
280    /// can be used to wait for the connection to be closed.
281    pub async fn accept(
282        &self,
283        request: ConnectionRequest,
284        socket: S,
285    ) -> Result<ConnectionState, Error> {
286        let (ready, state) = self.accept_late(request).await?;
287        ready.finish_connect(socket).await;
288        Ok(state)
289    }
290
291    /// Accepts a connection for which an outstanding connection request has been made, and
292    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
293    /// can be used to wait for the connection to be closed.
294    pub async fn accept_late(
295        &self,
296        request: ConnectionRequest,
297    ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
298        let address = request.address;
299        let notify_closed_rx;
300        if let Some(conn) = self.connections.lock().get_mut(&address) {
301            let VsockConnectionState::ConnectingIncoming = &conn.state else {
302                return Err(Error::other(format!(
303                    "Attempted to accept connection that was not waiting at {address:?}"
304                )));
305            };
306
307            let notify_closed = mpsc::channel(2);
308            notify_closed_rx = notify_closed.1;
309            let notify_closed = notify_closed.0;
310            let pause_state = PauseState::new();
311
312            let reader_scope = Scope::new_with_name("connection-reader");
313
314            conn.state = VsockConnectionState::Connected {
315                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
316                reader_scope,
317                notify_closed,
318                pause_state,
319            };
320        } else {
321            return Err(Error::other(format!(
322                "Attempting to accept connection that did not exist at {address:?}"
323            )));
324        }
325        let header = &mut Header::new(PacketType::Accept);
326        header.set_address(&address);
327        self.packet_filler
328            .write_vsock_packet(&Packet { header, payload: &[] })
329            .await
330            .assert_right_size()?;
331        Ok((
332            ReadyConnect {
333                connections: Arc::clone(&self.connections),
334                packet_filler: Arc::clone(&self.packet_filler),
335                address,
336            },
337            ConnectionState(notify_closed_rx),
338        ))
339    }
340
341    /// Rejects a pending connection request from the other side.
342    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
343        let address = request.address;
344        match self.connections.lock().entry(address.clone()) {
345            Entry::Occupied(entry) => {
346                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
347                    return Err(Error::other(format!(
348                        "Attempted to reject connection that was not waiting at {address:?}"
349                    )));
350                };
351                entry.remove();
352            }
353            Entry::Vacant(_) => {
354                return Err(Error::other(format!(
355                    "Attempted to reject connection that was not waiting at {address:?}"
356                )));
357            }
358        }
359
360        let header = &mut Header::new(PacketType::Reset);
361        header.set_address(&address);
362        self.packet_filler
363            .write_vsock_packet(&Packet { header, payload: &[] })
364            .await
365            .expect_right_size("accept packet should never be too large for packet buffer")?;
366        Ok(())
367    }
368
369    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
370        // all zero data packets go to the control channel
371        if address.is_zeros() {
372            if let Some(writer) = self.control_socket_writer.as_ref() {
373                writer.lock().await.write_all(payload).await?;
374            } else {
375                trace!("Discarding {} bytes of data sent to control socket", payload.len());
376            }
377            Ok(())
378        } else {
379            let payload_socket;
380            if let Some(conn) = self.connections.lock().get_mut(&address) {
381                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
382                    warn!(
383                        "Received data packet for connection in unexpected state for {address:?}"
384                    );
385                    return Ok(());
386                };
387                payload_socket = writer.clone();
388            } else {
389                warn!("Received data packet for connection that didn't exist at {address:?}");
390                return Ok(());
391            }
392            let mut socket_guard =
393                ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
394            let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
395                unreachable!("wait_available didn't wait until socket was available!");
396            };
397            match socket.write_all(payload) {
398                Err(err) => {
399                    debug!(
400                        "Write to socket address {address:?} failed, \
401                         resetting connection immediately: {err:?}"
402                    );
403                    self.reset(&address)
404                        .await
405                        .inspect_err(|err| {
406                            warn!(
407                                "Attempt to reset connection to {address:?} \
408                                   failed after write error: {err:?}"
409                            );
410                        })
411                        .ok();
412                }
413                Ok(status) => {
414                    if status.overflowed() {
415                        if self.protocol_version.has_pause_packets() {
416                            let header = &mut Header::new(PacketType::Pause);
417                            let payload = &PausePacket::Pause.bytes();
418                            header.set_address(&address);
419                            header.payload_len.set(payload.len() as u32);
420                            self.packet_filler
421                                .write_vsock_packet(&Packet { header, payload })
422                                .await
423                                .expect_right_size(
424                                    "pause packet should never be too large to fit in a usb packet",
425                                )?;
426                        }
427
428                        let weak_payload_socket = Arc::downgrade(&payload_socket);
429                        let connections = Arc::clone(&self.connections);
430                        let has_pause_packets = self.protocol_version.has_pause_packets();
431                        let packet_filler = Arc::clone(&self.packet_filler);
432                        self.task_scope.spawn(async move {
433                            let res = OverflowHandleFut::new(weak_payload_socket).await;
434
435                            if let Err(err) = res {
436                                debug!(
437                                    "Write to socket address {address:?} failed while \
438                                     processing backlog, resetting connection at next poll: {err:?}"
439                                );
440                                if let Err(err) = reset(&address, &connections, &packet_filler).await {
441                                    debug!("Error sending reset frame after overflow write failed: {err:?}");
442                                }
443                            } else if has_pause_packets {
444                                let header = &mut Header::new(PacketType::Pause);
445                                let payload = &PausePacket::UnPause.bytes();
446                                header.set_address(&address);
447                                header.payload_len.set(payload.len() as u32);
448                                let _: Result<_, ShutdownError> =
449                                packet_filler
450                                    .write_vsock_packet(&Packet { header, payload })
451                                    .await
452                                    .expect_right_size("pause packet should never be too large to fit in a usb packet");
453                            }
454                        });
455                    }
456                }
457            }
458            Ok(())
459        }
460    }
461
462    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
463        debug!("received echo for {address:?} with payload {payload:?}");
464        let header = &mut Header::new(PacketType::EchoReply);
465        header.payload_len.set(payload.len() as u32);
466        header.set_address(&address);
467        self.packet_filler.write_vsock_packet(&Packet { header, payload }).await.map_err(
468            |e| match e {
469                crate::WritePacketError::PacketTooBig(_) => {
470                    Error::other("Echo packet was too large to be sent back")
471                }
472                crate::WritePacketError::Shutdown(shutdown_error) => shutdown_error.into(),
473            },
474        )
475    }
476
477    async fn handle_echo_reply_packet(
478        &self,
479        address: Address,
480        payload: &[u8],
481    ) -> Result<(), Error> {
482        // ignore but log replies
483        debug!("received echo reply for {address:?} with payload {payload:?}");
484        Ok(())
485    }
486
487    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
488        if let Some(conn) = self.connections.lock().get_mut(&address) {
489            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
490            let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
491                warn!("Received accept packet for connection in unexpected state for {address:?}");
492                return Ok(());
493            };
494            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
495            if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
496                warn!(
497                    "Accept packet received for {address:?} but connect caller stopped waiting for it"
498                );
499            }
500            let pause_state = PauseState::new();
501
502            let reader_scope = Scope::new_with_name("connection-reader");
503            conn.state = VsockConnectionState::Connected {
504                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
505                reader_scope,
506                notify_closed,
507                pause_state,
508            };
509        } else {
510            warn!("Got accept packet for connection that was not being made at {address:?}");
511            return Ok(());
512        }
513        Ok(())
514    }
515
516    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
517        trace!("received connect packet for {address:?}");
518        match self.connections.lock().entry(address.clone()) {
519            Entry::Vacant(entry) => {
520                debug!("valid connect request for {address:?}");
521                entry.insert(VsockConnection {
522                    _address: address,
523                    state: VsockConnectionState::ConnectingIncoming,
524                });
525            }
526            Entry::Occupied(_) => {
527                warn!(
528                    "Received connect packet for already existing \
529                     connection for address {address:?}. Ignoring"
530                );
531                return Ok(());
532            }
533        }
534
535        trace!("sending incoming connection request to client for {address:?}");
536        let connection_request = ConnectionRequest { address };
537        self.incoming_requests_tx
538            .clone()
539            .send(connection_request)
540            .await
541            .inspect(|_| trace!("sent incoming request for {address:?}"))
542            .map_err(|_| Error::other("Failed to send connection request"))
543    }
544
545    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
546        trace!("received finish packet for {address:?}");
547        let mut notify;
548        if let Some(conn) = self.connections.lock().remove(&address) {
549            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
550                warn!(
551                    "Received finish (close) packet for {address:?} \
552                     which was not in a connected state. Ignoring and dropping connection state."
553                );
554                return Ok(());
555            };
556            notify = notify_closed;
557        } else {
558            warn!(
559                "Received finish (close) packet for connection that didn't exist \
560                 on address {address:?}. Ignoring"
561            );
562            return Ok(());
563        }
564
565        notify.send(Ok(())).await.ok();
566
567        let header = &mut Header::new(PacketType::Reset);
568        header.set_address(&address);
569        self.packet_filler
570            .write_vsock_packet(&Packet { header, payload: &[] })
571            .await
572            .expect_right_size("accept packet should never be too large for packet buffer")?;
573        Ok(())
574    }
575
576    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
577        trace!("received reset packet for {address:?}");
578        let mut notify = None;
579        if let Some(conn) = self.connections.lock().remove(&address) {
580            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
581                notify = Some(notify_closed);
582            } else {
583                debug!(
584                    "Received reset packet for connection that wasn't in a connecting or \
585                    disconnected state on address {address:?}."
586                );
587            }
588        } else {
589            warn!(
590                "Received reset packet for connection that didn't \
591                exist on address {address:?}. Ignoring"
592            );
593        }
594
595        if let Some(mut notify) = notify {
596            notify.send(Ok(())).await.ok();
597        }
598        Ok(())
599    }
600
601    async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
602        if !self.protocol_version.has_pause_packets() {
603            warn!(
604                "Got a pause packet while using protocol \
605                 version {} which does not support them. Ignoring",
606                self.protocol_version
607            );
608            return Ok(());
609        }
610
611        let pause = match payload {
612            [1] => true,
613            [0] => false,
614            other => {
615                warn!("Ignoring unexpected pause packet payload {other:?}");
616                return Ok(());
617            }
618        };
619
620        if let Some(conn) = self.connections.lock().get(&address) {
621            if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
622                pause_state.set_paused(pause);
623            } else {
624                warn!("Received pause packet for unestablished connection. Ignoring");
625            };
626        } else {
627            warn!(
628                "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
629            );
630        }
631
632        Ok(())
633    }
634
635    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
636    /// or the overall state of the connection.
637    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
638        trace!("received vsock packet {header:?}", header = packet.header);
639        let payload_len = packet.header.payload_len.get() as usize;
640        let payload = &packet.payload[..payload_len];
641        let address = Address::from(packet.header);
642        match packet.header.packet_type {
643            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
644            PacketType::Data => self.handle_data_packet(address, payload).await,
645            PacketType::Accept => self.handle_accept_packet(address).await,
646            PacketType::Connect => self.handle_connect_packet(address).await,
647            PacketType::Finish => self.handle_finish_packet(address).await,
648            PacketType::Reset => self.handle_reset_packet(address).await,
649            PacketType::Echo => self.handle_echo_packet(address, payload).await,
650            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
651            PacketType::Pause => self.handle_pause_packet(address, payload).await,
652        }
653    }
654
655    /// Provides a packet builder for the state machine to write packets to. Returns a future that
656    /// will be fulfilled when there is data available to send on the packet.
657    ///
658    /// # Panics
659    ///
660    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
661    pub async fn fill_usb_packet(
662        &self,
663        builder: UsbPacketBuilder<B>,
664    ) -> Result<UsbPacketBuilder<B>, ShutdownError> {
665        self.packet_filler.fill_usb_packet(builder).await
666    }
667}
668
669impl<B: PacketBuffer, S> Connection<B, S> {
670    /// Inform this connection that whatever transport was providing and
671    /// receiving packets has hung up and no more data will be written or read.
672    pub fn shutdown(&self) {
673        self.packet_filler.shutdown();
674        self.connections.lock().clear();
675    }
676}
677
678async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
679    address: &Address,
680    connections: &fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>,
681    packet_filler: &UsbPacketFiller<B>,
682) -> Result<(), Error> {
683    let mut notify = None;
684    if let Some(conn) = connections.lock().remove(&address) {
685        if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
686            notify = Some(notify_closed);
687        }
688    } else {
689        return Err(Error::other(
690            "Client asked to reset connection {address:?} that did not exist",
691        ));
692    }
693
694    if let Some(mut notify) = notify {
695        notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
696    }
697
698    let header = &mut Header::new(PacketType::Reset);
699    header.set_address(address);
700    packet_filler
701        .write_vsock_packet(&Packet { header, payload: &[] })
702        .await
703        .expect_right_size("Reset packet should never be too big")?;
704    Ok(())
705}
706
707/// A writer inside of a [`ConnectionState`]. This is essentially an
708/// option-monad around an [`OverflowWriter`], but unlike
709/// [`std::option::Option`] the empty variant stores wakers that by convention
710/// will be woken when we replace it with the occupied variant.
711enum ConnectionStateWriter<S> {
712    NotYetAvailable(Vec<Waker>),
713    Available(OverflowWriter<S>),
714}
715
716impl<S> ConnectionStateWriter<S> {
717    /// Wait for the given `ConnectionStateWriter` to contain an actual writer.
718    fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
719        ConnectionStateWriterFut { writer: this, lock_fut: None }
720    }
721}
722
723/// Future returned by [`ConnectionStateWriter::wait_available`].
724struct ConnectionStateWriterFut<S> {
725    writer: Arc<Mutex<ConnectionStateWriter<S>>>,
726    lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
727}
728
729impl<S> Future for ConnectionStateWriterFut<S> {
730    type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
731
732    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
733        let writer = Arc::clone(&self.writer);
734        let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
735        let mut lock = ready!(lock_fut.poll_unpin(cx));
736        self.lock_fut = None;
737        match &mut *lock {
738            ConnectionStateWriter::Available(_) => Poll::Ready(lock),
739            ConnectionStateWriter::NotYetAvailable(queue) => {
740                queue.push(cx.waker().clone());
741                Poll::Pending
742            }
743        }
744    }
745}
746
747enum VsockConnectionState<S> {
748    ConnectingOutgoing(oneshot::Sender<ConnectionState>),
749    ConnectingIncoming,
750    Connected {
751        writer: Arc<Mutex<ConnectionStateWriter<S>>>,
752        notify_closed: mpsc::Sender<Result<(), Error>>,
753        pause_state: Arc<PauseState>,
754        reader_scope: Scope,
755    },
756    Invalid,
757}
758
759struct VsockConnection<S> {
760    _address: Address,
761    state: VsockConnectionState<S>,
762}
763
764/// A handle for the state of a connection established with either [`Connection::connect`] or
765/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
766/// needing to hold on to the Socket end.
767#[derive(Debug)]
768pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
769
770impl ConnectionState {
771    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
772    /// and an error if it closed because of an error.
773    pub async fn wait_for_close(mut self) -> Result<(), Error> {
774        self.0
775            .next()
776            .await
777            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
778    }
779}
780
781/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
782/// [`Connection::reject`]ed.
783#[derive(Debug)]
784pub struct ConnectionRequest {
785    address: Address,
786}
787
788impl ConnectionRequest {
789    /// Creates a new connection request for the given address.
790    pub fn new(address: Address) -> Self {
791        Self { address }
792    }
793
794    /// The address this connection request is being made for.
795    pub fn address(&self) -> &Address {
796        &self.address
797    }
798}
799
800#[cfg(test)]
801mod test {
802    use std::sync::Arc;
803    use test_case::test_case;
804
805    use crate::VsockPacketIterator;
806
807    use super::*;
808
809    #[cfg(not(target_os = "fuchsia"))]
810    use fuchsia_async::emulated_handle::Socket as SyncSocket;
811    use fuchsia_async::{Socket, Task};
812    use futures::StreamExt;
813    #[cfg(target_os = "fuchsia")]
814    use zx::Socket as SyncSocket;
815
816    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
817        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
818        loop {
819            println!("waiting for usb packet");
820            builder = echo_connection.fill_usb_packet(builder).await.unwrap();
821            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
822            println!("got usb packet, echoing it back to the other side");
823            let mut packet_count = 0;
824            for packet in packets {
825                let packet = packet.unwrap();
826                match packet.header.packet_type {
827                    PacketType::Connect => {
828                        // respond with an accept packet
829                        let mut reply_header = packet.header.clone();
830                        reply_header.packet_type = PacketType::Accept;
831                        echo_connection
832                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
833                            .await
834                            .unwrap();
835                    }
836                    PacketType::Accept => {
837                        // just ignore it
838                    }
839                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
840                }
841                packet_count += 1;
842            }
843            println!("handled {packet_count} packets");
844        }
845    }
846
847    #[fuchsia::test]
848    async fn data_over_control_socket() {
849        let (socket, other_socket) = SyncSocket::create_stream();
850        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
851        let mut socket = Socket::from_socket(socket);
852        let connection = Arc::new(Connection::new(
853            ProtocolVersion::LATEST,
854            Some(Socket::from_socket(other_socket)),
855            incoming_requests_tx,
856        ));
857
858        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
859
860        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
861            println!("round tripping packet of size {size}");
862            socket.write_all(&vec![size; size as usize]).await.unwrap();
863            let mut buf = vec![0u8; size as usize];
864            socket.read_exact(&mut buf).await.unwrap();
865            assert_eq!(buf, vec![size; size as usize]);
866        }
867        echo_task.abort().await;
868    }
869
870    #[fuchsia::test]
871    async fn data_over_normal_outgoing_socket() {
872        let (_control_socket, other_socket) = SyncSocket::create_stream();
873        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
874        let connection = Arc::new(Connection::new(
875            ProtocolVersion::LATEST,
876            Some(Socket::from_socket(other_socket)),
877            incoming_requests_tx,
878        ));
879
880        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
881
882        let (socket, other_socket) = SyncSocket::create_stream();
883        let mut socket = Socket::from_socket(socket);
884        connection
885            .connect(
886                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
887                Socket::from_socket(other_socket),
888            )
889            .await
890            .unwrap();
891
892        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
893            println!("round tripping packet of size {size}");
894            socket.write_all(&vec![size; size as usize]).await.unwrap();
895            let mut buf = vec![0u8; size as usize];
896            socket.read_exact(&mut buf).await.unwrap();
897            assert_eq!(buf, vec![size; size as usize]);
898        }
899        echo_task.abort().await;
900    }
901
902    #[fuchsia::test]
903    async fn data_over_normal_incoming_socket() {
904        let (_control_socket, other_socket) = SyncSocket::create_stream();
905        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
906        let connection = Arc::new(Connection::new(
907            ProtocolVersion::LATEST,
908            Some(Socket::from_socket(other_socket)),
909            incoming_requests_tx,
910        ));
911
912        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
913
914        let header = &mut Header::new(PacketType::Connect);
915        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
916        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
917
918        let request = incoming_requests.next().await.unwrap();
919        assert_eq!(
920            request.address,
921            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
922        );
923
924        let (socket, other_socket) = SyncSocket::create_stream();
925        let mut socket = Socket::from_socket(socket);
926        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
927
928        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
929            println!("round tripping packet of size {size}");
930            socket.write_all(&vec![size; size as usize]).await.unwrap();
931            let mut buf = vec![0u8; size as usize];
932            socket.read_exact(&mut buf).await.unwrap();
933            assert_eq!(buf, vec![size; size as usize]);
934        }
935        echo_task.abort().await;
936    }
937
938    async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
939        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
940        loop {
941            builder = from.fill_usb_packet(builder).await.unwrap();
942            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
943            for packet in packets {
944                println!("forwarding vsock packet");
945                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
946            }
947        }
948    }
949
950    pub(crate) trait EndToEndTestFn<R>:
951        AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
952    {
953    }
954    impl<T, R> EndToEndTestFn<R> for T where
955        T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
956    {
957    }
958
959    pub(crate) async fn end_to_end_test<R1, R2>(
960        left_side: impl EndToEndTestFn<R1>,
961        right_side: impl EndToEndTestFn<R2>,
962    ) -> (R1, R2) {
963        type Connection = crate::Connection<Vec<u8>, Socket>;
964        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
965        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
966        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
967        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
968
969        let connection1 = Arc::new(Connection::new(
970            ProtocolVersion::LATEST,
971            Some(Socket::from_socket(other_socket1)),
972            incoming_requests_tx1,
973        ));
974        let connection2 = Arc::new(Connection::new(
975            ProtocolVersion::LATEST,
976            Some(Socket::from_socket(other_socket2)),
977            incoming_requests_tx2,
978        ));
979
980        let conn1 = connection1.clone();
981        let conn2 = connection2.clone();
982        let passthrough_task = Task::spawn(async move {
983            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
984            println!("passthrough task loop ended");
985        });
986
987        let res = futures::join!(
988            left_side(connection1, incoming_requests1),
989            right_side(connection2, incoming_requests2)
990        );
991        passthrough_task.abort().await;
992        res
993    }
994
995    #[fuchsia::test]
996    async fn data_over_end_to_end() {
997        end_to_end_test(
998            async |conn, _incoming| {
999                println!("sending request on connection 1");
1000                let (socket, other_socket) = SyncSocket::create_stream();
1001                let mut socket = Socket::from_socket(socket);
1002                let state = conn
1003                    .connect(
1004                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
1005                        Socket::from_socket(other_socket),
1006                    )
1007                    .await
1008                    .unwrap();
1009
1010                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1011                    println!("round tripping packet of size {size}");
1012                    socket.write_all(&vec![size; size as usize]).await.unwrap();
1013                }
1014                drop(socket);
1015                state.wait_for_close().await.unwrap();
1016            },
1017            async |conn, mut incoming| {
1018                println!("accepting request on connection 2");
1019                let request = incoming.next().await.unwrap();
1020                assert_eq!(
1021                    request.address,
1022                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
1023                );
1024
1025                let (socket, other_socket) = SyncSocket::create_stream();
1026                let mut socket = Socket::from_socket(socket);
1027                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1028
1029                println!("accepted request on connection 2");
1030                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1031                    let mut buf = vec![0u8; size as usize];
1032                    socket.read_exact(&mut buf).await.unwrap();
1033                    assert_eq!(buf, vec![size; size as usize]);
1034                }
1035                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1036                state.wait_for_close().await.unwrap();
1037            },
1038        )
1039        .await;
1040    }
1041
1042    #[fuchsia::test]
1043    async fn normal_close_end_to_end() {
1044        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1045        end_to_end_test(
1046            async |conn, _incoming| {
1047                let (socket, other_socket) = SyncSocket::create_stream();
1048                let mut socket = Socket::from_socket(socket);
1049                let state =
1050                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1051                conn.close(&addr).await;
1052                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1053                state.wait_for_close().await.unwrap();
1054            },
1055            async |conn, mut incoming| {
1056                println!("accepting request on connection 2");
1057                let request = incoming.next().await.unwrap();
1058                assert_eq!(request.address, addr.clone(),);
1059
1060                let (socket, other_socket) = SyncSocket::create_stream();
1061                let mut socket = Socket::from_socket(socket);
1062                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1063                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1064                state.wait_for_close().await.unwrap();
1065            },
1066        )
1067        .await;
1068    }
1069
1070    #[fuchsia::test]
1071    async fn reset_end_to_end() {
1072        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1073        end_to_end_test(
1074            async |conn, _incoming| {
1075                let (socket, other_socket) = SyncSocket::create_stream();
1076                let mut socket = Socket::from_socket(socket);
1077                let state =
1078                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1079                conn.reset(&addr).await.unwrap();
1080                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1081                state.wait_for_close().await.expect_err("expected reset");
1082            },
1083            async |conn, mut incoming| {
1084                println!("accepting request on connection 2");
1085                let request = incoming.next().await.unwrap();
1086                assert_eq!(request.address, addr.clone(),);
1087
1088                let (socket, other_socket) = SyncSocket::create_stream();
1089                let mut socket = Socket::from_socket(socket);
1090                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1091                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1092                state.wait_for_close().await.unwrap();
1093            },
1094        )
1095        .await;
1096    }
1097
1098    #[test_case(false; "in packet handling")]
1099    #[test_case(true; "in reply wait")]
1100    #[fuchsia::test]
1101    async fn conn_shutdown(fill_packets: bool) {
1102        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
1103
1104        let connection = Arc::new(Connection::<Vec<u8>, fuchsia_async::Socket>::new(
1105            ProtocolVersion::LATEST,
1106            None,
1107            incoming_requests_tx,
1108        ));
1109
1110        let mut filler = if fill_packets {
1111            Some(std::pin::pin!(connection.fill_usb_packet(UsbPacketBuilder::new(Vec::new()))))
1112        } else {
1113            None
1114        };
1115
1116        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1117        let mut fut = std::pin::pin!(connection.connect_late(addr));
1118
1119        for _ in 0..5 {
1120            assert!(fut.as_mut().poll(&mut Context::from_waker(Waker::noop())).is_pending());
1121            if let Some(filler) = filler.as_mut() {
1122                assert!(filler.as_mut().poll(&mut Context::from_waker(Waker::noop())).is_pending())
1123            }
1124        }
1125
1126        connection.shutdown();
1127        let Poll::Ready(res) = fut.poll(&mut Context::from_waker(Waker::noop())) else { panic!() };
1128        assert!(res.is_err());
1129        if let Some(filler) = filler {
1130            let Poll::Ready(res) = filler.poll(&mut Context::from_waker(Waker::noop())) else {
1131                panic!()
1132            };
1133            assert!(res.is_err());
1134        }
1135    }
1136}