_core_rustc_static/
recursion_guard.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cell::Cell;
6
7/// Executes the given function in a special context that cannot be recursively re-entered
8/// on the same execution stack.
9///
10/// It can be used to prevent unwanted recursive function executions.
11pub fn with_recursion_guard<T>(f: impl FnOnce() -> T) -> T {
12    match with_recursion_guard_impl(f) {
13        Ok(result) => result,
14        Err(UnwantedRecursionError) => {
15            // WARNING! Do not call panic! because it may itself allocate and cause further
16            // recursion. This still results in a backtrace in the log.
17            std::process::abort()
18        }
19    }
20}
21
22thread_local! {
23    /// Whether the current thread is currently in a `with_recursion_guard_impl` call or not.
24    static RECURSION_GUARD: Cell<bool> = const { Cell::new(false) };
25}
26
27#[derive(Debug)]
28struct UnwantedRecursionError;
29
30fn with_recursion_guard_impl<T>(f: impl FnOnce() -> T) -> Result<T, UnwantedRecursionError> {
31    RECURSION_GUARD.with(|cell| {
32        let was_already_acquired = cell.replace(true);
33        if was_already_acquired {
34            return Err(UnwantedRecursionError);
35        }
36
37        let result = f();
38
39        cell.set(false);
40
41        Ok(result)
42    })
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use assert_matches::assert_matches;
49
50    // Verify that executing a non-recursive function succeeds.
51    #[test]
52    fn test_recursion_guard_ok() {
53        let result = with_recursion_guard_impl(|| 42);
54        assert_matches!(result, Ok(42));
55    }
56
57    // Verify that the inner recursive call fails.
58    #[test]
59    fn test_recursion_guard_violation() {
60        let result = with_recursion_guard_impl(|| with_recursion_guard_impl(|| 42));
61        assert_matches!(result, Ok(Err(UnwantedRecursionError)));
62    }
63}