fuchsia_async/runtime/
epoch.rs1use fuchsia_sync::{Mutex, MutexGuard};
8use smallvec::SmallVec;
9use std::collections::VecDeque;
10use std::future::{Future, poll_fn};
11use std::marker::PhantomData;
12use std::mem::ManuallyDrop;
13use std::ops::Deref;
14use std::sync::LazyLock;
15use std::task::{Poll, RawWakerVTable, Waker};
16
17#[derive(Default)]
19pub struct Epoch {
20 inner: Mutex<Inner>,
21}
22
23#[derive(Default)]
24struct Inner {
25 callbacks: VecDeque<Callback>,
30
31 sequence: usize,
34}
35
36struct Callback {
39 data: Data,
40
41 vtable: Option<&'static RawWakerVTable>,
44}
45
46#[repr(C)]
47union Data {
48 data: *const (),
50
51 count: usize,
53}
54
55unsafe impl Send for Data {}
57
58pub struct EpochGuard<'a> {
60 epoch: &'a Epoch,
61
62 sequence: usize,
64}
65
66pub struct CallbackRef<'a> {
68 epoch: &'a Epoch,
69
70 sequence: usize,
72}
73
74struct InnerGuard<'a>(MutexGuard<'a, Inner>, PendingCallbacks);
76
77#[derive(Default)]
78struct PendingCallbacks(SmallVec<[Callback; 4]>);
79
80impl Epoch {
81 pub fn defer<F: FnOnce() + Send + Unpin + 'static>(&self, callback: F) -> CallbackRef<'_> {
85 CallbackRef { epoch: self, sequence: self.inner_guard().defer(callback.into()) }
86 }
87
88 pub fn defer_waker(&self, waker: &Waker) -> CallbackRef<'_> {
90 CallbackRef { epoch: self, sequence: self.inner_guard().defer(waker.clone().into()) }
91 }
92
93 pub fn guard(&self) -> EpochGuard<'_> {
96 EpochGuard { epoch: self, sequence: self.inner.lock().add_ref() }
97 }
98
99 pub fn global() -> &'static Epoch {
101 static GLOBAL: LazyLock<Epoch> = LazyLock::new(Epoch::default);
102 &GLOBAL
103 }
104
105 pub fn barrier(&self) -> impl Future<Output = ()> + '_ {
108 let cb = self.defer_waker(Waker::noop());
109 poll_fn(
110 move |cx| {
111 if cb.replace_waker(cx.waker()) { Poll::Pending } else { Poll::Ready(()) }
112 },
113 )
114 }
115
116 fn inner_guard(&self) -> InnerGuard<'_> {
117 InnerGuard(self.inner.lock(), PendingCallbacks::default())
118 }
119}
120
121impl Inner {
122 fn add_ref(&mut self) -> usize {
124 if let Some(count) = self.callbacks.back_mut().and_then(|cb| cb.count_mut()) {
125 *count += 1;
126 } else {
127 self.callbacks.push_back(Callback { data: Data { count: 1 }, vtable: None });
128 }
129 self.sequence + self.callbacks.len() - 1
130 }
131}
132
133impl InnerGuard<'_> {
134 fn defer(&mut self, callback: Callback) -> usize {
135 let InnerGuard(inner, pending_callbacks) = self;
136 if inner.callbacks.front().is_none_or(|cb| cb.count().unwrap() == 0) {
137 pending_callbacks.push(callback);
139
140 0
142 } else {
143 inner.callbacks.push_back(callback);
144 inner.sequence + inner.callbacks.len() - 1
145 }
146 }
147
148 fn sub_ref(&mut self, sequence: usize) {
150 let InnerGuard(inner, pending_callbacks) = self;
151 let index = sequence - inner.sequence;
152 let count = inner.callbacks[index].count_mut().unwrap();
153 *count -= 1;
154 if *count == 0 && index == 0 && inner.callbacks.len() > 1 {
157 while let Some(callback) = inner.callbacks.front() {
158 if let Some(count) = callback.count() {
159 if count > 0 || inner.callbacks.len() == 1 {
160 break;
163 }
164 inner.callbacks.pop_front();
165 } else {
166 pending_callbacks.push(inner.callbacks.pop_front().unwrap());
167 }
168 inner.sequence += 1;
169 }
170 }
171 }
172}
173
174impl PendingCallbacks {
175 fn push(&mut self, callback: Callback) {
176 self.0.push(callback);
177 }
178}
179
180impl Drop for PendingCallbacks {
181 fn drop(&mut self) {
182 for callback in self.0.drain(..) {
183 callback.call();
184 }
185 }
186}
187
188impl Callback {
189 fn new(data: *const (), vtable: &'static RawWakerVTable) -> Self {
190 Self { data: Data { data }, vtable: Some(vtable) }
191 }
192
193 fn count(&self) -> Option<usize> {
194 if self.vtable.is_none() {
195 Some(unsafe { self.data.count })
197 } else {
198 None
199 }
200 }
201
202 fn count_mut(&mut self) -> Option<&mut usize> {
203 if self.vtable.is_none() {
204 Some(unsafe { &mut self.data.count })
206 } else {
207 None
208 }
209 }
210
211 fn call(mut self) {
212 unsafe {
214 Waker::new(self.data.data, self.vtable.take().unwrap()).wake();
215 }
216 }
217}
218
219impl Drop for Callback {
220 fn drop(&mut self) {
221 if let Some(vtable) = self.vtable {
222 drop(unsafe { Waker::new(self.data.data, vtable) });
224 }
225 }
226}
227
228impl<F: FnOnce() + Send + Unpin + 'static> From<F> for Callback {
229 fn from(value: F) -> Self {
230 if std::mem::size_of::<F>() <= std::mem::size_of::<*const ()>() {
231 struct InlineCallback<F>(PhantomData<F>);
232
233 impl<F: FnOnce() + Send + Unpin + 'static> InlineCallback<F> {
234 const VTABLE: RawWakerVTable = RawWakerVTable::new(
235 |_| unreachable!(),
236 Self::wake,
237 |_| unreachable!(),
238 Self::drop,
239 );
240
241 unsafe fn wake(data: *const ()) {
242 let callback = unsafe { std::mem::transmute_copy::<*const (), F>(&data) };
245 callback();
246 }
247
248 unsafe fn drop(data: *const ()) {
249 drop(unsafe { std::mem::transmute_copy::<*const (), F>(&data) });
252 }
253 }
254
255 let mut data = std::ptr::null();
256 let callback = ManuallyDrop::new(value);
257 unsafe {
259 std::ptr::copy_nonoverlapping(
260 callback.deref() as *const F as *const u8,
261 &mut data as *mut _ as *mut u8,
262 std::mem::size_of::<F>(),
263 );
264 }
265 Self::new(data, &InlineCallback::<F>::VTABLE)
266 } else {
267 struct BoxCallback<F>(PhantomData<F>);
268
269 impl<F: FnOnce() + Send + Unpin + 'static> BoxCallback<F> {
270 const VTABLE: RawWakerVTable = RawWakerVTable::new(
271 |_| unreachable!(),
272 Self::wake,
273 |_| unreachable!(),
274 Self::drop,
275 );
276
277 unsafe fn wake(data: *const ()) {
278 let callback = unsafe { Box::from_raw(data as *mut F) };
280 callback();
281 }
282
283 unsafe fn drop(data: *const ()) {
284 drop(unsafe { Box::from_raw(data as *mut F) });
286 }
287 }
288 Callback::new(Box::into_raw(Box::new(value)) as *const (), &BoxCallback::<F>::VTABLE)
289 }
290 }
291}
292
293impl From<Waker> for Callback {
294 fn from(waker: Waker) -> Self {
295 let waker = ManuallyDrop::new(waker);
297 Callback::new(waker.data(), waker.vtable())
298 }
299}
300
301impl Clone for EpochGuard<'_> {
302 fn clone(&self) -> Self {
303 let mut inner = self.epoch.inner.lock();
304 let index = self.sequence - inner.sequence;
305 *inner.callbacks[index].count_mut().unwrap() += 1;
306 Self { epoch: self.epoch, sequence: self.sequence }
307 }
308}
309
310impl Drop for EpochGuard<'_> {
311 fn drop(&mut self) {
312 self.epoch.inner_guard().sub_ref(self.sequence);
313 }
314}
315
316impl CallbackRef<'_> {
317 pub fn has_fired(&self) -> bool {
319 self.sequence <= self.epoch.inner.lock().sequence
322 }
323
324 #[must_use]
327 pub fn replace<F: FnOnce() + Send + Unpin + 'static>(&self, callback: F) -> bool {
328 let mut inner = self.epoch.inner.lock();
329 if self.sequence <= inner.sequence {
330 return false;
331 }
332 let index = self.sequence - inner.sequence;
333 inner.callbacks[index] = callback.into();
334 true
335 }
336
337 #[must_use]
339 pub fn replace_waker(&self, waker: &Waker) -> bool {
340 let mut inner = self.epoch.inner.lock();
341 if self.sequence <= inner.sequence {
342 return false;
343 }
344 let index = self.sequence - inner.sequence;
345 inner.callbacks[index] = waker.clone().into();
346 true
347 }
348}
349
350#[cfg(test)]
351mod test {
352 use super::*;
353 use futures::stream::{FuturesUnordered, StreamExt};
354 use std::sync::Arc;
355 use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
356 use std::{iter, thread};
357
358 #[test]
359 fn test_defer() {
360 let epoch = Epoch::default();
361 let called = Arc::new(AtomicBool::new(false));
362 let called_clone = called.clone();
363 let guard = epoch.guard();
364 let _cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
365 assert!(!called.load(Ordering::Relaxed));
366 drop(guard);
367 assert!(called.load(Ordering::Relaxed));
368 }
369
370 #[test]
371 fn test_defer_large_callback() {
372 let epoch = Epoch::default();
373 let large_data = [0u8; 1024];
374 let called = Arc::new(AtomicBool::new(false));
375 let called_clone = called.clone();
376 let guard = epoch.guard();
377 let _cb = epoch.defer(move || {
378 assert_eq!(large_data.len(), 1024);
379 called_clone.store(true, Ordering::Relaxed);
380 });
381 assert!(!called.load(Ordering::Relaxed));
382 drop(guard);
383 assert!(called.load(Ordering::Relaxed));
384 }
385
386 #[test]
387 fn test_defer_small_callback() {
388 let epoch = Epoch::default();
389 epoch.defer(|| {});
390 let b = 13u8;
391 let callback = move || assert_eq!(b, 13);
392 epoch.defer(callback);
393 }
394
395 #[test]
396 fn test_defer_when_no_guards() {
397 let epoch = Epoch::default();
398 let called = Arc::new(AtomicBool::new(false));
399 let called_clone = called.clone();
400 let cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
401 assert!(called.load(Ordering::Relaxed));
402 assert!(cb.has_fired());
403 assert!(!cb.replace(|| {}));
404 assert!(!cb.replace_waker(Waker::noop()));
405 }
406
407 #[test]
408 fn test_multiple_guards() {
409 let epoch = Epoch::default();
410 let guard1 = epoch.guard();
411 let guard2 = epoch.guard();
412 let called = Arc::new(AtomicBool::new(false));
413 let called_clone = called.clone();
414 let _cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
415 assert!(!called.load(Ordering::Relaxed));
416 drop(guard1);
417 assert!(!called.load(Ordering::Relaxed));
418 drop(guard2);
419 assert!(called.load(Ordering::Relaxed));
420 }
421
422 #[test]
423 fn test_multiple_guards_in_different_epoch() {
424 let epoch = Epoch::default();
425 let guard1 = epoch.guard();
426 let called1 = Arc::new(AtomicBool::new(false));
427 let called1_clone = called1.clone();
428 let _cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
429 let guard2 = epoch.guard();
430 let called2 = Arc::new(AtomicBool::new(false));
431 let called2_clone = called2.clone();
432 let _cb = epoch.defer(move || called2_clone.store(true, Ordering::Relaxed));
433 assert!(!called1.load(Ordering::Relaxed));
434 assert!(!called2.load(Ordering::Relaxed));
435 drop(guard1);
436 assert!(called1.load(Ordering::Relaxed));
437 assert!(!called2.load(Ordering::Relaxed));
438 drop(guard2);
439 assert!(called2.load(Ordering::Relaxed));
440 }
441
442 #[test]
443 fn test_multiple_guards_in_different_epoch_reverse_order() {
444 let epoch = Epoch::default();
445 let guard1 = epoch.guard();
446 let called1 = Arc::new(AtomicBool::new(false));
447 let called1_clone = called1.clone();
448 let _cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
449 let guard2 = epoch.guard();
450 let called2 = Arc::new(AtomicBool::new(false));
451 let called2_clone = called2.clone();
452 let _cb = epoch.defer(move || called2_clone.store(true, Ordering::Relaxed));
453 assert!(!called1.load(Ordering::Relaxed));
454 assert!(!called2.load(Ordering::Relaxed));
455 drop(guard2);
457 assert!(!called1.load(Ordering::Relaxed));
458 assert!(!called2.load(Ordering::Relaxed));
459 drop(guard1);
460 assert!(called1.load(Ordering::Relaxed));
461 assert!(called2.load(Ordering::Relaxed));
462 }
463
464 #[test]
465 fn test_barrier() {
466 let epoch = Epoch::default();
467 let guard = epoch.guard();
468 let barrier_future = epoch.barrier();
469 let mut barrier_future: FuturesUnordered<_> = iter::once(barrier_future).collect();
472 let mut cx = std::task::Context::from_waker(Waker::noop());
473 assert!(barrier_future.poll_next_unpin(&mut cx).is_pending());
474 drop(guard);
475 assert!(barrier_future.poll_next_unpin(&mut cx).is_ready());
476 }
477
478 #[test]
479 fn test_has_fired() {
480 let epoch = Epoch::default();
481 let guard = epoch.guard();
482 let cb = epoch.defer(|| {});
483 assert!(!cb.has_fired());
484 drop(guard);
485 assert!(cb.has_fired());
486 }
487
488 #[test]
489 fn test_replace() {
490 let epoch = Epoch::default();
491 let called1 = Arc::new(AtomicBool::new(false));
492 let called2 = Arc::new(AtomicBool::new(false));
493 let called1_clone = called1.clone();
494 let called2_clone = called2.clone();
495 let guard = epoch.guard();
496 let cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
497 assert!(cb.replace(move || called2_clone.store(true, Ordering::Relaxed)));
498 drop(guard);
499 assert!(!called1.load(Ordering::Relaxed));
500 assert!(called2.load(Ordering::Relaxed));
501 assert!(!cb.replace(|| {}));
502 }
503
504 #[test]
505 fn test_barrier_race() {
506 let epoch = Epoch::default();
507 thread::scope(|s| {
508 s.spawn(|| {
509 for _ in 0..1000 {
510 let _guard = epoch.guard();
511 }
512 });
513 s.spawn(|| {
514 for _ in 0..1000 {
515 let _guard = epoch.guard();
516 }
517 });
518 s.spawn(|| {
519 for _ in 0..1000 {
520 futures::executor::block_on(async {
521 epoch.barrier().await;
522 });
523 }
524 });
525 });
526 }
527
528 #[test]
529 fn test_callback_reentrancy() {
530 let epoch = Arc::new(Epoch::default());
531 let counter = Arc::new(AtomicU8::new(0));
532
533 let epoch_clone = epoch.clone();
534 let counter_clone = counter.clone();
535 epoch.defer(move || {
536 epoch_clone.defer(move || {
537 counter_clone.fetch_add(1, Ordering::Relaxed);
538 });
539 });
540
541 assert_eq!(counter.load(Ordering::Relaxed), 1);
542
543 let guard = epoch.guard();
544 let epoch_clone = epoch.clone();
545 let counter_clone = counter.clone();
546 epoch.defer(move || {
547 epoch_clone.defer(move || {
548 counter_clone.fetch_add(1, Ordering::Relaxed);
549 });
550 });
551
552 assert_eq!(counter.load(Ordering::Relaxed), 1);
553
554 drop(guard);
555
556 assert_eq!(counter.load(Ordering::Relaxed), 2);
557 }
558}