Skip to main content

rkyv/impls/std/
with.rs

1use core::{error::Error, fmt, hash::BuildHasher};
2use std::{
3    borrow::Cow,
4    collections::{HashMap, HashSet},
5    ffi::{CStr, OsString},
6    hash::Hash,
7    marker::PhantomData,
8    path::{Path, PathBuf},
9    str::FromStr,
10    sync::{Mutex, RwLock},
11    time::{SystemTime, UNIX_EPOCH},
12};
13
14use rancor::{Fallible, OptionExt, ResultExt, Source};
15
16use crate::{
17    collections::{
18        swiss_table::{ArchivedHashMap, HashMapResolver},
19        util::{Entry, EntryAdapter},
20    },
21    ffi::{ArchivedCString, CStringResolver},
22    hash::FxHasher64,
23    impls::core::with::RefWrapper,
24    ser::{Allocator, Writer},
25    string::{ArchivedString, StringResolver},
26    time::ArchivedDuration,
27    vec::{ArchivedVec, VecResolver},
28    with::{
29        ArchiveWith, AsOwned, AsString, AsUnixTime, AsVec, DeserializeWith,
30        Lock, MapKV, SerializeWith,
31    },
32    Archive, Deserialize, Place, Serialize, SerializeUnsized,
33};
34
35// MapKV
36impl<A, B, K, V, H> ArchiveWith<HashMap<K, V, H>> for MapKV<A, B>
37where
38    A: ArchiveWith<K>,
39    B: ArchiveWith<V>,
40    H: Default + BuildHasher,
41{
42    type Archived = ArchivedHashMap<
43        <A as ArchiveWith<K>>::Archived,
44        <B as ArchiveWith<V>>::Archived,
45    >;
46    type Resolver = HashMapResolver;
47
48    fn resolve_with(
49        field: &HashMap<K, V, H>,
50        resolver: Self::Resolver,
51        out: Place<Self::Archived>,
52    ) {
53        ArchivedHashMap::resolve_from_len(field.len(), (7, 8), resolver, out)
54    }
55}
56
57impl<A, B, K, V, S, H> SerializeWith<HashMap<K, V, H>, S> for MapKV<A, B>
58where
59    A: ArchiveWith<K> + SerializeWith<K, S>,
60    B: ArchiveWith<V> + SerializeWith<V, S>,
61    K: Hash + Eq,
62    <A as ArchiveWith<K>>::Archived: Eq + Hash,
63    S: Fallible + Allocator + Writer + ?Sized,
64    S::Error: Source,
65    H: Default + BuildHasher,
66    H::Hasher: Default,
67{
68    fn serialize_with(
69        field: &HashMap<K, V, H>,
70        serializer: &mut S,
71    ) -> Result<Self::Resolver, <S as Fallible>::Error> {
72        ArchivedHashMap::<_, _, FxHasher64>::serialize_from_iter(
73            field.iter().map(|(k, v)| {
74                (
75                    RefWrapper::<'_, A, K>(k, PhantomData::<A>),
76                    RefWrapper::<'_, B, V>(v, PhantomData::<B>),
77                )
78            }),
79            (7, 8),
80            serializer,
81        )
82    }
83}
84
85impl<A, B, K, V, D, S>
86    DeserializeWith<
87        ArchivedHashMap<
88            <A as ArchiveWith<K>>::Archived,
89            <B as ArchiveWith<V>>::Archived,
90        >,
91        HashMap<K, V, S>,
92        D,
93    > for MapKV<A, B>
94where
95    A: ArchiveWith<K> + DeserializeWith<<A as ArchiveWith<K>>::Archived, K, D>,
96    B: ArchiveWith<V> + DeserializeWith<<B as ArchiveWith<V>>::Archived, V, D>,
97    K: Hash + Eq,
98    D: Fallible + ?Sized,
99    S: Default + BuildHasher,
100{
101    fn deserialize_with(
102        field: &ArchivedHashMap<
103            <A as ArchiveWith<K>>::Archived,
104            <B as ArchiveWith<V>>::Archived,
105        >,
106        deserializer: &mut D,
107    ) -> Result<HashMap<K, V, S>, <D as Fallible>::Error> {
108        let mut result =
109            HashMap::with_capacity_and_hasher(field.len(), S::default());
110        for (k, v) in field.iter() {
111            result.insert(
112                A::deserialize_with(k, deserializer)?,
113                B::deserialize_with(v, deserializer)?,
114            );
115        }
116        Ok(result)
117    }
118}
119
120// AsString
121
122#[derive(Debug)]
123struct InvalidUtf8;
124
125impl fmt::Display for InvalidUtf8 {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "invalid UTF-8")
128    }
129}
130
131impl Error for InvalidUtf8 {}
132
133impl ArchiveWith<OsString> for AsString {
134    type Archived = ArchivedString;
135    type Resolver = StringResolver;
136
137    #[inline]
138    fn resolve_with(
139        field: &OsString,
140        resolver: Self::Resolver,
141        out: Place<Self::Archived>,
142    ) {
143        // It's safe to unwrap here because if the OsString wasn't valid UTF-8
144        // it would have failed to serialize
145        ArchivedString::resolve_from_str(
146            field.to_str().unwrap(),
147            resolver,
148            out,
149        );
150    }
151}
152
153impl<S> SerializeWith<OsString, S> for AsString
154where
155    S: Fallible + ?Sized,
156    S::Error: Source,
157    str: SerializeUnsized<S>,
158{
159    fn serialize_with(
160        field: &OsString,
161        serializer: &mut S,
162    ) -> Result<Self::Resolver, S::Error> {
163        ArchivedString::serialize_from_str(
164            field.to_str().into_trace(InvalidUtf8)?,
165            serializer,
166        )
167    }
168}
169
170impl<D> DeserializeWith<ArchivedString, OsString, D> for AsString
171where
172    D: Fallible + ?Sized,
173{
174    fn deserialize_with(
175        field: &ArchivedString,
176        _: &mut D,
177    ) -> Result<OsString, D::Error> {
178        Ok(OsString::from_str(field.as_str()).unwrap())
179    }
180}
181
182impl ArchiveWith<PathBuf> for AsString {
183    type Archived = ArchivedString;
184    type Resolver = StringResolver;
185
186    #[inline]
187    fn resolve_with(
188        field: &PathBuf,
189        resolver: Self::Resolver,
190        out: Place<Self::Archived>,
191    ) {
192        // It's safe to unwrap here because if the OsString wasn't valid UTF-8
193        // it would have failed to serialize
194        ArchivedString::resolve_from_str(
195            field.to_str().unwrap(),
196            resolver,
197            out,
198        );
199    }
200}
201
202impl<S> SerializeWith<PathBuf, S> for AsString
203where
204    S: Fallible + ?Sized,
205    S::Error: Source,
206    str: SerializeUnsized<S>,
207{
208    fn serialize_with(
209        field: &PathBuf,
210        serializer: &mut S,
211    ) -> Result<Self::Resolver, S::Error> {
212        ArchivedString::serialize_from_str(
213            field.to_str().into_trace(InvalidUtf8)?,
214            serializer,
215        )
216    }
217}
218
219impl<D> DeserializeWith<ArchivedString, PathBuf, D> for AsString
220where
221    D: Fallible + ?Sized,
222{
223    fn deserialize_with(
224        field: &ArchivedString,
225        _: &mut D,
226    ) -> Result<PathBuf, D::Error> {
227        Ok(Path::new(field.as_str()).to_path_buf())
228    }
229}
230
231// Lock
232
233#[derive(Debug)]
234struct LockPoisoned;
235
236impl fmt::Display for LockPoisoned {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        write!(f, "lock poisoned")
239    }
240}
241
242impl Error for LockPoisoned {}
243
244impl<F: Archive> ArchiveWith<Mutex<F>> for Lock {
245    type Archived = F::Archived;
246    type Resolver = F::Resolver;
247
248    fn resolve_with(
249        field: &Mutex<F>,
250        resolver: Self::Resolver,
251        out: Place<Self::Archived>,
252    ) {
253        // Unfortunately, we have to unwrap here because resolve must be
254        // infallible
255        //
256        // An alternative would be to only implement ArchiveWith for
257        // Arc<Mutex<F>>, copy an Arc into the resolver, and hold the
258        // guard in there as well (as a reference to the internal mutex).
259        // This unfortunately will cause a deadlock if two Arcs to the same
260        // Mutex are serialized before the first is resolved. The
261        // compromise is, unfortunately, to just unwrap poison
262        // errors here and document it.
263        field.lock().unwrap().resolve(resolver, out);
264    }
265}
266
267impl<F, S> SerializeWith<Mutex<F>, S> for Lock
268where
269    F: Serialize<S>,
270    S: Fallible + ?Sized,
271    S::Error: Source,
272{
273    fn serialize_with(
274        field: &Mutex<F>,
275        serializer: &mut S,
276    ) -> Result<Self::Resolver, S::Error> {
277        field
278            .lock()
279            .ok()
280            .into_trace(LockPoisoned)?
281            .serialize(serializer)
282    }
283}
284
285impl<F, T, D> DeserializeWith<F, Mutex<T>, D> for Lock
286where
287    F: Deserialize<T, D>,
288    D: Fallible + ?Sized,
289{
290    fn deserialize_with(
291        field: &F,
292        deserializer: &mut D,
293    ) -> Result<Mutex<T>, D::Error> {
294        Ok(Mutex::new(field.deserialize(deserializer)?))
295    }
296}
297
298impl<F: Archive> ArchiveWith<RwLock<F>> for Lock {
299    type Archived = F::Archived;
300    type Resolver = F::Resolver;
301
302    fn resolve_with(
303        field: &RwLock<F>,
304        resolver: Self::Resolver,
305        out: Place<Self::Archived>,
306    ) {
307        // Unfortunately, we have to unwrap here because resolve must be
308        // infallible
309        //
310        // An alternative would be to only implement ArchiveWith for
311        // Arc<Mutex<F>>, copy a an Arc into the resolver, and hold the
312        // guard in there as well (as a reference to the internal
313        // mutex). This unfortunately will cause a deadlock if two Arcs to the
314        // same Mutex are serialized before the first is resolved. The
315        // compromise is, unfortunately, to just unwrap poison errors
316        // here and document it.
317        field.read().unwrap().resolve(resolver, out);
318    }
319}
320
321impl<F, S> SerializeWith<RwLock<F>, S> for Lock
322where
323    F: Serialize<S>,
324    S: Fallible + ?Sized,
325    S::Error: Source,
326{
327    fn serialize_with(
328        field: &RwLock<F>,
329        serializer: &mut S,
330    ) -> Result<Self::Resolver, S::Error> {
331        field
332            .read()
333            .ok()
334            .into_trace(LockPoisoned)?
335            .serialize(serializer)
336    }
337}
338
339impl<F, T, D> DeserializeWith<F, RwLock<T>, D> for Lock
340where
341    F: Deserialize<T, D>,
342    D: Fallible + ?Sized,
343{
344    fn deserialize_with(
345        field: &F,
346        deserializer: &mut D,
347    ) -> Result<RwLock<T>, D::Error> {
348        Ok(RwLock::new(field.deserialize(deserializer)?))
349    }
350}
351
352// AsVec
353
354impl<K: Archive, V: Archive, H> ArchiveWith<HashMap<K, V, H>> for AsVec {
355    type Archived = ArchivedVec<Entry<K::Archived, V::Archived>>;
356    type Resolver = VecResolver;
357
358    fn resolve_with(
359        field: &HashMap<K, V, H>,
360        resolver: Self::Resolver,
361        out: Place<Self::Archived>,
362    ) {
363        ArchivedVec::resolve_from_len(field.len(), resolver, out);
364    }
365}
366
367impl<K, V, H, S> SerializeWith<HashMap<K, V, H>, S> for AsVec
368where
369    K: Serialize<S>,
370    V: Serialize<S>,
371    S: Fallible + Allocator + Writer + ?Sized,
372{
373    fn serialize_with(
374        field: &HashMap<K, V, H>,
375        serializer: &mut S,
376    ) -> Result<Self::Resolver, S::Error> {
377        ArchivedVec::serialize_from_iter(
378            field.iter().map(|(key, value)| {
379                EntryAdapter::<_, _, K, V>::new(key, value)
380            }),
381            serializer,
382        )
383    }
384}
385
386impl<K, V, H, D>
387    DeserializeWith<
388        ArchivedVec<Entry<K::Archived, V::Archived>>,
389        HashMap<K, V, H>,
390        D,
391    > for AsVec
392where
393    K: Archive + Hash + Eq,
394    V: Archive,
395    K::Archived: Deserialize<K, D>,
396    V::Archived: Deserialize<V, D>,
397    H: BuildHasher + Default,
398    D: Fallible + ?Sized,
399{
400    fn deserialize_with(
401        field: &ArchivedVec<Entry<K::Archived, V::Archived>>,
402        deserializer: &mut D,
403    ) -> Result<HashMap<K, V, H>, D::Error> {
404        let mut result =
405            HashMap::with_capacity_and_hasher(field.len(), H::default());
406        for entry in field.iter() {
407            result.insert(
408                entry.key.deserialize(deserializer)?,
409                entry.value.deserialize(deserializer)?,
410            );
411        }
412        Ok(result)
413    }
414}
415
416impl<T: Archive, H> ArchiveWith<HashSet<T, H>> for AsVec {
417    type Archived = ArchivedVec<T::Archived>;
418    type Resolver = VecResolver;
419
420    fn resolve_with(
421        field: &HashSet<T, H>,
422        resolver: Self::Resolver,
423        out: Place<Self::Archived>,
424    ) {
425        ArchivedVec::resolve_from_len(field.len(), resolver, out);
426    }
427}
428
429impl<T, H, S> SerializeWith<HashSet<T, H>, S> for AsVec
430where
431    T: Serialize<S>,
432    S: Fallible + Allocator + Writer + ?Sized,
433{
434    fn serialize_with(
435        field: &HashSet<T, H>,
436        serializer: &mut S,
437    ) -> Result<Self::Resolver, S::Error> {
438        ArchivedVec::<T::Archived>::serialize_from_iter::<T, _, _>(
439            field.iter(),
440            serializer,
441        )
442    }
443}
444
445impl<T, H, D> DeserializeWith<ArchivedVec<T::Archived>, HashSet<T, H>, D>
446    for AsVec
447where
448    T: Archive + Hash + Eq,
449    T::Archived: Deserialize<T, D>,
450    H: BuildHasher + Default,
451    D: Fallible + ?Sized,
452{
453    fn deserialize_with(
454        field: &ArchivedVec<T::Archived>,
455        deserializer: &mut D,
456    ) -> Result<HashSet<T, H>, D::Error> {
457        let mut result =
458            HashSet::with_capacity_and_hasher(field.len(), H::default());
459        for key in field.iter() {
460            result.insert(key.deserialize(deserializer)?);
461        }
462        Ok(result)
463    }
464}
465
466// UnixTimestamp
467
468impl ArchiveWith<SystemTime> for AsUnixTime {
469    type Archived = ArchivedDuration;
470    type Resolver = ();
471
472    #[inline]
473    fn resolve_with(
474        field: &SystemTime,
475        resolver: Self::Resolver,
476        out: Place<Self::Archived>,
477    ) {
478        // We already checked the duration during serialize_with
479        let duration = field.duration_since(UNIX_EPOCH).unwrap();
480        Archive::resolve(&duration, resolver, out);
481    }
482}
483
484impl<S> SerializeWith<SystemTime, S> for AsUnixTime
485where
486    S: Fallible + ?Sized,
487    S::Error: Source,
488{
489    fn serialize_with(
490        field: &SystemTime,
491        _: &mut S,
492    ) -> Result<Self::Resolver, S::Error> {
493        field.duration_since(UNIX_EPOCH).into_error()?;
494        Ok(())
495    }
496}
497
498impl<D> DeserializeWith<ArchivedDuration, SystemTime, D> for AsUnixTime
499where
500    D: Fallible + ?Sized,
501{
502    fn deserialize_with(
503        field: &ArchivedDuration,
504        _: &mut D,
505    ) -> Result<SystemTime, D::Error> {
506        // `checked_add` forces correct type deduction when multiple `Duration`
507        // are present.
508        Ok(UNIX_EPOCH.checked_add((*field).into()).unwrap())
509    }
510}
511
512// AsOwned
513
514impl<'a> ArchiveWith<Cow<'a, CStr>> for AsOwned {
515    type Archived = ArchivedCString;
516    type Resolver = CStringResolver;
517
518    #[inline]
519    fn resolve_with(
520        field: &Cow<'a, CStr>,
521        resolver: Self::Resolver,
522        out: Place<Self::Archived>,
523    ) {
524        ArchivedCString::resolve_from_c_str(field, resolver, out);
525    }
526}
527
528impl<'a, S> SerializeWith<Cow<'a, CStr>, S> for AsOwned
529where
530    S: Fallible + Writer + ?Sized,
531{
532    fn serialize_with(
533        field: &Cow<'a, CStr>,
534        serializer: &mut S,
535    ) -> Result<Self::Resolver, S::Error> {
536        ArchivedCString::serialize_from_c_str(field, serializer)
537    }
538}
539
540impl<'a, D> DeserializeWith<ArchivedCString, Cow<'a, CStr>, D> for AsOwned
541where
542    D: Fallible + ?Sized,
543    D::Error: Source,
544{
545    fn deserialize_with(
546        field: &ArchivedCString,
547        deserializer: &mut D,
548    ) -> Result<Cow<'a, CStr>, D::Error> {
549        Ok(Cow::Owned(field.deserialize(deserializer)?))
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use std::{
556        collections::BTreeMap,
557        ffi::OsString,
558        path::PathBuf,
559        sync::{Mutex, RwLock},
560    };
561
562    use crate::{
563        alloc::collections::HashMap,
564        api::test::{roundtrip_with, to_archived},
565        with::{AsString, InlineAsBox, Lock, MapKV},
566        Archive, Deserialize, Serialize,
567    };
568
569    #[test]
570    fn roundtrip_mutex() {
571        #[derive(Archive, Serialize, Deserialize, Debug)]
572        #[rkyv(crate, derive(Debug, PartialEq))]
573        struct Test {
574            #[rkyv(with = Lock)]
575            value: Mutex<i32>,
576        }
577
578        impl PartialEq for Test {
579            fn eq(&self, other: &Self) -> bool {
580                let self_value = self.value.lock().unwrap();
581                let other_value = other.value.lock().unwrap();
582                *self_value == *other_value
583            }
584        }
585
586        roundtrip_with(
587            &Test {
588                value: Mutex::new(10),
589            },
590            |a, b| {
591                let a_value = a.value.lock().unwrap();
592                assert_eq!(b.value, *a_value);
593            },
594        );
595    }
596
597    #[test]
598    fn with_hash_map_mapkv() {
599        #[derive(Archive, Serialize, Deserialize)]
600        #[rkyv(crate)]
601        struct Test<'a> {
602            #[rkyv(with = MapKV<InlineAsBox, InlineAsBox>)]
603            inner: HashMap<&'a str, &'a str>,
604        }
605
606        let mut inner = HashMap::new();
607        inner.insert("cat", "hat");
608
609        let value = Test { inner };
610
611        to_archived(&value, |archived| {
612            assert_eq!(&**archived.inner.get("cat").unwrap(), "hat");
613        });
614    }
615
616    #[test]
617    fn with_btree_map_mapkv() {
618        #[derive(Archive, Serialize, Deserialize)]
619        #[rkyv(crate)]
620        struct Test<'a> {
621            #[rkyv(with = MapKV<InlineAsBox, InlineAsBox>)]
622            inner: BTreeMap<&'a str, &'a str>,
623        }
624
625        let mut inner = BTreeMap::new();
626        inner.insert("cat", "hat");
627
628        let value = Test { inner };
629
630        to_archived(&value, |archived| {
631            assert_eq!(&**archived.inner.get("cat").unwrap(), "hat");
632        });
633    }
634
635    #[test]
636    fn roundtrip_rwlock() {
637        #[derive(Archive, Serialize, Deserialize, Debug)]
638        #[rkyv(crate, derive(Debug, PartialEq))]
639        struct Test {
640            #[rkyv(with = Lock)]
641            value: RwLock<i32>,
642        }
643
644        impl PartialEq for Test {
645            fn eq(&self, other: &Self) -> bool {
646                let self_value = self.value.try_read().unwrap();
647                let other_value = other.value.try_read().unwrap();
648                *self_value == *other_value
649            }
650        }
651
652        roundtrip_with(
653            &Test {
654                value: RwLock::new(10),
655            },
656            |a, b| {
657                let a_value = a.value.try_read().unwrap();
658                assert_eq!(b.value, *a_value);
659            },
660        );
661    }
662
663    #[test]
664    fn roundtrip_os_string() {
665        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
666        #[rkyv(crate, derive(Debug, PartialEq))]
667        struct Test {
668            #[rkyv(with = AsString)]
669            value: OsString,
670        }
671
672        roundtrip_with(
673            &Test {
674                value: OsString::from("hello world"),
675            },
676            |a, b| {
677                assert_eq!(a.value.as_os_str().to_str().unwrap(), b.value);
678            },
679        );
680    }
681
682    #[test]
683    fn roundtrip_path_buf() {
684        #[derive(Archive, Serialize, Deserialize, Debug, PartialEq)]
685        #[rkyv(crate, derive(Debug, PartialEq))]
686        struct Test {
687            #[rkyv(with = AsString)]
688            value: PathBuf,
689        }
690
691        roundtrip_with(
692            &Test {
693                value: PathBuf::from("hello world"),
694            },
695            |a, b| {
696                assert_eq!(a.value.as_os_str().to_str().unwrap(), b.value);
697            },
698        );
699    }
700}