1use subtle::ConstantTimeEq;
6
7use crate::Field;
8
9#[cfg(feature = "alloc")]
17#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
18pub trait BatchInvert<F: Field> {
19 fn batch_invert(self) -> F;
24}
25
26#[cfg(feature = "alloc")]
27#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
28impl<'a, F, I> BatchInvert<F> for I
29where
30 F: Field + ConstantTimeEq,
31 I: IntoIterator<Item = &'a mut F>,
32{
33 fn batch_invert(self) -> F {
34 let mut acc = F::one();
35 let iter = self.into_iter();
36 let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
37 for p in iter {
38 let q = *p;
39 tmp.push((acc, p));
40 acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
41 }
42 acc = acc.invert().unwrap();
43 let allinv = acc;
44
45 for (tmp, p) in tmp.into_iter().rev() {
46 let skip = p.ct_eq(&F::zero());
47
48 let tmp = tmp * acc;
49 acc = F::conditional_select(&(acc * *p), &acc, skip);
50 *p = F::conditional_select(&tmp, p, skip);
51 }
52
53 allinv
54 }
55}
56
57pub struct BatchInverter {}
59
60impl BatchInverter {
61 pub fn invert_with_external_scratch<F>(elements: &mut [F], scratch_space: &mut [F]) -> F
72 where
73 F: Field + ConstantTimeEq,
74 {
75 assert_eq!(elements.len(), scratch_space.len());
76
77 let mut acc = F::one();
78 for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
79 *scratch = acc;
80 acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
81 }
82 acc = acc.invert().unwrap();
83 let allinv = acc;
84
85 for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
86 let tmp = *scratch * acc;
87 let skip = p.ct_eq(&F::zero());
88 acc = F::conditional_select(&(acc * *p), &acc, skip);
89 *p = F::conditional_select(&tmp, &p, skip);
90 }
91
92 allinv
93 }
94
95 pub fn invert_with_internal_scratch<F, T, TE, TS>(
103 items: &mut [T],
104 element: TE,
105 scratch_space: TS,
106 ) -> F
107 where
108 F: Field + ConstantTimeEq,
109 TE: Fn(&mut T) -> &mut F,
110 TS: Fn(&mut T) -> &mut F,
111 {
112 let mut acc = F::one();
113 for item in items.iter_mut() {
114 *(scratch_space)(item) = acc;
115 let p = (element)(item);
116 acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
117 }
118 acc = acc.invert().unwrap();
119 let allinv = acc;
120
121 for item in items.iter_mut().rev() {
122 let tmp = *(scratch_space)(item) * acc;
123 let p = (element)(item);
124 let skip = p.ct_eq(&F::zero());
125 acc = F::conditional_select(&(acc * *p), &acc, skip);
126 *p = F::conditional_select(&tmp, &p, skip);
127 }
128
129 allinv
130 }
131}