1use crate::token_registry::TokenRegistry;
20
21use fuchsia_async::{JoinHandle, Scope, ScopeHandle, SpawnableFuture};
22use fuchsia_sync::{MappedMutexGuard, Mutex, MutexGuard};
23use futures::Future;
24use futures::task::{self, Poll};
25use std::future::poll_fn;
26use std::pin::Pin;
27use std::sync::{Arc, Weak};
28use std::task::Context;
29
30#[cfg(target_os = "fuchsia")]
31use fuchsia_async::EHandle;
32
33pub use fuchsia_async::scope::ScopeActiveGuard as ActiveGuard;
34
35pub type SpawnError = task::SpawnError;
36
37#[derive(Clone)]
48pub struct ExecutionScope {
49 executor: Arc<Executor>,
50
51 #[cfg(feature = "fdomain")]
54 client: Arc<flex_client::Client>,
55}
56
57struct Executor {
58 token_registry: TokenRegistry,
59 scope: Mutex<Option<Scope>>,
60}
61
62impl ExecutionScope {
63 pub fn new(#[cfg(feature = "fdomain")] client: Arc<flex_client::Client>) -> Self {
66 Self::build().new(
67 #[cfg(feature = "fdomain")]
68 client,
69 )
70 }
71
72 #[cfg(feature = "fdomain")]
74 pub fn domain(&self) -> Arc<flex_client::Client> {
75 Arc::clone(&self.client)
76 }
77
78 #[cfg(not(feature = "fdomain"))]
80 pub fn domain(&self) -> fidl::endpoints::ZirconClient {
81 fidl::endpoints::ZirconClient
82 }
83
84 pub fn build() -> ExecutionScopeParams {
88 ExecutionScopeParams::default()
89 }
90
91 pub fn as_weak(&self) -> WeakExecutionScope {
92 WeakExecutionScope {
93 executor: Arc::downgrade(&self.executor),
94 #[cfg(feature = "fdomain")]
95 client: Arc::downgrade(&self.client),
96 }
97 }
98
99 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
111 self.executor.scope().spawn(task)
112 }
113
114 pub fn new_task(self, task: impl Future<Output = ()> + Send + 'static) -> Task {
116 Task(self.executor, SpawnableFuture::new(task))
117 }
118
119 pub fn token_registry(&self) -> &TokenRegistry {
120 &self.executor.token_registry
121 }
122
123 pub fn shutdown(&self) {
124 self.executor.shutdown();
125 }
126
127 pub fn force_shutdown(&self) {
129 let _ = self.executor.scope().clone().abort();
130 }
131
132 pub fn resurrect(&self) {
135 *self.executor.scope.lock() = None;
138 }
139
140 pub async fn wait(&self) {
142 let scope = self.executor.scope().clone();
143 scope.on_no_tasks_and_guards().await;
144 }
145
146 pub fn try_active_guard(&self) -> Option<ActiveGuard> {
149 self.executor.scope().active_guard()
150 }
151}
152
153impl PartialEq for ExecutionScope {
154 fn eq(&self, other: &Self) -> bool {
155 Arc::as_ptr(&self.executor) == Arc::as_ptr(&other.executor)
156 }
157}
158
159impl Eq for ExecutionScope {}
160
161impl std::fmt::Debug for ExecutionScope {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.write_fmt(format_args!("ExecutionScope {:?}", Arc::as_ptr(&self.executor)))
164 }
165}
166
167#[derive(Default)]
168pub struct ExecutionScopeParams {
169 #[cfg(target_os = "fuchsia")]
170 async_executor: Option<EHandle>,
171}
172
173impl ExecutionScopeParams {
174 #[cfg(target_os = "fuchsia")]
175 pub fn executor(mut self, value: EHandle) -> Self {
176 assert!(self.async_executor.is_none(), "`executor` is already set");
177 self.async_executor = Some(value);
178 self
179 }
180
181 pub fn new(
182 self,
183 #[cfg(feature = "fdomain")] client: Arc<flex_client::Client>,
184 ) -> ExecutionScope {
185 ExecutionScope {
186 executor: Arc::new(Executor {
187 token_registry: TokenRegistry::new(),
188 #[cfg(target_os = "fuchsia")]
189 scope: self.async_executor.map_or_else(
190 || Mutex::new(None),
191 |e| Mutex::new(Some(e.global_scope().new_child())),
192 ),
193 #[cfg(not(target_os = "fuchsia"))]
194 scope: Mutex::new(None),
195 }),
196 #[cfg(feature = "fdomain")]
197 client,
198 }
199 }
200}
201
202#[derive(Clone)]
205pub struct WeakExecutionScope {
206 executor: Weak<Executor>,
207 #[cfg(feature = "fdomain")]
208 client: Weak<flex_client::Client>,
209}
210
211impl WeakExecutionScope {
212 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) {
215 let executor = self.executor.upgrade();
216 if let Some(executor) = executor {
217 _ = executor.scope().spawn(task)
218 }
219 }
220
221 #[cfg(feature = "fdomain")]
223 pub fn domain(&self) -> Option<Arc<flex_client::Client>> {
224 self.client.upgrade()
225 }
226
227 #[cfg(not(feature = "fdomain"))]
228 pub fn domain(&self) -> Option<fidl::endpoints::ZirconClient> {
229 Some(fidl::endpoints::ZirconClient)
230 }
231}
232
233impl Executor {
234 fn scope(&self) -> MappedMutexGuard<'_, Scope> {
235 MutexGuard::map(self.scope.lock(), |s| {
239 s.get_or_insert_with(|| {
240 #[cfg(target_os = "fuchsia")]
241 return Scope::global().new_child();
242 #[cfg(not(target_os = "fuchsia"))]
243 return Scope::new();
244 })
245 })
246 }
247
248 fn shutdown(&self) {
249 if let Some(scope) = &*self.scope.lock() {
250 scope.wake_all_with_active_guard();
251 let _ = ScopeHandle::clone(&*scope).cancel();
252 }
253 }
254}
255
256impl Drop for Executor {
257 fn drop(&mut self) {
258 self.shutdown();
259 if let Some(scope) = self.scope.get_mut().take() {
262 scope.detach();
263 }
264 }
265}
266
267pub async fn yield_to_executor() {
269 let mut done = false;
270 poll_fn(|cx| {
271 if done {
272 Poll::Ready(())
273 } else {
274 done = true;
275 cx.waker().wake_by_ref();
276 Poll::Pending
277 }
278 })
279 .await;
280}
281
282pub struct Task(Arc<Executor>, SpawnableFuture<'static, ()>);
283
284impl Task {
285 pub fn spawn(self) {
287 self.0.scope().spawn(self.1);
288 }
289}
290
291impl Future for Task {
292 type Output = ();
293
294 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295 Pin::new(&mut &mut self.1).poll(cx)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::{ExecutionScope, yield_to_executor};
302
303 use fuchsia_async::{TestExecutor, Timer};
304 use futures::Future;
305 use futures::channel::oneshot;
306 use std::pin::pin;
307 use std::sync::Arc;
308 #[cfg(target_os = "fuchsia")]
309 use std::sync::atomic::{AtomicBool, Ordering};
310 #[cfg(target_os = "fuchsia")]
311 use std::task::Poll;
312 use std::time::Duration;
313
314 #[cfg(target_os = "fuchsia")]
315 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
316 where
317 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
318 GetTestRes: Future<Output = ()>,
319 {
320 let mut exec = TestExecutor::new();
321
322 #[cfg(feature = "fdomain")]
323 let scope = crate::execution_scope::ExecutionScope::new(flex_local::local_client_empty());
324 #[cfg(not(feature = "fdomain"))]
325 let scope = crate::execution_scope::ExecutionScope::new();
326
327 let test = get_test(scope);
328
329 assert_eq!(
330 exec.run_until_stalled(&mut pin!(test)),
331 Poll::Ready(()),
332 "Test did not complete"
333 );
334 }
335
336 #[cfg(not(target_os = "fuchsia"))]
337 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
338 where
339 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
340 GetTestRes: Future<Output = ()>,
341 {
342 use fuchsia_async::TimeoutExt;
343 let mut exec = TestExecutor::new();
344
345 #[cfg(feature = "fdomain")]
346 let scope = crate::execution_scope::ExecutionScope::new(flex_local::local_client_empty());
347 #[cfg(not(feature = "fdomain"))]
348 let scope = crate::execution_scope::ExecutionScope::new();
349
350 let test =
354 get_test(scope).on_stalled(Duration::from_secs(30), || panic!("Test did not complete"));
355
356 exec.run_singlethreaded(&mut pin!(test));
357 }
358
359 #[test]
360 fn simple() {
361 run_test(|scope| {
362 async move {
363 let (sender, receiver) = oneshot::channel();
364 let (counters, task) = mocks::ImmediateTask::new(sender);
365
366 scope.spawn(task);
367
368 receiver.await.unwrap();
370
371 assert_eq!(counters.drop_call(), 1);
372 assert_eq!(counters.poll_call(), 1);
373 }
374 });
375 }
376
377 #[test]
378 fn simple_drop() {
379 run_test(|scope| {
380 async move {
381 let (poll_sender, poll_receiver) = oneshot::channel();
382 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
383 let (drop_sender, drop_receiver) = oneshot::channel();
384 let (counters, task) =
385 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
386
387 scope.spawn(task);
388
389 poll_receiver.await.unwrap();
390
391 processing_done_sender.send(()).unwrap();
392
393 scope.shutdown();
394
395 drop_receiver.await.unwrap();
396
397 let poll_count = counters.poll_call();
400 assert!(poll_count >= 1, "poll was not called");
401
402 assert_eq!(counters.drop_call(), 1);
403 }
404 });
405 }
406
407 #[test]
408 fn test_wait_waits_for_tasks_to_finish() {
409 let mut executor = TestExecutor::new();
410 #[cfg(feature = "fdomain")]
411 let scope = crate::execution_scope::ExecutionScope::new(flex_local::local_client_empty());
412 #[cfg(not(feature = "fdomain"))]
413 let scope = crate::execution_scope::ExecutionScope::new();
414 executor.run_singlethreaded(async {
415 let (poll_sender, poll_receiver) = oneshot::channel();
416 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
417 let (drop_sender, _drop_receiver) = oneshot::channel();
418 let (_, task) =
419 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
420
421 scope.spawn(task);
422
423 poll_receiver.await.unwrap();
424
425 let done = fuchsia_sync::Mutex::new(false);
428 futures::join!(
429 async {
430 scope.wait().await;
431 assert_eq!(*done.lock(), true);
432 },
433 async {
434 Timer::new(Duration::from_millis(100)).await;
436 *done.lock() = true;
437 processing_done_sender.send(()).unwrap();
438 }
439 );
440 });
441 }
442
443 #[cfg(target_os = "fuchsia")]
444 #[fuchsia::test]
445 async fn test_shutdown_waits_for_channels() {
446 use fuchsia_async as fasync;
447
448 #[cfg(feature = "fdomain")]
449 let scope = crate::execution_scope::ExecutionScope::new(flex_local::local_client_empty());
450 #[cfg(not(feature = "fdomain"))]
451 let scope = crate::execution_scope::ExecutionScope::new();
452 let (rx, tx) = zx::Channel::create();
453 let received_msg = Arc::new(AtomicBool::new(false));
454 let (sender, receiver) = futures::channel::oneshot::channel();
455 {
456 let received_msg = received_msg.clone();
457 scope.spawn(async move {
458 let mut msg_buf = zx::MessageBuf::new();
459 msg_buf.ensure_capacity_bytes(64);
460 let _ = sender.send(());
461 let _ = fasync::Channel::from_channel(rx).recv_msg(&mut msg_buf).await;
462 received_msg.store(true, Ordering::Relaxed);
463 });
464 }
465 let _ = receiver.await;
467
468 tx.write(b"hello", &mut []).expect("write failed");
469 scope.shutdown();
470 scope.wait().await;
471 assert!(received_msg.load(Ordering::Relaxed));
472 }
473
474 #[fuchsia::test]
475 async fn test_force_shutdown() {
476 #[cfg(feature = "fdomain")]
477 let scope = crate::execution_scope::ExecutionScope::new(flex_local::local_client_empty());
478 #[cfg(not(feature = "fdomain"))]
479 let scope = crate::execution_scope::ExecutionScope::new();
480 let scope_clone = scope.clone();
481 let ref_count = Arc::new(());
482 let ref_count_clone = ref_count.clone();
483
484 scope.spawn(async move {
487 let _ref_count_clone = ref_count_clone;
488
489 let _guard = scope_clone.try_active_guard().unwrap();
491
492 let _: () = std::future::pending().await;
493 });
494
495 scope.force_shutdown();
496 scope.wait().await;
497
498 assert_eq!(Arc::strong_count(&ref_count), 1);
500
501 scope.resurrect();
503
504 let ref_count_clone = ref_count.clone();
505 scope.spawn(async move {
506 yield_to_executor().await;
508
509 let _ref_count = ref_count_clone.clone();
511
512 let _: () = std::future::pending().await;
513 });
514
515 while Arc::strong_count(&ref_count) != 3 {
516 yield_to_executor().await;
517 }
518
519 for _ in 0..5 {
521 yield_to_executor().await;
522 assert_eq!(Arc::strong_count(&ref_count), 3);
523 }
524 }
525
526 mod mocks {
527 use futures::Future;
528 use futures::channel::oneshot;
529 use futures::task::{Context, Poll};
530 use std::pin::Pin;
531 use std::sync::Arc;
532 use std::sync::atomic::{AtomicUsize, Ordering};
533
534 pub(super) struct TaskCounters {
535 poll_call_count: Arc<AtomicUsize>,
536 drop_call_count: Arc<AtomicUsize>,
537 }
538
539 impl TaskCounters {
540 fn new() -> (Arc<AtomicUsize>, Arc<AtomicUsize>, Self) {
541 let poll_call_count = Arc::new(AtomicUsize::new(0));
542 let drop_call_count = Arc::new(AtomicUsize::new(0));
543
544 (
545 poll_call_count.clone(),
546 drop_call_count.clone(),
547 Self { poll_call_count, drop_call_count },
548 )
549 }
550
551 pub(super) fn poll_call(&self) -> usize {
552 self.poll_call_count.load(Ordering::Relaxed)
553 }
554
555 pub(super) fn drop_call(&self) -> usize {
556 self.drop_call_count.load(Ordering::Relaxed)
557 }
558 }
559
560 pub(super) struct ImmediateTask {
561 poll_call_count: Arc<AtomicUsize>,
562 drop_call_count: Arc<AtomicUsize>,
563 done_sender: Option<oneshot::Sender<()>>,
564 }
565
566 impl ImmediateTask {
567 pub(super) fn new(done_sender: oneshot::Sender<()>) -> (TaskCounters, Self) {
568 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
569 (
570 counters,
571 Self { poll_call_count, drop_call_count, done_sender: Some(done_sender) },
572 )
573 }
574 }
575
576 impl Future for ImmediateTask {
577 type Output = ();
578
579 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
580 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
581
582 if let Some(sender) = self.done_sender.take() {
583 sender.send(()).unwrap();
584 }
585
586 Poll::Ready(())
587 }
588 }
589
590 impl Drop for ImmediateTask {
591 fn drop(&mut self) {
592 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
593 }
594 }
595
596 impl Unpin for ImmediateTask {}
597
598 pub(super) struct ControlledTask {
599 poll_call_count: Arc<AtomicUsize>,
600 drop_call_count: Arc<AtomicUsize>,
601
602 drop_sender: Option<oneshot::Sender<()>>,
603 future: Pin<Box<dyn Future<Output = ()> + Send>>,
604 }
605
606 impl ControlledTask {
607 pub(super) fn new(
608 poll_sender: oneshot::Sender<()>,
609 processing_complete: oneshot::Receiver<()>,
610 drop_sender: oneshot::Sender<()>,
611 ) -> (TaskCounters, Self) {
612 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
613 (
614 counters,
615 Self {
616 poll_call_count,
617 drop_call_count,
618 drop_sender: Some(drop_sender),
619 future: Box::pin(async move {
620 poll_sender.send(()).unwrap();
621 processing_complete.await.unwrap();
622 }),
623 },
624 )
625 }
626 }
627
628 impl Future for ControlledTask {
629 type Output = ();
630
631 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
632 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
633 self.future.as_mut().poll(cx)
634 }
635 }
636
637 impl Drop for ControlledTask {
638 fn drop(&mut self) {
639 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
640 self.drop_sender.take().unwrap().send(()).unwrap();
641 }
642 }
643 }
644}