Skip to main content

rkyv/impls/alloc/rc/
mod.rs

1#[cfg(target_has_atomic = "ptr")]
2mod atomic;
3
4use core::alloc::LayoutError;
5
6use ptr_meta::{from_raw_parts_mut, Pointee};
7use rancor::{Fallible, Source};
8
9use crate::{
10    alloc::{
11        alloc::{alloc, handle_alloc_error},
12        boxed::Box,
13        rc,
14    },
15    de::{FromMetadata, Metadata, Pooling, PoolingExt as _, SharedPointer},
16    rc::{ArchivedRc, ArchivedRcWeak, RcFlavor, RcResolver, RcWeakResolver},
17    ser::{Sharing, Writer},
18    traits::{ArchivePointee, LayoutRaw},
19    Archive, ArchiveUnsized, Deserialize, DeserializeUnsized, Place, Serialize,
20    SerializeUnsized,
21};
22
23// Rc
24
25impl<T: ArchiveUnsized + ?Sized> Archive for rc::Rc<T> {
26    type Archived = ArchivedRc<T::Archived, RcFlavor>;
27    type Resolver = RcResolver;
28
29    fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
30        ArchivedRc::resolve_from_ref(self.as_ref(), resolver, out);
31    }
32}
33
34impl<T, S> Serialize<S> for rc::Rc<T>
35where
36    T: SerializeUnsized<S> + ?Sized + 'static,
37    S: Fallible + Writer + Sharing + ?Sized,
38    S::Error: Source,
39{
40    fn serialize(
41        &self,
42        serializer: &mut S,
43    ) -> Result<Self::Resolver, S::Error> {
44        ArchivedRc::<T::Archived, RcFlavor>::serialize_from_ref(
45            self.as_ref(),
46            serializer,
47        )
48    }
49}
50
51unsafe impl<T: LayoutRaw + Pointee + ?Sized> SharedPointer<T> for rc::Rc<T> {
52    fn alloc(metadata: T::Metadata) -> Result<*mut T, LayoutError> {
53        let layout = T::layout_raw(metadata)?;
54        let data_address = if layout.size() > 0 {
55            let ptr = unsafe { alloc(layout) };
56            if ptr.is_null() {
57                handle_alloc_error(layout);
58            }
59            ptr
60        } else {
61            crate::polyfill::dangling(&layout).as_ptr()
62        };
63        let ptr = from_raw_parts_mut(data_address.cast(), metadata);
64        Ok(ptr)
65    }
66
67    unsafe fn from_value(ptr: *mut T) -> *mut T {
68        let rc = rc::Rc::<T>::from(unsafe { Box::from_raw(ptr) });
69        rc::Rc::into_raw(rc).cast_mut()
70    }
71
72    unsafe fn drop(ptr: *mut T) {
73        drop(unsafe { rc::Rc::from_raw(ptr) });
74    }
75}
76
77impl<T, D> Deserialize<rc::Rc<T>, D> for ArchivedRc<T::Archived, RcFlavor>
78where
79    T: ArchiveUnsized + LayoutRaw + Pointee + ?Sized + 'static,
80    T::Archived: DeserializeUnsized<T, D>,
81    T::Metadata: Into<Metadata> + FromMetadata,
82    D: Fallible + Pooling + ?Sized,
83    D::Error: Source,
84{
85    fn deserialize(&self, deserializer: &mut D) -> Result<rc::Rc<T>, D::Error> {
86        let raw_shared_ptr =
87            deserializer.deserialize_shared::<_, rc::Rc<T>>(self.get())?;
88        unsafe {
89            rc::Rc::<T>::increment_strong_count(raw_shared_ptr);
90        }
91        unsafe { Ok(rc::Rc::<T>::from_raw(raw_shared_ptr)) }
92    }
93}
94
95impl<T, U> PartialEq<rc::Rc<U>> for ArchivedRc<T, RcFlavor>
96where
97    T: ArchivePointee + PartialEq<U> + ?Sized,
98    U: ?Sized,
99{
100    fn eq(&self, other: &rc::Rc<U>) -> bool {
101        self.get().eq(other.as_ref())
102    }
103}
104
105// rc::Weak
106
107impl<T: ArchiveUnsized + ?Sized> Archive for rc::Weak<T> {
108    type Archived = ArchivedRcWeak<T::Archived, RcFlavor>;
109    type Resolver = RcWeakResolver;
110
111    fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
112        ArchivedRcWeak::resolve_from_ref(
113            self.upgrade().as_ref().map(|v| v.as_ref()),
114            resolver,
115            out,
116        );
117    }
118}
119
120impl<T, S> Serialize<S> for rc::Weak<T>
121where
122    T: SerializeUnsized<S> + ?Sized + 'static,
123    S: Fallible + Writer + Sharing + ?Sized,
124    S::Error: Source,
125{
126    fn serialize(
127        &self,
128        serializer: &mut S,
129    ) -> Result<Self::Resolver, S::Error> {
130        ArchivedRcWeak::<T::Archived, RcFlavor>::serialize_from_ref(
131            self.upgrade().as_ref().map(|v| v.as_ref()),
132            serializer,
133        )
134    }
135}
136
137impl<T, D> Deserialize<rc::Weak<T>, D> for ArchivedRcWeak<T::Archived, RcFlavor>
138where
139    // Deserialize can only be implemented for sized types because weak pointers
140    // to unsized types don't have `new` functions.
141    T: ArchiveUnsized
142        + LayoutRaw
143        + Pointee // + ?Sized
144        + 'static,
145    T::Archived: DeserializeUnsized<T, D>,
146    T::Metadata: Into<Metadata> + FromMetadata,
147    D: Fallible + Pooling + ?Sized,
148    D::Error: Source,
149{
150    fn deserialize(
151        &self,
152        deserializer: &mut D,
153    ) -> Result<rc::Weak<T>, D::Error> {
154        Ok(match self.upgrade() {
155            None => rc::Weak::new(),
156            Some(r) => rc::Rc::downgrade(&r.deserialize(deserializer)?),
157        })
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use munge::munge;
164    use rancor::{Failure, Panic};
165
166    use crate::{
167        access_unchecked, access_unchecked_mut,
168        alloc::{
169            rc::{Rc, Weak},
170            string::{String, ToString},
171            vec,
172        },
173        api::{
174            deserialize_using,
175            test::{roundtrip, to_archived},
176        },
177        de::Pool,
178        rc::{ArchivedRc, ArchivedRcWeak},
179        to_bytes, Archive, Deserialize, Serialize,
180    };
181
182    #[test]
183    fn roundtrip_rc() {
184        #[derive(Debug, Eq, PartialEq, Archive, Deserialize, Serialize)]
185        #[rkyv(crate, compare(PartialEq), derive(Debug))]
186        struct Test {
187            a: Rc<u32>,
188            b: Rc<u32>,
189        }
190
191        let shared = Rc::new(10);
192        let value = Test {
193            a: shared.clone(),
194            b: shared.clone(),
195        };
196
197        to_archived(&value, |mut archived| {
198            assert_eq!(*archived, value);
199
200            munge!(let ArchivedTest { a, .. } = archived.as_mut());
201            unsafe {
202                *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
203            }
204
205            assert_eq!(*archived.a, 42);
206            assert_eq!(*archived.b, 42);
207
208            munge!(let ArchivedTest { b, .. } = archived.as_mut());
209            unsafe {
210                *ArchivedRc::get_seal_unchecked(b) = 17u32.into();
211            }
212
213            assert_eq!(*archived.a, 17);
214            assert_eq!(*archived.b, 17);
215
216            let mut deserializer = Pool::new();
217            let deserialized = deserialize_using::<Test, _, Panic>(
218                &*archived,
219                &mut deserializer,
220            )
221            .unwrap();
222
223            assert_eq!(*deserialized.a, 17);
224            assert_eq!(*deserialized.b, 17);
225            assert_eq!(
226                &*deserialized.a as *const u32,
227                &*deserialized.b as *const u32
228            );
229            assert_eq!(Rc::strong_count(&deserialized.a), 3);
230            assert_eq!(Rc::strong_count(&deserialized.b), 3);
231            assert_eq!(Rc::weak_count(&deserialized.a), 0);
232            assert_eq!(Rc::weak_count(&deserialized.b), 0);
233
234            core::mem::drop(deserializer);
235
236            assert_eq!(*deserialized.a, 17);
237            assert_eq!(*deserialized.b, 17);
238            assert_eq!(
239                &*deserialized.a as *const u32,
240                &*deserialized.b as *const u32
241            );
242            assert_eq!(Rc::strong_count(&deserialized.a), 2);
243            assert_eq!(Rc::strong_count(&deserialized.b), 2);
244            assert_eq!(Rc::weak_count(&deserialized.a), 0);
245            assert_eq!(Rc::weak_count(&deserialized.b), 0);
246        });
247    }
248
249    #[test]
250    fn roundtrip_rc_zst() {
251        #[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
252        #[rkyv(crate, compare(PartialEq), derive(Debug))]
253        struct TestRcZST {
254            a: Rc<()>,
255            b: Rc<()>,
256        }
257
258        let rc_zst = Rc::new(());
259        roundtrip(&TestRcZST {
260            a: rc_zst.clone(),
261            b: rc_zst.clone(),
262        });
263    }
264
265    #[test]
266    fn roundtrip_unsized_shared_ptr() {
267        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
268        #[rkyv(crate, compare(PartialEq), derive(Debug))]
269        struct Test {
270            a: Rc<[String]>,
271            b: Rc<[String]>,
272        }
273
274        let rc_slice = Rc::<[String]>::from(
275            vec!["hello".to_string(), "world".to_string()].into_boxed_slice(),
276        );
277        let value = Test {
278            a: rc_slice.clone(),
279            b: rc_slice,
280        };
281
282        roundtrip(&value);
283    }
284
285    #[test]
286    fn roundtrip_unsized_shared_ptr_empty() {
287        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
288        #[rkyv(crate, compare(PartialEq), derive(Debug))]
289        struct Test {
290            a: Rc<[u32]>,
291            b: Rc<[u32]>,
292        }
293
294        let a_rc_slice = Rc::<[u32]>::from(vec![].into_boxed_slice());
295        let b_rc_slice = Rc::<[u32]>::from(vec![100].into_boxed_slice());
296        let value = Test {
297            a: a_rc_slice,
298            b: b_rc_slice.clone(),
299        };
300
301        roundtrip(&value);
302    }
303
304    #[test]
305    fn roundtrip_weak_ptr() {
306        #[derive(Archive, Serialize, Deserialize)]
307        #[rkyv(crate)]
308        struct Test {
309            a: Rc<u32>,
310            b: Weak<u32>,
311        }
312
313        let shared = Rc::new(10);
314        let value = Test {
315            a: shared.clone(),
316            b: Rc::downgrade(&shared),
317        };
318
319        let mut buf = to_bytes::<Panic>(&value).unwrap();
320
321        let archived =
322            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
323        assert_eq!(*archived.a, 10);
324        assert!(archived.b.upgrade().is_some());
325        assert_eq!(**archived.b.upgrade().unwrap(), 10);
326
327        let mut mutable_archived =
328            unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
329
330        munge!(let ArchivedTest { a, .. } = mutable_archived.as_mut());
331        unsafe {
332            *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
333        }
334
335        let archived =
336            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
337        assert_eq!(*archived.a, 42);
338        assert!(archived.b.upgrade().is_some());
339        assert_eq!(**archived.b.upgrade().unwrap(), 42);
340
341        let mut mutable_archived =
342            unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
343        munge!(let ArchivedTest { b, .. } = mutable_archived.as_mut());
344        unsafe {
345            *ArchivedRc::get_seal_unchecked(
346                ArchivedRcWeak::upgrade_seal(b).unwrap(),
347            ) = 17u32.into();
348        }
349
350        let archived =
351            unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
352        assert_eq!(*archived.a, 17);
353        assert!(archived.b.upgrade().is_some());
354        assert_eq!(**archived.b.upgrade().unwrap(), 17);
355
356        let mut deserializer = Pool::new();
357        let deserialized =
358            deserialize_using::<Test, _, Panic>(archived, &mut deserializer)
359                .unwrap();
360
361        assert_eq!(*deserialized.a, 17);
362        assert!(deserialized.b.upgrade().is_some());
363        assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
364        assert_eq!(
365            &*deserialized.a as *const u32,
366            &*deserialized.b.upgrade().unwrap() as *const u32
367        );
368        assert_eq!(Rc::strong_count(&deserialized.a), 2);
369        assert_eq!(Weak::strong_count(&deserialized.b), 2);
370        assert_eq!(Rc::weak_count(&deserialized.a), 1);
371        assert_eq!(Weak::weak_count(&deserialized.b), 1);
372
373        core::mem::drop(deserializer);
374
375        assert_eq!(*deserialized.a, 17);
376        assert!(deserialized.b.upgrade().is_some());
377        assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
378        assert_eq!(
379            &*deserialized.a as *const u32,
380            &*deserialized.b.upgrade().unwrap() as *const u32
381        );
382        assert_eq!(Rc::strong_count(&deserialized.a), 1);
383        assert_eq!(Weak::strong_count(&deserialized.b), 1);
384        assert_eq!(Rc::weak_count(&deserialized.a), 1);
385        assert_eq!(Weak::weak_count(&deserialized.b), 1);
386    }
387
388    #[test]
389    fn serialize_cyclic_error() {
390        use rancor::{Fallible, Source};
391
392        use crate::{
393            de::Pooling,
394            ser::{Sharing, Writer},
395        };
396
397        #[derive(Archive, Serialize, Deserialize)]
398        #[rkyv(
399            crate,
400            serialize_bounds(
401                __S: Sharing + Writer,
402                <__S as Fallible>::Error: Source,
403            ),
404            deserialize_bounds(
405                __D: Pooling,
406                <__D as Fallible>::Error: Source,
407            )
408        )]
409        #[cfg_attr(
410            feature = "bytecheck",
411            rkyv(bytecheck(bounds(
412                __C: crate::validation::ArchiveContext
413                    + crate::validation::SharedContext,
414                <__C as Fallible>::Error: Source,
415            ))),
416        )]
417        struct Inner {
418            #[rkyv(omit_bounds)]
419            weak: Weak<Self>,
420        }
421
422        #[derive(Archive, Serialize, Deserialize)]
423        #[rkyv(crate)]
424        struct Outer {
425            inner: Rc<Inner>,
426        }
427
428        let value = Outer {
429            inner: Rc::new_cyclic(|weak| Inner { weak: weak.clone() }),
430        };
431
432        assert!(to_bytes::<Failure>(&value).is_err());
433    }
434
435    #[cfg(all(
436        feature = "bytecheck",
437        not(feature = "big_endian"),
438        not(any(feature = "pointer_width_16", feature = "pointer_width_64")),
439    ))]
440    #[test]
441    fn recursive_stack_overflow() {
442        use rancor::{Fallible, Source};
443
444        use crate::{
445            access,
446            de::Pooling,
447            util::Align,
448            validation::{ArchiveContext, SharedContext},
449        };
450
451        #[derive(Archive, Deserialize)]
452        #[rkyv(
453            crate,
454            bytecheck(bounds(__C: ArchiveContext + SharedContext)),
455            deserialize_bounds(
456                __D: Pooling,
457                <__D as Fallible>::Error: Source,
458            ),
459            derive(Debug),
460        )]
461        enum AllValues {
462            Rc(#[rkyv(omit_bounds)] Rc<AllValues>),
463        }
464
465        let data = Align([
466            0x00, 0x00, 0x00, 0xff, // B: AllValues::Rc
467            0xfc, 0xff, 0xff, 0xff, // RelPtr with offset -4 (B)
468            0x00, 0x00, 0xf6, 0xff, // A: AllValues::Rc
469            0xf4, 0xff, 0xff, 0xff, // RelPtr with offset -12 (B)
470        ]);
471        access::<ArchivedAllValues, Failure>(&*data).unwrap_err();
472    }
473}