fuchsia_async/
condition.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements a combined mutex and condition.
6//!
7//! # Example:
8//!
9//! ```no_run
10//!     let condition = Condition::new(0);
11//!     condition.when(|state| if state == 1 { Poll::Ready(()) } else { Poll::Pending }).await;
12//!
13//!     // Elsewhere...
14//!     let guard = condition.lock();
15//!     *guard.lock() = 1;
16//!     for waker in guard.drain_wakers() {
17//!         waker.wake();
18//!     }
19//! ```
20
21use std::future::poll_fn;
22use std::marker::PhantomPinned;
23use std::ops::{Deref, DerefMut};
24use std::pin::{pin, Pin};
25use std::ptr::NonNull;
26use std::sync::{Arc, Mutex, MutexGuard};
27use std::task::{Poll, Waker};
28
29/// An async condition which combines a mutex and a condition variable.
30// Condition is implemented as an intrusive doubly linked list.  Typical use should avoid any
31// additional heap allocations after creation, as the nodes of the list are stored as part of the
32// caller's future.
33pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
34
35impl<T> Condition<T> {
36    /// Returns a new condition.
37    pub fn new(data: T) -> Self {
38        Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
39    }
40
41    /// Returns the number of wakers waiting on the condition.
42    pub fn waker_count(&self) -> usize {
43        self.0.lock().unwrap().count
44    }
45
46    /// Same as `Mutex::lock`.
47    pub fn lock(&self) -> ConditionGuard<'_, T> {
48        ConditionGuard(&self.0, self.0.lock().unwrap())
49    }
50
51    /// Returns when `poll` resolves.
52    pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
53        let mut entry = WakerEntry::new();
54        entry.list = Some(self.0.clone());
55        let mut entry = pin!(entry);
56        poll_fn(|cx| {
57            let mut guard = self.0.lock().unwrap();
58            // SAFETY: We uphold the pin guarantee.
59            let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60            let result = poll(&mut guard.data);
61            if result.is_pending() {
62                // SAFETY: We set list correctly above.
63                unsafe {
64                    entry.node.add(&mut *guard, cx.waker().clone());
65                }
66            }
67            result
68        })
69        .await
70    }
71}
72
73struct Inner<T> {
74    head: Option<NonNull<Node>>,
75    count: usize,
76    data: T,
77}
78
79// SAFETY: Safe because we always access `head` whilst holding the list lock.
80unsafe impl<T: Send> Send for Inner<T> {}
81
82/// Guard returned by `lock`.
83pub struct ConditionGuard<'a, T>(&'a Arc<Mutex<Inner<T>>>, MutexGuard<'a, Inner<T>>);
84
85impl<'a, T> ConditionGuard<'a, T> {
86    /// Adds the waker entry to the condition's list of wakers.
87    pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
88        // SAFETY: We never move the data out.
89        let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
90        waker_entry.list = Some(self.0.clone());
91        // SAFETY: We set list correctly above.
92        unsafe {
93            waker_entry.node.add(&mut *self.1, waker);
94        }
95    }
96
97    /// Returns an iterator that will drain all wakers.  Whilst the drainer exists, a lock is held
98    /// which will prevent new wakers from being added to the list, so depending on your use case,
99    /// you might wish to collect the wakers before calling `wake` on each waker.  NOTE: If the
100    /// drainer is dropped, this will *not* drain elements not visited.
101    pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
102        Drainer(self)
103    }
104
105    /// Returns the number of wakers registered with the condition.
106    pub fn waker_count(&self) -> usize {
107        self.1.count
108    }
109}
110
111impl<T> Deref for ConditionGuard<'_, T> {
112    type Target = T;
113
114    fn deref(&self) -> &Self::Target {
115        &self.1.data
116    }
117}
118
119impl<T> DerefMut for ConditionGuard<'_, T> {
120    fn deref_mut(&mut self) -> &mut Self::Target {
121        &mut self.1.data
122    }
123}
124
125/// A waker entry that can be added to a list.
126pub struct WakerEntry<T> {
127    list: Option<Arc<Mutex<Inner<T>>>>,
128    node: Node,
129}
130
131impl<T> WakerEntry<T> {
132    /// Returns a new entry.
133    pub fn new() -> Self {
134        Self {
135            list: None,
136            node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
137        }
138    }
139}
140
141impl<T> Drop for WakerEntry<T> {
142    fn drop(&mut self) {
143        if let Some(list) = &self.list {
144            self.node.remove(&mut *list.lock().unwrap());
145        }
146    }
147}
148
149// The members here must only be accessed whilst holding the mutex on the list.
150struct Node {
151    next: Option<NonNull<Node>>,
152    prev: Option<NonNull<Node>>,
153    waker: Option<Waker>,
154    _pinned: PhantomPinned,
155}
156
157// SAFETY: Safe because we always access all mebers of `Node` whilst holding the list lock.
158unsafe impl Send for Node {}
159
160impl Node {
161    // # Safety
162    //
163    // The waker *must* have `list` set correctly.
164    unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
165        if self.waker.is_none() {
166            self.prev = None;
167            self.next = inner.head;
168            inner.head = Some(self.into());
169            if let Some(mut next) = self.next {
170                // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set
171                // correctly above.
172                unsafe {
173                    next.as_mut().prev = Some(self.into());
174                }
175            }
176            inner.count += 1;
177        }
178        self.waker = Some(waker);
179    }
180
181    fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
182        if self.waker.is_none() {
183            debug_assert!(self.prev.is_none() && self.next.is_none());
184            return None;
185        }
186        if let Some(mut next) = self.next {
187            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
188            unsafe { next.as_mut().prev = self.prev };
189        }
190        if let Some(mut prev) = self.prev {
191            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
192            unsafe { prev.as_mut().next = self.next };
193        } else {
194            inner.head = self.next;
195        }
196        self.prev = None;
197        self.next = None;
198        inner.count -= 1;
199        self.waker.take()
200    }
201}
202
203/// An iterator that will drain waiters.
204pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
205
206impl<T> Iterator for Drainer<'_, '_, T> {
207    type Item = Waker;
208    fn next(&mut self) -> Option<Self::Item> {
209        if let Some(mut head) = self.0 .1.head {
210            // SAFETY: Safe because we have exclusive access to `Inner` and `head is set correctly.
211            unsafe { head.as_mut().remove(&mut self.0 .1) }
212        } else {
213            None
214        }
215    }
216
217    fn size_hint(&self) -> (usize, Option<usize>) {
218        (self.0 .1.count, Some(self.0 .1.count))
219    }
220}
221
222impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
223    fn len(&self) -> usize {
224        self.0 .1.count
225    }
226}
227
228#[cfg(all(target_os = "fuchsia", test))]
229mod tests {
230    use super::{Condition, WakerEntry};
231    use crate::TestExecutor;
232    use futures::stream::FuturesUnordered;
233    use futures::task::noop_waker;
234    use futures::StreamExt;
235    use std::pin::pin;
236    use std::sync::atomic::{AtomicU64, Ordering};
237    use std::task::Poll;
238
239    #[test]
240    fn test_condition_can_waker_multiple_wakers() {
241        let mut executor = TestExecutor::new();
242        let condition = Condition::new(());
243
244        static COUNT: u64 = 10;
245
246        let counter = AtomicU64::new(0);
247
248        // Use FuturesUnordered so that futures are only polled when explicitly woken.
249        let mut futures = FuturesUnordered::new();
250
251        for _ in 0..COUNT {
252            futures.push(condition.when(|()| {
253                if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
254                    Poll::Ready(())
255                } else {
256                    Poll::Pending
257                }
258            }));
259        }
260
261        assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
262
263        assert_eq!(counter.load(Ordering::Relaxed), COUNT);
264        assert_eq!(condition.waker_count(), COUNT as usize);
265
266        {
267            let mut guard = condition.lock();
268            let drainer = guard.drain_wakers();
269            assert_eq!(drainer.len(), COUNT as usize);
270            for waker in drainer {
271                waker.wake();
272            }
273        }
274
275        assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
276        assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
277    }
278
279    #[test]
280    fn test_dropping_waker_entry_removes_from_list() {
281        let condition = Condition::new(());
282
283        let entry1 = pin!(WakerEntry::new());
284        condition.lock().add_waker(entry1, noop_waker());
285
286        {
287            let entry2 = pin!(WakerEntry::new());
288            condition.lock().add_waker(entry2, noop_waker());
289
290            assert_eq!(condition.waker_count(), 2);
291        }
292
293        assert_eq!(condition.waker_count(), 1);
294        {
295            let mut guard = condition.lock();
296            assert_eq!(guard.drain_wakers().count(), 1);
297        }
298
299        assert_eq!(condition.waker_count(), 0);
300
301        let entry3 = pin!(WakerEntry::new());
302        condition.lock().add_waker(entry3, noop_waker());
303
304        assert_eq!(condition.waker_count(), 1);
305    }
306
307    #[test]
308    fn test_waker_can_be_added_multiple_times() {
309        let condition = Condition::new(());
310
311        let mut entry1 = pin!(WakerEntry::new());
312        condition.lock().add_waker(entry1.as_mut(), noop_waker());
313
314        let mut entry2 = pin!(WakerEntry::new());
315        condition.lock().add_waker(entry2.as_mut(), noop_waker());
316
317        assert_eq!(condition.waker_count(), 2);
318        {
319            let mut guard = condition.lock();
320            assert_eq!(guard.drain_wakers().count(), 2);
321        }
322        assert_eq!(condition.waker_count(), 0);
323
324        condition.lock().add_waker(entry1, noop_waker());
325        condition.lock().add_waker(entry2, noop_waker());
326
327        assert_eq!(condition.waker_count(), 2);
328
329        {
330            let mut guard = condition.lock();
331            assert_eq!(guard.drain_wakers().count(), 2);
332        }
333        assert_eq!(condition.waker_count(), 0);
334    }
335}