fuchsia_sync/
condvar.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::MutexGuard;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9#[cfg(not(target_os = "fuchsia"))]
10use parking_lot_core::{
11    DEFAULT_PARK_TOKEN, DEFAULT_UNPARK_TOKEN, ParkResult, park, unpark_all, unpark_one,
12};
13#[cfg(not(target_os = "fuchsia"))]
14use std::sync::atomic::AtomicU32;
15#[cfg(not(target_os = "fuchsia"))]
16use std::time::Instant;
17
18/// A [condition variable][wikipedia] that integrates with [`fuchsia_sync::Mutex`].
19///
20/// [wikipedia]: https://en.wikipedia.org/wiki/Monitor_(synchronization)#Condition_variables
21pub struct Condvar {
22    /// Incremented by 1 on each notification.
23    inner: Futex,
24}
25
26impl Condvar {
27    pub const fn new() -> Self {
28        Self { inner: Futex::new(0) }
29    }
30
31    pub fn notify_one(&self) {
32        // Relaxed because the futex operation synchronizes.
33        self.inner.fetch_add(1, Ordering::Relaxed);
34        self.inner.wake_one();
35    }
36
37    pub fn notify_all(&self) {
38        // Relaxed because the futex operation synchronizes.
39        self.inner.fetch_add(1, Ordering::Relaxed);
40        self.inner.wake_all();
41    }
42
43    pub fn wait<T: ?Sized>(&self, guard: &mut MutexGuard<'_, T>) {
44        assert!(!self.wait_inner(guard, None).timed_out, "an infinite wait should not timeout");
45    }
46
47    pub fn wait_while<'a, T: ?Sized, F>(&self, guard: &mut MutexGuard<'a, T>, mut condition: F)
48    where
49        F: FnMut(&mut T) -> bool,
50    {
51        while condition(&mut *guard) {
52            self.wait(guard);
53        }
54    }
55
56    pub fn wait_for<T: ?Sized>(
57        &self,
58        guard: &mut MutexGuard<'_, T>,
59        timeout: Duration,
60    ) -> WaitTimeoutResult {
61        self.wait_inner(guard, Some(timeout))
62    }
63
64    pub fn wait_while_for<'a, T: ?Sized, F>(
65        &self,
66        guard: &mut MutexGuard<'a, T>,
67        mut condition: F,
68        timeout: Duration,
69    ) -> WaitTimeoutResult
70    where
71        F: FnMut(&mut T) -> bool,
72    {
73        let mut result = WaitTimeoutResult { timed_out: false };
74
75        while !result.timed_out() && condition(&mut *guard) {
76            result = self.wait_for(guard, timeout);
77        }
78
79        result
80    }
81
82    fn wait_inner<T: ?Sized>(
83        &self,
84        guard: &mut MutexGuard<'_, T>,
85        timeout: Option<Duration>,
86    ) -> WaitTimeoutResult {
87        // Relaxed because the futex and mutex operations synchronize.
88        let current = self.inner.load(Ordering::Relaxed);
89        MutexGuard::unlocked(guard, || self.inner.wait(current, timeout))
90    }
91}
92
93#[derive(Clone, Debug, PartialEq)]
94pub struct WaitTimeoutResult {
95    timed_out: bool,
96}
97
98impl WaitTimeoutResult {
99    pub fn timed_out(&self) -> bool {
100        self.timed_out
101    }
102}
103
104#[cfg(target_os = "fuchsia")]
105struct Futex(zx::Futex);
106
107#[cfg(target_os = "fuchsia")]
108impl Futex {
109    const fn new(value: u32) -> Self {
110        Self(zx::Futex::new(value as i32))
111    }
112
113    /// Returns the current value of the futex.
114    fn load(&self, order: Ordering) -> u32 {
115        self.0.load(order) as u32
116    }
117
118    fn fetch_add(&self, value: u32, order: Ordering) -> u32 {
119        self.0.fetch_add(value as i32, order) as u32
120    }
121
122    fn wake_one(&self) {
123        self.0.wake_single_owner();
124    }
125
126    fn wake_all(&self) {
127        self.0.wake_all();
128    }
129
130    fn wait(&self, current: u32, timeout: Option<Duration>) -> WaitTimeoutResult {
131        let deadline = timeout
132            .map(|t| zx::MonotonicInstant::after(t.into()))
133            .unwrap_or(zx::MonotonicInstant::INFINITE);
134        match self.0.wait(current as i32, None, deadline) {
135            // The count only goes up. If `current` isn't the current value, a notification
136            // was received in between reading `current` and waiting.
137            Ok(()) | Err(zx::Status::BAD_STATE) => WaitTimeoutResult { timed_out: false },
138            Err(zx::Status::TIMED_OUT) => WaitTimeoutResult { timed_out: true },
139            Err(e) => panic!("unexpected wait error {e:?}"),
140        }
141    }
142}
143
144#[cfg(not(target_os = "fuchsia"))]
145struct Futex(AtomicU32);
146
147#[cfg(not(target_os = "fuchsia"))]
148impl Futex {
149    const fn new(value: u32) -> Self {
150        Self(AtomicU32::new(value))
151    }
152
153    fn load(&self, order: Ordering) -> u32 {
154        self.0.load(order)
155    }
156
157    fn fetch_add(&self, value: u32, order: Ordering) -> u32 {
158        self.0.fetch_add(value, order)
159    }
160
161    fn wake_one(&self) {
162        // SAFETY: the address of `inner` is controlled by us.
163        unsafe {
164            unpark_one(self.0.as_ptr() as usize, |_| DEFAULT_UNPARK_TOKEN);
165        }
166    }
167
168    fn wake_all(&self) {
169        // SAFETY: the address of `inner` is controlled by us.
170        unsafe {
171            unpark_all(self.0.as_ptr() as usize, DEFAULT_UNPARK_TOKEN);
172        }
173    }
174
175    fn wait(&self, current: u32, timeout: Option<Duration>) -> WaitTimeoutResult {
176        let key = self.0.as_ptr() as usize;
177        let deadline = timeout.map(|t| Instant::now() + t);
178
179        // SAFETY: the address of `inner` is controlled by us.
180        let park_result = unsafe {
181            park(
182                key,
183                || self.0.load(Ordering::Relaxed) == current,
184                || {},
185                |_, _| {},
186                DEFAULT_PARK_TOKEN,
187                deadline,
188            )
189        };
190
191        match park_result {
192            ParkResult::Unparked(token) => {
193                assert_eq!(token, DEFAULT_UNPARK_TOKEN);
194                WaitTimeoutResult { timed_out: false }
195            }
196            ParkResult::Invalid => {
197                // The validation function failed, meaning the state changed before we could park.
198                // This counts as a notification.
199                WaitTimeoutResult { timed_out: false }
200            }
201            ParkResult::TimedOut => WaitTimeoutResult { timed_out: true },
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::Mutex;
210
211    #[test]
212    fn notify_one_works() {
213        let mutex = Mutex::new(());
214        let condvar = Condvar::new();
215        crossbeam::thread::scope(|s| {
216            let mut locked = mutex.lock();
217            s.spawn(|_| {
218                // With the lock already held, this won't return until the below wait starts.
219                let _locked = mutex.lock();
220                condvar.notify_one();
221            });
222
223            // This will hang forever (and time out the test infra) if notification doesn't work.
224            condvar.wait(&mut locked);
225        })
226        .unwrap();
227    }
228
229    #[test]
230    fn notify_all_works() {
231        let num_threads = 10;
232        let count = Mutex::new(0);
233        let condvar = Condvar::new();
234        let (send, recv) = std::sync::mpsc::channel();
235
236        crossbeam::thread::scope(|s| {
237            for _ in 0..num_threads {
238                s.spawn(|_| {
239                    let mut count = count.lock();
240                    *count += 1;
241                    if *count == num_threads {
242                        // Notify the main thread that the last thread has acquired the lock.
243                        send.send(()).unwrap();
244                    }
245                    while *count != 0 {
246                        condvar.wait(&mut count);
247                    }
248                });
249            }
250
251            // Wait for all threads to have started waiting on their condvar.
252            recv.recv().unwrap();
253
254            let mut count = count.lock();
255            *count = 0;
256            condvar.notify_all();
257            drop(count);
258
259            // The crossbeam scope will now wait for all of the spawned threads to observe count=0.
260        })
261        .unwrap();
262    }
263
264    #[test]
265    fn wait_while_works() {
266        let pending = Mutex::new(true);
267        let condvar = Condvar::new();
268
269        crossbeam::thread::scope(|s| {
270            let mut locked_pending = pending.lock();
271
272            s.spawn(|_| {
273                // With the lock already held, this won't return until the below wait starts.
274                let mut locked_pending = pending.lock();
275                *locked_pending = false;
276                condvar.notify_one();
277            });
278
279            // This will hang forever (and time out the test infra) if the notification doesn't work
280            // or if the condition never evaluates to false.
281            condvar.wait_while(&mut locked_pending, |pending| !*pending);
282        })
283        .unwrap();
284    }
285
286    #[test]
287    fn wait_for_times_out() {
288        let mutex = Mutex::new(());
289        let condvar = Condvar::new();
290
291        let mut locked = mutex.lock();
292
293        // Account for possible spurious wakeups.
294        loop {
295            if condvar.wait_for(&mut locked, std::time::Duration::from_secs(1)).timed_out() {
296                break;
297            }
298        }
299    }
300
301    #[test]
302    fn wait_while_for_times_out() {
303        let mutex = Mutex::new(());
304        let condvar = Condvar::new();
305
306        let mut locked = mutex.lock();
307
308        let result =
309            condvar.wait_while_for(&mut locked, |_value| true, std::time::Duration::from_secs(1));
310
311        assert!(result.timed_out());
312    }
313}