1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
45use crate::Task;
67use futures::channel::mpsc;
8use futures::Future;
910use super::Scope;
1112/// Errors that can be returned by this crate.
13#[derive(Debug, thiserror::Error)]
14enum Error {
15/// Return when a task cannot be added to a [`TaskGroup`] or [`TaskSink`].
16#[error("Failed to add Task: {0}")]
17GroupDropped(#[from] mpsc::TrySendError<Task<()>>),
18}
1920/// Allows the user to spawn multiple Tasks and await them as a unit.
21///
22/// Tasks can be added to this group using [`TaskGroup::add`].
23/// All pending tasks in the group can be awaited using [`TaskGroup::join`].
24///
25/// New code should prefer to use [`Scope`] instead.
26pub struct TaskGroup {
27 scope: Scope,
28}
2930impl Default for TaskGroup {
31fn default() -> Self {
32Self::new()
33 }
34}
3536impl TaskGroup {
37/// Creates a new TaskGroup.
38 ///
39 /// The TaskGroup can be used to await an arbitrary number of Tasks and may
40 /// consume an arbitrary amount of memory.
41pub fn new() -> Self {
42#[cfg(target_os = "fuchsia")]
43return Self { scope: Scope::global().new_child() };
44#[cfg(not(target_os = "fuchsia"))]
45return Self { scope: Scope::new() };
46 }
4748/// Spawns a new task in this TaskGroup.
49 ///
50 /// To add a future that is not [`Send`] to this TaskGroup, use [`TaskGroup::local`].
51 ///
52 /// # Panics
53 ///
54 /// `spawn` may panic if not called in the context of an executor (e.g.
55 /// within a call to `run` or `run_singlethreaded`).
56pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
57self.scope.spawn(future);
58 }
5960/// Spawns a new task in this TaskGroup.
61 ///
62 /// # Panics
63 ///
64 /// `spawn` may panic if not called in the context of a single threaded executor
65 /// (e.g. within a call to `run_singlethreaded`).
66pub fn local(&mut self, future: impl Future<Output = ()> + 'static) {
67self.scope.spawn_local(future);
68 }
6970/// Waits for all Tasks in this TaskGroup to finish.
71 ///
72 /// Call this only after all Tasks have been added.
73pub async fn join(self) {
74self.scope.on_no_tasks().await;
75 }
76}
7778#[cfg(test)]
79mod tests {
80use super::*;
81use crate::SendExecutor;
82use futures::StreamExt;
83use std::sync::atomic::{AtomicU64, Ordering};
84use std::sync::Arc;
8586// Notifies a channel when dropped, signifying completion of some operation.
87#[derive(Clone)]
88struct DoneSignaler {
89 done: mpsc::UnboundedSender<()>,
90 }
91impl Drop for DoneSignaler {
92fn drop(&mut self) {
93self.done.unbounded_send(()).unwrap();
94self.done.disconnect();
95 }
96 }
9798// Waits for a group of `impl Drop` to signal completion.
99 // Create as many `impl Drop` objects as needed with `WaitGroup::add_one` and
100 // call `wait` to wait for all of them to be dropped.
101struct WaitGroup {
102 tx: mpsc::UnboundedSender<()>,
103 rx: mpsc::UnboundedReceiver<()>,
104 }
105106impl WaitGroup {
107fn new() -> Self {
108let (tx, rx) = mpsc::unbounded();
109Self { tx, rx }
110 }
111112fn add_one(&self) -> impl Drop {
113 DoneSignaler { done: self.tx.clone() }
114 }
115116async fn wait(self) {
117 drop(self.tx);
118self.rx.collect::<()>().await;
119 }
120 }
121122#[test]
123fn test_task_group_join_waits_for_tasks() {
124let task_count = 20;
125126 SendExecutor::new(task_count).run(async move {
127let mut task_group = TaskGroup::new();
128let value = Arc::new(AtomicU64::new(0));
129130for _ in 0..task_count {
131let value = value.clone();
132 task_group.spawn(async move {
133 value.fetch_add(1, Ordering::Relaxed);
134 });
135 }
136137 task_group.join().await;
138assert_eq!(value.load(Ordering::Relaxed), task_count as u64);
139 });
140 }
141142#[test]
143fn test_task_group_empty_join_completes() {
144 SendExecutor::new(1).run(async move {
145 TaskGroup::new().join().await;
146 });
147 }
148149#[test]
150fn test_task_group_added_tasks_are_cancelled_on_drop() {
151let wait_group = WaitGroup::new();
152let task_count = 10;
153154 SendExecutor::new(task_count).run(async move {
155let mut task_group = TaskGroup::new();
156for _ in 0..task_count {
157let done_signaler = wait_group.add_one();
158159// Never completes but drops `done_signaler` when cancelled.
160task_group.spawn(async move {
161// Take ownership of done_signaler.
162let _done_signaler = done_signaler;
163 std::future::pending::<()>().await;
164 });
165 }
166167 drop(task_group);
168 wait_group.wait().await;
169// If we get here, all tasks were cancelled.
170});
171 }
172173#[test]
174fn test_task_group_spawn() {
175let task_count = 3;
176 SendExecutor::new(task_count).run(async move {
177let mut task_group = TaskGroup::new();
178179// We can spawn tasks from any Future<()> implementation, including...
180181 // ... naked futures.
182task_group.spawn(std::future::ready(()));
183184// ... futures returned from async blocks.
185task_group.spawn(async move {
186 std::future::ready(()).await;
187 });
188189// ... and other tasks.
190task_group.spawn(Task::spawn(std::future::ready(())));
191192 task_group.join().await;
193 });
194 }
195}