fuchsia_sync/
condvar.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::MutexGuard;
6
7use std::sync::atomic::Ordering;
8use std::time::Duration;
9
10/// A [condition variable][wikipedia] that integrates with [`fuchsia_sync::Mutex`].
11///
12/// [wikipedia]: https://en.wikipedia.org/wiki/Monitor_(synchronization)#Condition_variables
13pub struct Condvar {
14    /// Incremented by 1 on each notification.
15    inner: zx::Futex,
16}
17
18impl Condvar {
19    pub const fn new() -> Self {
20        Self { inner: zx::Futex::new(0) }
21    }
22
23    pub fn notify_one(&self) {
24        // Relaxed because the futex operation synchronizes.
25        self.inner.fetch_add(1, Ordering::Relaxed);
26        self.inner.wake_single_owner();
27    }
28
29    pub fn notify_all(&self) {
30        // Relaxed because the futex operation synchronizes.
31        self.inner.fetch_add(1, Ordering::Relaxed);
32        self.inner.wake_all();
33    }
34
35    pub fn wait<T: ?Sized>(&self, guard: &mut MutexGuard<'_, T>) {
36        assert!(
37            !self.wait_inner(guard, zx::MonotonicInstant::INFINITE).timed_out,
38            "an infinite wait should not timeout"
39        );
40    }
41
42    pub fn wait_while<'a, T: ?Sized, F>(&self, guard: &mut MutexGuard<'a, T>, mut condition: F)
43    where
44        F: FnMut(&mut T) -> bool,
45    {
46        while condition(&mut *guard) {
47            self.wait(guard);
48        }
49    }
50
51    pub fn wait_for<T: ?Sized>(
52        &self,
53        guard: &mut MutexGuard<'_, T>,
54        timeout: Duration,
55    ) -> WaitTimeoutResult {
56        self.wait_inner(guard, zx::MonotonicInstant::after(timeout.into()))
57    }
58
59    pub fn wait_while_for<'a, T: ?Sized, F>(
60        &self,
61        guard: &mut MutexGuard<'a, T>,
62        mut condition: F,
63        timeout: Duration,
64    ) -> WaitTimeoutResult
65    where
66        F: FnMut(&mut T) -> bool,
67    {
68        let mut result = WaitTimeoutResult { timed_out: false };
69
70        while !result.timed_out() && condition(&mut *guard) {
71            result = self.wait_for(guard, timeout);
72        }
73
74        result
75    }
76
77    fn wait_inner<T: ?Sized>(
78        &self,
79        guard: &mut MutexGuard<'_, T>,
80        deadline: zx::MonotonicInstant,
81    ) -> WaitTimeoutResult {
82        // Relaxed because the futex and mutex operations synchronize.
83        let current = self.inner.load(Ordering::Relaxed);
84        MutexGuard::unlocked(guard, || {
85            match self.inner.wait(current, None, deadline) {
86                // The count only goes up. If `current` isn't the current value, a notification
87                // was received in between reading `current` and waiting.
88                Ok(()) | Err(zx::Status::BAD_STATE) => WaitTimeoutResult { timed_out: false },
89                Err(zx::Status::TIMED_OUT) => WaitTimeoutResult { timed_out: true },
90                Err(e) => panic!("unexpected wait error {e:?}"),
91            }
92        })
93    }
94}
95
96#[derive(Clone, Debug, PartialEq)]
97pub struct WaitTimeoutResult {
98    timed_out: bool,
99}
100
101impl WaitTimeoutResult {
102    pub fn timed_out(&self) -> bool {
103        self.timed_out
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::Mutex;
111
112    #[test]
113    fn notify_one_works() {
114        let mutex = Mutex::new(());
115        let condvar = Condvar::new();
116        crossbeam::thread::scope(|s| {
117            let mut locked = mutex.lock();
118            s.spawn(|_| {
119                // With the lock already held, this won't return until the below wait starts.
120                let _locked = mutex.lock();
121                condvar.notify_one();
122            });
123
124            // This will hang forever (and time out the test infra) if notification doesn't work.
125            condvar.wait(&mut locked);
126        })
127        .unwrap();
128    }
129
130    #[test]
131    fn notify_all_works() {
132        let num_threads = 10;
133        let count = Mutex::new(0);
134        let condvar = Condvar::new();
135        let (send, recv) = std::sync::mpsc::channel();
136
137        crossbeam::thread::scope(|s| {
138            for _ in 0..num_threads {
139                s.spawn(|_| {
140                    let mut count = count.lock();
141                    *count += 1;
142                    if *count == num_threads {
143                        // Notify the main thread that the last thread has acquired the lock.
144                        send.send(()).unwrap();
145                    }
146                    while *count != 0 {
147                        condvar.wait(&mut count);
148                    }
149                });
150            }
151
152            // Wait for all threads to have started waiting on their condvar.
153            recv.recv().unwrap();
154
155            let mut count = count.lock();
156            *count = 0;
157            condvar.notify_all();
158            drop(count);
159
160            // The crossbeam scope will now wait for all of the spawned threads to observe count=0.
161        })
162        .unwrap();
163    }
164
165    #[test]
166    fn wait_while_works() {
167        let pending = Mutex::new(true);
168        let condvar = Condvar::new();
169
170        crossbeam::thread::scope(|s| {
171            let mut locked_pending = pending.lock();
172
173            s.spawn(|_| {
174                // With the lock already held, this won't return until the below wait starts.
175                let mut locked_pending = pending.lock();
176                *locked_pending = false;
177                condvar.notify_one();
178            });
179
180            // This will hang forever (and time out the test infra) if the notification doesn't work
181            // or if the condition never evaluates to false.
182            condvar.wait_while(&mut locked_pending, |pending| !*pending);
183        })
184        .unwrap();
185    }
186
187    #[test]
188    fn wait_for_times_out() {
189        let mutex = Mutex::new(());
190        let condvar = Condvar::new();
191
192        let mut locked = mutex.lock();
193
194        // Account for possible spurious wakeups.
195        loop {
196            if condvar.wait_for(&mut locked, std::time::Duration::from_secs(1)).timed_out() {
197                break;
198            }
199        }
200    }
201
202    #[test]
203    fn wait_while_for_times_out() {
204        let mutex = Mutex::new(());
205        let condvar = Condvar::new();
206
207        let mut locked = mutex.lock();
208
209        let result =
210            condvar.wait_while_for(&mut locked, |_value| true, std::time::Duration::from_secs(1));
211
212        assert!(result.timed_out());
213    }
214}