async_lock/barrier.rs
1use event_listener::Event;
2
3use crate::Mutex;
4
5/// A counter to synchronize multiple tasks at the same time.
6#[derive(Debug)]
7pub struct Barrier {
8 n: usize,
9 state: Mutex<State>,
10 event: Event,
11}
12
13#[derive(Debug)]
14struct State {
15 count: usize,
16 generation_id: u64,
17}
18
19impl Barrier {
20 /// Creates a barrier that can block the given number of tasks.
21 ///
22 /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
23 /// at once when the `n`th task calls [`wait()`].
24 ///
25 /// [`wait()`]: `Barrier::wait()`
26 ///
27 /// # Examples
28 ///
29 /// ```
30 /// use async_lock::Barrier;
31 ///
32 /// let barrier = Barrier::new(5);
33 /// ```
34 pub const fn new(n: usize) -> Barrier {
35 Barrier {
36 n,
37 state: Mutex::new(State {
38 count: 0,
39 generation_id: 0,
40 }),
41 event: Event::new(),
42 }
43 }
44
45 /// Blocks the current task until all tasks reach this point.
46 ///
47 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
48 ///
49 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
50 /// last task to call this method.
51 ///
52 /// # Examples
53 ///
54 /// ```
55 /// use async_lock::Barrier;
56 /// use futures_lite::future;
57 /// use std::sync::Arc;
58 /// use std::thread;
59 ///
60 /// let barrier = Arc::new(Barrier::new(5));
61 ///
62 /// for _ in 0..5 {
63 /// let b = barrier.clone();
64 /// thread::spawn(move || {
65 /// future::block_on(async {
66 /// // The same messages will be printed together.
67 /// // There will NOT be interleaving of "before" and "after".
68 /// println!("before wait");
69 /// b.wait().await;
70 /// println!("after wait");
71 /// });
72 /// });
73 /// }
74 /// ```
75 pub async fn wait(&self) -> BarrierWaitResult {
76 let mut state = self.state.lock().await;
77 let local_gen = state.generation_id;
78 state.count += 1;
79
80 if state.count < self.n {
81 while local_gen == state.generation_id && state.count < self.n {
82 let listener = self.event.listen();
83 drop(state);
84 listener.await;
85 state = self.state.lock().await;
86 }
87 BarrierWaitResult { is_leader: false }
88 } else {
89 state.count = 0;
90 state.generation_id = state.generation_id.wrapping_add(1);
91 self.event.notify(std::usize::MAX);
92 BarrierWaitResult { is_leader: true }
93 }
94 }
95}
96
97/// Returned by [`Barrier::wait()`] when all tasks have called it.
98///
99/// # Examples
100///
101/// ```
102/// # futures_lite::future::block_on(async {
103/// use async_lock::Barrier;
104///
105/// let barrier = Barrier::new(1);
106/// let barrier_wait_result = barrier.wait().await;
107/// # });
108/// ```
109#[derive(Debug, Clone)]
110pub struct BarrierWaitResult {
111 is_leader: bool,
112}
113
114impl BarrierWaitResult {
115 /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
116 ///
117 /// # Examples
118 ///
119 /// ```
120 /// # futures_lite::future::block_on(async {
121 /// use async_lock::Barrier;
122 /// use futures_lite::future;
123 ///
124 /// let barrier = Barrier::new(2);
125 /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
126 /// assert_eq!(a.is_leader(), false);
127 /// assert_eq!(b.is_leader(), true);
128 /// # });
129 /// ```
130 pub fn is_leader(&self) -> bool {
131 self.is_leader
132 }
133}