crypto_bigint/
non_zero.rs

1//! Wrapper type for non-zero integers.
2
3use crate::{Encoding, Integer, Limb, UInt, Zero};
4use core::{
5    fmt,
6    num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7    ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "generic-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {
16    crate::Random,
17    rand_core::{CryptoRng, RngCore},
18};
19
20#[cfg(feature = "serde")]
21use serdect::serde::{
22    de::{Error, Unexpected},
23    Deserialize, Deserializer, Serialize, Serializer,
24};
25
26/// Wrapper type for non-zero integers.
27#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
28pub struct NonZero<T: Zero>(T);
29
30impl<T> NonZero<T>
31where
32    T: Zero,
33{
34    /// Create a new non-zero integer.
35    pub fn new(n: T) -> CtOption<Self> {
36        let is_zero = n.is_zero();
37        CtOption::new(Self(n), !is_zero)
38    }
39}
40
41impl<T> NonZero<T>
42where
43    T: Integer,
44{
45    /// The value `1`.
46    pub const ONE: Self = Self(T::ONE);
47
48    /// Maximum value this integer can express.
49    pub const MAX: Self = Self(T::MAX);
50}
51
52impl<T> NonZero<T>
53where
54    T: Encoding + Zero,
55{
56    /// Decode from big endian bytes.
57    pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
58        Self::new(T::from_be_bytes(bytes))
59    }
60
61    /// Decode from little endian bytes.
62    pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
63        Self::new(T::from_le_bytes(bytes))
64    }
65}
66
67#[cfg(feature = "generic-array")]
68#[cfg_attr(docsrs, doc(cfg(feature = "generic-array")))]
69impl<T> NonZero<T>
70where
71    T: ArrayEncoding + Zero,
72{
73    /// Decode a non-zero integer from big endian bytes.
74    pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
75        Self::new(T::from_be_byte_array(bytes))
76    }
77
78    /// Decode a non-zero integer from big endian bytes.
79    pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
80        Self::new(T::from_be_byte_array(bytes))
81    }
82}
83
84impl<T> AsRef<T> for NonZero<T>
85where
86    T: Zero,
87{
88    fn as_ref(&self) -> &T {
89        &self.0
90    }
91}
92
93impl<T> ConditionallySelectable for NonZero<T>
94where
95    T: ConditionallySelectable + Zero,
96{
97    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
98        Self(T::conditional_select(&a.0, &b.0, choice))
99    }
100}
101
102impl<T> ConstantTimeEq for NonZero<T>
103where
104    T: Zero,
105{
106    fn ct_eq(&self, other: &Self) -> Choice {
107        self.0.ct_eq(&other.0)
108    }
109}
110
111impl<T> Deref for NonZero<T>
112where
113    T: Zero,
114{
115    type Target = T;
116
117    fn deref(&self) -> &T {
118        &self.0
119    }
120}
121
122#[cfg(feature = "rand_core")]
123#[cfg_attr(docsrs, doc(cfg(feature = "rand_core")))]
124impl<T> Random for NonZero<T>
125where
126    T: Random + Zero,
127{
128    /// Generate a random `NonZero<T>`.
129    fn random(mut rng: impl CryptoRng + RngCore) -> Self {
130        // Use rejection sampling to eliminate zero values.
131        // While this method isn't constant-time, the attacker shouldn't learn
132        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
133        loop {
134            if let Some(result) = Self::new(T::random(&mut rng)).into() {
135                break result;
136            }
137        }
138    }
139}
140
141impl<T> fmt::Display for NonZero<T>
142where
143    T: fmt::Display + Zero,
144{
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        fmt::Display::fmt(&self.0, f)
147    }
148}
149
150impl<T> fmt::Binary for NonZero<T>
151where
152    T: fmt::Binary + Zero,
153{
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        fmt::Binary::fmt(&self.0, f)
156    }
157}
158
159impl<T> fmt::Octal for NonZero<T>
160where
161    T: fmt::Octal + Zero,
162{
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        fmt::Octal::fmt(&self.0, f)
165    }
166}
167
168impl<T> fmt::LowerHex for NonZero<T>
169where
170    T: fmt::LowerHex + Zero,
171{
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        fmt::LowerHex::fmt(&self.0, f)
174    }
175}
176
177impl<T> fmt::UpperHex for NonZero<T>
178where
179    T: fmt::UpperHex + Zero,
180{
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        fmt::UpperHex::fmt(&self.0, f)
183    }
184}
185
186impl NonZero<Limb> {
187    /// Create a [`NonZero<Limb>`] from a [`NonZeroU8`] (const-friendly)
188    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
189    pub const fn from_u8(n: NonZeroU8) -> Self {
190        Self(Limb::from_u8(n.get()))
191    }
192
193    /// Create a [`NonZero<Limb>`] from a [`NonZeroU16`] (const-friendly)
194    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
195    pub const fn from_u16(n: NonZeroU16) -> Self {
196        Self(Limb::from_u16(n.get()))
197    }
198
199    /// Create a [`NonZero<Limb>`] from a [`NonZeroU32`] (const-friendly)
200    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
201    pub const fn from_u32(n: NonZeroU32) -> Self {
202        Self(Limb::from_u32(n.get()))
203    }
204
205    /// Create a [`NonZero<Limb>`] from a [`NonZeroU64`] (const-friendly)
206    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
207    #[cfg(target_pointer_width = "64")]
208    #[cfg_attr(docsrs, doc(cfg(target_pointer_width = "64")))]
209    pub const fn from_u64(n: NonZeroU64) -> Self {
210        Self(Limb::from_u64(n.get()))
211    }
212}
213
214impl From<NonZeroU8> for NonZero<Limb> {
215    fn from(integer: NonZeroU8) -> Self {
216        Self::from_u8(integer)
217    }
218}
219
220impl From<NonZeroU16> for NonZero<Limb> {
221    fn from(integer: NonZeroU16) -> Self {
222        Self::from_u16(integer)
223    }
224}
225
226impl From<NonZeroU32> for NonZero<Limb> {
227    fn from(integer: NonZeroU32) -> Self {
228        Self::from_u32(integer)
229    }
230}
231
232#[cfg(target_pointer_width = "64")]
233#[cfg_attr(docsrs, doc(cfg(target_pointer_width = "64")))]
234impl From<NonZeroU64> for NonZero<Limb> {
235    fn from(integer: NonZeroU64) -> Self {
236        Self::from_u64(integer)
237    }
238}
239
240impl<const LIMBS: usize> NonZero<UInt<LIMBS>> {
241    /// Create a [`NonZero<UInt>`] from a [`UInt`] (const-friendly)
242    pub const fn from_uint(n: UInt<LIMBS>) -> Self {
243        let mut i = 0;
244        let mut found_non_zero = false;
245        while i < LIMBS {
246            if n.limbs()[i].0 != 0 {
247                found_non_zero = true;
248            }
249            i += 1;
250        }
251        assert!(found_non_zero, "found zero");
252        Self(n)
253    }
254
255    /// Create a [`NonZero<UInt>`] from a [`NonZeroU8`] (const-friendly)
256    // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
257    pub const fn from_u8(n: NonZeroU8) -> Self {
258        Self(UInt::from_u8(n.get()))
259    }
260
261    /// Create a [`NonZero<UInt>`] from a [`NonZeroU16`] (const-friendly)
262    // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
263    pub const fn from_u16(n: NonZeroU16) -> Self {
264        Self(UInt::from_u16(n.get()))
265    }
266
267    /// Create a [`NonZero<UInt>`] from a [`NonZeroU32`] (const-friendly)
268    // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
269    pub const fn from_u32(n: NonZeroU32) -> Self {
270        Self(UInt::from_u32(n.get()))
271    }
272
273    /// Create a [`NonZero<UInt>`] from a [`NonZeroU64`] (const-friendly)
274    // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
275    pub const fn from_u64(n: NonZeroU64) -> Self {
276        Self(UInt::from_u64(n.get()))
277    }
278
279    /// Create a [`NonZero<UInt>`] from a [`NonZeroU128`] (const-friendly)
280    // TODO(tarcieri): replace with `const impl From<NonZeroU128>` when stable
281    pub const fn from_u128(n: NonZeroU128) -> Self {
282        Self(UInt::from_u128(n.get()))
283    }
284}
285
286impl<const LIMBS: usize> From<NonZeroU8> for NonZero<UInt<LIMBS>> {
287    fn from(integer: NonZeroU8) -> Self {
288        Self::from_u8(integer)
289    }
290}
291
292impl<const LIMBS: usize> From<NonZeroU16> for NonZero<UInt<LIMBS>> {
293    fn from(integer: NonZeroU16) -> Self {
294        Self::from_u16(integer)
295    }
296}
297
298impl<const LIMBS: usize> From<NonZeroU32> for NonZero<UInt<LIMBS>> {
299    fn from(integer: NonZeroU32) -> Self {
300        Self::from_u32(integer)
301    }
302}
303
304impl<const LIMBS: usize> From<NonZeroU64> for NonZero<UInt<LIMBS>> {
305    fn from(integer: NonZeroU64) -> Self {
306        Self::from_u64(integer)
307    }
308}
309
310impl<const LIMBS: usize> From<NonZeroU128> for NonZero<UInt<LIMBS>> {
311    fn from(integer: NonZeroU128) -> Self {
312        Self::from_u128(integer)
313    }
314}
315
316#[cfg(feature = "serde")]
317#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
318impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
319    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
320    where
321        D: Deserializer<'de>,
322    {
323        let value: T = T::deserialize(deserializer)?;
324
325        if bool::from(value.is_zero()) {
326            Err(D::Error::invalid_value(
327                Unexpected::Other("zero"),
328                &"a non-zero value",
329            ))
330        } else {
331            Ok(Self(value))
332        }
333    }
334}
335
336#[cfg(feature = "serde")]
337#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
338impl<'de, T: Serialize + Zero> Serialize for NonZero<T> {
339    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340    where
341        S: Serializer,
342    {
343        self.0.serialize(serializer)
344    }
345}
346
347#[cfg(all(test, feature = "serde"))]
348mod tests {
349    use crate::{NonZero, U64};
350    use bincode::ErrorKind;
351
352    #[test]
353    fn serde() {
354        let test =
355            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
356
357        let serialized = bincode::serialize(&test).unwrap();
358        let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
359
360        assert_eq!(test, deserialized);
361
362        let serialized = bincode::serialize(&U64::ZERO).unwrap();
363        assert!(matches!(
364            *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
365            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
366        ));
367    }
368
369    #[test]
370    fn serde_owned() {
371        let test =
372            Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
373
374        let serialized = bincode::serialize(&test).unwrap();
375        let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
376
377        assert_eq!(test, deserialized);
378
379        let serialized = bincode::serialize(&U64::ZERO).unwrap();
380        assert!(matches!(
381            *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
382            ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
383        ));
384    }
385}