1use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
6use rayon::prelude::*;
7
8#[derive(Clone, Debug, Default)]
9pub struct GroupedIter<'s> {
10 sums: &'s [u32],
11 group_start: u32,
12 group_end: u32,
13 start: u32,
14 end: u32,
15}
16
17impl Iterator for GroupedIter<'_> {
18 type Item = (u32, u32);
19
20 #[inline]
21 fn next(&mut self) -> Option<Self::Item> {
22 loop {
23 if self.start >= self.end {
24 return None;
25 }
26
27 let exclusive_sum =
28 self.group_start.checked_sub(1).map(|i| self.sums[i as usize]).unwrap_or_default();
29 let inclusive_sum = self.sums[self.group_start as usize];
30
31 if exclusive_sum == inclusive_sum {
32 self.group_start += 1;
33 continue;
34 }
35
36 let result = (self.group_start, self.start as u32 - exclusive_sum);
37
38 self.start += 1;
39
40 if self.start == inclusive_sum {
41 self.group_start += 1;
42 }
43
44 return Some(result);
45 }
46 }
47}
48
49impl DoubleEndedIterator for GroupedIter<'_> {
50 #[inline]
51 fn next_back(&mut self) -> Option<Self::Item> {
52 loop {
53 if self.start >= self.end {
54 return None;
55 }
56
57 let exclusive_sum =
58 self.group_end.checked_sub(1).map(|i| self.sums[i as usize]).unwrap_or_default();
59 let inclusive_sum = self.sums[self.group_end as usize];
60
61 if exclusive_sum == inclusive_sum {
62 self.group_end = self.group_end.saturating_sub(1);
63 continue;
64 }
65
66 let result = (self.group_end, self.end as u32 - 1 - exclusive_sum);
67
68 self.end -= 1;
69
70 if self.end == exclusive_sum {
71 self.group_end = self.group_end.saturating_sub(1);
72 }
73
74 return Some(result);
75 }
76 }
77}
78
79impl ExactSizeIterator for GroupedIter<'_> {
80 fn len(&self) -> usize {
81 (self.end - self.start) as usize
82 }
83}
84
85struct GroupedIterProducer<'s> {
86 inner: GroupedIter<'s>,
87}
88
89impl<'s> Producer for GroupedIterProducer<'s> {
90 type Item = (u32, u32);
91
92 type IntoIter = GroupedIter<'s>;
93
94 #[inline]
95 fn into_iter(self) -> Self::IntoIter {
96 self.inner
97 }
98
99 #[inline]
100 fn split_at(self, index: usize) -> (Self, Self) {
101 let index = index as u32 + self.inner.start;
102
103 let mid = match self.inner.sums.binary_search(&(index as u32)) {
104 Ok(mid) => mid + 1,
105 Err(mid) => mid,
106 } as u32;
107
108 (
109 Self {
110 inner: GroupedIter {
111 sums: self.inner.sums,
112 group_start: self.inner.group_start,
113 group_end: mid as u32,
114 start: self.inner.start,
115 end: index as u32,
116 },
117 },
118 Self {
119 inner: GroupedIter {
120 sums: self.inner.sums,
121 group_start: mid as u32,
122 group_end: self.inner.group_end,
123 start: index as u32,
124 end: self.inner.end,
125 },
126 },
127 )
128 }
129}
130
131impl<'s> GroupedIter<'s> {
132 #[inline]
133 pub fn new(sums: &'s [u32]) -> Self {
134 Self {
135 sums,
136 group_start: 0,
137 group_end: sums.len().saturating_sub(1) as u32,
138 start: 0,
139 end: sums.last().copied().unwrap_or_default(),
140 }
141 }
142}
143
144impl<'s> IntoParallelIterator for GroupedIter<'s> {
145 type Iter = GroupedParIter<'s>;
146
147 type Item = (u32, u32);
148
149 #[inline]
150 fn into_par_iter(self) -> Self::Iter {
151 GroupedParIter { iter: self }
152 }
153}
154
155pub struct GroupedParIter<'s> {
156 iter: GroupedIter<'s>,
157}
158
159impl ParallelIterator for GroupedParIter<'_> {
160 type Item = (u32, u32);
161
162 #[inline]
163 fn drive_unindexed<C>(self, consumer: C) -> C::Result
164 where
165 C: UnindexedConsumer<Self::Item>,
166 {
167 bridge(self, consumer)
168 }
169
170 #[inline]
171 fn opt_len(&self) -> Option<usize> {
172 Some(self.iter.len())
173 }
174}
175
176impl IndexedParallelIterator for GroupedParIter<'_> {
177 #[inline]
178 fn len(&self) -> usize {
179 self.iter.len()
180 }
181
182 #[inline]
183 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
184 bridge(self, consumer)
185 }
186
187 #[inline]
188 fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
189 callback.callback(GroupedIterProducer { inner: self.iter })
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn empty_iter() {
199 let sums = &[];
200
201 assert_eq!(GroupedIter::new(sums).collect::<Vec<_>>(), []);
202 }
203
204 #[test]
205 fn local_iter() {
206 let sums = &[2, 5, 9, 15];
207 let iter = GroupedIter::new(sums);
208
209 assert_eq!(
210 iter.collect::<Vec<_>>(),
211 [
212 (0, 0),
213 (0, 1),
214 (1, 0),
215 (1, 1),
216 (1, 2),
217 (2, 0),
218 (2, 1),
219 (2, 2),
220 (2, 3),
221 (3, 0),
222 (3, 1),
223 (3, 2),
224 (3, 3),
225 (3, 4),
226 (3, 5),
227 ]
228 );
229 }
230
231 #[test]
232 fn local_iter_rev() {
233 let sums = &[2, 5, 9, 15];
234 let iter = GroupedIter::new(sums);
235
236 assert_eq!(
237 iter.rev().collect::<Vec<_>>(),
238 [
239 (3, 5),
240 (3, 4),
241 (3, 3),
242 (3, 2),
243 (3, 1),
244 (3, 0),
245 (2, 3),
246 (2, 2),
247 (2, 1),
248 (2, 0),
249 (1, 2),
250 (1, 1),
251 (1, 0),
252 (0, 1),
253 (0, 0),
254 ]
255 );
256 }
257
258 #[test]
259 fn both_ends() {
260 let sums = &[2, 5, 9, 15];
261 let mut iter = GroupedIter::new(sums);
262
263 assert_eq!(iter.len(), 15);
264
265 assert_eq!(iter.next(), Some((0, 0)));
266 assert_eq!(iter.next_back(), Some((3, 5)));
267 assert_eq!(iter.next(), Some((0, 1)));
268 assert_eq!(iter.next_back(), Some((3, 4)));
269 assert_eq!(iter.next(), Some((1, 0)));
270 assert_eq!(iter.next_back(), Some((3, 3)));
271 assert_eq!(iter.next(), Some((1, 1)));
272
273 assert_eq!(iter.len(), 8);
274
275 assert_eq!(iter.next_back(), Some((3, 2)));
276 assert_eq!(iter.next(), Some((1, 2)));
277 assert_eq!(iter.next_back(), Some((3, 1)));
278 assert_eq!(iter.next(), Some((2, 0)));
279 assert_eq!(iter.next_back(), Some((3, 0)));
280 assert_eq!(iter.next(), Some((2, 1)));
281 assert_eq!(iter.next_back(), Some((2, 3)));
282 assert_eq!(iter.next(), Some((2, 2)));
283 assert_eq!(iter.next(), None);
284 assert_eq!(iter.next_back(), None);
285
286 assert_eq!(iter.len(), 0);
287 }
288
289 #[test]
290 fn empty_groups() {
291 let sums = &[2, 2, 5, 5];
292 let iter = GroupedIter::new(sums);
293
294 assert_eq!(iter.collect::<Vec<_>>(), [(0, 0), (0, 1), (2, 0), (2, 1), (2, 2),]);
295 }
296
297 #[test]
298 fn empty_groups_rev() {
299 let sums = &[2, 2, 5, 5];
300 let iter = GroupedIter::new(sums);
301
302 assert_eq!(iter.rev().collect::<Vec<_>>(), [(2, 2), (2, 1), (2, 0), (0, 1), (0, 0),]);
303 }
304
305 #[test]
306 fn par_iter() {
307 let sums = &[2, 5, 9, 15];
308 let iter = GroupedIter::new(sums);
309
310 assert_eq!(
311 iter.into_par_iter().collect::<Vec<_>>(),
312 [
313 (0, 0),
314 (0, 1),
315 (1, 0),
316 (1, 1),
317 (1, 2),
318 (2, 0),
319 (2, 1),
320 (2, 2),
321 (2, 3),
322 (3, 0),
323 (3, 1),
324 (3, 2),
325 (3, 3),
326 (3, 4),
327 (3, 5),
328 ]
329 );
330 }
331
332 #[test]
333 fn par_iter2() {
334 let sums = &[3, 6, 10, 11];
335 let iter = GroupedIter::new(sums);
336
337 assert_eq!(
338 iter.into_par_iter().collect::<Vec<_>>(),
339 [
340 (0, 0),
341 (0, 1),
342 (0, 2),
343 (1, 0),
344 (1, 1),
345 (1, 2),
346 (2, 0),
347 (2, 1),
348 (2, 2),
349 (2, 3),
350 (3, 0),
351 ]
352 );
353 }
354}