ff/
batch.rs

1//! Batched field inversion APIs, using [Montgomery's trick].
2//!
3//! [Montgomery's trick]: https://zcash.github.io/halo2/background/fields.html#montgomerys-trick
4
5use subtle::ConstantTimeEq;
6
7use crate::Field;
8
9/// Extension trait for iterators over mutable field elements which allows those field
10/// elements to be inverted in a batch.
11///
12/// `I: IntoIterator<Item = &'a mut F: Field + ConstantTimeEq>` implements this trait when
13/// the `alloc` feature flag is enabled.
14///
15/// For non-allocating contexts, see the [`BatchInverter`] struct.
16#[cfg(feature = "alloc")]
17#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
18pub trait BatchInvert<F: Field> {
19    /// Consumes this iterator and inverts each field element (when nonzero). Zero-valued
20    /// elements are left as zero.
21    ///
22    /// Returns the inverse of the product of all nonzero field elements.
23    fn batch_invert(self) -> F;
24}
25
26#[cfg(feature = "alloc")]
27#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
28impl<'a, F, I> BatchInvert<F> for I
29where
30    F: Field + ConstantTimeEq,
31    I: IntoIterator<Item = &'a mut F>,
32{
33    fn batch_invert(self) -> F {
34        let mut acc = F::one();
35        let iter = self.into_iter();
36        let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
37        for p in iter {
38            let q = *p;
39            tmp.push((acc, p));
40            acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
41        }
42        acc = acc.invert().unwrap();
43        let allinv = acc;
44
45        for (tmp, p) in tmp.into_iter().rev() {
46            let skip = p.ct_eq(&F::zero());
47
48            let tmp = tmp * acc;
49            acc = F::conditional_select(&(acc * *p), &acc, skip);
50            *p = F::conditional_select(&tmp, p, skip);
51        }
52
53        allinv
54    }
55}
56
57/// A non-allocating batch inverter.
58pub struct BatchInverter {}
59
60impl BatchInverter {
61    /// Inverts each field element in `elements` (when nonzero). Zero-valued elements are
62    /// left as zero.
63    ///
64    /// - `scratch_space` is a slice of field elements that can be freely overwritten.
65    ///
66    /// Returns the inverse of the product of all nonzero field elements.
67    ///
68    /// # Panics
69    ///
70    /// This function will panic if `elements.len() != scratch_space.len()`.
71    pub fn invert_with_external_scratch<F>(elements: &mut [F], scratch_space: &mut [F]) -> F
72    where
73        F: Field + ConstantTimeEq,
74    {
75        assert_eq!(elements.len(), scratch_space.len());
76
77        let mut acc = F::one();
78        for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
79            *scratch = acc;
80            acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
81        }
82        acc = acc.invert().unwrap();
83        let allinv = acc;
84
85        for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
86            let tmp = *scratch * acc;
87            let skip = p.ct_eq(&F::zero());
88            acc = F::conditional_select(&(acc * *p), &acc, skip);
89            *p = F::conditional_select(&tmp, &p, skip);
90        }
91
92        allinv
93    }
94
95    /// Inverts each field element in `items` (when nonzero). Zero-valued elements are
96    /// left as zero.
97    ///
98    /// - `element` is a function that extracts the element to be inverted from `items`.
99    /// - `scratch_space` is a function that extracts the scratch space from `items`.
100    ///
101    /// Returns the inverse of the product of all nonzero field elements.
102    pub fn invert_with_internal_scratch<F, T, TE, TS>(
103        items: &mut [T],
104        element: TE,
105        scratch_space: TS,
106    ) -> F
107    where
108        F: Field + ConstantTimeEq,
109        TE: Fn(&mut T) -> &mut F,
110        TS: Fn(&mut T) -> &mut F,
111    {
112        let mut acc = F::one();
113        for item in items.iter_mut() {
114            *(scratch_space)(item) = acc;
115            let p = (element)(item);
116            acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
117        }
118        acc = acc.invert().unwrap();
119        let allinv = acc;
120
121        for item in items.iter_mut().rev() {
122            let tmp = *(scratch_space)(item) * acc;
123            let p = (element)(item);
124            let skip = p.ct_eq(&F::zero());
125            acc = F::conditional_select(&(acc * *p), &acc, skip);
126            *p = F::conditional_select(&tmp, &p, skip);
127        }
128
129        allinv
130    }
131}