p256/arithmetic/
scalar.rs

1//! Scalar field arithmetic modulo n = 115792089210356248762697446949407573529996955224135760342422259061068512044369
2
3pub mod blinded;
4
5use crate::{
6    arithmetic::util::{adc, mac, sbb},
7    FieldBytes, NistP256, SecretKey,
8};
9use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10use elliptic_curve::{
11    bigint::{prelude::*, Limb, U256},
12    generic_array::arr,
13    group::ff::{Field, PrimeField},
14    ops::{Reduce, ReduceNonZero},
15    rand_core::RngCore,
16    subtle::{
17        Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
18        CtOption,
19    },
20    zeroize::DefaultIsZeroes,
21    Curve, IsHigh, ScalarArithmetic, ScalarCore,
22};
23
24#[cfg(feature = "bits")]
25use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
26
27#[cfg(feature = "serde")]
28use serdect::serde::{de, ser, Deserialize, Serialize};
29
30/// Array containing 4 x 64-bit unsigned integers.
31// TODO(tarcieri): replace this entirely with `U256`
32type U64x4 = [u64; 4];
33
34/// Constant representing the modulus
35/// n = FFFFFFFF 00000000 FFFFFFFF FFFFFFFF BCE6FAAD A7179E84 F3B9CAC2 FC632551
36const MODULUS: U64x4 = u256_to_u64x4(NistP256::ORDER);
37
38const FRAC_MODULUS_2: Scalar = Scalar(NistP256::ORDER.shr_vartime(1));
39
40/// MU = floor(2^512 / n)
41///    = 115792089264276142090721624801893421302707618245269942344307673200490803338238
42///    = 0x100000000fffffffffffffffeffffffff43190552df1a6c21012ffd85eedf9bfe
43pub const MU: [u64; 5] = [
44    0x012f_fd85_eedf_9bfe,
45    0x4319_0552_df1a_6c21,
46    0xffff_fffe_ffff_ffff,
47    0x0000_0000_ffff_ffff,
48    0x0000_0000_0000_0001,
49];
50
51impl ScalarArithmetic for NistP256 {
52    type Scalar = Scalar;
53}
54
55/// Scalars are elements in the finite field modulo n.
56///
57/// # Trait impls
58///
59/// Much of the important functionality of scalars is provided by traits from
60/// the [`ff`](https://docs.rs/ff/) crate, which is re-exported as
61/// `p256::elliptic_curve::ff`:
62///
63/// - [`Field`](https://docs.rs/ff/latest/ff/trait.Field.html) -
64///   represents elements of finite fields and provides:
65///   - [`Field::random`](https://docs.rs/ff/latest/ff/trait.Field.html#tymethod.random) -
66///     generate a random scalar
67///   - `double`, `square`, and `invert` operations
68///   - Bounds for [`Add`], [`Sub`], [`Mul`], and [`Neg`] (as well as `*Assign` equivalents)
69///   - Bounds for [`ConditionallySelectable`] from the `subtle` crate
70/// - [`PrimeField`](https://docs.rs/ff/0.9.0/ff/trait.PrimeField.html) -
71///   represents elements of prime fields and provides:
72///   - `from_repr`/`to_repr` for converting field elements from/to big integers.
73///   - `multiplicative_generator` and `root_of_unity` constants.
74/// - [`PrimeFieldBits`](https://docs.rs/ff/latest/ff/trait.PrimeFieldBits.html) -
75///   operations over field elements represented as bits (requires `bits` feature)
76///
77/// Please see the documentation for the relevant traits for more information.
78///
79/// # `serde` support
80///
81/// When the `serde` feature of this crate is enabled, the `Serialize` and
82/// `Deserialize` traits are impl'd for this type.
83///
84/// The serialization is a fixed-width big endian encoding. When used with
85/// textual formats, the binary data is encoded as hexadecimal.
86#[derive(Clone, Copy, Debug, Default)]
87#[cfg_attr(docsrs, doc(cfg(feature = "arithmetic")))]
88pub struct Scalar(pub(crate) U256);
89
90impl Scalar {
91    /// Zero scalar.
92    pub const ZERO: Self = Self(U256::ZERO);
93
94    /// Multiplicative identity.
95    pub const ONE: Self = Self(U256::ONE);
96
97    /// Returns the SEC1 encoding of this scalar.
98    pub fn to_bytes(&self) -> FieldBytes {
99        self.0.to_be_byte_array()
100    }
101
102    /// Returns self + rhs mod n
103    pub const fn add(&self, rhs: &Self) -> Self {
104        Self(self.0.add_mod(&rhs.0, &NistP256::ORDER))
105    }
106
107    /// Returns 2*self.
108    pub const fn double(&self) -> Self {
109        self.add(self)
110    }
111
112    /// Returns self - rhs mod n.
113    pub const fn sub(&self, rhs: &Self) -> Self {
114        Self(self.0.sub_mod(&rhs.0, &NistP256::ORDER))
115    }
116
117    /// Returns self * rhs mod n
118    pub const fn mul(&self, rhs: &Self) -> Self {
119        let (lo, hi) = self.0.mul_wide(&rhs.0);
120        Self::barrett_reduce(lo, hi)
121    }
122
123    /// Returns self * self mod p
124    pub const fn square(&self) -> Self {
125        // Schoolbook multiplication.
126        self.mul(self)
127    }
128
129    /// Returns the multiplicative inverse of self, if self is non-zero
130    pub fn invert(&self) -> CtOption<Self> {
131        // We need to find b such that b * a ≡ 1 mod p. As we are in a prime
132        // field, we can apply Fermat's Little Theorem:
133        //
134        //    a^p         ≡ a mod p
135        //    a^(p-1)     ≡ 1 mod p
136        //    a^(p-2) * a ≡ 1 mod p
137        //
138        // Thus inversion can be implemented with a single exponentiation.
139        //
140        // This is `n - 2`, so the top right two digits are `4f` instead of `51`.
141        let inverse = self.pow_vartime(&[
142            0xf3b9_cac2_fc63_254f,
143            0xbce6_faad_a717_9e84,
144            0xffff_ffff_ffff_ffff,
145            0xffff_ffff_0000_0000,
146        ]);
147
148        CtOption::new(inverse, !self.is_zero())
149    }
150
151    /// Faster inversion using Stein's algorithm
152    #[allow(non_snake_case)]
153    pub fn invert_vartime(&self) -> CtOption<Self> {
154        // https://link.springer.com/article/10.1007/s13389-016-0135-4
155
156        let mut u = *self;
157        // currently an invalid scalar
158        let mut v = Scalar(NistP256::ORDER);
159        let mut A = Self::one();
160        let mut C = Self::zero();
161
162        while !bool::from(u.is_zero()) {
163            // u-loop
164            while bool::from(u.is_even()) {
165                u.shr1();
166
167                let was_odd: bool = A.is_odd().into();
168                A.shr1();
169
170                if was_odd {
171                    A += FRAC_MODULUS_2;
172                    A += Self::one();
173                }
174            }
175
176            // v-loop
177            while bool::from(v.is_even()) {
178                v.shr1();
179
180                let was_odd: bool = C.is_odd().into();
181                C.shr1();
182
183                if was_odd {
184                    C += FRAC_MODULUS_2;
185                    C += Self::one();
186                }
187            }
188
189            // sub-step
190            if u >= v {
191                u -= &v;
192                A -= &C;
193            } else {
194                v -= &u;
195                C -= &A;
196            }
197        }
198
199        CtOption::new(C, !self.is_zero())
200    }
201
202    /// Is integer representing equivalence class odd?
203    pub fn is_odd(&self) -> Choice {
204        self.0.is_odd()
205    }
206
207    /// Is integer representing equivalence class even?
208    pub fn is_even(&self) -> Choice {
209        !self.is_odd()
210    }
211
212    /// Barrett Reduction
213    ///
214    /// The general algorithm is:
215    /// ```text
216    /// p = n = order of group
217    /// b = 2^64 = 64bit machine word
218    /// k = 4
219    /// a \in [0, 2^512]
220    /// mu := floor(b^{2k} / p)
221    /// q1 := floor(a / b^{k - 1})
222    /// q2 := q1 * mu
223    /// q3 := <- floor(a / b^{k - 1})
224    /// r1 := a mod b^{k + 1}
225    /// r2 := q3 * m mod b^{k + 1}
226    /// r := r1 - r2
227    ///
228    /// if r < 0: r := r + b^{k + 1}
229    /// while r >= p: do r := r - p (at most twice)
230    /// ```
231    ///
232    /// References:
233    /// - Handbook of Applied Cryptography, Chapter 14
234    ///   Algorithm 14.42
235    ///   http://cacr.uwaterloo.ca/hac/about/chap14.pdf
236    ///
237    /// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256
238    ///   Algorithm 6) Barrett Reduction modulo p
239    ///   https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf
240    #[inline]
241    #[allow(clippy::too_many_arguments)]
242    const fn barrett_reduce(lo: U256, hi: U256) -> Self {
243        let lo = u256_to_u64x4(lo);
244        let hi = u256_to_u64x4(hi);
245        let a0 = lo[0];
246        let a1 = lo[1];
247        let a2 = lo[2];
248        let a3 = lo[3];
249        let a4 = hi[0];
250        let a5 = hi[1];
251        let a6 = hi[2];
252        let a7 = hi[3];
253        let q1: [u64; 5] = [a3, a4, a5, a6, a7];
254
255        const fn q1_times_mu_shift_five(q1: &[u64; 5]) -> [u64; 5] {
256            // Schoolbook multiplication.
257
258            let (_w0, carry) = mac(0, q1[0], MU[0], 0);
259            let (w1, carry) = mac(0, q1[0], MU[1], carry);
260            let (w2, carry) = mac(0, q1[0], MU[2], carry);
261            let (w3, carry) = mac(0, q1[0], MU[3], carry);
262            let (w4, w5) = mac(0, q1[0], MU[4], carry);
263
264            let (_w1, carry) = mac(w1, q1[1], MU[0], 0);
265            let (w2, carry) = mac(w2, q1[1], MU[1], carry);
266            let (w3, carry) = mac(w3, q1[1], MU[2], carry);
267            let (w4, carry) = mac(w4, q1[1], MU[3], carry);
268            let (w5, w6) = mac(w5, q1[1], MU[4], carry);
269
270            let (_w2, carry) = mac(w2, q1[2], MU[0], 0);
271            let (w3, carry) = mac(w3, q1[2], MU[1], carry);
272            let (w4, carry) = mac(w4, q1[2], MU[2], carry);
273            let (w5, carry) = mac(w5, q1[2], MU[3], carry);
274            let (w6, w7) = mac(w6, q1[2], MU[4], carry);
275
276            let (_w3, carry) = mac(w3, q1[3], MU[0], 0);
277            let (w4, carry) = mac(w4, q1[3], MU[1], carry);
278            let (w5, carry) = mac(w5, q1[3], MU[2], carry);
279            let (w6, carry) = mac(w6, q1[3], MU[3], carry);
280            let (w7, w8) = mac(w7, q1[3], MU[4], carry);
281
282            let (_w4, carry) = mac(w4, q1[4], MU[0], 0);
283            let (w5, carry) = mac(w5, q1[4], MU[1], carry);
284            let (w6, carry) = mac(w6, q1[4], MU[2], carry);
285            let (w7, carry) = mac(w7, q1[4], MU[3], carry);
286            let (w8, w9) = mac(w8, q1[4], MU[4], carry);
287
288            // let q2 = [_w0, _w1, _w2, _w3, _w4, w5, w6, w7, w8, w9];
289            [w5, w6, w7, w8, w9]
290        }
291
292        let q3 = q1_times_mu_shift_five(&q1);
293
294        let r1: [u64; 5] = [a0, a1, a2, a3, a4];
295
296        const fn q3_times_n_keep_five(q3: &[u64; 5]) -> [u64; 5] {
297            // Schoolbook multiplication.
298
299            let (w0, carry) = mac(0, q3[0], MODULUS[0], 0);
300            let (w1, carry) = mac(0, q3[0], MODULUS[1], carry);
301            let (w2, carry) = mac(0, q3[0], MODULUS[2], carry);
302            let (w3, carry) = mac(0, q3[0], MODULUS[3], carry);
303            let (w4, _) = mac(0, q3[0], 0, carry);
304
305            let (w1, carry) = mac(w1, q3[1], MODULUS[0], 0);
306            let (w2, carry) = mac(w2, q3[1], MODULUS[1], carry);
307            let (w3, carry) = mac(w3, q3[1], MODULUS[2], carry);
308            let (w4, _) = mac(w4, q3[1], MODULUS[3], carry);
309
310            let (w2, carry) = mac(w2, q3[2], MODULUS[0], 0);
311            let (w3, carry) = mac(w3, q3[2], MODULUS[1], carry);
312            let (w4, _) = mac(w4, q3[2], MODULUS[2], carry);
313
314            let (w3, carry) = mac(w3, q3[3], MODULUS[0], 0);
315            let (w4, _) = mac(w4, q3[3], MODULUS[1], carry);
316
317            let (w4, _) = mac(w4, q3[4], MODULUS[0], 0);
318
319            [w0, w1, w2, w3, w4]
320        }
321
322        let r2: [u64; 5] = q3_times_n_keep_five(&q3);
323
324        #[inline]
325        #[allow(clippy::too_many_arguments)]
326        const fn sub_inner_five(l: [u64; 5], r: [u64; 5]) -> [u64; 5] {
327            let (w0, borrow) = sbb(l[0], r[0], 0);
328            let (w1, borrow) = sbb(l[1], r[1], borrow);
329            let (w2, borrow) = sbb(l[2], r[2], borrow);
330            let (w3, borrow) = sbb(l[3], r[3], borrow);
331            let (w4, _borrow) = sbb(l[4], r[4], borrow);
332
333            // If underflow occurred on the final limb - don't care (= add b^{k+1}).
334            [w0, w1, w2, w3, w4]
335        }
336
337        let r: [u64; 5] = sub_inner_five(r1, r2);
338
339        #[inline]
340        #[allow(clippy::too_many_arguments)]
341        const fn subtract_n_if_necessary(r0: u64, r1: u64, r2: u64, r3: u64, r4: u64) -> [u64; 5] {
342            let (w0, borrow) = sbb(r0, MODULUS[0], 0);
343            let (w1, borrow) = sbb(r1, MODULUS[1], borrow);
344            let (w2, borrow) = sbb(r2, MODULUS[2], borrow);
345            let (w3, borrow) = sbb(r3, MODULUS[3], borrow);
346            let (w4, borrow) = sbb(r4, 0, borrow);
347
348            // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise
349            // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the
350            // modulus.
351            let (w0, carry) = adc(w0, MODULUS[0] & borrow, 0);
352            let (w1, carry) = adc(w1, MODULUS[1] & borrow, carry);
353            let (w2, carry) = adc(w2, MODULUS[2] & borrow, carry);
354            let (w3, carry) = adc(w3, MODULUS[3] & borrow, carry);
355            let (w4, _carry) = adc(w4, 0, carry);
356
357            [w0, w1, w2, w3, w4]
358        }
359
360        // Result is in range (0, 3*n - 1),
361        // and 90% of the time, no subtraction will be needed.
362        let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]);
363        let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]);
364        Scalar::from_u64x4_unchecked([r[0], r[1], r[2], r[3]])
365    }
366
367    /// Perform unchecked conversion from a U64x4 to a Scalar.
368    ///
369    /// Note: this does *NOT* ensure that the provided value is less than `MODULUS`.
370    // TODO(tarcieri): implement all algorithms in terms of `U256`?
371    #[cfg(target_pointer_width = "32")]
372    const fn from_u64x4_unchecked(limbs: U64x4) -> Self {
373        Self(U256::from_uint_array([
374            (limbs[0] & 0xFFFFFFFF) as u32,
375            (limbs[0] >> 32) as u32,
376            (limbs[1] & 0xFFFFFFFF) as u32,
377            (limbs[1] >> 32) as u32,
378            (limbs[2] & 0xFFFFFFFF) as u32,
379            (limbs[2] >> 32) as u32,
380            (limbs[3] & 0xFFFFFFFF) as u32,
381            (limbs[3] >> 32) as u32,
382        ]))
383    }
384
385    /// Perform unchecked conversion from a U64x4 to a Scalar.
386    ///
387    /// Note: this does *NOT* ensure that the provided value is less than `MODULUS`.
388    // TODO(tarcieri): implement all algorithms in terms of `U256`?
389    #[cfg(target_pointer_width = "64")]
390    const fn from_u64x4_unchecked(limbs: U64x4) -> Self {
391        Self(U256::from_uint_array(limbs))
392    }
393
394    /// Shift right by one bit
395    fn shr1(&mut self) {
396        self.0 >>= 1;
397    }
398}
399
400impl Field for Scalar {
401    fn random(mut rng: impl RngCore) -> Self {
402        let mut bytes = FieldBytes::default();
403
404        // Generate a uniformly random scalar using rejection sampling,
405        // which produces a uniformly random distribution of scalars.
406        //
407        // This method is not constant time, but should be secure so long as
408        // rejected RNG outputs are unrelated to future ones (which is a
409        // necessary property of a `CryptoRng`).
410        //
411        // With an unbiased RNG, the probability of failing to complete after 4
412        // iterations is vanishingly small.
413        loop {
414            rng.fill_bytes(&mut bytes);
415            if let Some(scalar) = Scalar::from_repr(bytes).into() {
416                return scalar;
417            }
418        }
419    }
420
421    fn zero() -> Self {
422        Self::ZERO
423    }
424
425    fn one() -> Self {
426        Self::ONE
427    }
428
429    #[must_use]
430    fn square(&self) -> Self {
431        Scalar::square(self)
432    }
433
434    #[must_use]
435    fn double(&self) -> Self {
436        self.add(self)
437    }
438
439    fn invert(&self) -> CtOption<Self> {
440        Scalar::invert(self)
441    }
442
443    /// Tonelli-Shank's algorithm for q mod 16 = 1
444    /// <https://eprint.iacr.org/2012/685.pdf> (page 12, algorithm 5)
445    #[allow(clippy::many_single_char_names)]
446    fn sqrt(&self) -> CtOption<Self> {
447        // Note: `pow_vartime` is constant-time with respect to `self`
448        let w = self.pow_vartime(&[
449            0x279dce5617e3192a,
450            0xfde737d56d38bcf4,
451            0x07ffffffffffffff,
452            0x07fffffff8000000,
453        ]);
454
455        let mut v = Self::S;
456        let mut x = *self * w;
457        let mut b = x * w;
458        let mut z = Self::root_of_unity();
459
460        for max_v in (1..=Self::S).rev() {
461            let mut k = 1;
462            let mut tmp = b.square();
463            let mut j_less_than_v = Choice::from(1);
464
465            for j in 2..max_v {
466                let tmp_is_one = tmp.ct_eq(&Self::one());
467                let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
468                tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
469                let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
470                j_less_than_v &= !j.ct_eq(&v);
471                k = u32::conditional_select(&j, &k, tmp_is_one);
472                z = Self::conditional_select(&z, &new_z, j_less_than_v);
473            }
474
475            let result = x * z;
476            x = Self::conditional_select(&result, &x, b.ct_eq(&Self::one()));
477            z = z.square();
478            b *= z;
479            v = k;
480        }
481
482        CtOption::new(x, x.square().ct_eq(self))
483    }
484}
485
486impl PrimeField for Scalar {
487    type Repr = FieldBytes;
488
489    const NUM_BITS: u32 = 256;
490    const CAPACITY: u32 = 255;
491    const S: u32 = 4;
492
493    /// Attempts to parse the given byte array as an SEC1-encoded scalar.
494    ///
495    /// Returns None if the byte array does not contain a big-endian integer in the range
496    /// [0, p).
497    fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
498        let inner = U256::from_be_byte_array(bytes);
499        CtOption::new(Self(inner), inner.ct_lt(&NistP256::ORDER))
500    }
501
502    fn to_repr(&self) -> FieldBytes {
503        self.to_bytes()
504    }
505
506    fn is_odd(&self) -> Choice {
507        self.0.is_odd()
508    }
509
510    fn multiplicative_generator() -> Self {
511        7u64.into()
512    }
513
514    fn root_of_unity() -> Self {
515        Scalar::from_repr(arr![u8;
516            0xff, 0xc9, 0x7f, 0x06, 0x2a, 0x77, 0x09, 0x92, 0xba, 0x80, 0x7a, 0xce, 0x84, 0x2a,
517            0x3d, 0xfc, 0x15, 0x46, 0xca, 0xd0, 0x04, 0x37, 0x8d, 0xaf, 0x05, 0x92, 0xd7, 0xfb,
518            0xb4, 0x1e, 0x66, 0x02,
519        ])
520        .unwrap()
521    }
522}
523
524#[cfg(feature = "bits")]
525#[cfg_attr(docsrs, doc(cfg(feature = "bits")))]
526impl PrimeFieldBits for Scalar {
527    #[cfg(target_pointer_width = "32")]
528    type ReprBits = [u32; 8];
529
530    #[cfg(target_pointer_width = "64")]
531    type ReprBits = [u64; 4];
532
533    fn to_le_bits(&self) -> ScalarBits {
534        self.into()
535    }
536
537    fn char_le_bits() -> ScalarBits {
538        NistP256::ORDER.to_uint_array().into()
539    }
540}
541
542impl DefaultIsZeroes for Scalar {}
543
544impl Eq for Scalar {}
545
546impl IsHigh for Scalar {
547    fn is_high(&self) -> Choice {
548        self.0.ct_gt(&FRAC_MODULUS_2.0)
549    }
550}
551
552impl PartialEq for Scalar {
553    fn eq(&self, other: &Self) -> bool {
554        self.ct_eq(other).into()
555    }
556}
557
558impl PartialOrd for Scalar {
559    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
560        Some(self.cmp(other))
561    }
562}
563
564impl Ord for Scalar {
565    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
566        self.0.cmp(&other.0)
567    }
568}
569
570impl From<u64> for Scalar {
571    fn from(k: u64) -> Self {
572        Scalar(k.into())
573    }
574}
575
576impl From<Scalar> for FieldBytes {
577    fn from(scalar: Scalar) -> Self {
578        scalar.to_bytes()
579    }
580}
581
582impl From<&Scalar> for FieldBytes {
583    fn from(scalar: &Scalar) -> Self {
584        scalar.to_bytes()
585    }
586}
587
588impl From<ScalarCore<NistP256>> for Scalar {
589    fn from(scalar: ScalarCore<NistP256>) -> Scalar {
590        Scalar(*scalar.as_uint())
591    }
592}
593
594impl From<&ScalarCore<NistP256>> for Scalar {
595    fn from(scalar: &ScalarCore<NistP256>) -> Scalar {
596        Scalar(*scalar.as_uint())
597    }
598}
599
600impl From<Scalar> for ScalarCore<NistP256> {
601    fn from(scalar: Scalar) -> ScalarCore<NistP256> {
602        ScalarCore::from(&scalar)
603    }
604}
605
606impl From<&Scalar> for ScalarCore<NistP256> {
607    fn from(scalar: &Scalar) -> ScalarCore<NistP256> {
608        ScalarCore::new(scalar.0).unwrap()
609    }
610}
611
612impl From<&SecretKey> for Scalar {
613    fn from(secret_key: &SecretKey) -> Scalar {
614        *secret_key.to_nonzero_scalar()
615    }
616}
617
618impl From<Scalar> for U256 {
619    fn from(scalar: Scalar) -> U256 {
620        scalar.0
621    }
622}
623
624impl From<&Scalar> for U256 {
625    fn from(scalar: &Scalar) -> U256 {
626        scalar.0
627    }
628}
629
630#[cfg(feature = "bits")]
631#[cfg_attr(docsrs, doc(cfg(feature = "bits")))]
632impl From<&Scalar> for ScalarBits {
633    fn from(scalar: &Scalar) -> ScalarBits {
634        scalar.0.to_uint_array().into()
635    }
636}
637
638impl Add<Scalar> for Scalar {
639    type Output = Scalar;
640
641    fn add(self, other: Scalar) -> Scalar {
642        Scalar::add(&self, &other)
643    }
644}
645
646impl Add<&Scalar> for &Scalar {
647    type Output = Scalar;
648
649    fn add(self, other: &Scalar) -> Scalar {
650        Scalar::add(self, other)
651    }
652}
653
654impl Add<&Scalar> for Scalar {
655    type Output = Scalar;
656
657    fn add(self, other: &Scalar) -> Scalar {
658        Scalar::add(&self, other)
659    }
660}
661
662impl AddAssign<Scalar> for Scalar {
663    fn add_assign(&mut self, rhs: Scalar) {
664        *self = Scalar::add(self, &rhs);
665    }
666}
667
668impl AddAssign<&Scalar> for Scalar {
669    fn add_assign(&mut self, rhs: &Scalar) {
670        *self = Scalar::add(self, rhs);
671    }
672}
673
674impl Sub<Scalar> for Scalar {
675    type Output = Scalar;
676
677    fn sub(self, other: Scalar) -> Scalar {
678        Scalar::sub(&self, &other)
679    }
680}
681
682impl Sub<&Scalar> for &Scalar {
683    type Output = Scalar;
684
685    fn sub(self, other: &Scalar) -> Scalar {
686        Scalar::sub(self, other)
687    }
688}
689
690impl Sub<&Scalar> for Scalar {
691    type Output = Scalar;
692
693    fn sub(self, other: &Scalar) -> Scalar {
694        Scalar::sub(&self, other)
695    }
696}
697
698impl SubAssign<Scalar> for Scalar {
699    fn sub_assign(&mut self, rhs: Scalar) {
700        *self = Scalar::sub(self, &rhs);
701    }
702}
703
704impl SubAssign<&Scalar> for Scalar {
705    fn sub_assign(&mut self, rhs: &Scalar) {
706        *self = Scalar::sub(self, rhs);
707    }
708}
709
710impl Mul<Scalar> for Scalar {
711    type Output = Scalar;
712
713    fn mul(self, other: Scalar) -> Scalar {
714        Scalar::mul(&self, &other)
715    }
716}
717
718impl Mul<&Scalar> for &Scalar {
719    type Output = Scalar;
720
721    fn mul(self, other: &Scalar) -> Scalar {
722        Scalar::mul(self, other)
723    }
724}
725
726impl Mul<&Scalar> for Scalar {
727    type Output = Scalar;
728
729    fn mul(self, other: &Scalar) -> Scalar {
730        Scalar::mul(&self, other)
731    }
732}
733
734impl MulAssign<Scalar> for Scalar {
735    fn mul_assign(&mut self, rhs: Scalar) {
736        *self = Scalar::mul(self, &rhs);
737    }
738}
739
740impl MulAssign<&Scalar> for Scalar {
741    fn mul_assign(&mut self, rhs: &Scalar) {
742        *self = Scalar::mul(self, rhs);
743    }
744}
745
746impl Neg for Scalar {
747    type Output = Scalar;
748
749    fn neg(self) -> Scalar {
750        Scalar::zero() - self
751    }
752}
753
754impl<'a> Neg for &'a Scalar {
755    type Output = Scalar;
756
757    fn neg(self) -> Scalar {
758        Scalar::zero() - self
759    }
760}
761
762impl Reduce<U256> for Scalar {
763    fn from_uint_reduced(w: U256) -> Self {
764        let (r, underflow) = w.sbb(&NistP256::ORDER, Limb::ZERO);
765        let underflow = Choice::from((underflow.0 >> (Limb::BIT_SIZE - 1)) as u8);
766        Self(U256::conditional_select(&w, &r, !underflow))
767    }
768}
769
770impl ReduceNonZero<U256> for Scalar {
771    fn from_uint_reduced_nonzero(w: U256) -> Self {
772        const ORDER_MINUS_ONE: U256 = NistP256::ORDER.wrapping_sub(&U256::ONE);
773        let (r, underflow) = w.sbb(&ORDER_MINUS_ONE, Limb::ZERO);
774        let underflow = Choice::from((underflow.0 >> (Limb::BIT_SIZE - 1)) as u8);
775        Self(U256::conditional_select(&w, &r, !underflow).wrapping_add(&U256::ONE))
776    }
777}
778
779impl ConditionallySelectable for Scalar {
780    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
781        Self(U256::conditional_select(&a.0, &b.0, choice))
782    }
783}
784
785impl ConstantTimeEq for Scalar {
786    fn ct_eq(&self, other: &Self) -> Choice {
787        self.0.ct_eq(&other.0)
788    }
789}
790
791#[cfg(feature = "serde")]
792#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
793impl Serialize for Scalar {
794    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
795    where
796        S: ser::Serializer,
797    {
798        ScalarCore::from(self).serialize(serializer)
799    }
800}
801
802#[cfg(feature = "serde")]
803#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
804impl<'de> Deserialize<'de> for Scalar {
805    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
806    where
807        D: de::Deserializer<'de>,
808    {
809        Ok(ScalarCore::deserialize(deserializer)?.into())
810    }
811}
812
813/// Convert to a [`U64x4`] array.
814// TODO(tarcieri): implement all algorithms in terms of `U256`?
815#[cfg(target_pointer_width = "32")]
816pub(crate) const fn u256_to_u64x4(u256: U256) -> U64x4 {
817    let limbs = u256.to_uint_array();
818
819    [
820        (limbs[0] as u64) | ((limbs[1] as u64) << 32),
821        (limbs[2] as u64) | ((limbs[3] as u64) << 32),
822        (limbs[4] as u64) | ((limbs[5] as u64) << 32),
823        (limbs[6] as u64) | ((limbs[7] as u64) << 32),
824    ]
825}
826
827/// Convert to a [`U64x4`] array.
828// TODO(tarcieri): implement all algorithms in terms of `U256`?
829#[cfg(target_pointer_width = "64")]
830pub(crate) const fn u256_to_u64x4(u256: U256) -> U64x4 {
831    u256.to_uint_array()
832}
833
834#[cfg(test)]
835mod tests {
836    use super::Scalar;
837    use crate::{FieldBytes, SecretKey};
838    use elliptic_curve::group::ff::{Field, PrimeField};
839
840    #[test]
841    fn from_to_bytes_roundtrip() {
842        let k: u64 = 42;
843        let mut bytes = FieldBytes::default();
844        bytes[24..].copy_from_slice(k.to_be_bytes().as_ref());
845
846        let scalar = Scalar::from_repr(bytes).unwrap();
847        assert_eq!(bytes, scalar.to_bytes());
848    }
849
850    /// Basic tests that multiplication works.
851    #[test]
852    fn multiply() {
853        let one = Scalar::one();
854        let two = one + &one;
855        let three = two + &one;
856        let six = three + &three;
857        assert_eq!(six, two * &three);
858
859        let minus_two = -two;
860        let minus_three = -three;
861        assert_eq!(two, -minus_two);
862
863        assert_eq!(minus_three * &minus_two, minus_two * &minus_three);
864        assert_eq!(six, minus_two * &minus_three);
865    }
866
867    /// Basic tests that scalar inversion works.
868    #[test]
869    fn invert() {
870        let one = Scalar::one();
871        let three = one + &one + &one;
872        let inv_three = three.invert().unwrap();
873        // println!("1/3 = {:x?}", &inv_three);
874        assert_eq!(three * &inv_three, one);
875
876        let minus_three = -three;
877        // println!("-3 = {:x?}", &minus_three);
878        let inv_minus_three = minus_three.invert().unwrap();
879        assert_eq!(inv_minus_three, -inv_three);
880        // println!("-1/3 = {:x?}", &inv_minus_three);
881        assert_eq!(three * &inv_minus_three, -one);
882    }
883
884    /// Basic tests that sqrt works.
885    #[test]
886    fn sqrt() {
887        for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
888            let scalar = Scalar::from(n);
889            let sqrt = scalar.sqrt().unwrap();
890            assert_eq!(sqrt.square(), scalar);
891        }
892    }
893
894    /// Tests that a Scalar can be safely converted to a SecretKey and back
895    #[test]
896    fn from_ec_secret() {
897        let scalar = Scalar::one();
898        let secret = SecretKey::from_be_bytes(&scalar.to_bytes()).unwrap();
899        let rederived_scalar = Scalar::from(&secret);
900        assert_eq!(scalar.0, rederived_scalar.0);
901    }
902
903    #[test]
904    #[cfg(all(feature = "bits", target_pointer_width = "32"))]
905    fn scalar_into_scalarbits() {
906        use crate::ScalarBits;
907
908        let minus_one = ScalarBits::from([
909            0xfc63_2550,
910            0xf3b9_cac2,
911            0xa717_9e84,
912            0xbce6_faad,
913            0xffff_ffff,
914            0xffff_ffff,
915            0x0000_0000,
916            0xffff_ffff,
917        ]);
918
919        let scalar_bits = ScalarBits::from(&-Scalar::from(1));
920        assert_eq!(minus_one, scalar_bits);
921    }
922}