fdf_channel/
futures.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::task::Waker;
9use zx::Status;
10
11use crate::channel::{Channel, try_read_raw};
12use crate::message::Message;
13use fdf_core::dispatcher::{DispatcherRef, OnDispatcher};
14use fdf_core::handle::DriverHandle;
15use fdf_sys::*;
16
17use core::mem::MaybeUninit;
18use core::task::{Context, Poll};
19use fuchsia_sync::Mutex;
20use std::sync::Arc;
21
22pub use fdf_sys::fdf_handle_t;
23
24// state for a read message that is controlled by a lock
25#[derive(Default, Debug)]
26struct ReadMessageStateOpLocked {
27    /// the currently active waker for this read operation. Only set if there
28    /// is currently a pending read operation awaiting a callback.
29    waker: Option<Waker>,
30    /// if the channel was dropped while a pending callback was active, so the
31    /// callback should close the driverhandle when it fires.
32    channel_dropped: bool,
33    /// whether cancelation of this future will happen asynchronously through
34    /// the callback or immediately when [`fdf_channel_cancel_wait`] is called.
35    /// This is used to decide what's responsible for freeing the reference
36    /// to this object when the future is canceled.
37    cancelation_is_async: bool,
38}
39
40/// This struct is shared between the future and the driver runtime, with the first field
41/// being managed by the driver runtime and the second by the future. It will be held by two
42/// [`Arc`]s, one for each of the future and the runtime.
43///
44/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
45/// normal [`Drop`] of the future.
46///
47/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
48/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
49/// called *up to* one time.
50///
51/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
52/// read wait.
53#[repr(C)]
54#[derive(Debug)]
55pub(crate) struct ReadMessageStateOp {
56    /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
57    read_op: fdf_channel_read,
58    state: Mutex<ReadMessageStateOpLocked>,
59}
60
61impl ReadMessageStateOp {
62    unsafe extern "C" fn handler(
63        _dispatcher: *mut fdf_dispatcher,
64        read_op: *mut fdf_channel_read,
65        _status: i32,
66    ) {
67        // Note: we don't really do anything different based on whether the callback
68        // says canceled. If the future was canceled by being dropped, it won't poll
69        // again since it was dropped.
70        // The only unusual case is when the dispatcher is shutting down, and in that
71        // case we will wake the future and it will try to read and get a more useful
72        // error.
73        // Meanwhile, since we use the same state object across multiple
74        // futures due to needing to handle async cancelation, trying to track the
75        // underlying reason for the cancelation becomes more tricky than it's worth.
76
77        // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
78        // for this handler to reconstitute it.
79        let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
80
81        let mut state = op.state.lock();
82        if state.channel_dropped {
83            // SAFETY: since the channel dropped we are the only outstanding owner of the
84            // channel object.
85            unsafe { fdf_handle_close(op.read_op.channel) };
86        }
87        let Some(waker) = state.waker.take() else {
88            // the waker was already taken, presumably because the future was dropped.
89            return;
90        };
91        // make sure to drop the lock before calling the waker.
92        drop(state);
93        waker.wake()
94    }
95
96    /// Called by the channel on drop to indicate that the channel has been dropped and
97    /// find out whether it needs to defer dropping the handle until the callback is called.
98    pub fn set_channel_dropped(&self) -> bool {
99        let mut state = self.state.lock();
100        if state.waker.is_some() {
101            state.channel_dropped = true;
102            false
103        } else {
104            true
105        }
106    }
107}
108
109/// An object for managing the state of an async channel read message operation that can be used to
110/// implement futures.
111pub struct ReadMessageState {
112    op: Arc<ReadMessageStateOp>,
113    channel: ManuallyDrop<DriverHandle>,
114}
115
116impl ReadMessageState {
117    /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
118    /// data from a channel and then converts it to the appropriate type. It also allows for
119    /// different ways of storing and managing the dispatcher we wait on by deferring the
120    /// dispatcher used to poll time. This state is registered with the given [`Channel`]
121    /// so that dropping the channel will correctly free resources.
122    ///
123    /// # Safety
124    ///
125    /// The caller is responsible for ensuring that the handle inside `channel` outlives this
126    /// object.
127    pub unsafe fn register_read_wait<T: ?Sized>(channel: &mut Channel<T>) -> Self {
128        // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
129        // and that the handle will outlive the created [`ReadMessageState`].
130        let channel_handle = unsafe { channel.handle.get_raw() };
131        let op = channel
132            .wait_state
133            .get_or_insert_with(|| {
134                Arc::new(ReadMessageStateOp {
135                    read_op: fdf_channel_read {
136                        channel: channel_handle.get(),
137                        handler: Some(ReadMessageStateOp::handler),
138                        ..Default::default()
139                    },
140                    state: Mutex::new(ReadMessageStateOpLocked::default()),
141                })
142            })
143            .clone();
144        Self {
145            op,
146            // SAFETY: We know this is a valid driver handle by construction and we are
147            // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
148            // The caller is responsible for ensuring that the handle outlives this object.
149            channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel_handle) }),
150        }
151    }
152
153    /// Polls this channel read operation against the given dispatcher.
154    #[expect(clippy::type_complexity)]
155    pub fn poll_with_dispatcher<D: OnDispatcher>(
156        &mut self,
157        cx: &mut Context<'_>,
158        dispatcher: D,
159    ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
160        let mut state = self.op.state.lock();
161
162        match try_read_raw(&self.channel) {
163            Ok(res) => Poll::Ready(Ok(res)),
164            Err(Status::SHOULD_WAIT) => {
165                // if we haven't yet set a waker, that means we haven't started the wait operation
166                // yet.
167                if state.waker.is_none() {
168                    // increment the reference count of the read op to account for the copy that will be given to
169                    // `fdf_channel_wait_async`.
170                    let op = Arc::into_raw(self.op.clone());
171                    let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
172                        let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
173                        // if we're not running on the same dispatcher as we're waiting from, we
174                        // want to force async cancellation
175                        let options = if !dispatcher.is_current_dispatcher() {
176                            FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
177                        } else {
178                            0
179                        };
180                        // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
181                        // has `repr(C)` layout, so is safe to be cast to the latter.
182                        let res = Status::ok(unsafe {
183                            fdf_channel_wait_async(
184                                dispatcher.inner().as_ptr(),
185                                op.cast_mut().cast(),
186                                options,
187                            )
188                        });
189                        if res.is_ok() {
190                            // only replace the waker if we succeeded, so we'll try again next time
191                            // otherwise.
192                            state.waker.replace(cx.waker().clone());
193                        } else {
194                            // reconstitute the arc we made for the callback so it can be dropped
195                            // since the async wait didn't succeed.
196                            drop(unsafe { Arc::from_raw(op) });
197                        }
198                        // if the dispatcher we're waiting on is unsynchronized, the callback
199                        // will drop the Arc and we need to indicate to our own Drop impl
200                        // that it should not.
201                        res.map(|_| {
202                            options == FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
203                                || dispatcher.is_unsynchronized()
204                        })
205                    });
206
207                    // the default state should be that `drop` will free the arc.
208                    state.cancelation_is_async = false;
209                    match res {
210                        Err(Status::BAD_STATE) => {
211                            return Poll::Pending; // a pending await is being cancelled
212                        }
213                        Ok(cancelation_is_async) => {
214                            state.cancelation_is_async = cancelation_is_async;
215                        }
216                        Err(e) => return Poll::Ready(Err(e)),
217                    }
218                }
219                Poll::Pending
220            }
221            Err(e) => Poll::Ready(Err(e)),
222        }
223    }
224}
225
226impl Drop for ReadMessageState {
227    fn drop(&mut self) {
228        let mut state = self.op.state.lock();
229        if state.waker.is_none() {
230            // if there's no waker either the callback has already fired or we never waited on this
231            // future in the first place, so just leave it be.
232            return;
233        }
234
235        // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
236        // be valid.
237        let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
238        match res {
239            Ok(_) => {}
240            Err(Status::NOT_FOUND) => {
241                // the callback is already being called or the wait was already cancelled, so just
242                // return and leave it.
243                return;
244            }
245            Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
246        }
247        // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
248        // successful, the callback will not be called and we will have to free the `Arc` that the
249        // callback would have consumed.
250        if !state.cancelation_is_async {
251            // steal the waker so it doesn't get called, if there is one.
252            state.waker.take();
253            unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
254        }
255    }
256}
257
258#[cfg(test)]
259mod test {
260    use std::pin::pin;
261    use std::sync::Weak;
262
263    use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
264    use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
265
266    use crate::arena::Arena;
267    use crate::channel::{Channel, read_raw};
268
269    use super::*;
270
271    /// assert that the strong count of an arc is correct
272    #[track_caller]
273    fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
274        assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
275    }
276
277    /// create, poll, and then immediately drop a read future for a channel and verify
278    /// that the internal op arc has the right refcount at all steps. Returns a copy
279    /// of the op arc at the end so it can be verified that the count goes down
280    /// to zero correctly.
281    async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher + Unpin>(
282        channel: &mut Channel<T>,
283        dispatcher: D,
284    ) -> Weak<ReadMessageStateOp> {
285        let fut = unsafe { read_raw(channel, dispatcher) };
286        let op_arc = Arc::downgrade(&fut.raw_fut.op);
287        assert_strong_count(&op_arc, 2);
288        let mut fut = pin!(fut);
289        let Poll::Pending = futures::poll!(fut.as_mut()) else {
290            panic!("expected pending state after polling channel read once");
291        };
292        assert_strong_count(&op_arc, 3);
293        op_arc
294    }
295
296    #[test]
297    fn early_cancel_future() {
298        spawn_in_driver("early cancellation", async {
299            let (mut a, b) = Channel::create();
300
301            // create, poll, and then immediately drop a read future for channel `a`
302            // so that it properly sets up the wait.
303            read_and_drop(&mut a, CurrentDispatcher).await;
304            b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
305            assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
306        })
307    }
308
309    #[test]
310    fn very_early_cancel_state_drops_correctly() {
311        spawn_in_driver("early cancellation drop correctness", async {
312            let (mut a, _b) = Channel::<[u8]>::create();
313
314            // drop before even polling it should drop the arc correctly
315            let fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
316            let op_arc = Arc::downgrade(&fut.raw_fut.op);
317            assert_strong_count(&op_arc, 2);
318            drop(fut);
319            assert_strong_count(&op_arc, 1);
320        })
321    }
322
323    #[test]
324    fn synchronized_early_cancel_state_drops_correctly() {
325        spawn_in_driver("early cancellation drop correctness", async {
326            let (mut a, _b) = Channel::<[u8]>::create();
327
328            assert_strong_count(&read_and_drop(&mut a, CurrentDispatcher).await, 1);
329        });
330    }
331
332    #[test]
333    fn unsynchronized_early_cancel_state_drops_correctly() {
334        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
335        // be closed before the read wait has been cancelled.
336        let (mut a, _b) = Channel::<[u8]>::create();
337        let unsync_op =
338            spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
339                // We send the arc out to be checked after the dispatcher has shut down so
340                // that we can be sure that the callback has had a chance to be called.
341                // We send the channel back out so that it lives long enough for the
342                // cancellation to be called on it.
343                read_and_drop(&mut a, CurrentDispatcher).await
344            });
345
346        // check that there are no more owners of the inner op for the unsynchronized dispatcher.
347        assert_strong_count(&unsync_op, 0);
348    }
349
350    #[test]
351    fn unsynchronized_early_cancel_state_drops_repeatedly_correctly() {
352        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
353        // be closed before the read wait has been cancelled.
354        let (mut a, _b) = Channel::<[u8]>::create();
355        spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
356            for _ in 0..10000 {
357                let mut fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
358                let Poll::Pending = futures::poll!(&mut fut) else {
359                    panic!("expected pending state after polling channel read once");
360                };
361                drop(fut);
362            }
363        });
364    }
365}