fuchsia_async/
condition.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements a combined mutex and condition.
6//!
7//! # Example:
8//!
9//! ```no_run
10//!     let condition = Condition::new(0);
11//!     condition.when(|state| if state == 1 { Poll::Ready(()) } else { Poll::Pending }).await;
12//!
13//!     // Elsewhere...
14//!     let guard = condition.lock();
15//!     *guard.lock() = 1;
16//!     for waker in guard.drain_wakers() {
17//!         waker.wake();
18//!     }
19//! ```
20
21use std::future::poll_fn;
22use std::marker::PhantomPinned;
23use std::ops::{Deref, DerefMut};
24use std::pin::{pin, Pin};
25use std::ptr::NonNull;
26use std::sync::{Arc, Mutex, MutexGuard};
27use std::task::{Poll, Waker};
28
29/// An async condition which combines a mutex and a condition variable.
30// Condition is implemented as an intrusive doubly linked list.  Typical use should avoid any
31// additional heap allocations after creation, as the nodes of the list are stored as part of the
32// caller's future.
33pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
34
35impl<T> Condition<T> {
36    /// Returns a new condition.
37    pub fn new(data: T) -> Self {
38        Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
39    }
40
41    /// Returns the number of wakers waiting on the condition.
42    pub fn waker_count(&self) -> usize {
43        self.0.lock().unwrap().count
44    }
45
46    /// Same as `Mutex::lock`.
47    pub fn lock(&self) -> ConditionGuard<'_, T> {
48        ConditionGuard(&self.0, self.0.lock().unwrap())
49    }
50
51    /// Returns when `poll` resolves.
52    pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
53        let mut entry = WakerEntry::new();
54        entry.list = Some(self.0.clone());
55        let mut entry = pin!(entry);
56        poll_fn(|cx| {
57            let mut guard = self.0.lock().unwrap();
58            // SAFETY: We uphold the pin guarantee.
59            let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60            let result = poll(&mut guard.data);
61            if result.is_pending() {
62                // SAFETY: We set list correctly above.
63                unsafe {
64                    entry.node.add(&mut *guard, cx.waker().clone());
65                }
66            }
67            result
68        })
69        .await
70    }
71}
72
73struct Inner<T> {
74    head: Option<NonNull<Node>>,
75    count: usize,
76    data: T,
77}
78
79// SAFETY: Safe because we always access `head` whilst holding the list lock.
80unsafe impl<T: Send> Send for Inner<T> {}
81
82/// Guard returned by `lock`.
83pub struct ConditionGuard<'a, T>(&'a Arc<Mutex<Inner<T>>>, MutexGuard<'a, Inner<T>>);
84
85impl<'a, T> ConditionGuard<'a, T> {
86    /// Adds the waker entry to the condition's list of wakers.
87    pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
88        // SAFETY: We never move the data out.
89        let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
90        waker_entry.list = Some(self.0.clone());
91        // SAFETY: We set list correctly above.
92        unsafe {
93            waker_entry.node.add(&mut *self.1, waker);
94        }
95    }
96
97    /// Returns an iterator that will drain all wakers.  Whilst the drainer exists, a lock is held
98    /// which will prevent new wakers from being added to the list, so depending on your use case,
99    /// you might wish to collect the wakers before calling `wake` on each waker.  NOTE: If the
100    /// drainer is dropped, this will *not* drain elements not visited.
101    pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
102        Drainer(self)
103    }
104
105    /// Returns the number of wakers registered with the condition.
106    pub fn waker_count(&self) -> usize {
107        self.1.count
108    }
109}
110
111impl<T> Deref for ConditionGuard<'_, T> {
112    type Target = T;
113
114    fn deref(&self) -> &Self::Target {
115        &self.1.data
116    }
117}
118
119impl<T> DerefMut for ConditionGuard<'_, T> {
120    fn deref_mut(&mut self) -> &mut Self::Target {
121        &mut self.1.data
122    }
123}
124
125/// A waker entry that can be added to a list.
126pub struct WakerEntry<T> {
127    list: Option<Arc<Mutex<Inner<T>>>>,
128    node: Node,
129}
130
131impl<T> WakerEntry<T> {
132    /// Returns a new entry.
133    pub fn new() -> Self {
134        Self {
135            list: None,
136            node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
137        }
138    }
139}
140
141impl<T> Default for WakerEntry<T> {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl<T> Drop for WakerEntry<T> {
148    fn drop(&mut self) {
149        if let Some(list) = &self.list {
150            self.node.remove(&mut *list.lock().unwrap());
151        }
152    }
153}
154
155// The members here must only be accessed whilst holding the mutex on the list.
156struct Node {
157    next: Option<NonNull<Node>>,
158    prev: Option<NonNull<Node>>,
159    waker: Option<Waker>,
160    _pinned: PhantomPinned,
161}
162
163// SAFETY: Safe because we always access all mebers of `Node` whilst holding the list lock.
164unsafe impl Send for Node {}
165
166impl Node {
167    // # Safety
168    //
169    // The waker *must* have `list` set correctly.
170    unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
171        if self.waker.is_none() {
172            self.prev = None;
173            self.next = inner.head;
174            inner.head = Some(self.into());
175            if let Some(mut next) = self.next {
176                // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set
177                // correctly above.
178                unsafe {
179                    next.as_mut().prev = Some(self.into());
180                }
181            }
182            inner.count += 1;
183        }
184        self.waker = Some(waker);
185    }
186
187    fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
188        if self.waker.is_none() {
189            debug_assert!(self.prev.is_none() && self.next.is_none());
190            return None;
191        }
192        if let Some(mut next) = self.next {
193            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
194            unsafe { next.as_mut().prev = self.prev };
195        }
196        if let Some(mut prev) = self.prev {
197            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
198            unsafe { prev.as_mut().next = self.next };
199        } else {
200            inner.head = self.next;
201        }
202        self.prev = None;
203        self.next = None;
204        inner.count -= 1;
205        self.waker.take()
206    }
207}
208
209/// An iterator that will drain waiters.
210pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
211
212impl<T> Iterator for Drainer<'_, '_, T> {
213    type Item = Waker;
214    fn next(&mut self) -> Option<Self::Item> {
215        if let Some(mut head) = self.0 .1.head {
216            // SAFETY: Safe because we have exclusive access to `Inner` and `head is set correctly.
217            unsafe { head.as_mut().remove(&mut self.0 .1) }
218        } else {
219            None
220        }
221    }
222
223    fn size_hint(&self) -> (usize, Option<usize>) {
224        (self.0 .1.count, Some(self.0 .1.count))
225    }
226}
227
228impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
229    fn len(&self) -> usize {
230        self.0 .1.count
231    }
232}
233
234#[cfg(all(target_os = "fuchsia", test))]
235mod tests {
236    use super::{Condition, WakerEntry};
237    use crate::TestExecutor;
238    use futures::stream::FuturesUnordered;
239    use futures::task::noop_waker;
240    use futures::StreamExt;
241    use std::pin::pin;
242    use std::sync::atomic::{AtomicU64, Ordering};
243    use std::task::Poll;
244
245    #[test]
246    fn test_condition_can_waker_multiple_wakers() {
247        let mut executor = TestExecutor::new();
248        let condition = Condition::new(());
249
250        static COUNT: u64 = 10;
251
252        let counter = AtomicU64::new(0);
253
254        // Use FuturesUnordered so that futures are only polled when explicitly woken.
255        let mut futures = FuturesUnordered::new();
256
257        for _ in 0..COUNT {
258            futures.push(condition.when(|()| {
259                if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
260                    Poll::Ready(())
261                } else {
262                    Poll::Pending
263                }
264            }));
265        }
266
267        assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
268
269        assert_eq!(counter.load(Ordering::Relaxed), COUNT);
270        assert_eq!(condition.waker_count(), COUNT as usize);
271
272        {
273            let mut guard = condition.lock();
274            let drainer = guard.drain_wakers();
275            assert_eq!(drainer.len(), COUNT as usize);
276            for waker in drainer {
277                waker.wake();
278            }
279        }
280
281        assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
282        assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
283    }
284
285    #[test]
286    fn test_dropping_waker_entry_removes_from_list() {
287        let condition = Condition::new(());
288
289        let entry1 = pin!(WakerEntry::new());
290        condition.lock().add_waker(entry1, noop_waker());
291
292        {
293            let entry2 = pin!(WakerEntry::new());
294            condition.lock().add_waker(entry2, noop_waker());
295
296            assert_eq!(condition.waker_count(), 2);
297        }
298
299        assert_eq!(condition.waker_count(), 1);
300        {
301            let mut guard = condition.lock();
302            assert_eq!(guard.drain_wakers().count(), 1);
303        }
304
305        assert_eq!(condition.waker_count(), 0);
306
307        let entry3 = pin!(WakerEntry::new());
308        condition.lock().add_waker(entry3, noop_waker());
309
310        assert_eq!(condition.waker_count(), 1);
311    }
312
313    #[test]
314    fn test_waker_can_be_added_multiple_times() {
315        let condition = Condition::new(());
316
317        let mut entry1 = pin!(WakerEntry::new());
318        condition.lock().add_waker(entry1.as_mut(), noop_waker());
319
320        let mut entry2 = pin!(WakerEntry::new());
321        condition.lock().add_waker(entry2.as_mut(), noop_waker());
322
323        assert_eq!(condition.waker_count(), 2);
324        {
325            let mut guard = condition.lock();
326            assert_eq!(guard.drain_wakers().count(), 2);
327        }
328        assert_eq!(condition.waker_count(), 0);
329
330        condition.lock().add_waker(entry1, noop_waker());
331        condition.lock().add_waker(entry2, noop_waker());
332
333        assert_eq!(condition.waker_count(), 2);
334
335        {
336            let mut guard = condition.lock();
337            assert_eq!(guard.drain_wakers().count(), 2);
338        }
339        assert_eq!(condition.waker_count(), 0);
340    }
341}