replace_with/
lib.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! `replace-with` provides the [`replace_with`] function.
6
7use std::{mem, ptr};
8
9/// Uses `f` to replace the referent of `dst` with a new value.
10///
11/// Reads the current value in `dst`, calls `f` on that value, and overwrites
12/// `dst` using the new value returned by `f`. If `f` panics, the process is
13/// aborted.
14///
15/// This is useful for updating a value whose type is not [`Copy`].
16///
17/// # Examples
18///
19/// ```rust
20/// # use replace_with::replace_with;
21/// /// A value that might be stored on the heap (boxed) or on the stack (unboxed).
22/// pub enum MaybeBoxed<T> {
23///     Boxed(Box<T>),
24///     Unboxed(T),
25/// }
26///
27/// impl<T> MaybeBoxed<T> {
28///     /// Ensures that `self` is boxed, moving the value to the heap if necessary.
29///     pub fn ensure_boxed(&mut self) {
30///         replace_with(self, |m| match m {
31///             MaybeBoxed::Boxed(b) => MaybeBoxed::Boxed(b),
32///             MaybeBoxed::Unboxed(u) => MaybeBoxed::Boxed(Box::new(u)),
33///         })
34///     }
35/// }
36/// ```
37pub fn replace_with<T, F: FnOnce(T) -> T>(dst: &mut T, f: F) {
38    replace_with_and(dst, move |t| (f(t), ()))
39}
40
41/// Uses `f` to replace the referent of `dst` and returns a value from the
42/// transformation.
43///
44/// Like [`replace_with`] but the provided function returns a tuple of `(T, R`)
45/// where `T` is the new value for `dst` and `R` is returned from
46/// `replace_with_and`.
47pub fn replace_with_and<T, R, F: FnOnce(T) -> (T, R)>(dst: &mut T, f: F) -> R {
48    // This is not necessary today, but it may be necessary if the "strict
49    // pointer provenance" model [1] is adopted in the future.
50    //
51    // [1] https://github.com/rust-lang/rust/issues/95228
52    let dst = dst as *mut T;
53
54    // SAFETY:
55    // - The initial `ptr::read` is sound because `dst` is derived from a `&mut
56    //   T`, and so all of `ptr::read`'s safety preconditions are satisfied:
57    //   - `dst` is valid for reads
58    //   - `dst` is properly aligned
59    //   - `dst` points at a properly initialized value of type `T`
60    // - After `ptr::read` is called, we've created a copy of `*dst`. Since `T:
61    //   !Copy`, it is not guaranteed that operating on both copies would be
62    //   sound. Since we allow `f` to operate on `old`, we have to ensure that
63    //   no code operates on `*dst`. This could happen in a few circumstances:
64    //   - Code in this function could operate on `*dst`, which it doesn't.
65    //   - Code in `f` could operate on `*dst`. Since `dst` is a mutable
66    //     reference, and it is borrowed for the duration of this function call,
67    //     `f` cannot also access `dst` (code that attempted to do that would
68    //     fail to compile).
69    //   - The caller could operate on `dst` after the function returns. There
70    //     are two cases:
71    //     - In the success case, `f` returns without panicking. It returns a
72    //       new `T`, and we overwrite `*dst` with this new `T` using
73    //       `ptr::write`. At this point, it is sound for code to operate on
74    //       `*dst`, and so it is sound for this function to return.
75    //     - In the failure case, `f` panics. Since, at the point we call `f`,
76    //       we have not overwritten `*dst` yet, it would be unsound if the
77    //       panic were to unwind the stack, allowing code from the caller to
78    //       run. Since we call `f` within a call to `abort_on_panic`, we are
79    //       guaranteed that the process would abort, and no future code could
80    //       run.
81    // - The call to `ptr::write` itself is sound because, thanks to `dst`
82    //   being derived from a `&mut T`, all of `ptr::write`'s preconditions are
83    //   satisfied:
84    //   - `dst` is valid for writes
85    //   - `dst` is properly aligned
86    unsafe {
87        let old = ptr::read(dst);
88        let (new, ret) = abort_on_panic(move || f(old));
89        ptr::write(dst, new);
90        ret
91    }
92}
93
94/// Calls `f` or aborts the process if `f` panics.
95fn abort_on_panic<T, F: FnOnce() -> T>(f: F) -> T {
96    struct CallOnDrop<O, F: Fn() -> O>(F);
97    impl<O, F: Fn() -> O> Drop for CallOnDrop<O, F> {
98        #[cold]
99        fn drop(&mut self) {
100            (self.0)();
101        }
102    }
103
104    let backtrace_and_abort_on_drop = CallOnDrop(|| {
105        // SAFETY: This guard ensures that we abort in both of the following two
106        // cases:
107        // - The code executes normally (the guard is dropped at the end of the
108        //   function)
109        // - The backtrace code panics (the guard is dropped during unwinding)
110        //
111        // No functions called from the backtrace code are documented to panic,
112        // but this serves as a hedge in case there are undocumented panic
113        // conditions.
114        let abort_on_drop = CallOnDrop(std::process::abort);
115
116        use std::io::Write as _;
117        let backtrace = std::backtrace::Backtrace::force_capture();
118        let mut stderr = std::io::stderr().lock();
119        // We treat backtrace-printing as best-effort, so we ignore any errors.
120        let _ = write!(&mut stderr, "replace_with: callback panicked; backtrace:\n{backtrace}\n");
121        let _ = stderr.flush();
122
123        mem::drop(abort_on_drop);
124    });
125
126    let t = f();
127    mem::forget(backtrace_and_abort_on_drop);
128    t
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_replace_with() {
137        let mut x = 1usize;
138        replace_with(&mut x, |x| x * 2);
139        assert_eq!(x, 2);
140    }
141}