libasync/dispatcher/
task.rs1use zx::sys::ZX_OK;
8
9use core::task::Context;
10use fuchsia_sync::Mutex;
11use std::pin::Pin;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, mpsc};
14use std::task::{Poll, Wake, Waker};
15
16use zx::Status;
17
18use futures::future::{BoxFuture, FutureExt};
19use futures::task::AtomicWaker;
20
21use crate::{AsyncDispatcher, OnDispatcher};
22
23#[must_use]
26pub struct Task<T> {
27 state: Arc<TaskFutureState>,
28 result_receiver: mpsc::Receiver<Result<T, Status>>,
29 detached: bool,
30}
31
32impl<T: Send + 'static> Task<T> {
33 fn new<D: OnDispatcher + 'static>(
34 future: impl Future<Output = T> + Send + 'static,
35 dispatcher: D,
36 ) -> (Self, Arc<TaskWakerState<T, D>>) {
37 let future_state = Arc::new(TaskFutureState {
38 waker: AtomicWaker::new(),
39 aborted: AtomicBool::new(false),
40 });
41 let (result_sender, result_receiver) = mpsc::sync_channel(1);
42 let state = Arc::new(TaskWakerState {
43 result_sender,
44 future_state: future_state.clone(),
45 future: Mutex::new(Some(future.boxed())),
46 dispatcher,
47 });
48 let future = Task { state: future_state, result_receiver, detached: false };
49 (future, state)
50 }
51
52 pub(crate) fn try_start<D: OnDispatcher + 'static>(
53 future: impl Future<Output = T> + Send + 'static,
54 dispatcher: D,
55 ) -> Result<Self, Status> {
56 let (future, state) = Self::new(future, dispatcher);
57 state.queue().map(|_| future)
58 }
59
60 pub(crate) fn start<D: OnDispatcher + 'static>(
61 future: impl Future<Output = T> + Send + 'static,
62 dispatcher: D,
63 ) -> Self {
64 let (future, state) = Self::new(future, dispatcher);
65
66 if let Err(err) = state.queue() {
69 drop(state.future.lock().take());
71 state.result_sender.try_send(Err(err)).unwrap();
74 }
75
76 future
77 }
78}
79
80impl<T> Task<T> {
81 pub fn detach(self) {
85 drop(self.detach_on_drop());
86 }
87
88 pub fn detach_on_drop(mut self) -> JoinHandle<T> {
94 self.detached = true;
95 JoinHandle(self)
96 }
97
98 pub fn abort(&self) {
102 self.state.aborted.store(true, Ordering::Relaxed);
103 }
104}
105
106impl<T> Drop for Task<T> {
107 fn drop(&mut self) {
108 if !self.detached {
109 self.state.aborted.store(true, Ordering::Relaxed);
110 }
111 }
112}
113
114struct TaskFutureState {
115 waker: AtomicWaker,
116 aborted: AtomicBool,
117}
118
119impl<T> Future for Task<T> {
120 type Output = Result<T, Status>;
121
122 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 use std::sync::mpsc::TryRecvError;
124 self.state.waker.register(cx.waker());
125 match self.result_receiver.try_recv() {
126 Ok(res) => Poll::Ready(res),
127 Err(TryRecvError::Disconnected) => Poll::Ready(Err(Status::CANCELED)),
128 Err(TryRecvError::Empty) => Poll::Pending,
129 }
130 }
131}
132
133pub struct JoinHandle<T>(Task<T>);
135
136impl<T> JoinHandle<T> {
137 pub fn abort(&self) {
141 self.0.abort()
142 }
143}
144
145impl<T> Future for JoinHandle<T> {
146 type Output = Result<T, Status>;
147
148 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
149 self.0.poll_unpin(cx)
150 }
151}
152
153struct TaskWakerState<T, D> {
154 result_sender: mpsc::SyncSender<Result<T, Status>>,
155 future_state: Arc<TaskFutureState>,
156 future: Mutex<Option<BoxFuture<'static, T>>>,
157 dispatcher: D,
158}
159
160impl<T: Send + 'static, D: OnDispatcher + 'static> Wake for TaskWakerState<T, D> {
161 fn wake(self: Arc<Self>) {
162 self.wake_by_ref();
163 }
164 fn wake_by_ref(self: &Arc<Self>) {
165 match self.queue() {
166 Err(e) if e == Status::BAD_STATE => {
167 let future_slot = self.future.lock().take();
170 drop(future_slot);
171 self.send_result(Err(e));
172 }
173 res => res.expect("Unexpected error waking dispatcher task"),
174 }
175 }
176}
177
178impl<T: Send + 'static, D: OnDispatcher + 'static> TaskWakerState<T, D> {
179 fn send_result(&self, res: Result<T, Status>) {
181 self.result_sender.try_send(res).ok();
185 self.future_state.waker.wake();
186 }
187
188 pub(crate) fn queue(self: &Arc<Self>) -> Result<(), Status> {
192 let arc_self = self.clone();
193 self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
194 dispatcher
195 .post_task_sync(move |status| {
196 let mut future_slot = arc_self.future.lock();
197 if status != Status::from_raw(ZX_OK) {
200 drop(future_slot.take());
201 arc_self.send_result(Err(status));
202 return;
203 }
204
205 if arc_self.future_state.aborted.load(Ordering::Relaxed) {
208 drop(future_slot.take());
209 arc_self.send_result(Err(Status::CANCELED));
210 return;
211 }
212
213 let Some(mut future) = future_slot.take() else {
214 return;
215 };
216 let waker = Waker::from(arc_self.clone());
217 let context = &mut Context::from_waker(&waker);
218 match future.as_mut().poll(context) {
219 Poll::Pending => *future_slot = Some(future),
220 Poll::Ready(res) => {
221 arc_self.send_result(Ok(res));
222 }
223 }
224 })
225 .map(|_| ())
226 })
227 }
228}