libasync/dispatcher/
after_deadline.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::pin::Pin;
6use std::ptr::NonNull;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI32, Ordering};
9use std::task::{Context, Poll};
10
11use libasync_sys::{async_cancel_task, async_dispatcher, async_post_task, async_task};
12
13use futures::task::AtomicWaker;
14use zx::Status;
15use zx::sys::{ZX_ERR_CANCELED, ZX_OK};
16
17use crate::callback_state::CallbackSharedState;
18use crate::dispatcher::{AsyncDispatcher, OnDispatcher};
19
20type SharedState = CallbackSharedState<async_task, AfterDeadlineState>;
21
22struct AfterDeadlineState {
23    async_dispatcher: NonNull<async_dispatcher>,
24    waker: AtomicWaker,
25    /// The status will initially be [`Status::SHOULD_WAIT`]. Once fired it will be the status
26    /// returned by the callback.
27    status: AtomicI32,
28}
29
30// SAFETY: All fields in AfterDeadlineState are either atomic or immutable.
31unsafe impl Send for AfterDeadlineState {}
32unsafe impl Sync for AfterDeadlineState {}
33
34impl AfterDeadlineState {
35    extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
36        debug_assert!(
37            status == ZX_OK || status == ZX_ERR_CANCELED,
38            "task callback called with status other than ok or canceled"
39        );
40        // SAFETY: This callback's copy of the `async_task` object was refcounted for when we
41        // started the wait.
42        let state = unsafe { SharedState::from_raw_ptr(task) };
43        state.status.store(status, Ordering::Relaxed);
44        state.waker.wake();
45    }
46}
47
48/// A future that represents a deferral to a future time.
49///
50/// See [`OnDispatcher::after_deadline`] for more information.
51pub struct AfterDeadline<D: OnDispatcher> {
52    dispatcher: D,
53    state: Option<Arc<SharedState>>,
54    deadline: zx::MonotonicInstant,
55}
56
57impl<D: OnDispatcher + Clone> AfterDeadline<D> {
58    pub(super) fn new(dispatcher: &D, deadline: zx::MonotonicInstant) -> Self {
59        let dispatcher = dispatcher.clone();
60        let state = None;
61        Self { dispatcher, state, deadline }
62    }
63}
64
65impl<D: OnDispatcher + Unpin> Future for AfterDeadline<D> {
66    type Output = Result<(), Status>;
67
68    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69        // if we've already spawned a task then return based on the task's state.
70        if let Some(state) = &self.state {
71            let status = state.status.load(Ordering::Relaxed);
72            if status != Status::SHOULD_WAIT.into_raw() {
73                return Poll::Ready(Status::ok(status));
74            } else {
75                state.waker.register(cx.waker());
76                return Poll::Pending;
77            }
78        }
79
80        let deadline = self.deadline;
81
82        let now = self.dispatcher.on_maybe_dispatcher(|dispatcher| Ok(dispatcher.now()));
83        match now {
84            Ok(now) if deadline < now => {
85                return Poll::Ready(Ok(()));
86            }
87            Err(err) => {
88                return Poll::Ready(Err(err));
89            }
90            _ => {}
91        }
92
93        // otherwise we want to wait for a callback
94        let state = self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
95            // SAFETY: the fdf dispatcher is valid by construction and can provide an async dispatcher.
96            let async_dispatcher = dispatcher.inner();
97
98            let task = async_task {
99                handler: Some(AfterDeadlineState::call),
100                deadline: deadline.into_nanos(),
101                ..Default::default()
102            };
103            let state = AfterDeadlineState {
104                async_dispatcher,
105                waker: AtomicWaker::new(),
106                status: AtomicI32::new(Status::SHOULD_WAIT.into_raw()),
107            };
108            let state = SharedState::new(task, state);
109            state.waker.register(cx.waker());
110
111            let state_ptr = SharedState::make_raw_ptr(state.clone());
112
113            // SAFETY: We know the `async_dispatcher` is valid because we're running inside
114            // `on_dispatcher` and we are giving ownership of the shared state object to the
115            // callback.
116            let res = Status::ok(unsafe { async_post_task(async_dispatcher.as_ptr(), state_ptr) });
117            match res {
118                Ok(_) => Ok(state),
119                Err(err) => {
120                    // SAFETY: Posting the task failed, so we now have an outstanding reference to
121                    // the state object that will never have a callback called on it.
122                    unsafe { SharedState::release_raw_ptr(state_ptr) };
123                    Err(err)
124                }
125            }
126        });
127
128        match state {
129            Ok(state) => {
130                self.state = Some(state);
131                Poll::Pending
132            }
133            Err(err) => Poll::Ready(Err(err)),
134        }
135    }
136}
137
138impl<D: OnDispatcher> Drop for AfterDeadline<D> {
139    fn drop(&mut self) {
140        let Some(state) = self.state.take() else {
141            // if we never spawned a task we can just return.
142            return;
143        };
144        self.dispatcher.on_dispatcher(|dispatcher| {
145            let Some(dispatcher) = dispatcher else {
146                // if the dispatcher is no longer alive then the callback will have been
147                // called with ZX_ERR_CANCELED and we can assume that freed the callback's
148                // Arc.
149                return;
150            };
151            if state.status.load(Ordering::Relaxed) != Status::SHOULD_WAIT.into_raw() {
152                // the callback has been called so we don't even need to try to cancel it.
153                return;
154            }
155            let async_dispatcher = dispatcher.inner();
156            if async_dispatcher != state.async_dispatcher {
157                panic!("Dropping a pending `AfterDeadline` future from a different dispatcher than the one it was awaited on.");
158            }
159            let state_ptr = SharedState::as_raw_ptr(&state);
160            // SAFETY: We know that the current async dispatcher is valid because we are running
161            // inside `on_dispatcher`, and we know the `state_ptr` is valid because the `Arc`
162            // holding it is still held.
163            let status = unsafe { async_cancel_task(async_dispatcher.as_ptr(), state_ptr) };
164            if Status::from_raw(status) == Status::OK {
165                // SAFETY: If the cancellation was successful, we know the callback won't be called
166                // so we need to deallocate the copy of the arc that was given to it.
167                unsafe { SharedState::release_raw_ptr(state_ptr) };
168            }
169        });
170    }
171}
172
173#[cfg(all(not_yet, test))]
174mod tests {
175    use std::sync::mpsc;
176    use std::thread::sleep;
177    use std::time::Duration;
178
179    use super::*;
180
181    use futures::task::noop_waker;
182    use futures::{FutureExt, poll};
183
184    use crate::dispatcher::tests::with_raw_dispatcher;
185    use crate::dispatcher::{CurrentDispatcher, OnDispatcher};
186
187    #[test]
188    fn after_the_past() {
189        with_raw_dispatcher("testing task", |dispatcher| {
190            let (tx, rx) = mpsc::channel();
191            dispatcher
192                .spawn_task(async move {
193                    let fut = CurrentDispatcher.after_deadline(zx::MonotonicInstant::INFINITE_PAST);
194                    assert_eq!(poll!(fut), Poll::Ready(Ok(())));
195                    tx.send(()).unwrap();
196                })
197                .unwrap();
198            rx.recv().unwrap();
199        });
200    }
201
202    #[test]
203    fn after_now() {
204        with_raw_dispatcher("testing task", |dispatcher| {
205            let (tx, rx) = mpsc::channel();
206            dispatcher
207                .spawn_task(async move {
208                    let fut = CurrentDispatcher.after_deadline(CurrentDispatcher.now().unwrap());
209                    assert_eq!(poll!(fut), Poll::Ready(Ok(())));
210                    tx.send(()).unwrap();
211                })
212                .unwrap();
213            rx.recv().unwrap();
214        });
215    }
216
217    #[test]
218    fn after_future() {
219        with_raw_dispatcher("testing task", |dispatcher| {
220            let (tx, rx) = mpsc::channel();
221            dispatcher
222                .spawn_task(async move {
223                    let deadline =
224                        CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_seconds(3);
225                    let mut fut = CurrentDispatcher.after_deadline(deadline);
226                    assert_eq!(poll!(&mut fut), Poll::Pending);
227                    assert!(fut.await.is_ok());
228                    assert!(CurrentDispatcher.now().unwrap() >= deadline);
229                    tx.send(()).unwrap();
230                })
231                .unwrap();
232            rx.recv().unwrap();
233        });
234    }
235
236    #[test]
237    fn drop_after_poll() {
238        with_raw_dispatcher("testing task", |dispatcher| {
239            let (tx, rx) = mpsc::channel();
240            dispatcher
241                .spawn_task(async move {
242                    let deadline =
243                        CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_minutes(3);
244                    let mut fut = CurrentDispatcher.after_deadline(deadline);
245                    assert_eq!(poll!(&mut fut), Poll::Pending);
246                    tx.send(()).unwrap();
247                })
248                .unwrap();
249            rx.recv().unwrap();
250        });
251    }
252
253    #[test]
254    fn dispatcher_shutdown_cancel() {
255        let (fut_tx, fut_rx) = mpsc::channel();
256        with_raw_dispatcher("testing task", |dispatcher| {
257            let (tx, rx) = mpsc::channel();
258            dispatcher
259                .spawn_task(async move {
260                    let deadline =
261                        CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_minutes(3);
262                    let mut fut = CurrentDispatcher.after_deadline(deadline);
263                    assert_eq!(poll!(&mut fut), Poll::Pending);
264                    fut_tx.send(fut).unwrap();
265                    tx.send(()).unwrap();
266                })
267                .unwrap();
268            rx.recv().unwrap();
269        });
270        let mut fut = fut_rx.recv().unwrap();
271        loop {
272            let Poll::Ready(res) = fut.poll_unpin(&mut Context::from_waker(&noop_waker())) else {
273                sleep(Duration::from_millis(10));
274                continue;
275            };
276            assert_eq!(res, Err(Status::CANCELED));
277            break;
278        }
279    }
280}