1use core::{
21 alloc::Layout,
22 borrow::Borrow,
23 error::Error,
24 fmt,
25 marker::PhantomData,
26 mem::{size_of, MaybeUninit},
27 ptr::{self, null, NonNull},
28 slice::from_raw_parts,
29};
30
31use munge::munge;
32use rancor::{fail, Fallible, ResultExt as _, Source};
33
34use crate::{
35 collections::util::IteratorLengthMismatch,
36 primitive::{ArchivedUsize, FixedUsize},
37 seal::Seal,
38 ser::{Allocator, Writer, WriterExt},
39 simd::{Bitmask, Group, MAX_GROUP_WIDTH},
40 util::SerVec,
41 Archive as _, Place, Portable, RawRelPtr, Serialize,
42};
43
44#[derive(Portable)]
46#[cfg_attr(
47 feature = "bytecheck",
48 derive(bytecheck::CheckBytes),
49 bytecheck(verify)
50)]
51#[rkyv(crate)]
52#[repr(C)]
53pub struct ArchivedHashTable<T> {
54 ptr: RawRelPtr,
55 len: ArchivedUsize,
56 cap: ArchivedUsize,
57 _phantom: PhantomData<T>,
58}
59
60#[inline]
61fn h1(hash: u64) -> u64 {
62 hash
63}
64
65#[inline]
66fn h2(hash: u64) -> u8 {
67 (hash >> 57) as u8
68}
69
70struct ProbeSeq {
71 pos: usize,
72 stride: usize,
73}
74
75impl ProbeSeq {
76 #[inline]
77 fn move_next(&mut self, bucket_mask: usize) {
78 self.stride += MAX_GROUP_WIDTH;
79 self.pos += self.stride;
80 self.pos &= bucket_mask;
81 }
82}
83
84impl<T> ArchivedHashTable<T> {
85 fn probe_seq(hash: u64, capacity: usize) -> ProbeSeq {
86 ProbeSeq {
87 pos: (h1(hash) % capacity as u64) as usize,
88 stride: 0,
89 }
90 }
91
92 unsafe fn control_raw(this: *mut Self, index: usize) -> *const u8 {
97 debug_assert!(unsafe { !(*this).is_empty() });
98
99 let ptr =
104 unsafe { RawRelPtr::as_ptr_raw(ptr::addr_of_mut!((*this).ptr)) };
105 unsafe { ptr.cast::<u8>().add(index) }
109 }
110
111 unsafe fn bucket_raw(this: *mut Self, index: usize) -> NonNull<T> {
116 unsafe {
117 NonNull::new_unchecked(
118 RawRelPtr::as_ptr_raw(ptr::addr_of_mut!((*this).ptr))
119 .cast::<T>()
120 .sub(index + 1),
121 )
122 }
123 }
124
125 fn bucket_mask(capacity: usize) -> usize {
126 capacity.checked_next_power_of_two().unwrap() - 1
127 }
128
129 unsafe fn get_entry_raw<C>(
133 this: *mut Self,
134 hash: u64,
135 cmp: C,
136 ) -> Option<NonNull<T>>
137 where
138 C: Fn(&T) -> bool,
139 {
140 let is_empty = unsafe { (*this).is_empty() };
141 if is_empty {
142 return None;
143 }
144
145 let capacity = unsafe { (*this).capacity() };
146 let probe_cap = Self::probe_cap(capacity);
147 let control_count = Self::control_count(probe_cap);
148
149 let h2_hash = h2(hash);
150 let mut probe_seq = Self::probe_seq(hash, capacity);
151
152 let bucket_mask = Self::bucket_mask(control_count);
153
154 loop {
155 let mut any_empty = false;
156
157 for i in 0..MAX_GROUP_WIDTH / Group::WIDTH {
158 let pos = probe_seq.pos + i * Group::WIDTH;
159
160 let group =
161 unsafe { Group::read(Self::control_raw(this, pos)) };
162
163 for bit in group.match_byte(h2_hash) {
164 let index = (pos + bit) % capacity;
165 let bucket_ptr = unsafe { Self::bucket_raw(this, index) };
166 let bucket = unsafe { bucket_ptr.as_ref() };
167
168 if cmp(bucket) {
170 return Some(bucket_ptr);
171 }
172 }
173
174 any_empty = any_empty || group.match_empty().any_bit_set();
176 }
177
178 if any_empty {
179 return None;
180 }
181
182 loop {
183 probe_seq.move_next(bucket_mask);
184 if probe_seq.pos < probe_cap {
185 break;
186 }
187 }
188 }
189 }
190
191 pub fn get_with<C>(&self, hash: u64, cmp: C) -> Option<&T>
193 where
194 C: Fn(&T) -> bool,
195 {
196 let this = (self as *const Self).cast_mut();
197 let ptr = unsafe { Self::get_entry_raw(this, hash, |e| cmp(e))? };
198 Some(unsafe { ptr.as_ref() })
199 }
200
201 pub fn get_seal_with<C>(
203 this: Seal<'_, Self>,
204 hash: u64,
205 cmp: C,
206 ) -> Option<Seal<'_, T>>
207 where
208 C: Fn(&T) -> bool,
209 {
210 let mut ptr = unsafe {
211 Self::get_entry_raw(this.unseal_unchecked(), hash, |e| cmp(e))?
212 };
213 Some(Seal::new(unsafe { ptr.as_mut() }))
214 }
215
216 pub const fn is_empty(&self) -> bool {
218 self.len.to_native() == 0
219 }
220
221 pub const fn len(&self) -> usize {
223 self.len.to_native() as usize
224 }
225
226 pub fn capacity(&self) -> usize {
228 self.cap.to_native() as usize
229 }
230
231 unsafe fn control_iter(this: *mut Self) -> ControlIter {
235 ControlIter {
236 current_mask: unsafe {
237 Group::read(Self::control_raw(this, 0)).match_full()
238 },
239 next_group: unsafe { Self::control_raw(this, Group::WIDTH) },
240 }
241 }
242
243 pub fn raw_iter(&self) -> RawIter<T> {
245 if self.is_empty() {
246 RawIter::empty()
247 } else {
248 let this = (self as *const Self).cast_mut();
249 RawIter {
250 controls: unsafe { Self::control_iter(this) },
252 entries: unsafe {
253 NonNull::new_unchecked(self.ptr.as_ptr().cast_mut().cast())
254 },
255 items_left: self.len(),
256 }
257 }
258 }
259
260 pub fn raw_iter_seal(mut this: Seal<'_, Self>) -> RawIter<T> {
262 if this.is_empty() {
263 RawIter::empty()
264 } else {
265 let controls =
267 unsafe { Self::control_iter(this.as_mut().unseal_unchecked()) };
268 let items_left = this.len();
269 munge!(let Self { ptr, .. } = this);
270 RawIter {
271 controls,
272 entries: unsafe {
273 NonNull::new_unchecked(RawRelPtr::as_mut_ptr(ptr).cast())
274 },
275 items_left,
276 }
277 }
278 }
279
280 fn capacity_from_len(len: usize, load_factor: (usize, usize)) -> usize {
281 if len == 0 {
282 0
283 } else {
284 usize::max(len * load_factor.1 / load_factor.0, len + 1)
285 }
286 }
287
288 fn probe_cap(capacity: usize) -> usize {
289 capacity.next_multiple_of(MAX_GROUP_WIDTH)
290 }
291
292 fn control_count(probe_cap: usize) -> usize {
293 probe_cap + MAX_GROUP_WIDTH - 1
294 }
295
296 #[allow(dead_code)]
297 fn memory_layout<E: Source>(
298 capacity: usize,
299 control_count: usize,
300 ) -> Result<(Layout, usize), E> {
301 let buckets_layout = Layout::array::<T>(capacity).into_error()?;
302 let control_layout = Layout::array::<u8>(control_count).into_error()?;
303 buckets_layout.extend(control_layout).into_error()
304 }
305
306 pub fn serialize_from_iter<I, U, H, S>(
308 items: I,
309 hashes: H,
310 load_factor: (usize, usize),
311 serializer: &mut S,
312 ) -> Result<HashTableResolver, S::Error>
313 where
314 I: Clone + ExactSizeIterator,
315 I::Item: Borrow<U>,
316 U: Serialize<S, Archived = T>,
317 H: ExactSizeIterator<Item = u64>,
318 S: Fallible + Writer + Allocator + ?Sized,
319 S::Error: Source,
320 {
321 #[derive(Debug)]
322 struct InvalidLoadFactor {
323 numerator: usize,
324 denominator: usize,
325 }
326
327 impl fmt::Display for InvalidLoadFactor {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 write!(
330 f,
331 "invalid load factor {} / {}, load factor must be a \
332 fraction in the range (0, 1]",
333 self.numerator, self.denominator
334 )
335 }
336 }
337
338 impl Error for InvalidLoadFactor {}
339
340 if load_factor.0 == 0
341 || load_factor.1 == 0
342 || load_factor.0 > load_factor.1
343 {
344 fail!(InvalidLoadFactor {
345 numerator: load_factor.0,
346 denominator: load_factor.1,
347 });
348 }
349
350 let len = items.len();
351
352 if len == 0 {
353 let count = items.count();
354 if count != 0 {
355 fail!(IteratorLengthMismatch {
356 expected: 0,
357 actual: count,
358 });
359 }
360
361 return Ok(HashTableResolver { pos: 0 });
362 }
363
364 let capacity = Self::capacity_from_len(len, load_factor);
365 let probe_cap = Self::probe_cap(capacity);
366 let control_count = Self::control_count(probe_cap);
367
368 SerVec::with_capacity(
370 serializer,
371 capacity,
372 |ordered_items, serializer| {
373 for _ in 0..capacity {
374 unsafe {
375 ordered_items.push_unchecked(None);
376 }
377 }
378
379 SerVec::<u8>::with_capacity(
380 serializer,
381 control_count,
382 |control_bytes, serializer| {
383 unsafe {
385 control_bytes
386 .as_mut_ptr()
387 .write_bytes(0xff, control_bytes.capacity());
388 control_bytes.set_len(control_bytes.capacity());
389 }
390
391 let bucket_mask = Self::bucket_mask(control_count);
392
393 for (item, hash) in items.zip(hashes) {
394 let h2_hash = h2(hash);
395 let mut probe_seq = Self::probe_seq(hash, capacity);
396
397 'insert: loop {
398 for i in 0..MAX_GROUP_WIDTH / Group::WIDTH {
399 let pos = probe_seq.pos + i * Group::WIDTH;
400 let group = unsafe {
401 Group::read(
402 control_bytes.as_ptr().add(pos),
403 )
404 };
405
406 if let Some(bit) =
407 group.match_empty().lowest_set_bit()
408 {
409 let index = (pos + bit) % capacity;
410
411 control_bytes[index] = h2_hash;
413 if index < (control_count - capacity) {
417 control_bytes[capacity + index] =
418 h2_hash;
419 }
420
421 ordered_items[index] = Some(item);
422 break 'insert;
423 }
424 }
425
426 loop {
427 probe_seq.move_next(bucket_mask);
428 if probe_seq.pos < probe_cap {
429 break;
430 }
431 }
432 }
433 }
434
435 let mut zeros = MaybeUninit::<T>::uninit();
436 unsafe {
437 zeros.as_mut_ptr().write_bytes(0, 1);
438 }
439 let zeros = unsafe {
440 from_raw_parts(
441 zeros.as_ptr().cast::<u8>(),
442 size_of::<T>(),
443 )
444 };
445 SerVec::with_capacity(
446 serializer,
447 len,
448 |resolvers, serializer| {
449 for item in ordered_items
450 .iter()
451 .filter_map(|x| x.as_ref())
452 {
453 resolvers.push(
454 item.borrow().serialize(serializer)?,
455 );
456 }
457
458 serializer.align_for::<T>()?;
459
460 let mut resolvers = resolvers.drain().rev();
461 for item in ordered_items.iter().rev() {
462 if let Some(item) = item {
463 unsafe {
464 serializer.resolve_aligned(
465 item.borrow(),
466 resolvers.next().unwrap(),
467 )?;
468 }
469 } else {
470 serializer.write(zeros)?;
471 }
472 }
473
474 let pos = serializer.pos();
475 serializer.write(control_bytes)?;
476
477 Ok(HashTableResolver {
478 pos: pos as FixedUsize,
479 })
480 },
481 )?
482 },
483 )?
484 },
485 )?
486 }
487
488 pub fn resolve_from_len(
490 len: usize,
491 load_factor: (usize, usize),
492 resolver: HashTableResolver,
493 out: Place<Self>,
494 ) {
495 munge!(let Self { ptr, len: out_len, cap, _phantom: _ } = out);
496
497 if len == 0 {
498 RawRelPtr::emplace_invalid(ptr);
499 } else {
500 RawRelPtr::emplace(resolver.pos as usize, ptr);
501 }
502
503 len.resolve((), out_len);
504
505 let capacity = Self::capacity_from_len(len, load_factor);
506 capacity.resolve((), cap);
507
508 }
510}
511
512pub struct HashTableResolver {
514 pos: FixedUsize,
515}
516
517struct ControlIter {
518 current_mask: Bitmask,
519 next_group: *const u8,
520}
521
522unsafe impl Send for ControlIter {}
523unsafe impl Sync for ControlIter {}
524
525impl ControlIter {
526 fn none() -> Self {
527 Self {
528 current_mask: Bitmask::EMPTY,
529 next_group: null(),
530 }
531 }
532
533 #[inline]
534 fn next_full(&mut self) -> Option<usize> {
535 let bit = self.current_mask.lowest_set_bit()?;
536 self.current_mask = self.current_mask.remove_lowest_bit();
537 Some(bit)
538 }
539
540 #[inline]
541 fn move_next(&mut self) {
542 self.current_mask =
543 unsafe { Group::read(self.next_group).match_full() };
544 self.next_group = unsafe { self.next_group.add(Group::WIDTH) };
545 }
546}
547
548pub struct RawIter<T> {
550 controls: ControlIter,
551 entries: NonNull<T>,
552 items_left: usize,
553}
554
555impl<T> RawIter<T> {
556 pub fn empty() -> Self {
558 Self {
559 controls: ControlIter::none(),
560 entries: NonNull::dangling(),
561 items_left: 0,
562 }
563 }
564}
565
566impl<T> Iterator for RawIter<T> {
567 type Item = NonNull<T>;
568
569 fn next(&mut self) -> Option<Self::Item> {
570 if self.items_left == 0 {
571 None
572 } else {
573 let bit = loop {
574 if let Some(bit) = self.controls.next_full() {
575 break bit;
576 }
577 self.controls.move_next();
578 self.entries = unsafe {
579 NonNull::new_unchecked(
580 self.entries.as_ptr().sub(Group::WIDTH),
581 )
582 };
583 };
584 self.items_left -= 1;
585 let entry = unsafe {
586 NonNull::new_unchecked(self.entries.as_ptr().sub(bit + 1))
587 };
588 Some(entry)
589 }
590 }
591}
592
593impl<T> ExactSizeIterator for RawIter<T> {
594 fn len(&self) -> usize {
595 self.items_left
596 }
597}
598
599#[cfg(feature = "bytecheck")]
600mod verify {
601 use core::{error::Error, fmt};
602
603 use bytecheck::{CheckBytes, Verify};
604 use rancor::{fail, Fallible, Source};
605
606 use super::ArchivedHashTable;
607 use crate::{
608 simd::Group,
609 validation::{ArchiveContext, ArchiveContextExt as _},
610 };
611
612 #[derive(Debug)]
613 struct InvalidLength {
614 len: usize,
615 cap: usize,
616 }
617
618 impl fmt::Display for InvalidLength {
619 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
620 write!(
621 f,
622 "hash table length must be strictly less than its capacity \
623 (length: {}, capacity: {})",
624 self.len, self.cap,
625 )
626 }
627 }
628
629 impl Error for InvalidLength {}
630
631 #[derive(Debug)]
632 struct UnwrappedControlByte {
633 index: usize,
634 }
635
636 impl fmt::Display for UnwrappedControlByte {
637 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
638 write!(f, "unwrapped control byte at index {}", self.index,)
639 }
640 }
641
642 impl Error for UnwrappedControlByte {}
643
644 unsafe impl<C, T> Verify<C> for ArchivedHashTable<T>
645 where
646 C: Fallible + ArchiveContext + ?Sized,
647 C::Error: Source,
648 T: CheckBytes<C>,
649 {
650 fn verify(&self, context: &mut C) -> Result<(), C::Error> {
651 let len = self.len();
652 let cap = self.capacity();
653
654 if len == 0 && cap == 0 {
655 return Ok(());
656 }
657
658 if len >= cap {
659 fail!(InvalidLength { len, cap });
660 }
661
662 let probe_cap = Self::probe_cap(cap);
664 let control_count = Self::control_count(probe_cap);
665 let (layout, control_offset) =
666 Self::memory_layout(cap, control_count)?;
667 let ptr = self
668 .ptr
669 .as_ptr_wrapping()
670 .cast::<u8>()
671 .wrapping_sub(control_offset);
672
673 context.in_subtree_raw(ptr, layout, |context| {
674 let this = (self as *const Self).cast_mut();
677 let mut controls = unsafe { Self::control_iter(this) };
679 let mut base_index = 0;
680 'outer: while base_index < cap {
681 while let Some(bit) = controls.next_full() {
682 let index = base_index + bit;
683 if index >= cap {
684 break 'outer;
685 }
686
687 unsafe {
688 T::check_bytes(
689 Self::bucket_raw(this, index).as_ptr(),
690 context,
691 )?;
692 }
693 }
694
695 controls.move_next();
696 base_index += Group::WIDTH;
697 }
698
699 for i in cap..usize::min(2 * cap, control_count - cap) {
701 let byte = unsafe { *Self::control_raw(this, i) };
702 let wrapped = unsafe { *Self::control_raw(this, i % cap) };
703 if wrapped != byte {
704 fail!(UnwrappedControlByte { index: i })
705 }
706 }
707
708 Ok(())
709 })
710 }
711 }
712}