1use crate::{
4 bigint::Encoding as _,
5 ops::{Invert, Reduce, ReduceNonZero},
6 rand_core::{CryptoRng, RngCore},
7 Curve, Error, FieldBytes, IsHigh, PrimeCurve, Result, Scalar, ScalarArithmetic, ScalarCore,
8 SecretKey,
9};
10use base16ct::HexDisplay;
11use core::{
12 fmt,
13 ops::{Deref, Mul, Neg},
14 str,
15};
16use crypto_bigint::{ArrayEncoding, Integer};
17use ff::{Field, PrimeField};
18use generic_array::GenericArray;
19use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
20use zeroize::Zeroize;
21
22#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
31#[derive(Clone)]
32pub struct NonZeroScalar<C>
33where
34 C: Curve + ScalarArithmetic,
35{
36 scalar: Scalar<C>,
37}
38
39impl<C> NonZeroScalar<C>
40where
41 C: Curve + ScalarArithmetic,
42{
43 pub fn random(mut rng: impl CryptoRng + RngCore) -> Self {
45 loop {
49 if let Some(result) = Self::new(Field::random(&mut rng)).into() {
50 break result;
51 }
52 }
53 }
54
55 pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
57 CtOption::new(Self { scalar }, !scalar.is_zero())
58 }
59
60 pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
62 Scalar::<C>::from_repr(repr).and_then(Self::new)
63 }
64
65 pub fn from_uint(uint: C::UInt) -> CtOption<Self> {
67 ScalarCore::new(uint).and_then(|scalar| Self::new(scalar.into()))
68 }
69}
70
71impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
72where
73 C: Curve + ScalarArithmetic,
74{
75 fn as_ref(&self) -> &Scalar<C> {
76 &self.scalar
77 }
78}
79
80impl<C> ConditionallySelectable for NonZeroScalar<C>
81where
82 C: Curve + ScalarArithmetic,
83{
84 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
85 Self {
86 scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
87 }
88 }
89}
90
91impl<C> ConstantTimeEq for NonZeroScalar<C>
92where
93 C: Curve + ScalarArithmetic,
94{
95 fn ct_eq(&self, other: &Self) -> Choice {
96 self.scalar.ct_eq(&other.scalar)
97 }
98}
99
100impl<C> Copy for NonZeroScalar<C> where C: Curve + ScalarArithmetic {}
101
102impl<C> Deref for NonZeroScalar<C>
103where
104 C: Curve + ScalarArithmetic,
105{
106 type Target = Scalar<C>;
107
108 fn deref(&self) -> &Scalar<C> {
109 &self.scalar
110 }
111}
112
113impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
114where
115 C: Curve + ScalarArithmetic,
116{
117 fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
118 Self::from(&scalar)
119 }
120}
121
122impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
123where
124 C: Curve + ScalarArithmetic,
125{
126 fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
127 scalar.to_repr()
128 }
129}
130
131impl<C> From<NonZeroScalar<C>> for ScalarCore<C>
132where
133 C: Curve + ScalarArithmetic,
134{
135 fn from(scalar: NonZeroScalar<C>) -> ScalarCore<C> {
136 ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
137 }
138}
139
140impl<C> From<&NonZeroScalar<C>> for ScalarCore<C>
141where
142 C: Curve + ScalarArithmetic,
143{
144 fn from(scalar: &NonZeroScalar<C>) -> ScalarCore<C> {
145 ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
146 }
147}
148
149impl<C> From<SecretKey<C>> for NonZeroScalar<C>
150where
151 C: Curve + ScalarArithmetic,
152{
153 fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
154 Self::from(&sk)
155 }
156}
157
158impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
159where
160 C: Curve + ScalarArithmetic,
161{
162 fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
163 let scalar = sk.as_scalar_core().to_scalar();
164 debug_assert!(!bool::from(scalar.is_zero()));
165 Self { scalar }
166 }
167}
168
169impl<C> Invert for NonZeroScalar<C>
170where
171 C: Curve + ScalarArithmetic,
172{
173 type Output = Self;
174
175 fn invert(&self) -> Self {
176 Self {
177 scalar: ff::Field::invert(&self.scalar).unwrap(),
179 }
180 }
181}
182
183impl<C> IsHigh for NonZeroScalar<C>
184where
185 C: Curve + ScalarArithmetic,
186{
187 fn is_high(&self) -> Choice {
188 self.scalar.is_high()
189 }
190}
191
192impl<C> Neg for NonZeroScalar<C>
193where
194 C: Curve + ScalarArithmetic,
195{
196 type Output = NonZeroScalar<C>;
197
198 fn neg(self) -> NonZeroScalar<C> {
199 let scalar = -self.scalar;
200 debug_assert!(!bool::from(scalar.is_zero()));
201 NonZeroScalar { scalar }
202 }
203}
204
205impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
206where
207 C: PrimeCurve + ScalarArithmetic,
208{
209 type Output = Self;
210
211 #[inline]
212 fn mul(self, other: Self) -> Self {
213 Self::mul(self, &other)
214 }
215}
216
217impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
218where
219 C: PrimeCurve + ScalarArithmetic,
220{
221 type Output = Self;
222
223 fn mul(self, other: &Self) -> Self {
224 let scalar = self.scalar * other.scalar;
227 debug_assert!(!bool::from(scalar.is_zero()));
228 NonZeroScalar { scalar }
229 }
230}
231
232impl<C, I> Reduce<I> for NonZeroScalar<C>
234where
235 C: Curve + ScalarArithmetic,
236 I: Integer + ArrayEncoding,
237 Scalar<C>: ReduceNonZero<I>,
238{
239 fn from_uint_reduced(n: I) -> Self {
240 Self::from_uint_reduced_nonzero(n)
241 }
242}
243
244impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
245where
246 C: Curve + ScalarArithmetic,
247 I: Integer + ArrayEncoding,
248 Scalar<C>: ReduceNonZero<I>,
249{
250 fn from_uint_reduced_nonzero(n: I) -> Self {
251 let scalar = Scalar::<C>::from_uint_reduced_nonzero(n);
252 debug_assert!(!bool::from(scalar.is_zero()));
253 Self::new(scalar).unwrap()
254 }
255}
256
257impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
258where
259 C: Curve + ScalarArithmetic,
260{
261 type Error = Error;
262
263 fn try_from(bytes: &[u8]) -> Result<Self> {
264 if bytes.len() == C::UInt::BYTE_SIZE {
265 Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
266 bytes,
267 )))
268 .ok_or(Error)
269 } else {
270 Err(Error)
271 }
272 }
273}
274
275impl<C> Zeroize for NonZeroScalar<C>
276where
277 C: Curve + ScalarArithmetic,
278{
279 fn zeroize(&mut self) {
280 self.scalar.zeroize();
282
283 self.scalar = Scalar::<C>::one();
286 }
287}
288
289impl<C> fmt::Display for NonZeroScalar<C>
290where
291 C: Curve + ScalarArithmetic,
292{
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(f, "{:X}", self)
295 }
296}
297
298impl<C> fmt::LowerHex for NonZeroScalar<C>
299where
300 C: Curve + ScalarArithmetic,
301{
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(f, "{:x}", HexDisplay(&self.to_repr()))
304 }
305}
306
307impl<C> fmt::UpperHex for NonZeroScalar<C>
308where
309 C: Curve + ScalarArithmetic,
310{
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 write!(f, "{:}", HexDisplay(&self.to_repr()))
313 }
314}
315
316impl<C> str::FromStr for NonZeroScalar<C>
317where
318 C: Curve + ScalarArithmetic,
319{
320 type Err = Error;
321
322 fn from_str(hex: &str) -> Result<Self> {
323 let mut bytes = FieldBytes::<C>::default();
324
325 if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
326 Option::from(Self::from_repr(bytes)).ok_or(Error)
327 } else {
328 Err(Error)
329 }
330 }
331}
332
333#[cfg(all(test, feature = "dev"))]
334mod tests {
335 use crate::dev::{NonZeroScalar, Scalar};
336 use ff::{Field, PrimeField};
337 use hex_literal::hex;
338 use zeroize::Zeroize;
339
340 #[test]
341 fn round_trip() {
342 let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
343 let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
344 assert_eq!(&bytes, scalar.to_repr().as_slice());
345 }
346
347 #[test]
348 fn zeroize() {
349 let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
350 scalar.zeroize();
351 assert_eq!(*scalar, Scalar::one());
352 }
353}