1use crate::token_registry::TokenRegistry;
20
21use fuchsia_async::{self as fasync, JoinHandle, Scope, SpawnableFuture};
22use futures::task::{self, Poll};
23use futures::Future;
24use pin_project::pin_project;
25use std::future::{pending, poll_fn};
26use std::pin::{pin, Pin};
27use std::sync::{Arc, Mutex, OnceLock};
28use std::task::{ready, Context};
29
30#[cfg(target_os = "fuchsia")]
31use fuchsia_async::EHandle;
32
33pub type SpawnError = task::SpawnError;
34
35#[derive(Clone)]
46pub struct ExecutionScope {
47 executor: Arc<Executor>,
48}
49
50struct Executor {
51 inner: Mutex<Inner>,
52 token_registry: TokenRegistry,
53 scope: OnceLock<Scope>,
54}
55
56struct Inner {
57 shutdown_state: ShutdownState,
59
60 active_count: usize,
62
63 fake_active_task: Option<fasync::Task<()>>,
66}
67
68#[derive(Copy, Clone, PartialEq)]
69enum ShutdownState {
70 Active,
71 Shutdown,
72 ForceShutdown,
73}
74
75impl ExecutionScope {
76 pub fn new() -> Self {
79 Self::build().new()
80 }
81
82 pub fn build() -> ExecutionScopeParams {
86 ExecutionScopeParams::default()
87 }
88
89 pub fn active_count(&self) -> usize {
91 self.executor.inner.lock().unwrap().active_count
92 }
93
94 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
106 self.executor.scope().spawn(FutureWithShutdown { executor: self.executor.clone(), task })
107 }
108
109 pub fn new_task(self, task: impl Future<Output = ()> + Send + 'static) -> Task {
111 Task(
112 self.executor.clone(),
113 SpawnableFuture::new(FutureWithShutdown { executor: self.executor, task }),
114 )
115 }
116
117 pub fn token_registry(&self) -> &TokenRegistry {
118 &self.executor.token_registry
119 }
120
121 pub fn shutdown(&self) {
122 self.executor.shutdown();
123 }
124
125 pub fn force_shutdown(&self) {
127 let mut inner = self.executor.inner.lock().unwrap();
128 inner.shutdown_state = ShutdownState::ForceShutdown;
129 self.executor.scope().wake_all();
130 }
131
132 pub fn resurrect(&self) {
135 self.executor.inner.lock().unwrap().shutdown_state = ShutdownState::Active;
136 }
137
138 pub async fn wait(&self) {
140 let mut on_no_tasks = pin!(self.executor.scope().on_no_tasks());
141 poll_fn(|cx| {
142 let mut inner = self.executor.inner.lock().unwrap();
144 ready!(on_no_tasks.as_mut().poll(cx));
145 if inner.active_count == 0 {
146 Poll::Ready(())
147 } else {
148 let scope = self.executor.scope();
153 inner.fake_active_task = Some(scope.compute(pending::<()>()));
154 on_no_tasks.set(scope.on_no_tasks());
155 assert!(on_no_tasks.as_mut().poll(cx).is_pending());
156 Poll::Pending
157 }
158 })
159 .await;
160 }
161
162 pub fn try_active_guard(&self) -> Option<ActiveGuard> {
165 let mut inner = self.executor.inner.lock().unwrap();
166 if inner.shutdown_state != ShutdownState::Active {
167 return None;
168 }
169 inner.active_count += 1;
170 Some(ActiveGuard(self.executor.clone()))
171 }
172
173 pub fn active_guard(&self) -> ActiveGuard {
176 self.executor.inner.lock().unwrap().active_count += 1;
177 ActiveGuard(self.executor.clone())
178 }
179}
180
181impl PartialEq for ExecutionScope {
182 fn eq(&self, other: &Self) -> bool {
183 Arc::as_ptr(&self.executor) == Arc::as_ptr(&other.executor)
184 }
185}
186
187impl Eq for ExecutionScope {}
188
189impl std::fmt::Debug for ExecutionScope {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.write_fmt(format_args!("ExecutionScope {:?}", Arc::as_ptr(&self.executor)))
192 }
193}
194
195#[derive(Default)]
196pub struct ExecutionScopeParams {
197 #[cfg(target_os = "fuchsia")]
198 async_executor: Option<EHandle>,
199}
200
201impl ExecutionScopeParams {
202 #[cfg(target_os = "fuchsia")]
203 pub fn executor(mut self, value: EHandle) -> Self {
204 assert!(self.async_executor.is_none(), "`executor` is already set");
205 self.async_executor = Some(value);
206 self
207 }
208
209 pub fn new(self) -> ExecutionScope {
210 ExecutionScope {
211 executor: Arc::new(Executor {
212 token_registry: TokenRegistry::new(),
213 inner: Mutex::new(Inner {
214 shutdown_state: ShutdownState::Active,
215 active_count: 0,
216 fake_active_task: None,
217 }),
218 #[cfg(target_os = "fuchsia")]
219 scope: self
220 .async_executor
221 .map_or_else(|| OnceLock::new(), |e| e.global_scope().new_child().into()),
222 #[cfg(not(target_os = "fuchsia"))]
223 scope: OnceLock::new(),
224 }),
225 }
226 }
227}
228
229impl Executor {
230 fn scope(&self) -> &Scope {
231 self.scope.get_or_init(|| {
235 #[cfg(target_os = "fuchsia")]
236 return Scope::global().new_child();
237 #[cfg(not(target_os = "fuchsia"))]
238 return Scope::new();
239 })
240 }
241
242 fn shutdown(&self) {
243 let wake_all = {
244 let mut inner = self.inner.lock().unwrap();
245 inner.shutdown_state = ShutdownState::Shutdown;
246 inner.active_count == 0
247 };
248 if wake_all {
249 if let Some(scope) = self.scope.get() {
250 scope.wake_all();
251 }
252 }
253 }
254}
255
256impl Drop for Executor {
257 fn drop(&mut self) {
258 self.shutdown();
259 }
260}
261
262pub struct ActiveGuard(Arc<Executor>);
264
265impl Drop for ActiveGuard {
266 fn drop(&mut self) {
267 let wake_all = {
268 let mut inner = self.0.inner.lock().unwrap();
269 inner.active_count -= 1;
270 if inner.active_count == 0 {
271 if let Some(task) = inner.fake_active_task.take() {
272 let _ = task.cancel();
273 }
274 }
275 inner.active_count == 0 && inner.shutdown_state == ShutdownState::Shutdown
276 };
277 if wake_all {
278 self.0.scope().wake_all();
279 }
280 }
281}
282
283pub async fn yield_to_executor() {
285 let mut done = false;
286 poll_fn(|cx| {
287 if done {
288 Poll::Ready(())
289 } else {
290 done = true;
291 cx.waker().wake_by_ref();
292 Poll::Pending
293 }
294 })
295 .await;
296}
297
298#[pin_project]
300struct FutureWithShutdown<Task: Future<Output = ()> + Send + 'static> {
301 executor: Arc<Executor>,
302 #[pin]
303 task: Task,
304}
305
306impl<Task: Future<Output = ()> + Send + 'static> Future for FutureWithShutdown<Task> {
307 type Output = ();
308
309 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
310 let this = self.project();
311 let shutdown_state = this.executor.inner.lock().unwrap().shutdown_state;
312 match this.task.poll(cx) {
313 Poll::Ready(()) => Poll::Ready(()),
314 Poll::Pending => match shutdown_state {
315 ShutdownState::Active => Poll::Pending,
316 ShutdownState::Shutdown if this.executor.inner.lock().unwrap().active_count > 0 => {
317 Poll::Pending
318 }
319 _ => Poll::Ready(()),
320 },
321 }
322 }
323}
324
325pub struct Task(Arc<Executor>, SpawnableFuture<'static, ()>);
326
327impl Task {
328 pub fn spawn(self) {
330 self.0.scope().spawn(self.1);
331 }
332}
333
334impl Future for Task {
335 type Output = ();
336
337 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
338 Pin::new(&mut &mut self.1).poll(cx)
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::{yield_to_executor, ExecutionScope};
345
346 use fuchsia_async::{Task, TestExecutor, Timer};
347 use futures::channel::oneshot;
348 use futures::stream::FuturesUnordered;
349 use futures::task::Poll;
350 use futures::{Future, StreamExt};
351 use std::pin::pin;
352 use std::sync::atomic::{AtomicBool, Ordering};
353 use std::sync::Arc;
354 use std::time::Duration;
355
356 #[cfg(target_os = "fuchsia")]
357 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
358 where
359 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
360 GetTestRes: Future<Output = ()>,
361 {
362 let mut exec = TestExecutor::new();
363
364 let scope = ExecutionScope::new();
365
366 let test = get_test(scope);
367
368 assert_eq!(
369 exec.run_until_stalled(&mut pin!(test)),
370 Poll::Ready(()),
371 "Test did not complete"
372 );
373 }
374
375 #[cfg(not(target_os = "fuchsia"))]
376 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
377 where
378 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
379 GetTestRes: Future<Output = ()>,
380 {
381 use fuchsia_async::TimeoutExt;
382 let mut exec = TestExecutor::new();
383
384 let scope = ExecutionScope::new();
385
386 let test =
390 get_test(scope).on_stalled(Duration::from_secs(30), || panic!("Test did not complete"));
391
392 exec.run_singlethreaded(&mut pin!(test));
393 }
394
395 #[test]
396 fn simple() {
397 run_test(|scope| {
398 async move {
399 let (sender, receiver) = oneshot::channel();
400 let (counters, task) = mocks::ImmediateTask::new(sender);
401
402 scope.spawn(task);
403
404 receiver.await.unwrap();
406
407 assert_eq!(counters.drop_call(), 1);
408 assert_eq!(counters.poll_call(), 1);
409 }
410 });
411 }
412
413 #[test]
414 fn simple_drop() {
415 run_test(|scope| {
416 async move {
417 let (poll_sender, poll_receiver) = oneshot::channel();
418 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
419 let (drop_sender, drop_receiver) = oneshot::channel();
420 let (counters, task) =
421 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
422
423 scope.spawn(task);
424
425 poll_receiver.await.unwrap();
426
427 processing_done_sender.send(()).unwrap();
428
429 scope.shutdown();
430
431 drop_receiver.await.unwrap();
432
433 let poll_count = counters.poll_call();
436 assert!(poll_count >= 1, "poll was not called");
437
438 assert_eq!(counters.drop_call(), 1);
439 }
440 });
441 }
442
443 #[test]
444 fn test_wait_waits_for_tasks_to_finish() {
445 let mut executor = TestExecutor::new();
446 let scope = ExecutionScope::new();
447 executor.run_singlethreaded(async {
448 let (poll_sender, poll_receiver) = oneshot::channel();
449 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
450 let (drop_sender, _drop_receiver) = oneshot::channel();
451 let (_, task) =
452 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
453
454 scope.spawn(task);
455
456 poll_receiver.await.unwrap();
457
458 let done = std::sync::Mutex::new(false);
461 futures::join!(
462 async {
463 scope.wait().await;
464 assert_eq!(*done.lock().unwrap(), true);
465 },
466 async {
467 Timer::new(Duration::from_millis(100)).await;
469 *done.lock().unwrap() = true;
470 processing_done_sender.send(()).unwrap();
471 }
472 );
473 });
474 }
475
476 #[fuchsia::test]
477 async fn test_active_guard() {
478 let scope = ExecutionScope::new();
479 let (guard_taken_tx, guard_taken_rx) = oneshot::channel();
480 let (shutdown_triggered_tx, shutdown_triggered_rx) = oneshot::channel();
481 let (drop_task_tx, drop_task_rx) = oneshot::channel();
482 let scope_clone = scope.clone();
483 let done = Arc::new(AtomicBool::new(false));
484 let done_clone = done.clone();
485 scope.spawn(async move {
486 {
487 struct OnDrop((ExecutionScope, Option<oneshot::Receiver<()>>));
488 impl Drop for OnDrop {
489 fn drop(&mut self) {
490 let guard = self.0 .0.active_guard();
491 let rx = self.0 .1.take().unwrap();
492 Task::spawn(async move {
493 rx.await.unwrap();
494 std::mem::drop(guard);
495 })
496 .detach();
497 }
498 }
499 let _guard = scope_clone.try_active_guard().unwrap();
500 let _on_drop = OnDrop((scope_clone, Some(drop_task_rx)));
501 guard_taken_tx.send(()).unwrap();
502 shutdown_triggered_rx.await.unwrap();
503 Timer::new(std::time::Duration::from_millis(100)).await;
506 done_clone.store(true, Ordering::SeqCst);
507 }
508 });
509 guard_taken_rx.await.unwrap();
510 scope.shutdown();
511
512 Timer::new(std::time::Duration::from_millis(100)).await;
515 let mut shutdown_wait = std::pin::pin!(scope.wait());
516 assert_eq!(futures::poll!(shutdown_wait.as_mut()), Poll::Pending);
517
518 shutdown_triggered_tx.send(()).unwrap();
519
520 Timer::new(std::time::Duration::from_millis(100)).await;
522 assert_eq!(futures::poll!(shutdown_wait.as_mut()), Poll::Pending);
523
524 drop_task_tx.send(()).unwrap();
525
526 shutdown_wait.await;
527
528 assert!(done.load(Ordering::SeqCst));
529 }
530
531 #[cfg(target_os = "fuchsia")]
532 #[fuchsia::test]
533 async fn test_shutdown_waits_for_channels() {
534 use fuchsia_async as fasync;
535
536 let scope = ExecutionScope::new();
537 let (rx, tx) = zx::Channel::create();
538 let received_msg = Arc::new(AtomicBool::new(false));
539 let (sender, receiver) = futures::channel::oneshot::channel();
540 {
541 let received_msg = received_msg.clone();
542 scope.spawn(async move {
543 let mut msg_buf = zx::MessageBuf::new();
544 msg_buf.ensure_capacity_bytes(64);
545 let _ = sender.send(());
546 let _ = fasync::Channel::from_channel(rx).recv_msg(&mut msg_buf).await;
547 received_msg.store(true, Ordering::Relaxed);
548 });
549 }
550 let _ = receiver.await;
552
553 tx.write(b"hello", &mut []).expect("write failed");
554 scope.shutdown();
555 scope.wait().await;
556 assert!(received_msg.load(Ordering::Relaxed));
557 }
558
559 #[fuchsia::test]
560 async fn test_force_shutdown() {
561 let scope = ExecutionScope::new();
562 let scope_clone = scope.clone();
563 let ref_count = Arc::new(());
564 let ref_count_clone = ref_count.clone();
565
566 scope.spawn(async move {
569 let _ref_count_clone = ref_count_clone;
570
571 let _guard = scope_clone.active_guard();
573
574 let _: () = std::future::pending().await;
575 });
576
577 scope.force_shutdown();
578 scope.wait().await;
579
580 assert_eq!(Arc::strong_count(&ref_count), 1);
582
583 scope.resurrect();
585
586 let ref_count_clone = ref_count.clone();
587 scope.spawn(async move {
588 yield_to_executor().await;
590
591 let _ref_count = ref_count_clone.clone();
593
594 let _: () = std::future::pending().await;
595 });
596
597 while Arc::strong_count(&ref_count) != 3 {
598 yield_to_executor().await;
599 }
600
601 for _ in 0..5 {
603 yield_to_executor().await;
604 assert_eq!(Arc::strong_count(&ref_count), 3);
605 }
606 }
607
608 #[fuchsia::test]
609 async fn test_task_runs_once() {
610 let scope = ExecutionScope::new();
611
612 scope.spawn(async {});
614
615 scope.shutdown();
616
617 let polled = Arc::new(AtomicBool::new(false));
618 let polled_clone = polled.clone();
619
620 let scope_clone = scope.clone();
621
622 let mut futures = FuturesUnordered::new();
624 futures.push(async move { scope_clone.wait().await });
625
626 assert_eq!(futures::poll!(futures.next()), Poll::Pending);
628
629 scope.spawn(async move {
632 assert_eq!(futures::poll!(futures.next()), Poll::Pending);
633 polled_clone.store(true, Ordering::Relaxed);
634 });
635
636 scope.wait().await;
637
638 assert!(polled.load(Ordering::Relaxed));
640 }
641
642 mod mocks {
643 use futures::channel::oneshot;
644 use futures::task::{Context, Poll};
645 use futures::Future;
646 use std::pin::Pin;
647 use std::sync::atomic::{AtomicUsize, Ordering};
648 use std::sync::Arc;
649
650 pub(super) struct TaskCounters {
651 poll_call_count: Arc<AtomicUsize>,
652 drop_call_count: Arc<AtomicUsize>,
653 }
654
655 impl TaskCounters {
656 fn new() -> (Arc<AtomicUsize>, Arc<AtomicUsize>, Self) {
657 let poll_call_count = Arc::new(AtomicUsize::new(0));
658 let drop_call_count = Arc::new(AtomicUsize::new(0));
659
660 (
661 poll_call_count.clone(),
662 drop_call_count.clone(),
663 Self { poll_call_count, drop_call_count },
664 )
665 }
666
667 pub(super) fn poll_call(&self) -> usize {
668 self.poll_call_count.load(Ordering::Relaxed)
669 }
670
671 pub(super) fn drop_call(&self) -> usize {
672 self.drop_call_count.load(Ordering::Relaxed)
673 }
674 }
675
676 pub(super) struct ImmediateTask {
677 poll_call_count: Arc<AtomicUsize>,
678 drop_call_count: Arc<AtomicUsize>,
679 done_sender: Option<oneshot::Sender<()>>,
680 }
681
682 impl ImmediateTask {
683 pub(super) fn new(done_sender: oneshot::Sender<()>) -> (TaskCounters, Self) {
684 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
685 (
686 counters,
687 Self { poll_call_count, drop_call_count, done_sender: Some(done_sender) },
688 )
689 }
690 }
691
692 impl Future for ImmediateTask {
693 type Output = ();
694
695 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
696 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
697
698 if let Some(sender) = self.done_sender.take() {
699 sender.send(()).unwrap();
700 }
701
702 Poll::Ready(())
703 }
704 }
705
706 impl Drop for ImmediateTask {
707 fn drop(&mut self) {
708 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
709 }
710 }
711
712 impl Unpin for ImmediateTask {}
713
714 pub(super) struct ControlledTask {
715 poll_call_count: Arc<AtomicUsize>,
716 drop_call_count: Arc<AtomicUsize>,
717
718 drop_sender: Option<oneshot::Sender<()>>,
719 future: Pin<Box<dyn Future<Output = ()> + Send>>,
720 }
721
722 impl ControlledTask {
723 pub(super) fn new(
724 poll_sender: oneshot::Sender<()>,
725 processing_complete: oneshot::Receiver<()>,
726 drop_sender: oneshot::Sender<()>,
727 ) -> (TaskCounters, Self) {
728 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
729 (
730 counters,
731 Self {
732 poll_call_count,
733 drop_call_count,
734 drop_sender: Some(drop_sender),
735 future: Box::pin(async move {
736 poll_sender.send(()).unwrap();
737 processing_complete.await.unwrap();
738 }),
739 },
740 )
741 }
742 }
743
744 impl Future for ControlledTask {
745 type Output = ();
746
747 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
748 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
749 self.future.as_mut().poll(cx)
750 }
751 }
752
753 impl Drop for ControlledTask {
754 fn drop(&mut self) {
755 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
756 self.drop_sender.take().unwrap().send(()).unwrap();
757 }
758 }
759 }
760}