num_bigint/biguint/
power.rs

1use super::monty::monty_modpow;
2use super::BigUint;
3
4use crate::big_digit::{self, BigDigit};
5
6use num_integer::Integer;
7use num_traits::{One, Pow, ToPrimitive, Zero};
8
9impl<'b> Pow<&'b BigUint> for BigUint {
10    type Output = BigUint;
11
12    #[inline]
13    fn pow(self, exp: &BigUint) -> BigUint {
14        if self.is_one() || exp.is_zero() {
15            BigUint::one()
16        } else if self.is_zero() {
17            BigUint::zero()
18        } else if let Some(exp) = exp.to_u64() {
19            self.pow(exp)
20        } else if let Some(exp) = exp.to_u128() {
21            self.pow(exp)
22        } else {
23            // At this point, `self >= 2` and `exp >= 2¹²⁸`. The smallest possible result given
24            // `2.pow(2¹²⁸)` would require far more memory than 64-bit targets can address!
25            panic!("memory overflow")
26        }
27    }
28}
29
30impl Pow<BigUint> for BigUint {
31    type Output = BigUint;
32
33    #[inline]
34    fn pow(self, exp: BigUint) -> BigUint {
35        Pow::pow(self, &exp)
36    }
37}
38
39impl<'a, 'b> Pow<&'b BigUint> for &'a BigUint {
40    type Output = BigUint;
41
42    #[inline]
43    fn pow(self, exp: &BigUint) -> BigUint {
44        if self.is_one() || exp.is_zero() {
45            BigUint::one()
46        } else if self.is_zero() {
47            BigUint::zero()
48        } else {
49            self.clone().pow(exp)
50        }
51    }
52}
53
54impl<'a> Pow<BigUint> for &'a BigUint {
55    type Output = BigUint;
56
57    #[inline]
58    fn pow(self, exp: BigUint) -> BigUint {
59        Pow::pow(self, &exp)
60    }
61}
62
63macro_rules! pow_impl {
64    ($T:ty) => {
65        impl Pow<$T> for BigUint {
66            type Output = BigUint;
67
68            fn pow(self, mut exp: $T) -> BigUint {
69                if exp == 0 {
70                    return BigUint::one();
71                }
72                let mut base = self;
73
74                while exp & 1 == 0 {
75                    base = &base * &base;
76                    exp >>= 1;
77                }
78
79                if exp == 1 {
80                    return base;
81                }
82
83                let mut acc = base.clone();
84                while exp > 1 {
85                    exp >>= 1;
86                    base = &base * &base;
87                    if exp & 1 == 1 {
88                        acc *= &base;
89                    }
90                }
91                acc
92            }
93        }
94
95        impl<'b> Pow<&'b $T> for BigUint {
96            type Output = BigUint;
97
98            #[inline]
99            fn pow(self, exp: &$T) -> BigUint {
100                Pow::pow(self, *exp)
101            }
102        }
103
104        impl<'a> Pow<$T> for &'a BigUint {
105            type Output = BigUint;
106
107            #[inline]
108            fn pow(self, exp: $T) -> BigUint {
109                if exp == 0 {
110                    return BigUint::one();
111                }
112                Pow::pow(self.clone(), exp)
113            }
114        }
115
116        impl<'a, 'b> Pow<&'b $T> for &'a BigUint {
117            type Output = BigUint;
118
119            #[inline]
120            fn pow(self, exp: &$T) -> BigUint {
121                Pow::pow(self, *exp)
122            }
123        }
124    };
125}
126
127pow_impl!(u8);
128pow_impl!(u16);
129pow_impl!(u32);
130pow_impl!(u64);
131pow_impl!(usize);
132pow_impl!(u128);
133
134pub(super) fn modpow(x: &BigUint, exponent: &BigUint, modulus: &BigUint) -> BigUint {
135    assert!(
136        !modulus.is_zero(),
137        "attempt to calculate with zero modulus!"
138    );
139
140    if modulus.is_odd() {
141        // For an odd modulus, we can use Montgomery multiplication in base 2^32.
142        monty_modpow(x, exponent, modulus)
143    } else {
144        // Otherwise do basically the same as `num::pow`, but with a modulus.
145        plain_modpow(x, &exponent.data, modulus)
146    }
147}
148
149fn plain_modpow(base: &BigUint, exp_data: &[BigDigit], modulus: &BigUint) -> BigUint {
150    assert!(
151        !modulus.is_zero(),
152        "attempt to calculate with zero modulus!"
153    );
154
155    let i = match exp_data.iter().position(|&r| r != 0) {
156        None => return BigUint::one(),
157        Some(i) => i,
158    };
159
160    let mut base = base % modulus;
161    for _ in 0..i {
162        for _ in 0..big_digit::BITS {
163            base = &base * &base % modulus;
164        }
165    }
166
167    let mut r = exp_data[i];
168    let mut b = 0u8;
169    while r.is_even() {
170        base = &base * &base % modulus;
171        r >>= 1;
172        b += 1;
173    }
174
175    let mut exp_iter = exp_data[i + 1..].iter();
176    if exp_iter.len() == 0 && r.is_one() {
177        return base;
178    }
179
180    let mut acc = base.clone();
181    r >>= 1;
182    b += 1;
183
184    {
185        let mut unit = |exp_is_odd| {
186            base = &base * &base % modulus;
187            if exp_is_odd {
188                acc *= &base;
189                acc %= modulus;
190            }
191        };
192
193        if let Some(&last) = exp_iter.next_back() {
194            // consume exp_data[i]
195            for _ in b..big_digit::BITS {
196                unit(r.is_odd());
197                r >>= 1;
198            }
199
200            // consume all other digits before the last
201            for &r in exp_iter {
202                let mut r = r;
203                for _ in 0..big_digit::BITS {
204                    unit(r.is_odd());
205                    r >>= 1;
206                }
207            }
208            r = last;
209        }
210
211        debug_assert_ne!(r, 0);
212        while !r.is_zero() {
213            unit(r.is_odd());
214            r >>= 1;
215        }
216    }
217    acc
218}
219
220#[test]
221fn test_plain_modpow() {
222    let two = &BigUint::from(2u32);
223    let modulus = BigUint::from(0x1100u32);
224
225    let exp = vec![0, 0b1];
226    assert_eq!(
227        two.pow(0b1_00000000_u32) % &modulus,
228        plain_modpow(&two, &exp, &modulus)
229    );
230    let exp = vec![0, 0b10];
231    assert_eq!(
232        two.pow(0b10_00000000_u32) % &modulus,
233        plain_modpow(&two, &exp, &modulus)
234    );
235    let exp = vec![0, 0b110010];
236    assert_eq!(
237        two.pow(0b110010_00000000_u32) % &modulus,
238        plain_modpow(&two, &exp, &modulus)
239    );
240    let exp = vec![0b1, 0b1];
241    assert_eq!(
242        two.pow(0b1_00000001_u32) % &modulus,
243        plain_modpow(&two, &exp, &modulus)
244    );
245    let exp = vec![0b1100, 0, 0b1];
246    assert_eq!(
247        two.pow(0b1_00000000_00001100_u32) % &modulus,
248        plain_modpow(&two, &exp, &modulus)
249    );
250}
251
252#[test]
253fn test_pow_biguint() {
254    let base = BigUint::from(5u8);
255    let exponent = BigUint::from(3u8);
256
257    assert_eq!(BigUint::from(125u8), base.pow(exponent));
258}