async_lock/
barrier.rs

1use event_listener::Event;
2
3use crate::Mutex;
4
5/// A counter to synchronize multiple tasks at the same time.
6#[derive(Debug)]
7pub struct Barrier {
8    n: usize,
9    state: Mutex<State>,
10    event: Event,
11}
12
13#[derive(Debug)]
14struct State {
15    count: usize,
16    generation_id: u64,
17}
18
19impl Barrier {
20    /// Creates a barrier that can block the given number of tasks.
21    ///
22    /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
23    /// at once when the `n`th task calls [`wait()`].
24    ///
25    /// [`wait()`]: `Barrier::wait()`
26    ///
27    /// # Examples
28    ///
29    /// ```
30    /// use async_lock::Barrier;
31    ///
32    /// let barrier = Barrier::new(5);
33    /// ```
34    pub const fn new(n: usize) -> Barrier {
35        Barrier {
36            n,
37            state: Mutex::new(State {
38                count: 0,
39                generation_id: 0,
40            }),
41            event: Event::new(),
42        }
43    }
44
45    /// Blocks the current task until all tasks reach this point.
46    ///
47    /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
48    ///
49    /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
50    /// last task to call this method.
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use async_lock::Barrier;
56    /// use futures_lite::future;
57    /// use std::sync::Arc;
58    /// use std::thread;
59    ///
60    /// let barrier = Arc::new(Barrier::new(5));
61    ///
62    /// for _ in 0..5 {
63    ///     let b = barrier.clone();
64    ///     thread::spawn(move || {
65    ///         future::block_on(async {
66    ///             // The same messages will be printed together.
67    ///             // There will NOT be interleaving of "before" and "after".
68    ///             println!("before wait");
69    ///             b.wait().await;
70    ///             println!("after wait");
71    ///         });
72    ///     });
73    /// }
74    /// ```
75    pub async fn wait(&self) -> BarrierWaitResult {
76        let mut state = self.state.lock().await;
77        let local_gen = state.generation_id;
78        state.count += 1;
79
80        if state.count < self.n {
81            while local_gen == state.generation_id && state.count < self.n {
82                let listener = self.event.listen();
83                drop(state);
84                listener.await;
85                state = self.state.lock().await;
86            }
87            BarrierWaitResult { is_leader: false }
88        } else {
89            state.count = 0;
90            state.generation_id = state.generation_id.wrapping_add(1);
91            self.event.notify(std::usize::MAX);
92            BarrierWaitResult { is_leader: true }
93        }
94    }
95}
96
97/// Returned by [`Barrier::wait()`] when all tasks have called it.
98///
99/// # Examples
100///
101/// ```
102/// # futures_lite::future::block_on(async {
103/// use async_lock::Barrier;
104///
105/// let barrier = Barrier::new(1);
106/// let barrier_wait_result = barrier.wait().await;
107/// # });
108/// ```
109#[derive(Debug, Clone)]
110pub struct BarrierWaitResult {
111    is_leader: bool,
112}
113
114impl BarrierWaitResult {
115    /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// # futures_lite::future::block_on(async {
121    /// use async_lock::Barrier;
122    /// use futures_lite::future;
123    ///
124    /// let barrier = Barrier::new(2);
125    /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
126    /// assert_eq!(a.is_leader(), false);
127    /// assert_eq!(b.is_leader(), true);
128    /// # });
129    /// ```
130    pub fn is_leader(&self) -> bool {
131        self.is_leader
132    }
133}