_usb_vsock_service_driver_rustc/
vsock_service.rs
1use fidl_fuchsia_hardware_vsock::{self as vsock, Addr, CallbacksProxy};
6use fuchsia_async::{Scope, Socket};
7use futures::channel::mpsc;
8use futures::{StreamExt, TryStreamExt};
9use log::{debug, error, info};
10use std::io::Error;
11use std::sync::{self, Arc, Weak};
12use usb_vsock::{Address, Connection, PacketBuffer};
13use zx::Status;
14
15use crate::ConnectionRequest;
16
17pub struct VsockService<B> {
19 connection: sync::Mutex<Option<Weak<Connection<B>>>>,
20 callback: CallbacksProxy,
21 scope: Scope,
22}
23
24impl<B: PacketBuffer> VsockService<B> {
25 pub async fn wait_for_start(
27 incoming_connections: mpsc::Receiver<ConnectionRequest>,
28 requests: &mut vsock::DeviceRequestStream,
29 ) -> Result<Self, Error> {
30 use vsock::DeviceRequest::*;
31
32 let scope = Scope::new_with_name("vsock-service");
33 let Some(req) = requests.try_next().await.map_err(Error::other)? else {
34 return Err(Error::other(
35 "vsock client connected and disconnected without sending start message",
36 ));
37 };
38
39 match req {
40 Start { cb, responder } => {
41 info!("Client callback set for vsock client");
42 let connection = Default::default();
43 let callback = cb.into_proxy();
44 scope.spawn(Self::run_incoming_loop(incoming_connections, callback.clone()));
45 responder.send(Ok(())).map_err(Error::other)?;
46 Ok(Self { connection, callback, scope })
47 }
48 other => {
49 Err(Error::other(format!("unexpected message before start message: {other:?}")))
50 }
51 }
52 }
53
54 pub async fn set_connection(&self, conn: Arc<Connection<B>>) {
60 self.callback.transport_reset(3).await.unwrap_or_else(log_callback_error);
61 let mut current = self.connection.lock().unwrap();
62 if current.as_ref().and_then(Weak::upgrade).is_some() {
63 panic!("Can only have one active connection set at a time");
64 }
65 current.replace(Arc::downgrade(&conn));
66 }
67
68 fn get_connection(&self) -> Option<Arc<Connection<B>>> {
70 self.connection.lock().unwrap().as_ref().and_then(Weak::upgrade)
71 }
72
73 async fn send_request(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
74 let cb = self.callback.clone();
75 let connection = self.get_connection();
76 self.scope.spawn(async move {
77 let Some(connection) = connection else {
78 cb.rst(&addr).unwrap_or_else(log_callback_error);
81 return;
82 };
83 let status = match connection
84 .connect(from_fidl_addr(3, addr), Socket::from_socket(data))
85 .await
86 {
87 Ok(status) => status,
88 Err(err) => {
89 debug!("Connection request failed to connect with err {err:?}");
91 cb.rst(&addr).unwrap_or_else(log_callback_error);
92 return;
93 }
94 };
95 cb.response(&addr).unwrap_or_else(log_callback_error);
96 status.wait_for_close().await.ok();
97 cb.rst(&addr).unwrap_or_else(log_callback_error);
98 });
99 Ok(())
100 }
101
102 async fn send_shutdown(&self, addr: Addr) -> Result<(), Status> {
103 if let Some(connection) = self.get_connection() {
104 connection.close(&from_fidl_addr(3, addr)).await;
105 } else {
106 self.callback.rst(&addr).unwrap_or_else(log_callback_error);
108 }
109 Ok(())
110 }
111
112 async fn send_rst(&self, addr: Addr) -> Result<(), Status> {
113 if let Some(connection) = self.get_connection() {
114 connection.reset(&from_fidl_addr(3, addr)).await.ok();
115 }
116 Ok(())
117 }
118
119 async fn send_response(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
120 let address = from_fidl_addr(3, addr);
125 let request = ConnectionRequest::new(address.clone());
126 let Some(connection) = self.get_connection() else {
127 error!("Tried to accept connection for {address:?} on usb connection that is not open");
128 return Err(Status::BAD_STATE);
129 };
130 connection.accept(request, Socket::from_socket(data)).await.map_err(|err| {
131 error!("Failed to accept connection for {address:?}: {err:?}");
132 Err(Status::ADDRESS_UNREACHABLE)
133 })?;
134
135 Ok(())
136 }
137
138 async fn run_incoming_loop(
139 mut incoming_connections: mpsc::Receiver<ConnectionRequest>,
140 proxy: CallbacksProxy,
141 ) {
142 loop {
143 let Some(next) = incoming_connections.next().await else {
144 return;
145 };
146 if let Err(err) = proxy.request(&from_vsock_addr(*next.address())) {
147 error!("Error calling callback for incoming connection request: {err:?}");
148 return;
149 }
150 }
151 }
152
153 pub async fn run(&self, mut requests: vsock::DeviceRequestStream) -> Result<(), Error> {
156 use vsock::DeviceRequest::*;
157
158 while let Some(req) = requests.try_next().await.map_err(Error::other)? {
159 match req {
160 start @ Start { .. } => {
161 return Err(Error::other(format!(
162 "unexpected start message after one was already sent {start:?}"
163 )))
164 }
165 SendRequest { addr, data, responder } => responder
166 .send(self.send_request(addr, data).await.map_err(Status::into_raw))
167 .map_err(Error::other)?,
168 SendShutdown { addr, responder } => responder
169 .send(self.send_shutdown(addr).await.map_err(Status::into_raw))
170 .map_err(Error::other)?,
171 SendRst { addr, responder } => responder
172 .send(self.send_rst(addr).await.map_err(Status::into_raw))
173 .map_err(Error::other)?,
174 SendResponse { addr, data, responder } => responder
175 .send(self.send_response(addr, data).await.map_err(Status::into_raw))
176 .map_err(Error::other)?,
177 GetCid { responder } => responder.send(3).map_err(Error::other)?,
178 }
179 }
180 Ok(())
181 }
182}
183
184fn log_callback_error<E: std::error::Error>(err: E) {
185 error!("Error sending callback to vsock client: {err:?}")
186}
187
188fn from_fidl_addr(device_cid: u32, value: Addr) -> Address {
189 Address {
190 device_cid,
191 host_cid: value.remote_cid,
192 device_port: value.local_port,
193 host_port: value.remote_port,
194 }
195}
196
197fn from_vsock_addr(value: Address) -> Addr {
199 Addr { local_port: value.device_port, remote_cid: value.host_cid, remote_port: value.host_port }
200}