1use crate::recyclable::{Recyclable, UninitRecyclable};
8use crate::ref_counted::HasRefCount;
9use core::mem::MaybeUninit;
10use core::ops::Deref;
11use core::ptr::NonNull;
12use kalloc::AllocError;
13
14use pin_init::{Init, PinInit};
15
16#[repr(C)]
22pub struct RefPtr<T>
23where
24 T: HasRefCount + Recyclable,
25{
26 ptr: NonNull<T>,
27}
28
29impl<T: HasRefCount + Recyclable> RefPtr<T> {
30 pub unsafe fn from_raw(ptr: *const T) -> Self {
39 unsafe { RefPtr { ptr: NonNull::new_unchecked(ptr as *mut T) } }
41 }
42
43 pub unsafe fn try_new(value: T) -> Result<RefPtr<T>, AllocError> {
55 let mut ptr = T::allocate(value)?;
56 unsafe { ptr.as_mut().ref_count().adopt() };
59 Ok(RefPtr { ptr })
60 }
61
62 pub fn as_ptr(this: &Self) -> *const T {
64 this.ptr.as_ptr()
65 }
66
67 pub fn ptr_eq(a: &Self, b: &Self) -> bool {
69 a.ptr == b.ptr
70 }
71
72 pub fn into_raw(this: Self) -> *const T {
76 let ptr = this.ptr;
77 core::mem::forget(this);
78 ptr.as_ptr()
79 }
80
81 pub fn try_pin_init<E>(init: impl PinInit<T, E>) -> Result<Self, E>
83 where
84 T: UninitRecyclable,
85 E: From<AllocError>,
86 {
87 let ptr = T::allocate_uninit()?;
88 let guard = UninitRefGuard { ptr };
89 let slot = guard.ptr.as_ptr() as *mut T;
90 unsafe { init.__pinned_init(slot)? };
92 unsafe { (*slot).ref_count().adopt() };
94 let initialized_ptr = guard.ptr.cast::<T>();
95 core::mem::forget(guard);
96 let initialized_ref = RefPtr { ptr: initialized_ptr };
97 Ok(initialized_ref)
98 }
99
100 pub fn try_init<E>(init: impl Init<T, E>) -> Result<Self, E>
102 where
103 T: UninitRecyclable,
104 E: From<AllocError>,
105 {
106 let ptr = T::allocate_uninit()?;
107 let guard = UninitRefGuard { ptr };
108 let slot = guard.ptr.as_ptr() as *mut T;
109 unsafe { init.__init(slot)? };
111 unsafe { (*slot).ref_count().adopt() };
113 let initialized_ptr = guard.ptr.cast::<T>();
114 core::mem::forget(guard);
115 Ok(RefPtr { ptr: initialized_ptr })
116 }
117
118 #[inline]
120 pub fn pin_init(init: impl PinInit<T, core::convert::Infallible>) -> Result<Self, AllocError>
121 where
122 T: UninitRecyclable,
123 {
124 let init = unsafe {
125 ::pin_init::pin_init_from_closure(|slot| {
126 init.__pinned_init(slot).map_err(|i| match i {})
127 })
128 };
129 Self::try_pin_init(init)
130 }
131
132 #[inline]
134 pub fn init(init: impl Init<T, core::convert::Infallible>) -> Result<Self, AllocError>
135 where
136 T: UninitRecyclable,
137 {
138 let init = unsafe {
139 ::pin_init::init_from_closure(|slot| init.__init(slot).map_err(|i| match i {}))
140 };
141 Self::try_init(init)
142 }
143}
144
145impl<T: HasRefCount + Recyclable> Deref for RefPtr<T> {
146 type Target = T;
147 fn deref(&self) -> &Self::Target {
148 unsafe { self.ptr.as_ref() }
149 }
150}
151
152impl<T: HasRefCount + Recyclable> Clone for RefPtr<T> {
153 fn clone(&self) -> Self {
154 self.deref().ref_count().add_ref();
155 RefPtr { ptr: self.ptr }
156 }
157}
158
159impl<T: HasRefCount + Recyclable> Drop for RefPtr<T> {
160 fn drop(&mut self) {
161 if self.deref().ref_count().release() {
162 unsafe {
163 T::recycle(self.ptr);
164 }
165 }
166 }
167}
168
169impl<T: HasRefCount + Recyclable> PartialEq for RefPtr<T> {
170 fn eq(&self, other: &Self) -> bool {
171 RefPtr::ptr_eq(self, other)
172 }
173}
174
175impl<T: HasRefCount + Recyclable> Eq for RefPtr<T> {}
176
177unsafe impl<T: HasRefCount + Recyclable + Send + Sync> Send for RefPtr<T> {}
178unsafe impl<T: HasRefCount + Recyclable + Send + Sync> Sync for RefPtr<T> {}
179
180struct UninitRefGuard<T: UninitRecyclable> {
181 ptr: NonNull<MaybeUninit<T>>,
182}
183
184impl<T: UninitRecyclable> Drop for UninitRefGuard<T> {
185 fn drop(&mut self) {
186 unsafe {
187 T::recycle_uninit(self.ptr);
188 }
189 }
190}
191
192#[macro_export]
194macro_rules! make_ref_counted {
195 ($ty:ident { $($field:ident : $val:expr),* $(,)? }) => {
196 unsafe {
198 $crate::RefPtr::try_new($ty {
199 ref_count: $crate::RefCounted::new(),
200 __fbl_ref_counted_guard: (),
201 $($field : $val),*
202 })
203 }
204 };
205}
206
207#[macro_export]
209macro_rules! pin_make_ref_counted {
210 ($ty:ident { $($field:tt)* }) => {
211 $crate::RefPtr::pin_init($crate::pin_init::pin_init!($ty {
212 ref_count: $crate::RefCounted::new(),
213 __fbl_ref_counted_guard: (),
214 $($field)*
215 }))
216 };
217}
218
219#[macro_export]
221macro_rules! try_pin_make_ref_counted {
222 ($ty:ident { $($field:tt)* }) => {
223 $crate::RefPtr::try_pin_init($crate::pin_init::pin_init!($ty {
224 ref_count: $crate::RefCounted::new(),
225 __fbl_ref_counted_guard: (),
226 $($field)*
227 }))
228 };
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use core::ffi::c_void;
235 use core::pin::Pin;
236 use core::sync::atomic::{AtomicBool, Ordering};
237
238 extern crate alloc;
239 use alloc::sync::Arc;
240
241 #[unsafe(no_mangle)]
242 pub extern "C" fn rust_recycle_test_rust_ref_counted(ptr: *mut c_void) {
243 unsafe { TestRustRefCounted::recycle_ffi(ptr) }
244 }
245
246 unsafe extern "C" {
247 fn test_import_rust_ref_counted(ptr: *mut c_void);
248 }
249
250 #[fbl::ref_counted]
251 #[pin_init::pin_data(PinnedDrop)]
252 #[derive(crate::Recyclable)]
253 #[repr(C)]
254 pub struct TestRustRefCounted {
255 destroyed: Arc<AtomicBool>,
256 }
257
258 ::zr::static_assert!(core::mem::size_of::<RefPtr<TestRustRefCounted>>() == 8);
259 ::zr::static_assert!(core::mem::align_of::<RefPtr<TestRustRefCounted>>() == 8);
260 ::zr::static_assert!(core::mem::size_of::<Option<RefPtr<TestRustRefCounted>>>() == 8);
261 ::zr::static_assert!(core::mem::align_of::<Option<RefPtr<TestRustRefCounted>>>() == 8);
262
263 #[pin_init::pinned_drop]
264 impl pin_init::PinnedDrop for TestRustRefCounted {
265 fn drop(self: Pin<&mut Self>) {
266 self.destroyed.store(true, Ordering::Relaxed);
267 }
268 }
269
270 #[test]
271 fn test_rust_drops_reference() {
272 let destroyed = Arc::new(AtomicBool::new(false));
273 {
274 let ref_ptr =
275 make_ref_counted!(TestRustRefCounted { destroyed: destroyed.clone() }).unwrap();
276 assert!(!destroyed.load(Ordering::Relaxed));
277 let ref_ptr_clone = ref_ptr.clone();
278 drop(ref_ptr_clone);
279 assert!(!destroyed.load(Ordering::Relaxed));
280 } assert!(destroyed.load(Ordering::Relaxed));
283 }
284
285 #[test]
286 #[cfg_attr(miri, ignore = "miri does not support calling foreign functions")]
287 fn test_cpp_drops_reference() {
288 let destroyed = Arc::new(AtomicBool::new(false));
289 let ref_ptr =
290 make_ref_counted!(TestRustRefCounted { destroyed: destroyed.clone() }).unwrap();
291 let raw_ptr = RefPtr::into_raw(ref_ptr);
292
293 unsafe {
294 assert!(!destroyed.load(Ordering::Relaxed));
295 test_import_rust_ref_counted(raw_ptr as *const TestRustRefCounted as *mut c_void);
297 assert!(destroyed.load(Ordering::Relaxed));
300 }
301 }
302
303 #[test]
304 fn test_ref_ptr_compare() {
305 let destroyed1 = Arc::new(AtomicBool::new(false));
306 let destroyed2 = Arc::new(AtomicBool::new(false));
307 let ptr1 = make_ref_counted!(TestRustRefCounted { destroyed: destroyed1.clone() }).unwrap();
308 let ptr2 = make_ref_counted!(TestRustRefCounted { destroyed: destroyed2.clone() }).unwrap();
309 let ptr1_clone = ptr1.clone();
310
311 assert!(ptr1 == ptr1);
312 assert!(ptr1 != ptr2);
313 assert!(ptr1 == ptr1_clone);
314 }
315
316 #[test]
317 fn test_rust_pin_init() {
318 let destroyed = Arc::new(AtomicBool::new(false));
319 let destroyed_clone = destroyed.clone();
320 {
321 let ref_ptr =
322 pin_make_ref_counted!(TestRustRefCounted { destroyed: destroyed_clone }).unwrap();
323 assert!(!destroyed.load(Ordering::Relaxed));
324 let ref_ptr_clone = ref_ptr.clone();
325 drop(ref_ptr_clone);
326 assert!(!destroyed.load(Ordering::Relaxed));
327 } assert!(destroyed.load(Ordering::Relaxed));
329 }
330
331 #[fbl::ref_counted]
332 #[pin_init::pin_data]
333 #[derive(crate::Recyclable)]
334 #[repr(C)]
335 struct FallibleInit {
336 value: i32,
337 }
338
339 #[test]
340 fn test_rust_try_pin_init_fail() {
341 let init = unsafe {
342 ::pin_init::pin_init_from_closure(
343 |_slot: *mut FallibleInit| -> Result<(), AllocError> { Err(AllocError) },
344 )
345 };
346 let res = RefPtr::try_pin_init(init);
347 assert!(res.is_err());
348 }
349}