_usb_vsock_service_driver_rustc/
vsock_service.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fidl_fuchsia_hardware_vsock::{self as vsock, Addr, CallbacksProxy};
6use fuchsia_async::{Scope, Socket};
7use futures::channel::mpsc;
8use futures::{StreamExt, TryStreamExt};
9use log::{debug, error, info};
10use std::io::Error;
11use std::sync::{self, Arc, Weak};
12use usb_vsock::{Address, Connection, PacketBuffer};
13use zx::Status;
14
15use crate::ConnectionRequest;
16
17/// Implements the fuchsia.hardware.vsock service against a [`Connection`].
18pub struct VsockService<B> {
19    connection: sync::Mutex<Option<Weak<Connection<B>>>>,
20    callback: CallbacksProxy,
21    scope: Scope,
22}
23
24impl<B: PacketBuffer> VsockService<B> {
25    /// Waits for the start message from the client and returns a constructed [`VsockService`]
26    pub async fn wait_for_start(
27        incoming_connections: mpsc::Receiver<ConnectionRequest>,
28        requests: &mut vsock::DeviceRequestStream,
29    ) -> Result<Self, Error> {
30        use vsock::DeviceRequest::*;
31
32        let scope = Scope::new_with_name("vsock-service");
33        let Some(req) = requests.try_next().await.map_err(Error::other)? else {
34            return Err(Error::other(
35                "vsock client connected and disconnected without sending start message",
36            ));
37        };
38
39        match req {
40            Start { cb, responder } => {
41                info!("Client callback set for vsock client");
42                let connection = Default::default();
43                let callback = cb.into_proxy();
44                scope.spawn(Self::run_incoming_loop(incoming_connections, callback.clone()));
45                responder.send(Ok(())).map_err(Error::other)?;
46                Ok(Self { connection, callback, scope })
47            }
48            other => {
49                Err(Error::other(format!("unexpected message before start message: {other:?}")))
50            }
51        }
52    }
53
54    /// Set the current connection to be used by the vsock service server.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the current socket is already set.
59    pub async fn set_connection(&self, conn: Arc<Connection<B>>) {
60        self.callback.transport_reset(3).await.unwrap_or_else(log_callback_error);
61        let mut current = self.connection.lock().unwrap();
62        if current.as_ref().and_then(Weak::upgrade).is_some() {
63            panic!("Can only have one active connection set at a time");
64        }
65        current.replace(Arc::downgrade(&conn));
66    }
67
68    /// Gets the current connection if one is set.
69    fn get_connection(&self) -> Option<Arc<Connection<B>>> {
70        self.connection.lock().unwrap().as_ref().and_then(Weak::upgrade)
71    }
72
73    async fn send_request(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
74        let cb = self.callback.clone();
75        let connection = self.get_connection();
76        self.scope.spawn(async move {
77            let Some(connection) = connection else {
78                // immediately reject a connection request if we don't have a usb connection to
79                // put it on
80                cb.rst(&addr).unwrap_or_else(log_callback_error);
81                return;
82            };
83            let status = match connection
84                .connect(from_fidl_addr(3, addr), Socket::from_socket(data))
85                .await
86            {
87                Ok(status) => status,
88                Err(err) => {
89                    // connection failed
90                    debug!("Connection request failed to connect with err {err:?}");
91                    cb.rst(&addr).unwrap_or_else(log_callback_error);
92                    return;
93                }
94            };
95            cb.response(&addr).unwrap_or_else(log_callback_error);
96            status.wait_for_close().await.ok();
97            cb.rst(&addr).unwrap_or_else(log_callback_error);
98        });
99        Ok(())
100    }
101
102    async fn send_shutdown(&self, addr: Addr) -> Result<(), Status> {
103        if let Some(connection) = self.get_connection() {
104            connection.close(&from_fidl_addr(3, addr)).await;
105        } else {
106            // this connection can't exist so just tell the caller that it was reset.
107            self.callback.rst(&addr).unwrap_or_else(log_callback_error);
108        }
109        Ok(())
110    }
111
112    async fn send_rst(&self, addr: Addr) -> Result<(), Status> {
113        if let Some(connection) = self.get_connection() {
114            connection.reset(&from_fidl_addr(3, addr)).await.ok();
115        }
116        Ok(())
117    }
118
119    async fn send_response(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
120        // We cheat here and reconstitute the ConnectionRequest ourselves rather than try to thread
121        // it through the state machine. Since the main client of this particular api should be
122        // keeping track on its own, and we will ignore accepts of unknown addresses, this should be
123        // fine.
124        let address = from_fidl_addr(3, addr);
125        let request = ConnectionRequest::new(address.clone());
126        let Some(connection) = self.get_connection() else {
127            error!("Tried to accept connection for {address:?} on usb connection that is not open");
128            return Err(Status::BAD_STATE);
129        };
130        connection.accept(request, Socket::from_socket(data)).await.map_err(|err| {
131            error!("Failed to accept connection for {address:?}: {err:?}");
132            Err(Status::ADDRESS_UNREACHABLE)
133        })?;
134
135        Ok(())
136    }
137
138    async fn run_incoming_loop(
139        mut incoming_connections: mpsc::Receiver<ConnectionRequest>,
140        proxy: CallbacksProxy,
141    ) {
142        loop {
143            let Some(next) = incoming_connections.next().await else {
144                return;
145            };
146            if let Err(err) = proxy.request(&from_vsock_addr(*next.address())) {
147                error!("Error calling callback for incoming connection request: {err:?}");
148                return;
149            }
150        }
151    }
152
153    /// Runs the request loop for [`vsock::DeviceRequest`] against whatever the current [`Connection`]
154    /// is.
155    pub async fn run(&self, mut requests: vsock::DeviceRequestStream) -> Result<(), Error> {
156        use vsock::DeviceRequest::*;
157
158        while let Some(req) = requests.try_next().await.map_err(Error::other)? {
159            match req {
160                start @ Start { .. } => {
161                    return Err(Error::other(format!(
162                        "unexpected start message after one was already sent {start:?}"
163                    )))
164                }
165                SendRequest { addr, data, responder } => responder
166                    .send(self.send_request(addr, data).await.map_err(Status::into_raw))
167                    .map_err(Error::other)?,
168                SendShutdown { addr, responder } => responder
169                    .send(self.send_shutdown(addr).await.map_err(Status::into_raw))
170                    .map_err(Error::other)?,
171                SendRst { addr, responder } => responder
172                    .send(self.send_rst(addr).await.map_err(Status::into_raw))
173                    .map_err(Error::other)?,
174                SendResponse { addr, data, responder } => responder
175                    .send(self.send_response(addr, data).await.map_err(Status::into_raw))
176                    .map_err(Error::other)?,
177                GetCid { responder } => responder.send(3).map_err(Error::other)?,
178            }
179        }
180        Ok(())
181    }
182}
183
184fn log_callback_error<E: std::error::Error>(err: E) {
185    error!("Error sending callback to vsock client: {err:?}")
186}
187
188fn from_fidl_addr(device_cid: u32, value: Addr) -> Address {
189    Address {
190        device_cid,
191        host_cid: value.remote_cid,
192        device_port: value.local_port,
193        host_port: value.remote_port,
194    }
195}
196
197/// Leaves [`Address::device_cid`] blank, to be filled in by the caller
198fn from_vsock_addr(value: Address) -> Addr {
199    Addr { local_port: value.device_port, remote_cid: value.host_cid, remote_port: value.host_port }
200}