1use fdf_component::{driver_register, Driver, DriverContext, Node};
6use fidl::endpoints::create_endpoints;
7use fuchsia_async::scope::ScopeStream;
8use fuchsia_async::{Scope, Socket};
9use fuchsia_component::server::ServiceFs;
10use futures::channel::mpsc;
11use futures::future::{select, Either};
12use futures::io::{ReadHalf, WriteHalf};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt};
14use log::{debug, error, info, warn};
15use std::io::Error;
16use std::pin::pin;
17use std::sync::Arc;
18use usb_vsock::{
19 Connection, ConnectionRequest, Header, Packet, PacketType, UsbPacketBuilder,
20 VsockPacketIterator,
21};
22use zx::{SocketOpts, Status};
23use {fidl_fuchsia_hardware_overnet as overnet, fidl_fuchsia_hardware_vsock as vsock};
24
25mod vsock_service;
26
27use vsock_service::VsockService;
28
29static MTU: usize = 1024;
30
31struct UsbVsockServiceDriver {
32 _scope: Scope,
34 _node: Node,
37}
38
39driver_register!(UsbVsockServiceDriver);
40
41struct UsbConnection {
45 vsock_service: Arc<VsockService<Vec<u8>>>,
46 usb_socket_reader: ReadHalf<Socket>,
47 usb_socket_writer: WriteHalf<Socket>,
48 connection_tx: mpsc::Sender<ConnectionRequest>,
49}
50
51impl UsbConnection {
52 fn new(
53 vsock_service: Arc<VsockService<Vec<u8>>>,
54 usb_socket: zx::Socket,
55 connection_tx: mpsc::Sender<ConnectionRequest>,
56 ) -> Self {
57 assert!(
58 usb_socket.info().unwrap().options.contains(SocketOpts::DATAGRAM),
59 "USB socket must be a datagram socket"
60 );
61 let (usb_socket_reader, usb_socket_writer) = Socket::from_socket(usb_socket).split();
62 Self { vsock_service, usb_socket_reader, usb_socket_writer, connection_tx }
63 }
64
65 async fn next_socket(&mut self, mut found_magic: Option<Vec<u8>>) -> Option<Socket> {
68 let mut data = [0; MTU];
69 while found_magic.is_none() {
70 let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
71 {
72 Ok(None) => {
73 debug!("Usb socket closed");
74 return None;
75 }
76 Err(err) => {
77 error!("Unexpected error on usb socket: {err}");
78 return None;
79 }
80 Ok(Some(packets)) => packets,
81 };
82
83 while let Some(packet) = packets.next() {
84 match packet {
87 Ok(Packet {
88 header: Header { packet_type: PacketType::Sync, .. },
89 payload,
90 }) => {
91 found_magic = Some(payload.to_owned());
92 }
93 Ok(packet) => {
94 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
95 }
96 Err(err) => {
97 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
98 }
99 }
100 }
101 }
102 let found_magic =
103 found_magic.expect("read loop should not terminate until sync packet is read");
104
105 debug!("Read sync packet, sending it back and setting up a new link");
106 let mut header = Header::new(PacketType::Sync);
107 header.payload_len = (found_magic.len() as u32).into();
108 let packet = Packet { header: &header, payload: &found_magic };
109 packet.write_to_unchecked(&mut data);
110 if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
111 error!("Error writing overnet magic string to the usb socket: {err:?}");
112 return None;
113 }
114 let (next_control_socket, other_end) = zx::Socket::create_stream();
115 Socket::from_socket(other_end).write_all(b"hello").await.ok();
118 return Some(Socket::from_socket(next_control_socket));
123 }
124
125 async fn run(mut self) {
126 let mut found_magic = None;
127 loop {
128 let Some(control_socket) = self.next_socket(found_magic).await else {
129 info!("USB socket closed or failed");
130 return;
131 };
132 found_magic = None;
134 let connection = Arc::new(Connection::new(control_socket, self.connection_tx.clone()));
135 self.vsock_service.set_connection(connection.clone()).await;
136 let usb_socket_writer =
137 usb_socket_writer::<MTU>(&connection, &mut self.usb_socket_writer);
138 let usb_socket_reader = usb_socket_reader::<MTU>(
139 &mut found_magic,
140 &mut self.usb_socket_reader,
141 &connection,
142 );
143 let client_socket_copy = pin!(usb_socket_writer);
144 let usb_socket_copy = pin!(usb_socket_reader);
145 let res = select(client_socket_copy, usb_socket_copy).await;
146 match res {
147 Either::Left((Err(err), _)) => {
148 warn!("Error on client to usb socket transfer: {err:?}");
149 }
150 Either::Left((Ok(_), _)) => {
151 debug!("client to usb socket closed normally");
152 }
153 Either::Right((Err(err), _)) => {
154 warn!("Error on usb to client socket transfer: {err:?}");
155 }
156 Either::Right((Ok(_), _)) => {
157 info!("usb to client socket closed normally");
158 }
159 }
160 }
161 }
162}
163
164async fn read_packet_stream<'a>(
165 reader: &mut (impl AsyncRead + Unpin),
166 mut buffer: &'a mut [u8],
167) -> Result<Option<VsockPacketIterator<'a>>, std::io::Error> {
168 let size = reader.read(&mut buffer).await?;
169 if size == 0 {
170 return Ok(None);
171 }
172 Ok(Some(VsockPacketIterator::new(&buffer[0..size])))
173}
174
175async fn usb_socket_writer<const MTU: usize>(
176 connection: &Connection<Vec<u8>>,
177 usb_writer: &mut (impl AsyncWrite + Unpin),
178) -> Result<(), Error> {
179 let mut builder = UsbPacketBuilder::new(vec![0; MTU]);
180 loop {
181 builder = connection.fill_usb_packet(builder).await;
182 let buf = builder.take_usb_packet().unwrap();
183 assert_eq!(
184 buf.len(),
185 usb_writer.write(buf).await?,
186 "datagram socket sent incomplete packet"
187 );
188 }
189}
190
191async fn usb_socket_reader<const MTU: usize>(
192 found_magic: &mut Option<Vec<u8>>,
193 usb_reader: &mut (impl AsyncRead + Unpin),
194 connection: &Connection<Vec<u8>>,
195) -> Result<(), Error> {
196 let mut data = [0; MTU];
197 loop {
198 let Some(mut packets) = read_packet_stream(usb_reader, &mut data).await? else {
199 break;
200 };
201 while let Some(packet) = packets.next() {
202 match packet {
203 Ok(Packet { header: Header { packet_type: PacketType::Sync, .. }, payload }) => {
204 debug!("Found sync packet, ending stream");
205 *found_magic = Some(payload.to_owned());
206 return Ok(());
207 }
208 Ok(packet) => connection.handle_vsock_packet(packet).await?,
209 Err(err) => {
210 error!("Failed to parse vsock packet, going back to waiting for sync packet: {err:?}");
211 break;
212 }
213 }
214 }
215 }
216 Ok(())
217}
218
219struct UsbCallbackHandler {
222 usb_callback_server: overnet::CallbackRequestStream,
223 connection_tx: mpsc::Sender<ConnectionRequest>,
224}
225
226impl UsbCallbackHandler {
227 async fn run(mut self, vsock_service: Arc<VsockService<Vec<u8>>>) -> Result<(), fidl::Error> {
228 use overnet::CallbackRequest::*;
229 while let Some(req) = self.usb_callback_server.try_next().await? {
230 let NewLink { socket, responder } = req;
231 responder.send()?;
232
233 debug!("Received new socket from usb driver");
234 UsbConnection::new(vsock_service.clone(), socket, self.connection_tx.clone())
235 .run()
236 .await;
237 }
238 Ok(())
239 }
240}
241
242impl Driver for UsbVsockServiceDriver {
243 const NAME: &str = "usb-vsock-service";
244
245 async fn start(mut context: DriverContext) -> Result<Self, Status> {
246 let node = context.take_node()?;
247 let scope = Scope::new_with_name(Self::NAME);
248 let mut outgoing = ServiceFs::new();
249
250 let usb_device = get_usb_device(&context)?;
251
252 info!("Offering a vsock service in the outgoing directory");
253 outgoing.dir("svc").add_fidl_service_instance("default", move |i| {
254 let vsock::ServiceRequest::Device(request_stream) = i;
255 request_stream
256 });
257
258 context.serve_outgoing(&mut outgoing)?;
259
260 scope.spawn(async move {
261 while let Some(request_stream) = outgoing.next().await {
262 let (usb_callback, usb_callback_server) = create_endpoints();
263 usb_device.set_callback(usb_callback).await.expect("usb device service went away");
264
265 run_connection(usb_callback_server.into_stream(), request_stream).await
266 }
267 });
268
269 Ok(Self { _scope: scope, _node: node })
270 }
271
272 async fn stop(&self) {}
273}
274
275async fn run_connection(
276 usb_callback_server: overnet::CallbackRequestStream,
277 mut request_stream: vsock::DeviceRequestStream,
278) {
279 debug!("Waiting for start message on vsock implementation service");
280 let (connection_tx, incoming_connections) = mpsc::channel(1);
281 let svc = match VsockService::wait_for_start(incoming_connections, &mut request_stream).await {
282 Ok(svc) => svc,
283 Err(err) => {
284 error!("Error while waiting for start message from vsock client: {err:?}");
285 return;
286 }
287 };
288 debug!(
289 "Received start message on vsock implementation service, waiting for usb socket handles"
290 );
291
292 let svc = Arc::new(svc);
293 let (mut scopes_stream, scopes) = ScopeStream::new_with_name("usb-vsock-connection".to_owned());
294
295 let usb_callback_handler =
296 UsbCallbackHandler { usb_callback_server, connection_tx: connection_tx.clone() };
297 let usb_svc = svc.clone();
298 scopes.push(async move {
299 if let Err(err) = usb_callback_handler.run(usb_svc).await {
300 error!("Error while waiting for usb device callbacks: {err:?}");
301 }
302 });
303 scopes.push(async move {
304 if let Err(err) = svc.run(request_stream).await {
305 error!("Error while servicing vsock client: {err:?}");
306 }
307 });
308 scopes_stream.next().await;
310}
311
312fn get_usb_device(context: &DriverContext) -> Result<overnet::UsbProxy, Status> {
313 let service_proxy = context.incoming.service_marker(overnet::UsbServiceMarker).connect()?;
314
315 service_proxy.connect_to_device().map_err(|err| {
316 error!("Error connecting to usb device proxy at driver startup: {err}");
317 Status::INTERNAL
318 })
319}
320
321#[cfg(test)]
322mod tests {
323 use fidl::endpoints::create_endpoints;
324 use fidl_fuchsia_vsock as vsock_api;
325 use futures::channel::oneshot;
326 use futures::future::join;
327 use log::trace;
328
329 use super::*;
330
331 async fn end_to_end_test(
332 device_side: impl AsyncFn(vsock_api::ConnectorProxy),
333 host_side: impl AsyncFn(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>),
334 ) {
335 let scope = Scope::new();
336 let (vsock_impl_client, vsock_impl_server) = create_endpoints::<vsock::DeviceMarker>();
337 let (usb_callback_client, usb_callback_server) =
338 create_endpoints::<overnet::CallbackMarker>();
339 scope.spawn(run_connection(
340 usb_callback_server.into_stream(),
341 vsock_impl_server.into_stream(),
342 ));
343 let usb_callback_client = usb_callback_client.into_proxy();
344
345 let (vsock_api_service, vsock_api_future) =
346 vsock_service_lib::Vsock::new(Some(vsock_impl_client.into_proxy()), None)
347 .await
348 .unwrap();
349 scope.spawn_local(async move {
350 vsock_api_future.await.unwrap();
351 });
352
353 let (vsock_api_client, vsock_api_server) = create_endpoints::<vsock_api::ConnectorMarker>();
354 scope.spawn_local(vsock_api_service.run_client_connection(vsock_api_server.into_stream()));
355 let vsock_api_client = vsock_api_client.into_proxy();
356
357 let (usb_packet_socket, usb_packet_server) = zx::Socket::create_datagram();
358 let (mut usb_packet_reader, mut usb_packet_writer) =
359 Socket::from_socket(usb_packet_socket).split();
360 usb_callback_client.new_link(usb_packet_server).await.unwrap();
361
362 let (incoming_tx, incoming_rx) = mpsc::channel(1);
363 let (_control_socket, other_end) = zx::Socket::create_stream();
364 let host_connection =
365 Arc::new(Connection::new(Socket::from_socket(other_end), incoming_tx));
366
367 let header = &mut Header::new(PacketType::Sync);
368 let payload = b"hello!";
369 header.payload_len.set(payload.len() as u32);
370 let sync_packet = Packet { header, payload };
371 let mut buf = [0; 1024];
372 sync_packet.write_to_unchecked(&mut buf);
373 assert_eq!(
374 sync_packet.size(),
375 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
376 );
377
378 let writer_connection = host_connection.clone();
379 scope.spawn(async move {
380 let mut buf = UsbPacketBuilder::new(vec![0; 4096]);
381 loop {
382 buf = writer_connection.fill_usb_packet(buf).await;
383 let buf = buf.take_usb_packet().unwrap();
384 for packet in VsockPacketIterator::new(buf) {
385 let packet = packet.unwrap();
386 trace!("sending packet {packet:?}");
387 }
388 let _ = usb_packet_writer.write(buf).await.unwrap();
389 }
390 });
391
392 let reader_connection = host_connection.clone();
393 let (synchronized_tx, synchronized) = oneshot::channel();
394 let mut synchronized_tx = Some(synchronized_tx);
395 scope.spawn(async move {
396 let mut buf = vec![0; 4096];
397 while let Ok(bytes) = usb_packet_reader.read(&mut buf).await {
398 for packet in VsockPacketIterator::new(&buf[..bytes]) {
399 let packet = packet.unwrap();
400 trace!("received packet {packet:?}");
401 if packet.header.packet_type == PacketType::Sync {
402 assert_eq!(packet.payload, b"hello!");
403 synchronized_tx.take().unwrap().send(()).unwrap();
404 continue;
405 }
406 reader_connection.handle_vsock_packet(packet).await.unwrap();
407 }
408 }
409 });
410
411 synchronized.await.unwrap();
412
413 let device = device_side(vsock_api_client);
414 let host = host_side(host_connection, incoming_rx);
415 join(device, host).await;
416 }
417
418 #[fuchsia::test(allow_stalls = false)]
419 async fn test_device_to_host_connection() {
420 end_to_end_test(
421 async move |vsock_api_client| {
422 let (socket, data) = zx::Socket::create_stream();
423 let mut socket = Socket::from_socket(socket);
424 let (_con, con) = create_endpoints();
425 vsock_api_client
426 .connect(2, 200, vsock_api::ConnectionTransport { data, con })
427 .await
428 .unwrap()
429 .unwrap();
430 let mut buf = [0; 4];
431 socket.read_exact(&mut buf).await.unwrap();
432 assert_eq!(&buf, b"boom");
433 socket.write_all(b"zoom").await.unwrap();
434 assert_eq!(0, socket.read(&mut buf).await.unwrap());
435 trace!("vsock api fin");
436 },
437 async move |host_connection, mut incoming_rx| {
438 let incoming = incoming_rx.next().await.unwrap();
439 trace!("{incoming:?}");
440 let (socket, other_end) = zx::Socket::create_stream();
441 let mut socket = Socket::from_socket(socket);
442 let _state =
443 host_connection.accept(incoming, Socket::from_socket(other_end)).await.unwrap();
444 socket.write_all(b"boom").await.unwrap();
445 let mut buf = [0; 4];
446 socket.read_exact(&mut buf).await.unwrap();
447 assert_eq!(&buf, b"zoom");
448 trace!("host fin");
449 },
450 )
451 .await;
452 }
453
454 #[fuchsia::test(allow_stalls = false)]
455 async fn test_host_to_device_connection() {
456 end_to_end_test(
457 async move |vsock_api_client| {
458 let (other_end, acceptor) = create_endpoints::<vsock_api::AcceptorMarker>();
459 let mut acceptor = acceptor.into_stream();
460 vsock_api_client.listen(200, other_end).await.unwrap().unwrap();
461 let vsock_api::AcceptorRequest::Accept { addr, responder } =
462 acceptor.next().await.unwrap().unwrap();
463 assert_eq!(addr, vsock::Addr { local_port: 200, remote_cid: 2, remote_port: 9000 });
464
465 let (socket, data) = zx::Socket::create_stream();
466 let mut socket = Socket::from_socket(socket);
467 let (_con, con) = create_endpoints();
468 responder.send(Some(vsock_api::ConnectionTransport { data, con })).unwrap();
469
470 let mut buf = [0; 4];
471 socket.read_exact(&mut buf).await.unwrap();
472 assert_eq!(&buf, b"boom");
473 socket.write_all(b"zoom").await.unwrap();
474 assert_eq!(0, socket.read(&mut buf).await.unwrap());
475 trace!("vsock api fin");
476 },
477 async move |host_connection, _incoming_rx| {
478 let (socket, other_end) = zx::Socket::create_stream();
479 let mut socket = Socket::from_socket(socket);
480 let _state = host_connection
481 .connect(
482 usb_vsock::Address {
483 host_cid: 2,
484 host_port: 9000,
485 device_cid: 3,
486 device_port: 200,
487 },
488 Socket::from_socket(other_end),
489 )
490 .await
491 .unwrap();
492
493 socket.write_all(b"boom").await.unwrap();
494 let mut buf = [0; 4];
495 socket.read_exact(&mut buf).await.unwrap();
496 assert_eq!(&buf, b"zoom");
497 trace!("host fin");
498 },
499 )
500 .await;
501 }
502}