fuchsia_rcu/
state_machine.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::atomic_stack::AtomicStack;
6use fuchsia_sync::Mutex;
7use std::cell::Cell;
8use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
9use std::thread_local;
10
11type RcuCallback = Box<dyn FnOnce() + Send + Sync + 'static>;
12
13/// The length of the queue of waiting callbacks.
14///
15/// The state machine waits for this many generations to complete before running these callbacks.
16const QUEUE_LENGTH: usize = 2;
17
18/// The queue of waiting callbacks.
19///
20/// The queue is a ring buffer of sets of callbacks of length `QUEUE_LENGTH`.
21struct CallbackQueue {
22    /// The callbacks that are waiting to be run.
23    ///
24    /// The callbacks are stored in a ring buffer.
25    callbacks: [Vec<RcuCallback>; QUEUE_LENGTH],
26}
27
28impl CallbackQueue {
29    /// Create an empty callback queue.
30    const fn new() -> Self {
31        Self { callbacks: [Vec::new(), Vec::new()] }
32    }
33
34    /// Add a set of callbacks to the queue for the given generation.
35    ///
36    /// # Panics
37    ///
38    /// Panics if the slot for the given generation is already occupied.
39    fn enqueue(&mut self, generation: usize, callbacks: Vec<RcuCallback>) {
40        let index = generation % QUEUE_LENGTH;
41        assert!(
42            self.callbacks[index].is_empty(),
43            "Queue slot for generation {} is already occupied",
44            generation
45        );
46        self.callbacks[index] = callbacks;
47    }
48
49    /// Pops the set of callbacks that are ready to be run after the given generation completes.
50    fn take_ready(&mut self, generation: usize) -> Vec<RcuCallback> {
51        // We take the callbacks that have reached the end of the queue, which is the same as the
52        // slot in the queue that the next generation will occupy.
53        let index = (generation + 1) % QUEUE_LENGTH;
54        std::mem::take(&mut self.callbacks[index])
55    }
56}
57
58struct RcuControlBlock {
59    /// The generation counter.
60    ///
61    /// The generation counter is incremented whenever the state machine leaves the `Idle` state.
62    generation: AtomicUsize,
63
64    /// The read counters.
65    ///
66    /// Readers increment the counter for the generation that they are reading from. For example,
67    /// if the `generation` is even, then readers increment the counter for the `read_counters[0]`.
68    /// If the `generation` is odd, then readers increment the counter for the `read_counters[1]`.
69    read_counters: [AtomicUsize; 2],
70
71    /// The chain of callbacks that are waiting to be run.
72    ///
73    /// Writers add callbacks to this chain after writing to the object. The callbacks are run when
74    /// all currently in-flight read operations have completed.
75    callback_chain: AtomicStack<RcuCallback>,
76
77    /// The futex used to wait for the state machine to advance.
78    advancer: zx::Futex,
79
80    /// The queue of waiting callbacks.
81    ///
82    /// Callbacks are added to this queue when the state machine leaves the `Idle` state. They are
83    /// run when the state machine leaves the `Waiting` state after `QUEUE_LENGTH` generations
84    /// have completed.
85    waiting_callbacks: Mutex<CallbackQueue>,
86}
87
88const ADVANCER_IDLE: i32 = 0;
89const ADVANCER_WAITING: i32 = 1;
90
91impl RcuControlBlock {
92    /// Create a new control block for the RCU state machine.
93    const fn new() -> Self {
94        Self {
95            generation: AtomicUsize::new(0),
96            read_counters: [AtomicUsize::new(0), AtomicUsize::new(0)],
97            callback_chain: AtomicStack::new(),
98            advancer: zx::Futex::new(ADVANCER_IDLE),
99            waiting_callbacks: Mutex::new(CallbackQueue::new()),
100        }
101    }
102}
103
104/// The control block for the RCU state machine.
105static RCU_CONTROL_BLOCK: RcuControlBlock = RcuControlBlock::new();
106
107#[derive(Default)]
108struct RcuThreadBlock {
109    /// The number of times the thread has nested into a read lock.
110    nesting_level: Cell<usize>,
111
112    /// The index of the read counter that the thread incremented when it entered its outermost read
113    /// lock.
114    counter_index: Cell<u8>,
115
116    /// Whether this thread has scheduled callbacks since the last time the thread called
117    /// `rcu_synchronize`.
118    has_pending_callbacks: Cell<bool>,
119}
120
121impl RcuThreadBlock {
122    /// Returns true if the thread is holding a read lock.
123    fn holding_read_lock(&self) -> bool {
124        self.nesting_level.get() > 0
125    }
126}
127
128thread_local! {
129    /// Thread-specific data for the RCU state machine.
130    ///
131    /// This data is used to track the nesting level of read locks and the index of the read counter
132    /// that the thread incremented when it entered its outermost read lock.
133    static RCU_THREAD_BLOCK: RcuThreadBlock = RcuThreadBlock::default();
134}
135
136/// Acquire a read lock.
137///
138/// This function is used to acquire a read lock on the RCU state machine. The RCU state machine
139/// defers calling callbacks until all currently in-flight read operations have completed.
140///
141/// Must be balanced by a call to `rcu_read_unlock` on the same thread.
142pub(crate) fn rcu_read_lock() {
143    RCU_THREAD_BLOCK.with(|block| {
144        let nesting_level = block.nesting_level.get();
145        if nesting_level > 0 {
146            // If this thread already has a read lock, increment the nesting level instead of the
147            // incrementing the read counter. This approach is a performance optimization to reduce
148            // the number of atomic operations that need to be performed.
149            block.nesting_level.set(nesting_level + 1);
150        } else {
151            // This is the outermost read lock. Increment the read counter.
152            let index = RCU_CONTROL_BLOCK.generation.load(Ordering::Relaxed) & 1;
153            // Synchronization point [A] (see design.md)
154            RCU_CONTROL_BLOCK.read_counters[index].fetch_add(1, Ordering::SeqCst);
155            block.counter_index.set(index as u8);
156            block.nesting_level.set(1);
157        }
158    });
159}
160
161/// Release a read lock.
162///
163/// This function is used to release a read lock on the RCU state machine. See `rcu_read_lock` for
164/// more details.
165pub(crate) fn rcu_read_unlock() {
166    RCU_THREAD_BLOCK.with(|block| {
167        let nesting_level = block.nesting_level.get();
168        if nesting_level > 1 {
169            // If the nesting level is greater than 1, this is not the outermost read lock.
170            // Decrement the nesting level instead of the read counter.
171            block.nesting_level.set(nesting_level - 1);
172        } else {
173            // This is the outermost read lock. Decrement the read counter.
174            let index = block.counter_index.get() as usize;
175            // Synchronization point [B] (see design.md)
176            let previous_count =
177                RCU_CONTROL_BLOCK.read_counters[index].fetch_sub(1, Ordering::SeqCst);
178            if previous_count == 1 {
179                rcu_advancer_wake_all();
180            }
181            block.nesting_level.set(0);
182            block.counter_index.set(u8::MAX);
183        }
184    });
185}
186
187/// Read the value of an RCU pointer.
188///
189/// This function cannot be called unless the current thread is holding a read lock. The returned
190/// pointer is valid until the read lock is released.
191pub(crate) fn rcu_read_pointer<T>(ptr: &AtomicPtr<T>) -> *const T {
192    // Synchronization point [D] (see design.md)
193    ptr.load(Ordering::Acquire)
194}
195
196/// Assign a new value to an RCU pointer.
197///
198/// Concurrent readers may continue to reference the old value of the pointer until the RCU state
199/// machine has made sufficient progress. To clean up the old value of the pointer, use `rcu_call`
200/// or `rcu_drop`, which defer processing until all in-flight read operations have completed.
201pub(crate) fn rcu_assign_pointer<T>(ptr: &AtomicPtr<T>, new_ptr: *mut T) {
202    // Synchronization point [E] (see design.md)
203    ptr.store(new_ptr, Ordering::Release);
204}
205
206/// Replace the value of an RCU pointer.
207///
208/// Concurrent readers may continue to reference the old value of the pointer until the RCU state
209/// machine has made sufficient progress. To clean up the old value of the pointer, use `rcu_call`
210/// or `rcu_drop`, which defer processing until all in-flight read operations have completed.
211pub(crate) fn rcu_replace_pointer<T>(ptr: &AtomicPtr<T>, new_ptr: *mut T) -> *mut T {
212    // Synchronization point [F] (see design.md)
213    ptr.swap(new_ptr, Ordering::AcqRel)
214}
215
216/// Call a callback to run after all in-flight read operations have completed.
217///
218/// To wait until the callback is run, call `rcu_synchronize()`. The callback might be called from
219/// an arbitrary thread.
220pub(crate) fn rcu_call(callback: impl FnOnce() + Send + Sync + 'static) {
221    RCU_THREAD_BLOCK.with(|block| {
222        block.has_pending_callbacks.set(true);
223    });
224
225    // Even though we push the callback to the front of the stack, we reverse the order of the stack
226    // when we pop the callbacks from the stack to ensure that the callbacks are run in the order in
227    // which they were scheduled.
228
229    // Synchronization point [G] (see design.md)
230    RCU_CONTROL_BLOCK.callback_chain.push_front(Box::new(callback));
231}
232
233/// Schedule the object to be dropped after all in-flight read operations have completed.
234///
235/// To wait until the object is dropped, call `rcu_synchronize()`. The object might be dropped from
236/// an arbitrary thread.
237pub fn rcu_drop<T: Send + Sync + 'static>(value: T) {
238    rcu_call(move || {
239        std::mem::drop(value);
240    });
241}
242
243/// Check if there are any active readers for the given generation.
244fn has_active_readers(generation: usize) -> bool {
245    let i = generation & 1;
246    // Synchronization point [C] (see design.md)
247    RCU_CONTROL_BLOCK.read_counters[i].load(Ordering::SeqCst) > 0
248}
249
250/// Wake up all the threads that are waiting to advance the state machine.
251///
252/// Does nothing if no threads are waiting.
253fn rcu_advancer_wake_all() {
254    let advancer = &RCU_CONTROL_BLOCK.advancer;
255    if advancer.load(Ordering::SeqCst) == ADVANCER_WAITING {
256        advancer.store(ADVANCER_IDLE, Ordering::Relaxed);
257        advancer.wake_all();
258    }
259}
260
261/// Blocks the current thread until all in-flight read operations have completed for the given
262/// generation.
263///
264/// Postcondition: The number of active readers for the given generation is zero and the advancer
265/// futex contains `ADVANCER_IDLE`.
266fn rcu_advancer_wait(generation: usize) {
267    let advancer = &RCU_CONTROL_BLOCK.advancer;
268    loop {
269        // In order to avoid a race with `rcu_advancer_wake_all`, we must store `ADVANCER_WAITING`
270        // before checking if there are any active readers.
271        //
272        // In the single total order, either this store or the last decrement to the reader counter
273        // must happen first.
274        //
275        //  (1) If this store happens first, then the last thread to decrement the reader counter
276        //      for this generation will observe `ADVANCER_WAITING` and will reset the value to
277        //      `ADVANCER_IDLE` and wake the futex, unblocking this thread.
278        //
279        //  (2) If the last decrement to the reader counter happens first, then this thread will see
280        //      that there are no active readers in this generation and avoid blocking on the futex.
281        advancer.store(ADVANCER_WAITING, Ordering::SeqCst);
282        if !has_active_readers(generation) {
283            break;
284        }
285        let _ = advancer.wait(ADVANCER_WAITING, None, zx::MonotonicInstant::INFINITE);
286    }
287    advancer.store(ADVANCER_IDLE, Ordering::SeqCst);
288}
289
290/// Advance the RCU state machine.
291///
292/// This function blocks until all in-flight read operations have completed for the current
293/// generation and all callbacks have been run.
294fn rcu_grace_period() {
295    let callbacks = {
296        let mut waiting_callbacks = RCU_CONTROL_BLOCK.waiting_callbacks.lock();
297        // We are in the *Idle* state.
298
299        // Synchronization point [H] (see design.md)
300        let callbacks = RCU_CONTROL_BLOCK.callback_chain.drain();
301        let generation = RCU_CONTROL_BLOCK.generation.fetch_add(1, Ordering::Relaxed);
302
303        waiting_callbacks.enqueue(generation, callbacks);
304
305        // Enter the *Waiting* state.
306        rcu_advancer_wait(generation);
307        waiting_callbacks.take_ready(generation)
308
309        // Return to the *Idle* state.
310    };
311
312    // Run the callbacks in reverse order to ensure that the callbacks are run in the order in which
313    // they were scheduled.
314    for callback in callbacks.into_iter().rev() {
315        callback();
316    }
317}
318
319/// Block until all in-flight read operations and callbacks have completed.
320pub fn rcu_synchronize() {
321    RCU_THREAD_BLOCK.with(|block| {
322        assert!(!block.holding_read_lock());
323        block.has_pending_callbacks.set(false);
324    });
325    for _ in 0..QUEUE_LENGTH {
326        rcu_grace_period();
327    }
328}
329
330/// Run all callbacks that have been scheduled from this thread.
331///
332/// If any callbacks have been scheduled from this thread, this function will block until all
333/// callbacks have been run. If no callbacks have been scheduled from this thread, this function
334/// will return immediately.
335pub fn rcu_run_callbacks() {
336    RCU_THREAD_BLOCK.with(|block| {
337        assert!(!block.holding_read_lock());
338        if block.has_pending_callbacks.get() {
339            rcu_synchronize();
340        }
341    })
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use std::sync::Arc;
348    use std::sync::atomic::{AtomicBool, Ordering};
349
350    #[test]
351    fn test_rcu_delay_regression() {
352        // This test relies on the global RCU state machine.
353        // It verifies that callbacks are NOT executed immediately after one grace period.
354
355        let flag = Arc::new(AtomicBool::new(false));
356        let moved_flag = flag.clone();
357
358        rcu_call(move || {
359            moved_flag.store(true, Ordering::SeqCst);
360        });
361
362        for _ in 0..QUEUE_LENGTH - 1 {
363            rcu_grace_period();
364
365            assert!(
366                !flag.load(Ordering::SeqCst),
367                "Callback executed too early! RCU requires QUEUE_LENGTH grace periods delay."
368            );
369        }
370
371        rcu_grace_period();
372        assert!(
373            flag.load(Ordering::SeqCst),
374            "Callback should have executed after QUEUE_LENGTH grace periods"
375        );
376    }
377
378    #[test]
379    fn test_rcu_synchronize() {
380        // This test relies on the global RCU state machine.
381        // It verifies that rcu_synchronize() blocks until all callbacks have been run.
382
383        let flag = Arc::new(AtomicBool::new(false));
384        let moved_flag = flag.clone();
385
386        rcu_call(move || {
387            moved_flag.store(true, Ordering::SeqCst);
388        });
389
390        rcu_synchronize();
391        assert!(
392            flag.load(Ordering::SeqCst),
393            "Callback should have executed after rcu_synchronize()"
394        );
395    }
396
397    #[test]
398    fn test_rcu_run_callbacks() {
399        // This test relies on the global RCU state machine.
400        // It verifies that rcu_run_callbacks() blocks until all callbacks have been run.
401
402        let flag = Arc::new(AtomicBool::new(false));
403        let moved_flag = flag.clone();
404
405        rcu_call(move || {
406            moved_flag.store(true, Ordering::SeqCst);
407        });
408
409        rcu_run_callbacks();
410        assert!(
411            flag.load(Ordering::SeqCst),
412            "Callback should have executed after rcu_run_callbacks()"
413        );
414    }
415}