fuchsia_async/runtime/
task_group.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::Task;
6
7use futures::channel::mpsc;
8use futures::Future;
9
10use super::Scope;
11
12/// Errors that can be returned by this crate.
13#[derive(Debug, thiserror::Error)]
14enum Error {
15    /// Return when a task cannot be added to a [`TaskGroup`] or [`TaskSink`].
16    #[error("Failed to add Task: {0}")]
17    GroupDropped(#[from] mpsc::TrySendError<Task<()>>),
18}
19
20/// Allows the user to spawn multiple Tasks and await them as a unit.
21///
22/// Tasks can be added to this group using [`TaskGroup::add`].
23/// All pending tasks in the group can be awaited using [`TaskGroup::join`].
24///
25/// New code should prefer to use [`Scope`] instead.
26pub struct TaskGroup {
27    scope: Scope,
28}
29
30impl Default for TaskGroup {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl TaskGroup {
37    /// Creates a new TaskGroup.
38    ///
39    /// The TaskGroup can be used to await an arbitrary number of Tasks and may
40    /// consume an arbitrary amount of memory.
41    pub fn new() -> Self {
42        #[cfg(target_os = "fuchsia")]
43        return Self { scope: Scope::global().new_child() };
44        #[cfg(not(target_os = "fuchsia"))]
45        return Self { scope: Scope::new() };
46    }
47
48    /// Spawns a new task in this TaskGroup.
49    ///
50    /// To add a future that is not [`Send`] to this TaskGroup, use [`TaskGroup::local`].
51    ///
52    /// # Panics
53    ///
54    /// `spawn` may panic if not called in the context of an executor (e.g.
55    /// within a call to `run` or `run_singlethreaded`).
56    pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
57        self.scope.spawn(future);
58    }
59
60    /// Spawns a new task in this TaskGroup.
61    ///
62    /// # Panics
63    ///
64    /// `spawn` may panic if not called in the context of a single threaded executor
65    /// (e.g. within a call to `run_singlethreaded`).
66    pub fn local(&mut self, future: impl Future<Output = ()> + 'static) {
67        self.scope.spawn_local(future);
68    }
69
70    /// Waits for all Tasks in this TaskGroup to finish.
71    ///
72    /// Call this only after all Tasks have been added.
73    pub async fn join(self) {
74        self.scope.on_no_tasks().await;
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::SendExecutor;
82    use futures::StreamExt;
83    use std::sync::atomic::{AtomicU64, Ordering};
84    use std::sync::Arc;
85
86    // Notifies a channel when dropped, signifying completion of some operation.
87    #[derive(Clone)]
88    struct DoneSignaler {
89        done: mpsc::UnboundedSender<()>,
90    }
91    impl Drop for DoneSignaler {
92        fn drop(&mut self) {
93            self.done.unbounded_send(()).unwrap();
94            self.done.disconnect();
95        }
96    }
97
98    // Waits for a group of `impl Drop` to signal completion.
99    // Create as many `impl Drop` objects as needed with `WaitGroup::add_one` and
100    // call `wait` to wait for all of them to be dropped.
101    struct WaitGroup {
102        tx: mpsc::UnboundedSender<()>,
103        rx: mpsc::UnboundedReceiver<()>,
104    }
105
106    impl WaitGroup {
107        fn new() -> Self {
108            let (tx, rx) = mpsc::unbounded();
109            Self { tx, rx }
110        }
111
112        fn add_one(&self) -> impl Drop {
113            DoneSignaler { done: self.tx.clone() }
114        }
115
116        async fn wait(self) {
117            drop(self.tx);
118            self.rx.collect::<()>().await;
119        }
120    }
121
122    #[test]
123    fn test_task_group_join_waits_for_tasks() {
124        let task_count = 20;
125
126        SendExecutor::new(task_count).run(async move {
127            let mut task_group = TaskGroup::new();
128            let value = Arc::new(AtomicU64::new(0));
129
130            for _ in 0..task_count {
131                let value = value.clone();
132                task_group.spawn(async move {
133                    value.fetch_add(1, Ordering::Relaxed);
134                });
135            }
136
137            task_group.join().await;
138            assert_eq!(value.load(Ordering::Relaxed), task_count as u64);
139        });
140    }
141
142    #[test]
143    fn test_task_group_empty_join_completes() {
144        SendExecutor::new(1).run(async move {
145            TaskGroup::new().join().await;
146        });
147    }
148
149    #[test]
150    fn test_task_group_added_tasks_are_cancelled_on_drop() {
151        let wait_group = WaitGroup::new();
152        let task_count = 10;
153
154        SendExecutor::new(task_count).run(async move {
155            let mut task_group = TaskGroup::new();
156            for _ in 0..task_count {
157                let done_signaler = wait_group.add_one();
158
159                // Never completes but drops `done_signaler` when cancelled.
160                task_group.spawn(async move {
161                    // Take ownership of done_signaler.
162                    let _done_signaler = done_signaler;
163                    std::future::pending::<()>().await;
164                });
165            }
166
167            drop(task_group);
168            wait_group.wait().await;
169            // If we get here, all tasks were cancelled.
170        });
171    }
172
173    #[test]
174    fn test_task_group_spawn() {
175        let task_count = 3;
176        SendExecutor::new(task_count).run(async move {
177            let mut task_group = TaskGroup::new();
178
179            // We can spawn tasks from any Future<()> implementation, including...
180
181            // ... naked futures.
182            task_group.spawn(std::future::ready(()));
183
184            // ... futures returned from async blocks.
185            task_group.spawn(async move {
186                std::future::ready(()).await;
187            });
188
189            // ... and other tasks.
190            task_group.spawn(Task::spawn(std::future::ready(())));
191
192            task_group.join().await;
193        });
194    }
195}