_usb_vsock_service_driver_rustc/
lib.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fdf_component::{driver_register, Driver, DriverContext, Node};
6use fidl::endpoints::create_endpoints;
7use fuchsia_async::scope::ScopeStream;
8use fuchsia_async::{Scope, Socket};
9use fuchsia_component::server::ServiceFs;
10use futures::channel::mpsc;
11use futures::future::{select, Either};
12use futures::io::{ReadHalf, WriteHalf};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt};
14use log::{debug, error, info, warn};
15use std::io::Error;
16use std::pin::pin;
17use std::sync::Arc;
18use usb_vsock::{
19    Connection, ConnectionRequest, Header, Packet, PacketType, UsbPacketBuilder,
20    VsockPacketIterator,
21};
22use zx::{SocketOpts, Status};
23use {fidl_fuchsia_hardware_overnet as overnet, fidl_fuchsia_hardware_vsock as vsock};
24
25mod vsock_service;
26
27use vsock_service::VsockService;
28
29static MTU: usize = 1024;
30
31struct UsbVsockServiceDriver {
32    /// A scope for async tasks running under this driver
33    _scope: Scope,
34    /// The [`Node`] is our handle to the node we bound to. We need to keep this handle
35    /// open to keep the node around.
36    _node: Node,
37}
38
39driver_register!(UsbVsockServiceDriver);
40
41/// Processes a connection to the underlying USB device through a datagram socket where each
42/// packet received or sent corresponds to a USB bulk transfer buffer. It will call the callback
43/// with a new link and close the old one whenever a magic reset packet is received from the host.
44struct UsbConnection {
45    vsock_service: Arc<VsockService<Vec<u8>>>,
46    usb_socket_reader: ReadHalf<Socket>,
47    usb_socket_writer: WriteHalf<Socket>,
48    connection_tx: mpsc::Sender<ConnectionRequest>,
49}
50
51impl UsbConnection {
52    fn new(
53        vsock_service: Arc<VsockService<Vec<u8>>>,
54        usb_socket: zx::Socket,
55        connection_tx: mpsc::Sender<ConnectionRequest>,
56    ) -> Self {
57        assert!(
58            usb_socket.info().unwrap().options.contains(SocketOpts::DATAGRAM),
59            "USB socket must be a datagram socket"
60        );
61        let (usb_socket_reader, usb_socket_writer) = Socket::from_socket(usb_socket).split();
62        Self { vsock_service, usb_socket_reader, usb_socket_writer, connection_tx }
63    }
64
65    /// Waits for an [`PacketType::Sync`] packet and sends the reply back, and then returns a
66    /// fresh control socket for the new connection
67    async fn next_socket(&mut self, mut found_magic: Option<Vec<u8>>) -> Option<Socket> {
68        let mut data = [0; MTU];
69        while found_magic.is_none() {
70            let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
71            {
72                Ok(None) => {
73                    debug!("Usb socket closed");
74                    return None;
75                }
76                Err(err) => {
77                    error!("Unexpected error on usb socket: {err}");
78                    return None;
79                }
80                Ok(Some(packets)) => packets,
81            };
82
83            while let Some(packet) = packets.next() {
84                // note: we will deliberately warn and ignore for any vsock packets in the same
85                // usb packet as a sync packet, regardless of whether they were before or after.
86                match packet {
87                    Ok(Packet {
88                        header: Header { packet_type: PacketType::Sync, .. },
89                        payload,
90                    }) => {
91                        found_magic = Some(payload.to_owned());
92                    }
93                    Ok(packet) => {
94                        warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
95                    }
96                    Err(err) => {
97                        warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
98                    }
99                }
100            }
101        }
102        let found_magic =
103            found_magic.expect("read loop should not terminate until sync packet is read");
104
105        debug!("Read sync packet, sending it back and setting up a new link");
106        let mut header = Header::new(PacketType::Sync);
107        header.payload_len = (found_magic.len() as u32).into();
108        let packet = Packet { header: &header, payload: &found_magic };
109        packet.write_to_unchecked(&mut data);
110        if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
111            error!("Error writing overnet magic string to the usb socket: {err:?}");
112            return None;
113        }
114        let (next_control_socket, other_end) = zx::Socket::create_stream();
115        // TODO(406262417): this is only here because the host side has trouble with hanging
116        // gets and sending some data immediately after will help it clear and re-establish its state.
117        Socket::from_socket(other_end).write_all(b"hello").await.ok();
118        // after writing to the 'other_end' we drop this socket end because we don't expect any
119        // further data on the control socket, as it's currently unused. In the future if we want
120        // to have side channel data flow between the host and driver, this is the socket it would
121        // go in.
122        return Some(Socket::from_socket(next_control_socket));
123    }
124
125    async fn run(mut self) {
126        let mut found_magic = None;
127        loop {
128            let Some(control_socket) = self.next_socket(found_magic).await else {
129                info!("USB socket closed or failed");
130                return;
131            };
132            // reset whether we found the magic string last time around or not.
133            found_magic = None;
134            let connection = Arc::new(Connection::new(control_socket, self.connection_tx.clone()));
135            self.vsock_service.set_connection(connection.clone()).await;
136            let usb_socket_writer =
137                usb_socket_writer::<MTU>(&connection, &mut self.usb_socket_writer);
138            let usb_socket_reader = usb_socket_reader::<MTU>(
139                &mut found_magic,
140                &mut self.usb_socket_reader,
141                &connection,
142            );
143            let client_socket_copy = pin!(usb_socket_writer);
144            let usb_socket_copy = pin!(usb_socket_reader);
145            let res = select(client_socket_copy, usb_socket_copy).await;
146            match res {
147                Either::Left((Err(err), _)) => {
148                    warn!("Error on client to usb socket transfer: {err:?}");
149                }
150                Either::Left((Ok(_), _)) => {
151                    debug!("client to usb socket closed normally");
152                }
153                Either::Right((Err(err), _)) => {
154                    warn!("Error on usb to client socket transfer: {err:?}");
155                }
156                Either::Right((Ok(_), _)) => {
157                    info!("usb to client socket closed normally");
158                }
159            }
160        }
161    }
162}
163
164async fn read_packet_stream<'a>(
165    reader: &mut (impl AsyncRead + Unpin),
166    mut buffer: &'a mut [u8],
167) -> Result<Option<VsockPacketIterator<'a>>, std::io::Error> {
168    let size = reader.read(&mut buffer).await?;
169    if size == 0 {
170        return Ok(None);
171    }
172    Ok(Some(VsockPacketIterator::new(&buffer[0..size])))
173}
174
175async fn usb_socket_writer<const MTU: usize>(
176    connection: &Connection<Vec<u8>>,
177    usb_writer: &mut (impl AsyncWrite + Unpin),
178) -> Result<(), Error> {
179    let mut builder = UsbPacketBuilder::new(vec![0; MTU]);
180    loop {
181        builder = connection.fill_usb_packet(builder).await;
182        let buf = builder.take_usb_packet().unwrap();
183        assert_eq!(
184            buf.len(),
185            usb_writer.write(buf).await?,
186            "datagram socket sent incomplete packet"
187        );
188    }
189}
190
191async fn usb_socket_reader<const MTU: usize>(
192    found_magic: &mut Option<Vec<u8>>,
193    usb_reader: &mut (impl AsyncRead + Unpin),
194    connection: &Connection<Vec<u8>>,
195) -> Result<(), Error> {
196    let mut data = [0; MTU];
197    loop {
198        let Some(mut packets) = read_packet_stream(usb_reader, &mut data).await? else {
199            break;
200        };
201        while let Some(packet) = packets.next() {
202            match packet {
203                Ok(Packet { header: Header { packet_type: PacketType::Sync, .. }, payload }) => {
204                    debug!("Found sync packet, ending stream");
205                    *found_magic = Some(payload.to_owned());
206                    return Ok(());
207                }
208                Ok(packet) => connection.handle_vsock_packet(packet).await?,
209                Err(err) => {
210                    error!("Failed to parse vsock packet, going back to waiting for sync packet: {err:?}");
211                    break;
212                }
213            }
214        }
215    }
216    Ok(())
217}
218
219/// Processes a stream of device connections from the parent driver, and for each one initiates a
220/// [`UsbConnection`] process to handle individual connections to the host process.
221struct UsbCallbackHandler {
222    usb_callback_server: overnet::CallbackRequestStream,
223    connection_tx: mpsc::Sender<ConnectionRequest>,
224}
225
226impl UsbCallbackHandler {
227    async fn run(mut self, vsock_service: Arc<VsockService<Vec<u8>>>) -> Result<(), fidl::Error> {
228        use overnet::CallbackRequest::*;
229        while let Some(req) = self.usb_callback_server.try_next().await? {
230            let NewLink { socket, responder } = req;
231            responder.send()?;
232
233            debug!("Received new socket from usb driver");
234            UsbConnection::new(vsock_service.clone(), socket, self.connection_tx.clone())
235                .run()
236                .await;
237        }
238        Ok(())
239    }
240}
241
242impl Driver for UsbVsockServiceDriver {
243    const NAME: &str = "usb-vsock-service";
244
245    async fn start(mut context: DriverContext) -> Result<Self, Status> {
246        let node = context.take_node()?;
247        let scope = Scope::new_with_name(Self::NAME);
248        let mut outgoing = ServiceFs::new();
249
250        let usb_device = get_usb_device(&context)?;
251
252        info!("Offering a vsock service in the outgoing directory");
253        outgoing.dir("svc").add_fidl_service_instance("default", move |i| {
254            let vsock::ServiceRequest::Device(request_stream) = i;
255            request_stream
256        });
257
258        context.serve_outgoing(&mut outgoing)?;
259
260        scope.spawn(async move {
261            while let Some(request_stream) = outgoing.next().await {
262                let (usb_callback, usb_callback_server) = create_endpoints();
263                usb_device.set_callback(usb_callback).await.expect("usb device service went away");
264
265                run_connection(usb_callback_server.into_stream(), request_stream).await
266            }
267        });
268
269        Ok(Self { _scope: scope, _node: node })
270    }
271
272    async fn stop(&self) {}
273}
274
275async fn run_connection(
276    usb_callback_server: overnet::CallbackRequestStream,
277    mut request_stream: vsock::DeviceRequestStream,
278) {
279    debug!("Waiting for start message on vsock implementation service");
280    let (connection_tx, incoming_connections) = mpsc::channel(1);
281    let svc = match VsockService::wait_for_start(incoming_connections, &mut request_stream).await {
282        Ok(svc) => svc,
283        Err(err) => {
284            error!("Error while waiting for start message from vsock client: {err:?}");
285            return;
286        }
287    };
288    debug!(
289        "Received start message on vsock implementation service, waiting for usb socket handles"
290    );
291
292    let svc = Arc::new(svc);
293    let (mut scopes_stream, scopes) = ScopeStream::new_with_name("usb-vsock-connection".to_owned());
294
295    let usb_callback_handler =
296        UsbCallbackHandler { usb_callback_server, connection_tx: connection_tx.clone() };
297    let usb_svc = svc.clone();
298    scopes.push(async move {
299        if let Err(err) = usb_callback_handler.run(usb_svc).await {
300            error!("Error while waiting for usb device callbacks: {err:?}");
301        }
302    });
303    scopes.push(async move {
304        if let Err(err) = svc.run(request_stream).await {
305            error!("Error while servicing vsock client: {err:?}");
306        }
307    });
308    // wait for either to finish and then wait for a new client instead.
309    scopes_stream.next().await;
310}
311
312fn get_usb_device(context: &DriverContext) -> Result<overnet::UsbProxy, Status> {
313    let service_proxy = context.incoming.service_marker(overnet::UsbServiceMarker).connect()?;
314
315    service_proxy.connect_to_device().map_err(|err| {
316        error!("Error connecting to usb device proxy at driver startup: {err}");
317        Status::INTERNAL
318    })
319}
320
321#[cfg(test)]
322mod tests {
323    use fidl::endpoints::create_endpoints;
324    use fidl_fuchsia_vsock as vsock_api;
325    use futures::channel::oneshot;
326    use futures::future::join;
327    use log::trace;
328
329    use super::*;
330
331    async fn end_to_end_test(
332        device_side: impl AsyncFn(vsock_api::ConnectorProxy),
333        host_side: impl AsyncFn(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>),
334    ) {
335        let scope = Scope::new();
336        let (vsock_impl_client, vsock_impl_server) = create_endpoints::<vsock::DeviceMarker>();
337        let (usb_callback_client, usb_callback_server) =
338            create_endpoints::<overnet::CallbackMarker>();
339        scope.spawn(run_connection(
340            usb_callback_server.into_stream(),
341            vsock_impl_server.into_stream(),
342        ));
343        let usb_callback_client = usb_callback_client.into_proxy();
344
345        let (vsock_api_service, vsock_api_future) =
346            vsock_service_lib::Vsock::new(Some(vsock_impl_client.into_proxy()), None)
347                .await
348                .unwrap();
349        scope.spawn_local(async move {
350            vsock_api_future.await.unwrap();
351        });
352
353        let (vsock_api_client, vsock_api_server) = create_endpoints::<vsock_api::ConnectorMarker>();
354        scope.spawn_local(vsock_api_service.run_client_connection(vsock_api_server.into_stream()));
355        let vsock_api_client = vsock_api_client.into_proxy();
356
357        let (usb_packet_socket, usb_packet_server) = zx::Socket::create_datagram();
358        let (mut usb_packet_reader, mut usb_packet_writer) =
359            Socket::from_socket(usb_packet_socket).split();
360        usb_callback_client.new_link(usb_packet_server).await.unwrap();
361
362        let (incoming_tx, incoming_rx) = mpsc::channel(1);
363        let (_control_socket, other_end) = zx::Socket::create_stream();
364        let host_connection =
365            Arc::new(Connection::new(Socket::from_socket(other_end), incoming_tx));
366
367        let header = &mut Header::new(PacketType::Sync);
368        let payload = b"hello!";
369        header.payload_len.set(payload.len() as u32);
370        let sync_packet = Packet { header, payload };
371        let mut buf = [0; 1024];
372        sync_packet.write_to_unchecked(&mut buf);
373        assert_eq!(
374            sync_packet.size(),
375            usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
376        );
377
378        let writer_connection = host_connection.clone();
379        scope.spawn(async move {
380            let mut buf = UsbPacketBuilder::new(vec![0; 4096]);
381            loop {
382                buf = writer_connection.fill_usb_packet(buf).await;
383                let buf = buf.take_usb_packet().unwrap();
384                for packet in VsockPacketIterator::new(buf) {
385                    let packet = packet.unwrap();
386                    trace!("sending packet {packet:?}");
387                }
388                let _ = usb_packet_writer.write(buf).await.unwrap();
389            }
390        });
391
392        let reader_connection = host_connection.clone();
393        let (synchronized_tx, synchronized) = oneshot::channel();
394        let mut synchronized_tx = Some(synchronized_tx);
395        scope.spawn(async move {
396            let mut buf = vec![0; 4096];
397            while let Ok(bytes) = usb_packet_reader.read(&mut buf).await {
398                for packet in VsockPacketIterator::new(&buf[..bytes]) {
399                    let packet = packet.unwrap();
400                    trace!("received packet {packet:?}");
401                    if packet.header.packet_type == PacketType::Sync {
402                        assert_eq!(packet.payload, b"hello!");
403                        synchronized_tx.take().unwrap().send(()).unwrap();
404                        continue;
405                    }
406                    reader_connection.handle_vsock_packet(packet).await.unwrap();
407                }
408            }
409        });
410
411        synchronized.await.unwrap();
412
413        let device = device_side(vsock_api_client);
414        let host = host_side(host_connection, incoming_rx);
415        join(device, host).await;
416    }
417
418    #[fuchsia::test(allow_stalls = false)]
419    async fn test_device_to_host_connection() {
420        end_to_end_test(
421            async move |vsock_api_client| {
422                let (socket, data) = zx::Socket::create_stream();
423                let mut socket = Socket::from_socket(socket);
424                let (_con, con) = create_endpoints();
425                vsock_api_client
426                    .connect(2, 200, vsock_api::ConnectionTransport { data, con })
427                    .await
428                    .unwrap()
429                    .unwrap();
430                let mut buf = [0; 4];
431                socket.read_exact(&mut buf).await.unwrap();
432                assert_eq!(&buf, b"boom");
433                socket.write_all(b"zoom").await.unwrap();
434                assert_eq!(0, socket.read(&mut buf).await.unwrap());
435                trace!("vsock api fin");
436            },
437            async move |host_connection, mut incoming_rx| {
438                let incoming = incoming_rx.next().await.unwrap();
439                trace!("{incoming:?}");
440                let (socket, other_end) = zx::Socket::create_stream();
441                let mut socket = Socket::from_socket(socket);
442                let _state =
443                    host_connection.accept(incoming, Socket::from_socket(other_end)).await.unwrap();
444                socket.write_all(b"boom").await.unwrap();
445                let mut buf = [0; 4];
446                socket.read_exact(&mut buf).await.unwrap();
447                assert_eq!(&buf, b"zoom");
448                trace!("host fin");
449            },
450        )
451        .await;
452    }
453
454    #[fuchsia::test(allow_stalls = false)]
455    async fn test_host_to_device_connection() {
456        end_to_end_test(
457            async move |vsock_api_client| {
458                let (other_end, acceptor) = create_endpoints::<vsock_api::AcceptorMarker>();
459                let mut acceptor = acceptor.into_stream();
460                vsock_api_client.listen(200, other_end).await.unwrap().unwrap();
461                let vsock_api::AcceptorRequest::Accept { addr, responder } =
462                    acceptor.next().await.unwrap().unwrap();
463                assert_eq!(addr, vsock::Addr { local_port: 200, remote_cid: 2, remote_port: 9000 });
464
465                let (socket, data) = zx::Socket::create_stream();
466                let mut socket = Socket::from_socket(socket);
467                let (_con, con) = create_endpoints();
468                responder.send(Some(vsock_api::ConnectionTransport { data, con })).unwrap();
469
470                let mut buf = [0; 4];
471                socket.read_exact(&mut buf).await.unwrap();
472                assert_eq!(&buf, b"boom");
473                socket.write_all(b"zoom").await.unwrap();
474                assert_eq!(0, socket.read(&mut buf).await.unwrap());
475                trace!("vsock api fin");
476            },
477            async move |host_connection, _incoming_rx| {
478                let (socket, other_end) = zx::Socket::create_stream();
479                let mut socket = Socket::from_socket(socket);
480                let _state = host_connection
481                    .connect(
482                        usb_vsock::Address {
483                            host_cid: 2,
484                            host_port: 9000,
485                            device_cid: 3,
486                            device_port: 200,
487                        },
488                        Socket::from_socket(other_end),
489                    )
490                    .await
491                    .unwrap();
492
493                socket.write_all(b"boom").await.unwrap();
494                let mut buf = [0; 4];
495                socket.read_exact(&mut buf).await.unwrap();
496                assert_eq!(&buf, b"zoom");
497                trace!("host fin");
498            },
499        )
500        .await;
501    }
502}