Skip to main content

fbl/
ref_ptr.rs

1// Copyright 2026 The Fuchsia Authors
2//
3// Use of this source code is governed by a MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT
6
7use crate::recyclable::{Recyclable, UninitRecyclable};
8use crate::ref_counted::HasRefCount;
9use core::mem::MaybeUninit;
10use core::ops::Deref;
11use core::ptr::NonNull;
12use kalloc::AllocError;
13
14use pin_init::{Init, PinInit};
15
16/// `RefPtr<T>` holds a reference to an intrusively-refcounted object of type
17/// T that deletes the object when the refcount drops to 0.
18///
19/// T should be a struct that contains a `fbl::RefCounted` field and implements
20/// `HasRefCount` and `Destroy` traits.
21#[repr(C)]
22pub struct RefPtr<T>
23where
24    T: HasRefCount + Recyclable,
25{
26    ptr: NonNull<T>,
27}
28
29impl<T: HasRefCount + Recyclable> RefPtr<T> {
30    /// Constructs a `RefPtr` from a raw pointer that has already been adopted.
31    ///
32    /// # Safety
33    ///
34    /// - The caller must ensure that `ptr` is valid and has a ref count already
35    ///   acquired.
36    /// - `ptr` must have been allocated in such a way that calling `T::recycle(ptr)` is a
37    ///   correct way to deallocate the pointer.
38    pub unsafe fn from_raw(ptr: *const T) -> Self {
39        // SAFETY: The caller must ensure that ptr is valid.
40        unsafe { RefPtr { ptr: NonNull::new_unchecked(ptr as *mut T) } }
41    }
42
43    /// Helper function that allocates a new instance of `T` using `T::allocate` and
44    /// returns a `RefPtr` wrapping it.
45    ///
46    /// This is an internal helper function that should not be used directly.
47    /// Use the `make_ref_counted!(...)` macro instead of this function to properly
48    /// initialize the ref count.
49    ///
50    /// # Safety
51    ///
52    /// The caller must ensure that `T` has a RefCounted field that is not
53    /// already adopted.
54    pub unsafe fn try_new(value: T) -> Result<RefPtr<T>, AllocError> {
55        let mut ptr = T::allocate(value)?;
56        // SAFETY: The caller must ensure that T has a RefCounted field that is not
57        // already adopted.
58        unsafe { ptr.as_mut().ref_count().adopt() };
59        Ok(RefPtr { ptr })
60    }
61
62    /// Returns the raw pointer to the object.
63    pub fn as_ptr(this: &Self) -> *const T {
64        this.ptr.as_ptr()
65    }
66
67    /// Returns `true` if the two `RefPtr`s point to the same object.
68    pub fn ptr_eq(a: &Self, b: &Self) -> bool {
69        a.ptr == b.ptr
70    }
71
72    /// Consume the `RefPtr` and return the raw pointer without modifying the ref count.
73    ///
74    /// The caller is responsible for maintaining the reference count.
75    pub fn into_raw(this: Self) -> *const T {
76        let ptr = this.ptr;
77        core::mem::forget(this);
78        ptr.as_ptr()
79    }
80
81    /// Use the given pin-initializer to pin-initialize a `T` inside of a new `RefPtr`.
82    pub fn try_pin_init<E>(init: impl PinInit<T, E>) -> Result<Self, E>
83    where
84        T: UninitRecyclable,
85        E: From<AllocError>,
86    {
87        let ptr = T::allocate_uninit()?;
88        let guard = UninitRefGuard { ptr };
89        let slot = guard.ptr.as_ptr() as *mut T;
90        // SAFETY: `slot` is valid and will not be moved.
91        unsafe { init.__pinned_init(slot)? };
92        // SAFETY: The object is now initialized, so we can access its ref_count.
93        unsafe { (*slot).ref_count().adopt() };
94        let initialized_ptr = guard.ptr.cast::<T>();
95        core::mem::forget(guard);
96        let initialized_ref = RefPtr { ptr: initialized_ptr };
97        Ok(initialized_ref)
98    }
99
100    /// Use the given initializer to in-place initialize a `T` inside of a new `RefPtr`.
101    pub fn try_init<E>(init: impl Init<T, E>) -> Result<Self, E>
102    where
103        T: UninitRecyclable,
104        E: From<AllocError>,
105    {
106        let ptr = T::allocate_uninit()?;
107        let guard = UninitRefGuard { ptr };
108        let slot = guard.ptr.as_ptr() as *mut T;
109        // SAFETY: `slot` is valid.
110        unsafe { init.__init(slot)? };
111        // SAFETY: The object is now initialized, so we can access its ref_count.
112        unsafe { (*slot).ref_count().adopt() };
113        let initialized_ptr = guard.ptr.cast::<T>();
114        core::mem::forget(guard);
115        Ok(RefPtr { ptr: initialized_ptr })
116    }
117
118    /// Use the given pin-initializer to pin-initialize a `T` inside of a new `RefPtr`.
119    #[inline]
120    pub fn pin_init(init: impl PinInit<T, core::convert::Infallible>) -> Result<Self, AllocError>
121    where
122        T: UninitRecyclable,
123    {
124        let init = unsafe {
125            ::pin_init::pin_init_from_closure(|slot| {
126                init.__pinned_init(slot).map_err(|i| match i {})
127            })
128        };
129        Self::try_pin_init(init)
130    }
131
132    /// Use the given initializer to in-place initialize a `T` inside of a new `RefPtr`.
133    #[inline]
134    pub fn init(init: impl Init<T, core::convert::Infallible>) -> Result<Self, AllocError>
135    where
136        T: UninitRecyclable,
137    {
138        let init = unsafe {
139            ::pin_init::init_from_closure(|slot| init.__init(slot).map_err(|i| match i {}))
140        };
141        Self::try_init(init)
142    }
143}
144
145impl<T: HasRefCount + Recyclable> Deref for RefPtr<T> {
146    type Target = T;
147    fn deref(&self) -> &Self::Target {
148        unsafe { self.ptr.as_ref() }
149    }
150}
151
152impl<T: HasRefCount + Recyclable> Clone for RefPtr<T> {
153    fn clone(&self) -> Self {
154        self.deref().ref_count().add_ref();
155        RefPtr { ptr: self.ptr }
156    }
157}
158
159impl<T: HasRefCount + Recyclable> Drop for RefPtr<T> {
160    fn drop(&mut self) {
161        if self.deref().ref_count().release() {
162            unsafe {
163                T::recycle(self.ptr);
164            }
165        }
166    }
167}
168
169impl<T: HasRefCount + Recyclable> PartialEq for RefPtr<T> {
170    fn eq(&self, other: &Self) -> bool {
171        RefPtr::ptr_eq(self, other)
172    }
173}
174
175impl<T: HasRefCount + Recyclable> Eq for RefPtr<T> {}
176
177unsafe impl<T: HasRefCount + Recyclable + Send + Sync> Send for RefPtr<T> {}
178unsafe impl<T: HasRefCount + Recyclable + Send + Sync> Sync for RefPtr<T> {}
179
180struct UninitRefGuard<T: UninitRecyclable> {
181    ptr: NonNull<MaybeUninit<T>>,
182}
183
184impl<T: UninitRecyclable> Drop for UninitRefGuard<T> {
185    fn drop(&mut self) {
186        unsafe {
187            T::recycle_uninit(self.ptr);
188        }
189    }
190}
191
192/// Macro to construct a RefPtr, automatically populating the ref_count field.
193#[macro_export]
194macro_rules! make_ref_counted {
195    ($ty:ident { $($field:ident : $val:expr),* $(,)? }) => {
196        // SAFETY: The macro creates a new object with a ref count of 1.
197        unsafe {
198            $crate::RefPtr::try_new($ty {
199                ref_count: $crate::RefCounted::new(),
200                __fbl_ref_counted_guard: (),
201                $($field : $val),*
202            })
203        }
204    };
205}
206
207/// Macro to construct a RefPtr with pin-initialization, automatically populating the ref_count field.
208#[macro_export]
209macro_rules! pin_make_ref_counted {
210    ($ty:ident { $($field:tt)* }) => {
211        $crate::RefPtr::pin_init($crate::pin_init::pin_init!($ty {
212            ref_count: $crate::RefCounted::new(),
213            __fbl_ref_counted_guard: (),
214            $($field)*
215        }))
216    };
217}
218
219/// Macro to construct a RefPtr with fallible pin-initialization, automatically populating the ref_count field.
220#[macro_export]
221macro_rules! try_pin_make_ref_counted {
222    ($ty:ident { $($field:tt)* }) => {
223        $crate::RefPtr::try_pin_init($crate::pin_init::pin_init!($ty {
224            ref_count: $crate::RefCounted::new(),
225            __fbl_ref_counted_guard: (),
226            $($field)*
227        }))
228    };
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use core::ffi::c_void;
235    use core::pin::Pin;
236    use core::sync::atomic::{AtomicBool, Ordering};
237
238    extern crate alloc;
239    use alloc::sync::Arc;
240
241    #[unsafe(no_mangle)]
242    pub extern "C" fn rust_recycle_test_rust_ref_counted(ptr: *mut c_void) {
243        unsafe { TestRustRefCounted::recycle_ffi(ptr) }
244    }
245
246    unsafe extern "C" {
247        fn test_import_rust_ref_counted(ptr: *mut c_void);
248    }
249
250    #[fbl::ref_counted]
251    #[pin_init::pin_data(PinnedDrop)]
252    #[derive(crate::Recyclable)]
253    #[repr(C)]
254    pub struct TestRustRefCounted {
255        destroyed: Arc<AtomicBool>,
256    }
257
258    ::zr::static_assert!(core::mem::size_of::<RefPtr<TestRustRefCounted>>() == 8);
259    ::zr::static_assert!(core::mem::align_of::<RefPtr<TestRustRefCounted>>() == 8);
260    ::zr::static_assert!(core::mem::size_of::<Option<RefPtr<TestRustRefCounted>>>() == 8);
261    ::zr::static_assert!(core::mem::align_of::<Option<RefPtr<TestRustRefCounted>>>() == 8);
262
263    #[pin_init::pinned_drop]
264    impl pin_init::PinnedDrop for TestRustRefCounted {
265        fn drop(self: Pin<&mut Self>) {
266            self.destroyed.store(true, Ordering::Relaxed);
267        }
268    }
269
270    #[test]
271    fn test_rust_drops_reference() {
272        let destroyed = Arc::new(AtomicBool::new(false));
273        {
274            let ref_ptr =
275                make_ref_counted!(TestRustRefCounted { destroyed: destroyed.clone() }).unwrap();
276            assert!(!destroyed.load(Ordering::Relaxed));
277            let ref_ptr_clone = ref_ptr.clone();
278            drop(ref_ptr_clone);
279            assert!(!destroyed.load(Ordering::Relaxed));
280        } // Drop ref_ptr -> count becomes 0 -> calls destroy -> triggers Drop trait!
281
282        assert!(destroyed.load(Ordering::Relaxed));
283    }
284
285    #[test]
286    #[cfg_attr(miri, ignore = "miri does not support calling foreign functions")]
287    fn test_cpp_drops_reference() {
288        let destroyed = Arc::new(AtomicBool::new(false));
289        let ref_ptr =
290            make_ref_counted!(TestRustRefCounted { destroyed: destroyed.clone() }).unwrap();
291        let raw_ptr = RefPtr::into_raw(ref_ptr);
292
293        unsafe {
294            assert!(!destroyed.load(Ordering::Relaxed));
295            // Pass to C++!
296            test_import_rust_ref_counted(raw_ptr as *const TestRustRefCounted as *mut c_void);
297            // C++ should have acquired reference and released it!
298            // And since count was 1, it should have dropped it!
299            assert!(destroyed.load(Ordering::Relaxed));
300        }
301    }
302
303    #[test]
304    fn test_ref_ptr_compare() {
305        let destroyed1 = Arc::new(AtomicBool::new(false));
306        let destroyed2 = Arc::new(AtomicBool::new(false));
307        let ptr1 = make_ref_counted!(TestRustRefCounted { destroyed: destroyed1.clone() }).unwrap();
308        let ptr2 = make_ref_counted!(TestRustRefCounted { destroyed: destroyed2.clone() }).unwrap();
309        let ptr1_clone = ptr1.clone();
310
311        assert!(ptr1 == ptr1);
312        assert!(ptr1 != ptr2);
313        assert!(ptr1 == ptr1_clone);
314    }
315
316    #[test]
317    fn test_rust_pin_init() {
318        let destroyed = Arc::new(AtomicBool::new(false));
319        let destroyed_clone = destroyed.clone();
320        {
321            let ref_ptr =
322                pin_make_ref_counted!(TestRustRefCounted { destroyed: destroyed_clone }).unwrap();
323            assert!(!destroyed.load(Ordering::Relaxed));
324            let ref_ptr_clone = ref_ptr.clone();
325            drop(ref_ptr_clone);
326            assert!(!destroyed.load(Ordering::Relaxed));
327        } // Drop ref_ptr
328        assert!(destroyed.load(Ordering::Relaxed));
329    }
330
331    #[fbl::ref_counted]
332    #[pin_init::pin_data]
333    #[derive(crate::Recyclable)]
334    #[repr(C)]
335    struct FallibleInit {
336        value: i32,
337    }
338
339    #[test]
340    fn test_rust_try_pin_init_fail() {
341        let init = unsafe {
342            ::pin_init::pin_init_from_closure(
343                |_slot: *mut FallibleInit| -> Result<(), AllocError> { Err(AllocError) },
344            )
345        };
346        let res = RefPtr::try_pin_init(init);
347        assert!(res.is_err());
348    }
349}