fuchsia_async/
condition.rs1use fuchsia_sync::{Mutex, MutexGuard};
22use std::future::poll_fn;
23use std::marker::PhantomPinned;
24use std::ops::{Deref, DerefMut};
25use std::pin::{Pin, pin};
26use std::ptr::NonNull;
27use std::sync::Arc;
28use std::task::{Poll, Waker};
29
30#[derive(Default)]
35pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
36
37impl<T> Condition<T> {
38 pub fn new(data: T) -> Self {
40 Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
41 }
42
43 pub fn waker_count(&self) -> usize {
45 self.0.lock().count
46 }
47
48 pub fn lock(&self) -> ConditionGuard<'_, T> {
50 ConditionGuard(self.0.lock())
51 }
52
53 pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
55 let mut entry = pin!(self.waker_entry());
56 poll_fn(|cx| {
57 let mut guard = self.0.lock();
58 let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60 let result = poll(&mut guard.data);
61 if result.is_pending() {
62 unsafe {
64 entry.node.add(&mut *guard, cx.waker().clone());
65 }
66 }
67 result
68 })
69 .await
70 }
71
72 pub fn waker_entry(&self) -> WakerEntry<T> {
74 WakerEntry {
75 list: self.0.clone(),
76 node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
77 }
78 }
79}
80
81#[derive(Default)]
82struct Inner<T> {
83 head: Option<NonNull<Node>>,
84 count: usize,
85 data: T,
86}
87
88unsafe impl<T: Send> Send for Inner<T> {}
90
91pub struct ConditionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
93
94impl<'a, T> ConditionGuard<'a, T> {
95 pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
97 let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
99 unsafe {
101 waker_entry.node.add(&mut *self.0, waker);
102 }
103 }
104
105 pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
110 Drainer(self)
111 }
112
113 pub fn waker_count(&self) -> usize {
115 self.0.count
116 }
117}
118
119impl<T> Deref for ConditionGuard<'_, T> {
120 type Target = T;
121
122 fn deref(&self) -> &Self::Target {
123 &self.0.data
124 }
125}
126
127impl<T> DerefMut for ConditionGuard<'_, T> {
128 fn deref_mut(&mut self) -> &mut Self::Target {
129 &mut self.0.data
130 }
131}
132
133pub struct WakerEntry<T> {
135 list: Arc<Mutex<Inner<T>>>,
136 node: Node,
137}
138
139impl<T> Drop for WakerEntry<T> {
140 fn drop(&mut self) {
141 self.node.remove(&mut *self.list.lock());
142 }
143}
144
145struct Node {
147 next: Option<NonNull<Node>>,
148 prev: Option<NonNull<Node>>,
149 waker: Option<Waker>,
150 _pinned: PhantomPinned,
151}
152
153unsafe impl Send for Node {}
155
156impl Node {
157 unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
161 if self.waker.is_none() {
162 self.prev = None;
163 self.next = inner.head;
164 inner.head = Some(self.into());
165 if let Some(mut next) = self.next {
166 unsafe {
169 next.as_mut().prev = Some(self.into());
170 }
171 }
172 inner.count += 1;
173 }
174 self.waker = Some(waker);
175 }
176
177 fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
178 if self.waker.is_none() {
179 debug_assert!(self.prev.is_none() && self.next.is_none());
180 return None;
181 }
182 if let Some(mut next) = self.next {
183 unsafe { next.as_mut().prev = self.prev };
185 }
186 if let Some(mut prev) = self.prev {
187 unsafe { prev.as_mut().next = self.next };
189 } else {
190 inner.head = self.next;
191 }
192 self.prev = None;
193 self.next = None;
194 inner.count -= 1;
195 self.waker.take()
196 }
197}
198
199pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
201
202impl<T> Iterator for Drainer<'_, '_, T> {
203 type Item = Waker;
204 fn next(&mut self) -> Option<Self::Item> {
205 if let Some(mut head) = self.0.0.head {
206 unsafe { head.as_mut().remove(&mut self.0.0) }
208 } else {
209 None
210 }
211 }
212
213 fn size_hint(&self) -> (usize, Option<usize>) {
214 (self.0.0.count, Some(self.0.0.count))
215 }
216}
217
218impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
219 fn len(&self) -> usize {
220 self.0.0.count
221 }
222}
223
224#[cfg(all(target_os = "fuchsia", test))]
225mod tests {
226 use super::Condition;
227 use crate::TestExecutor;
228 use futures::StreamExt;
229 use futures::stream::FuturesUnordered;
230 use futures::task::noop_waker;
231 use std::pin::pin;
232 use std::sync::atomic::{AtomicU64, Ordering};
233 use std::task::Poll;
234
235 #[test]
236 fn test_condition_can_waker_multiple_wakers() {
237 let mut executor = TestExecutor::new();
238 let condition = Condition::new(());
239
240 static COUNT: u64 = 10;
241
242 let counter = AtomicU64::new(0);
243
244 let mut futures = FuturesUnordered::new();
246
247 for _ in 0..COUNT {
248 futures.push(condition.when(|()| {
249 if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
250 Poll::Ready(())
251 } else {
252 Poll::Pending
253 }
254 }));
255 }
256
257 assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
258
259 assert_eq!(counter.load(Ordering::Relaxed), COUNT);
260 assert_eq!(condition.waker_count(), COUNT as usize);
261
262 {
263 let mut guard = condition.lock();
264 let drainer = guard.drain_wakers();
265 assert_eq!(drainer.len(), COUNT as usize);
266 for waker in drainer {
267 waker.wake();
268 }
269 }
270
271 assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
272 assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
273 }
274
275 #[test]
276 fn test_dropping_waker_entry_removes_from_list() {
277 let condition = Condition::new(());
278
279 let entry1 = pin!(condition.waker_entry());
280 condition.lock().add_waker(entry1, noop_waker());
281
282 {
283 let entry2 = pin!(condition.waker_entry());
284 condition.lock().add_waker(entry2, noop_waker());
285
286 assert_eq!(condition.waker_count(), 2);
287 }
288
289 assert_eq!(condition.waker_count(), 1);
290 {
291 let mut guard = condition.lock();
292 assert_eq!(guard.drain_wakers().count(), 1);
293 }
294
295 assert_eq!(condition.waker_count(), 0);
296
297 let entry3 = pin!(condition.waker_entry());
298 condition.lock().add_waker(entry3, noop_waker());
299
300 assert_eq!(condition.waker_count(), 1);
301 }
302
303 #[test]
304 fn test_waker_can_be_added_multiple_times() {
305 let condition = Condition::new(());
306
307 let mut entry1 = pin!(condition.waker_entry());
308 condition.lock().add_waker(entry1.as_mut(), noop_waker());
309
310 let mut entry2 = pin!(condition.waker_entry());
311 condition.lock().add_waker(entry2.as_mut(), noop_waker());
312
313 assert_eq!(condition.waker_count(), 2);
314 {
315 let mut guard = condition.lock();
316 assert_eq!(guard.drain_wakers().count(), 2);
317 }
318 assert_eq!(condition.waker_count(), 0);
319
320 condition.lock().add_waker(entry1, noop_waker());
321 condition.lock().add_waker(entry2, noop_waker());
322
323 assert_eq!(condition.waker_count(), 2);
324
325 {
326 let mut guard = condition.lock();
327 assert_eq!(guard.drain_wakers().count(), 2);
328 }
329 assert_eq!(condition.waker_count(), 0);
330 }
331}