ring/arithmetic/bigint/
modulus.rs

1// Copyright 2015-2023 Brian Smith.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15use super::{BoxedLimbs, Elem, PublicModulus, Unencoded, N0};
16use crate::{
17    bits::BitLength,
18    cpu, error,
19    limb::{self, Limb, LimbMask, LIMB_BITS},
20    polyfill::LeadingZerosStripped,
21};
22use core::marker::PhantomData;
23
24/// The x86 implementation of `bn_mul_mont`, at least, requires at least 4
25/// limbs. For a long time we have required 4 limbs for all targets, though
26/// this may be unnecessary. TODO: Replace this with
27/// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
28/// same.
29pub const MODULUS_MIN_LIMBS: usize = 4;
30
31pub const MODULUS_MAX_LIMBS: usize = super::super::BIGINT_MODULUS_MAX_LIMBS;
32
33/// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
34/// for efficient Montgomery multiplication modulo *m*. The value must be odd
35/// and larger than 2. The larger-than-1 requirement is imposed, at least, by
36/// the modular inversion code.
37pub struct OwnedModulus<M> {
38    limbs: BoxedLimbs<M>, // Also `value >= 3`.
39
40    // n0 * N == -1 (mod r).
41    //
42    // r == 2**(N0::LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
43    // ensures that we can do integer division by |r| by simply ignoring
44    // `N0::LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
45    // just looking at the lowest `N0::LIMBS_USED` limbs. This is what makes
46    // Montgomery multiplication efficient.
47    //
48    // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
49    // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
50    // multi-limb Montgomery multiplication of a * b (mod n), given the
51    // unreduced product t == a * b, we repeatedly calculate:
52    //
53    //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
54    //    t2 := t1*n0*n
55    //    t3 := t + t2
56    //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
57    //
58    // In the last step, it would only make sense to ignore the lowest limb of
59    // |t3| if it were zero. The middle steps ensure that this is the case:
60    //
61    //                            t3 ==  0 (mod r)
62    //                        t + t2 ==  0 (mod r)
63    //                   t + t1*n0*n ==  0 (mod r)
64    //                       t1*n0*n == -t (mod r)
65    //                        t*n0*n == -t (mod r)
66    //                          n0*n == -1 (mod r)
67    //                            n0 == -1/n (mod r)
68    //
69    // Thus, in each iteration of the loop, we multiply by the constant factor
70    // n0, the negative inverse of n (mod r).
71    //
72    // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
73    // ones that don't, we could use a shorter `R` value and use faster `Limb`
74    // calculations instead of double-precision `u64` calculations.
75    n0: N0,
76
77    len_bits: BitLength,
78}
79
80impl<M: PublicModulus> Clone for OwnedModulus<M> {
81    fn clone(&self) -> Self {
82        Self {
83            limbs: self.limbs.clone(),
84            n0: self.n0,
85            len_bits: self.len_bits,
86        }
87    }
88}
89
90impl<M> OwnedModulus<M> {
91    pub(crate) fn from_be_bytes(input: untrusted::Input) -> Result<Self, error::KeyRejected> {
92        let n = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
93        if n.len() > MODULUS_MAX_LIMBS {
94            return Err(error::KeyRejected::too_large());
95        }
96        if n.len() < MODULUS_MIN_LIMBS {
97            return Err(error::KeyRejected::unexpected_error());
98        }
99        if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
100            return Err(error::KeyRejected::invalid_component());
101        }
102        if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
103            return Err(error::KeyRejected::unexpected_error());
104        }
105
106        // n_mod_r = n % r. As explained in the documentation for `n0`, this is
107        // done by taking the lowest `N0::LIMBS_USED` limbs of `n`.
108        #[allow(clippy::useless_conversion)]
109        let n0 = {
110            prefixed_extern! {
111                fn bn_neg_inv_mod_r_u64(n: u64) -> u64;
112            }
113
114            // XXX: u64::from isn't guaranteed to be constant time.
115            let mut n_mod_r: u64 = u64::from(n[0]);
116
117            if N0::LIMBS_USED == 2 {
118                // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
119                // fail to compile because of `deny(exceeding_bitshifts)`.
120                debug_assert_eq!(LIMB_BITS, 32);
121                n_mod_r |= u64::from(n[1]) << 32;
122            }
123            N0::precalculated(unsafe { bn_neg_inv_mod_r_u64(n_mod_r) })
124        };
125
126        let len_bits = limb::limbs_minimal_bits(&n);
127
128        Ok(Self {
129            limbs: n,
130            n0,
131            len_bits,
132        })
133    }
134
135    pub fn verify_less_than<L>(&self, l: &Modulus<L>) -> Result<(), error::Unspecified> {
136        if self.len_bits() > l.len_bits()
137            || (self.limbs.len() == l.limbs().len()
138                && limb::limbs_less_than_limbs_consttime(&self.limbs, l.limbs()) != LimbMask::True)
139        {
140            return Err(error::Unspecified);
141        }
142        Ok(())
143    }
144
145    pub fn to_elem<L>(&self, l: &Modulus<L>) -> Result<Elem<L, Unencoded>, error::Unspecified> {
146        self.verify_less_than(l)?;
147        let mut limbs = BoxedLimbs::zero(l.limbs.len());
148        limbs[..self.limbs.len()].copy_from_slice(&self.limbs);
149        Ok(Elem {
150            limbs,
151            encoding: PhantomData,
152        })
153    }
154    pub(crate) fn modulus(&self, cpu_features: cpu::Features) -> Modulus<M> {
155        Modulus {
156            limbs: &self.limbs,
157            n0: self.n0,
158            len_bits: self.len_bits,
159            m: PhantomData,
160            cpu_features,
161        }
162    }
163
164    pub fn len_bits(&self) -> BitLength {
165        self.len_bits
166    }
167}
168
169impl<M: PublicModulus> OwnedModulus<M> {
170    pub fn be_bytes(&self) -> LeadingZerosStripped<impl ExactSizeIterator<Item = u8> + Clone + '_> {
171        LeadingZerosStripped::new(limb::unstripped_be_bytes(&self.limbs))
172    }
173}
174
175pub struct Modulus<'a, M> {
176    limbs: &'a [Limb],
177    n0: N0,
178    len_bits: BitLength,
179    m: PhantomData<M>,
180    cpu_features: cpu::Features,
181}
182
183impl<M> Modulus<'_, M> {
184    pub(super) fn oneR(&self, out: &mut [Limb]) {
185        assert_eq!(self.limbs.len(), out.len());
186
187        let r = self.limbs.len() * LIMB_BITS;
188
189        // out = 2**r - m where m = self.
190        limb::limbs_negative_odd(out, self.limbs);
191
192        let lg_m = self.len_bits().as_bits();
193        let leading_zero_bits_in_m = r - lg_m;
194
195        // When m's length is a multiple of LIMB_BITS, which is the case we
196        // most want to optimize for, then we already have
197        // out == 2**r - m == 2**r (mod m).
198        if leading_zero_bits_in_m != 0 {
199            debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
200            // Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped
201            // all the leading zero bits to ones. Flip them back.
202            *out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
203
204            // Now we have out == 2**(lg m) (mod m). Keep doubling until we get
205            // to 2**r (mod m).
206            for _ in 0..leading_zero_bits_in_m {
207                limb::limbs_double_mod(out, self.limbs)
208            }
209        }
210
211        // Now out == 2**r (mod m) == 1*R.
212    }
213
214    // TODO: XXX Avoid duplication with `Modulus`.
215    pub(super) fn zero<E>(&self) -> Elem<M, E> {
216        Elem {
217            limbs: BoxedLimbs::zero(self.limbs.len()),
218            encoding: PhantomData,
219        }
220    }
221
222    #[inline]
223    pub(super) fn limbs(&self) -> &[Limb] {
224        self.limbs
225    }
226
227    #[inline]
228    pub(super) fn n0(&self) -> &N0 {
229        &self.n0
230    }
231
232    pub fn len_bits(&self) -> BitLength {
233        self.len_bits
234    }
235
236    #[inline]
237    pub(crate) fn cpu_features(&self) -> cpu::Features {
238        self.cpu_features
239    }
240}