elliptic_curve/scalar/
nonzero.rs

1//! Non-zero scalar type.
2
3use crate::{
4    bigint::Encoding as _,
5    ops::{Invert, Reduce, ReduceNonZero},
6    rand_core::{CryptoRng, RngCore},
7    Curve, Error, FieldBytes, IsHigh, PrimeCurve, Result, Scalar, ScalarArithmetic, ScalarCore,
8    SecretKey,
9};
10use base16ct::HexDisplay;
11use core::{
12    fmt,
13    ops::{Deref, Mul, Neg},
14    str,
15};
16use crypto_bigint::{ArrayEncoding, Integer};
17use ff::{Field, PrimeField};
18use generic_array::GenericArray;
19use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
20use zeroize::Zeroize;
21
22/// Non-zero scalar type.
23///
24/// This type ensures that its value is not zero, ala `core::num::NonZero*`.
25/// To do this, the generic `S` type must impl both `Default` and
26/// `ConstantTimeEq`, with the requirement that `S::default()` returns 0.
27///
28/// In the context of ECC, it's useful for ensuring that scalar multiplication
29/// cannot result in the point at infinity.
30#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
31#[derive(Clone)]
32pub struct NonZeroScalar<C>
33where
34    C: Curve + ScalarArithmetic,
35{
36    scalar: Scalar<C>,
37}
38
39impl<C> NonZeroScalar<C>
40where
41    C: Curve + ScalarArithmetic,
42{
43    /// Generate a random `NonZeroScalar`.
44    pub fn random(mut rng: impl CryptoRng + RngCore) -> Self {
45        // Use rejection sampling to eliminate zero values.
46        // While this method isn't constant-time, the attacker shouldn't learn
47        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
48        loop {
49            if let Some(result) = Self::new(Field::random(&mut rng)).into() {
50                break result;
51            }
52        }
53    }
54
55    /// Create a [`NonZeroScalar`] from a scalar.
56    pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
57        CtOption::new(Self { scalar }, !scalar.is_zero())
58    }
59
60    /// Decode a [`NonZeroScalar`] from a big endian-serialized field element.
61    pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
62        Scalar::<C>::from_repr(repr).and_then(Self::new)
63    }
64
65    /// Create a [`NonZeroScalar`] from a `C::UInt`.
66    pub fn from_uint(uint: C::UInt) -> CtOption<Self> {
67        ScalarCore::new(uint).and_then(|scalar| Self::new(scalar.into()))
68    }
69}
70
71impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
72where
73    C: Curve + ScalarArithmetic,
74{
75    fn as_ref(&self) -> &Scalar<C> {
76        &self.scalar
77    }
78}
79
80impl<C> ConditionallySelectable for NonZeroScalar<C>
81where
82    C: Curve + ScalarArithmetic,
83{
84    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
85        Self {
86            scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
87        }
88    }
89}
90
91impl<C> ConstantTimeEq for NonZeroScalar<C>
92where
93    C: Curve + ScalarArithmetic,
94{
95    fn ct_eq(&self, other: &Self) -> Choice {
96        self.scalar.ct_eq(&other.scalar)
97    }
98}
99
100impl<C> Copy for NonZeroScalar<C> where C: Curve + ScalarArithmetic {}
101
102impl<C> Deref for NonZeroScalar<C>
103where
104    C: Curve + ScalarArithmetic,
105{
106    type Target = Scalar<C>;
107
108    fn deref(&self) -> &Scalar<C> {
109        &self.scalar
110    }
111}
112
113impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
114where
115    C: Curve + ScalarArithmetic,
116{
117    fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
118        Self::from(&scalar)
119    }
120}
121
122impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
123where
124    C: Curve + ScalarArithmetic,
125{
126    fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
127        scalar.to_repr()
128    }
129}
130
131impl<C> From<NonZeroScalar<C>> for ScalarCore<C>
132where
133    C: Curve + ScalarArithmetic,
134{
135    fn from(scalar: NonZeroScalar<C>) -> ScalarCore<C> {
136        ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
137    }
138}
139
140impl<C> From<&NonZeroScalar<C>> for ScalarCore<C>
141where
142    C: Curve + ScalarArithmetic,
143{
144    fn from(scalar: &NonZeroScalar<C>) -> ScalarCore<C> {
145        ScalarCore::from_be_bytes(scalar.to_repr()).unwrap()
146    }
147}
148
149impl<C> From<SecretKey<C>> for NonZeroScalar<C>
150where
151    C: Curve + ScalarArithmetic,
152{
153    fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
154        Self::from(&sk)
155    }
156}
157
158impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
159where
160    C: Curve + ScalarArithmetic,
161{
162    fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
163        let scalar = sk.as_scalar_core().to_scalar();
164        debug_assert!(!bool::from(scalar.is_zero()));
165        Self { scalar }
166    }
167}
168
169impl<C> Invert for NonZeroScalar<C>
170where
171    C: Curve + ScalarArithmetic,
172{
173    type Output = Self;
174
175    fn invert(&self) -> Self {
176        Self {
177            // This will always succeed since `scalar` will never be 0
178            scalar: ff::Field::invert(&self.scalar).unwrap(),
179        }
180    }
181}
182
183impl<C> IsHigh for NonZeroScalar<C>
184where
185    C: Curve + ScalarArithmetic,
186{
187    fn is_high(&self) -> Choice {
188        self.scalar.is_high()
189    }
190}
191
192impl<C> Neg for NonZeroScalar<C>
193where
194    C: Curve + ScalarArithmetic,
195{
196    type Output = NonZeroScalar<C>;
197
198    fn neg(self) -> NonZeroScalar<C> {
199        let scalar = -self.scalar;
200        debug_assert!(!bool::from(scalar.is_zero()));
201        NonZeroScalar { scalar }
202    }
203}
204
205impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
206where
207    C: PrimeCurve + ScalarArithmetic,
208{
209    type Output = Self;
210
211    #[inline]
212    fn mul(self, other: Self) -> Self {
213        Self::mul(self, &other)
214    }
215}
216
217impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
218where
219    C: PrimeCurve + ScalarArithmetic,
220{
221    type Output = Self;
222
223    fn mul(self, other: &Self) -> Self {
224        // Multiplication is modulo a prime, so the product of two non-zero
225        // scalars is also non-zero.
226        let scalar = self.scalar * other.scalar;
227        debug_assert!(!bool::from(scalar.is_zero()));
228        NonZeroScalar { scalar }
229    }
230}
231
232/// Note: implementation is the same as `ReduceNonZero`
233impl<C, I> Reduce<I> for NonZeroScalar<C>
234where
235    C: Curve + ScalarArithmetic,
236    I: Integer + ArrayEncoding,
237    Scalar<C>: ReduceNonZero<I>,
238{
239    fn from_uint_reduced(n: I) -> Self {
240        Self::from_uint_reduced_nonzero(n)
241    }
242}
243
244impl<C, I> ReduceNonZero<I> for NonZeroScalar<C>
245where
246    C: Curve + ScalarArithmetic,
247    I: Integer + ArrayEncoding,
248    Scalar<C>: ReduceNonZero<I>,
249{
250    fn from_uint_reduced_nonzero(n: I) -> Self {
251        let scalar = Scalar::<C>::from_uint_reduced_nonzero(n);
252        debug_assert!(!bool::from(scalar.is_zero()));
253        Self::new(scalar).unwrap()
254    }
255}
256
257impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
258where
259    C: Curve + ScalarArithmetic,
260{
261    type Error = Error;
262
263    fn try_from(bytes: &[u8]) -> Result<Self> {
264        if bytes.len() == C::UInt::BYTE_SIZE {
265            Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice(
266                bytes,
267            )))
268            .ok_or(Error)
269        } else {
270            Err(Error)
271        }
272    }
273}
274
275impl<C> Zeroize for NonZeroScalar<C>
276where
277    C: Curve + ScalarArithmetic,
278{
279    fn zeroize(&mut self) {
280        // Use zeroize's volatile writes to ensure value is cleared.
281        self.scalar.zeroize();
282
283        // Write a 1 instead of a 0 to ensure this type's non-zero invariant
284        // is upheld.
285        self.scalar = Scalar::<C>::one();
286    }
287}
288
289impl<C> fmt::Display for NonZeroScalar<C>
290where
291    C: Curve + ScalarArithmetic,
292{
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        write!(f, "{:X}", self)
295    }
296}
297
298impl<C> fmt::LowerHex for NonZeroScalar<C>
299where
300    C: Curve + ScalarArithmetic,
301{
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        write!(f, "{:x}", HexDisplay(&self.to_repr()))
304    }
305}
306
307impl<C> fmt::UpperHex for NonZeroScalar<C>
308where
309    C: Curve + ScalarArithmetic,
310{
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        write!(f, "{:}", HexDisplay(&self.to_repr()))
313    }
314}
315
316impl<C> str::FromStr for NonZeroScalar<C>
317where
318    C: Curve + ScalarArithmetic,
319{
320    type Err = Error;
321
322    fn from_str(hex: &str) -> Result<Self> {
323        let mut bytes = FieldBytes::<C>::default();
324
325        if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
326            Option::from(Self::from_repr(bytes)).ok_or(Error)
327        } else {
328            Err(Error)
329        }
330    }
331}
332
333#[cfg(all(test, feature = "dev"))]
334mod tests {
335    use crate::dev::{NonZeroScalar, Scalar};
336    use ff::{Field, PrimeField};
337    use hex_literal::hex;
338    use zeroize::Zeroize;
339
340    #[test]
341    fn round_trip() {
342        let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
343        let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
344        assert_eq!(&bytes, scalar.to_repr().as_slice());
345    }
346
347    #[test]
348    fn zeroize() {
349        let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
350        scalar.zeroize();
351        assert_eq!(*scalar, Scalar::one());
352    }
353}