use core::future::Future;
use zx::Status;
use crate::{Arena, ArenaBox, DispatcherRef, DriverHandle, Message, MixedHandle};
use fdf_sys::*;
use core::marker::PhantomData;
use core::mem::{size_of_val, MaybeUninit};
use core::num::NonZero;
use core::pin::Pin;
use core::ptr::{null_mut, NonNull};
use core::task::{Context, Poll, Waker};
use std::sync::{Arc, Mutex};
pub use fdf_sys::fdf_handle_t;
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct Channel<T: ?Sized + 'static>(pub(crate) DriverHandle, PhantomData<Message<T>>);
impl<T: ?Sized + 'static> Channel<T> {
pub fn create() -> Result<(Self, Self), Status> {
let mut channel1 = 0;
let mut channel2 = 0;
Status::ok(unsafe { fdf_channel_create(0, &mut channel1, &mut channel2) })?;
unsafe {
Ok((
Self::from_handle_unchecked(NonZero::new_unchecked(channel1)),
Self::from_handle_unchecked(NonZero::new_unchecked(channel2)),
))
}
}
pub fn into_driver_handle(self) -> DriverHandle {
self.0
}
unsafe fn from_handle_unchecked(handle: NonZero<fdf_handle_t>) -> Self {
Self(unsafe { DriverHandle::new_unchecked(handle) }, PhantomData)
}
pub unsafe fn from_driver_handle(handle: DriverHandle) -> Self {
Self(handle, PhantomData)
}
pub fn write(&self, message: Message<T>) -> Result<(), Status> {
let data_len = message.data().map_or(0, |data| size_of_val(&*data) as u32);
let handles_count = message.handles().map_or(0, |handles| handles.len() as u32);
let (arena, data, handles) = message.into_raw();
let data_ptr = data.map_or(null_mut(), |data| data.cast().as_ptr());
let handles_ptr = handles.map_or(null_mut(), |handles| handles.cast().as_ptr());
Status::ok(unsafe {
fdf_channel_write(
self.0.get_raw().get(),
0,
arena.as_ptr(),
data_ptr,
data_len,
handles_ptr,
handles_count,
)
})
}
pub fn write_with<F>(&self, arena: Arena, f: F) -> Result<(), Status>
where
F: for<'a> FnOnce(
&'a Arena,
)
-> (Option<ArenaBox<'a, T>>, Option<ArenaBox<'a, [Option<MixedHandle>]>>),
{
self.write(Message::new_with(arena, f))
}
pub fn write_with_data<F>(&self, arena: Arena, f: F) -> Result<(), Status>
where
F: for<'a> FnOnce(&'a Arena) -> ArenaBox<'a, T>,
{
self.write(Message::new_with_data(arena, f))
}
fn try_read_raw<'a>(&self) -> Result<Option<Message<[MaybeUninit<u8>]>>, Status> {
let mut out_arena = null_mut();
let mut out_data = null_mut();
let mut out_num_bytes = 0;
let mut out_handles = null_mut();
let mut out_num_handles = 0;
Status::ok(unsafe {
fdf_channel_read(
self.0.get_raw().get(),
0,
&mut out_arena,
&mut out_data,
&mut out_num_bytes,
&mut out_handles,
&mut out_num_handles,
)
})?;
if out_arena == null_mut() {
return Ok(None);
}
let arena = Arena(unsafe { NonNull::new_unchecked(out_arena) });
let data_ptr = if !out_data.is_null() {
let ptr = core::ptr::slice_from_raw_parts_mut(out_data.cast(), out_num_bytes as usize);
Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
} else {
None
};
let handles_ptr = if !out_handles.is_null() {
let ptr =
core::ptr::slice_from_raw_parts_mut(out_handles.cast(), out_num_handles as usize);
Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
} else {
None
};
Ok(Some(unsafe { Message::new_unchecked(arena, data_ptr, handles_ptr) }))
}
fn read_raw<'a>(&'a self, dispatcher: DispatcherRef<'a>) -> ReadMessageRawFut<'a, T> {
ReadMessageRawFut {
op: Arc::new(ReadMessageRawOp {
read_op: fdf_channel_read {
channel: unsafe { self.0.get_raw() }.get(),
handler: Some(ReadMessageRawOp::handler),
..Default::default()
},
waker: Mutex::new(None),
}),
channel: self,
dispatcher,
}
}
}
impl<T> Channel<T> {
pub fn try_read<'a>(&self) -> Result<Option<Message<T>>, Status> {
let Some(message) = self.try_read_raw()? else {
return Ok(None);
};
Ok(Some(unsafe { message.cast_unchecked() }))
}
pub async fn read(&self, dispatcher: DispatcherRef<'_>) -> Result<Option<Message<T>>, Status> {
let Some(message) = self.read_raw(dispatcher).await? else {
return Ok(None);
};
Ok(Some(unsafe { message.cast_unchecked() }))
}
}
impl Channel<[u8]> {
pub fn try_read_bytes<'a>(&self) -> Result<Option<Message<[u8]>>, Status> {
let Some(message) = self.try_read_raw()? else {
return Ok(None);
};
Ok(Some(unsafe { message.assume_init() }))
}
pub async fn read_bytes(
&self,
dispatcher: DispatcherRef<'_>,
) -> Result<Option<Message<[u8]>>, Status> {
let Some(message) = self.read_raw(dispatcher).await? else {
return Ok(None);
};
Ok(Some(unsafe { message.assume_init() }))
}
}
impl<T> From<Channel<T>> for MixedHandle {
fn from(value: Channel<T>) -> Self {
MixedHandle::from(value.0)
}
}
#[repr(C)]
struct ReadMessageRawOp {
read_op: fdf_channel_read,
waker: Mutex<Option<Waker>>,
}
impl ReadMessageRawOp {
unsafe extern "C" fn handler(
_dispatcher: *mut fdf_dispatcher,
read_op: *mut fdf_channel_read,
_status: i32,
) {
let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
let waker = op
.waker
.lock()
.unwrap()
.take()
.expect("Channel read handler somehow called with no waker registered");
waker.wake()
}
}
struct ReadMessageRawFut<'a, T: ?Sized + 'static> {
op: Arc<ReadMessageRawOp>,
channel: &'a Channel<T>,
dispatcher: DispatcherRef<'a>,
}
impl<'a, T: ?Sized + 'static> Future for ReadMessageRawFut<'a, T> {
type Output = Result<Option<Message<[MaybeUninit<u8>]>>, Status>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut waker_lock = self.op.waker.lock().unwrap();
match self.channel.try_read_raw() {
Ok(res) => Poll::Ready(Ok(res)),
Err(Status::SHOULD_WAIT) => {
if waker_lock.replace(cx.waker().clone()).is_none() {
let op = Arc::into_raw(self.op.clone());
let res = Status::ok(unsafe {
fdf_channel_wait_async(self.dispatcher.0.as_ptr(), op.cast_mut().cast(), 0)
});
match res {
Ok(()) => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl<'a, T: ?Sized + 'static> Drop for ReadMessageRawFut<'a, T> {
fn drop(&mut self) {
let mut waker_lock = self.op.waker.lock().unwrap();
if waker_lock.is_none() {
return;
}
let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.0.get_raw().get()) });
match res {
Ok(_) => {}
Err(Status::NOT_FOUND) => {
return;
}
Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
}
waker_lock.take();
if !self.dispatcher.is_unsynchronized() {
unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
}
}
}
#[cfg(test)]
mod tests {
use std::sync::mpsc;
use crate::test::with_raw_dispatcher;
use crate::tests::DropSender;
use crate::{Dispatcher, MixedHandleType};
use super::*;
#[test]
fn send_and_receive_bytes_synchronously() {
let (first, second) = Channel::create().unwrap();
let arena = Arena::new().unwrap();
assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
assert_eq!(&*second.try_read_bytes().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
second.write_with_data(arena.clone(), |arena| arena.insert_slice(&[5, 6, 7, 8])).unwrap();
assert_eq!(&*first.try_read_bytes().unwrap().unwrap().data().unwrap(), &[5, 6, 7, 8]);
assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
drop(second);
assert_eq!(
first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[9, 10, 11, 12])),
Err(Status::from_raw(ZX_ERR_PEER_CLOSED))
);
}
#[test]
fn send_and_receive_bytes_asynchronously() {
with_raw_dispatcher("channel async", |dispatcher| {
let arena = Arena::new().unwrap();
let (fin_tx, fin_rx) = mpsc::channel();
let (first, second) = Channel::create().unwrap();
dispatcher
.spawn_task(async move {
fin_tx.send(first.read_bytes(dispatcher.as_ref()).await.unwrap()).unwrap();
})
.unwrap();
second.write_with_data(arena, |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
assert_eq!(fin_rx.recv().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
});
}
#[test]
fn send_and_receive_objects_synchronously() {
let arena = Arena::new().unwrap();
let (first, second) = Channel::create().unwrap();
let (tx, rx) = mpsc::channel();
first
.write_with_data(arena.clone(), |arena| arena.insert(DropSender::new(1, tx.clone())))
.unwrap();
rx.try_recv().expect_err("should not drop the object when sent");
let message = second.try_read().unwrap().unwrap();
assert_eq!(message.data().unwrap().0, 1);
rx.try_recv().expect_err("should not drop the object when received");
drop(message);
rx.try_recv().expect("dropped when received");
}
#[test]
fn send_and_receive_handles_synchronously() {
println!("Create channels and write one end of one of the channel pairs to the other");
let (first, second) = Channel::<()>::create().unwrap();
let (inner_first, inner_second) = Channel::<String>::create().unwrap();
let message = Message::new_with(Arena::new().unwrap(), |arena| {
(None, Some(arena.insert_boxed_slice(Box::new([Some(inner_first.into())]))))
});
first.write(message).unwrap();
println!("Receive the channel back on the other end of the first channel pair.");
let mut arena = None;
let message =
second.try_read().unwrap().expect("Expected a message with contents to be received");
let (_, received_handles) = message.into_arena_boxes(&mut arena);
let mut first_handle_received =
ArenaBox::take_boxed_slice(received_handles.expect("expected handles in the message"));
let first_handle_received = first_handle_received
.first_mut()
.expect("expected one handle in the handle set")
.take()
.expect("expected the first handle to be non-null");
let first_handle_received = first_handle_received.resolve();
let MixedHandleType::Driver(driver_handle) = first_handle_received else {
panic!("Got a non-driver handle when we sent a driver handle");
};
let inner_first_received = unsafe { Channel::from_driver_handle(driver_handle) };
println!("Send and receive a string across the now-transmitted channel pair.");
inner_first_received
.write_with_data(Arena::new().unwrap(), |arena| arena.insert("boom".to_string()))
.unwrap();
assert_eq!(inner_second.try_read().unwrap().unwrap().data().unwrap(), &"boom".to_string());
}
async fn ping(dispatcher: &Dispatcher, chan: Channel<u8>) {
println!("starting ping!");
chan.write_with_data(Arena::new().unwrap(), |arena| arena.insert(0)).unwrap();
while let Ok(Some(msg)) = chan.read(dispatcher.as_ref()).await {
let next = *msg.data().unwrap();
println!("ping! {next}");
chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
}
}
async fn pong(dispatcher: &Dispatcher, fin_tx: std::sync::mpsc::Sender<()>, chan: Channel<u8>) {
println!("starting pong!");
while let Some(msg) = chan.read(dispatcher.as_ref()).await.unwrap() {
let next = *msg.data().unwrap();
println!("pong! {next}");
if next > 10 {
println!("bye!");
break;
}
chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
}
fin_tx.send(()).unwrap();
}
#[test]
fn async_ping_pong() {
with_raw_dispatcher("async ping pong", |dispatcher| {
let (fin_tx, fin_rx) = mpsc::channel();
let (ping_chan, pong_chan) = Channel::create().unwrap();
dispatcher.spawn_task(ping(&dispatcher, ping_chan)).unwrap();
dispatcher.spawn_task(pong(&dispatcher, fin_tx, pong_chan)).unwrap();
fin_rx.recv().expect("to receive final value");
});
}
#[test]
fn async_ping_pong_on_fuchsia_async() {
with_raw_dispatcher("async ping pong", |dispatcher| {
let (fin_tx, fin_rx) = mpsc::channel();
let (ping_chan, pong_chan) = Channel::create().unwrap();
dispatcher
.post_task_sync(|_status| {
let rust_async_dispatcher = crate::DispatcherBuilder::new()
.name("fuchsia-async")
.allow_thread_blocking()
.create()
.expect("failure creating blocking dispatcher for rust async");
dispatcher.spawn_task(pong(&dispatcher, fin_tx, pong_chan)).unwrap();
rust_async_dispatcher
.post_task_sync(|_| {
let mut executor = fuchsia_async::LocalExecutor::new();
executor.run_singlethreaded(ping(&dispatcher, ping_chan));
})
.unwrap();
})
.unwrap();
fin_rx.recv().expect("to receive final value");
});
}
#[test]
fn early_cancel_future() {
with_raw_dispatcher("early cancellation", |dispatcher| {
let (fin_tx, fin_rx) = mpsc::channel();
let (a, b) = Channel::create().unwrap();
dispatcher
.spawn_task(async move {
let fut = a.read(dispatcher.as_ref());
futures::pin_mut!(fut);
let Poll::Pending = futures::poll!(fut.as_mut()) else {
panic!("expected pending state after polling channel read once");
};
drop(fut);
b.write_with_data(Arena::new().unwrap(), |arena| arena.insert(1)).unwrap();
assert_eq!(
a.read(dispatcher.as_ref()).await.unwrap().unwrap().data(),
Some(&1)
);
fin_tx.send(()).unwrap();
})
.unwrap();
fin_rx.recv().unwrap();
})
}
}