libasync/
after_deadline.rs1use std::pin::Pin;
6use std::ptr::NonNull;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI32, Ordering};
9use std::task::{Context, Poll};
10
11use libasync_sys::{async_cancel_task, async_dispatcher, async_post_task, async_task};
12
13use futures::task::AtomicWaker;
14use zx::Status;
15use zx::sys::{ZX_ERR_CANCELED, ZX_OK};
16
17use crate::callback_state::CallbackSharedState;
18use crate::{AsyncDispatcher, OnDispatcher};
19
20type SharedState = CallbackSharedState<async_task, AfterDeadlineState>;
21
22pub trait DispatcherTimerExt: OnDispatcher {
24 fn after_deadline(&self, deadline: zx::MonotonicInstant) -> AfterDeadline<Self>;
31}
32
33impl<T> DispatcherTimerExt for T
34where
35 T: OnDispatcher,
36{
37 fn after_deadline(&self, deadline: zx::MonotonicInstant) -> AfterDeadline<Self> {
38 AfterDeadline::new(self, deadline)
39 }
40}
41
42struct AfterDeadlineState {
43 async_dispatcher: NonNull<async_dispatcher>,
44 waker: AtomicWaker,
45 status: AtomicI32,
48}
49
50unsafe impl Send for AfterDeadlineState {}
52unsafe impl Sync for AfterDeadlineState {}
53
54impl AfterDeadlineState {
55 extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
56 debug_assert!(
57 status == ZX_OK || status == ZX_ERR_CANCELED,
58 "task callback called with status other than ok or canceled"
59 );
60 let state = unsafe { SharedState::from_raw_ptr(task) };
63 state.status.store(status, Ordering::Relaxed);
64 state.waker.wake();
65 }
66}
67
68pub struct AfterDeadline<D: OnDispatcher> {
72 dispatcher: D,
73 state: Option<Arc<SharedState>>,
74 deadline: zx::MonotonicInstant,
75}
76
77impl<D: OnDispatcher + Clone> AfterDeadline<D> {
78 pub(super) fn new(dispatcher: &D, deadline: zx::MonotonicInstant) -> Self {
79 let dispatcher = dispatcher.clone();
80 let state = None;
81 Self { dispatcher, state, deadline }
82 }
83}
84
85impl<D: OnDispatcher + Unpin> Future for AfterDeadline<D> {
86 type Output = Result<(), Status>;
87
88 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
89 if let Some(state) = &self.state {
91 let status = state.status.load(Ordering::Relaxed);
92 if status != Status::SHOULD_WAIT.into_raw() {
93 return Poll::Ready(Status::ok(status));
94 } else {
95 state.waker.register(cx.waker());
96 return Poll::Pending;
97 }
98 }
99
100 let deadline = self.deadline;
101
102 let now = self.dispatcher.on_maybe_dispatcher(|dispatcher| Ok(dispatcher.now()));
103 match now {
104 Ok(now) if deadline < zx::MonotonicInstant::from_nanos(now) => {
105 return Poll::Ready(Ok(()));
106 }
107 Err(err) => {
108 return Poll::Ready(Err(err));
109 }
110 _ => {}
111 }
112
113 let state = self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
115 let async_dispatcher = dispatcher.inner();
117
118 let task = async_task {
119 handler: Some(AfterDeadlineState::call),
120 deadline: deadline.into_nanos(),
121 ..Default::default()
122 };
123 let state = AfterDeadlineState {
124 async_dispatcher,
125 waker: AtomicWaker::new(),
126 status: AtomicI32::new(Status::SHOULD_WAIT.into_raw()),
127 };
128 let state = SharedState::new(task, state);
129 state.waker.register(cx.waker());
130
131 let state_ptr = SharedState::make_raw_ptr(state.clone());
132
133 let res = Status::ok(unsafe { async_post_task(async_dispatcher.as_ptr(), state_ptr) });
137 match res {
138 Ok(_) => Ok(state),
139 Err(err) => {
140 unsafe { SharedState::release_raw_ptr(state_ptr) };
143 Err(err)
144 }
145 }
146 });
147
148 match state {
149 Ok(state) => {
150 self.state = Some(state);
151 Poll::Pending
152 }
153 Err(err) => Poll::Ready(Err(err)),
154 }
155 }
156}
157
158impl<D: OnDispatcher> Drop for AfterDeadline<D> {
159 fn drop(&mut self) {
160 let Some(state) = self.state.take() else {
161 return;
163 };
164 self.dispatcher.on_dispatcher(|dispatcher| {
165 let Some(dispatcher) = dispatcher else {
166 return;
170 };
171 if state.status.load(Ordering::Relaxed) != Status::SHOULD_WAIT.into_raw() {
172 return;
174 }
175 let async_dispatcher = dispatcher.inner();
176 if async_dispatcher != state.async_dispatcher {
177 panic!("Dropping a pending `AfterDeadline` future from a different dispatcher than the one it was awaited on.");
178 }
179 let state_ptr = SharedState::as_raw_ptr(&state);
180 let status = unsafe { async_cancel_task(async_dispatcher.as_ptr(), state_ptr) };
184 if Status::from_raw(status) == Status::OK {
185 unsafe { SharedState::release_raw_ptr(state_ptr) };
188 }
189 });
190 }
191}
192
193#[cfg(all(not_yet, test))]
194mod tests {
195 use std::sync::mpsc;
196 use std::thread::sleep;
197 use std::time::Duration;
198
199 use super::*;
200
201 use futures::{FutureExt, poll};
202 use std::task::Waker;
203
204 use crate::dispatcher::tests::with_raw_dispatcher;
205 use crate::dispatcher::{CurrentDispatcher, OnDispatcher};
206
207 #[test]
208 fn after_the_past() {
209 with_raw_dispatcher("testing task", |dispatcher| {
210 let (tx, rx) = mpsc::channel();
211 dispatcher
212 .spawn_task(async move {
213 let fut = CurrentDispatcher.after_deadline(zx::MonotonicInstant::INFINITE_PAST);
214 assert_eq!(poll!(fut), Poll::Ready(Ok(())));
215 tx.send(()).unwrap();
216 })
217 .unwrap();
218 rx.recv().unwrap();
219 });
220 }
221
222 #[test]
223 fn after_now() {
224 with_raw_dispatcher("testing task", |dispatcher| {
225 let (tx, rx) = mpsc::channel();
226 dispatcher
227 .spawn_task(async move {
228 let fut = CurrentDispatcher.after_deadline(CurrentDispatcher.now().unwrap());
229 assert_eq!(poll!(fut), Poll::Ready(Ok(())));
230 tx.send(()).unwrap();
231 })
232 .unwrap();
233 rx.recv().unwrap();
234 });
235 }
236
237 #[test]
238 fn after_future() {
239 with_raw_dispatcher("testing task", |dispatcher| {
240 let (tx, rx) = mpsc::channel();
241 dispatcher
242 .spawn_task(async move {
243 let deadline =
244 CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_seconds(3);
245 let mut fut = CurrentDispatcher.after_deadline(deadline);
246 assert_eq!(poll!(&mut fut), Poll::Pending);
247 assert!(fut.await.is_ok());
248 assert!(CurrentDispatcher.now().unwrap() >= deadline);
249 tx.send(()).unwrap();
250 })
251 .unwrap();
252 rx.recv().unwrap();
253 });
254 }
255
256 #[test]
257 fn drop_after_poll() {
258 with_raw_dispatcher("testing task", |dispatcher| {
259 let (tx, rx) = mpsc::channel();
260 dispatcher
261 .spawn_task(async move {
262 let deadline =
263 CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_minutes(3);
264 let mut fut = CurrentDispatcher.after_deadline(deadline);
265 assert_eq!(poll!(&mut fut), Poll::Pending);
266 tx.send(()).unwrap();
267 })
268 .unwrap();
269 rx.recv().unwrap();
270 });
271 }
272
273 #[test]
274 fn dispatcher_shutdown_cancel() {
275 let (fut_tx, fut_rx) = mpsc::channel();
276 with_raw_dispatcher("testing task", |dispatcher| {
277 let (tx, rx) = mpsc::channel();
278 dispatcher
279 .spawn_task(async move {
280 let deadline =
281 CurrentDispatcher.now().unwrap() + zx::MonotonicDuration::from_minutes(3);
282 let mut fut = CurrentDispatcher.after_deadline(deadline);
283 assert_eq!(poll!(&mut fut), Poll::Pending);
284 fut_tx.send(fut).unwrap();
285 tx.send(()).unwrap();
286 })
287 .unwrap();
288 rx.recv().unwrap();
289 });
290 let mut fut = fut_rx.recv().unwrap();
291 loop {
292 let Poll::Ready(res) = fut.poll_unpin(&mut Context::from_waker(Waker::noop())) else {
293 sleep(Duration::from_millis(10));
294 continue;
295 };
296 assert_eq!(res, Err(Status::CANCELED));
297 break;
298 }
299 }
300}