1use crate::{Encoding, Integer, Limb, UInt, Zero};
4use core::{
5 fmt,
6 num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7 ops::Deref,
8};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "generic-array")]
12use crate::{ArrayEncoding, ByteArray};
13
14#[cfg(feature = "rand_core")]
15use {
16 crate::Random,
17 rand_core::{CryptoRng, RngCore},
18};
19
20#[cfg(feature = "serde")]
21use serdect::serde::{
22 de::{Error, Unexpected},
23 Deserialize, Deserializer, Serialize, Serializer,
24};
25
26#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
28pub struct NonZero<T: Zero>(T);
29
30impl<T> NonZero<T>
31where
32 T: Zero,
33{
34 pub fn new(n: T) -> CtOption<Self> {
36 let is_zero = n.is_zero();
37 CtOption::new(Self(n), !is_zero)
38 }
39}
40
41impl<T> NonZero<T>
42where
43 T: Integer,
44{
45 pub const ONE: Self = Self(T::ONE);
47
48 pub const MAX: Self = Self(T::MAX);
50}
51
52impl<T> NonZero<T>
53where
54 T: Encoding + Zero,
55{
56 pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
58 Self::new(T::from_be_bytes(bytes))
59 }
60
61 pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
63 Self::new(T::from_le_bytes(bytes))
64 }
65}
66
67#[cfg(feature = "generic-array")]
68#[cfg_attr(docsrs, doc(cfg(feature = "generic-array")))]
69impl<T> NonZero<T>
70where
71 T: ArrayEncoding + Zero,
72{
73 pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
75 Self::new(T::from_be_byte_array(bytes))
76 }
77
78 pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
80 Self::new(T::from_be_byte_array(bytes))
81 }
82}
83
84impl<T> AsRef<T> for NonZero<T>
85where
86 T: Zero,
87{
88 fn as_ref(&self) -> &T {
89 &self.0
90 }
91}
92
93impl<T> ConditionallySelectable for NonZero<T>
94where
95 T: ConditionallySelectable + Zero,
96{
97 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
98 Self(T::conditional_select(&a.0, &b.0, choice))
99 }
100}
101
102impl<T> ConstantTimeEq for NonZero<T>
103where
104 T: Zero,
105{
106 fn ct_eq(&self, other: &Self) -> Choice {
107 self.0.ct_eq(&other.0)
108 }
109}
110
111impl<T> Deref for NonZero<T>
112where
113 T: Zero,
114{
115 type Target = T;
116
117 fn deref(&self) -> &T {
118 &self.0
119 }
120}
121
122#[cfg(feature = "rand_core")]
123#[cfg_attr(docsrs, doc(cfg(feature = "rand_core")))]
124impl<T> Random for NonZero<T>
125where
126 T: Random + Zero,
127{
128 fn random(mut rng: impl CryptoRng + RngCore) -> Self {
130 loop {
134 if let Some(result) = Self::new(T::random(&mut rng)).into() {
135 break result;
136 }
137 }
138 }
139}
140
141impl<T> fmt::Display for NonZero<T>
142where
143 T: fmt::Display + Zero,
144{
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 fmt::Display::fmt(&self.0, f)
147 }
148}
149
150impl<T> fmt::Binary for NonZero<T>
151where
152 T: fmt::Binary + Zero,
153{
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 fmt::Binary::fmt(&self.0, f)
156 }
157}
158
159impl<T> fmt::Octal for NonZero<T>
160where
161 T: fmt::Octal + Zero,
162{
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 fmt::Octal::fmt(&self.0, f)
165 }
166}
167
168impl<T> fmt::LowerHex for NonZero<T>
169where
170 T: fmt::LowerHex + Zero,
171{
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 fmt::LowerHex::fmt(&self.0, f)
174 }
175}
176
177impl<T> fmt::UpperHex for NonZero<T>
178where
179 T: fmt::UpperHex + Zero,
180{
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 fmt::UpperHex::fmt(&self.0, f)
183 }
184}
185
186impl NonZero<Limb> {
187 pub const fn from_u8(n: NonZeroU8) -> Self {
190 Self(Limb::from_u8(n.get()))
191 }
192
193 pub const fn from_u16(n: NonZeroU16) -> Self {
196 Self(Limb::from_u16(n.get()))
197 }
198
199 pub const fn from_u32(n: NonZeroU32) -> Self {
202 Self(Limb::from_u32(n.get()))
203 }
204
205 #[cfg(target_pointer_width = "64")]
208 #[cfg_attr(docsrs, doc(cfg(target_pointer_width = "64")))]
209 pub const fn from_u64(n: NonZeroU64) -> Self {
210 Self(Limb::from_u64(n.get()))
211 }
212}
213
214impl From<NonZeroU8> for NonZero<Limb> {
215 fn from(integer: NonZeroU8) -> Self {
216 Self::from_u8(integer)
217 }
218}
219
220impl From<NonZeroU16> for NonZero<Limb> {
221 fn from(integer: NonZeroU16) -> Self {
222 Self::from_u16(integer)
223 }
224}
225
226impl From<NonZeroU32> for NonZero<Limb> {
227 fn from(integer: NonZeroU32) -> Self {
228 Self::from_u32(integer)
229 }
230}
231
232#[cfg(target_pointer_width = "64")]
233#[cfg_attr(docsrs, doc(cfg(target_pointer_width = "64")))]
234impl From<NonZeroU64> for NonZero<Limb> {
235 fn from(integer: NonZeroU64) -> Self {
236 Self::from_u64(integer)
237 }
238}
239
240impl<const LIMBS: usize> NonZero<UInt<LIMBS>> {
241 pub const fn from_uint(n: UInt<LIMBS>) -> Self {
243 let mut i = 0;
244 let mut found_non_zero = false;
245 while i < LIMBS {
246 if n.limbs()[i].0 != 0 {
247 found_non_zero = true;
248 }
249 i += 1;
250 }
251 assert!(found_non_zero, "found zero");
252 Self(n)
253 }
254
255 pub const fn from_u8(n: NonZeroU8) -> Self {
258 Self(UInt::from_u8(n.get()))
259 }
260
261 pub const fn from_u16(n: NonZeroU16) -> Self {
264 Self(UInt::from_u16(n.get()))
265 }
266
267 pub const fn from_u32(n: NonZeroU32) -> Self {
270 Self(UInt::from_u32(n.get()))
271 }
272
273 pub const fn from_u64(n: NonZeroU64) -> Self {
276 Self(UInt::from_u64(n.get()))
277 }
278
279 pub const fn from_u128(n: NonZeroU128) -> Self {
282 Self(UInt::from_u128(n.get()))
283 }
284}
285
286impl<const LIMBS: usize> From<NonZeroU8> for NonZero<UInt<LIMBS>> {
287 fn from(integer: NonZeroU8) -> Self {
288 Self::from_u8(integer)
289 }
290}
291
292impl<const LIMBS: usize> From<NonZeroU16> for NonZero<UInt<LIMBS>> {
293 fn from(integer: NonZeroU16) -> Self {
294 Self::from_u16(integer)
295 }
296}
297
298impl<const LIMBS: usize> From<NonZeroU32> for NonZero<UInt<LIMBS>> {
299 fn from(integer: NonZeroU32) -> Self {
300 Self::from_u32(integer)
301 }
302}
303
304impl<const LIMBS: usize> From<NonZeroU64> for NonZero<UInt<LIMBS>> {
305 fn from(integer: NonZeroU64) -> Self {
306 Self::from_u64(integer)
307 }
308}
309
310impl<const LIMBS: usize> From<NonZeroU128> for NonZero<UInt<LIMBS>> {
311 fn from(integer: NonZeroU128) -> Self {
312 Self::from_u128(integer)
313 }
314}
315
316#[cfg(feature = "serde")]
317#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
318impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
319 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
320 where
321 D: Deserializer<'de>,
322 {
323 let value: T = T::deserialize(deserializer)?;
324
325 if bool::from(value.is_zero()) {
326 Err(D::Error::invalid_value(
327 Unexpected::Other("zero"),
328 &"a non-zero value",
329 ))
330 } else {
331 Ok(Self(value))
332 }
333 }
334}
335
336#[cfg(feature = "serde")]
337#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
338impl<'de, T: Serialize + Zero> Serialize for NonZero<T> {
339 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340 where
341 S: Serializer,
342 {
343 self.0.serialize(serializer)
344 }
345}
346
347#[cfg(all(test, feature = "serde"))]
348mod tests {
349 use crate::{NonZero, U64};
350 use bincode::ErrorKind;
351
352 #[test]
353 fn serde() {
354 let test =
355 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
356
357 let serialized = bincode::serialize(&test).unwrap();
358 let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
359
360 assert_eq!(test, deserialized);
361
362 let serialized = bincode::serialize(&U64::ZERO).unwrap();
363 assert!(matches!(
364 *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
365 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
366 ));
367 }
368
369 #[test]
370 fn serde_owned() {
371 let test =
372 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
373
374 let serialized = bincode::serialize(&test).unwrap();
375 let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
376
377 assert_eq!(test, deserialized);
378
379 let serialized = bincode::serialize(&U64::ZERO).unwrap();
380 assert!(matches!(
381 *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
382 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
383 ));
384 }
385}