1#[cfg(target_has_atomic = "ptr")]
2mod atomic;
3
4use core::alloc::LayoutError;
5
6use ptr_meta::{from_raw_parts_mut, Pointee};
7use rancor::{Fallible, Source};
8
9use crate::{
10 alloc::{
11 alloc::{alloc, handle_alloc_error},
12 boxed::Box,
13 rc,
14 },
15 de::{FromMetadata, Metadata, Pooling, PoolingExt as _, SharedPointer},
16 rc::{ArchivedRc, ArchivedRcWeak, RcFlavor, RcResolver, RcWeakResolver},
17 ser::{Sharing, Writer},
18 traits::{ArchivePointee, LayoutRaw},
19 Archive, ArchiveUnsized, Deserialize, DeserializeUnsized, Place, Serialize,
20 SerializeUnsized,
21};
22
23impl<T: ArchiveUnsized + ?Sized> Archive for rc::Rc<T> {
26 type Archived = ArchivedRc<T::Archived, RcFlavor>;
27 type Resolver = RcResolver;
28
29 fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
30 ArchivedRc::resolve_from_ref(self.as_ref(), resolver, out);
31 }
32}
33
34impl<T, S> Serialize<S> for rc::Rc<T>
35where
36 T: SerializeUnsized<S> + ?Sized + 'static,
37 S: Fallible + Writer + Sharing + ?Sized,
38 S::Error: Source,
39{
40 fn serialize(
41 &self,
42 serializer: &mut S,
43 ) -> Result<Self::Resolver, S::Error> {
44 ArchivedRc::<T::Archived, RcFlavor>::serialize_from_ref(
45 self.as_ref(),
46 serializer,
47 )
48 }
49}
50
51unsafe impl<T: LayoutRaw + Pointee + ?Sized> SharedPointer<T> for rc::Rc<T> {
52 fn alloc(metadata: T::Metadata) -> Result<*mut T, LayoutError> {
53 let layout = T::layout_raw(metadata)?;
54 let data_address = if layout.size() > 0 {
55 let ptr = unsafe { alloc(layout) };
56 if ptr.is_null() {
57 handle_alloc_error(layout);
58 }
59 ptr
60 } else {
61 crate::polyfill::dangling(&layout).as_ptr()
62 };
63 let ptr = from_raw_parts_mut(data_address.cast(), metadata);
64 Ok(ptr)
65 }
66
67 unsafe fn from_value(ptr: *mut T) -> *mut T {
68 let rc = rc::Rc::<T>::from(unsafe { Box::from_raw(ptr) });
69 rc::Rc::into_raw(rc).cast_mut()
70 }
71
72 unsafe fn drop(ptr: *mut T) {
73 drop(unsafe { rc::Rc::from_raw(ptr) });
74 }
75}
76
77impl<T, D> Deserialize<rc::Rc<T>, D> for ArchivedRc<T::Archived, RcFlavor>
78where
79 T: ArchiveUnsized + LayoutRaw + Pointee + ?Sized + 'static,
80 T::Archived: DeserializeUnsized<T, D>,
81 T::Metadata: Into<Metadata> + FromMetadata,
82 D: Fallible + Pooling + ?Sized,
83 D::Error: Source,
84{
85 fn deserialize(&self, deserializer: &mut D) -> Result<rc::Rc<T>, D::Error> {
86 let raw_shared_ptr =
87 deserializer.deserialize_shared::<_, rc::Rc<T>>(self.get())?;
88 unsafe {
89 rc::Rc::<T>::increment_strong_count(raw_shared_ptr);
90 }
91 unsafe { Ok(rc::Rc::<T>::from_raw(raw_shared_ptr)) }
92 }
93}
94
95impl<T, U> PartialEq<rc::Rc<U>> for ArchivedRc<T, RcFlavor>
96where
97 T: ArchivePointee + PartialEq<U> + ?Sized,
98 U: ?Sized,
99{
100 fn eq(&self, other: &rc::Rc<U>) -> bool {
101 self.get().eq(other.as_ref())
102 }
103}
104
105impl<T: ArchiveUnsized + ?Sized> Archive for rc::Weak<T> {
108 type Archived = ArchivedRcWeak<T::Archived, RcFlavor>;
109 type Resolver = RcWeakResolver;
110
111 fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
112 ArchivedRcWeak::resolve_from_ref(
113 self.upgrade().as_ref().map(|v| v.as_ref()),
114 resolver,
115 out,
116 );
117 }
118}
119
120impl<T, S> Serialize<S> for rc::Weak<T>
121where
122 T: SerializeUnsized<S> + ?Sized + 'static,
123 S: Fallible + Writer + Sharing + ?Sized,
124 S::Error: Source,
125{
126 fn serialize(
127 &self,
128 serializer: &mut S,
129 ) -> Result<Self::Resolver, S::Error> {
130 ArchivedRcWeak::<T::Archived, RcFlavor>::serialize_from_ref(
131 self.upgrade().as_ref().map(|v| v.as_ref()),
132 serializer,
133 )
134 }
135}
136
137impl<T, D> Deserialize<rc::Weak<T>, D> for ArchivedRcWeak<T::Archived, RcFlavor>
138where
139 T: ArchiveUnsized
142 + LayoutRaw
143 + Pointee + 'static,
145 T::Archived: DeserializeUnsized<T, D>,
146 T::Metadata: Into<Metadata> + FromMetadata,
147 D: Fallible + Pooling + ?Sized,
148 D::Error: Source,
149{
150 fn deserialize(
151 &self,
152 deserializer: &mut D,
153 ) -> Result<rc::Weak<T>, D::Error> {
154 Ok(match self.upgrade() {
155 None => rc::Weak::new(),
156 Some(r) => rc::Rc::downgrade(&r.deserialize(deserializer)?),
157 })
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use munge::munge;
164 use rancor::{Failure, Panic};
165
166 use crate::{
167 access_unchecked, access_unchecked_mut,
168 alloc::{
169 rc::{Rc, Weak},
170 string::{String, ToString},
171 vec,
172 },
173 api::{
174 deserialize_using,
175 test::{roundtrip, to_archived},
176 },
177 de::Pool,
178 rc::{ArchivedRc, ArchivedRcWeak},
179 to_bytes, Archive, Deserialize, Serialize,
180 };
181
182 #[test]
183 fn roundtrip_rc() {
184 #[derive(Debug, Eq, PartialEq, Archive, Deserialize, Serialize)]
185 #[rkyv(crate, compare(PartialEq), derive(Debug))]
186 struct Test {
187 a: Rc<u32>,
188 b: Rc<u32>,
189 }
190
191 let shared = Rc::new(10);
192 let value = Test {
193 a: shared.clone(),
194 b: shared.clone(),
195 };
196
197 to_archived(&value, |mut archived| {
198 assert_eq!(*archived, value);
199
200 munge!(let ArchivedTest { a, .. } = archived.as_mut());
201 unsafe {
202 *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
203 }
204
205 assert_eq!(*archived.a, 42);
206 assert_eq!(*archived.b, 42);
207
208 munge!(let ArchivedTest { b, .. } = archived.as_mut());
209 unsafe {
210 *ArchivedRc::get_seal_unchecked(b) = 17u32.into();
211 }
212
213 assert_eq!(*archived.a, 17);
214 assert_eq!(*archived.b, 17);
215
216 let mut deserializer = Pool::new();
217 let deserialized = deserialize_using::<Test, _, Panic>(
218 &*archived,
219 &mut deserializer,
220 )
221 .unwrap();
222
223 assert_eq!(*deserialized.a, 17);
224 assert_eq!(*deserialized.b, 17);
225 assert_eq!(
226 &*deserialized.a as *const u32,
227 &*deserialized.b as *const u32
228 );
229 assert_eq!(Rc::strong_count(&deserialized.a), 3);
230 assert_eq!(Rc::strong_count(&deserialized.b), 3);
231 assert_eq!(Rc::weak_count(&deserialized.a), 0);
232 assert_eq!(Rc::weak_count(&deserialized.b), 0);
233
234 core::mem::drop(deserializer);
235
236 assert_eq!(*deserialized.a, 17);
237 assert_eq!(*deserialized.b, 17);
238 assert_eq!(
239 &*deserialized.a as *const u32,
240 &*deserialized.b as *const u32
241 );
242 assert_eq!(Rc::strong_count(&deserialized.a), 2);
243 assert_eq!(Rc::strong_count(&deserialized.b), 2);
244 assert_eq!(Rc::weak_count(&deserialized.a), 0);
245 assert_eq!(Rc::weak_count(&deserialized.b), 0);
246 });
247 }
248
249 #[test]
250 fn roundtrip_rc_zst() {
251 #[derive(Archive, Deserialize, Serialize, Debug, PartialEq)]
252 #[rkyv(crate, compare(PartialEq), derive(Debug))]
253 struct TestRcZST {
254 a: Rc<()>,
255 b: Rc<()>,
256 }
257
258 let rc_zst = Rc::new(());
259 roundtrip(&TestRcZST {
260 a: rc_zst.clone(),
261 b: rc_zst.clone(),
262 });
263 }
264
265 #[test]
266 fn roundtrip_unsized_shared_ptr() {
267 #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
268 #[rkyv(crate, compare(PartialEq), derive(Debug))]
269 struct Test {
270 a: Rc<[String]>,
271 b: Rc<[String]>,
272 }
273
274 let rc_slice = Rc::<[String]>::from(
275 vec!["hello".to_string(), "world".to_string()].into_boxed_slice(),
276 );
277 let value = Test {
278 a: rc_slice.clone(),
279 b: rc_slice,
280 };
281
282 roundtrip(&value);
283 }
284
285 #[test]
286 fn roundtrip_unsized_shared_ptr_empty() {
287 #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
288 #[rkyv(crate, compare(PartialEq), derive(Debug))]
289 struct Test {
290 a: Rc<[u32]>,
291 b: Rc<[u32]>,
292 }
293
294 let a_rc_slice = Rc::<[u32]>::from(vec![].into_boxed_slice());
295 let b_rc_slice = Rc::<[u32]>::from(vec![100].into_boxed_slice());
296 let value = Test {
297 a: a_rc_slice,
298 b: b_rc_slice.clone(),
299 };
300
301 roundtrip(&value);
302 }
303
304 #[test]
305 fn roundtrip_weak_ptr() {
306 #[derive(Archive, Serialize, Deserialize)]
307 #[rkyv(crate)]
308 struct Test {
309 a: Rc<u32>,
310 b: Weak<u32>,
311 }
312
313 let shared = Rc::new(10);
314 let value = Test {
315 a: shared.clone(),
316 b: Rc::downgrade(&shared),
317 };
318
319 let mut buf = to_bytes::<Panic>(&value).unwrap();
320
321 let archived =
322 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
323 assert_eq!(*archived.a, 10);
324 assert!(archived.b.upgrade().is_some());
325 assert_eq!(**archived.b.upgrade().unwrap(), 10);
326
327 let mut mutable_archived =
328 unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
329
330 munge!(let ArchivedTest { a, .. } = mutable_archived.as_mut());
331 unsafe {
332 *ArchivedRc::get_seal_unchecked(a) = 42u32.into();
333 }
334
335 let archived =
336 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
337 assert_eq!(*archived.a, 42);
338 assert!(archived.b.upgrade().is_some());
339 assert_eq!(**archived.b.upgrade().unwrap(), 42);
340
341 let mut mutable_archived =
342 unsafe { access_unchecked_mut::<ArchivedTest>(buf.as_mut()) };
343 munge!(let ArchivedTest { b, .. } = mutable_archived.as_mut());
344 unsafe {
345 *ArchivedRc::get_seal_unchecked(
346 ArchivedRcWeak::upgrade_seal(b).unwrap(),
347 ) = 17u32.into();
348 }
349
350 let archived =
351 unsafe { access_unchecked::<ArchivedTest>(buf.as_ref()) };
352 assert_eq!(*archived.a, 17);
353 assert!(archived.b.upgrade().is_some());
354 assert_eq!(**archived.b.upgrade().unwrap(), 17);
355
356 let mut deserializer = Pool::new();
357 let deserialized =
358 deserialize_using::<Test, _, Panic>(archived, &mut deserializer)
359 .unwrap();
360
361 assert_eq!(*deserialized.a, 17);
362 assert!(deserialized.b.upgrade().is_some());
363 assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
364 assert_eq!(
365 &*deserialized.a as *const u32,
366 &*deserialized.b.upgrade().unwrap() as *const u32
367 );
368 assert_eq!(Rc::strong_count(&deserialized.a), 2);
369 assert_eq!(Weak::strong_count(&deserialized.b), 2);
370 assert_eq!(Rc::weak_count(&deserialized.a), 1);
371 assert_eq!(Weak::weak_count(&deserialized.b), 1);
372
373 core::mem::drop(deserializer);
374
375 assert_eq!(*deserialized.a, 17);
376 assert!(deserialized.b.upgrade().is_some());
377 assert_eq!(*deserialized.b.upgrade().unwrap(), 17);
378 assert_eq!(
379 &*deserialized.a as *const u32,
380 &*deserialized.b.upgrade().unwrap() as *const u32
381 );
382 assert_eq!(Rc::strong_count(&deserialized.a), 1);
383 assert_eq!(Weak::strong_count(&deserialized.b), 1);
384 assert_eq!(Rc::weak_count(&deserialized.a), 1);
385 assert_eq!(Weak::weak_count(&deserialized.b), 1);
386 }
387
388 #[test]
389 fn serialize_cyclic_error() {
390 use rancor::{Fallible, Source};
391
392 use crate::{
393 de::Pooling,
394 ser::{Sharing, Writer},
395 };
396
397 #[derive(Archive, Serialize, Deserialize)]
398 #[rkyv(
399 crate,
400 serialize_bounds(
401 __S: Sharing + Writer,
402 <__S as Fallible>::Error: Source,
403 ),
404 deserialize_bounds(
405 __D: Pooling,
406 <__D as Fallible>::Error: Source,
407 )
408 )]
409 #[cfg_attr(
410 feature = "bytecheck",
411 rkyv(bytecheck(bounds(
412 __C: crate::validation::ArchiveContext
413 + crate::validation::SharedContext,
414 <__C as Fallible>::Error: Source,
415 ))),
416 )]
417 struct Inner {
418 #[rkyv(omit_bounds)]
419 weak: Weak<Self>,
420 }
421
422 #[derive(Archive, Serialize, Deserialize)]
423 #[rkyv(crate)]
424 struct Outer {
425 inner: Rc<Inner>,
426 }
427
428 let value = Outer {
429 inner: Rc::new_cyclic(|weak| Inner { weak: weak.clone() }),
430 };
431
432 assert!(to_bytes::<Failure>(&value).is_err());
433 }
434
435 #[cfg(all(
436 feature = "bytecheck",
437 not(feature = "big_endian"),
438 not(any(feature = "pointer_width_16", feature = "pointer_width_64")),
439 ))]
440 #[test]
441 fn recursive_stack_overflow() {
442 use rancor::{Fallible, Source};
443
444 use crate::{
445 access,
446 de::Pooling,
447 util::Align,
448 validation::{ArchiveContext, SharedContext},
449 };
450
451 #[derive(Archive, Deserialize)]
452 #[rkyv(
453 crate,
454 bytecheck(bounds(__C: ArchiveContext + SharedContext)),
455 deserialize_bounds(
456 __D: Pooling,
457 <__D as Fallible>::Error: Source,
458 ),
459 derive(Debug),
460 )]
461 enum AllValues {
462 Rc(#[rkyv(omit_bounds)] Rc<AllValues>),
463 }
464
465 let data = Align([
466 0x00, 0x00, 0x00, 0xff, 0xfc, 0xff, 0xff, 0xff, 0x00, 0x00, 0xf6, 0xff, 0xf4, 0xff, 0xff, 0xff, ]);
471 access::<ArchivedAllValues, Failure>(&*data).unwrap_err();
472 }
473}