fuchsia_async/
condition.rs
1use std::future::poll_fn;
22use std::marker::PhantomPinned;
23use std::ops::{Deref, DerefMut};
24use std::pin::{pin, Pin};
25use std::ptr::NonNull;
26use std::sync::{Arc, Mutex, MutexGuard};
27use std::task::{Poll, Waker};
28
29pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
34
35impl<T> Condition<T> {
36 pub fn new(data: T) -> Self {
38 Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
39 }
40
41 pub fn waker_count(&self) -> usize {
43 self.0.lock().unwrap().count
44 }
45
46 pub fn lock(&self) -> ConditionGuard<'_, T> {
48 ConditionGuard(&self.0, self.0.lock().unwrap())
49 }
50
51 pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
53 let mut entry = WakerEntry::new();
54 entry.list = Some(self.0.clone());
55 let mut entry = pin!(entry);
56 poll_fn(|cx| {
57 let mut guard = self.0.lock().unwrap();
58 let entry = unsafe { entry.as_mut().get_unchecked_mut() };
60 let result = poll(&mut guard.data);
61 if result.is_pending() {
62 unsafe {
64 entry.node.add(&mut *guard, cx.waker().clone());
65 }
66 }
67 result
68 })
69 .await
70 }
71}
72
73struct Inner<T> {
74 head: Option<NonNull<Node>>,
75 count: usize,
76 data: T,
77}
78
79unsafe impl<T: Send> Send for Inner<T> {}
81
82pub struct ConditionGuard<'a, T>(&'a Arc<Mutex<Inner<T>>>, MutexGuard<'a, Inner<T>>);
84
85impl<'a, T> ConditionGuard<'a, T> {
86 pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
88 let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
90 waker_entry.list = Some(self.0.clone());
91 unsafe {
93 waker_entry.node.add(&mut *self.1, waker);
94 }
95 }
96
97 pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
102 Drainer(self)
103 }
104
105 pub fn waker_count(&self) -> usize {
107 self.1.count
108 }
109}
110
111impl<T> Deref for ConditionGuard<'_, T> {
112 type Target = T;
113
114 fn deref(&self) -> &Self::Target {
115 &self.1.data
116 }
117}
118
119impl<T> DerefMut for ConditionGuard<'_, T> {
120 fn deref_mut(&mut self) -> &mut Self::Target {
121 &mut self.1.data
122 }
123}
124
125pub struct WakerEntry<T> {
127 list: Option<Arc<Mutex<Inner<T>>>>,
128 node: Node,
129}
130
131impl<T> WakerEntry<T> {
132 pub fn new() -> Self {
134 Self {
135 list: None,
136 node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
137 }
138 }
139}
140
141impl<T> Drop for WakerEntry<T> {
142 fn drop(&mut self) {
143 if let Some(list) = &self.list {
144 self.node.remove(&mut *list.lock().unwrap());
145 }
146 }
147}
148
149struct Node {
151 next: Option<NonNull<Node>>,
152 prev: Option<NonNull<Node>>,
153 waker: Option<Waker>,
154 _pinned: PhantomPinned,
155}
156
157unsafe impl Send for Node {}
159
160impl Node {
161 unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
165 if self.waker.is_none() {
166 self.prev = None;
167 self.next = inner.head;
168 inner.head = Some(self.into());
169 if let Some(mut next) = self.next {
170 unsafe {
173 next.as_mut().prev = Some(self.into());
174 }
175 }
176 inner.count += 1;
177 }
178 self.waker = Some(waker);
179 }
180
181 fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
182 if self.waker.is_none() {
183 debug_assert!(self.prev.is_none() && self.next.is_none());
184 return None;
185 }
186 if let Some(mut next) = self.next {
187 unsafe { next.as_mut().prev = self.prev };
189 }
190 if let Some(mut prev) = self.prev {
191 unsafe { prev.as_mut().next = self.next };
193 } else {
194 inner.head = self.next;
195 }
196 self.prev = None;
197 self.next = None;
198 inner.count -= 1;
199 self.waker.take()
200 }
201}
202
203pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
205
206impl<T> Iterator for Drainer<'_, '_, T> {
207 type Item = Waker;
208 fn next(&mut self) -> Option<Self::Item> {
209 if let Some(mut head) = self.0 .1.head {
210 unsafe { head.as_mut().remove(&mut self.0 .1) }
212 } else {
213 None
214 }
215 }
216
217 fn size_hint(&self) -> (usize, Option<usize>) {
218 (self.0 .1.count, Some(self.0 .1.count))
219 }
220}
221
222impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
223 fn len(&self) -> usize {
224 self.0 .1.count
225 }
226}
227
228#[cfg(all(target_os = "fuchsia", test))]
229mod tests {
230 use super::{Condition, WakerEntry};
231 use crate::TestExecutor;
232 use futures::stream::FuturesUnordered;
233 use futures::task::noop_waker;
234 use futures::StreamExt;
235 use std::pin::pin;
236 use std::sync::atomic::{AtomicU64, Ordering};
237 use std::task::Poll;
238
239 #[test]
240 fn test_condition_can_waker_multiple_wakers() {
241 let mut executor = TestExecutor::new();
242 let condition = Condition::new(());
243
244 static COUNT: u64 = 10;
245
246 let counter = AtomicU64::new(0);
247
248 let mut futures = FuturesUnordered::new();
250
251 for _ in 0..COUNT {
252 futures.push(condition.when(|()| {
253 if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
254 Poll::Ready(())
255 } else {
256 Poll::Pending
257 }
258 }));
259 }
260
261 assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
262
263 assert_eq!(counter.load(Ordering::Relaxed), COUNT);
264 assert_eq!(condition.waker_count(), COUNT as usize);
265
266 {
267 let mut guard = condition.lock();
268 let drainer = guard.drain_wakers();
269 assert_eq!(drainer.len(), COUNT as usize);
270 for waker in drainer {
271 waker.wake();
272 }
273 }
274
275 assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
276 assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
277 }
278
279 #[test]
280 fn test_dropping_waker_entry_removes_from_list() {
281 let condition = Condition::new(());
282
283 let entry1 = pin!(WakerEntry::new());
284 condition.lock().add_waker(entry1, noop_waker());
285
286 {
287 let entry2 = pin!(WakerEntry::new());
288 condition.lock().add_waker(entry2, noop_waker());
289
290 assert_eq!(condition.waker_count(), 2);
291 }
292
293 assert_eq!(condition.waker_count(), 1);
294 {
295 let mut guard = condition.lock();
296 assert_eq!(guard.drain_wakers().count(), 1);
297 }
298
299 assert_eq!(condition.waker_count(), 0);
300
301 let entry3 = pin!(WakerEntry::new());
302 condition.lock().add_waker(entry3, noop_waker());
303
304 assert_eq!(condition.waker_count(), 1);
305 }
306
307 #[test]
308 fn test_waker_can_be_added_multiple_times() {
309 let condition = Condition::new(());
310
311 let mut entry1 = pin!(WakerEntry::new());
312 condition.lock().add_waker(entry1.as_mut(), noop_waker());
313
314 let mut entry2 = pin!(WakerEntry::new());
315 condition.lock().add_waker(entry2.as_mut(), noop_waker());
316
317 assert_eq!(condition.waker_count(), 2);
318 {
319 let mut guard = condition.lock();
320 assert_eq!(guard.drain_wakers().count(), 2);
321 }
322 assert_eq!(condition.waker_count(), 0);
323
324 condition.lock().add_waker(entry1, noop_waker());
325 condition.lock().add_waker(entry2, noop_waker());
326
327 assert_eq!(condition.waker_count(), 2);
328
329 {
330 let mut guard = condition.lock();
331 assert_eq!(guard.drain_wakers().count(), 2);
332 }
333 assert_eq!(condition.waker_count(), 0);
334 }
335}