1use libasync_sys::*;
8use zx::sys::ZX_OK;
9
10use core::cell::UnsafeCell;
11use core::future::Future;
12use core::marker::PhantomData;
13use core::ptr::NonNull;
14use core::task::Context;
15use fuchsia_sync::Mutex;
16use std::sync::{Arc, Weak};
17
18use zx::Status;
19
20use futures::future::{BoxFuture, FutureExt};
21use futures::task::{ArcWake, waker_ref};
22
23mod after_deadline;
24
25pub use after_deadline::*;
26
27#[derive(Debug)]
31pub struct AsyncDispatcherRef<'a>(NonNull<async_dispatcher_t>, PhantomData<&'a async_dispatcher_t>);
32
33unsafe impl<'a> Send for AsyncDispatcherRef<'a> {}
34unsafe impl<'a> Sync for AsyncDispatcherRef<'a> {}
35
36impl<'a> AsyncDispatcherRef<'a> {
37 pub unsafe fn from_raw(ptr: NonNull<async_dispatcher_t>) -> Self {
44 Self(ptr, PhantomData)
46 }
47
48 pub fn inner(&self) -> NonNull<async_dispatcher_t> {
50 self.0
51 }
52}
53
54impl<'a> Clone for AsyncDispatcherRef<'a> {
55 fn clone(&self) -> Self {
56 Self(self.0, PhantomData)
57 }
58}
59
60pub trait AsyncDispatcher: Send + Sync {
62 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_>;
64
65 fn post_task_sync(&self, p: impl TaskCallback) -> Result<(), Status> {
67 #[expect(clippy::arc_with_non_send_sync)]
68 let task_arc = Arc::new(UnsafeCell::new(TaskFunc {
69 task: async_task { handler: Some(TaskFunc::call), ..Default::default() },
70 func: Box::new(p),
71 }));
72
73 let task_cell = Arc::into_raw(task_arc);
74 let res = unsafe {
81 let task_ptr = &raw mut (*UnsafeCell::raw_get(task_cell)).task;
82 Status::ok(async_post_task(self.as_async_dispatcher_ref().0.as_ptr(), task_ptr))
83 };
84 if res.is_err() {
85 unsafe { Arc::decrement_strong_count(task_cell) }
88 }
89 res
90 }
91
92 fn now(&self) -> zx::MonotonicInstant {
94 let async_dispatcher = self.as_async_dispatcher_ref().0.as_ptr();
95 let now_nanos = unsafe { async_now(async_dispatcher) };
96 zx::MonotonicInstant::from_nanos(now_nanos)
97 }
98}
99
100impl<'a> AsyncDispatcher for AsyncDispatcherRef<'a> {
101 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
102 self.clone()
103 }
104}
105
106pub trait OnDispatcher: Clone + Send + Sync {
108 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R;
111
112 fn on_maybe_dispatcher<R, E: From<Status>>(
115 &self,
116 f: impl FnOnce(AsyncDispatcherRef<'_>) -> Result<R, E>,
117 ) -> Result<R, E> {
118 self.on_dispatcher(|dispatcher| {
119 let dispatcher = dispatcher.ok_or(Status::BAD_STATE)?;
120 f(dispatcher)
121 })
122 }
123
124 fn spawn_task(&self, future: impl Future<Output = ()> + Send + 'static) -> Result<(), Status>
128 where
129 Self: 'static,
130 {
131 let task =
132 Arc::new(Task { future: Mutex::new(Some(future.boxed())), dispatcher: self.clone() });
133 task.queue()
134 }
135
136 fn after_deadline(&self, deadline: zx::MonotonicInstant) -> AfterDeadline<Self> {
143 AfterDeadline::new(self, deadline)
144 }
145}
146
147impl<D: AsyncDispatcher> OnDispatcher for &D {
148 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
149 f(Some(D::as_async_dispatcher_ref(*self)))
150 }
151}
152
153impl<'a> OnDispatcher for AsyncDispatcherRef<'a> {
154 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
155 f(Some(self.clone()))
156 }
157}
158
159impl<T: AsyncDispatcher> OnDispatcher for Arc<T> {
160 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
161 f(Some(self.as_async_dispatcher_ref()))
162 }
163}
164
165impl<T: AsyncDispatcher> OnDispatcher for Weak<T> {
166 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
167 let dispatcher = Weak::upgrade(self);
168 match dispatcher {
169 Some(dispatcher) => f(Some(dispatcher.as_async_dispatcher_ref())),
170 None => f(None),
171 }
172 }
173}
174
175pub trait TaskCallback: FnOnce(Status) + 'static + Send {}
177impl<T> TaskCallback for T where T: FnOnce(Status) + 'static + Send {}
178
179struct Task<D> {
180 future: Mutex<Option<BoxFuture<'static, ()>>>,
181 dispatcher: D,
182}
183
184impl<D: OnDispatcher + 'static> ArcWake for Task<D> {
185 fn wake_by_ref(arc_self: &Arc<Self>) {
186 match arc_self.queue() {
187 Err(e) if e == Status::BAD_STATE => {
188 let future_slot = arc_self.future.lock().take();
191 core::mem::drop(future_slot);
192 }
193 res => res.expect("Unexpected error waking dispatcher task"),
194 }
195 }
196}
197
198impl<D: OnDispatcher + 'static> Task<D> {
199 fn queue(self: &Arc<Self>) -> Result<(), Status> {
203 let arc_self = self.clone();
204 self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
205 dispatcher
206 .post_task_sync(move |status| {
207 let mut future_slot = arc_self.future.lock();
208 if status != Status::from_raw(ZX_OK) {
210 core::mem::drop(future_slot.take());
211 return;
212 }
213
214 let Some(mut future) = future_slot.take() else {
215 return;
216 };
217 let waker = waker_ref(&arc_self);
218 let context = &mut Context::from_waker(&waker);
219 if future.as_mut().poll(context).is_pending() {
220 *future_slot = Some(future);
221 }
222 })
223 .map(|_| ())
224 })
225 }
226}
227
228#[repr(C)]
229struct TaskFunc {
230 task: async_task,
231 func: Box<dyn TaskCallback>,
232}
233
234impl TaskFunc {
235 extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
236 let task = unsafe { Arc::from_raw(task as *const UnsafeCell<Self>) };
239 if let Ok(task) = Arc::try_unwrap(task) {
242 (task.into_inner().func)(Status::from_raw(status));
243 }
244 }
245}