fuchsia_async/
condition.rs1use fuchsia_sync::{Mutex, MutexGuard};
22use std::future::poll_fn;
23use std::marker::PhantomPinned;
24use std::ops::{Deref, DerefMut};
25use std::pin::{Pin, pin};
26use std::ptr::NonNull;
27use std::sync::Arc;
28use std::task::{Poll, Waker};
29
30#[derive(Default)]
35pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
36
37impl<T> Condition<T> {
38 pub fn new(data: T) -> Self {
40 Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
41 }
42
43 pub fn waker_count(&self) -> usize {
45 self.0.lock().count
46 }
47
48 pub fn lock(&self) -> ConditionGuard<'_, T> {
50 ConditionGuard(self.0.lock())
51 }
52
53 pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
55 let mut entry = pin!(self.waker_entry());
56 poll_fn(|cx| {
57 let mut guard = self.0.lock();
58 let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60 let result = poll(&mut guard.data);
61 if result.is_pending() {
62 unsafe {
64 entry.node.add(&mut *guard, cx.waker().clone());
65 }
66 }
67 result
68 })
69 .await
70 }
71
72 pub fn waker_entry(&self) -> WakerEntry<T> {
74 WakerEntry {
75 list: self.0.clone(),
76 node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
77 }
78 }
79}
80
81#[derive(Default)]
82struct Inner<T> {
83 head: Option<NonNull<Node>>,
84 count: usize,
85 data: T,
86}
87
88unsafe impl<T: Send> Send for Inner<T> {}
90
91pub struct ConditionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
93
94impl<'a, T> ConditionGuard<'a, T> {
95 pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
101 assert!(
103 waker_entry.list.data_ptr() == &mut *self.0,
104 "Cannot add waker to different Condition"
105 );
106 let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
108 unsafe {
110 waker_entry.node.add(&mut *self.0, waker);
111 }
112 }
113
114 pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
119 Drainer(self)
120 }
121
122 pub fn waker_count(&self) -> usize {
124 self.0.count
125 }
126}
127
128impl<T> Deref for ConditionGuard<'_, T> {
129 type Target = T;
130
131 fn deref(&self) -> &Self::Target {
132 &self.0.data
133 }
134}
135
136impl<T> DerefMut for ConditionGuard<'_, T> {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 &mut self.0.data
139 }
140}
141
142pub struct WakerEntry<T> {
144 list: Arc<Mutex<Inner<T>>>,
145 node: Node,
146}
147
148impl<T> Drop for WakerEntry<T> {
149 fn drop(&mut self) {
150 self.node.remove(&mut *self.list.lock());
151 }
152}
153
154struct Node {
156 next: Option<NonNull<Node>>,
157 prev: Option<NonNull<Node>>,
158 waker: Option<Waker>,
159 _pinned: PhantomPinned,
160}
161
162unsafe impl Send for Node {}
164
165impl Node {
166 unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
170 if self.waker.is_none() {
171 self.prev = None;
172 self.next = inner.head;
173 inner.head = Some(self.into());
174 if let Some(mut next) = self.next {
175 unsafe {
178 next.as_mut().prev = Some(self.into());
179 }
180 }
181 inner.count += 1;
182 }
183 self.waker = Some(waker);
184 }
185
186 fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
187 if self.waker.is_none() {
188 debug_assert!(self.prev.is_none() && self.next.is_none());
189 return None;
190 }
191 if let Some(mut next) = self.next {
192 unsafe { next.as_mut().prev = self.prev };
194 }
195 if let Some(mut prev) = self.prev {
196 unsafe { prev.as_mut().next = self.next };
198 } else {
199 debug_assert_eq!(inner.head, Some(self.into()));
200 inner.head = self.next;
201 }
202 self.prev = None;
203 self.next = None;
204 inner.count -= 1;
205 self.waker.take()
206 }
207}
208
209pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
211
212impl<T> Iterator for Drainer<'_, '_, T> {
213 type Item = Waker;
214 fn next(&mut self) -> Option<Self::Item> {
215 if let Some(mut head) = self.0.0.head {
216 unsafe { head.as_mut().remove(&mut self.0.0) }
218 } else {
219 None
220 }
221 }
222
223 fn size_hint(&self) -> (usize, Option<usize>) {
224 (self.0.0.count, Some(self.0.0.count))
225 }
226}
227
228impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
229 fn len(&self) -> usize {
230 self.0.0.count
231 }
232}
233
234#[cfg(all(target_os = "fuchsia", test))]
235mod tests {
236 use super::Condition;
237 use crate::TestExecutor;
238 use futures::StreamExt;
239 use futures::stream::FuturesUnordered;
240 use std::pin::pin;
241 use std::sync::atomic::{AtomicU64, Ordering};
242 use std::task::{Poll, Waker};
243
244 #[test]
245 fn test_condition_can_waker_multiple_wakers() {
246 let mut executor = TestExecutor::new();
247 let condition = Condition::new(());
248
249 static COUNT: u64 = 10;
250
251 let counter = AtomicU64::new(0);
252
253 let mut futures = FuturesUnordered::new();
255
256 for _ in 0..COUNT {
257 futures.push(condition.when(|()| {
258 if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
259 Poll::Ready(())
260 } else {
261 Poll::Pending
262 }
263 }));
264 }
265
266 assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
267
268 assert_eq!(counter.load(Ordering::Relaxed), COUNT);
269 assert_eq!(condition.waker_count(), COUNT as usize);
270
271 {
272 let mut guard = condition.lock();
273 let drainer = guard.drain_wakers();
274 assert_eq!(drainer.len(), COUNT as usize);
275 for waker in drainer {
276 waker.wake();
277 }
278 }
279
280 assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
281 assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
282 }
283
284 #[test]
285 fn test_dropping_waker_entry_removes_from_list() {
286 let condition = Condition::new(());
287
288 let entry1 = pin!(condition.waker_entry());
289 condition.lock().add_waker(entry1, Waker::noop().clone());
290
291 {
292 let entry2 = pin!(condition.waker_entry());
293 condition.lock().add_waker(entry2, Waker::noop().clone());
294
295 assert_eq!(condition.waker_count(), 2);
296 }
297
298 assert_eq!(condition.waker_count(), 1);
299 {
300 let mut guard = condition.lock();
301 assert_eq!(guard.drain_wakers().count(), 1);
302 }
303
304 assert_eq!(condition.waker_count(), 0);
305
306 let entry3 = pin!(condition.waker_entry());
307 condition.lock().add_waker(entry3, Waker::noop().clone());
308
309 assert_eq!(condition.waker_count(), 1);
310 }
311
312 #[test]
313 fn test_waker_can_be_added_multiple_times() {
314 let condition = Condition::new(());
315
316 let mut entry1 = pin!(condition.waker_entry());
317 condition.lock().add_waker(entry1.as_mut(), Waker::noop().clone());
318
319 let mut entry2 = pin!(condition.waker_entry());
320 condition.lock().add_waker(entry2.as_mut(), Waker::noop().clone());
321
322 assert_eq!(condition.waker_count(), 2);
323 {
324 let mut guard = condition.lock();
325 assert_eq!(guard.drain_wakers().count(), 2);
326 }
327 assert_eq!(condition.waker_count(), 0);
328
329 condition.lock().add_waker(entry1, Waker::noop().clone());
330 condition.lock().add_waker(entry2, Waker::noop().clone());
331
332 assert_eq!(condition.waker_count(), 2);
333
334 {
335 let mut guard = condition.lock();
336 assert_eq!(guard.drain_wakers().count(), 2);
337 }
338 assert_eq!(condition.waker_count(), 0);
339 }
340
341 #[test]
342 #[should_panic]
343 fn test_adding_waker_to_different_condition() {
344 let condition1 = Condition::new(());
345 let condition2 = Condition::new(());
346
347 let entry2 = pin!(condition2.waker_entry());
348
349 let mut guard = condition1.lock();
350 guard.add_waker(entry2, std::task::Waker::noop().clone());
352 }
353}