libasync/dispatcher/
task.rs1use zx::sys::ZX_OK;
8
9use core::task::Context;
10use fuchsia_sync::Mutex;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, mpsc};
14use std::task::{Poll, Wake, Waker};
15
16use zx::Status;
17
18use futures::future::{BoxFuture, FutureExt};
19use futures::task::AtomicWaker;
20
21use crate::{AsyncDispatcher, OnDispatcher};
22
23#[must_use]
26#[derive(Debug)]
27pub struct Task<T> {
28 state: Arc<TaskFutureState>,
29 result_receiver: mpsc::Receiver<Result<T, Status>>,
30 detached: bool,
31}
32
33impl<T: Send + 'static> Task<T> {
34 fn new<D: OnDispatcher + 'static>(
35 future: impl Future<Output = T> + Send + 'static,
36 dispatcher: D,
37 ) -> (Self, Arc<TaskWakerState<T, D>>) {
38 let future_state = Arc::new(TaskFutureState {
39 waker: AtomicWaker::new(),
40 aborted: AtomicBool::new(false),
41 });
42 let (result_sender, result_receiver) = mpsc::sync_channel(1);
43 let state = Arc::new(TaskWakerState {
44 result_sender,
45 future_state: future_state.clone(),
46 future: Mutex::new(Some(future.boxed())),
47 dispatcher,
48 });
49 let future = Task { state: future_state, result_receiver, detached: false };
50 (future, state)
51 }
52
53 pub(crate) fn try_start<D: OnDispatcher + 'static>(
54 future: impl Future<Output = T> + Send + 'static,
55 dispatcher: D,
56 ) -> Result<Self, Status> {
57 let (future, state) = Self::new(future, dispatcher);
58 state.queue().map(|_| future)
59 }
60
61 pub(crate) fn start<D: OnDispatcher + 'static>(
62 future: impl Future<Output = T> + Send + 'static,
63 dispatcher: D,
64 ) -> Self {
65 let (future, state) = Self::new(future, dispatcher);
66
67 if let Err(err) = state.queue() {
70 drop(state.future.lock().take());
72 state.result_sender.try_send(Err(err)).unwrap();
75 }
76
77 future
78 }
79}
80
81impl<T> Task<T> {
82 pub fn detach(self) {
86 drop(self.detach_on_drop());
87 }
88
89 pub fn detach_on_drop(mut self) -> JoinHandle<T> {
95 self.detached = true;
96 JoinHandle(self)
97 }
98
99 pub fn abort(&self) {
103 self.state.aborted.store(true, Ordering::Relaxed);
104 }
105}
106
107impl<T> Drop for Task<T> {
108 fn drop(&mut self) {
109 if !self.detached {
110 self.state.aborted.store(true, Ordering::Relaxed);
111 }
112 }
113}
114
115#[derive(Debug)]
116struct TaskFutureState {
117 waker: AtomicWaker,
118 aborted: AtomicBool,
119}
120
121impl<T> Future for Task<T> {
122 type Output = Result<T, Status>;
123
124 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125 use std::sync::mpsc::TryRecvError;
126 self.state.waker.register(cx.waker());
127 match self.result_receiver.try_recv() {
128 Ok(res) => Poll::Ready(res),
129 Err(TryRecvError::Disconnected) => Poll::Ready(Err(Status::CANCELED)),
130 Err(TryRecvError::Empty) => Poll::Pending,
131 }
132 }
133}
134
135#[derive(Debug)]
137pub struct JoinHandle<T>(Task<T>);
138
139impl<T> JoinHandle<T> {
140 pub fn abort(&self) {
144 self.0.abort()
145 }
146}
147
148impl<T> Future for JoinHandle<T> {
149 type Output = Result<T, Status>;
150
151 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 self.0.poll_unpin(cx)
153 }
154}
155
156struct TaskWakerState<T, D> {
157 result_sender: mpsc::SyncSender<Result<T, Status>>,
158 future_state: Arc<TaskFutureState>,
159 future: Mutex<Option<BoxFuture<'static, T>>>,
160 dispatcher: D,
161}
162
163impl<T: Send + 'static, D: OnDispatcher + 'static> Wake for TaskWakerState<T, D> {
164 fn wake(self: Arc<Self>) {
165 self.wake_by_ref();
166 }
167 fn wake_by_ref(self: &Arc<Self>) {
168 match self.queue() {
169 Err(e) if e == Status::BAD_STATE => {
170 let future_slot = self.future.lock().take();
173 drop(future_slot);
174 self.send_result(Err(e));
175 }
176 res => res.expect("Unexpected error waking dispatcher task"),
177 }
178 }
179}
180
181impl<T: Send + 'static, D: OnDispatcher + 'static> TaskWakerState<T, D> {
182 fn send_result(&self, res: Result<T, Status>) {
184 self.result_sender.try_send(res).ok();
188 self.future_state.waker.wake();
189 }
190
191 pub(crate) fn queue(self: &Arc<Self>) -> Result<(), Status> {
195 let arc_self = self.clone();
196 self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
197 dispatcher
198 .post_task_sync(move |status| {
199 let mut future_slot = arc_self.future.lock();
200 if status != Status::from_raw(ZX_OK) {
203 drop(future_slot.take());
204 arc_self.send_result(Err(status));
205 return;
206 }
207
208 if arc_self.future_state.aborted.load(Ordering::Relaxed) {
211 drop(future_slot.take());
212 arc_self.send_result(Err(Status::CANCELED));
213 return;
214 }
215
216 let Some(mut future) = future_slot.take() else {
217 return;
218 };
219 let waker = Waker::from(arc_self.clone());
220 let context = &mut Context::from_waker(&waker);
221 match future.as_mut().poll(context) {
222 Poll::Pending => *future_slot = Some(future),
223 Poll::Ready(res) => {
224 arc_self.send_result(Ok(res));
225 }
226 }
227 })
228 .map(|_| ())
229 })
230 }
231}