vsock_service_lib/
service.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// This module contains the bulk of the logic for connecting user applications to a
6// vsock driver.
7//
8// Handling user requests is complicated as there are multiple communication channels
9// involved. For example a request to 'connect' will result in sending a message
10// to the driver over the single DeviceProxy. If this returns with success then
11// eventually a message will come over the single Callbacks stream indicating
12// whether the remote accepted or rejected.
13//
14// Fundamentally then there needs to be mutual exclusion in accessing DeviceProxy,
15// and de-multiplexing of incoming messages on the Callbacks stream. There are
16// a two high level options for doing this.
17//  1. Force a single task event driver model. This would mean that additional
18//     asynchronous executions are never spawned, and any use of await! or otherwise
19//     blocking with additional futures requires collection futures in future sets
20//     or having custom polling logic etc. Whilst this is probably the most resource
21//     efficient it restricts the service to be single task forever by its design,
22//     is harder to reason about as cannot be written very idiomatically with futures
23//     and is even more complicated to avoid blocking other requests whilst waiting
24//     on responses from the driver.
25//  2. Allow multiple asynchronous executions and use some form of message passing
26//     and mutual exclusion checking to handle DeviceProxy access and sharing access
27//     to the Callbacks stream. Potentially more resource intensive with unnecessary
28//     refcells etc, but allows for the potential to have actual concurrent execution
29//     and is much simpler to write the logic.
30// The chosen option is (2) and the access to DeviceProxy is handled with an Rc<Refcell<State>>,
31// and de-multiplexing of the Callbacks is done by registering an event whilst holding
32// the refcell, and having a single asynchronous task that is dedicated to converting
33// incoming Callbacks to signaling registered events.
34
35use crate::{addr, port};
36use anyhow::{format_err, Context as _};
37use fidl::endpoints;
38use fidl::endpoints::{ControlHandle, RequestStream};
39use fidl_fuchsia_hardware_vsock::{
40    CallbacksMarker, CallbacksRequest, CallbacksRequestStream, DeviceProxy, VMADDR_CID_HOST,
41    VMADDR_CID_LOCAL,
42};
43use fidl_fuchsia_vsock::{
44    AcceptorProxy, ConnectionRequest, ConnectionRequestStream, ConnectionTransport,
45    ConnectorRequest, ConnectorRequestStream, ListenerControlHandle, ListenerRequest,
46    ListenerRequestStream, SIGNAL_STREAM_INCOMING,
47};
48use fuchsia_async as fasync;
49use futures::channel::{mpsc, oneshot};
50use futures::{future, select, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
51use std::cell::{Ref, RefCell, RefMut};
52use std::collections::{HashMap, VecDeque};
53use std::convert::Infallible;
54use std::ops::Deref;
55use std::pin::Pin;
56use std::rc::Rc;
57use std::task::{Context, Poll};
58use thiserror::Error;
59
60const ZXIO_SIGNAL_INCOMING: zx::Signals = zx::Signals::from_bits(SIGNAL_STREAM_INCOMING).unwrap();
61
62type Cid = u32;
63type Port = u32;
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65struct Addr(Cid, Port);
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68enum EventType {
69    Shutdown,
70    Response,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74struct Event {
75    action: EventType,
76    addr: addr::Vsock,
77}
78
79#[derive(Debug, Clone, Eq, PartialEq, Hash)]
80enum Deregister {
81    Event(Event),
82    Listen(Addr),
83    Port(Addr),
84}
85
86#[derive(Error, Debug)]
87enum Error {
88    #[error("Driver returned failure status {}", _0)]
89    Driver(#[source] zx::Status),
90    #[error("All ephemeral ports are allocated")]
91    OutOfPorts,
92    #[error("Addr has already been bound")]
93    AlreadyBound,
94    #[error("Connection refused by remote")]
95    ConnectionRefused,
96    #[error("Error whilst communication with client")]
97    ClientCommunication(#[source] anyhow::Error),
98    #[error("Error whilst communication with client")]
99    DriverCommunication(#[source] anyhow::Error),
100    #[error("Driver reset the connection")]
101    ConnectionReset,
102    #[error("There are no more connections in the accept queue")]
103    NoConnectionsInQueue,
104}
105
106impl From<oneshot::Canceled> for Error {
107    fn from(_: oneshot::Canceled) -> Error {
108        Error::ConnectionReset
109    }
110}
111
112impl Error {
113    pub fn into_status(&self) -> zx::Status {
114        match self {
115            Error::Driver(status) => *status,
116            Error::OutOfPorts => zx::Status::NO_RESOURCES,
117            Error::AlreadyBound => zx::Status::ALREADY_BOUND,
118            Error::ConnectionRefused => zx::Status::UNAVAILABLE,
119            Error::ClientCommunication(err) | Error::DriverCommunication(err) => {
120                *err.downcast_ref::<zx::Status>().unwrap_or(&zx::Status::INTERNAL)
121            }
122            Error::ConnectionReset => zx::Status::PEER_CLOSED,
123            Error::NoConnectionsInQueue => zx::Status::SHOULD_WAIT,
124        }
125    }
126    pub fn is_comm_failure(&self) -> bool {
127        match self {
128            Error::ClientCommunication(_) | Error::DriverCommunication(_) => true,
129            _ => false,
130        }
131    }
132}
133
134fn map_driver_result(result: Result<Result<(), i32>, fidl::Error>) -> Result<(), Error> {
135    result
136        .map_err(|x| Error::DriverCommunication(x.into()))?
137        .map_err(|e| Error::Driver(zx::Status::from_raw(e)))
138}
139
140struct SocketContextState {
141    port: Addr,
142    accept_queue: VecDeque<addr::Vsock>,
143    backlog: Option<u32>,
144    control: ListenerControlHandle,
145    signaled: bool,
146}
147
148#[derive(Clone)]
149pub struct SocketContext(Rc<RefCell<SocketContextState>>);
150
151impl SocketContext {
152    fn new(port: Addr, control: ListenerControlHandle) -> SocketContext {
153        SocketContext(Rc::new(RefCell::new(SocketContextState {
154            port,
155            accept_queue: VecDeque::new(),
156            backlog: None,
157            signaled: false,
158            control,
159        })))
160    }
161
162    fn listen(&self, backlog: u32) -> Result<(), Error> {
163        let mut ctx = self.0.borrow_mut();
164        if ctx.backlog.is_some() {
165            return Err(Error::AlreadyBound);
166        }
167        // TODO: Update listener?
168        ctx.backlog = Some(backlog);
169        Ok(())
170    }
171
172    fn push_addr(&self, addr: addr::Vsock) -> bool {
173        let mut ctx = self.0.borrow_mut();
174        if Addr(addr.remote_cid, addr.local_port) != ctx.port {
175            panic!("request address doesn't match local socket address");
176        }
177        let Some(ref mut backlog) = ctx.backlog else {
178            panic!("pushing address when not yet bound");
179        };
180        if *backlog == 0 {
181            return false;
182        }
183        *backlog -= 1;
184        ctx.accept_queue.push_back(addr);
185        if ctx.signaled == false {
186            let _ = ctx.control.signal_peer(zx::Signals::empty(), ZXIO_SIGNAL_INCOMING);
187            ctx.signaled = true
188        }
189        return true;
190    }
191
192    fn pop_addr(&self) -> Option<addr::Vsock> {
193        let mut ctx = self.0.borrow_mut();
194        if let Some(addr) = ctx.accept_queue.pop_front() {
195            let Some(ref mut backlog) = ctx.backlog else {
196                return None;
197            };
198            *backlog += 1;
199            if ctx.accept_queue.len() == 0 {
200                let _ = ctx.control.signal_peer(ZXIO_SIGNAL_INCOMING, zx::Signals::empty());
201                ctx.signaled = false;
202            }
203            Some(addr)
204        } else {
205            None
206        }
207    }
208
209    fn port(&self) -> Addr {
210        self.0.borrow_mut().port
211    }
212}
213
214enum Listener {
215    Bound,
216    Channel(mpsc::UnboundedSender<addr::Vsock>),
217    Queue(SocketContext),
218}
219
220struct State {
221    guest_vsock_device: Option<DeviceProxy>,
222    loopback_vsock_device: Option<DeviceProxy>,
223    local_cid: Cid,
224    events: HashMap<Event, oneshot::Sender<()>>,
225    used_ports: HashMap<Cid, port::Tracker>,
226    listeners: HashMap<Addr, Listener>,
227    tasks: fasync::TaskGroup,
228}
229
230impl State {
231    fn device(&self, addr: &addr::Vsock) -> &DeviceProxy {
232        match (addr.remote_cid, &self.guest_vsock_device, &self.loopback_vsock_device) {
233            (VMADDR_CID_LOCAL, _, Some(loopback)) => &loopback,
234            (VMADDR_CID_HOST, Some(guest), _) => &guest,
235            (VMADDR_CID_HOST, None, Some(loopback)) => &loopback,
236            (cid, None, Some(loopback)) if cid == self.local_cid => &loopback,
237            _ => unreachable!("Shouldn't be able to end up here!"),
238        }
239    }
240}
241
242#[derive(Clone)]
243pub struct Vsock(Rc<RefCell<State>>);
244
245impl Vsock {
246    /// Creates a new vsock service connected to the given `DeviceProxy`
247    ///
248    /// The creation is asynchronous due to need to invoke methods on the given `DeviceProxy`. On
249    /// success a pair of `Self, impl Future<Result<_, Error>>` is returned. The `impl Future` is
250    /// a future that is listening for and processing messages from the `device`. This future needs
251    /// to be evaluated for other methods on the returned `Self` to complete successfully. Unless
252    /// a fatal error occurs the future will never yield a result and will execute infinitely.
253    pub async fn new(
254        guest_vsock_device: Option<DeviceProxy>,
255        loopback_vsock_device: Option<DeviceProxy>,
256    ) -> Result<(Self, impl Future<Output = Result<Vec<Infallible>, anyhow::Error>>), anyhow::Error>
257    {
258        let mut server_streams = Vec::new();
259        let mut start_device = |device: &DeviceProxy| {
260            let (callbacks_client, callbacks_server) =
261                endpoints::create_endpoints::<CallbacksMarker>();
262            server_streams.push(callbacks_server.into_stream());
263
264            device.start(callbacks_client).map(map_driver_result).err_into::<anyhow::Error>()
265        };
266        let mut local_cid = VMADDR_CID_LOCAL;
267        if let Some(ref device) = guest_vsock_device {
268            start_device(device).await.context("Failed to start guest device")?;
269            local_cid = device.get_cid().await?;
270        }
271        if let Some(ref device) = loopback_vsock_device {
272            start_device(device).await.context("Failed to start loopback device")?;
273        }
274        let service = State {
275            guest_vsock_device,
276            loopback_vsock_device,
277            local_cid,
278            events: HashMap::new(),
279            used_ports: HashMap::new(),
280            listeners: HashMap::new(),
281            tasks: fasync::TaskGroup::new(),
282        };
283
284        let service = Vsock(Rc::new(RefCell::new(service)));
285        let callback_loops: Vec<_> = server_streams
286            .into_iter()
287            .map(|stream| service.clone().run_callbacks(stream))
288            .collect();
289
290        Ok((service, future::try_join_all(callback_loops)))
291    }
292    async fn run_callbacks(
293        self,
294        mut callbacks: CallbacksRequestStream,
295    ) -> Result<Infallible, anyhow::Error> {
296        while let Some(Ok(cb)) = callbacks.next().await {
297            self.borrow_mut().do_callback(cb);
298        }
299        // The only way to get here is if our callbacks stream ended, since our notifications
300        // cannot disconnect as we are holding a reference to them in |service|.
301        Err(format_err!("Driver disconnected"))
302    }
303
304    fn supported_cid(&self, cid: u32) -> bool {
305        cid == VMADDR_CID_HOST || cid == VMADDR_CID_LOCAL || cid == self.borrow().local_cid
306    }
307
308    // Spawns a new asynchronous task for listening for incoming connections on a port.
309    fn start_listener(
310        &self,
311        acceptor: fidl::endpoints::ClientEnd<fidl_fuchsia_vsock::AcceptorMarker>,
312        local_port: u32,
313    ) -> Result<(), Error> {
314        let acceptor = acceptor.into_proxy();
315        let stream = self.listen_port(local_port)?;
316        self.borrow_mut().tasks.local(
317            self.clone()
318                .run_connection_listener(stream, acceptor)
319                .unwrap_or_else(|err| log::warn!("Error {} running connection listener", err)),
320        );
321        Ok(())
322    }
323
324    // Spawns a new asynchronous task for listening for incoming connections on a port.
325    fn start_listener2(
326        &self,
327        listener: fidl::endpoints::ServerEnd<fidl_fuchsia_vsock::ListenerMarker>,
328        port: Addr,
329    ) -> Result<(), Error> {
330        let stream = listener.into_stream();
331        self.bind_port(port.clone())?;
332        self.borrow_mut().tasks.local(
333            self.clone()
334                .run_connection_listener2(stream, port)
335                .unwrap_or_else(|err| log::warn!("Error {} running connection listener", err)),
336        );
337        Ok(())
338    }
339
340    // Handles a single incoming client request.
341    async fn handle_request(&self, request: ConnectorRequest) -> Result<(), Error> {
342        match request {
343            ConnectorRequest::Connect { remote_cid, remote_port, con, responder } => responder
344                .send(
345                    self.make_connection(remote_cid, remote_port, con)
346                        .await
347                        .map_err(|e| e.into_status().into_raw()),
348                ),
349            ConnectorRequest::Listen { local_port, acceptor, responder } => responder.send(
350                self.start_listener(acceptor, local_port).map_err(|e| e.into_status().into_raw()),
351            ),
352            ConnectorRequest::Bind { remote_cid, local_port, listener, responder } => responder
353                .send(
354                    self.start_listener2(listener, Addr(remote_cid, local_port))
355                        .map_err(|e| e.into_status().into_raw()),
356                ),
357        }
358        .map_err(|e| Error::ClientCommunication(e.into()))
359    }
360
361    /// Evaluates messages on a `ConnectorRequestStream` until completion or error
362    ///
363    /// Takes ownership of a `RequestStream` that is most likely created from a `ServicesServer`
364    /// and processes any incoming requests on it.
365    pub async fn run_client_connection(self, request: ConnectorRequestStream) {
366        let self_ref = &self;
367        let fut = request
368            .map_err(|err| Error::ClientCommunication(err.into()))
369            // TODO: The parallel limit of 4 is currently invented with no basis and should
370            // made something more sensible.
371            .try_for_each_concurrent(4, |request| {
372                self_ref
373                    .handle_request(request)
374                    .or_else(|e| future::ready(if e.is_comm_failure() { Err(e) } else { Ok(()) }))
375            });
376        if let Err(e) = fut.await {
377            log::info!("Failed to handle request {}", e);
378        }
379    }
380    fn alloc_ephemeral_port(self, cid: Cid) -> Option<AllocatedPort> {
381        let p = self.borrow_mut().used_ports.entry(cid).or_default().allocate();
382        p.map(|p| AllocatedPort { port: Addr(cid, p), service: self })
383    }
384    // Creates a `ListenStream` that will retrieve raw incoming connection requests.
385    // These requests come from the device via the run_callbacks future.
386    fn listen_port(&self, port: u32) -> Result<ListenStream, Error> {
387        if port::is_ephemeral(port) {
388            log::info!("Rejecting request to listen on ephemeral port {}", port);
389            return Err(Error::ConnectionRefused);
390        }
391        match self.borrow_mut().listeners.entry(Addr(VMADDR_CID_HOST, port)) {
392            std::collections::hash_map::Entry::Vacant(entry) => {
393                let (sender, receiver) = mpsc::unbounded();
394                let listen =
395                    ListenStream { local_port: port, service: self.clone(), stream: receiver };
396                entry.insert(Listener::Channel(sender));
397                Ok(listen)
398            }
399            _ => {
400                log::info!("Attempt to listen on already bound port {}", port);
401                Err(Error::AlreadyBound)
402            }
403        }
404    }
405
406    fn bind_port(&self, port: Addr) -> Result<(), Error> {
407        if port::is_ephemeral(port.1) {
408            log::info!("Rejecting request to listen on ephemeral port {}", port.1);
409            return Err(Error::ConnectionRefused);
410        }
411        match self.borrow_mut().listeners.entry(port) {
412            std::collections::hash_map::Entry::Vacant(entry) => {
413                entry.insert(Listener::Bound);
414                Ok(())
415            }
416            _ => {
417                log::info!("Attempt to listen on already bound port {:?}", port);
418                Err(Error::AlreadyBound)
419            }
420        }
421    }
422
423    // Helper for inserting an event into the events hashmap
424    fn register_event(&self, event: Event) -> Result<OneshotEvent, Error> {
425        match self.borrow_mut().events.entry(event) {
426            std::collections::hash_map::Entry::Vacant(entry) => {
427                let (sender, receiver) = oneshot::channel();
428                let event = OneshotEvent {
429                    event: Some(entry.key().clone()),
430                    service: self.clone(),
431                    oneshot: receiver,
432                };
433                entry.insert(sender);
434                Ok(event)
435            }
436            _ => Err(Error::AlreadyBound),
437        }
438    }
439
440    // These helpers are wrappers around sending a message to the device, and creating events that
441    // will be signaled by the run_callbacks future when it receives a message from the device.
442    fn send_request(
443        &self,
444        addr: &addr::Vsock,
445        data: zx::Socket,
446    ) -> Result<impl Future<Output = Result<(OneshotEvent, OneshotEvent), Error>> + 'static, Error>
447    {
448        let shutdown_callback =
449            self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
450        let response_callback =
451            self.register_event(Event { action: EventType::Response, addr: addr.clone() })?;
452
453        let send_request_fut = self.borrow_mut().send_request(&addr, data);
454
455        Ok(async move {
456            send_request_fut.await?;
457            Ok((shutdown_callback, response_callback))
458        })
459    }
460    fn send_response(
461        &self,
462        addr: &addr::Vsock,
463        data: zx::Socket,
464    ) -> Result<impl Future<Output = Result<OneshotEvent, Error>> + 'static, Error> {
465        let shutdown_callback =
466            self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
467
468        let send_request_fut = self.borrow_mut().send_response(&addr, data);
469
470        Ok(async move {
471            send_request_fut.await?;
472            Ok(shutdown_callback)
473        })
474    }
475
476    // Runs a connected socket until completion. Processes any VMO sends and shutdown events.
477    async fn run_connection<ShutdownFut>(
478        self,
479        addr: addr::Vsock,
480        shutdown_event: ShutdownFut,
481        mut requests: ConnectionRequestStream,
482        _port: Option<AllocatedPort>,
483    ) -> Result<(), Error>
484    where
485        ShutdownFut:
486            Future<Output = Result<(), futures::channel::oneshot::Canceled>> + std::marker::Unpin,
487    {
488        let mut shutdown_event = shutdown_event.fuse();
489        select! {
490            shutdown_event = shutdown_event => {
491                let fut = future::ready(shutdown_event)
492                    .err_into()
493                    .and_then(|()| self.borrow_mut().send_rst(&addr));
494                return fut.await;
495            },
496            request = requests.next() => {
497                match request {
498                    Some(Ok(ConnectionRequest::Shutdown{control_handle: _control_handle})) => {
499                        let fut =
500                            self.borrow_mut().send_shutdown(&addr)
501                                // Wait to either receive the RST for the client or to be
502                                // shut down for some other reason
503                                .and_then(|()| shutdown_event.err_into());
504                        return fut.await;
505                    },
506                    // Generate a RST for a non graceful client disconnect.
507                    Some(Err(e)) => {
508                        let fut = self.borrow_mut().send_rst(&addr);
509                        fut.await?;
510                        return Err(Error::ClientCommunication(e.into()));
511                    },
512                    None => {
513                        let fut = self.borrow_mut().send_rst(&addr);
514                        return fut.await;
515                    },
516                }
517            }
518        }
519    }
520
521    fn listen(&self, socket: &SocketContext, backlog: u32) -> Result<(), Error> {
522        socket.listen(backlog)?;
523        // Replace "bound" listener with a socket accept queue.
524        match self.borrow_mut().listeners.entry(socket.port()) {
525            std::collections::hash_map::Entry::Vacant(_) => {
526                // We should be in bound state. Something went wrong if we end up here.
527                log::warn!("Expected listener to be in bound state, but listener not found!");
528                return Err(Error::AlreadyBound);
529            }
530            std::collections::hash_map::Entry::Occupied(mut entry) => {
531                if !matches!(entry.get(), Listener::Bound) {
532                    // Listen was probably already called. The call to socket.listen should
533                    // probably already have failed in this case.
534                    log::warn!("Listen called multiple times.");
535                    return Err(Error::AlreadyBound);
536                }
537                entry.insert(Listener::Queue(socket.clone()));
538            }
539        };
540
541        Ok(())
542    }
543
544    async fn accept(
545        &self,
546        socket: &SocketContext,
547        con: ConnectionTransport,
548    ) -> Result<addr::Vsock, Error> {
549        if let Some(addr) = socket.pop_addr() {
550            let data = con.data;
551            let con = con.con.into_stream();
552            let shutdown_event = self.send_response(&addr, data)?.await?;
553            self.borrow_mut().tasks.local(
554                self.clone()
555                    .run_connection(addr.clone(), shutdown_event, con, None)
556                    .map_err(|err| log::warn!("Error {} whilst running connection", err))
557                    .map(|_| ()),
558            );
559            // TODO: check if we want want to return the local port for the connection or the local
560            // port which the request came over.
561            Ok(addr)
562        } else {
563            Err(Error::NoConnectionsInQueue)
564        }
565    }
566
567    // Handles a single incoming client request.
568    async fn handle_listener_request(
569        &self,
570        socket: &SocketContext,
571        request: ListenerRequest,
572    ) -> Result<(), Error> {
573        match request {
574            ListenerRequest::Listen { backlog, responder } => {
575                responder.send(self.listen(socket, backlog).map_err(|e| e.into_status().into_raw()))
576            }
577            ListenerRequest::Accept { con, responder } => match self.accept(socket, con).await {
578                Ok(addr) => responder.send(Ok(&addr)),
579                Err(e) => responder.send(Err(e.into_status().into_raw())),
580            },
581        }
582        .map_err(|e| Error::ClientCommunication(e.into()))
583    }
584
585    async fn run_connection_listener2(
586        self,
587        request: ListenerRequestStream,
588        port: Addr,
589    ) -> Result<(), Error> {
590        let socket = SocketContext::new(port, request.control_handle());
591        let self_ref = &self;
592        let fut = request
593            .map_err(|err| Error::ClientCommunication(err.into()))
594            .try_for_each_concurrent(None, |request| {
595                self_ref
596                    .handle_listener_request(&socket, request)
597                    .or_else(|e| future::ready(if e.is_comm_failure() { Err(e) } else { Ok(()) }))
598            });
599        if let Err(e) = fut.await {
600            log::info!("Failed to handle request {}", e);
601        }
602        self.deregister(Deregister::Listen(socket.port()));
603        Ok(())
604    }
605
606    // Waits for incoming connections on the given `ListenStream`, checks with the
607    // user via the `acceptor` if it should be accepted, and if so spawns a new
608    // asynchronous task to run the connection.
609    async fn run_connection_listener(
610        self,
611        incoming: ListenStream,
612        acceptor: AcceptorProxy,
613    ) -> Result<(), Error> {
614        incoming
615            .then(|addr| acceptor.accept(&*addr.clone()).map_ok(|maybe_con| (maybe_con, addr)))
616            .map_err(|e| Error::ClientCommunication(e.into()))
617            .try_for_each(|(maybe_con, addr)| async {
618                match maybe_con {
619                    Some(con) => {
620                        let data = con.data;
621                        let con = con.con.into_stream();
622                        let shutdown_event = self.send_response(&addr, data)?.await?;
623                        self.borrow_mut().tasks.local(
624                            self.clone()
625                                .run_connection(addr, shutdown_event, con, None)
626                                .map_err(|err| {
627                                    log::warn!("Error {} whilst running connection", err)
628                                })
629                                .map(|_| ()),
630                        );
631                        Ok(())
632                    }
633                    None => {
634                        let fut = self.borrow_mut().send_rst(&addr);
635                        fut.await
636                    }
637                }
638            })
639            .await
640    }
641
642    // Attempts to connect to the given remote cid/port. If successful spawns a new
643    // asynchronous task to run the connection until completion.
644    async fn make_connection(
645        &self,
646        remote_cid: u32,
647        remote_port: u32,
648        con: ConnectionTransport,
649    ) -> Result<u32, Error> {
650        if !self.supported_cid(remote_cid) {
651            log::info!("Rejecting request to connect to unsupported CID {}", remote_cid);
652            return Err(Error::ConnectionRefused);
653        }
654        let data = con.data;
655        let con = con.con.into_stream();
656        let port = self.clone().alloc_ephemeral_port(remote_cid).ok_or(Error::OutOfPorts)?;
657        let port_value = port.port.1;
658        let addr = addr::Vsock::new(port_value, remote_port, remote_cid);
659        let (shutdown_event, response_event) = self.send_request(&addr, data)?.await?;
660        let mut shutdown_event = shutdown_event.fuse();
661        select! {
662            _shutdown_event = shutdown_event => {
663                // Getting a RST here just indicates a rejection and
664                // not any underlying issues.
665                return Err(Error::ConnectionRefused);
666            },
667            response_event = response_event.fuse() => response_event?,
668        }
669
670        self.borrow_mut().tasks.local(
671            self.clone()
672                .run_connection(addr, shutdown_event, con, Some(port))
673                .unwrap_or_else(|err| log::warn!("Error {} whilst running connection", err)),
674        );
675        Ok(port_value)
676    }
677
678    /// Mutably borrow the wrapped value.
679    fn borrow_mut(&self) -> RefMut<'_, State> {
680        self.0.borrow_mut()
681    }
682
683    fn borrow(&self) -> Ref<'_, State> {
684        self.0.borrow()
685    }
686
687    // Deregisters the specified event.
688    fn deregister(&self, event: Deregister) {
689        self.borrow_mut().deregister(event);
690    }
691}
692
693impl State {
694    // Remove the `event` from the `events` `HashMap`
695    fn deregister(&mut self, event: Deregister) {
696        match event {
697            Deregister::Event(e) => {
698                self.events.remove(&e);
699            }
700            Deregister::Listen(a) => {
701                self.listeners.remove(&a);
702            }
703            Deregister::Port(p) => {
704                self.used_ports.get_mut(&p.0).unwrap().free(p.1);
705            }
706        }
707    }
708
709    // Wrappers around device functions with nicer type signatures
710    fn send_request(
711        &mut self,
712        addr: &addr::Vsock,
713        data: zx::Socket,
714    ) -> impl Future<Output = Result<(), Error>> {
715        self.device(addr).send_request(&addr.clone(), data).map(map_driver_result)
716    }
717    fn send_response(
718        &mut self,
719        addr: &addr::Vsock,
720        data: zx::Socket,
721    ) -> impl Future<Output = Result<(), Error>> {
722        self.device(addr).send_response(&addr.clone(), data).map(map_driver_result)
723    }
724    fn send_rst(
725        &mut self,
726        addr: &addr::Vsock,
727    ) -> impl Future<Output = Result<(), Error>> + 'static {
728        self.device(addr).send_rst(&addr.clone()).map(map_driver_result)
729    }
730    fn send_shutdown(
731        &mut self,
732        addr: &addr::Vsock,
733    ) -> impl Future<Output = Result<(), Error>> + 'static {
734        self.device(addr).send_shutdown(&addr).map(map_driver_result)
735    }
736
737    // Processes a single callback from the `device`. This is intended to be used by
738    // `Vsock::run_callbacks`
739    fn do_callback(&mut self, callback: CallbacksRequest) {
740        match callback {
741            CallbacksRequest::Response { addr, control_handle: _control_handle } => {
742                self.events
743                    .remove(&Event { action: EventType::Response, addr: addr::Vsock::from(addr) })
744                    .map(|channel| channel.send(()));
745            }
746            CallbacksRequest::Rst { addr, control_handle: _control_handle } => {
747                self.events
748                    .remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) });
749            }
750            CallbacksRequest::Request { addr, control_handle: _control_handle } => {
751                let addr = addr::Vsock::from(addr);
752                let reset = |state: &mut State| {
753                    let task = state.send_rst(&addr).map(|_| ());
754                    state.tasks.local(task);
755                };
756                match self.listeners.get(&Addr(addr.remote_cid, addr.local_port)) {
757                    Some(Listener::Bound) => {
758                        log::warn!(
759                            "Request on port {} denied due to socket only bound, not yet listening",
760                            addr.local_port
761                        );
762                        reset(self);
763                    }
764                    Some(Listener::Channel(sender)) => {
765                        let _ = sender.unbounded_send(addr.clone());
766                    }
767                    Some(Listener::Queue(socket)) => {
768                        if !socket.push_addr(addr.clone()) {
769                            log::warn!(
770                                "Request on port {} denied due to full backlog",
771                                addr.local_port
772                            );
773                            reset(self);
774                        }
775                    }
776                    None => {
777                        log::warn!("Request on port {} with no listener", addr.local_port);
778                        reset(self);
779                    }
780                }
781            }
782            CallbacksRequest::Shutdown { addr, control_handle: _control_handle } => {
783                self.events
784                    .remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) })
785                    .map(|channel| channel.send(()));
786            }
787            CallbacksRequest::TransportReset { new_cid: _new_cid, responder } => {
788                self.events.clear();
789                let _ = responder.send();
790            }
791        }
792    }
793}
794
795struct AllocatedPort {
796    service: Vsock,
797    port: Addr,
798}
799
800impl Deref for AllocatedPort {
801    type Target = Addr;
802
803    fn deref(&self) -> &Addr {
804        &self.port
805    }
806}
807
808impl Drop for AllocatedPort {
809    fn drop(&mut self) {
810        self.service.deregister(Deregister::Port(self.port));
811    }
812}
813
814struct OneshotEvent {
815    event: Option<Event>,
816    service: Vsock,
817    oneshot: oneshot::Receiver<()>,
818}
819
820impl Drop for OneshotEvent {
821    fn drop(&mut self) {
822        self.event.take().map(|e| self.service.deregister(Deregister::Event(e)));
823    }
824}
825
826impl Future for OneshotEvent {
827    type Output = <oneshot::Receiver<()> as Future>::Output;
828
829    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
830        match self.oneshot.poll_unpin(cx) {
831            Poll::Ready(x) => {
832                // Take the event so that we don't try to deregister it later,
833                // as by having sent the message we just received the callbacks
834                // task will already have removed it
835                self.event.take();
836                Poll::Ready(x)
837            }
838            p => p,
839        }
840    }
841}
842
843struct ListenStream {
844    local_port: Port,
845    service: Vsock,
846    stream: mpsc::UnboundedReceiver<addr::Vsock>,
847}
848
849impl Drop for ListenStream {
850    fn drop(&mut self) {
851        self.service.deregister(Deregister::Listen(Addr(VMADDR_CID_HOST, self.local_port)));
852    }
853}
854
855impl Stream for ListenStream {
856    type Item = addr::Vsock;
857
858    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
859        self.stream.poll_next_unpin(cx)
860    }
861}