Skip to main content

rkyv/validation/
mod.rs

1//! Validation implementations and helper types.
2
3pub mod archive;
4pub mod shared;
5
6use core::{any::TypeId, ops::Range};
7
8pub use self::{
9    archive::{ArchiveContext, ArchiveContextExt},
10    shared::SharedContext,
11};
12
13/// The default validator.
14#[derive(Debug)]
15pub struct Validator<A, S> {
16    archive: A,
17    shared: S,
18}
19
20impl<A, S> Validator<A, S> {
21    /// Creates a new validator from a byte range.
22    #[inline]
23    pub fn new(archive: A, shared: S) -> Self {
24        Self { archive, shared }
25    }
26}
27
28unsafe impl<A, S, E> ArchiveContext<E> for Validator<A, S>
29where
30    A: ArchiveContext<E>,
31{
32    fn check_subtree_ptr(
33        &mut self,
34        ptr: *const u8,
35        layout: &core::alloc::Layout,
36    ) -> Result<(), E> {
37        self.archive.check_subtree_ptr(ptr, layout)
38    }
39
40    unsafe fn push_subtree_range(
41        &mut self,
42        root: *const u8,
43        end: *const u8,
44    ) -> Result<Range<usize>, E> {
45        // SAFETY: This just forwards the call to the underlying `CoreValidator`
46        // which has the same safety requirements.
47        unsafe { self.archive.push_subtree_range(root, end) }
48    }
49
50    unsafe fn pop_subtree_range(
51        &mut self,
52        range: Range<usize>,
53    ) -> Result<(), E> {
54        // SAFETY: This just forwards the call to the underlying `CoreValidator`
55        // which has the same safety requirements.
56        unsafe { self.archive.pop_subtree_range(range) }
57    }
58}
59
60impl<A, S, E> SharedContext<E> for Validator<A, S>
61where
62    S: SharedContext<E>,
63{
64    fn start_shared(
65        &mut self,
66        address: usize,
67        type_id: TypeId,
68    ) -> Result<shared::ValidationState, E> {
69        self.shared.start_shared(address, type_id)
70    }
71
72    fn finish_shared(
73        &mut self,
74        address: usize,
75        type_id: TypeId,
76    ) -> Result<(), E> {
77        self.shared.finish_shared(address, type_id)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use rancor::Failure;
84
85    use crate::{
86        api::low::{access, access_pos},
87        boxed::ArchivedBox,
88        option::ArchivedOption,
89        util::Align,
90        Archived,
91    };
92
93    #[test]
94    fn basic_functionality() {
95        #[cfg(all(feature = "pointer_width_16", not(feature = "big_endian")))]
96        // Synthetic archive (correct)
97        let synthetic_buf = Align([
98            // "Hello world"
99            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
100            0u8, // padding to 2-alignment
101            1u8, 0u8, // Some + padding
102            0xf2u8, 0xffu8, // points 14 bytes backwards
103            11u8, 0u8, // string is 11 characters long
104        ]);
105
106        #[cfg(all(feature = "pointer_width_16", feature = "big_endian"))]
107        // Synthetic archive (correct)
108        let synthetic_buf = Align([
109            // "Hello world"
110            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
111            0u8, // padding to 2-alignment
112            1u8, 0u8, // Some + padding
113            0xffu8, 0xf2u8, // points 14 bytes backwards
114            0u8, 11u8, // string is 11 characters long
115        ]);
116
117        #[cfg(all(
118            not(any(
119                feature = "pointer_width_16",
120                feature = "pointer_width_64",
121            )),
122            not(feature = "big_endian"),
123        ))]
124        // Synthetic archive (correct)
125        let synthetic_buf = Align([
126            // "Hello world"
127            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
128            0u8, // padding to 4-alignment
129            1u8, 0u8, 0u8, 0u8, // Some + padding
130            0xf0u8, 0xffu8, 0xffu8, 0xffu8, // points 16 bytes backward
131            11u8, 0u8, 0u8, 0u8, // string is 11 characters long
132        ]);
133
134        #[cfg(all(
135            not(any(
136                feature = "pointer_width_16",
137                feature = "pointer_width_64",
138            )),
139            feature = "big_endian",
140        ))]
141        // Synthetic archive (correct)
142        let synthetic_buf = Align([
143            // "Hello world"
144            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
145            0u8, // padding to 4-alignment
146            1u8, 0u8, 0u8, 0u8, // Some + padding
147            0xffu8, 0xffu8, 0xffu8, 0xf0u8, // points 16 bytes backward
148            0u8, 0u8, 0u8, 11u8, // string is 11 characters long
149        ]);
150
151        #[cfg(all(feature = "pointer_width_64", not(feature = "big_endian")))]
152        // Synthetic archive (correct)
153        let synthetic_buf = Align([
154            // "Hello world"
155            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
156            0u8, 0u8, 0u8, 0u8, 0u8, // padding to 8-alignment
157            1u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, // Some + padding
158            // points 24 bytes backward
159            0xe8u8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8,
160            11u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
161            0u8, // string is 11 characters long
162        ]);
163
164        #[cfg(all(feature = "pointer_width_64", feature = "big_endian"))]
165        // Synthetic archive (correct)
166        let synthetic_buf = Align([
167            // "Hello world!!!!!" because otherwise the string will get inlined
168            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
169            0x21, 0x21, 0x21, 0x21, 0x21, 1u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
170            0u8, // Some + padding
171            // points 24 bytes backward
172            0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xffu8, 0xe8u8, 0u8,
173            0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
174            11u8, // string is 11 characters long
175        ]);
176
177        let result = access::<ArchivedOption<ArchivedBox<[u8]>>, Failure>(
178            &*synthetic_buf,
179        );
180        result.unwrap();
181
182        // Out of bounds
183        access_pos::<Archived<u32>, Failure>(&*Align([0, 1, 2, 3, 4]), 8)
184            .expect_err("expected out of bounds error");
185        // Overrun
186        access_pos::<Archived<u32>, Failure>(&*Align([0, 1, 2, 3, 4]), 4)
187            .expect_err("expected overrun error");
188        // Unaligned
189        access_pos::<Archived<u32>, Failure>(&*Align([0, 1, 2, 3, 4]), 1)
190            .expect_err("expected unaligned error");
191        // Underaligned
192        access_pos::<Archived<u32>, Failure>(&Align([0, 1, 2, 3, 4])[1..], 0)
193            .expect_err("expected underaligned error");
194        // Undersized
195        access::<Archived<u32>, Failure>(&*Align([]))
196            .expect_err("expected out of bounds error");
197    }
198
199    #[cfg(feature = "pointer_width_32")]
200    #[test]
201    fn invalid_tags() {
202        // Invalid archive (invalid tag)
203        let synthetic_buf = Align([
204            2u8, 0u8, 0u8, 0u8, // invalid tag + padding
205            8u8, 0u8, 0u8, 0u8, // points 8 bytes forward
206            11u8, 0u8, 0u8, 0u8, // string is 11 characters long
207            // "Hello world"
208            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
209        ]);
210
211        let result = access_pos::<Archived<Option<Box<[u8]>>>, Failure>(
212            &*synthetic_buf,
213            0,
214        );
215        result.unwrap_err();
216    }
217
218    #[cfg(feature = "pointer_width_32")]
219    #[test]
220    fn overlapping_claims() {
221        // Invalid archive (overlapping claims)
222        let synthetic_buf = Align([
223            // First string
224            16u8, 0u8, 0u8, 0u8, // points 16 bytes forward
225            11u8, 0u8, 0u8, 0u8, // string is 11 characters long
226            // Second string
227            8u8, 0u8, 0u8, 0u8, // points 8 bytes forward
228            11u8, 0u8, 0u8, 0u8, // string is 11 characters long
229            // "Hello world"
230            0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64,
231        ]);
232
233        access_pos::<Archived<[Box<[u8]>; 2]>, Failure>(&*synthetic_buf, 0)
234            .unwrap_err();
235    }
236
237    #[cfg(feature = "pointer_width_32")]
238    #[test]
239    fn cycle_detection() {
240        use bytecheck::CheckBytes;
241        use rancor::{Fallible, Source};
242
243        use crate::{
244            ser::Writer, validation::ArchiveContext, Archive, Serialize,
245        };
246
247        #[allow(dead_code)]
248        #[derive(Archive)]
249        #[rkyv(crate, derive(Debug))]
250        enum Node {
251            Nil,
252            Cons(#[omit_bounds] Box<Node>),
253        }
254
255        impl<S: Fallible + Writer + ?Sized> Serialize<S> for Node {
256            fn serialize(
257                &self,
258                serializer: &mut S,
259            ) -> Result<NodeResolver, S::Error> {
260                Ok(match self {
261                    Node::Nil => NodeResolver::Nil,
262                    Node::Cons(inner) => {
263                        NodeResolver::Cons(inner.serialize(serializer)?)
264                    }
265                })
266            }
267        }
268
269        unsafe impl<C> CheckBytes<C> for ArchivedNode
270        where
271            C: Fallible + ArchiveContext + ?Sized,
272            C::Error: Source,
273        {
274            unsafe fn check_bytes(
275                value: *const Self,
276                context: &mut C,
277            ) -> Result<(), C::Error> {
278                let bytes = value.cast::<u8>();
279                let tag = unsafe { *bytes };
280                match tag {
281                    0 => (),
282                    1 => unsafe {
283                        <Archived<Box<Node>> as CheckBytes<C>>::check_bytes(
284                            bytes.add(4).cast(),
285                            context,
286                        )?;
287                    },
288                    _ => panic!(),
289                }
290                Ok(())
291            }
292        }
293
294        // Invalid archive (cyclic claims)
295        let synthetic_buf = Align([
296            // First node
297            1u8, 0u8, 0u8, 0u8, // Cons
298            4u8, 0u8, 0u8, 0u8, // Node is 4 bytes forward
299            // Second string
300            1u8, 0u8, 0u8, 0u8, // Cons
301            244u8, 255u8, 255u8, 255u8, // Node is 12 bytes back
302        ]);
303
304        access_pos::<ArchivedNode, Failure>(&*synthetic_buf, 0).unwrap_err();
305    }
306}