libasync/dispatcher/
task.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the C libasync async dispatcher library
6
7use zx::sys::ZX_OK;
8
9use core::task::Context;
10use fuchsia_sync::Mutex;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, mpsc};
14use std::task::Poll;
15
16use zx::Status;
17
18use futures::future::{BoxFuture, FutureExt};
19use futures::task::{ArcWake, AtomicWaker, waker_ref};
20
21use crate::{AsyncDispatcher, OnDispatcher};
22
23/// The future returned by [`OnDispatcher::compute`] or [`OnDispatcher::try_compute`]. If this is
24/// dropped, the task will be cancelled.
25#[must_use]
26pub struct Task<T> {
27    state: Arc<TaskFutureState>,
28    result_receiver: mpsc::Receiver<Result<T, Status>>,
29    detached: bool,
30}
31
32impl<T: Send + 'static> Task<T> {
33    fn new<D: OnDispatcher + 'static>(
34        future: impl Future<Output = T> + Send + 'static,
35        dispatcher: D,
36    ) -> (Self, Arc<TaskWakerState<T, D>>) {
37        let future_state = Arc::new(TaskFutureState {
38            waker: AtomicWaker::new(),
39            aborted: AtomicBool::new(false),
40        });
41        let (result_sender, result_receiver) = mpsc::sync_channel(1);
42        let state = Arc::new(TaskWakerState {
43            result_sender,
44            future_state: future_state.clone(),
45            future: Mutex::new(Some(future.boxed())),
46            dispatcher,
47        });
48        let future = Task { state: future_state, result_receiver, detached: false };
49        (future, state)
50    }
51
52    pub(crate) fn try_start<D: OnDispatcher + 'static>(
53        future: impl Future<Output = T> + Send + 'static,
54        dispatcher: D,
55    ) -> Result<Self, Status> {
56        let (future, state) = Self::new(future, dispatcher);
57        state.queue().map(|_| future)
58    }
59
60    pub(crate) fn start<D: OnDispatcher + 'static>(
61        future: impl Future<Output = T> + Send + 'static,
62        dispatcher: D,
63    ) -> Self {
64        let (future, state) = Self::new(future, dispatcher);
65
66        // try to queue the task and if it fails short circuit the delivery of failure to the
67        // caller.
68        if let Err(err) = state.queue() {
69            // drop the future we were given
70            drop(state.future.lock().take());
71            // send the error to the result receiver. This should never fail, since
72            // we just created both ends and the task queuing failed.
73            state.result_sender.try_send(Err(err)).unwrap();
74        }
75
76        future
77    }
78}
79
80impl<T> Task<T> {
81    /// Detaches this future from the task so that it will continue executing without waiting
82    /// on the future. If this is not called, and the future is dropped, the task will be aborted
83    /// the next time it is awoken.
84    pub fn detach(self) {
85        drop(self.detach_on_drop());
86    }
87
88    /// Detaches this future from the task so that it will continue executing without waiting
89    /// on the future. If this is not called, and the future is dropped, the task will be aborted
90    /// the next time it is awoken.
91    ///
92    /// Returns a future that can be awaited on or dropped without affecting the task.
93    pub fn detach_on_drop(mut self) -> JoinHandle<T> {
94        self.detached = true;
95        JoinHandle(self)
96    }
97
98    /// Aborts the task and returns a future that can be used to wait for the task to either
99    /// complete or cancel. If the task was canceled the result of the future will be
100    /// [`Status::CANCELED`].
101    pub fn abort(&self) {
102        self.state.aborted.store(true, Ordering::Relaxed);
103    }
104}
105
106impl<T> Drop for Task<T> {
107    fn drop(&mut self) {
108        if !self.detached {
109            self.state.aborted.store(true, Ordering::Relaxed);
110        }
111    }
112}
113
114struct TaskFutureState {
115    waker: AtomicWaker,
116    aborted: AtomicBool,
117}
118
119impl<T> Future for Task<T> {
120    type Output = Result<T, Status>;
121
122    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123        match self.result_receiver.try_recv() {
124            Ok(res) => Poll::Ready(res),
125            Err(_) => {
126                self.state.waker.register(cx.waker());
127                Poll::Pending
128            }
129        }
130    }
131}
132
133/// A handle for a task that will detach on drop. Returned by [`OnDispatcher::spawn`].
134pub struct JoinHandle<T>(Task<T>);
135
136impl<T> JoinHandle<T> {
137    /// Aborts the task and returns a future that can be used to wait for the task to either
138    /// complete or cancel. If the task was canceled the result of the future will be
139    /// [`Status::CANCELED`].
140    pub fn abort(&self) {
141        self.0.abort()
142    }
143}
144
145impl<T> Future for JoinHandle<T> {
146    type Output = Result<T, Status>;
147
148    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
149        self.0.poll_unpin(cx)
150    }
151}
152
153struct TaskWakerState<T, D> {
154    result_sender: mpsc::SyncSender<Result<T, Status>>,
155    future_state: Arc<TaskFutureState>,
156    future: Mutex<Option<BoxFuture<'static, T>>>,
157    dispatcher: D,
158}
159
160impl<T: Send + 'static, D: OnDispatcher + 'static> ArcWake for TaskWakerState<T, D> {
161    fn wake_by_ref(arc_self: &Arc<Self>) {
162        match arc_self.queue() {
163            Err(e) if e == Status::BAD_STATE => {
164                // the dispatcher is shutting down so drop the future, if there
165                // is one, to cancel it.
166                let future_slot = arc_self.future.lock().take();
167                drop(future_slot);
168                arc_self.send_result(Err(e));
169            }
170            res => res.expect("Unexpected error waking dispatcher task"),
171        }
172    }
173}
174
175impl<T: Send + 'static, D: OnDispatcher + 'static> TaskWakerState<T, D> {
176    /// Sends the result to the future end of this task, if it still exists.
177    fn send_result(&self, res: Result<T, Status>) {
178        // send the result and wake the waker if any has been registered.
179        // We ignore the result here because if the other end has dropped it's
180        // fine for the result to go nowhere.
181        self.result_sender.try_send(res).ok();
182        self.future_state.waker.wake();
183    }
184
185    /// Posts a task to progress the currently stored future. The task will
186    /// consume the future if the future is ready after the next poll.
187    /// Otherwise, the future is kept to be polled again after being woken.
188    pub(crate) fn queue(self: &Arc<Self>) -> Result<(), Status> {
189        let arc_self = self.clone();
190        self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
191            dispatcher
192                .post_task_sync(move |status| {
193                    let mut future_slot = arc_self.future.lock();
194                    // if the executor is shutting down, drop the future we're waiting on and pass
195                    // on the error.
196                    if status != Status::from_raw(ZX_OK) {
197                        drop(future_slot.take());
198                        arc_self.send_result(Err(status));
199                        return;
200                    }
201
202                    // if the future has been dropped without being detached, drop the future and
203                    // send an Err(Status::CANCELED) if the caller is still listening.
204                    if arc_self.future_state.aborted.load(Ordering::Relaxed) {
205                        drop(future_slot.take());
206                        arc_self.send_result(Err(Status::CANCELED));
207                        return;
208                    }
209
210                    let Some(mut future) = future_slot.take() else {
211                        return;
212                    };
213                    let waker = waker_ref(&arc_self);
214                    let context = &mut Context::from_waker(&waker);
215                    match future.as_mut().poll(context) {
216                        Poll::Pending => *future_slot = Some(future),
217                        Poll::Ready(res) => {
218                            arc_self.send_result(Ok(res));
219                        }
220                    }
221                })
222                .map(|_| ())
223        })
224    }
225}