dhcp_protocol/
size_constrained.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::size_of_contents::SizeOfContents;
6use serde::{Deserialize, Serialize};
7use std::ops::Deref;
8
9pub const U8_MAX_AS_USIZE: usize = u8::MAX as usize;
10
11/// AtLeast encodes a lower bound on the inner container's number of elements.
12///
13/// If `inner` is an `AtMostBytes<Vec<U>>`, it must contain at least
14/// `LOWER_BOUND_ON_NUMBER_OF_ELEMENTS` instances of `U`.
15#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
16pub struct AtLeast<const LOWER_BOUND_ON_NUMBER_OF_ELEMENTS: usize, T> {
17    inner: T,
18}
19
20// We'd have liked to make this more general, but we run into
21// "unconstrained type parameter" issues.
22impl<const LOWER_BOUND_ON_NUMBER_OF_ELEMENTS: usize, const UPPER_BOUND_ON_SIZE_IN_BYTES: usize, T>
23    TryFrom<AtMostBytes<UPPER_BOUND_ON_SIZE_IN_BYTES, Vec<T>>>
24    for AtLeast<
25        LOWER_BOUND_ON_NUMBER_OF_ELEMENTS,
26        AtMostBytes<UPPER_BOUND_ON_SIZE_IN_BYTES, Vec<T>>,
27    >
28{
29    type Error = (Error, Vec<T>);
30
31    fn try_from(
32        value: AtMostBytes<UPPER_BOUND_ON_SIZE_IN_BYTES, Vec<T>>,
33    ) -> Result<Self, Self::Error> {
34        if value.len() >= LOWER_BOUND_ON_NUMBER_OF_ELEMENTS {
35            Ok(AtLeast { inner: value })
36        } else {
37            let AtMostBytes { inner } = value;
38            Err((Error::SizeConstraintViolated, inner))
39        }
40    }
41}
42
43/// AtMostBytes encodes an upper bound on the inner container's size in bytes.
44///
45/// If `inner` is a Vec<U>, its size in bytes (as defined by
46/// `SizeOfContents::size_of_contents_in_bytes`) must be at most
47/// `UPPER_BOUND_ON_SIZE_IN_BYTES`.
48#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
49pub struct AtMostBytes<const UPPER_BOUND_ON_SIZE_IN_BYTES: usize, T> {
50    inner: T,
51}
52
53impl<const UPPER_BOUND_ON_SIZE_IN_BYTES: usize, T> TryFrom<Vec<T>>
54    for AtMostBytes<UPPER_BOUND_ON_SIZE_IN_BYTES, Vec<T>>
55{
56    type Error = (Error, Vec<T>);
57    fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
58        if value.size_of_contents_in_bytes() <= UPPER_BOUND_ON_SIZE_IN_BYTES {
59            Ok(AtMostBytes { inner: value })
60        } else {
61            Err((Error::SizeConstraintViolated, value))
62        }
63    }
64}
65
66impl<const LOWER_BOUND_ON_NUMBER_OF_ELEMENTS: usize, const UPPER_BOUND_ON_SIZE_IN_BYTES: usize, T>
67    TryFrom<Vec<T>>
68    for AtLeast<
69        LOWER_BOUND_ON_NUMBER_OF_ELEMENTS,
70        AtMostBytes<UPPER_BOUND_ON_SIZE_IN_BYTES, Vec<T>>,
71    >
72{
73    type Error = (Error, Vec<T>);
74
75    fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
76        AtLeast::try_from(AtMostBytes::try_from(value)?)
77    }
78}
79
80impl<const N: usize, T> From<AtMostBytes<N, Vec<T>>> for Vec<T> {
81    fn from(value: AtMostBytes<N, Vec<T>>) -> Self {
82        let AtMostBytes { inner } = value;
83        inner
84    }
85}
86
87impl<const M: usize, const N: usize, T> From<AtLeast<M, AtMostBytes<N, Vec<T>>>> for Vec<T> {
88    fn from(value: AtLeast<M, AtMostBytes<N, Vec<T>>>) -> Self {
89        let AtLeast { inner: AtMostBytes { inner } } = value;
90        inner
91    }
92}
93
94macro_rules! impl_for_both {
95    ($trait_name:ty, { $($tail:tt)* }) => {
96        impl<const N: usize, T> $trait_name for AtMostBytes<N, Vec<T>> {
97            $($tail)*
98        }
99
100        impl<const M: usize, const N: usize, T> $trait_name for AtLeast<M, AtMostBytes<N, Vec<T>>> {
101            $($tail)*
102        }
103    }
104}
105
106impl_for_both!(Deref, {
107    type Target = Vec<T>;
108
109    fn deref(&self) -> &Self::Target {
110        let Self { inner } = self;
111        inner
112    }
113});
114
115impl_for_both!(IntoIterator, {
116    type Item = T;
117    type IntoIter = <Vec<T> as IntoIterator>::IntoIter;
118
119    fn into_iter(self) -> Self::IntoIter {
120        let vec = Vec::from(self);
121        vec.into_iter()
122    }
123});
124
125impl_for_both!(AsRef<[T]>, {
126    fn as_ref(&self) -> &[T] {
127        self.deref().as_ref()
128    }
129});
130
131// We can delete a lot of this manual repetition once generic_const_exprs
132// is stabilized. https://github.com/rust-lang/rust/issues/76560
133
134trait IsAtMost4Bytes {}
135
136macro_rules! impl_at_most_4_bytes {
137    ($t:ty) => {
138        impl IsAtMost4Bytes for $t {}
139        static_assertions::const_assert!(
140            { std::mem::size_of::<$t>() } <= { std::mem::size_of::<u8>() * 4 }
141        );
142    };
143    ($t:ty, $($tail:ty),*) => {
144        impl_at_most_4_bytes!($t);
145        impl_at_most_4_bytes!($($tail),*);
146    }
147}
148
149impl_at_most_4_bytes!(u8, u16, u32, crate::OptionCode, std::net::Ipv4Addr);
150
151const U8_MAX_DIVIDED_BY_4: usize = u8::MAX as usize / 4;
152
153enum Num<const X: usize> {}
154
155trait GreaterThanOrEqualTo<const Y: usize> {}
156
157macro_rules! impl_ge {
158    ([ $lhs:literal ] >= $rhs:literal) => {
159        impl GreaterThanOrEqualTo<$rhs> for Num<$lhs> {}
160        ::static_assertions::const_assert!($lhs >= $rhs);
161    };
162    ([ $lhs:literal, $($ltail:literal),* ] >= $rhs:literal) => {
163        impl_ge!([ $lhs ] >= $rhs);
164
165        impl_ge!([ $($ltail),* ] >= $rhs);
166    }
167}
168
169impl_ge!([1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >= 1);
170impl_ge!([2, 3, 4, 5, 6, 7, 8, 9, 10] >= 2);
171
172trait LessThanOrEqualTo<const Y: usize> {}
173
174macro_rules! impl_le {
175    ([ $lhs:literal ] <= $rhs:literal) => {
176        impl LessThanOrEqualTo<$rhs> for Num<$lhs> {}
177        ::static_assertions::const_assert!($lhs <= $rhs);
178    };
179    ([ $lhs:literal, $($ltail:literal),* ] <= $rhs:literal) => {
180        impl_le!([ $lhs ] <= $rhs);
181
182        impl_le!([ $($ltail),* ] <= $rhs);
183    }
184}
185
186impl_le!([1, 2, 3, 4, 5, 6, 7, 8, 9, 10] <= 63);
187
188impl<T, const LOWER_BOUND_ON_NUMBER_OF_ELEMENTS: usize, const N: usize> From<[T; N]>
189    for AtLeast<LOWER_BOUND_ON_NUMBER_OF_ELEMENTS, AtMostBytes<U8_MAX_AS_USIZE, Vec<T>>>
190where
191    T: IsAtMost4Bytes,
192    Num<N>: GreaterThanOrEqualTo<LOWER_BOUND_ON_NUMBER_OF_ELEMENTS>,
193    Num<N>: LessThanOrEqualTo<{ U8_MAX_DIVIDED_BY_4 }>,
194{
195    fn from(value: [T; N]) -> Self {
196        value.into_iter().collect::<Vec<_>>().try_into().unwrap_or_else(
197            |(Error::SizeConstraintViolated, _)| {
198                panic!(
199                    "should be statically known that \
200                 {N} >= {LOWER_BOUND_ON_NUMBER_OF_ELEMENTS} and \
201                 [{type_name}; {N}] fits within {U8_MAX_AS_USIZE} bytes",
202                    type_name = std::any::type_name::<T>()
203                )
204            },
205        )
206    }
207}
208
209#[derive(Debug, PartialEq, thiserror::Error)]
210pub enum Error {
211    #[error("size constraint violated")]
212    SizeConstraintViolated,
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use std::fmt::Debug;
219
220    fn run_edge_case<T: Copy + Debug + PartialEq>(item: T, max_number_allowed: impl Into<usize>) {
221        let max_number_allowed = max_number_allowed.into();
222        let v = std::iter::repeat(item).take(max_number_allowed).collect::<Vec<_>>();
223        assert_eq!(
224            AtLeast::<1, AtMostBytes<{ U8_MAX_AS_USIZE }, _>>::try_from(v.clone()),
225            Ok(AtLeast { inner: AtMostBytes { inner: v } }),
226            "{max_number_allowed} instances of {} should fit in 255 bytes",
227            std::any::type_name::<T>(),
228        );
229
230        let v = std::iter::repeat(item).take(max_number_allowed + 1).collect::<Vec<_>>();
231        assert_eq!(
232            AtLeast::<1, AtMostBytes<{ U8_MAX_AS_USIZE }, _>>::try_from(v.clone()),
233            Err((Error::SizeConstraintViolated, v)),
234            "{max_number_allowed} instances of {} should not fit in 255 bytes",
235            std::any::type_name::<T>(),
236        );
237    }
238
239    #[test]
240    fn edge_cases() {
241        run_edge_case(1u8, u8::MAX);
242        run_edge_case(1u32, u8::MAX / 4);
243        run_edge_case(1u64, u8::MAX / 8);
244    }
245
246    #[test]
247    fn disallows_empty() {
248        assert_eq!(
249            AtLeast::<1, AtMostBytes<{ U8_MAX_AS_USIZE }, _>>::try_from(Vec::<u8>::new()),
250            Err((Error::SizeConstraintViolated, Vec::new()))
251        )
252    }
253}