fdf_channel/futures.rs
1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::task::Waker;
9use zx::Status;
10
11use crate::channel::{Channel, try_read_raw};
12use crate::message::Message;
13use fdf_core::dispatcher::{DispatcherRef, OnDispatcher};
14use fdf_core::handle::DriverHandle;
15use fdf_sys::*;
16
17use core::mem::MaybeUninit;
18use core::task::{Context, Poll};
19use fuchsia_sync::Mutex;
20use std::sync::Arc;
21
22pub use fdf_sys::fdf_handle_t;
23
24// state for a read message that is controlled by a lock
25#[derive(Default, Debug)]
26struct ReadMessageStateOpLocked {
27 /// the currently active waker for this read operation. Only set if there
28 /// is currently a pending read operation awaiting a callback.
29 waker: Option<Waker>,
30 /// if the channel was dropped while a pending callback was active, so the
31 /// callback should close the driverhandle when it fires.
32 channel_dropped: bool,
33 /// whether cancelation of this future will happen asynchronously through
34 /// the callback or immediately when [`fdf_channel_cancel_wait`] is called.
35 /// This is used to decide what's responsible for freeing the reference
36 /// to this object when the future is canceled.
37 cancelation_is_async: bool,
38}
39
40/// This struct is shared between the future and the driver runtime, with the first field
41/// being managed by the driver runtime and the second by the future. It will be held by two
42/// [`Arc`]s, one for each of the future and the runtime.
43///
44/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
45/// normal [`Drop`] of the future.
46///
47/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
48/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
49/// called *up to* one time.
50///
51/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
52/// read wait.
53#[repr(C)]
54#[derive(Debug)]
55pub(crate) struct ReadMessageStateOp {
56 /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
57 read_op: fdf_channel_read,
58 state: Mutex<ReadMessageStateOpLocked>,
59}
60
61impl ReadMessageStateOp {
62 unsafe extern "C" fn handler(
63 _dispatcher: *mut fdf_dispatcher,
64 read_op: *mut fdf_channel_read,
65 _status: i32,
66 ) {
67 // Note: we don't really do anything different based on whether the callback
68 // says canceled. If the future was canceled by being dropped, it won't poll
69 // again since it was dropped.
70 // The only unusual case is when the dispatcher is shutting down, and in that
71 // case we will wake the future and it will try to read and get a more useful
72 // error.
73 // Meanwhile, since we use the same state object across multiple
74 // futures due to needing to handle async cancelation, trying to track the
75 // underlying reason for the cancelation becomes more tricky than it's worth.
76
77 // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
78 // for this handler to reconstitute it.
79 let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
80
81 let mut state = op.state.lock();
82 if state.channel_dropped {
83 // SAFETY: since the channel dropped we are the only outstanding owner of the
84 // channel object.
85 unsafe { fdf_handle_close(op.read_op.channel) };
86 }
87 let Some(waker) = state.waker.take() else {
88 // the waker was already taken, presumably because the future was dropped.
89 return;
90 };
91 // make sure to drop the lock before calling the waker.
92 drop(state);
93 waker.wake()
94 }
95
96 /// Called by the channel on drop to indicate that the channel has been dropped and
97 /// find out whether it needs to defer dropping the handle until the callback is called.
98 pub fn set_channel_dropped(&self) -> bool {
99 let mut state = self.state.lock();
100 if state.waker.is_some() {
101 state.channel_dropped = true;
102 false
103 } else {
104 true
105 }
106 }
107}
108
109/// An object for managing the state of an async channel read message operation that can be used to
110/// implement futures.
111pub struct ReadMessageState {
112 op: Arc<ReadMessageStateOp>,
113 channel: ManuallyDrop<DriverHandle>,
114}
115
116impl ReadMessageState {
117 /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
118 /// data from a channel and then converts it to the appropriate type. It also allows for
119 /// different ways of storing and managing the dispatcher we wait on by deferring the
120 /// dispatcher used to poll time. This state is registered with the given [`Channel`]
121 /// so that dropping the channel will correctly free resources.
122 ///
123 /// # Safety
124 ///
125 /// The caller is responsible for ensuring that the handle inside `channel` outlives this
126 /// object.
127 pub unsafe fn register_read_wait<T: ?Sized>(channel: &mut Channel<T>) -> Self {
128 // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
129 // and that the handle will outlive the created [`ReadMessageState`].
130 let channel_handle = unsafe { channel.handle.get_raw() };
131 let op = channel
132 .wait_state
133 .get_or_insert_with(|| {
134 Arc::new(ReadMessageStateOp {
135 read_op: fdf_channel_read {
136 channel: channel_handle.get(),
137 handler: Some(ReadMessageStateOp::handler),
138 ..Default::default()
139 },
140 state: Mutex::new(ReadMessageStateOpLocked::default()),
141 })
142 })
143 .clone();
144 Self {
145 op,
146 // SAFETY: We know this is a valid driver handle by construction and we are
147 // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
148 // The caller is responsible for ensuring that the handle outlives this object.
149 channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel_handle) }),
150 }
151 }
152
153 /// Polls this channel read operation against the given dispatcher.
154 #[expect(clippy::type_complexity)]
155 pub fn poll_with_dispatcher<D: OnDispatcher>(
156 &mut self,
157 cx: &mut Context<'_>,
158 dispatcher: D,
159 ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
160 let mut state = self.op.state.lock();
161
162 match try_read_raw(&self.channel) {
163 Ok(res) => Poll::Ready(Ok(res)),
164 Err(Status::SHOULD_WAIT) => {
165 // if we haven't yet set a waker, that means we haven't started the wait operation
166 // yet.
167 if state.waker.is_none() {
168 // increment the reference count of the read op to account for the copy that will be given to
169 // `fdf_channel_wait_async`.
170 let op = Arc::into_raw(self.op.clone());
171 let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
172 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
173 // if we're not running on the same dispatcher as we're waiting from, we
174 // want to force async cancellation
175 let options = if !dispatcher.is_current_dispatcher() {
176 FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
177 } else {
178 0
179 };
180 // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
181 // has `repr(C)` layout, so is safe to be cast to the latter.
182 let res = Status::ok(unsafe {
183 fdf_channel_wait_async(
184 dispatcher.inner().as_ptr(),
185 op.cast_mut().cast(),
186 options,
187 )
188 });
189 if res.is_ok() {
190 // only replace the waker if we succeeded, so we'll try again next time
191 // otherwise.
192 state.waker.replace(cx.waker().clone());
193 } else {
194 // reconstitute the arc we made for the callback so it can be dropped
195 // since the async wait didn't succeed.
196 drop(unsafe { Arc::from_raw(op) });
197 }
198 // if the dispatcher we're waiting on is unsynchronized, the callback
199 // will drop the Arc and we need to indicate to our own Drop impl
200 // that it should not.
201 res.map(|_| {
202 options == FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
203 || dispatcher.is_unsynchronized()
204 })
205 });
206
207 // the default state should be that `drop` will free the arc.
208 state.cancelation_is_async = false;
209 match res {
210 Err(Status::BAD_STATE) => {
211 return Poll::Pending; // a pending await is being cancelled
212 }
213 Ok(cancelation_is_async) => {
214 state.cancelation_is_async = cancelation_is_async;
215 }
216 Err(e) => return Poll::Ready(Err(e)),
217 }
218 }
219 Poll::Pending
220 }
221 Err(e) => Poll::Ready(Err(e)),
222 }
223 }
224}
225
226impl Drop for ReadMessageState {
227 fn drop(&mut self) {
228 let mut state = self.op.state.lock();
229 if state.waker.is_none() {
230 // if there's no waker either the callback has already fired or we never waited on this
231 // future in the first place, so just leave it be.
232 return;
233 }
234
235 // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
236 // be valid.
237 let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
238 match res {
239 Ok(_) => {}
240 Err(Status::NOT_FOUND) => {
241 // the callback is already being called or the wait was already cancelled, so just
242 // return and leave it.
243 return;
244 }
245 Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
246 }
247 // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
248 // successful, the callback will not be called and we will have to free the `Arc` that the
249 // callback would have consumed.
250 if !state.cancelation_is_async {
251 // steal the waker so it doesn't get called, if there is one.
252 state.waker.take();
253 unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
254 }
255 }
256}
257
258#[cfg(test)]
259mod test {
260 use std::pin::pin;
261 use std::sync::Weak;
262
263 use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
264 use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
265
266 use crate::arena::Arena;
267 use crate::channel::{Channel, read_raw};
268
269 use super::*;
270
271 /// assert that the strong count of an arc is correct
272 #[track_caller]
273 fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
274 assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
275 }
276
277 /// create, poll, and then immediately drop a read future for a channel and verify
278 /// that the internal op arc has the right refcount at all steps. Returns a copy
279 /// of the op arc at the end so it can be verified that the count goes down
280 /// to zero correctly.
281 async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher + Unpin>(
282 channel: &mut Channel<T>,
283 dispatcher: D,
284 ) -> Weak<ReadMessageStateOp> {
285 let fut = unsafe { read_raw(channel, dispatcher) };
286 let op_arc = Arc::downgrade(&fut.raw_fut.op);
287 assert_strong_count(&op_arc, 2);
288 let mut fut = pin!(fut);
289 let Poll::Pending = futures::poll!(fut.as_mut()) else {
290 panic!("expected pending state after polling channel read once");
291 };
292 assert_strong_count(&op_arc, 3);
293 op_arc
294 }
295
296 #[test]
297 fn early_cancel_future() {
298 spawn_in_driver("early cancellation", async {
299 let (mut a, b) = Channel::create();
300
301 // create, poll, and then immediately drop a read future for channel `a`
302 // so that it properly sets up the wait.
303 read_and_drop(&mut a, CurrentDispatcher).await;
304 b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
305 assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
306 })
307 }
308
309 #[test]
310 fn very_early_cancel_state_drops_correctly() {
311 spawn_in_driver("early cancellation drop correctness", async {
312 let (mut a, _b) = Channel::<[u8]>::create();
313
314 // drop before even polling it should drop the arc correctly
315 let fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
316 let op_arc = Arc::downgrade(&fut.raw_fut.op);
317 assert_strong_count(&op_arc, 2);
318 drop(fut);
319 assert_strong_count(&op_arc, 1);
320 })
321 }
322
323 #[test]
324 fn synchronized_early_cancel_state_drops_correctly() {
325 spawn_in_driver("early cancellation drop correctness", async {
326 let (mut a, _b) = Channel::<[u8]>::create();
327
328 assert_strong_count(&read_and_drop(&mut a, CurrentDispatcher).await, 1);
329 });
330 }
331
332 #[test]
333 fn unsynchronized_early_cancel_state_drops_correctly() {
334 // the channel needs to outlive the dispatcher for this test because the channel shouldn't
335 // be closed before the read wait has been cancelled.
336 let (mut a, _b) = Channel::<[u8]>::create();
337 let unsync_op =
338 spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
339 // We send the arc out to be checked after the dispatcher has shut down so
340 // that we can be sure that the callback has had a chance to be called.
341 // We send the channel back out so that it lives long enough for the
342 // cancellation to be called on it.
343 read_and_drop(&mut a, CurrentDispatcher).await
344 });
345
346 // check that there are no more owners of the inner op for the unsynchronized dispatcher.
347 assert_strong_count(&unsync_op, 0);
348 }
349
350 #[test]
351 fn unsynchronized_early_cancel_state_drops_repeatedly_correctly() {
352 // the channel needs to outlive the dispatcher for this test because the channel shouldn't
353 // be closed before the read wait has been cancelled.
354 let (mut a, _b) = Channel::<[u8]>::create();
355 spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
356 for _ in 0..10000 {
357 let mut fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
358 let Poll::Pending = futures::poll!(&mut fut) else {
359 panic!("expected pending state after polling channel read once");
360 };
361 drop(fut);
362 }
363 });
364 }
365}