Skip to main content

rkyv/collections/btree/map/
mod.rs

1//! [`Archive`](crate::Archive) implementation for B-tree maps.
2
3use core::{
4    borrow::Borrow,
5    cmp::Ordering,
6    fmt,
7    marker::PhantomData,
8    mem::{size_of, MaybeUninit},
9    ops::{ControlFlow, Index},
10    ptr::addr_of_mut,
11    slice,
12};
13
14use munge::munge;
15use rancor::{fail, Fallible, Source};
16
17use crate::{
18    collections::util::IteratorLengthMismatch,
19    primitive::{ArchivedUsize, FixedUsize},
20    seal::Seal,
21    ser::{Allocator, Writer, WriterExt as _},
22    traits::NoUndef,
23    util::{InlineVec, SerVec},
24    Place, Portable, RelPtr, Serialize,
25};
26
27// TODO(#515): Get Iterator APIs working without the `alloc` feature enabled
28#[cfg(feature = "alloc")]
29mod iter;
30
31#[cfg(feature = "alloc")]
32pub use self::iter::*;
33
34// B-trees are typically characterized as having a branching factor of B.
35// However, in this implementation our B-trees are characterized as having a
36// number of entries per node E where E = B - 1. This is done because it's
37// easier to add an additional node pointer to each inner node than it is to
38// store one less entry per inner node. Because generic const exprs are not
39// stable, we can't declare a field `entries: [Entry; { B - 1 }]`. But we can
40// declare `branches: [RelPtr; E]` and then add another `last: RelPtr`
41// field. When the branching factor B is needed, it will be calculated as E + 1.
42
43const fn nodes_in_level<const E: usize>(i: u32) -> usize {
44    // The root of the tree has one node, and each level down has B times as
45    // many nodes at the last. Therefore, the number of nodes in the I-th level
46    // is equal to B^I.
47
48    (E + 1).pow(i)
49}
50
51const fn entries_in_full_tree<const E: usize>(h: u32) -> usize {
52    // The number of nodes in each layer I of a B-tree is equal to B^I. At layer
53    // I = 0, the number of nodes is exactly one. At layer I = 1, the number of
54    // nodes is B, at layer I = 2 the number of nodes is B^2, and so on. The
55    // total number of nodes is equal to the sum from 0 to H - 1 of B^I. Since
56    // this is the sum of a geometric progression, we have the closed-form
57    // solution N = (B^H - 1) / (B - 1). Since the number of entries per node is
58    // equal to B - 1, we thus have the solution that the number of entries in a
59    // B-tree of height H is equal to B^H - 1.
60
61    // Note that this is one less than the number of nodes in the level after
62    // the final level of the B-tree.
63
64    nodes_in_level::<E>(h) - 1
65}
66
67const fn entries_to_height<const E: usize>(n: usize) -> u32 {
68    // Solving B^H - 1 = N for H yields H = log_B(N + 1). However, we'll be
69    // using an integer logarithm, and so the value of H will be rounded down
70    // which underestimates the height of the tree:
71    // => H = ilog_B(N + 1) = floor(log_B(N + 1)).
72    // To compensate for this, we'll calculate the height for a tree with a
73    // greater number of nodes and choose this greater number so that rounding
74    // down will always yield the correct result.
75
76    // The minimum value which yields a height of H is exactly B^H - 1, so we
77    // need to add a large enough correction to always be greater than or equal
78    // to that value. The maximum value which yields a height of H is one less
79    // than the number of nodes in the next-largest B-tree, which is equal to
80    // B^(H + 1) - 1. This gives the following relationships for N:
81    // => B^(H - 1) - 1 < N <= B^H - 1
82    // And the desired relationships for the corrected number of entries C(N):
83    // => B^H - 1 <= C(N) < B^(H + 1) - 1
84
85    // First, we can add 1 to the two ends of our first set of relationships
86    // to change whether equality is allowed. We can do this because all entries
87    // are integers. This makes the relationships match the desired
88    // relationships for C(N):
89    // => B^(H - 1) - 1 + 1 <= N < B^H - 1 + 1
90    // => B^(H - 1) <= N < B^H
91    // Let's choose a function to map the lower bound for N to the desired lower
92    // bound for C(N):
93    // => C(B^(H - 1)) = B^(H - 1)
94    // A straightforward choice would be C(N) = B * N - 1. Substituting yields:
95    // => C(B^(H - 1)) <= C(N) < C(B^H)
96    // => B * B^(H - 1) - 1 <= B * N - 1 < B * B^H - 1
97    // => B^H - 1 <= B * N - 1 < B^(H + 1) - 1
98    // These exactly match the desired bounds, so this is the function we want.
99
100    // Putting it all together:
101    // => H = ilog_B(C(N) + 1) = ilog_b(B * N - 1 + 1) = ilog_b(B * N)
102    // => H = 1 + ilog_b(N)
103    1 + n.ilog(E + 1)
104}
105
106const fn ll_entries<const E: usize>(height: u32, n: usize) -> usize {
107    // The number of entries not in the last level is equal to the number of
108    // entries in a full B-tree of height H - 1. The number of entries in
109    // the last level is thus the total number of entries minus the number
110    // of entries not in the last level.
111    n - entries_in_full_tree::<E>(height - 1)
112}
113
114#[derive(Clone, Copy, Portable)]
115#[cfg_attr(feature = "bytecheck", derive(bytecheck::CheckBytes))]
116#[rkyv(crate)]
117#[repr(u8)]
118enum NodeKind {
119    Leaf,
120    Inner,
121}
122
123// SAFETY: `NodeKind` is `repr(u8)` and so always consists of a single
124// well-defined byte.
125unsafe impl NoUndef for NodeKind {}
126
127#[derive(Portable)]
128#[rkyv(crate)]
129#[repr(C)]
130struct Node<K, V, const E: usize> {
131    kind: NodeKind,
132    keys: [MaybeUninit<K>; E],
133    values: [MaybeUninit<V>; E],
134}
135
136#[derive(Portable)]
137#[rkyv(crate)]
138#[repr(C)]
139struct LeafNode<K, V, const E: usize> {
140    node: Node<K, V, E>,
141    len: ArchivedUsize,
142}
143
144#[cfg_attr(feature = "bytecheck", derive(bytecheck::CheckBytes))]
145#[derive(Portable)]
146#[rkyv(crate)]
147#[repr(C)]
148struct InnerNode<K, V, const E: usize> {
149    node: Node<K, V, E>,
150    lesser_nodes: [RelPtr<Node<K, V, E>>; E],
151    greater_node: RelPtr<Node<K, V, E>>,
152}
153
154/// An archived [`BTreeMap`](crate::alloc::collections::BTreeMap).
155#[cfg_attr(
156    feature = "bytecheck",
157    derive(bytecheck::CheckBytes),
158    bytecheck(verify)
159)]
160#[derive(Portable)]
161#[rkyv(crate)]
162#[repr(C)]
163pub struct ArchivedBTreeMap<K, V, const E: usize = 5> {
164    // The type of the root node is determined at runtime because it may point
165    // to:
166    // - Nothing if the length is zero
167    // - A leaf node if there is only one node
168    // - Or an inner node if there are multiple nodes
169    root: RelPtr<Node<K, V, E>>,
170    len: ArchivedUsize,
171    _phantom: PhantomData<(K, V)>,
172}
173
174impl<K, V, const E: usize> ArchivedBTreeMap<K, V, E> {
175    /// Returns whether the B-tree map contains the given key.
176    pub fn contains_key<Q>(&self, key: &Q) -> bool
177    where
178        Q: Ord + ?Sized,
179        K: Borrow<Q> + Ord,
180    {
181        self.get_key_value(key).is_some()
182    }
183
184    /// Returns the value associated with the given key, or `None` if the key is
185    /// not present in the B-tree map.
186    pub fn get<Q>(&self, key: &Q) -> Option<&V>
187    where
188        Q: Ord + ?Sized,
189        K: Borrow<Q> + Ord,
190    {
191        Some(self.get_key_value(key)?.1)
192    }
193
194    /// Returns the mutable value associated with the given key, or `None` if
195    /// the key is not present in the B-tree map.
196    pub fn get_seal<'a, Q>(this: Seal<'a, Self>, key: &Q) -> Option<Seal<'a, V>>
197    where
198        Q: Ord + ?Sized,
199        K: Borrow<Q> + Ord,
200    {
201        Some(Self::get_key_value_seal(this, key)?.1)
202    }
203
204    /// Returns true if the B-tree map contains no entries.
205    pub fn is_empty(&self) -> bool {
206        self.len() == 0
207    }
208
209    /// Returns the number of entries in the B-tree map.
210    pub fn len(&self) -> usize {
211        self.len.to_native() as usize
212    }
213
214    /// Gets the key-value pair associated with the given key, or `None` if the
215    /// key is not present in the B-tree map.
216    pub fn get_key_value<Q>(&self, key: &Q) -> Option<(&K, &V)>
217    where
218        Q: Ord + ?Sized,
219        K: Borrow<Q> + Ord,
220    {
221        self.get_key_value_with(key, |q, k| q.cmp(k.borrow()))
222    }
223
224    /// Gets the key-value pair associated with the given key, or `None` if the
225    /// key is not present in the B-tree map.
226    ///
227    /// This method uses the supplied comparison function to compare the key to
228    /// elements.
229    pub fn get_key_value_with<Q, C>(&self, key: &Q, cmp: C) -> Option<(&K, &V)>
230    where
231        Q: Ord + ?Sized,
232        C: Fn(&Q, &K) -> Ordering,
233        K: Ord,
234    {
235        let this = (self as *const Self).cast_mut();
236        Self::get_key_value_raw(this, key, cmp)
237            .map(|(k, v)| (unsafe { &*k }, unsafe { &*v }))
238    }
239
240    /// Gets the mutable key-value pair associated with the given key, or `None`
241    /// if the key is not present in the B-tree map.
242    pub fn get_key_value_seal<'a, Q>(
243        this: Seal<'a, Self>,
244        key: &Q,
245    ) -> Option<(&'a K, Seal<'a, V>)>
246    where
247        Q: Ord + ?Sized,
248        K: Borrow<Q> + Ord,
249    {
250        Self::get_key_value_seal_with(this, key, |q, k| q.cmp(k.borrow()))
251    }
252
253    /// Gets the mutable key-value pair associated with the given key, or `None`
254    /// if the key is not present in the B-tree map.
255    ///
256    /// This method uses the supplied comparison function to compare the key to
257    /// elements.
258    pub fn get_key_value_seal_with<'a, Q, C>(
259        this: Seal<'a, Self>,
260        key: &Q,
261        cmp: C,
262    ) -> Option<(&'a K, Seal<'a, V>)>
263    where
264        Q: Ord + ?Sized,
265        C: Fn(&Q, &K) -> Ordering,
266        K: Ord,
267    {
268        let this = unsafe { Seal::unseal_unchecked(this) as *mut Self };
269        Self::get_key_value_raw(this, key, cmp)
270            .map(|(k, v)| (unsafe { &*k }, Seal::new(unsafe { &mut *v })))
271    }
272
273    fn get_key_value_raw<Q, C>(
274        this: *mut Self,
275        key: &Q,
276        cmp: C,
277    ) -> Option<(*mut K, *mut V)>
278    where
279        Q: Ord + ?Sized,
280        C: Fn(&Q, &K) -> Ordering,
281        K: Ord,
282    {
283        let len = unsafe { (*this).len.to_native() };
284        if len == 0 {
285            return None;
286        }
287
288        let root_ptr = unsafe { addr_of_mut!((*this).root) };
289        let mut current = unsafe { RelPtr::as_ptr_raw(root_ptr) };
290        'outer: loop {
291            let kind = unsafe { (*current).kind };
292
293            match kind {
294                NodeKind::Leaf => {
295                    let leaf = current.cast::<LeafNode<K, V, E>>();
296                    let len = unsafe { (*leaf).len };
297
298                    for i in 0..len.to_native() as usize {
299                        let k = unsafe {
300                            addr_of_mut!((*current).keys[i]).cast::<K>()
301                        };
302                        let ordering = cmp(key, unsafe { &*k });
303
304                        match ordering {
305                            Ordering::Equal => {
306                                let v = unsafe {
307                                    addr_of_mut!((*current).values[i])
308                                        .cast::<V>()
309                                };
310                                return Some((k, v));
311                            }
312                            Ordering::Less => return None,
313                            Ordering::Greater => (),
314                        }
315                    }
316
317                    return None;
318                }
319                NodeKind::Inner => {
320                    let inner = current.cast::<InnerNode<K, V, E>>();
321
322                    for i in 0..E {
323                        let k = unsafe {
324                            addr_of_mut!((*current).keys[i]).cast::<K>()
325                        };
326                        let ordering = cmp(key, unsafe { &*k });
327
328                        match ordering {
329                            Ordering::Equal => {
330                                let v = unsafe {
331                                    addr_of_mut!((*current).values[i])
332                                        .cast::<V>()
333                                };
334                                return Some((k, v));
335                            }
336                            Ordering::Less => {
337                                let lesser = unsafe {
338                                    addr_of_mut!((*inner).lesser_nodes[i])
339                                };
340                                let lesser_is_invalid =
341                                    unsafe { RelPtr::is_invalid_raw(lesser) };
342                                if !lesser_is_invalid {
343                                    current =
344                                        unsafe { RelPtr::as_ptr_raw(lesser) };
345                                    continue 'outer;
346                                } else {
347                                    return None;
348                                }
349                            }
350                            Ordering::Greater => (),
351                        }
352                    }
353
354                    let inner = current.cast::<InnerNode<K, V, E>>();
355                    let greater =
356                        unsafe { addr_of_mut!((*inner).greater_node) };
357                    let greater_is_invalid =
358                        unsafe { RelPtr::is_invalid_raw(greater) };
359                    if !greater_is_invalid {
360                        current = unsafe { RelPtr::as_ptr_raw(greater) };
361                    } else {
362                        return None;
363                    }
364                }
365            }
366        }
367    }
368
369    /// Resolves an `ArchivedBTreeMap` from the given length, resolver, and
370    /// output place.
371    pub fn resolve_from_len(
372        len: usize,
373        resolver: BTreeMapResolver,
374        out: Place<Self>,
375    ) {
376        munge!(let ArchivedBTreeMap { root, len: out_len, _phantom: _ } = out);
377
378        if len == 0 {
379            RelPtr::emplace_invalid(root);
380        } else {
381            RelPtr::emplace(resolver.root_node_pos as usize, root);
382        }
383
384        out_len.write(ArchivedUsize::from_native(len as FixedUsize));
385    }
386
387    /// Serializes an `ArchivedBTreeMap` from the given iterator and serializer.
388    pub fn serialize_from_ordered_iter<I, BKU, BVU, KU, VU, S>(
389        mut iter: I,
390        serializer: &mut S,
391    ) -> Result<BTreeMapResolver, S::Error>
392    where
393        I: ExactSizeIterator<Item = (BKU, BVU)>,
394        BKU: Borrow<KU>,
395        BVU: Borrow<VU>,
396        KU: Serialize<S, Archived = K>,
397        VU: Serialize<S, Archived = V>,
398        S: Fallible + Allocator + Writer + ?Sized,
399        S::Error: Source,
400    {
401        let len = iter.len();
402
403        if len == 0 {
404            let actual = iter.count();
405            if actual != 0 {
406                fail!(IteratorLengthMismatch {
407                    expected: 0,
408                    actual,
409                });
410            }
411            return Ok(BTreeMapResolver { root_node_pos: 0 });
412        }
413
414        let height = entries_to_height::<E>(len);
415        let ll_entries = ll_entries::<E>(height, len);
416
417        SerVec::with_capacity(
418            serializer,
419            height as usize - 1,
420            |open_inners, serializer| {
421                for _ in 0..height - 1 {
422                    open_inners
423                        .push(InlineVec::<(BKU, BVU, Option<usize>), E>::new());
424                }
425
426                let mut open_leaf = InlineVec::<(BKU, BVU), E>::new();
427
428                let mut child_node_pos = None;
429                let mut leaf_entries = 0;
430                while let Some((key, value)) = iter.next() {
431                    open_leaf.push((key, value));
432                    leaf_entries += 1;
433
434                    if leaf_entries == ll_entries
435                        || open_leaf.len() == open_leaf.capacity()
436                    {
437                        // Close open leaf
438                        child_node_pos =
439                            Some(Self::close_leaf(&open_leaf, serializer)?);
440                        open_leaf.clear();
441
442                        // If on the transition node, fill and close open inner
443                        if leaf_entries == ll_entries {
444                            if let Some(mut inner) = open_inners.pop() {
445                                while inner.len() < inner.capacity() {
446                                    if let Some((k, v)) = iter.next() {
447                                        inner.push((k, v, child_node_pos));
448                                        child_node_pos = None;
449                                    } else {
450                                        break;
451                                    }
452                                }
453
454                                child_node_pos = Some(Self::close_inner(
455                                    &inner,
456                                    child_node_pos,
457                                    serializer,
458                                )?);
459                            }
460                        }
461
462                        // Add closed node to open inner
463                        let mut popped = 0;
464                        while let Some(last_inner) = open_inners.last_mut() {
465                            if last_inner.len() == last_inner.capacity() {
466                                // Close open inner
467                                child_node_pos = Some(Self::close_inner(
468                                    last_inner,
469                                    child_node_pos,
470                                    serializer,
471                                )?);
472                                open_inners.pop();
473                                popped += 1;
474                            } else {
475                                let (key, value) = iter.next().unwrap();
476                                last_inner.push((key, value, child_node_pos));
477                                child_node_pos = None;
478                                for _ in 0..popped {
479                                    open_inners.push(InlineVec::default());
480                                }
481                                break;
482                            }
483                        }
484                    }
485                }
486
487                if !open_leaf.is_empty() {
488                    // Close open leaf
489                    child_node_pos =
490                        Some(Self::close_leaf(&open_leaf, serializer)?);
491                    open_leaf.clear();
492                }
493
494                // Close open inners
495                while let Some(inner) = open_inners.pop() {
496                    child_node_pos = Some(Self::close_inner(
497                        &inner,
498                        child_node_pos,
499                        serializer,
500                    )?);
501                }
502
503                debug_assert!(open_inners.is_empty());
504                debug_assert!(open_leaf.is_empty());
505
506                let leftovers = iter.count();
507                if leftovers != 0 {
508                    fail!(IteratorLengthMismatch {
509                        expected: len,
510                        actual: len + leftovers,
511                    });
512                }
513
514                Ok(BTreeMapResolver {
515                    root_node_pos: child_node_pos.unwrap() as FixedUsize,
516                })
517            },
518        )?
519    }
520
521    fn close_leaf<BKU, BVU, KU, VU, S>(
522        items: &[(BKU, BVU)],
523        serializer: &mut S,
524    ) -> Result<usize, S::Error>
525    where
526        BKU: Borrow<KU>,
527        BVU: Borrow<VU>,
528        KU: Serialize<S, Archived = K>,
529        VU: Serialize<S, Archived = V>,
530        S: Writer + Fallible + ?Sized,
531    {
532        let mut resolvers = InlineVec::<(KU::Resolver, VU::Resolver), E>::new();
533        for (key, value) in items {
534            resolvers.push((
535                key.borrow().serialize(serializer)?,
536                value.borrow().serialize(serializer)?,
537            ));
538        }
539
540        let pos = serializer.align_for::<LeafNode<K, V, E>>()?;
541        let mut node = MaybeUninit::<LeafNode<K, V, E>>::uninit();
542        // SAFETY: `node` is properly aligned and valid for writes of
543        // `size_of::<LeafNode<K, V, E>>()` bytes.
544        unsafe {
545            node.as_mut_ptr().write_bytes(0, 1);
546        }
547
548        let node_place =
549            unsafe { Place::new_unchecked(pos, node.as_mut_ptr()) };
550
551        munge! {
552            let LeafNode {
553                node: Node {
554                    kind,
555                    keys,
556                    values,
557                },
558                len,
559            } = node_place;
560        }
561        kind.write(NodeKind::Leaf);
562        len.write(ArchivedUsize::from_native(items.len() as FixedUsize));
563        for (i, ((k, v), (kr, vr))) in
564            items.iter().zip(resolvers.drain()).enumerate()
565        {
566            let out_key = unsafe { keys.index(i).cast_unchecked() };
567            k.borrow().resolve(kr, out_key);
568            let out_value = unsafe { values.index(i).cast_unchecked() };
569            v.borrow().resolve(vr, out_value);
570        }
571
572        let bytes = unsafe {
573            slice::from_raw_parts(
574                node.as_ptr().cast::<u8>(),
575                size_of::<LeafNode<K, V, E>>(),
576            )
577        };
578        serializer.write(bytes)?;
579
580        Ok(pos)
581    }
582
583    fn close_inner<BKU, BVU, KU, VU, S>(
584        items: &[(BKU, BVU, Option<usize>)],
585        greater_node_pos: Option<usize>,
586        serializer: &mut S,
587    ) -> Result<usize, S::Error>
588    where
589        BKU: Borrow<KU>,
590        BVU: Borrow<VU>,
591        KU: Serialize<S, Archived = K>,
592        VU: Serialize<S, Archived = V>,
593        S: Writer + Fallible + ?Sized,
594    {
595        debug_assert_eq!(items.len(), E);
596
597        let mut resolvers = InlineVec::<(KU::Resolver, VU::Resolver), E>::new();
598        for (key, value, _) in items {
599            resolvers.push((
600                key.borrow().serialize(serializer)?,
601                value.borrow().serialize(serializer)?,
602            ));
603        }
604
605        let pos = serializer.align_for::<InnerNode<K, V, E>>()?;
606        let mut node = MaybeUninit::<InnerNode<K, V, E>>::uninit();
607        // SAFETY: `node` is properly aligned and valid for writes of
608        // `size_of::<InnerNode<K, V, E>>()` bytes.
609        unsafe {
610            node.as_mut_ptr().write_bytes(0, 1);
611        }
612
613        let node_place =
614            unsafe { Place::new_unchecked(pos, node.as_mut_ptr()) };
615
616        munge! {
617            let InnerNode {
618                node: Node {
619                    kind,
620                    keys,
621                    values,
622                },
623                lesser_nodes,
624                greater_node,
625            } = node_place;
626        }
627
628        kind.write(NodeKind::Inner);
629        for (i, ((k, v, l), (kr, vr))) in
630            items.iter().zip(resolvers.drain()).enumerate()
631        {
632            let out_key = unsafe { keys.index(i).cast_unchecked() };
633            k.borrow().resolve(kr, out_key);
634            let out_value = unsafe { values.index(i).cast_unchecked() };
635            v.borrow().resolve(vr, out_value);
636
637            let out_lesser_node = unsafe { lesser_nodes.index(i) };
638            if let Some(lesser_node) = l {
639                RelPtr::emplace(*lesser_node, out_lesser_node);
640            } else {
641                RelPtr::emplace_invalid(out_lesser_node);
642            }
643        }
644
645        if let Some(greater_node_pos) = greater_node_pos {
646            RelPtr::emplace(greater_node_pos, greater_node);
647        } else {
648            RelPtr::emplace_invalid(greater_node);
649        }
650
651        let bytes = unsafe {
652            slice::from_raw_parts(
653                node.as_ptr().cast::<u8>(),
654                size_of::<InnerNode<K, V, E>>(),
655            )
656        };
657        serializer.write(bytes)?;
658
659        Ok(pos)
660    }
661
662    /// Visits every key-value pair in the B-tree with a function.
663    ///
664    /// If `f` returns `ControlFlow::Break`, `visit` will return `Some` with the
665    /// broken value. If `f` returns `Continue` for every pair in the tree,
666    /// `visit` will return `None`.
667    pub fn visit<T>(
668        &self,
669        mut f: impl FnMut(&K, &V) -> ControlFlow<T>,
670    ) -> Option<T> {
671        if self.is_empty() {
672            None
673        } else {
674            let root = &self.root;
675            let root_ptr = unsafe { root.as_ptr().cast::<Node<K, V, E>>() };
676            let mut call_inner = |k: *mut K, v: *mut V| unsafe { f(&*k, &*v) };
677            match Self::visit_raw(root_ptr.cast_mut(), &mut call_inner) {
678                ControlFlow::Continue(()) => None,
679                ControlFlow::Break(x) => Some(x),
680            }
681        }
682    }
683
684    /// Visits every mutable key-value pair in the B-tree with a function.
685    ///
686    /// If `f` returns `ControlFlow::Break`, `visit` will return `Some` with the
687    /// broken value. If `f` returns `Continue` for every pair in the tree,
688    /// `visit` will return `None`.
689    pub fn visit_seal<T>(
690        this: Seal<'_, Self>,
691        mut f: impl FnMut(&K, Seal<'_, V>) -> ControlFlow<T>,
692    ) -> Option<T> {
693        if this.is_empty() {
694            None
695        } else {
696            munge!(let Self { root, .. } = this);
697            let root_ptr =
698                unsafe { RelPtr::as_mut_ptr(root).cast::<Node<K, V, E>>() };
699            let mut call_inner =
700                |k: *mut K, v: *mut V| unsafe { f(&*k, Seal::new(&mut *v)) };
701            match Self::visit_raw(root_ptr, &mut call_inner) {
702                ControlFlow::Continue(()) => None,
703                ControlFlow::Break(x) => Some(x),
704            }
705        }
706    }
707
708    fn visit_raw<T>(
709        current: *mut Node<K, V, E>,
710        f: &mut impl FnMut(*mut K, *mut V) -> ControlFlow<T>,
711    ) -> ControlFlow<T> {
712        let kind = unsafe { (*current).kind };
713
714        match kind {
715            NodeKind::Leaf => {
716                let leaf = current.cast::<LeafNode<K, V, E>>();
717                let len = unsafe { (*leaf).len };
718                for i in 0..len.to_native() as usize {
719                    Self::visit_key_value_raw(current, i, f)?;
720                }
721            }
722            NodeKind::Inner => {
723                let inner = current.cast::<InnerNode<K, V, E>>();
724
725                // Visit lesser nodes and key-value pairs
726                for i in 0..E {
727                    let lesser =
728                        unsafe { addr_of_mut!((*inner).lesser_nodes[i]) };
729                    let lesser_is_invalid =
730                        unsafe { RelPtr::is_invalid_raw(lesser) };
731                    if !lesser_is_invalid {
732                        let lesser_ptr = unsafe { RelPtr::as_ptr_raw(lesser) };
733                        Self::visit_raw(lesser_ptr, f)?;
734                    }
735                    Self::visit_key_value_raw(current, i, f)?;
736                }
737
738                // Visit greater node
739                let greater = unsafe { addr_of_mut!((*inner).greater_node) };
740                let greater_is_invalid =
741                    unsafe { RelPtr::is_invalid_raw(greater) };
742                if !greater_is_invalid {
743                    let greater_ptr = unsafe {
744                        RelPtr::as_ptr_raw(greater).cast::<Node<K, V, E>>()
745                    };
746                    Self::visit_raw(greater_ptr, f)?;
747                }
748            }
749        }
750
751        ControlFlow::Continue(())
752    }
753
754    fn visit_key_value_raw<T>(
755        current: *mut Node<K, V, E>,
756        i: usize,
757        f: &mut impl FnMut(*mut K, *mut V) -> ControlFlow<T>,
758    ) -> ControlFlow<T> {
759        let key_ptr = unsafe { addr_of_mut!((*current).keys[i]).cast::<K>() };
760        let value_ptr =
761            unsafe { addr_of_mut!((*current).values[i]).cast::<V>() };
762        f(key_ptr, value_ptr)
763    }
764}
765
766impl<K, V, const E: usize> fmt::Debug for ArchivedBTreeMap<K, V, E>
767where
768    K: fmt::Debug,
769    V: fmt::Debug,
770{
771    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
772        let mut map = f.debug_map();
773        self.visit(|k, v| {
774            map.entry(k, v);
775            ControlFlow::<()>::Continue(())
776        });
777        map.finish()
778    }
779}
780
781// TODO(#515): ungate this impl
782#[cfg(feature = "alloc")]
783impl<K, V, const E: usize> Eq for ArchivedBTreeMap<K, V, E>
784where
785    K: PartialEq,
786    V: PartialEq,
787{
788}
789
790impl<K, V, Q, const E: usize> Index<&Q> for ArchivedBTreeMap<K, V, E>
791where
792    Q: Ord + ?Sized,
793    K: Borrow<Q> + Ord,
794{
795    type Output = V;
796
797    fn index(&self, key: &Q) -> &Self::Output {
798        self.get(key).unwrap()
799    }
800}
801
802// TODO(#515): ungate this impl
803#[cfg(feature = "alloc")]
804impl<K, V, const E1: usize, const E2: usize>
805    PartialEq<ArchivedBTreeMap<K, V, E2>> for ArchivedBTreeMap<K, V, E1>
806where
807    K: PartialEq,
808    V: PartialEq,
809{
810    fn eq(&self, other: &ArchivedBTreeMap<K, V, E2>) -> bool {
811        if self.len() != other.len() {
812            return false;
813        }
814        let mut i = other.iter();
815        self.visit(|lk, lv| {
816            let (rk, rv) = i.next().unwrap();
817            if lk != rk || lv != rv {
818                ControlFlow::Break(())
819            } else {
820                ControlFlow::Continue(())
821            }
822        })
823        .is_none()
824    }
825}
826
827impl<K: core::hash::Hash, V: core::hash::Hash, const E: usize> core::hash::Hash
828    for ArchivedBTreeMap<K, V, E>
829{
830    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
831        self.visit(|k, v| {
832            (*k).hash(state);
833            (*v).hash(state);
834            ControlFlow::<()>::Continue(())
835        });
836    }
837}
838
839/// The resolver for [`ArchivedBTreeMap`].
840pub struct BTreeMapResolver {
841    root_node_pos: FixedUsize,
842}
843
844#[cfg(feature = "bytecheck")]
845mod verify {
846    use core::{alloc::Layout, error::Error, fmt, ptr::addr_of};
847
848    use bytecheck::{CheckBytes, Verify};
849    use rancor::{fail, Fallible, Source};
850
851    use super::{ArchivedBTreeMap, InnerNode, Node};
852    use crate::{
853        collections::btree_map::{LeafNode, NodeKind},
854        validation::{ArchiveContext, ArchiveContextExt as _},
855        RelPtr,
856    };
857
858    #[derive(Debug)]
859    struct InvalidLength {
860        len: usize,
861        maximum: usize,
862    }
863
864    impl fmt::Display for InvalidLength {
865        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
866            write!(
867                f,
868                "Invalid length in B-tree node: len {} was greater than \
869                 maximum {}",
870                self.len, self.maximum
871            )
872        }
873    }
874
875    impl Error for InvalidLength {}
876
877    unsafe impl<C, K, V, const E: usize> Verify<C> for ArchivedBTreeMap<K, V, E>
878    where
879        C: Fallible + ArchiveContext + ?Sized,
880        C::Error: Source,
881        K: CheckBytes<C>,
882        V: CheckBytes<C>,
883    {
884        fn verify(&self, context: &mut C) -> Result<(), C::Error> {
885            let len = self.len();
886
887            if len == 0 {
888                return Ok(());
889            }
890
891            check_node_rel_ptr::<C, K, V, E>(&self.root, context)
892        }
893    }
894
895    fn check_node_rel_ptr<C, K, V, const E: usize>(
896        node_rel_ptr: &RelPtr<Node<K, V, E>>,
897        context: &mut C,
898    ) -> Result<(), C::Error>
899    where
900        C: Fallible + ArchiveContext + ?Sized,
901        C::Error: Source,
902        K: CheckBytes<C>,
903        V: CheckBytes<C>,
904    {
905        let node_ptr = node_rel_ptr.as_ptr_wrapping().cast::<Node<K, V, E>>();
906        context.check_subtree_ptr(
907            node_ptr.cast::<u8>(),
908            &Layout::new::<Node<K, V, E>>(),
909        )?;
910
911        // SAFETY: We checked to make sure that `node_ptr` is properly aligned
912        // and dereferenceable by calling `check_subtree_ptr`.
913        let kind_ptr = unsafe { addr_of!((*node_ptr).kind) };
914        // SAFETY: `kind_ptr` is a pointer to a subfield of `node_ptr` and so is
915        // also properly aligned and dereferenceable.
916        unsafe {
917            CheckBytes::check_bytes(kind_ptr, context)?;
918        }
919        // SAFETY: `kind_ptr` was always properly aligned and dereferenceable,
920        // and we just checked to make sure it pointed to a valid `NodeKind`.
921        let kind = unsafe { kind_ptr.read() };
922
923        match kind {
924            NodeKind::Leaf => {
925                // SAFETY:
926                // We checked to make sure that `node_ptr` is properly aligned,
927                // dereferenceable, and contained entirely within `context`'s
928                // buffer by calling `check_subtree_ptr`.
929                unsafe {
930                    check_leaf_node::<C, K, V, E>(node_ptr.cast(), context)?
931                }
932            }
933            NodeKind::Inner => {
934                // SAFETY:
935                // We checked to make sure that `node_ptr` is properly aligned
936                // and dereferenceable.
937                unsafe {
938                    check_inner_node::<C, K, V, E>(node_ptr.cast(), context)?
939                }
940            }
941        }
942
943        Ok(())
944    }
945
946    /// # Safety
947    ///
948    /// `node_ptr` must be properly aligned, dereferenceable, and contained
949    /// within `context`'s buffer.
950    unsafe fn check_leaf_node<C, K, V, const E: usize>(
951        node_ptr: *const LeafNode<K, V, E>,
952        context: &mut C,
953    ) -> Result<(), C::Error>
954    where
955        C: Fallible + ArchiveContext + ?Sized,
956        C::Error: Source,
957        K: CheckBytes<C>,
958        V: CheckBytes<C>,
959    {
960        context.in_subtree(node_ptr, |context| {
961            // SAFETY: We checked to make sure that `node_ptr` is properly
962            // aligned and dereferenceable by calling
963            // `check_subtree_ptr`.
964            let len_ptr = unsafe { addr_of!((*node_ptr).len) };
965            // SAFETY: `len_ptr` is a pointer to a subfield of `node_ptr` and so
966            // is also properly aligned and dereferenceable.
967            unsafe {
968                CheckBytes::check_bytes(len_ptr, context)?;
969            }
970            // SAFETY: `len_ptr` was always properly aligned and
971            // dereferenceable, and we just checked to make sure it
972            // pointed to a valid `ArchivedUsize`.
973            let len = unsafe { &*len_ptr };
974            let len = len.to_native() as usize;
975            if len > E {
976                fail!(InvalidLength { len, maximum: E });
977            }
978
979            // SAFETY: We checked that `node_ptr` is properly-aligned and
980            // dereferenceable.
981            let node_ptr = unsafe { addr_of!((*node_ptr).node) };
982            // SAFETY:
983            // - We checked that `node_ptr` is properly aligned and
984            //   dereferenceable.
985            // - We checked that `len` is less than or equal to `E`.
986            unsafe {
987                check_node_entries(node_ptr, len, context)?;
988            }
989
990            Ok(())
991        })
992    }
993
994    /// # Safety
995    ///
996    /// - `node_ptr` must point to a valid `Node<K, V, E>`.
997    /// - `len` must be less than or equal to `E`.
998    unsafe fn check_node_entries<C, K, V, const E: usize>(
999        node_ptr: *const Node<K, V, E>,
1000        len: usize,
1001        context: &mut C,
1002    ) -> Result<(), C::Error>
1003    where
1004        C: Fallible + ArchiveContext + ?Sized,
1005        C::Error: Source,
1006        K: CheckBytes<C>,
1007        V: CheckBytes<C>,
1008    {
1009        for i in 0..len {
1010            // SAFETY: The caller has guaranteed that `node_ptr` is properly
1011            // aligned and dereferenceable.
1012            let key_ptr = unsafe { addr_of!((*node_ptr).keys[i]).cast::<K>() };
1013            // SAFETY: The caller has guaranteed that `node_ptr` is properly
1014            // aligned and dereferenceable.
1015            let value_ptr =
1016                unsafe { addr_of!((*node_ptr).values[i]).cast::<V>() };
1017            unsafe {
1018                K::check_bytes(key_ptr, context)?;
1019            }
1020            // SAFETY: `value_ptr` is a subfield of a node, and so is guaranteed
1021            // to be properly aligned and point to enough bytes for a `V`.
1022            unsafe {
1023                V::check_bytes(value_ptr, context)?;
1024            }
1025        }
1026
1027        Ok(())
1028    }
1029
1030    /// # Safety
1031    ///
1032    /// - `node_ptr` must be properly aligned and dereferenceable.
1033    /// - `len` must be less than or equal to `E`.
1034    unsafe fn check_inner_node<C, K, V, const E: usize>(
1035        node_ptr: *const InnerNode<K, V, E>,
1036        context: &mut C,
1037    ) -> Result<(), C::Error>
1038    where
1039        C: Fallible + ArchiveContext + ?Sized,
1040        C::Error: Source,
1041        K: CheckBytes<C>,
1042        V: CheckBytes<C>,
1043    {
1044        context.in_subtree(node_ptr, |context| {
1045            for i in 0..E {
1046                // SAFETY: `in_subtree` guarantees that `node_ptr` is properly
1047                // aligned and dereferenceable.
1048                let lesser_node_ptr =
1049                    unsafe { addr_of!((*node_ptr).lesser_nodes[i]) };
1050                // SAFETY: `lesser_node_ptr` is a subfield of an inner node, and
1051                // so is guaranteed to be properly aligned and point to enough
1052                // bytes for a `RelPtr`.
1053                unsafe {
1054                    RelPtr::check_bytes(lesser_node_ptr, context)?;
1055                }
1056                // SAFETY: We just checked the `lesser_node_ptr` and it
1057                // succeeded, so it's safe to dereference.
1058                let lesser_node = unsafe { &*lesser_node_ptr };
1059                if !lesser_node.is_invalid() {
1060                    check_node_rel_ptr::<C, K, V, E>(lesser_node, context)?;
1061                }
1062            }
1063            // SAFETY: We checked that `node_ptr` is properly aligned and
1064            // dereferenceable.
1065            let greater_node_ptr =
1066                unsafe { addr_of!((*node_ptr).greater_node) };
1067            // SAFETY: `greater_node_ptr` is a subfield of an inner node, and so
1068            // is guaranteed to be properly aligned and point to enough bytes
1069            // for a `RelPtr`.
1070            unsafe {
1071                RelPtr::check_bytes(greater_node_ptr, context)?;
1072            }
1073            // SAFETY: We just checked the `greater_node_ptr` and it succeeded,
1074            // so it's safe to dereference.
1075            let greater_node = unsafe { &*greater_node_ptr };
1076            if !greater_node.is_invalid() {
1077                check_node_rel_ptr::<C, K, V, E>(greater_node, context)?;
1078            }
1079
1080            // SAFETY: We checked that `node_ptr` is properly aligned and
1081            // dereferenceable.
1082            let node_ptr = unsafe { addr_of!((*node_ptr).node) };
1083            // SAFETY:
1084            // - The caller has guaranteed that `node_ptr` points to a valid
1085            //   `Node<K, V, E>`.
1086            // - All inner nodes have `E` items, and `E` is less than or equal
1087            //   to `E`.
1088            unsafe {
1089                check_node_entries::<C, K, V, E>(node_ptr, E, context)?;
1090            }
1091
1092            Ok(())
1093        })
1094    }
1095}
1096
1097#[cfg(all(test, feature = "alloc"))]
1098mod tests {
1099    use core::hash::{Hash, Hasher};
1100
1101    use ahash::AHasher;
1102
1103    use crate::{
1104        alloc::{collections::BTreeMap, string::ToString},
1105        api::test::to_archived,
1106        primitive::ArchivedU32,
1107    };
1108
1109    #[test]
1110    fn test_hash() {
1111        let mut map = BTreeMap::new();
1112        map.insert("a".to_string(), 1);
1113        map.insert("b".to_string(), 2);
1114
1115        to_archived(&map, |archived_map| {
1116            let mut hasher = AHasher::default();
1117            archived_map.hash(&mut hasher);
1118            let hash_value = hasher.finish();
1119
1120            let mut expected_hasher = AHasher::default();
1121            for (k, v) in &map {
1122                k.hash(&mut expected_hasher);
1123                v.hash(&mut expected_hasher);
1124            }
1125            let expected_hash_value = expected_hasher.finish();
1126
1127            assert_eq!(hash_value, expected_hash_value);
1128        });
1129    }
1130
1131    #[test]
1132    fn test_range_empty() {
1133        let map = BTreeMap::<char, char>::new();
1134        to_archived(&map, |archived_map| {
1135            for _ in archived_map.range(..) {
1136                panic!("ArchivedBTreeMap should be empty");
1137            }
1138        });
1139    }
1140
1141    #[test]
1142    fn test_range_one() {
1143        let mut map = BTreeMap::<i32, i32>::new();
1144        map.insert(1, 1);
1145        to_archived(&map, |archived_map| {
1146            for _ in archived_map.range_with(2.., |q, k| q.cmp(&k.to_native()))
1147            {
1148                panic!("ArchivedBTreeMap range should be empty");
1149            }
1150        })
1151    }
1152
1153    #[test]
1154    fn test_range_open() {
1155        let mut map = BTreeMap::new();
1156        for i in 'a'..'z' {
1157            map.insert(i, i);
1158        }
1159
1160        to_archived(&map, |archived_map| {
1161            for _ in
1162                archived_map.range_with(..'a', |q, k| q.cmp(&k.to_native()))
1163            {
1164                panic!("Range should be empty");
1165            }
1166            for _ in
1167                archived_map.range_with('|'.., |q, k| q.cmp(&k.to_native()))
1168            {
1169                panic!("Range should be empty");
1170            }
1171        });
1172    }
1173
1174    #[test]
1175    fn test_range_str() {
1176        let mut map = BTreeMap::new();
1177        for i in 'a'..'z' {
1178            map.insert(i.to_string(), i.to_string());
1179        }
1180
1181        to_archived(&map, |archived_map| {
1182            let start = 'd';
1183            let end = 'w';
1184
1185            for ((k, v), expected) in archived_map
1186                .range_with(start..end, |q, k| {
1187                    q.cmp(&k.chars().next().unwrap())
1188                })
1189                .zip(start..end)
1190            {
1191                let expected = expected.to_string();
1192                assert_eq!(k.as_str(), expected);
1193                assert_eq!(v.as_str(), expected);
1194            }
1195        });
1196    }
1197
1198    #[test]
1199    fn test_range_u32() {
1200        let mut map = BTreeMap::new();
1201        for i in 0..200 {
1202            map.insert(i as u32, i as u32);
1203        }
1204
1205        to_archived(&map, |archived_map| {
1206            const START: u32 = 32;
1207            const END: u32 = 100;
1208            let start = ArchivedU32::from_native(START);
1209            let end = ArchivedU32::from_native(END);
1210
1211            for ((k, v), expected) in
1212                archived_map.range(start..end).zip(START..END)
1213            {
1214                assert_eq!(k.to_native(), expected);
1215                assert_eq!(v.to_native(), expected);
1216            }
1217        });
1218    }
1219}