fuchsia_async/runtime/
epoch.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Epoch based deferred execution
6
7use fuchsia_sync::{Mutex, MutexGuard};
8use smallvec::SmallVec;
9use std::collections::VecDeque;
10use std::future::{Future, poll_fn};
11use std::marker::PhantomData;
12use std::mem::ManuallyDrop;
13use std::ops::Deref;
14use std::sync::LazyLock;
15use std::task::{Poll, RawWakerVTable, Waker};
16
17/// Epoch implements epoch based deferred execution
18#[derive(Default)]
19pub struct Epoch {
20    inner: Mutex<Inner>,
21}
22
23#[derive(Default)]
24struct Inner {
25    // Contains a list of deferred callbacks intermixed with special entries which hold the
26    // reference counts for each epoch.  The first entry, if there is one, should always be a
27    // reference count.  Once it reaches zero, callbacks that follow are called until we encounter
28    // another non-zero reference count entry.
29    callbacks: VecDeque<Callback>,
30
31    // Every entry in the queue is given an increasing sequence number.  `sequence` here is the
32    // sequence number assigned to the first callback in `callbacks`.
33    sequence: usize,
34}
35
36// Callback is either a callback that has been added using `defer` or a guard count (if `vtable`
37// is None).
38struct Callback {
39    data: Data,
40
41    // We use RawWakerVTable so that we can easily make this work for wakers, but we only use the
42    // `wake` and `drop` functions.
43    vtable: Option<&'static RawWakerVTable>,
44}
45
46#[repr(C)]
47union Data {
48    // This is used if `vtable` is Some.
49    data: *const (),
50
51    // This is used as a reference count when `vtable` is None.
52    count: usize,
53}
54
55// SAFETY: This is safe so long as the contract for RawWakerVTable is upheld.
56unsafe impl Send for Data {}
57
58/// An epoch guard. See `guard` below.
59pub struct EpochGuard<'a> {
60    epoch: &'a Epoch,
61
62    // This is the sequence number of the entry in `callbacks` that has the count.
63    sequence: usize,
64}
65
66/// A reference to the deferred callback.
67pub struct CallbackRef<'a> {
68    epoch: &'a Epoch,
69
70    // This is the sequence number for the entry in `callbacks`.
71    sequence: usize,
72}
73
74// PendingCallbacks must be dropped after the guard.
75struct InnerGuard<'a>(MutexGuard<'a, Inner>, PendingCallbacks);
76
77#[derive(Default)]
78struct PendingCallbacks(SmallVec<[Callback; 4]>);
79
80impl Epoch {
81    /// Schedule `callback` to be executed when all prior references have been returned. If
82    /// `callback` is no bigger than `usize`, typically no heap allocation will be incurred.  If
83    /// there are no outstanding guards, `callback` will be called immediately.
84    pub fn defer<F: FnOnce() + Send + Unpin + 'static>(&self, callback: F) -> CallbackRef<'_> {
85        CallbackRef { epoch: self, sequence: self.inner_guard().defer(callback.into()) }
86    }
87
88    /// Same as `defer` but for a waker.
89    pub fn defer_waker(&self, waker: &Waker) -> CallbackRef<'_> {
90        CallbackRef { epoch: self, sequence: self.inner_guard().defer(waker.clone().into()) }
91    }
92
93    /// Takes a guard on the current epoch. Subsequent callbacks queued via `defer` are guaranteed
94    /// not to be called before this guard is dropped.
95    pub fn guard(&self) -> EpochGuard<'_> {
96        EpochGuard { epoch: self, sequence: self.inner.lock().add_ref() }
97    }
98
99    /// Returns a reference to the global Epoch instance.
100    pub fn global() -> &'static Epoch {
101        static GLOBAL: LazyLock<Epoch> = LazyLock::new(Epoch::default);
102        &GLOBAL
103    }
104
105    /// Waits for all previous guards to be released.  The barrier is initialised when the future is
106    /// created, not when it is first polled.
107    pub fn barrier(&self) -> impl Future<Output = ()> + '_ {
108        let cb = self.defer_waker(Waker::noop());
109        poll_fn(
110            move |cx| {
111                if cb.replace_waker(cx.waker()) { Poll::Pending } else { Poll::Ready(()) }
112            },
113        )
114    }
115
116    fn inner_guard(&self) -> InnerGuard<'_> {
117        InnerGuard(self.inner.lock(), PendingCallbacks::default())
118    }
119}
120
121impl Inner {
122    // Add a reference to the current epoch.
123    fn add_ref(&mut self) -> usize {
124        if let Some(count) = self.callbacks.back_mut().and_then(|cb| cb.count_mut()) {
125            *count += 1;
126        } else {
127            self.callbacks.push_back(Callback { data: Data { count: 1 }, vtable: None });
128        }
129        self.sequence + self.callbacks.len() - 1
130    }
131}
132
133impl InnerGuard<'_> {
134    fn defer(&mut self, callback: Callback) -> usize {
135        let InnerGuard(inner, pending_callbacks) = self;
136        if inner.callbacks.front().is_none_or(|cb| cb.count().unwrap() == 0) {
137            // There are no outstanding guards, so call the callback immediately.
138            pending_callbacks.push(callback);
139
140            // Return a sequence of 0, which `has_fired` below will always return true for.
141            0
142        } else {
143            inner.callbacks.push_back(callback);
144            inner.sequence + inner.callbacks.len() - 1
145        }
146    }
147
148    // Decrement a reference to the epoch at `sequence`.
149    fn sub_ref(&mut self, sequence: usize) {
150        let InnerGuard(inner, pending_callbacks) = self;
151        let index = sequence - inner.sequence;
152        let count = inner.callbacks[index].count_mut().unwrap();
153        *count -= 1;
154        // We need to call the callbacks if the count has reached zero, and the count is at the
155        // beginning of `callbacks` *and* there are actually callbacks queued.
156        if *count == 0 && index == 0 && inner.callbacks.len() > 1 {
157            while let Some(callback) = inner.callbacks.front() {
158                if let Some(count) = callback.count() {
159                    if count > 0 || inner.callbacks.len() == 1 {
160                        // We've encountered a count element which is either non-zero or has no
161                        // callbacks after it, so we're done.
162                        break;
163                    }
164                    inner.callbacks.pop_front();
165                } else {
166                    pending_callbacks.push(inner.callbacks.pop_front().unwrap());
167                }
168                inner.sequence += 1;
169            }
170        }
171    }
172}
173
174impl PendingCallbacks {
175    fn push(&mut self, callback: Callback) {
176        self.0.push(callback);
177    }
178}
179
180impl Drop for PendingCallbacks {
181    fn drop(&mut self) {
182        for callback in self.0.drain(..) {
183            callback.call();
184        }
185    }
186}
187
188impl Callback {
189    fn new(data: *const (), vtable: &'static RawWakerVTable) -> Self {
190        Self { data: Data { data }, vtable: Some(vtable) }
191    }
192
193    fn count(&self) -> Option<usize> {
194        if self.vtable.is_none() {
195            // SAFETY: If vtable is None, then it must be a count.
196            Some(unsafe { self.data.count })
197        } else {
198            None
199        }
200    }
201
202    fn count_mut(&mut self) -> Option<&mut usize> {
203        if self.vtable.is_none() {
204            // SAFETY: If vtable is None, then it must be a count.
205            Some(unsafe { &mut self.data.count })
206        } else {
207            None
208        }
209    }
210
211    fn call(mut self) {
212        // SAFETY: This is safe so long as the contract for RawWakerVTable is upheld.
213        unsafe {
214            Waker::new(self.data.data, self.vtable.take().unwrap()).wake();
215        }
216    }
217}
218
219impl Drop for Callback {
220    fn drop(&mut self) {
221        if let Some(vtable) = self.vtable {
222            // SAFETY: Safe so long as the contract for RawWakerVTable is upheld.
223            drop(unsafe { Waker::new(self.data.data, vtable) });
224        }
225    }
226}
227
228impl<F: FnOnce() + Send + Unpin + 'static> From<F> for Callback {
229    fn from(value: F) -> Self {
230        if std::mem::size_of::<F>() <= std::mem::size_of::<*const ()>() {
231            struct InlineCallback<F>(PhantomData<F>);
232
233            impl<F: FnOnce() + Send + Unpin + 'static> InlineCallback<F> {
234                const VTABLE: RawWakerVTable = RawWakerVTable::new(
235                    |_| unreachable!(),
236                    Self::wake,
237                    |_| unreachable!(),
238                    Self::drop,
239                );
240
241                unsafe fn wake(data: *const ()) {
242                    // SAFETY: We know `data` must be valid for size_of::<F>() bytes because we
243                    // copied that many bytes below.
244                    let callback = unsafe { std::mem::transmute_copy::<*const (), F>(&data) };
245                    callback();
246                }
247
248                unsafe fn drop(data: *const ()) {
249                    // SAFETY: We know `data` must be valid for size_of::<F>() bytes because we
250                    // copied that many bytes below.
251                    drop(unsafe { std::mem::transmute_copy::<*const (), F>(&data) });
252                }
253            }
254
255            let mut data = std::ptr::null();
256            let callback = ManuallyDrop::new(value);
257            // SAFETY: We checked the size of `F` above.
258            unsafe {
259                std::ptr::copy_nonoverlapping(
260                    callback.deref() as *const F as *const u8,
261                    &mut data as *mut _ as *mut u8,
262                    std::mem::size_of::<F>(),
263                );
264            }
265            Self::new(data, &InlineCallback::<F>::VTABLE)
266        } else {
267            struct BoxCallback<F>(PhantomData<F>);
268
269            impl<F: FnOnce() + Send + Unpin + 'static> BoxCallback<F> {
270                const VTABLE: RawWakerVTable = RawWakerVTable::new(
271                    |_| unreachable!(),
272                    Self::wake,
273                    |_| unreachable!(),
274                    Self::drop,
275                );
276
277                unsafe fn wake(data: *const ()) {
278                    // SAFETY: This is just the reverse of what we do below.
279                    let callback = unsafe { Box::from_raw(data as *mut F) };
280                    callback();
281                }
282
283                unsafe fn drop(data: *const ()) {
284                    // SAFETY: This is just the reverse of what we do below.
285                    drop(unsafe { Box::from_raw(data as *mut F) });
286                }
287            }
288            Callback::new(Box::into_raw(Box::new(value)) as *const (), &BoxCallback::<F>::VTABLE)
289        }
290    }
291}
292
293impl From<Waker> for Callback {
294    fn from(waker: Waker) -> Self {
295        // We consume the waker.
296        let waker = ManuallyDrop::new(waker);
297        Callback::new(waker.data(), waker.vtable())
298    }
299}
300
301impl Clone for EpochGuard<'_> {
302    fn clone(&self) -> Self {
303        let mut inner = self.epoch.inner.lock();
304        let index = self.sequence - inner.sequence;
305        *inner.callbacks[index].count_mut().unwrap() += 1;
306        Self { epoch: self.epoch, sequence: self.sequence }
307    }
308}
309
310impl Drop for EpochGuard<'_> {
311    fn drop(&mut self) {
312        self.epoch.inner_guard().sub_ref(self.sequence);
313    }
314}
315
316impl CallbackRef<'_> {
317    /// Returns true if the callback has fired.
318    pub fn has_fired(&self) -> bool {
319        // We use <= because the first entry in callbacks should always be a reference count, and we
320        // use sequence 0 when a callback is immediately called.
321        self.sequence <= self.epoch.inner.lock().sequence
322    }
323
324    /// Replaces the callback with a different callback. Returns `true` if successful, or `false` if
325    /// the existing callback has already been called.
326    #[must_use]
327    pub fn replace<F: FnOnce() + Send + Unpin + 'static>(&self, callback: F) -> bool {
328        let mut inner = self.epoch.inner.lock();
329        if self.sequence <= inner.sequence {
330            return false;
331        }
332        let index = self.sequence - inner.sequence;
333        inner.callbacks[index] = callback.into();
334        true
335    }
336
337    /// Same as `replace` but for a waker.
338    #[must_use]
339    pub fn replace_waker(&self, waker: &Waker) -> bool {
340        let mut inner = self.epoch.inner.lock();
341        if self.sequence <= inner.sequence {
342            return false;
343        }
344        let index = self.sequence - inner.sequence;
345        inner.callbacks[index] = waker.clone().into();
346        true
347    }
348}
349
350#[cfg(test)]
351mod test {
352    use super::*;
353    use futures::stream::{FuturesUnordered, StreamExt};
354    use std::sync::Arc;
355    use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
356    use std::{iter, thread};
357
358    #[test]
359    fn test_defer() {
360        let epoch = Epoch::default();
361        let called = Arc::new(AtomicBool::new(false));
362        let called_clone = called.clone();
363        let guard = epoch.guard();
364        let _cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
365        assert!(!called.load(Ordering::Relaxed));
366        drop(guard);
367        assert!(called.load(Ordering::Relaxed));
368    }
369
370    #[test]
371    fn test_defer_large_callback() {
372        let epoch = Epoch::default();
373        let large_data = [0u8; 1024];
374        let called = Arc::new(AtomicBool::new(false));
375        let called_clone = called.clone();
376        let guard = epoch.guard();
377        let _cb = epoch.defer(move || {
378            assert_eq!(large_data.len(), 1024);
379            called_clone.store(true, Ordering::Relaxed);
380        });
381        assert!(!called.load(Ordering::Relaxed));
382        drop(guard);
383        assert!(called.load(Ordering::Relaxed));
384    }
385
386    #[test]
387    fn test_defer_small_callback() {
388        let epoch = Epoch::default();
389        epoch.defer(|| {});
390        let b = 13u8;
391        let callback = move || assert_eq!(b, 13);
392        epoch.defer(callback);
393    }
394
395    #[test]
396    fn test_defer_when_no_guards() {
397        let epoch = Epoch::default();
398        let called = Arc::new(AtomicBool::new(false));
399        let called_clone = called.clone();
400        let cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
401        assert!(called.load(Ordering::Relaxed));
402        assert!(cb.has_fired());
403        assert!(!cb.replace(|| {}));
404        assert!(!cb.replace_waker(Waker::noop()));
405    }
406
407    #[test]
408    fn test_multiple_guards() {
409        let epoch = Epoch::default();
410        let guard1 = epoch.guard();
411        let guard2 = epoch.guard();
412        let called = Arc::new(AtomicBool::new(false));
413        let called_clone = called.clone();
414        let _cb = epoch.defer(move || called_clone.store(true, Ordering::Relaxed));
415        assert!(!called.load(Ordering::Relaxed));
416        drop(guard1);
417        assert!(!called.load(Ordering::Relaxed));
418        drop(guard2);
419        assert!(called.load(Ordering::Relaxed));
420    }
421
422    #[test]
423    fn test_multiple_guards_in_different_epoch() {
424        let epoch = Epoch::default();
425        let guard1 = epoch.guard();
426        let called1 = Arc::new(AtomicBool::new(false));
427        let called1_clone = called1.clone();
428        let _cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
429        let guard2 = epoch.guard();
430        let called2 = Arc::new(AtomicBool::new(false));
431        let called2_clone = called2.clone();
432        let _cb = epoch.defer(move || called2_clone.store(true, Ordering::Relaxed));
433        assert!(!called1.load(Ordering::Relaxed));
434        assert!(!called2.load(Ordering::Relaxed));
435        drop(guard1);
436        assert!(called1.load(Ordering::Relaxed));
437        assert!(!called2.load(Ordering::Relaxed));
438        drop(guard2);
439        assert!(called2.load(Ordering::Relaxed));
440    }
441
442    #[test]
443    fn test_multiple_guards_in_different_epoch_reverse_order() {
444        let epoch = Epoch::default();
445        let guard1 = epoch.guard();
446        let called1 = Arc::new(AtomicBool::new(false));
447        let called1_clone = called1.clone();
448        let _cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
449        let guard2 = epoch.guard();
450        let called2 = Arc::new(AtomicBool::new(false));
451        let called2_clone = called2.clone();
452        let _cb = epoch.defer(move || called2_clone.store(true, Ordering::Relaxed));
453        assert!(!called1.load(Ordering::Relaxed));
454        assert!(!called2.load(Ordering::Relaxed));
455        // Drop guard2 first.
456        drop(guard2);
457        assert!(!called1.load(Ordering::Relaxed));
458        assert!(!called2.load(Ordering::Relaxed));
459        drop(guard1);
460        assert!(called1.load(Ordering::Relaxed));
461        assert!(called2.load(Ordering::Relaxed));
462    }
463
464    #[test]
465    fn test_barrier() {
466        let epoch = Epoch::default();
467        let guard = epoch.guard();
468        let barrier_future = epoch.barrier();
469        // Use `FuturesUnordered because it uses its own wakers and so this will check that
470        // the waker is actually called.
471        let mut barrier_future: FuturesUnordered<_> = iter::once(barrier_future).collect();
472        let mut cx = std::task::Context::from_waker(Waker::noop());
473        assert!(barrier_future.poll_next_unpin(&mut cx).is_pending());
474        drop(guard);
475        assert!(barrier_future.poll_next_unpin(&mut cx).is_ready());
476    }
477
478    #[test]
479    fn test_has_fired() {
480        let epoch = Epoch::default();
481        let guard = epoch.guard();
482        let cb = epoch.defer(|| {});
483        assert!(!cb.has_fired());
484        drop(guard);
485        assert!(cb.has_fired());
486    }
487
488    #[test]
489    fn test_replace() {
490        let epoch = Epoch::default();
491        let called1 = Arc::new(AtomicBool::new(false));
492        let called2 = Arc::new(AtomicBool::new(false));
493        let called1_clone = called1.clone();
494        let called2_clone = called2.clone();
495        let guard = epoch.guard();
496        let cb = epoch.defer(move || called1_clone.store(true, Ordering::Relaxed));
497        assert!(cb.replace(move || called2_clone.store(true, Ordering::Relaxed)));
498        drop(guard);
499        assert!(!called1.load(Ordering::Relaxed));
500        assert!(called2.load(Ordering::Relaxed));
501        assert!(!cb.replace(|| {}));
502    }
503
504    #[test]
505    fn test_barrier_race() {
506        let epoch = Epoch::default();
507        thread::scope(|s| {
508            s.spawn(|| {
509                for _ in 0..1000 {
510                    let _guard = epoch.guard();
511                }
512            });
513            s.spawn(|| {
514                for _ in 0..1000 {
515                    let _guard = epoch.guard();
516                }
517            });
518            s.spawn(|| {
519                for _ in 0..1000 {
520                    futures::executor::block_on(async {
521                        epoch.barrier().await;
522                    });
523                }
524            });
525        });
526    }
527
528    #[test]
529    fn test_callback_reentrancy() {
530        let epoch = Arc::new(Epoch::default());
531        let counter = Arc::new(AtomicU8::new(0));
532
533        let epoch_clone = epoch.clone();
534        let counter_clone = counter.clone();
535        epoch.defer(move || {
536            epoch_clone.defer(move || {
537                counter_clone.fetch_add(1, Ordering::Relaxed);
538            });
539        });
540
541        assert_eq!(counter.load(Ordering::Relaxed), 1);
542
543        let guard = epoch.guard();
544        let epoch_clone = epoch.clone();
545        let counter_clone = counter.clone();
546        epoch.defer(move || {
547            epoch_clone.defer(move || {
548                counter_clone.fetch_add(1, Ordering::Relaxed);
549            });
550        });
551
552        assert_eq!(counter.load(Ordering::Relaxed), 1);
553
554        drop(guard);
555
556        assert_eq!(counter.load(Ordering::Relaxed), 2);
557    }
558}