fuchsia_async/
condition.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements a combined mutex and condition.
6//!
7//! # Example:
8//!
9//! ```no_run
10//!     let condition = Condition::new(0);
11//!     condition.when(|state| if state == 1 { Poll::Ready(()) } else { Poll::Pending }).await;
12//!
13//!     // Elsewhere...
14//!     let guard = condition.lock();
15//!     *guard.lock() = 1;
16//!     for waker in guard.drain_wakers() {
17//!         waker.wake();
18//!     }
19//! ```
20
21use fuchsia_sync::{Mutex, MutexGuard};
22use std::future::poll_fn;
23use std::marker::PhantomPinned;
24use std::ops::{Deref, DerefMut};
25use std::pin::{Pin, pin};
26use std::ptr::NonNull;
27use std::sync::Arc;
28use std::task::{Poll, Waker};
29
30/// An async condition which combines a mutex and a condition variable.
31// Condition is implemented as an intrusive doubly linked list.  Typical use should avoid any
32// additional heap allocations after creation, as the nodes of the list are stored as part of the
33// caller's future.
34#[derive(Default)]
35pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
36
37impl<T> Condition<T> {
38    /// Returns a new condition.
39    pub fn new(data: T) -> Self {
40        Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
41    }
42
43    /// Returns the number of wakers waiting on the condition.
44    pub fn waker_count(&self) -> usize {
45        self.0.lock().count
46    }
47
48    /// Same as `Mutex::lock`.
49    pub fn lock(&self) -> ConditionGuard<'_, T> {
50        ConditionGuard(self.0.lock())
51    }
52
53    /// Returns when `poll` resolves.
54    pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
55        let mut entry = pin!(self.waker_entry());
56        poll_fn(|cx| {
57            let mut guard = self.0.lock();
58            // SAFETY: We uphold the pin guarantee.
59            let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60            let result = poll(&mut guard.data);
61            if result.is_pending() {
62                // SAFETY: We set list correctly above.
63                unsafe {
64                    entry.node.add(&mut *guard, cx.waker().clone());
65                }
66            }
67            result
68        })
69        .await
70    }
71
72    /// Returns a new waker entry.
73    pub fn waker_entry(&self) -> WakerEntry<T> {
74        WakerEntry {
75            list: self.0.clone(),
76            node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
77        }
78    }
79}
80
81#[derive(Default)]
82struct Inner<T> {
83    head: Option<NonNull<Node>>,
84    count: usize,
85    data: T,
86}
87
88// SAFETY: Safe because we always access `head` whilst holding the list lock.
89unsafe impl<T: Send> Send for Inner<T> {}
90
91/// Guard returned by `lock`.
92pub struct ConditionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
93
94impl<'a, T> ConditionGuard<'a, T> {
95    /// Adds the waker entry to the condition's list of wakers.
96    pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
97        // SAFETY: We never move the data out.
98        let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
99        // SAFETY: We set list correctly above.
100        unsafe {
101            waker_entry.node.add(&mut *self.0, waker);
102        }
103    }
104
105    /// Returns an iterator that will drain all wakers.  Whilst the drainer exists, a lock is held
106    /// which will prevent new wakers from being added to the list, so depending on your use case,
107    /// you might wish to collect the wakers before calling `wake` on each waker.  NOTE: If the
108    /// drainer is dropped, this will *not* drain elements not visited.
109    pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
110        Drainer(self)
111    }
112
113    /// Returns the number of wakers registered with the condition.
114    pub fn waker_count(&self) -> usize {
115        self.0.count
116    }
117}
118
119impl<T> Deref for ConditionGuard<'_, T> {
120    type Target = T;
121
122    fn deref(&self) -> &Self::Target {
123        &self.0.data
124    }
125}
126
127impl<T> DerefMut for ConditionGuard<'_, T> {
128    fn deref_mut(&mut self) -> &mut Self::Target {
129        &mut self.0.data
130    }
131}
132
133/// A waker entry that can be added to a list.
134pub struct WakerEntry<T> {
135    list: Arc<Mutex<Inner<T>>>,
136    node: Node,
137}
138
139impl<T> Drop for WakerEntry<T> {
140    fn drop(&mut self) {
141        self.node.remove(&mut *self.list.lock());
142    }
143}
144
145// The members here must only be accessed whilst holding the mutex on the list.
146struct Node {
147    next: Option<NonNull<Node>>,
148    prev: Option<NonNull<Node>>,
149    waker: Option<Waker>,
150    _pinned: PhantomPinned,
151}
152
153// SAFETY: Safe because we always access all mebers of `Node` whilst holding the list lock.
154unsafe impl Send for Node {}
155
156impl Node {
157    // # Safety
158    //
159    // The waker *must* have `list` set correctly.
160    unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
161        if self.waker.is_none() {
162            self.prev = None;
163            self.next = inner.head;
164            inner.head = Some(self.into());
165            if let Some(mut next) = self.next {
166                // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set
167                // correctly above.
168                unsafe {
169                    next.as_mut().prev = Some(self.into());
170                }
171            }
172            inner.count += 1;
173        }
174        self.waker = Some(waker);
175    }
176
177    fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
178        if self.waker.is_none() {
179            debug_assert!(self.prev.is_none() && self.next.is_none());
180            return None;
181        }
182        if let Some(mut next) = self.next {
183            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
184            unsafe { next.as_mut().prev = self.prev };
185        }
186        if let Some(mut prev) = self.prev {
187            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
188            unsafe { prev.as_mut().next = self.next };
189        } else {
190            inner.head = self.next;
191        }
192        self.prev = None;
193        self.next = None;
194        inner.count -= 1;
195        self.waker.take()
196    }
197}
198
199/// An iterator that will drain waiters.
200pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
201
202impl<T> Iterator for Drainer<'_, '_, T> {
203    type Item = Waker;
204    fn next(&mut self) -> Option<Self::Item> {
205        if let Some(mut head) = self.0.0.head {
206            // SAFETY: Safe because we have exclusive access to `Inner` and `head is set correctly.
207            unsafe { head.as_mut().remove(&mut self.0.0) }
208        } else {
209            None
210        }
211    }
212
213    fn size_hint(&self) -> (usize, Option<usize>) {
214        (self.0.0.count, Some(self.0.0.count))
215    }
216}
217
218impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
219    fn len(&self) -> usize {
220        self.0.0.count
221    }
222}
223
224#[cfg(all(target_os = "fuchsia", test))]
225mod tests {
226    use super::Condition;
227    use crate::TestExecutor;
228    use futures::StreamExt;
229    use futures::stream::FuturesUnordered;
230    use futures::task::noop_waker;
231    use std::pin::pin;
232    use std::sync::atomic::{AtomicU64, Ordering};
233    use std::task::Poll;
234
235    #[test]
236    fn test_condition_can_waker_multiple_wakers() {
237        let mut executor = TestExecutor::new();
238        let condition = Condition::new(());
239
240        static COUNT: u64 = 10;
241
242        let counter = AtomicU64::new(0);
243
244        // Use FuturesUnordered so that futures are only polled when explicitly woken.
245        let mut futures = FuturesUnordered::new();
246
247        for _ in 0..COUNT {
248            futures.push(condition.when(|()| {
249                if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
250                    Poll::Ready(())
251                } else {
252                    Poll::Pending
253                }
254            }));
255        }
256
257        assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
258
259        assert_eq!(counter.load(Ordering::Relaxed), COUNT);
260        assert_eq!(condition.waker_count(), COUNT as usize);
261
262        {
263            let mut guard = condition.lock();
264            let drainer = guard.drain_wakers();
265            assert_eq!(drainer.len(), COUNT as usize);
266            for waker in drainer {
267                waker.wake();
268            }
269        }
270
271        assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
272        assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
273    }
274
275    #[test]
276    fn test_dropping_waker_entry_removes_from_list() {
277        let condition = Condition::new(());
278
279        let entry1 = pin!(condition.waker_entry());
280        condition.lock().add_waker(entry1, noop_waker());
281
282        {
283            let entry2 = pin!(condition.waker_entry());
284            condition.lock().add_waker(entry2, noop_waker());
285
286            assert_eq!(condition.waker_count(), 2);
287        }
288
289        assert_eq!(condition.waker_count(), 1);
290        {
291            let mut guard = condition.lock();
292            assert_eq!(guard.drain_wakers().count(), 1);
293        }
294
295        assert_eq!(condition.waker_count(), 0);
296
297        let entry3 = pin!(condition.waker_entry());
298        condition.lock().add_waker(entry3, noop_waker());
299
300        assert_eq!(condition.waker_count(), 1);
301    }
302
303    #[test]
304    fn test_waker_can_be_added_multiple_times() {
305        let condition = Condition::new(());
306
307        let mut entry1 = pin!(condition.waker_entry());
308        condition.lock().add_waker(entry1.as_mut(), noop_waker());
309
310        let mut entry2 = pin!(condition.waker_entry());
311        condition.lock().add_waker(entry2.as_mut(), noop_waker());
312
313        assert_eq!(condition.waker_count(), 2);
314        {
315            let mut guard = condition.lock();
316            assert_eq!(guard.drain_wakers().count(), 2);
317        }
318        assert_eq!(condition.waker_count(), 0);
319
320        condition.lock().add_waker(entry1, noop_waker());
321        condition.lock().add_waker(entry2, noop_waker());
322
323        assert_eq!(condition.waker_count(), 2);
324
325        {
326            let mut guard = condition.lock();
327            assert_eq!(guard.drain_wakers().count(), 2);
328        }
329        assert_eq!(condition.waker_count(), 0);
330    }
331}