Skip to main content

libasync/dispatcher/
task.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the C libasync async dispatcher library
6
7use zx::sys::ZX_OK;
8
9use core::task::Context;
10use fuchsia_sync::Mutex;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, mpsc};
14use std::task::{Poll, Wake, Waker};
15
16use zx::Status;
17
18use futures::future::{BoxFuture, FutureExt};
19use futures::task::AtomicWaker;
20
21use crate::{AsyncDispatcher, OnDispatcher};
22
23/// The future returned by [`OnDispatcher::compute`] or [`OnDispatcher::try_compute`]. If this is
24/// dropped, the task will be cancelled.
25#[must_use]
26#[derive(Debug)]
27pub struct Task<T> {
28    state: Arc<TaskFutureState>,
29    result_receiver: mpsc::Receiver<Result<T, Status>>,
30    detached: bool,
31}
32
33impl<T: Send + 'static> Task<T> {
34    fn new<D: OnDispatcher + 'static>(
35        future: impl Future<Output = T> + Send + 'static,
36        dispatcher: D,
37    ) -> (Self, Arc<TaskWakerState<T, D>>) {
38        let future_state = Arc::new(TaskFutureState {
39            waker: AtomicWaker::new(),
40            aborted: AtomicBool::new(false),
41        });
42        let (result_sender, result_receiver) = mpsc::sync_channel(1);
43        let state = Arc::new(TaskWakerState {
44            result_sender,
45            future_state: future_state.clone(),
46            future: Mutex::new(Some(future.boxed())),
47            dispatcher,
48        });
49        let future = Task { state: future_state, result_receiver, detached: false };
50        (future, state)
51    }
52
53    pub(crate) fn try_start<D: OnDispatcher + 'static>(
54        future: impl Future<Output = T> + Send + 'static,
55        dispatcher: D,
56    ) -> Result<Self, Status> {
57        let (future, state) = Self::new(future, dispatcher);
58        state.queue().map(|_| future)
59    }
60
61    pub(crate) fn start<D: OnDispatcher + 'static>(
62        future: impl Future<Output = T> + Send + 'static,
63        dispatcher: D,
64    ) -> Self {
65        let (future, state) = Self::new(future, dispatcher);
66
67        // try to queue the task and if it fails short circuit the delivery of failure to the
68        // caller.
69        if let Err(err) = state.queue() {
70            // drop the future we were given
71            drop(state.future.lock().take());
72            // send the error to the result receiver. This should never fail, since
73            // we just created both ends and the task queuing failed.
74            state.result_sender.try_send(Err(err)).unwrap();
75        }
76
77        future
78    }
79}
80
81impl<T> Task<T> {
82    /// Detaches this future from the task so that it will continue executing without waiting
83    /// on the future. If this is not called, and the future is dropped, the task will be aborted
84    /// the next time it is awoken.
85    pub fn detach(self) {
86        drop(self.detach_on_drop());
87    }
88
89    /// Detaches this future from the task so that it will continue executing without waiting
90    /// on the future. If this is not called, and the future is dropped, the task will be aborted
91    /// the next time it is awoken.
92    ///
93    /// Returns a future that can be awaited on or dropped without affecting the task.
94    pub fn detach_on_drop(mut self) -> JoinHandle<T> {
95        self.detached = true;
96        JoinHandle(self)
97    }
98
99    /// Aborts the task and returns a future that can be used to wait for the task to either
100    /// complete or cancel. If the task was canceled the result of the future will be
101    /// [`Status::CANCELED`].
102    pub fn abort(&self) {
103        self.state.aborted.store(true, Ordering::Relaxed);
104    }
105}
106
107impl<T> Drop for Task<T> {
108    fn drop(&mut self) {
109        if !self.detached {
110            self.state.aborted.store(true, Ordering::Relaxed);
111        }
112    }
113}
114
115#[derive(Debug)]
116struct TaskFutureState {
117    waker: AtomicWaker,
118    aborted: AtomicBool,
119}
120
121impl<T> Future for Task<T> {
122    type Output = Result<T, Status>;
123
124    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125        use std::sync::mpsc::TryRecvError;
126        self.state.waker.register(cx.waker());
127        match self.result_receiver.try_recv() {
128            Ok(res) => Poll::Ready(res),
129            Err(TryRecvError::Disconnected) => Poll::Ready(Err(Status::CANCELED)),
130            Err(TryRecvError::Empty) => Poll::Pending,
131        }
132    }
133}
134
135/// A handle for a task that will detach on drop. Returned by [`OnDispatcher::spawn`].
136#[derive(Debug)]
137pub struct JoinHandle<T>(Task<T>);
138
139impl<T> JoinHandle<T> {
140    /// Aborts the task and returns a future that can be used to wait for the task to either
141    /// complete or cancel. If the task was canceled the result of the future will be
142    /// [`Status::CANCELED`].
143    pub fn abort(&self) {
144        self.0.abort()
145    }
146}
147
148impl<T> Future for JoinHandle<T> {
149    type Output = Result<T, Status>;
150
151    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        self.0.poll_unpin(cx)
153    }
154}
155
156struct TaskWakerState<T, D> {
157    result_sender: mpsc::SyncSender<Result<T, Status>>,
158    future_state: Arc<TaskFutureState>,
159    future: Mutex<Option<BoxFuture<'static, T>>>,
160    dispatcher: D,
161}
162
163impl<T: Send + 'static, D: OnDispatcher + 'static> Wake for TaskWakerState<T, D> {
164    fn wake(self: Arc<Self>) {
165        self.wake_by_ref();
166    }
167    fn wake_by_ref(self: &Arc<Self>) {
168        match self.queue() {
169            Err(e) if e == Status::BAD_STATE => {
170                // the dispatcher is shutting down so drop the future, if there
171                // is one, to cancel it.
172                let future_slot = self.future.lock().take();
173                drop(future_slot);
174                self.send_result(Err(e));
175            }
176            res => res.expect("Unexpected error waking dispatcher task"),
177        }
178    }
179}
180
181impl<T: Send + 'static, D: OnDispatcher + 'static> TaskWakerState<T, D> {
182    /// Sends the result to the future end of this task, if it still exists.
183    fn send_result(&self, res: Result<T, Status>) {
184        // send the result and wake the waker if any has been registered.
185        // We ignore the result here because if the other end has dropped it's
186        // fine for the result to go nowhere.
187        self.result_sender.try_send(res).ok();
188        self.future_state.waker.wake();
189    }
190
191    /// Posts a task to progress the currently stored future. The task will
192    /// consume the future if the future is ready after the next poll.
193    /// Otherwise, the future is kept to be polled again after being woken.
194    pub(crate) fn queue(self: &Arc<Self>) -> Result<(), Status> {
195        let arc_self = self.clone();
196        self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
197            dispatcher
198                .post_task_sync(move |status| {
199                    let mut future_slot = arc_self.future.lock();
200                    // if the executor is shutting down, drop the future we're waiting on and pass
201                    // on the error.
202                    if status != Status::from_raw(ZX_OK) {
203                        drop(future_slot.take());
204                        arc_self.send_result(Err(status));
205                        return;
206                    }
207
208                    // if the future has been dropped without being detached, drop the future and
209                    // send an Err(Status::CANCELED) if the caller is still listening.
210                    if arc_self.future_state.aborted.load(Ordering::Relaxed) {
211                        drop(future_slot.take());
212                        arc_self.send_result(Err(Status::CANCELED));
213                        return;
214                    }
215
216                    let Some(mut future) = future_slot.take() else {
217                        return;
218                    };
219                    let waker = Waker::from(arc_self.clone());
220                    let context = &mut Context::from_waker(&waker);
221                    match future.as_mut().poll(context) {
222                        Poll::Pending => *future_slot = Some(future),
223                        Poll::Ready(res) => {
224                            arc_self.send_result(Ok(res));
225                        }
226                    }
227                })
228                .map(|_| ())
229        })
230    }
231}