use {
crate::{addr, port},
anyhow::{format_err, Context as _},
fidl::endpoints,
fidl_fuchsia_hardware_vsock::{
CallbacksMarker, CallbacksRequest, CallbacksRequestStream, DeviceProxy,
},
fidl_fuchsia_vsock::{
AcceptorProxy, ConnectionRequest, ConnectionRequestStream, ConnectionTransport,
ConnectorRequest, ConnectorRequestStream,
},
fuchsia_async as fasync,
fuchsia_sync::Mutex,
fuchsia_zircon as zx,
futures::{
channel::{mpsc, oneshot},
future, select, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
},
std::{
collections::HashMap,
convert::Infallible,
ops::Deref,
pin::Pin,
sync::Arc,
task::{Context, Poll},
},
thiserror::Error,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum EventType {
Shutdown,
VmoComplete,
Response,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct Event {
action: EventType,
addr: addr::Vsock,
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
enum Deregister {
Event(Event),
Listen(u32),
Port(u32),
}
#[derive(Error, Debug)]
enum Error {
#[error("Driver returned failure status {}", _0)]
Driver(#[source] zx::Status),
#[error("All ephemeral ports are allocated")]
OutOfPorts,
#[error("Addr has already been bound")]
AlreadyBound,
#[error("Connection refused by remote")]
ConnectionRefused,
#[error("Error whilst communication with client")]
ClientCommunication(#[source] anyhow::Error),
#[error("Error whilst communication with client")]
DriverCommunication(#[source] anyhow::Error),
#[error("Driver reset the connection")]
ConnectionReset,
}
impl From<oneshot::Canceled> for Error {
fn from(_: oneshot::Canceled) -> Error {
Error::ConnectionReset
}
}
impl Error {
pub fn into_status(&self) -> zx::Status {
match self {
Error::Driver(status) => *status,
Error::OutOfPorts => zx::Status::NO_RESOURCES,
Error::AlreadyBound => zx::Status::ALREADY_BOUND,
Error::ConnectionRefused => zx::Status::UNAVAILABLE,
Error::ClientCommunication(err) | Error::DriverCommunication(err) => {
*err.downcast_ref::<zx::Status>().unwrap_or(&zx::Status::INTERNAL)
}
Error::ConnectionReset => zx::Status::PEER_CLOSED,
}
}
pub fn is_comm_failure(&self) -> bool {
match self {
Error::ClientCommunication(_) | Error::DriverCommunication(_) => true,
_ => false,
}
}
}
fn map_driver_result(result: Result<i32, fidl::Error>) -> Result<(), Error> {
result
.map_err(|x| Error::DriverCommunication(x.into()))
.and_then(|x| zx::Status::ok(x).map_err(Error::Driver))
}
fn send_result<T>(
result: Result<T, Error>,
send: impl FnOnce(i32, Option<T>) -> Result<(), fidl::Error>,
) -> Result<(), Error> {
match result {
Ok(v) => send(zx::Status::OK.into_raw(), Some(v))
.map_err(|e| Error::ClientCommunication(e.into())),
Err(e) => {
send(e.into_status().into_raw(), None)
.map_err(|e| Error::ClientCommunication(e.into()))?;
Err(e)
}
}
}
struct State {
device: DeviceProxy,
events: HashMap<Event, oneshot::Sender<()>>,
used_ports: port::Tracker,
listens: HashMap<u32, mpsc::UnboundedSender<addr::Vsock>>,
}
pub struct LockedState {
inner: Mutex<State>,
deregister_tx: crossbeam::channel::Sender<Deregister>,
deregister_rx: crossbeam::channel::Receiver<Deregister>,
}
#[derive(Clone)]
pub struct Vsock {
inner: Arc<LockedState>,
}
impl Vsock {
pub async fn new(
device: DeviceProxy,
) -> Result<(Self, impl Future<Output = Result<Infallible, anyhow::Error>>), anyhow::Error>
{
let (callbacks_client, callbacks_server) = endpoints::create_endpoints::<CallbacksMarker>();
let server_stream = callbacks_server.into_stream()?;
device
.start(callbacks_client)
.map(|x| map_driver_result(x))
.err_into::<anyhow::Error>()
.await
.context("Failed to start device")?;
let service = State {
device,
events: HashMap::new(),
used_ports: port::Tracker::new(),
listens: HashMap::new(),
};
let (tx, rx) = crossbeam::channel::unbounded();
let service =
LockedState { inner: Mutex::new(service), deregister_tx: tx, deregister_rx: rx };
let service = Vsock { inner: Arc::new(service) };
let callback_loop = service.clone().run_callbacks(server_stream);
Ok((service, callback_loop))
}
async fn run_callbacks(
self,
mut callbacks: CallbacksRequestStream,
) -> Result<Infallible, anyhow::Error> {
while let Some(Ok(cb)) = callbacks.next().await {
self.lock().do_callback(cb);
}
Err(format_err!("Driver disconnected"))
}
fn start_listener(
&self,
acceptor: fidl::endpoints::ClientEnd<fidl_fuchsia_vsock::AcceptorMarker>,
local_port: u32,
) -> Result<(), Error> {
let acceptor = acceptor.into_proxy().map_err(|x| Error::ClientCommunication(x.into()))?;
let stream = self.listen_port(local_port)?;
fasync::Task::spawn(
self.clone()
.run_connection_listener(stream, acceptor)
.unwrap_or_else(|err| tracing::warn!("Error {} running connection listener", err)),
)
.detach();
Ok(())
}
async fn handle_request(&self, request: ConnectorRequest) -> Result<(), Error> {
match request {
ConnectorRequest::Connect { remote_cid, remote_port, con, responder } => {
send_result(self.make_connection(remote_cid, remote_port, con).await, |r, v| {
responder.send(r, v.unwrap_or(0))
})
}
ConnectorRequest::Listen { local_port, acceptor, responder } => {
send_result(self.start_listener(acceptor, local_port), |r, _| responder.send(r))
}
}
}
pub async fn run_client_connection(
self,
request: ConnectorRequestStream,
) -> Result<(), anyhow::Error> {
let self_ref = &self;
let fut = request
.map_err(|err| Error::ClientCommunication(err.into()))
.try_for_each_concurrent(4, |request| {
self_ref
.handle_request(request)
.or_else(|e| future::ready(if e.is_comm_failure() { Err(e) } else { Ok(()) }))
})
.err_into();
fut.await
}
fn alloc_ephemeral_port(self) -> Option<AllocatedPort> {
let p = self.lock().used_ports.allocate();
p.map(|p| AllocatedPort { port: p, service: self })
}
fn listen_port(&self, port: u32) -> Result<ListenStream, Error> {
if port::is_ephemeral(port) {
tracing::info!("Rejecting request to listen on ephemeral port {}", port);
return Err(Error::ConnectionRefused);
}
match self.lock().listens.entry(port) {
std::collections::hash_map::Entry::Vacant(entry) => {
let (sender, receiver) = mpsc::unbounded();
let listen =
ListenStream { local_port: port, service: self.clone(), stream: receiver };
entry.insert(sender);
Ok(listen)
}
_ => {
tracing::info!("Attempt to listen on already bound port {}", port);
Err(Error::AlreadyBound)
}
}
}
fn register_event(&self, event: Event) -> Result<OneshotEvent, Error> {
match self.lock().events.entry(event) {
std::collections::hash_map::Entry::Vacant(entry) => {
let (sender, receiver) = oneshot::channel();
let event = OneshotEvent {
event: Some(entry.key().clone()),
service: self.clone(),
oneshot: receiver,
};
entry.insert(sender);
Ok(event)
}
_ => Err(Error::AlreadyBound),
}
}
fn send_request(
&self,
addr: &addr::Vsock,
data: zx::Socket,
) -> Result<impl Future<Output = Result<(OneshotEvent, OneshotEvent), Error>>, Error> {
let shutdown_callback =
self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
let response_callback =
self.register_event(Event { action: EventType::Response, addr: addr.clone() })?;
let send_request_fut = self.lock().device.send_request(&addr, data);
Ok(async move {
map_driver_result(send_request_fut.await)?;
Ok((shutdown_callback, response_callback))
})
}
fn send_response(
&self,
addr: &addr::Vsock,
data: zx::Socket,
) -> Result<impl Future<Output = Result<OneshotEvent, Error>>, Error> {
let shutdown_callback =
self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
let send_request_fut = self.lock().device.send_response(&addr.clone(), data);
Ok(async move {
map_driver_result(send_request_fut.await)?;
Ok(shutdown_callback)
})
}
fn send_vmo(
&self,
addr: &addr::Vsock,
vmo: zx::Vmo,
off: u64,
len: u64,
) -> Result<impl Future<Output = Result<OneshotEvent, Error>>, Error> {
let vmo_callback =
self.register_event(Event { action: EventType::VmoComplete, addr: addr.clone() })?;
let send_request_fut = self.lock().device.send_vmo(&addr, vmo, off, len);
Ok(async move {
map_driver_result(send_request_fut.await)?;
Ok(vmo_callback)
})
}
async fn run_connection<ShutdownFut>(
self,
addr: addr::Vsock,
shutdown_event: ShutdownFut,
mut requests: ConnectionRequestStream,
_port: Option<AllocatedPort>,
) -> Result<(), Error>
where
ShutdownFut:
Future<Output = Result<(), futures::channel::oneshot::Canceled>> + std::marker::Unpin,
{
async fn wait_vmo_complete<ShutdownFut>(
mut shutdown_event: &mut futures::future::Fuse<ShutdownFut>,
cb: OneshotEvent,
) -> Result<zx::Status, Result<(), Error>>
where
ShutdownFut: Future<Output = Result<(), futures::channel::oneshot::Canceled>>
+ std::marker::Unpin,
{
select! {
shutdown_event = shutdown_event => Err(shutdown_event.map_err(|e| e.into())),
cb = cb.fuse() => match cb {
Ok(_) => Ok(zx::Status::OK),
Err(_) => Ok(Error::ConnectionReset.into_status()),
},
}
}
let mut shutdown_event = shutdown_event.fuse();
loop {
select! {
shutdown_event = shutdown_event => {
let fut = future::ready(shutdown_event)
.err_into()
.and_then(|()| self.lock().send_rst(&addr));
return fut.await;
},
request = requests.next() => {
match request {
Some(Ok(ConnectionRequest::Shutdown{control_handle: _control_handle})) => {
let fut =
self.lock().send_shutdown(&addr)
.and_then(|()| shutdown_event.err_into());
return fut.await;
},
Some(Ok(ConnectionRequest::SendVmo{vmo, off, len, responder})) => {
let result = self.send_vmo(&addr, vmo, off, len);
let result = match result {
Ok(fut) => fut.await,
Err(e) => Err(e),
};
let status = match result {
Ok(cb) => {
match wait_vmo_complete(&mut shutdown_event, cb).await {
Err(e) => return e,
Ok(o) => o,
}
},
Err(e) => e.into_status(),
};
let _ = responder.send(status.into_raw());
},
Some(Err(e)) => {
let fut = self.lock().send_rst(&addr);
fut.await?;
return Err(Error::ClientCommunication(e.into()));
},
None => {
let fut = self.lock().send_rst(&addr);
return fut.await;
},
}
},
}
}
}
async fn run_connection_listener(
self,
incoming: ListenStream,
acceptor: AcceptorProxy,
) -> Result<(), Error> {
incoming
.then(|addr| acceptor.accept(&*addr.clone()).map_ok(|maybe_con| (maybe_con, addr)))
.map_err(|e| Error::ClientCommunication(e.into()))
.try_for_each(|(maybe_con, addr)| async {
match maybe_con {
Some(con) => {
let data = con.data;
let con = con
.con
.into_stream()
.map_err(|x| Error::ClientCommunication(x.into()))?;
let shutdown_event = self.send_response(&addr, data)?.await?;
fasync::Task::spawn(
self.clone()
.run_connection(addr, shutdown_event, con, None)
.map_err(|err| {
tracing::warn!("Error {} whilst running connection", err)
})
.map(|_| ()),
)
.detach();
Ok(())
}
None => {
let fut = self.lock().send_rst(&addr);
fut.await
}
}
})
.await
}
async fn make_connection(
&self,
remote_cid: u32,
remote_port: u32,
con: ConnectionTransport,
) -> Result<u32, Error> {
let data = con.data;
let con = con.con.into_stream().map_err(|x| Error::ClientCommunication(x.into()))?;
let port = self.clone().alloc_ephemeral_port().ok_or(Error::OutOfPorts)?;
let port_value = *port;
let addr = addr::Vsock::new(port_value, remote_port, remote_cid);
let (shutdown_event, response_event) = self.send_request(&addr, data)?.await?;
let mut shutdown_event = shutdown_event.fuse();
select! {
_shutdown_event = shutdown_event => {
return Err(Error::ConnectionRefused);
},
response_event = response_event.fuse() => response_event?,
}
fasync::Task::spawn(
self.clone()
.run_connection(addr, shutdown_event, con, Some(port))
.unwrap_or_else(|err| tracing::warn!("Error {} whilst running connection", err)),
)
.detach();
Ok(port_value)
}
}
impl Deref for Vsock {
type Target = LockedState;
fn deref(&self) -> &LockedState {
&self.inner
}
}
impl LockedState {
fn lock(&self) -> fuchsia_sync::MutexGuard<'_, State> {
let mut guard = self.inner.lock();
self.deregister_rx.try_iter().for_each(|e| guard.deregister(e));
guard
}
fn try_lock(&self) -> Option<fuchsia_sync::MutexGuard<'_, State>> {
if let Some(mut guard) = self.inner.try_lock() {
self.deregister_rx.try_iter().for_each(|e| guard.deregister(e));
Some(guard)
} else {
None
}
}
fn deregister(&self, event: Deregister) {
if let Some(mut service) = self.try_lock() {
service.deregister(event);
} else {
let _ = self.deregister_tx.try_send(event);
}
}
}
impl State {
fn deregister(&mut self, event: Deregister) {
match event {
Deregister::Event(e) => {
self.events.remove(&e);
}
Deregister::Listen(p) => {
self.listens.remove(&p);
}
Deregister::Port(p) => {
self.used_ports.free(p);
}
}
}
fn send_rst(&mut self, addr: &addr::Vsock) -> impl Future<Output = Result<(), Error>> {
self.device.send_rst(&addr.clone()).map(|x| map_driver_result(x))
}
fn send_shutdown(&mut self, addr: &addr::Vsock) -> impl Future<Output = Result<(), Error>> {
self.device.send_shutdown(&addr).map(|x| map_driver_result(x))
}
fn do_callback(&mut self, callback: CallbacksRequest) {
match callback {
CallbacksRequest::Response { addr, control_handle: _control_handle } => {
self.events
.remove(&Event { action: EventType::Response, addr: addr::Vsock::from(addr) })
.map(|channel| channel.send(()));
}
CallbacksRequest::Rst { addr, control_handle: _control_handle } => {
self.events
.remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) });
}
CallbacksRequest::SendVmoComplete { addr, control_handle: _control_handle } => {
self.events
.remove(&Event {
action: EventType::VmoComplete,
addr: addr::Vsock::from(addr),
})
.map(|channel| channel.send(()));
}
CallbacksRequest::Request { addr, control_handle: _control_handle } => {
let addr = addr::Vsock::from(addr);
match self.listens.get(&addr.local_port) {
Some(sender) => {
let _ = sender.unbounded_send(addr.clone());
}
None => {
tracing::warn!("Request on port {} with no listener", addr.local_port);
fasync::Task::spawn(self.send_rst(&addr).map(|_| ())).detach();
}
}
}
CallbacksRequest::Shutdown { addr, control_handle: _control_handle } => {
self.events
.remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) })
.map(|channel| channel.send(()));
}
CallbacksRequest::TransportReset { new_cid: _new_cid, responder } => {
self.events.clear();
let _ = responder.send();
}
}
}
}
struct AllocatedPort {
service: Vsock,
port: u32,
}
impl Deref for AllocatedPort {
type Target = u32;
fn deref(&self) -> &u32 {
&self.port
}
}
impl Drop for AllocatedPort {
fn drop(&mut self) {
self.service.deregister(Deregister::Port(self.port));
}
}
struct OneshotEvent {
event: Option<Event>,
service: Vsock,
oneshot: oneshot::Receiver<()>,
}
impl Drop for OneshotEvent {
fn drop(&mut self) {
self.event.take().map(|e| self.service.deregister(Deregister::Event(e)));
}
}
impl Future for OneshotEvent {
type Output = <oneshot::Receiver<()> as Future>::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.oneshot.poll_unpin(cx) {
Poll::Ready(x) => {
self.event.take();
Poll::Ready(x)
}
p => p,
}
}
}
struct ListenStream {
local_port: u32,
service: Vsock,
stream: mpsc::UnboundedReceiver<addr::Vsock>,
}
impl Drop for ListenStream {
fn drop(&mut self) {
self.service.deregister(Deregister::Listen(self.local_port));
}
}
impl Stream for ListenStream {
type Item = addr::Vsock;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_next_unpin(cx)
}
}