Skip to main content

rkyv/impls/std/collections/
hash_map.rs

1use core::{
2    borrow::Borrow,
3    hash::{BuildHasher, Hash},
4};
5use std::collections::HashMap;
6
7use rancor::{Fallible, Source};
8
9use crate::{
10    collections::swiss_table::map::{ArchivedHashMap, HashMapResolver},
11    ser::{Allocator, Writer},
12    Archive, Deserialize, Place, Serialize,
13};
14
15impl<K, V: Archive, S> Archive for HashMap<K, V, S>
16where
17    K: Archive + Hash + Eq,
18    K::Archived: Hash + Eq,
19{
20    type Archived = ArchivedHashMap<K::Archived, V::Archived>;
21    type Resolver = HashMapResolver;
22
23    fn resolve(&self, resolver: Self::Resolver, out: Place<Self::Archived>) {
24        ArchivedHashMap::resolve_from_len(self.len(), (7, 8), resolver, out);
25    }
26}
27
28impl<K, V, S, RandomState> Serialize<S> for HashMap<K, V, RandomState>
29where
30    K: Serialize<S> + Hash + Eq,
31    K::Archived: Hash + Eq,
32    V: Serialize<S>,
33    S: Fallible + Writer + Allocator + ?Sized,
34    S::Error: Source,
35{
36    fn serialize(
37        &self,
38        serializer: &mut S,
39    ) -> Result<Self::Resolver, S::Error> {
40        ArchivedHashMap::<K::Archived, V::Archived>::serialize_from_iter::<
41            _,
42            _,
43            _,
44            K,
45            V,
46            _,
47        >(self.iter(), (7, 8), serializer)
48    }
49}
50
51impl<K, V, D, S> Deserialize<HashMap<K, V, S>, D>
52    for ArchivedHashMap<K::Archived, V::Archived>
53where
54    K: Archive + Hash + Eq,
55    K::Archived: Deserialize<K, D> + Hash + Eq,
56    V: Archive,
57    V::Archived: Deserialize<V, D>,
58    D: Fallible + ?Sized,
59    S: Default + BuildHasher,
60{
61    fn deserialize(
62        &self,
63        deserializer: &mut D,
64    ) -> Result<HashMap<K, V, S>, D::Error> {
65        let mut result =
66            HashMap::with_capacity_and_hasher(self.len(), S::default());
67        for (k, v) in self.iter() {
68            result.insert(
69                k.deserialize(deserializer)?,
70                v.deserialize(deserializer)?,
71            );
72        }
73        Ok(result)
74    }
75}
76
77impl<K, V, AK, AV, S> PartialEq<HashMap<K, V, S>> for ArchivedHashMap<AK, AV>
78where
79    K: Hash + Eq + Borrow<AK>,
80    AK: Hash + Eq,
81    AV: PartialEq<V>,
82    S: BuildHasher,
83{
84    fn eq(&self, other: &HashMap<K, V, S>) -> bool {
85        if self.len() != other.len() {
86            false
87        } else {
88            self.iter()
89                .all(|(key, value)| other.get(key).is_some_and(|v| value.eq(v)))
90        }
91    }
92}
93
94impl<K, V, AK, AV> PartialEq<ArchivedHashMap<AK, AV>> for HashMap<K, V>
95where
96    K: Hash + Eq + Borrow<AK>,
97    AK: Hash + Eq,
98    AV: PartialEq<V>,
99{
100    fn eq(&self, other: &ArchivedHashMap<AK, AV>) -> bool {
101        other.eq(self)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use core::{fmt::Debug, hash::BuildHasher};
108    use std::collections::HashMap;
109
110    use ahash::RandomState;
111
112    use crate::{
113        api::test::{roundtrip, roundtrip_with, to_archived},
114        collections::swiss_table::ArchivedHashMap,
115        string::ArchivedString,
116        Archive, Archived, Deserialize, Serialize,
117    };
118
119    fn assert_equal<V, S: BuildHasher>(
120        a: &HashMap<String, V, S>,
121        b: &Archived<HashMap<String, V, S>>,
122    ) where
123        V: Archive + Debug + PartialEq<V::Archived>,
124        V::Archived: Debug + PartialEq<V>,
125    {
126        assert_eq!(a.len(), b.len());
127
128        for (key, value) in a.iter() {
129            assert!(b.contains_key(key.as_str()));
130            assert_eq!(&b[key.as_str()], value);
131        }
132
133        for (key, value) in b.iter() {
134            assert!(a.contains_key(key.as_str()));
135            assert_eq!(&a[key.as_str()], value);
136        }
137    }
138
139    #[test]
140    fn roundtrip_empty_hash_map() {
141        roundtrip(&HashMap::<i8, i32>::default());
142    }
143
144    #[test]
145    fn roundtrip_hash_map_string_int() {
146        let mut map = HashMap::new();
147        map.insert("Hello".to_string(), 12);
148        map.insert("world".to_string(), 34);
149        map.insert("foo".to_string(), 56);
150        map.insert("bar".to_string(), 78);
151        map.insert("baz".to_string(), 90);
152        roundtrip_with(&map, assert_equal);
153    }
154
155    #[test]
156    fn roundtrip_hash_map_string_string() {
157        let mut hash_map = HashMap::new();
158        hash_map.insert("hello".to_string(), "world".to_string());
159        hash_map.insert("foo".to_string(), "bar".to_string());
160        hash_map.insert("baz".to_string(), "bat".to_string());
161
162        roundtrip_with(&hash_map, assert_equal);
163    }
164
165    #[test]
166    fn roundtrip_hash_map_zsts() {
167        let mut value = HashMap::new();
168        value.insert((), 10);
169        roundtrip(&value);
170
171        let mut value = HashMap::new();
172        value.insert((), ());
173        roundtrip(&value);
174    }
175
176    #[test]
177    fn roundtrip_hash_map_with_custom_hasher_empty() {
178        roundtrip(&HashMap::<i8, i32, RandomState>::default());
179    }
180
181    #[test]
182    fn roundtrip_hash_map_with_custom_hasher() {
183        let mut hash_map: HashMap<i8, _, RandomState> = HashMap::default();
184        hash_map.insert(1, 2);
185        hash_map.insert(3, 4);
186        hash_map.insert(5, 6);
187        hash_map.insert(7, 8);
188
189        roundtrip(&hash_map);
190    }
191
192    #[test]
193    fn roundtrip_hash_map_with_custom_hasher_strings() {
194        let mut hash_map: HashMap<_, _, RandomState> = HashMap::default();
195        hash_map.insert("hello".to_string(), "world".to_string());
196        hash_map.insert("foo".to_string(), "bar".to_string());
197        hash_map.insert("baz".to_string(), "bat".to_string());
198
199        roundtrip_with(&hash_map, assert_equal);
200    }
201
202    #[test]
203    fn get_with() {
204        #[derive(Archive, Serialize, Deserialize, Eq, Hash, PartialEq)]
205        #[rkyv(crate, derive(Eq, Hash, PartialEq))]
206        pub struct Pair(String, String);
207
208        let mut hash_map = HashMap::new();
209        hash_map.insert(
210            Pair("my".to_string(), "key".to_string()),
211            "value".to_string(),
212        );
213        hash_map.insert(
214            Pair("wrong".to_string(), "key".to_string()),
215            "wrong value".to_string(),
216        );
217
218        to_archived(&hash_map, |archived| {
219            let get_with = archived
220                .get_with(&("my", "key"), |input_key, key| {
221                    &(key.0.as_str(), key.1.as_str()) == input_key
222                })
223                .unwrap();
224
225            assert_eq!(get_with.as_str(), "value");
226        });
227    }
228
229    #[test]
230    fn get_seal() {
231        let mut hash_map: HashMap<_, _, RandomState> = HashMap::default();
232        hash_map.insert("hello".to_string(), "world".to_string());
233        hash_map.insert("foo".to_string(), "bar".to_string());
234        hash_map.insert("baz".to_string(), "bat".to_string());
235
236        to_archived(&hash_map, |archived| {
237            let mut value =
238                ArchivedHashMap::get_seal(archived, "hello").unwrap();
239            assert_eq!("world", &*value);
240            let mut string = ArchivedString::as_str_seal(value.as_mut());
241            string.make_ascii_uppercase();
242            assert_eq!("WORLD", &*value);
243        });
244    }
245
246    #[test]
247    fn iter_seal() {
248        let mut hash_map: HashMap<_, _, RandomState> = HashMap::default();
249        hash_map.insert("hello".to_string(), "world".to_string());
250        hash_map.insert("foo".to_string(), "bar".to_string());
251        hash_map.insert("baz".to_string(), "bat".to_string());
252
253        to_archived(&hash_map, |mut archived| {
254            for value in ArchivedHashMap::values_seal(archived.as_mut()) {
255                let mut string = ArchivedString::as_str_seal(value);
256                string.make_ascii_uppercase();
257            }
258            assert_eq!(archived.get("hello").unwrap(), "WORLD");
259            assert_eq!(archived.get("foo").unwrap(), "BAR");
260            assert_eq!(archived.get("baz").unwrap(), "BAT");
261        });
262    }
263
264    #[test]
265    fn large_hash_map() {
266        let mut map = std::collections::HashMap::new();
267        for i in 0..100 {
268            map.insert(i.to_string(), i);
269        }
270        roundtrip_with(&map, assert_equal);
271    }
272
273    #[cfg(feature = "bytecheck")]
274    #[test]
275    fn nested_hash_map() {
276        use rancor::{Error, Panic};
277
278        use crate::{access, to_bytes};
279
280        #[derive(
281            Hash, PartialEq, Eq, Archive, Serialize, Deserialize, Debug,
282        )]
283        #[rkyv(crate, derive(Hash, PartialEq, Eq, Debug))]
284        struct Key(u8, u8);
285
286        let mut nested_map = std::collections::HashMap::new();
287        nested_map.insert(1337u16, 42u16);
288
289        type MyHashMap = HashMap<Key, HashMap<u16, u16>>;
290        let mut map: MyHashMap = std::collections::HashMap::new();
291        map.insert(Key(1, 2), nested_map.clone());
292        map.insert(Key(3, 4), nested_map.clone());
293
294        let encoded = to_bytes::<Error>(&map).unwrap();
295
296        // This .unwrap() fails!
297        let _decoded = access::<Archived<MyHashMap>, Panic>(&encoded).unwrap();
298    }
299}