num_bigint/biguint/
multiplication.rs

1use super::addition::{__add2, add2};
2use super::subtraction::sub2;
3#[cfg(not(u64_digit))]
4use super::u32_from_u128;
5use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
6
7use crate::big_digit::{self, BigDigit, DoubleBigDigit};
8use crate::Sign::{self, Minus, NoSign, Plus};
9use crate::{BigInt, UsizePromotion};
10
11use core::cmp::Ordering;
12use core::iter::Product;
13use core::ops::{Mul, MulAssign};
14use num_traits::{CheckedMul, FromPrimitive, One, Zero};
15
16#[inline]
17pub(super) fn mac_with_carry(
18    a: BigDigit,
19    b: BigDigit,
20    c: BigDigit,
21    acc: &mut DoubleBigDigit,
22) -> BigDigit {
23    *acc += DoubleBigDigit::from(a);
24    *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
25    let lo = *acc as BigDigit;
26    *acc >>= big_digit::BITS;
27    lo
28}
29
30#[inline]
31fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
32    *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
33    let lo = *acc as BigDigit;
34    *acc >>= big_digit::BITS;
35    lo
36}
37
38/// Three argument multiply accumulate:
39/// acc += b * c
40fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
41    if c == 0 {
42        return;
43    }
44
45    let mut carry = 0;
46    let (a_lo, a_hi) = acc.split_at_mut(b.len());
47
48    for (a, &b) in a_lo.iter_mut().zip(b) {
49        *a = mac_with_carry(*a, b, c, &mut carry);
50    }
51
52    let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
53
54    let final_carry = if carry_hi == 0 {
55        __add2(a_hi, &[carry_lo])
56    } else {
57        __add2(a_hi, &[carry_hi, carry_lo])
58    };
59    assert_eq!(final_carry, 0, "carry overflow during multiplication!");
60}
61
62fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
63    BigInt::from(biguint_from_vec(slice.to_vec()))
64}
65
66/// Three argument multiply accumulate:
67/// acc += b * c
68#[allow(clippy::many_single_char_names)]
69fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
70    // Least-significant zeros have no effect on the output.
71    if let Some(&0) = b.first() {
72        if let Some(nz) = b.iter().position(|&d| d != 0) {
73            b = &b[nz..];
74            acc = &mut acc[nz..];
75        } else {
76            return;
77        }
78    }
79    if let Some(&0) = c.first() {
80        if let Some(nz) = c.iter().position(|&d| d != 0) {
81            c = &c[nz..];
82            acc = &mut acc[nz..];
83        } else {
84            return;
85        }
86    }
87
88    let acc = acc;
89    let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
90
91    // We use three algorithms for different input sizes.
92    //
93    // - For small inputs, long multiplication is fastest.
94    // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
95    //   to avoid unnecessary allocations for intermediate values.
96    // - For the largest inputs we use Toom-3, which better optimizes the
97    //   number of operations, but uses more temporary allocations.
98    //
99    // The thresholds are somewhat arbitrary, chosen by evaluating the results
100    // of `cargo bench --bench bigint multiply`.
101
102    if x.len() <= 32 {
103        // Long multiplication:
104        for (i, xi) in x.iter().enumerate() {
105            mac_digit(&mut acc[i..], y, *xi);
106        }
107    } else if x.len() <= 256 {
108        // Karatsuba multiplication:
109        //
110        // The idea is that we break x and y up into two smaller numbers that each have about half
111        // as many digits, like so (note that multiplying by b is just a shift):
112        //
113        // x = x0 + x1 * b
114        // y = y0 + y1 * b
115        //
116        // With some algebra, we can compute x * y with three smaller products, where the inputs to
117        // each of the smaller products have only about half as many digits as x and y:
118        //
119        // x * y = (x0 + x1 * b) * (y0 + y1 * b)
120        //
121        // x * y = x0 * y0
122        //       + x0 * y1 * b
123        //       + x1 * y0 * b
124        //       + x1 * y1 * b^2
125        //
126        // Let p0 = x0 * y0 and p2 = x1 * y1:
127        //
128        // x * y = p0
129        //       + (x0 * y1 + x1 * y0) * b
130        //       + p2 * b^2
131        //
132        // The real trick is that middle term:
133        //
134        //         x0 * y1 + x1 * y0
135        //
136        //       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
137        //
138        //       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
139        //
140        // Now we complete the square:
141        //
142        //       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
143        //
144        //       = -((x1 - x0) * (y1 - y0)) + p0 + p2
145        //
146        // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
147        //
148        // x * y = p0
149        //       + (p0 + p2 - p1) * b
150        //       + p2 * b^2
151        //
152        // Where the three intermediate products are:
153        //
154        // p0 = x0 * y0
155        // p1 = (x1 - x0) * (y1 - y0)
156        // p2 = x1 * y1
157        //
158        // In doing the computation, we take great care to avoid unnecessary temporary variables
159        // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
160        // bit so we can use the same temporary variable for all the intermediate products:
161        //
162        // x * y = p2 * b^2 + p2 * b
163        //       + p0 * b + p0
164        //       - p1 * b
165        //
166        // The other trick we use is instead of doing explicit shifts, we slice acc at the
167        // appropriate offset when doing the add.
168
169        // When x is smaller than y, it's significantly faster to pick b such that x is split in
170        // half, not y:
171        let b = x.len() / 2;
172        let (x0, x1) = x.split_at(b);
173        let (y0, y1) = y.split_at(b);
174
175        // We reuse the same BigUint for all the intermediate multiplies and have to size p
176        // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
177        let len = x1.len() + y1.len() + 1;
178        let mut p = BigUint { data: vec![0; len] };
179
180        // p2 = x1 * y1
181        mac3(&mut p.data, x1, y1);
182
183        // Not required, but the adds go faster if we drop any unneeded 0s from the end:
184        p.normalize();
185
186        add2(&mut acc[b..], &p.data);
187        add2(&mut acc[b * 2..], &p.data);
188
189        // Zero out p before the next multiply:
190        p.data.truncate(0);
191        p.data.resize(len, 0);
192
193        // p0 = x0 * y0
194        mac3(&mut p.data, x0, y0);
195        p.normalize();
196
197        add2(acc, &p.data);
198        add2(&mut acc[b..], &p.data);
199
200        // p1 = (x1 - x0) * (y1 - y0)
201        // We do this one last, since it may be negative and acc can't ever be negative:
202        let (j0_sign, j0) = sub_sign(x1, x0);
203        let (j1_sign, j1) = sub_sign(y1, y0);
204
205        match j0_sign * j1_sign {
206            Plus => {
207                p.data.truncate(0);
208                p.data.resize(len, 0);
209
210                mac3(&mut p.data, &j0.data, &j1.data);
211                p.normalize();
212
213                sub2(&mut acc[b..], &p.data);
214            }
215            Minus => {
216                mac3(&mut acc[b..], &j0.data, &j1.data);
217            }
218            NoSign => (),
219        }
220    } else {
221        // Toom-3 multiplication:
222        //
223        // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
224        // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
225        //
226        // The general idea is to treat the large integers digits as
227        // polynomials of a certain degree and determine the coefficients/digits
228        // of the product of the two via interpolation of the polynomial product.
229        let i = y.len() / 3 + 1;
230
231        let x0_len = Ord::min(x.len(), i);
232        let x1_len = Ord::min(x.len() - x0_len, i);
233
234        let y0_len = i;
235        let y1_len = Ord::min(y.len() - y0_len, i);
236
237        // Break x and y into three parts, representating an order two polynomial.
238        // t is chosen to be the size of a digit so we can use faster shifts
239        // in place of multiplications.
240        //
241        // x(t) = x2*t^2 + x1*t + x0
242        let x0 = bigint_from_slice(&x[..x0_len]);
243        let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
244        let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
245
246        // y(t) = y2*t^2 + y1*t + y0
247        let y0 = bigint_from_slice(&y[..y0_len]);
248        let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
249        let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
250
251        // Let w(t) = x(t) * y(t)
252        //
253        // This gives us the following order-4 polynomial.
254        //
255        // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
256        //
257        // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
258        // of simply multiplying the x and y in total, we can evaluate w
259        // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
260        // points.
261        //
262        // It is arbitrary as to what points we evaluate w at but we use the
263        // following.
264        //
265        // w(t) at t = 0, 1, -1, -2 and inf
266        //
267        // The values for w(t) in terms of x(t)*y(t) at these points are:
268        //
269        // let a = w(0)   = x0 * y0
270        // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
271        // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
272        // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
273        // let e = w(inf) = x2 * y2 as t -> inf
274
275        // x0 + x2, avoiding temporaries
276        let p = &x0 + &x2;
277
278        // y0 + y2, avoiding temporaries
279        let q = &y0 + &y2;
280
281        // x2 - x1 + x0, avoiding temporaries
282        let p2 = &p - &x1;
283
284        // y2 - y1 + y0, avoiding temporaries
285        let q2 = &q - &y1;
286
287        // w(0)
288        let r0 = &x0 * &y0;
289
290        // w(inf)
291        let r4 = &x2 * &y2;
292
293        // w(1)
294        let r1 = (p + x1) * (q + y1);
295
296        // w(-1)
297        let r2 = &p2 * &q2;
298
299        // w(-2)
300        let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
301
302        // Evaluating these points gives us the following system of linear equations.
303        //
304        //  0  0  0  0  1 | a
305        //  1  1  1  1  1 | b
306        //  1 -1  1 -1  1 | c
307        // 16 -8  4 -2  1 | d
308        //  1  0  0  0  0 | e
309        //
310        // The solved equation (after gaussian elimination or similar)
311        // in terms of its coefficients:
312        //
313        // w0 = w(0)
314        // w1 = w(0)/2 + w(1)/3 - w(-1) + w(2)/6 - 2*w(inf)
315        // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
316        // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(1)/6
317        // w4 = w(inf)
318        //
319        // This particular sequence is given by Bodrato and is an interpolation
320        // of the above equations.
321        let mut comp3: BigInt = (r3 - &r1) / 3u32;
322        let mut comp1: BigInt = (r1 - &r2) >> 1;
323        let mut comp2: BigInt = r2 - &r0;
324        comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
325        comp2 += &comp1 - &r4;
326        comp1 -= &comp3;
327
328        // Recomposition. The coefficients of the polynomial are now known.
329        //
330        // Evaluate at w(t) where t is our given base to get the result.
331        //
332        //     let bits = u64::from(big_digit::BITS) * i as u64;
333        //     let result = r0
334        //         + (comp1 << bits)
335        //         + (comp2 << (2 * bits))
336        //         + (comp3 << (3 * bits))
337        //         + (r4 << (4 * bits));
338        //     let result_pos = result.to_biguint().unwrap();
339        //     add2(&mut acc[..], &result_pos.data);
340        //
341        // But with less intermediate copying:
342        for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
343            match result.sign() {
344                Plus => add2(&mut acc[i * j..], result.digits()),
345                Minus => sub2(&mut acc[i * j..], result.digits()),
346                NoSign => {}
347            }
348        }
349    }
350}
351
352fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
353    let len = x.len() + y.len() + 1;
354    let mut prod = BigUint { data: vec![0; len] };
355
356    mac3(&mut prod.data, x, y);
357    prod.normalized()
358}
359
360fn scalar_mul(a: &mut BigUint, b: BigDigit) {
361    match b {
362        0 => a.set_zero(),
363        1 => {}
364        _ => {
365            if b.is_power_of_two() {
366                *a <<= b.trailing_zeros();
367            } else {
368                let mut carry = 0;
369                for a in a.data.iter_mut() {
370                    *a = mul_with_carry(*a, b, &mut carry);
371                }
372                if carry != 0 {
373                    a.data.push(carry as BigDigit);
374                }
375            }
376        }
377    }
378}
379
380fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
381    // Normalize:
382    if let Some(&0) = a.last() {
383        a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
384    }
385    if let Some(&0) = b.last() {
386        b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
387    }
388
389    match cmp_slice(a, b) {
390        Ordering::Greater => {
391            let mut a = a.to_vec();
392            sub2(&mut a, b);
393            (Plus, biguint_from_vec(a))
394        }
395        Ordering::Less => {
396            let mut b = b.to_vec();
397            sub2(&mut b, a);
398            (Minus, biguint_from_vec(b))
399        }
400        Ordering::Equal => (NoSign, Zero::zero()),
401    }
402}
403
404macro_rules! impl_mul {
405    ($(impl<$($a:lifetime),*> Mul<$Other:ty> for $Self:ty;)*) => {$(
406        impl<$($a),*> Mul<$Other> for $Self {
407            type Output = BigUint;
408
409            #[inline]
410            fn mul(self, other: $Other) -> BigUint {
411                match (&*self.data, &*other.data) {
412                    // multiply by zero
413                    (&[], _) | (_, &[]) => BigUint::zero(),
414                    // multiply by a scalar
415                    (_, &[digit]) => self * digit,
416                    (&[digit], _) => other * digit,
417                    // full multiplication
418                    (x, y) => mul3(x, y),
419                }
420            }
421        }
422    )*}
423}
424impl_mul! {
425    impl<> Mul<BigUint> for BigUint;
426    impl<'b> Mul<&'b BigUint> for BigUint;
427    impl<'a> Mul<BigUint> for &'a BigUint;
428    impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint;
429}
430
431macro_rules! impl_mul_assign {
432    ($(impl<$($a:lifetime),*> MulAssign<$Other:ty> for BigUint;)*) => {$(
433        impl<$($a),*> MulAssign<$Other> for BigUint {
434            #[inline]
435            fn mul_assign(&mut self, other: $Other) {
436                match (&*self.data, &*other.data) {
437                    // multiply by zero
438                    (&[], _) => {},
439                    (_, &[]) => self.set_zero(),
440                    // multiply by a scalar
441                    (_, &[digit]) => *self *= digit,
442                    (&[digit], _) => *self = other * digit,
443                    // full multiplication
444                    (x, y) => *self = mul3(x, y),
445                }
446            }
447        }
448    )*}
449}
450impl_mul_assign! {
451    impl<> MulAssign<BigUint> for BigUint;
452    impl<'a> MulAssign<&'a BigUint> for BigUint;
453}
454
455promote_unsigned_scalars!(impl Mul for BigUint, mul);
456promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
457forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
458forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
459forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
460
461impl Mul<u32> for BigUint {
462    type Output = BigUint;
463
464    #[inline]
465    fn mul(mut self, other: u32) -> BigUint {
466        self *= other;
467        self
468    }
469}
470impl MulAssign<u32> for BigUint {
471    #[inline]
472    fn mul_assign(&mut self, other: u32) {
473        scalar_mul(self, other as BigDigit);
474    }
475}
476
477impl Mul<u64> for BigUint {
478    type Output = BigUint;
479
480    #[inline]
481    fn mul(mut self, other: u64) -> BigUint {
482        self *= other;
483        self
484    }
485}
486impl MulAssign<u64> for BigUint {
487    #[cfg(not(u64_digit))]
488    #[inline]
489    fn mul_assign(&mut self, other: u64) {
490        if let Some(other) = BigDigit::from_u64(other) {
491            scalar_mul(self, other);
492        } else {
493            let (hi, lo) = big_digit::from_doublebigdigit(other);
494            *self = mul3(&self.data, &[lo, hi]);
495        }
496    }
497
498    #[cfg(u64_digit)]
499    #[inline]
500    fn mul_assign(&mut self, other: u64) {
501        scalar_mul(self, other);
502    }
503}
504
505impl Mul<u128> for BigUint {
506    type Output = BigUint;
507
508    #[inline]
509    fn mul(mut self, other: u128) -> BigUint {
510        self *= other;
511        self
512    }
513}
514
515impl MulAssign<u128> for BigUint {
516    #[cfg(not(u64_digit))]
517    #[inline]
518    fn mul_assign(&mut self, other: u128) {
519        if let Some(other) = BigDigit::from_u128(other) {
520            scalar_mul(self, other);
521        } else {
522            *self = match u32_from_u128(other) {
523                (0, 0, c, d) => mul3(&self.data, &[d, c]),
524                (0, b, c, d) => mul3(&self.data, &[d, c, b]),
525                (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
526            };
527        }
528    }
529
530    #[cfg(u64_digit)]
531    #[inline]
532    fn mul_assign(&mut self, other: u128) {
533        if let Some(other) = BigDigit::from_u128(other) {
534            scalar_mul(self, other);
535        } else {
536            let (hi, lo) = big_digit::from_doublebigdigit(other);
537            *self = mul3(&self.data, &[lo, hi]);
538        }
539    }
540}
541
542impl CheckedMul for BigUint {
543    #[inline]
544    fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
545        Some(self.mul(v))
546    }
547}
548
549impl_product_iter_type!(BigUint);
550
551#[test]
552fn test_sub_sign() {
553    use crate::BigInt;
554    use num_traits::Num;
555
556    fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
557        let (sign, val) = sub_sign(a, b);
558        BigInt::from_biguint(sign, val)
559    }
560
561    let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
562    let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
563    let a_i = BigInt::from(a.clone());
564    let b_i = BigInt::from(b.clone());
565
566    assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
567    assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
568}