fdf/
channel.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the driver runtime channel stable ABI
6
7use core::future::Future;
8use std::mem::ManuallyDrop;
9use zx::Status;
10
11use crate::{Arena, ArenaBox, DispatcherRef, DriverHandle, Message, MixedHandle};
12use fdf_sys::*;
13
14use core::marker::PhantomData;
15use core::mem::{size_of_val, MaybeUninit};
16use core::num::NonZero;
17use core::pin::Pin;
18use core::ptr::{null_mut, NonNull};
19use core::task::{Context, Poll, Waker};
20use std::sync::{Arc, Mutex};
21
22pub use fdf_sys::fdf_handle_t;
23
24/// Implements a message channel through the Fuchsia Driver Runtime
25#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
26pub struct Channel<T: ?Sized + 'static>(pub(crate) DriverHandle, PhantomData<Message<T>>);
27
28impl<T: ?Sized + 'static> Channel<T> {
29    /// Creates a new channel pair that can be used to send messages of type `T`
30    /// between threads managed by the driver runtime.
31    pub fn create() -> (Self, Self) {
32        let mut channel1 = 0;
33        let mut channel2 = 0;
34        // This call cannot fail as the only reason it would fail is due to invalid
35        // option flags, and 0 is a valid option.
36        Status::ok(unsafe { fdf_channel_create(0, &mut channel1, &mut channel2) })
37            .expect("failed to create channel pair");
38        // SAFETY: if fdf_channel_create returned ZX_OK, it will have placed
39        // valid channel handles that must be non-zero.
40        unsafe {
41            (
42                Self::from_handle_unchecked(NonZero::new_unchecked(channel1)),
43                Self::from_handle_unchecked(NonZero::new_unchecked(channel2)),
44            )
45        }
46    }
47
48    /// Takes the inner handle to the channel. The caller is responsible for ensuring
49    /// that the handle is freed.
50    pub fn into_driver_handle(self) -> DriverHandle {
51        self.0
52    }
53
54    /// Initializes a [`Channel`] object from the given non-zero handle.
55    ///
56    /// # Safety
57    ///
58    /// The caller must ensure that the handle is not invalid and that it is
59    /// part of a driver runtime channel pair of type `T`.
60    unsafe fn from_handle_unchecked(handle: NonZero<fdf_handle_t>) -> Self {
61        // SAFETY: caller is responsible for ensuring that it is a valid channel
62        Self(unsafe { DriverHandle::new_unchecked(handle) }, PhantomData)
63    }
64
65    /// Initializes a [`Channel`] object from the given [`DriverHandle`],
66    /// assuming that it is a channel of type `T`.
67    ///
68    /// # Safety
69    ///
70    /// The caller must ensure that the handle is a [`Channel`]-based handle that is
71    /// using type `T` as its wire format.
72    pub unsafe fn from_driver_handle(handle: DriverHandle) -> Self {
73        Self(handle, PhantomData)
74    }
75
76    /// Writes the [`Message`] given to the channel. This will complete asynchronously and can't
77    /// be cancelled.
78    ///
79    /// The channel will take ownership of the data and handles passed in,
80    pub fn write(&self, message: Message<T>) -> Result<(), Status> {
81        // get the sizes while the we still have refs to the data and handles
82        let data_len = message.data().map_or(0, |data| size_of_val(&*data) as u32);
83        let handles_count = message.handles().map_or(0, |handles| handles.len() as u32);
84
85        let (arena, data, handles) = message.into_raw();
86
87        // transform the `Option<NonNull<T>>` into just `*mut T`
88        let data_ptr = data.map_or(null_mut(), |data| data.cast().as_ptr());
89        let handles_ptr = handles.map_or(null_mut(), |handles| handles.cast().as_ptr());
90
91        // SAFETY:
92        // - Normally, we could be reading uninit bytes here. However, as long as fdf_channel_write
93        //   doesn't allow cross-LTO then it won't care whether the bytes are initialized.
94        // - The `Message` will generally only construct correctly if the data and handles pointers
95        //   inside it are from the arena it holds, but just in case `fdf_channel_write` will check
96        //   that we are using the correct arena so we do not need to re-verify that they are from
97        //   the same arena.
98        Status::ok(unsafe {
99            fdf_channel_write(
100                self.0.get_raw().get(),
101                0,
102                arena.as_ptr(),
103                data_ptr,
104                data_len,
105                handles_ptr,
106                handles_count,
107            )
108        })?;
109
110        // SAFETY: this is the valid-by-contruction arena we were passed in through the [`Message`]
111        // object, and now that we have completed `fdf_channel_write` it is safe to drop our copy
112        // of it.
113        unsafe { fdf_arena_drop_ref(arena.as_ptr()) };
114        Ok(())
115    }
116
117    /// Shorthand for calling [`Self::write`] with the result of [`Message::new_with`]
118    pub fn write_with<F>(&self, arena: Arena, f: F) -> Result<(), Status>
119    where
120        F: for<'a> FnOnce(
121            &'a Arena,
122        )
123            -> (Option<ArenaBox<'a, T>>, Option<ArenaBox<'a, [Option<MixedHandle>]>>),
124    {
125        self.write(Message::new_with(arena, f))
126    }
127
128    /// Shorthand for calling [`Self::write`] with the result of [`Message::new_with`]
129    pub fn write_with_data<F>(&self, arena: Arena, f: F) -> Result<(), Status>
130    where
131        F: for<'a> FnOnce(&'a Arena) -> ArenaBox<'a, T>,
132    {
133        self.write(Message::new_with_data(arena, f))
134    }
135}
136
137/// Attempts to read from the channel, returning a [`Message`] object that can be used to
138/// access or take the data received if there was any. This is the basic building block
139/// on which the other `try_read_*` methods are built.
140fn try_read_raw(channel: &DriverHandle) -> Result<Option<Message<[MaybeUninit<u8>]>>, Status> {
141    let mut out_arena = null_mut();
142    let mut out_data = null_mut();
143    let mut out_num_bytes = 0;
144    let mut out_handles = null_mut();
145    let mut out_num_handles = 0;
146    Status::ok(unsafe {
147        fdf_channel_read(
148            channel.get_raw().get(),
149            0,
150            &mut out_arena,
151            &mut out_data,
152            &mut out_num_bytes,
153            &mut out_handles,
154            &mut out_num_handles,
155        )
156    })?;
157    // if no arena was returned, that means no data was returned.
158    if out_arena == null_mut() {
159        return Ok(None);
160    }
161    // SAFETY: we just checked that the `out_arena` is non-null
162    let arena = Arena(unsafe { NonNull::new_unchecked(out_arena) });
163    let data_ptr = if !out_data.is_null() {
164        let ptr = core::ptr::slice_from_raw_parts_mut(out_data.cast(), out_num_bytes as usize);
165        // SAFETY: we just checked that the pointer was non-null, the slice version of it should
166        // be too.
167        Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
168    } else {
169        None
170    };
171    let handles_ptr = if !out_handles.is_null() {
172        let ptr = core::ptr::slice_from_raw_parts_mut(out_handles.cast(), out_num_handles as usize);
173        // SAFETY: we just checked that the pointer was non-null, the slice version of it should
174        // be too.
175        Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
176    } else {
177        None
178    };
179    Ok(Some(unsafe { Message::new_unchecked(arena, data_ptr, handles_ptr) }))
180}
181
182/// Reads a message from the channel asynchronously
183///
184/// # Panic
185///
186/// Panics if this is not run from a driver framework dispatcher.
187fn read_raw<'a>(channel: &'a DriverHandle, dispatcher: DispatcherRef<'a>) -> ReadMessageRawFut<'a> {
188    // SAFETY: Since the future's lifetime is bound to the original driver handle and it
189    // holds the message state, the message state object can't outlive the handle.
190    ReadMessageRawFut { raw_fut: unsafe { ReadMessageState::new(channel) }, dispatcher }
191}
192
193impl<T> Channel<T> {
194    /// Attempts to read an object of type `T` and a handle set from the channel
195    pub fn try_read<'a>(&self) -> Result<Option<Message<T>>, Status> {
196        // read a message from the channel
197        let Some(message) = try_read_raw(&self.0)? else {
198            return Ok(None);
199        };
200        // SAFETY: It is an invariant of Channel<T> that messages sent or received are always of
201        // type T.
202        Ok(Some(unsafe { message.cast_unchecked() }))
203    }
204
205    /// Reads an object of type `T` and a handle set from the channel asynchronously
206    pub async fn read(&self, dispatcher: DispatcherRef<'_>) -> Result<Option<Message<T>>, Status> {
207        let Some(message) = read_raw(&self.0, dispatcher).await? else {
208            return Ok(None);
209        };
210        // SAFETY: It is an invariant of Channel<T> that messages sent or received are always of
211        // type T.
212        Ok(Some(unsafe { message.cast_unchecked() }))
213    }
214}
215
216impl Channel<[u8]> {
217    /// Attempts to read an object of type `T` and a handle set from the channel
218    pub fn try_read_bytes<'a>(&self) -> Result<Option<Message<[u8]>>, Status> {
219        // read a message from the channel
220        let Some(message) = try_read_raw(&self.0)? else {
221            return Ok(None);
222        };
223        // SAFETY: It is an invariant of Channel<[u8]> that messages sent or received are always of
224        // type [u8].
225        Ok(Some(unsafe { message.assume_init() }))
226    }
227
228    /// Reads a slice of type `T` and a handle set from the channel asynchronously
229    pub async fn read_bytes(
230        &self,
231        dispatcher: DispatcherRef<'_>,
232    ) -> Result<Option<Message<[u8]>>, Status> {
233        // read a message from the channel
234        let Some(message) = read_raw(&self.0, dispatcher).await? else {
235            return Ok(None);
236        };
237        // SAFETY: It is an invariant of Channel<[u8]> that messages sent or received are always of
238        // type [u8].
239        Ok(Some(unsafe { message.assume_init() }))
240    }
241}
242
243impl<T> From<Channel<T>> for MixedHandle {
244    fn from(value: Channel<T>) -> Self {
245        MixedHandle::from(value.0)
246    }
247}
248
249/// This struct is shared between the future and the driver runtime, with the first field
250/// being managed by the driver runtime and the second by the future. It will be held by two
251/// [`Arc`]s, one for each of the future and the runtime.
252///
253/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
254/// normal [`Drop`] of the future.
255///
256/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
257/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
258/// called *up to* one time.
259///
260/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
261/// read wait.
262#[repr(C)]
263struct ReadMessageStateOp {
264    /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
265    read_op: fdf_channel_read,
266    waker: Mutex<Option<Waker>>,
267}
268
269impl ReadMessageStateOp {
270    unsafe extern "C" fn handler(
271        _dispatcher: *mut fdf_dispatcher,
272        read_op: *mut fdf_channel_read,
273        _status: i32,
274    ) {
275        // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
276        // for this handler to reconstitute it.
277        let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
278        let Some(waker) = op.waker.lock().unwrap().take() else {
279            // the waker was already taken, presumably because the future was dropped.
280            return;
281        };
282        waker.wake()
283    }
284}
285
286/// An object for managing the state of an async channel read message operation that can be used to
287/// implement futures.
288pub(crate) struct ReadMessageState {
289    op: Arc<ReadMessageStateOp>,
290    channel: ManuallyDrop<DriverHandle>,
291    callback_drops_arc: bool,
292}
293
294impl ReadMessageState {
295    /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
296    /// data from a channel and then converts it to the appropriate type. It also allows for
297    /// different ways of storing and managing the dispatcher we wait on by deferring the
298    /// dispatcher used to poll time.
299    ///
300    /// # Safety
301    ///
302    /// The caller is responsible for ensuring that `channel` outlives this object.
303    pub(crate) unsafe fn new(channel: &DriverHandle) -> Self {
304        // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
305        // and that the handle will outlive the created [`ReadMessageState`].
306        let channel = unsafe { channel.get_raw() };
307        Self {
308            op: Arc::new(ReadMessageStateOp {
309                read_op: fdf_channel_read {
310                    channel: channel.get(),
311                    handler: Some(ReadMessageStateOp::handler),
312                    ..Default::default()
313                },
314                waker: Mutex::new(None),
315            }),
316            // SAFETY: We know this is a valid driver handle by construction and we are
317            // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
318            // The caller is responsible for ensuring that the handle outlives this object.
319            channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel) }),
320            // We haven't waited on it yet so we are responsible for dropping the arc for now,
321            // regardless of what kind of dispatcher it's intended to be used with.
322            callback_drops_arc: false,
323        }
324    }
325
326    /// Polls this channel read operation against the given dispatcher.
327    pub(crate) fn poll_with_dispatcher(
328        self: &mut Self,
329        cx: &mut Context<'_>,
330        dispatcher: DispatcherRef<'_>,
331    ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
332        let mut waker_lock = self.op.waker.lock().unwrap();
333
334        match try_read_raw(&self.channel) {
335            Ok(res) => Poll::Ready(Ok(res)),
336            Err(Status::SHOULD_WAIT) => {
337                // if we haven't yet set a waker, that means we haven't started the wait operation
338                // yet.
339                if waker_lock.replace(cx.waker().clone()).is_none() {
340                    // increment the reference count of the read op to account for the copy that will be given to
341                    // `fdf_channel_wait_async`.
342                    let op = Arc::into_raw(self.op.clone());
343                    // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
344                    // has `repr(C)` layout, so is safe to be cast to the latter.
345                    let res = Status::ok(unsafe {
346                        fdf_channel_wait_async(dispatcher.0.as_ptr(), op.cast_mut().cast(), 0)
347                    });
348                    match res {
349                        Ok(()) => {
350                            // if the dispatcher we're waiting on is unsynchronized, the callback
351                            // will drop the Arc and we need to indicate to our own Drop impl
352                            // that it should not.
353                            self.callback_drops_arc = dispatcher.is_unsynchronized();
354                        }
355                        Err(e) => return Poll::Ready(Err(e)),
356                    }
357                }
358                Poll::Pending
359            }
360            Err(e) => Poll::Ready(Err(e)),
361        }
362    }
363}
364
365impl Drop for ReadMessageState {
366    fn drop(&mut self) {
367        let mut waker_lock = self.op.waker.lock().unwrap();
368        if waker_lock.is_none() {
369            // if there's no waker either the callback has already fired or we never waited on this
370            // future in the first place, so just leave it be.
371            return;
372        }
373
374        // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
375        // be valid.
376        let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
377        match res {
378            Ok(_) => {}
379            Err(Status::NOT_FOUND) => {
380                // the callback is already being called or the wait was already cancelled, so just
381                // return and leave it.
382                return;
383            }
384            Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
385        }
386        // steal the waker so it doesn't get called, if there is one.
387        waker_lock.take();
388        // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
389        // successful, the callback will not be called and we will have to free the `Arc` that the
390        // callback would have consumed.
391        if !self.callback_drops_arc {
392            unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
393        }
394    }
395}
396
397struct ReadMessageRawFut<'a> {
398    raw_fut: ReadMessageState,
399    dispatcher: DispatcherRef<'a>,
400}
401
402impl<'a> Future for ReadMessageRawFut<'a> {
403    type Output = Result<Option<Message<[MaybeUninit<u8>]>>, Status>;
404
405    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
406        let dispatcher = self.dispatcher.clone();
407        self.as_mut().raw_fut.poll_with_dispatcher(cx, dispatcher)
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use std::pin::pin;
414    use std::sync::{mpsc, Weak};
415
416    use crate::test::{with_raw_dispatcher, with_raw_dispatcher_flags};
417    use crate::tests::DropSender;
418    use crate::{Dispatcher, DispatcherBuilder, MixedHandleType};
419
420    use super::*;
421
422    #[test]
423    fn send_and_receive_bytes_synchronously() {
424        let (first, second) = Channel::create();
425        let arena = Arena::new();
426        assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
427        first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
428        assert_eq!(&*second.try_read_bytes().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
429        assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
430        second.write_with_data(arena.clone(), |arena| arena.insert_slice(&[5, 6, 7, 8])).unwrap();
431        assert_eq!(&*first.try_read_bytes().unwrap().unwrap().data().unwrap(), &[5, 6, 7, 8]);
432        assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
433        assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
434        drop(second);
435        assert_eq!(
436            first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[9, 10, 11, 12])),
437            Err(Status::from_raw(ZX_ERR_PEER_CLOSED))
438        );
439    }
440
441    #[test]
442    fn send_and_receive_bytes_asynchronously() {
443        with_raw_dispatcher("channel async", |dispatcher| {
444            let arena = Arena::new();
445            let (fin_tx, fin_rx) = mpsc::channel();
446            let (first, second) = Channel::create();
447
448            let dispatcher = dispatcher.clone();
449            dispatcher
450                .clone()
451                .spawn_task(async move {
452                    fin_tx
453                        .send(first.read_bytes(dispatcher.as_dispatcher_ref()).await.unwrap())
454                        .unwrap();
455                })
456                .unwrap();
457            second.write_with_data(arena, |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
458            assert_eq!(fin_rx.recv().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
459        });
460    }
461
462    #[test]
463    fn send_and_receive_objects_synchronously() {
464        let arena = Arena::new();
465        let (first, second) = Channel::create();
466        let (tx, rx) = mpsc::channel();
467        first
468            .write_with_data(arena.clone(), |arena| arena.insert(DropSender::new(1, tx.clone())))
469            .unwrap();
470        rx.try_recv().expect_err("should not drop the object when sent");
471        let message = second.try_read().unwrap().unwrap();
472        assert_eq!(message.data().unwrap().0, 1);
473        rx.try_recv().expect_err("should not drop the object when received");
474        drop(message);
475        rx.try_recv().expect("dropped when received");
476    }
477
478    #[test]
479    fn send_and_receive_handles_synchronously() {
480        println!("Create channels and write one end of one of the channel pairs to the other");
481        let (first, second) = Channel::<()>::create();
482        let (inner_first, inner_second) = Channel::<String>::create();
483        let message = Message::new_with(Arena::new(), |arena| {
484            (None, Some(arena.insert_boxed_slice(Box::new([Some(inner_first.into())]))))
485        });
486        first.write(message).unwrap();
487
488        println!("Receive the channel back on the other end of the first channel pair.");
489        let mut arena = None;
490        let message =
491            second.try_read().unwrap().expect("Expected a message with contents to be received");
492        let (_, received_handles) = message.into_arena_boxes(&mut arena);
493        let mut first_handle_received =
494            ArenaBox::take_boxed_slice(received_handles.expect("expected handles in the message"));
495        let first_handle_received = first_handle_received
496            .first_mut()
497            .expect("expected one handle in the handle set")
498            .take()
499            .expect("expected the first handle to be non-null");
500        let first_handle_received = first_handle_received.resolve();
501        let MixedHandleType::Driver(driver_handle) = first_handle_received else {
502            panic!("Got a non-driver handle when we sent a driver handle");
503        };
504        let inner_first_received = unsafe { Channel::from_driver_handle(driver_handle) };
505
506        println!("Send and receive a string across the now-transmitted channel pair.");
507        inner_first_received
508            .write_with_data(Arena::new(), |arena| arena.insert("boom".to_string()))
509            .unwrap();
510        assert_eq!(inner_second.try_read().unwrap().unwrap().data().unwrap(), &"boom".to_string());
511    }
512
513    async fn ping(dispatcher: Arc<Dispatcher>, chan: Channel<u8>) {
514        println!("starting ping!");
515        chan.write_with_data(Arena::new(), |arena| arena.insert(0)).unwrap();
516        while let Ok(Some(msg)) = chan.read(dispatcher.as_dispatcher_ref()).await {
517            let next = *msg.data().unwrap();
518            println!("ping! {next}");
519            chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
520        }
521    }
522
523    async fn pong(
524        dispatcher: Arc<Dispatcher>,
525        fin_tx: std::sync::mpsc::Sender<()>,
526        chan: Channel<u8>,
527    ) {
528        println!("starting pong!");
529        while let Some(msg) = chan.read(dispatcher.as_dispatcher_ref()).await.unwrap() {
530            let next = *msg.data().unwrap();
531            println!("pong! {next}");
532            if next > 10 {
533                println!("bye!");
534                break;
535            }
536            chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
537        }
538        fin_tx.send(()).unwrap();
539    }
540
541    #[test]
542    fn async_ping_pong() {
543        with_raw_dispatcher("async ping pong", |dispatcher| {
544            let (fin_tx, fin_rx) = mpsc::channel();
545            let (ping_chan, pong_chan) = Channel::create();
546            dispatcher.spawn_task(ping(dispatcher.clone(), ping_chan)).unwrap();
547            dispatcher.spawn_task(pong(dispatcher.clone(), fin_tx, pong_chan)).unwrap();
548
549            fin_rx.recv().expect("to receive final value");
550        });
551    }
552
553    #[test]
554    fn async_ping_pong_on_fuchsia_async() {
555        with_raw_dispatcher("async ping pong", |dispatcher| {
556            let (fin_tx, fin_rx) = mpsc::channel();
557            let (ping_chan, pong_chan) = Channel::create();
558
559            let dispatcher = dispatcher.clone();
560            dispatcher
561                .clone()
562                .post_task_sync(move |_status| {
563                    let rust_async_dispatcher_fin_tx = fin_tx.clone();
564                    let rust_async_dispatcher = crate::DispatcherBuilder::new()
565                        .name("fuchsia-async")
566                        .allow_thread_blocking()
567                        .shutdown_observer(move |_| rust_async_dispatcher_fin_tx.send(()).unwrap())
568                        .create()
569                        .expect("failure creating blocking dispatcher for rust async");
570
571                    dispatcher.spawn_task(pong(dispatcher.clone(), fin_tx, pong_chan)).unwrap();
572                    let dispatcher = dispatcher.clone();
573                    rust_async_dispatcher
574                        .post_task_sync(move |_| {
575                            let mut executor = fuchsia_async::LocalExecutor::new();
576                            executor.run_singlethreaded(ping(dispatcher, ping_chan));
577                        })
578                        .unwrap();
579                })
580                .unwrap();
581
582            // wait for everything to shut down.
583            while fin_rx.recv().is_ok() {}
584        });
585    }
586
587    /// assert that the strong count of an arc is correct
588    fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
589        assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
590    }
591
592    /// create, poll, and then immediately drop a read future for a channel and verify
593    /// that the internal op arc has the right refcount at all steps. Returns a copy
594    /// of the op arc at the end so it can be verified that the count goes down
595    /// to zero correctly.
596    async fn read_and_drop<T: ?Sized + 'static>(
597        channel: &Channel<T>,
598        dispatcher: DispatcherRef<'_>,
599    ) -> Weak<ReadMessageStateOp> {
600        let fut = read_raw(&channel.0, dispatcher.as_dispatcher_ref());
601        let op_arc = Arc::downgrade(&fut.raw_fut.op);
602        assert_strong_count(&op_arc, 1);
603        let mut fut = pin!(fut);
604        let Poll::Pending = futures::poll!(fut.as_mut()) else {
605            panic!("expected pending state after polling channel read once");
606        };
607        assert_strong_count(&op_arc, 2);
608        op_arc
609    }
610
611    #[test]
612    fn early_cancel_future() {
613        with_raw_dispatcher("early cancellation", |dispatcher| {
614            let (fin_tx, fin_rx) = mpsc::channel();
615            let (a, b) = Channel::create();
616            let dispatcher = dispatcher.clone();
617            dispatcher
618                .clone()
619                .spawn_task(async move {
620                    // create, poll, and then immediately drop a read future for channel `a`
621                    // so that it properly sets up the wait.
622                    read_and_drop(&a, dispatcher.as_dispatcher_ref()).await;
623                    b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
624                    assert_eq!(
625                        a.read(dispatcher.as_dispatcher_ref()).await.unwrap().unwrap().data(),
626                        Some(&1)
627                    );
628                    fin_tx.send(()).unwrap();
629                })
630                .unwrap();
631            fin_rx.recv().unwrap();
632        })
633    }
634
635    #[test]
636    fn very_early_cancel_state_drops_correctly() {
637        with_raw_dispatcher("early cancellation drop correctness", |dispatcher| {
638            let (a, _b) = Channel::<[u8]>::create();
639            let (fin_tx, fin_rx) = mpsc::channel();
640
641            let dispatcher = dispatcher.clone();
642            dispatcher
643                .clone()
644                .spawn_task(async move {
645                    // drop before even polling it should drop the arc correctly
646                    let fut = read_raw(&a.0, dispatcher.as_dispatcher_ref());
647                    let op_arc = Arc::downgrade(&fut.raw_fut.op);
648                    assert_strong_count(&op_arc, 1);
649                    drop(fut);
650                    assert_strong_count(&op_arc, 0);
651                    fin_tx.send(()).unwrap();
652                })
653                .unwrap();
654            fin_rx.recv().unwrap()
655        })
656    }
657
658    #[test]
659    fn synchronized_early_cancel_state_drops_correctly() {
660        with_raw_dispatcher("early cancellation drop correctness", |dispatcher| {
661            let (a, _b) = Channel::<[u8]>::create();
662            let (fin_tx, fin_rx) = mpsc::channel();
663
664            let dispatcher = dispatcher.clone();
665            dispatcher
666                .clone()
667                .spawn_task(async move {
668                    assert_strong_count(
669                        &read_and_drop(&a, dispatcher.as_dispatcher_ref()).await,
670                        0,
671                    );
672                    fin_tx.send(()).unwrap();
673                })
674                .unwrap();
675            fin_rx.recv().unwrap()
676        });
677    }
678
679    #[test]
680    fn unsynchronized_early_cancel_state_drops_correctly() {
681        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
682        // be closed before the read wait has been cancelled.
683        let (a, _b) = Channel::<[u8]>::create();
684        let (unsync_op, _a) = with_raw_dispatcher_flags(
685            "early cancellation drop correctness",
686            DispatcherBuilder::UNSYNCHRONIZED,
687            |dispatcher| {
688                let (fin_tx, fin_rx) = mpsc::channel();
689
690                let inner_dispatcher = dispatcher.clone();
691                dispatcher
692                    .spawn_task(async move {
693                        // We send the arc out to be checked after the dispatcher has shut down so
694                        // that we can be sure that the callback has had a chance to be called.
695                        // We send the channel back out so that it lives long enough for the
696                        // cancellation to be called on it.
697                        let res = read_and_drop(&a, inner_dispatcher.as_dispatcher_ref()).await;
698                        fin_tx.send((res, a)).unwrap();
699                    })
700                    .unwrap();
701                fin_rx.recv().unwrap()
702            },
703        );
704
705        // check that there are no more owners of the inner op for the unsynchronized dispatcher.
706        assert_strong_count(&unsync_op, 0);
707    }
708}