1use super::arrays::SimpleArrayView;
6use super::parser::{PolicyCursor, PolicyData, PolicyOffset};
7use super::{Counted, Parse, PolicyValidationContext, Validate};
8
9use hashbrown::hash_table::HashTable;
10use static_assertions::const_assert;
11use std::fmt::Debug;
12use std::hash::{DefaultHasher, Hash, Hasher};
13use std::marker::PhantomData;
14use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned};
15
16pub trait HasMetadata {
22 type Metadata: FromBytes + Sized;
24}
25
26pub trait Walk {
30 fn walk(policy_data: &PolicyData, offset: PolicyOffset) -> PolicyOffset;
34}
35
36#[derive(Debug, Clone, Copy)]
41pub struct View<T> {
42 phantom: PhantomData<T>,
43
44 start: PolicyOffset,
46
47 end: PolicyOffset,
49}
50
51impl<T> View<T> {
52 pub fn new(start: PolicyOffset, end: PolicyOffset) -> Self {
54 Self { phantom: PhantomData, start, end }
55 }
56
57 fn start(&self) -> PolicyOffset {
59 self.start
60 }
61}
62
63impl<T: Sized> View<T> {
64 pub fn at(start: PolicyOffset) -> Self {
68 let end = start + std::mem::size_of::<T>() as u32;
69 Self::new(start, end)
70 }
71}
72
73impl<T: FromBytes + Sized> View<T> {
74 pub fn read(&self, policy_data: &PolicyData) -> T {
81 debug_assert_eq!(self.end - self.start, std::mem::size_of::<T>() as u32);
82 let start = self.start as usize;
83 let end = self.end as usize;
84 T::read_from_bytes(&policy_data[start..end]).unwrap()
85 }
86}
87
88impl<T: HasMetadata> View<T> {
89 pub fn metadata(&self) -> View<T::Metadata> {
93 View::<T::Metadata>::at(self.start)
94 }
95
96 pub fn read_metadata(&self, policy_data: &PolicyData) -> T::Metadata {
98 self.metadata().read(policy_data)
99 }
100}
101
102impl<T: Parse> View<T> {
103 pub fn parse(&self, policy_data: &PolicyData) -> T {
109 let cursor = PolicyCursor::new_at(policy_data.clone(), self.start);
110 let (object, _) =
111 T::parse(cursor).map_err(Into::<anyhow::Error>::into).expect("policy should be valid");
112 object
113 }
114}
115
116impl<T: Validate + Parse> Validate for View<T> {
117 type Error = anyhow::Error;
118
119 fn validate(&self, context: &mut PolicyValidationContext) -> Result<(), Self::Error> {
120 let object = self.parse(&context.data);
121 object.validate(context).map_err(Into::<anyhow::Error>::into)
122 }
123}
124
125#[derive(Debug, Clone, Copy)]
130pub struct ArrayDataView<D> {
131 phantom: PhantomData<D>,
132 start: PolicyOffset,
133 count: u32,
134}
135
136impl<D> ArrayDataView<D> {
137 pub fn new(start: PolicyOffset, count: u32) -> Self {
139 Self { phantom: PhantomData, start, count }
140 }
141
142 pub fn iter(self, policy_data: &PolicyData) -> ArrayDataViewIter<D> {
149 ArrayDataViewIter::new(policy_data.clone(), self.start, self.count)
150 }
151}
152
153pub struct ArrayDataViewIter<D> {
158 phantom: PhantomData<D>,
159 policy_data: PolicyData,
160 offset: PolicyOffset,
161 remaining: u32,
162}
163
164impl<T> ArrayDataViewIter<T> {
165 fn new(policy_data: PolicyData, offset: PolicyOffset, remaining: u32) -> Self {
167 Self { phantom: PhantomData, policy_data, offset, remaining }
168 }
169}
170
171impl<D: Walk> std::iter::Iterator for ArrayDataViewIter<D> {
172 type Item = View<D>;
173
174 fn next(&mut self) -> Option<Self::Item> {
175 if self.remaining > 0 {
176 let start = self.offset;
177 self.offset = D::walk(&self.policy_data, start);
178 self.remaining -= 1;
179 Some(View::new(start, self.offset))
180 } else {
181 None
182 }
183 }
184}
185
186#[derive(Debug, Clone, Copy)]
191pub(super) struct ArrayView<M, D> {
192 phantom: PhantomData<(M, D)>,
193 start: PolicyOffset,
194 count: u32,
195}
196
197impl<M, D> ArrayView<M, D> {
198 pub fn new(start: PolicyOffset, count: u32) -> Self {
200 Self { phantom: PhantomData, start, count }
201 }
202}
203
204impl<M: Sized, D> ArrayView<M, D> {
205 pub fn metadata(&self) -> View<M> {
207 View::<M>::at(self.start)
208 }
209
210 pub fn data(&self) -> ArrayDataView<D> {
212 ArrayDataView::new(self.metadata().end, self.count)
213 }
214}
215
216fn parse_array_data<D: Parse>(
217 cursor: PolicyCursor,
218 count: u32,
219) -> Result<PolicyCursor, anyhow::Error> {
220 let mut tail = cursor;
221 for _ in 0..count {
222 let (_, next) = D::parse(tail).map_err(Into::<anyhow::Error>::into)?;
223 tail = next;
224 }
225 Ok(tail)
226}
227
228impl<M: Counted + Parse + Sized, D: Parse> Parse for ArrayView<M, D> {
229 type Error = anyhow::Error;
232
233 fn parse(cursor: PolicyCursor) -> Result<(Self, PolicyCursor), Self::Error> {
234 let start = cursor.offset();
235 let (metadata, cursor) = M::parse(cursor).map_err(Into::<anyhow::Error>::into)?;
236 let count = metadata.count();
237 let cursor = parse_array_data::<D>(cursor, count)?;
238 Ok((Self::new(start, count), cursor))
239 }
240}
241
242struct HashedArrayViewEntryIter<'a, D: HasMetadata> {
245 policy_data: &'a PolicyData,
246 limit: PolicyOffset,
247 metadata: D::Metadata,
248 offset: Option<PolicyOffset>,
249}
250
251#[derive(Clone, Copy, Debug, KnownLayout, FromBytes, Immutable, PartialEq, Unaligned)]
258#[repr(C, packed)]
259pub(super) struct U24 {
260 low: u8,
261 middle: u8,
262 high: u8,
263}
264
265const_assert!(std::mem::size_of::<U24>() == 3);
267const_assert!(std::mem::align_of::<U24>() == 1);
268
269impl TryFrom<u32> for U24 {
270 type Error = ();
271
272 fn try_from(value: u32) -> Result<Self, Self::Error> {
273 if 0x1000000 <= value {
274 Err(())
275 } else {
276 Ok(Self {
277 low: (value & 0xff) as u8,
278 middle: ((value >> 8) & 0xff) as u8,
279 high: ((value >> 16) & 0xff) as u8,
280 })
281 }
282 }
283}
284
285impl From<U24> for u32 {
286 fn from(value: U24) -> u32 {
287 ((value.high as u32) << 16) + ((value.middle as u32) << 8) + (value.low as u32)
288 }
289}
290
291impl<'a, D: HasMetadata + Walk> Iterator for HashedArrayViewEntryIter<'a, D>
292where
293 D::Metadata: Eq,
294{
295 type Item = View<D>;
296
297 fn next(&mut self) -> Option<Self::Item> {
298 if let Some(offset) = self.offset
299 && offset < self.limit
300 {
301 let element = View::<D>::at(offset);
302 let metadata = element.read_metadata(&self.policy_data);
303 if metadata == self.metadata {
304 self.offset = Some(D::walk(&self.policy_data, offset));
305 Some(element)
306 } else {
307 self.offset = None;
308 None
309 }
310 } else {
311 None
312 }
313 }
314}
315
316#[derive(Debug, Clone)]
321pub(super) struct HashedArrayView<D: HasMetadata> {
322 phantom: PhantomData<D>,
323 index: HashTable<U24>,
324 limit: PolicyOffset,
328}
329
330impl<D: HasMetadata> HashedArrayView<D>
331where
332 D::Metadata: Hash,
333{
334 fn metadata_hash(metadata: &D::Metadata) -> u64 {
335 let mut hasher = DefaultHasher::new();
336 metadata.hash(&mut hasher);
337 hasher.finish()
338 }
339}
340
341impl<D: Parse + HasMetadata + Walk> HashedArrayView<D>
342where
343 D::Metadata: Eq + PartialEq + Hash + Debug,
344{
345 pub fn find(&self, key: D::Metadata, policy_data: &PolicyData) -> Option<D> {
350 let key_hash = Self::metadata_hash(&key);
351 let offset = self.index.find(key_hash, |&offset| {
352 let element = View::<D>::at(u32::from(offset));
353 key == element.read_metadata(policy_data)
354 })?;
355 let element = View::<D>::at(u32::from(*offset));
356 Some(element.parse(policy_data))
357 }
358
359 pub(super) fn find_all(
362 &self,
363 key: D::Metadata,
364 policy_data: &PolicyData,
365 ) -> impl Iterator<Item = D> {
366 let key_hash = Self::metadata_hash(&key);
367 let offset = self.index.find(key_hash, |&offset| {
368 let element = View::<D>::at(u32::from(offset));
369 key == element.read_metadata(policy_data)
370 });
371 (HashedArrayViewEntryIter {
372 policy_data: policy_data,
373 limit: self.limit,
374 metadata: key,
375 offset: offset.map(|offset| u32::from(*offset)),
376 })
377 .map(|element| element.parse(policy_data))
378 }
379
380 pub(super) fn iter(&self, policy_data: &PolicyData) -> impl Iterator<Item = View<D>> {
382 self.index
383 .iter()
384 .map(|offset| {
385 let element = View::<D>::at(u32::from(*offset));
386 HashedArrayViewEntryIter {
387 policy_data: policy_data,
388 limit: self.limit,
389 metadata: element.read_metadata(policy_data),
390 offset: Some(u32::from(*offset)),
391 }
392 })
393 .flatten()
394 }
395}
396
397impl<D: Parse + HasMetadata + Walk> Parse for HashedArrayView<D>
398where
399 D::Metadata: Eq + Debug + PartialEq + Parse + Hash,
400{
401 type Error = anyhow::Error;
402
403 fn parse(cursor: PolicyCursor) -> Result<(Self, PolicyCursor), Self::Error> {
404 let (array_view, cursor) = SimpleArrayView::<D>::parse(cursor)?;
405
406 let mut index = HashTable::with_capacity(array_view.count as usize);
408
409 let limit = cursor.offset();
411
412 for view in array_view.data().iter(cursor.data()) {
415 let metadata = view.read_metadata(cursor.data());
416
417 index
418 .entry(
419 Self::metadata_hash(&metadata),
420 |&offset| {
421 let element = View::<D>::at(u32::from(offset));
422 metadata == element.read_metadata(cursor.data())
423 },
424 |&offset| {
425 let element = View::<D>::at(u32::from(offset));
426 Self::metadata_hash(&element.read_metadata(cursor.data()))
427 },
428 )
429 .or_insert(U24::try_from(view.start()).expect("Policy offsets ought fit in U24!"));
430 }
431
432 Ok((Self { phantom: PhantomData, index, limit }, cursor))
433 }
434}
435
436impl<D: Validate + Parse + HasMetadata + Walk> Validate for HashedArrayView<D>
437where
438 D::Metadata: Eq,
439{
440 type Error = anyhow::Error;
441
442 fn validate(&self, context: &mut PolicyValidationContext) -> Result<(), Self::Error> {
443 let policy_data = context.data.clone();
444 for element in self
445 .index
446 .iter()
447 .map(|offset| {
448 let element = View::<D>::at(u32::from(*offset));
449 HashedArrayViewEntryIter::<D> {
450 policy_data: &policy_data,
451 limit: self.limit,
452 metadata: element.read_metadata(&policy_data),
453 offset: Some(u32::from(*offset)),
454 }
455 })
456 .flatten()
457 {
458 element.validate(context)?;
459 }
460
461 Ok(())
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::U24;
468
469 #[test]
470 fn to_and_from_u24() {
471 for i in 0u32..0x10000 {
472 let u24_result = U24::try_from(i);
473 assert!(u24_result.is_ok());
474 let u24 = u24_result.unwrap();
475 assert_eq!(i >> 16, u24.high as u32);
476 assert_eq!((i >> 8) & 0xff, u24.middle as u32);
477 assert_eq!(i & 0xff, u24.low as u32);
478 let j = u32::from(u24);
479 assert_eq!(i, j);
480 }
481
482 assert!(U24::try_from(0x1000000).is_err());
483 }
484}