1use core::future::Future;
8use std::mem::ManuallyDrop;
9use zx::Status;
10
11use crate::{Arena, ArenaBox, DispatcherRef, DriverHandle, Message, MixedHandle};
12use fdf_sys::*;
13
14use core::marker::PhantomData;
15use core::mem::{size_of_val, MaybeUninit};
16use core::num::NonZero;
17use core::pin::Pin;
18use core::ptr::{null_mut, NonNull};
19use core::task::{Context, Poll, Waker};
20use std::sync::{Arc, Mutex};
21
22pub use fdf_sys::fdf_handle_t;
23
24#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
26pub struct Channel<T: ?Sized + 'static>(pub(crate) DriverHandle, PhantomData<Message<T>>);
27
28impl<T: ?Sized + 'static> Channel<T> {
29 pub fn create() -> (Self, Self) {
32 let mut channel1 = 0;
33 let mut channel2 = 0;
34 Status::ok(unsafe { fdf_channel_create(0, &mut channel1, &mut channel2) })
37 .expect("failed to create channel pair");
38 unsafe {
41 (
42 Self::from_handle_unchecked(NonZero::new_unchecked(channel1)),
43 Self::from_handle_unchecked(NonZero::new_unchecked(channel2)),
44 )
45 }
46 }
47
48 pub fn into_driver_handle(self) -> DriverHandle {
51 self.0
52 }
53
54 unsafe fn from_handle_unchecked(handle: NonZero<fdf_handle_t>) -> Self {
61 Self(unsafe { DriverHandle::new_unchecked(handle) }, PhantomData)
63 }
64
65 pub unsafe fn from_driver_handle(handle: DriverHandle) -> Self {
73 Self(handle, PhantomData)
74 }
75
76 pub fn write(&self, message: Message<T>) -> Result<(), Status> {
81 let data_len = message.data().map_or(0, |data| size_of_val(&*data) as u32);
83 let handles_count = message.handles().map_or(0, |handles| handles.len() as u32);
84
85 let (arena, data, handles) = message.into_raw();
86
87 let data_ptr = data.map_or(null_mut(), |data| data.cast().as_ptr());
89 let handles_ptr = handles.map_or(null_mut(), |handles| handles.cast().as_ptr());
90
91 Status::ok(unsafe {
99 fdf_channel_write(
100 self.0.get_raw().get(),
101 0,
102 arena.as_ptr(),
103 data_ptr,
104 data_len,
105 handles_ptr,
106 handles_count,
107 )
108 })?;
109
110 unsafe { fdf_arena_drop_ref(arena.as_ptr()) };
114 Ok(())
115 }
116
117 pub fn write_with<F>(&self, arena: Arena, f: F) -> Result<(), Status>
119 where
120 F: for<'a> FnOnce(
121 &'a Arena,
122 )
123 -> (Option<ArenaBox<'a, T>>, Option<ArenaBox<'a, [Option<MixedHandle>]>>),
124 {
125 self.write(Message::new_with(arena, f))
126 }
127
128 pub fn write_with_data<F>(&self, arena: Arena, f: F) -> Result<(), Status>
130 where
131 F: for<'a> FnOnce(&'a Arena) -> ArenaBox<'a, T>,
132 {
133 self.write(Message::new_with_data(arena, f))
134 }
135}
136
137fn try_read_raw(channel: &DriverHandle) -> Result<Option<Message<[MaybeUninit<u8>]>>, Status> {
141 let mut out_arena = null_mut();
142 let mut out_data = null_mut();
143 let mut out_num_bytes = 0;
144 let mut out_handles = null_mut();
145 let mut out_num_handles = 0;
146 Status::ok(unsafe {
147 fdf_channel_read(
148 channel.get_raw().get(),
149 0,
150 &mut out_arena,
151 &mut out_data,
152 &mut out_num_bytes,
153 &mut out_handles,
154 &mut out_num_handles,
155 )
156 })?;
157 if out_arena == null_mut() {
159 return Ok(None);
160 }
161 let arena = Arena(unsafe { NonNull::new_unchecked(out_arena) });
163 let data_ptr = if !out_data.is_null() {
164 let ptr = core::ptr::slice_from_raw_parts_mut(out_data.cast(), out_num_bytes as usize);
165 Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
168 } else {
169 None
170 };
171 let handles_ptr = if !out_handles.is_null() {
172 let ptr = core::ptr::slice_from_raw_parts_mut(out_handles.cast(), out_num_handles as usize);
173 Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
176 } else {
177 None
178 };
179 Ok(Some(unsafe { Message::new_unchecked(arena, data_ptr, handles_ptr) }))
180}
181
182fn read_raw<'a>(channel: &'a DriverHandle, dispatcher: DispatcherRef<'a>) -> ReadMessageRawFut<'a> {
188 ReadMessageRawFut { raw_fut: unsafe { ReadMessageState::new(channel) }, dispatcher }
191}
192
193impl<T> Channel<T> {
194 pub fn try_read<'a>(&self) -> Result<Option<Message<T>>, Status> {
196 let Some(message) = try_read_raw(&self.0)? else {
198 return Ok(None);
199 };
200 Ok(Some(unsafe { message.cast_unchecked() }))
203 }
204
205 pub async fn read(&self, dispatcher: DispatcherRef<'_>) -> Result<Option<Message<T>>, Status> {
207 let Some(message) = read_raw(&self.0, dispatcher).await? else {
208 return Ok(None);
209 };
210 Ok(Some(unsafe { message.cast_unchecked() }))
213 }
214}
215
216impl Channel<[u8]> {
217 pub fn try_read_bytes<'a>(&self) -> Result<Option<Message<[u8]>>, Status> {
219 let Some(message) = try_read_raw(&self.0)? else {
221 return Ok(None);
222 };
223 Ok(Some(unsafe { message.assume_init() }))
226 }
227
228 pub async fn read_bytes(
230 &self,
231 dispatcher: DispatcherRef<'_>,
232 ) -> Result<Option<Message<[u8]>>, Status> {
233 let Some(message) = read_raw(&self.0, dispatcher).await? else {
235 return Ok(None);
236 };
237 Ok(Some(unsafe { message.assume_init() }))
240 }
241}
242
243impl<T> From<Channel<T>> for MixedHandle {
244 fn from(value: Channel<T>) -> Self {
245 MixedHandle::from(value.0)
246 }
247}
248
249#[repr(C)]
263struct ReadMessageStateOp {
264 read_op: fdf_channel_read,
266 waker: Mutex<Option<Waker>>,
267}
268
269impl ReadMessageStateOp {
270 unsafe extern "C" fn handler(
271 _dispatcher: *mut fdf_dispatcher,
272 read_op: *mut fdf_channel_read,
273 _status: i32,
274 ) {
275 let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
278 let Some(waker) = op.waker.lock().unwrap().take() else {
279 return;
281 };
282 waker.wake()
283 }
284}
285
286pub(crate) struct ReadMessageState {
289 op: Arc<ReadMessageStateOp>,
290 channel: ManuallyDrop<DriverHandle>,
291 callback_drops_arc: bool,
292}
293
294impl ReadMessageState {
295 pub(crate) unsafe fn new(channel: &DriverHandle) -> Self {
304 let channel = unsafe { channel.get_raw() };
307 Self {
308 op: Arc::new(ReadMessageStateOp {
309 read_op: fdf_channel_read {
310 channel: channel.get(),
311 handler: Some(ReadMessageStateOp::handler),
312 ..Default::default()
313 },
314 waker: Mutex::new(None),
315 }),
316 channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel) }),
320 callback_drops_arc: false,
323 }
324 }
325
326 pub(crate) fn poll_with_dispatcher(
328 self: &mut Self,
329 cx: &mut Context<'_>,
330 dispatcher: DispatcherRef<'_>,
331 ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
332 let mut waker_lock = self.op.waker.lock().unwrap();
333
334 match try_read_raw(&self.channel) {
335 Ok(res) => Poll::Ready(Ok(res)),
336 Err(Status::SHOULD_WAIT) => {
337 if waker_lock.replace(cx.waker().clone()).is_none() {
340 let op = Arc::into_raw(self.op.clone());
343 let res = Status::ok(unsafe {
346 fdf_channel_wait_async(dispatcher.0.as_ptr(), op.cast_mut().cast(), 0)
347 });
348 match res {
349 Ok(()) => {
350 self.callback_drops_arc = dispatcher.is_unsynchronized();
354 }
355 Err(e) => return Poll::Ready(Err(e)),
356 }
357 }
358 Poll::Pending
359 }
360 Err(e) => Poll::Ready(Err(e)),
361 }
362 }
363}
364
365impl Drop for ReadMessageState {
366 fn drop(&mut self) {
367 let mut waker_lock = self.op.waker.lock().unwrap();
368 if waker_lock.is_none() {
369 return;
372 }
373
374 let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
377 match res {
378 Ok(_) => {}
379 Err(Status::NOT_FOUND) => {
380 return;
383 }
384 Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
385 }
386 waker_lock.take();
388 if !self.callback_drops_arc {
392 unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
393 }
394 }
395}
396
397struct ReadMessageRawFut<'a> {
398 raw_fut: ReadMessageState,
399 dispatcher: DispatcherRef<'a>,
400}
401
402impl<'a> Future for ReadMessageRawFut<'a> {
403 type Output = Result<Option<Message<[MaybeUninit<u8>]>>, Status>;
404
405 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
406 let dispatcher = self.dispatcher.clone();
407 self.as_mut().raw_fut.poll_with_dispatcher(cx, dispatcher)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use std::pin::pin;
414 use std::sync::{mpsc, Weak};
415
416 use crate::test::{with_raw_dispatcher, with_raw_dispatcher_flags};
417 use crate::tests::DropSender;
418 use crate::{Dispatcher, DispatcherBuilder, MixedHandleType};
419
420 use super::*;
421
422 #[test]
423 fn send_and_receive_bytes_synchronously() {
424 let (first, second) = Channel::create();
425 let arena = Arena::new();
426 assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
427 first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
428 assert_eq!(&*second.try_read_bytes().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
429 assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
430 second.write_with_data(arena.clone(), |arena| arena.insert_slice(&[5, 6, 7, 8])).unwrap();
431 assert_eq!(&*first.try_read_bytes().unwrap().unwrap().data().unwrap(), &[5, 6, 7, 8]);
432 assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
433 assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
434 drop(second);
435 assert_eq!(
436 first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[9, 10, 11, 12])),
437 Err(Status::from_raw(ZX_ERR_PEER_CLOSED))
438 );
439 }
440
441 #[test]
442 fn send_and_receive_bytes_asynchronously() {
443 with_raw_dispatcher("channel async", |dispatcher| {
444 let arena = Arena::new();
445 let (fin_tx, fin_rx) = mpsc::channel();
446 let (first, second) = Channel::create();
447
448 let dispatcher = dispatcher.clone();
449 dispatcher
450 .clone()
451 .spawn_task(async move {
452 fin_tx
453 .send(first.read_bytes(dispatcher.as_dispatcher_ref()).await.unwrap())
454 .unwrap();
455 })
456 .unwrap();
457 second.write_with_data(arena, |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
458 assert_eq!(fin_rx.recv().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
459 });
460 }
461
462 #[test]
463 fn send_and_receive_objects_synchronously() {
464 let arena = Arena::new();
465 let (first, second) = Channel::create();
466 let (tx, rx) = mpsc::channel();
467 first
468 .write_with_data(arena.clone(), |arena| arena.insert(DropSender::new(1, tx.clone())))
469 .unwrap();
470 rx.try_recv().expect_err("should not drop the object when sent");
471 let message = second.try_read().unwrap().unwrap();
472 assert_eq!(message.data().unwrap().0, 1);
473 rx.try_recv().expect_err("should not drop the object when received");
474 drop(message);
475 rx.try_recv().expect("dropped when received");
476 }
477
478 #[test]
479 fn send_and_receive_handles_synchronously() {
480 println!("Create channels and write one end of one of the channel pairs to the other");
481 let (first, second) = Channel::<()>::create();
482 let (inner_first, inner_second) = Channel::<String>::create();
483 let message = Message::new_with(Arena::new(), |arena| {
484 (None, Some(arena.insert_boxed_slice(Box::new([Some(inner_first.into())]))))
485 });
486 first.write(message).unwrap();
487
488 println!("Receive the channel back on the other end of the first channel pair.");
489 let mut arena = None;
490 let message =
491 second.try_read().unwrap().expect("Expected a message with contents to be received");
492 let (_, received_handles) = message.into_arena_boxes(&mut arena);
493 let mut first_handle_received =
494 ArenaBox::take_boxed_slice(received_handles.expect("expected handles in the message"));
495 let first_handle_received = first_handle_received
496 .first_mut()
497 .expect("expected one handle in the handle set")
498 .take()
499 .expect("expected the first handle to be non-null");
500 let first_handle_received = first_handle_received.resolve();
501 let MixedHandleType::Driver(driver_handle) = first_handle_received else {
502 panic!("Got a non-driver handle when we sent a driver handle");
503 };
504 let inner_first_received = unsafe { Channel::from_driver_handle(driver_handle) };
505
506 println!("Send and receive a string across the now-transmitted channel pair.");
507 inner_first_received
508 .write_with_data(Arena::new(), |arena| arena.insert("boom".to_string()))
509 .unwrap();
510 assert_eq!(inner_second.try_read().unwrap().unwrap().data().unwrap(), &"boom".to_string());
511 }
512
513 async fn ping(dispatcher: Arc<Dispatcher>, chan: Channel<u8>) {
514 println!("starting ping!");
515 chan.write_with_data(Arena::new(), |arena| arena.insert(0)).unwrap();
516 while let Ok(Some(msg)) = chan.read(dispatcher.as_dispatcher_ref()).await {
517 let next = *msg.data().unwrap();
518 println!("ping! {next}");
519 chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
520 }
521 }
522
523 async fn pong(
524 dispatcher: Arc<Dispatcher>,
525 fin_tx: std::sync::mpsc::Sender<()>,
526 chan: Channel<u8>,
527 ) {
528 println!("starting pong!");
529 while let Some(msg) = chan.read(dispatcher.as_dispatcher_ref()).await.unwrap() {
530 let next = *msg.data().unwrap();
531 println!("pong! {next}");
532 if next > 10 {
533 println!("bye!");
534 break;
535 }
536 chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
537 }
538 fin_tx.send(()).unwrap();
539 }
540
541 #[test]
542 fn async_ping_pong() {
543 with_raw_dispatcher("async ping pong", |dispatcher| {
544 let (fin_tx, fin_rx) = mpsc::channel();
545 let (ping_chan, pong_chan) = Channel::create();
546 dispatcher.spawn_task(ping(dispatcher.clone(), ping_chan)).unwrap();
547 dispatcher.spawn_task(pong(dispatcher.clone(), fin_tx, pong_chan)).unwrap();
548
549 fin_rx.recv().expect("to receive final value");
550 });
551 }
552
553 #[test]
554 fn async_ping_pong_on_fuchsia_async() {
555 with_raw_dispatcher("async ping pong", |dispatcher| {
556 let (fin_tx, fin_rx) = mpsc::channel();
557 let (ping_chan, pong_chan) = Channel::create();
558
559 let dispatcher = dispatcher.clone();
560 dispatcher
561 .clone()
562 .post_task_sync(move |_status| {
563 let rust_async_dispatcher_fin_tx = fin_tx.clone();
564 let rust_async_dispatcher = crate::DispatcherBuilder::new()
565 .name("fuchsia-async")
566 .allow_thread_blocking()
567 .shutdown_observer(move |_| rust_async_dispatcher_fin_tx.send(()).unwrap())
568 .create()
569 .expect("failure creating blocking dispatcher for rust async");
570
571 dispatcher.spawn_task(pong(dispatcher.clone(), fin_tx, pong_chan)).unwrap();
572 let dispatcher = dispatcher.clone();
573 rust_async_dispatcher
574 .post_task_sync(move |_| {
575 let mut executor = fuchsia_async::LocalExecutor::new();
576 executor.run_singlethreaded(ping(dispatcher, ping_chan));
577 })
578 .unwrap();
579 })
580 .unwrap();
581
582 while fin_rx.recv().is_ok() {}
584 });
585 }
586
587 fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
589 assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
590 }
591
592 async fn read_and_drop<T: ?Sized + 'static>(
597 channel: &Channel<T>,
598 dispatcher: DispatcherRef<'_>,
599 ) -> Weak<ReadMessageStateOp> {
600 let fut = read_raw(&channel.0, dispatcher.as_dispatcher_ref());
601 let op_arc = Arc::downgrade(&fut.raw_fut.op);
602 assert_strong_count(&op_arc, 1);
603 let mut fut = pin!(fut);
604 let Poll::Pending = futures::poll!(fut.as_mut()) else {
605 panic!("expected pending state after polling channel read once");
606 };
607 assert_strong_count(&op_arc, 2);
608 op_arc
609 }
610
611 #[test]
612 fn early_cancel_future() {
613 with_raw_dispatcher("early cancellation", |dispatcher| {
614 let (fin_tx, fin_rx) = mpsc::channel();
615 let (a, b) = Channel::create();
616 let dispatcher = dispatcher.clone();
617 dispatcher
618 .clone()
619 .spawn_task(async move {
620 read_and_drop(&a, dispatcher.as_dispatcher_ref()).await;
623 b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
624 assert_eq!(
625 a.read(dispatcher.as_dispatcher_ref()).await.unwrap().unwrap().data(),
626 Some(&1)
627 );
628 fin_tx.send(()).unwrap();
629 })
630 .unwrap();
631 fin_rx.recv().unwrap();
632 })
633 }
634
635 #[test]
636 fn very_early_cancel_state_drops_correctly() {
637 with_raw_dispatcher("early cancellation drop correctness", |dispatcher| {
638 let (a, _b) = Channel::<[u8]>::create();
639 let (fin_tx, fin_rx) = mpsc::channel();
640
641 let dispatcher = dispatcher.clone();
642 dispatcher
643 .clone()
644 .spawn_task(async move {
645 let fut = read_raw(&a.0, dispatcher.as_dispatcher_ref());
647 let op_arc = Arc::downgrade(&fut.raw_fut.op);
648 assert_strong_count(&op_arc, 1);
649 drop(fut);
650 assert_strong_count(&op_arc, 0);
651 fin_tx.send(()).unwrap();
652 })
653 .unwrap();
654 fin_rx.recv().unwrap()
655 })
656 }
657
658 #[test]
659 fn synchronized_early_cancel_state_drops_correctly() {
660 with_raw_dispatcher("early cancellation drop correctness", |dispatcher| {
661 let (a, _b) = Channel::<[u8]>::create();
662 let (fin_tx, fin_rx) = mpsc::channel();
663
664 let dispatcher = dispatcher.clone();
665 dispatcher
666 .clone()
667 .spawn_task(async move {
668 assert_strong_count(
669 &read_and_drop(&a, dispatcher.as_dispatcher_ref()).await,
670 0,
671 );
672 fin_tx.send(()).unwrap();
673 })
674 .unwrap();
675 fin_rx.recv().unwrap()
676 });
677 }
678
679 #[test]
680 fn unsynchronized_early_cancel_state_drops_correctly() {
681 let (a, _b) = Channel::<[u8]>::create();
684 let (unsync_op, _a) = with_raw_dispatcher_flags(
685 "early cancellation drop correctness",
686 DispatcherBuilder::UNSYNCHRONIZED,
687 |dispatcher| {
688 let (fin_tx, fin_rx) = mpsc::channel();
689
690 let inner_dispatcher = dispatcher.clone();
691 dispatcher
692 .spawn_task(async move {
693 let res = read_and_drop(&a, inner_dispatcher.as_dispatcher_ref()).await;
698 fin_tx.send((res, a)).unwrap();
699 })
700 .unwrap();
701 fin_rx.recv().unwrap()
702 },
703 );
704
705 assert_strong_count(&unsync_op, 0);
707 }
708}