surpass/
extend.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5macro_rules! extend_tuple {
6    ( $name:ident, $( ( $fields:tt $types:ident ) ),+ ) => {
7        pub struct $name<'a, $($types),+> {
8            tuple: ($(&'a mut Vec<$types>),+),
9        }
10
11        impl<'a, $($types),+> $name<'a, $($types),+> {
12            pub fn new(tuple: ($(&'a mut Vec<$types>),+)) -> Self {
13                Self { tuple }
14            }
15        }
16
17        impl<$($types),+> ::rayon::iter::ParallelExtend<($($types),+)> for $name<'_, $($types),+>
18        where
19            $(
20                $types: Send,
21            )+
22        {
23            fn par_extend<PI>(&mut self, par_iter: PI)
24            where
25                PI: ::rayon::iter::IntoParallelIterator<Item = ($($types),+)>,
26            {
27                use ::std::{
28                    collections::LinkedList, ptr, slice, sync::atomic::{AtomicUsize, Ordering},
29                };
30
31                use ::rayon::{
32                    iter::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer},
33                    prelude::*,
34                };
35
36                struct NoopReducer;
37
38                impl Reducer<()> for NoopReducer {
39                    fn reduce(self, _left: (), _right: ()) {}
40                }
41
42                struct CollectTupleConsumer<'c, $($types: Send),+> {
43                    writes: &'c AtomicUsize,
44                    targets: ($(&'c mut [$types]),+),
45                }
46
47                struct CollectTupleFolder<'c, $($types: Send),+> {
48                    global_writes: &'c AtomicUsize,
49                    local_writes: usize,
50                    targets: ($(slice::IterMut<'c, $types>),+),
51                }
52
53                impl<'c, $($types: Send + 'c),+> Consumer<($($types),+)>
54                for CollectTupleConsumer<'c, $($types),+>
55                {
56                    type Folder = CollectTupleFolder<'c, $($types),+>;
57                    type Reducer = NoopReducer;
58                    type Result = ();
59
60                    fn split_at(self, index: usize) -> (Self, Self, NoopReducer) {
61                        let CollectTupleConsumer { writes, targets } = self;
62
63                        let splits = (
64                            $(
65                                targets.$fields.split_at_mut(index),
66                            )+
67                        );
68
69                        (
70                            CollectTupleConsumer {
71                                writes,
72                                targets: (
73                                    $(
74                                        splits.$fields.0,
75                                    )+
76                                ),
77                            },
78                            CollectTupleConsumer {
79                                writes,
80                                targets: (
81                                    $(
82                                        splits.$fields.1,
83                                    )+
84                                ),
85                            },
86                            NoopReducer,
87                        )
88                    }
89
90                    fn into_folder(self) -> CollectTupleFolder<'c, $($types),+> {
91                        CollectTupleFolder {
92                            global_writes: self.writes,
93                            local_writes: 0,
94                            targets: (
95                                $(
96                                    self.targets.$fields.iter_mut(),
97                                )+
98                            ),
99                        }
100                    }
101
102                    fn full(&self) -> bool {
103                        false
104                    }
105                }
106
107                impl<'c, $($types: Send + 'c),+> Folder<($($types),+)>
108                for CollectTupleFolder<'c, $($types),+>
109                {
110                     type Result = ();
111
112                    fn consume(
113                        mut self,
114                        item: ($($types),+),
115                    ) -> CollectTupleFolder<'c, $($types),+> {
116                        $(
117                            let head = self
118                                .targets
119                                .$fields
120                                .next()
121                                .expect("too many values pushed to consumer");
122                            unsafe {
123                                ptr::write(head, item.$fields);
124                            }
125                        )+
126
127                        self.local_writes += 1;
128                        self
129                    }
130
131                    fn complete(self) {
132                        self.global_writes.fetch_add(self.local_writes, Ordering::Relaxed);
133                    }
134
135                    fn full(&self) -> bool {
136                        false
137                    }
138                }
139
140                impl<'c, $($types: Send + 'c),+> UnindexedConsumer<($($types),+)>
141                for CollectTupleConsumer<'c, $($types),+>
142                {
143                     fn split_off_left(&self) -> Self {
144                        unreachable!("CollectTupleConsumer must be indexed!")
145                    }
146                    fn to_reducer(&self) -> Self::Reducer {
147                        NoopReducer
148                    }
149                }
150
151                struct CollectTuple<'c, $($types: Send),+> {
152                    writes: AtomicUsize,
153                    tuple: ($(&'c mut Vec<$types>),+),
154                    len: usize,
155                }
156
157                impl<'c, $($types: Send),+> CollectTuple<'c, $($types),+> {
158                    pub fn new(tuple: ($(&'c mut Vec<$types>),+), len: usize) -> Self {
159                        Self {
160                            writes: AtomicUsize::new(0),
161                            tuple,
162                            len,
163                        }
164                    }
165
166                    pub fn as_consumer(&mut self) -> CollectTupleConsumer<'_, $($types),+> {
167                        $(
168                            self.tuple.$fields.reserve(self.len);
169                        )+
170
171                        CollectTupleConsumer {
172                            writes: &self.writes,
173                            targets: (
174                                $(
175                                    {
176                                        let vec = &mut self.tuple.$fields;
177                                        let start = vec.len();
178                                        let slice = &mut vec[start..];
179                                        unsafe {
180                                            slice::from_raw_parts_mut(
181                                                slice.as_mut_ptr(),
182                                                self.len,
183                                            )
184                                        }
185                                    }
186                                ),+
187                            ),
188                        }
189                    }
190
191                    pub fn complete(mut self) {
192                        unsafe {
193                            let actual_writes = self.writes.load(Ordering::Relaxed);
194                            assert!(
195                                actual_writes == self.len,
196                                "expected {} total writes, but got {}",
197                                self.len,
198                                actual_writes
199                            );
200
201                            $(
202                                let vec = &mut self.tuple.$fields;
203                                let new_len = vec.len() + self.len;
204                                vec.set_len(new_len);
205                            )+
206                        }
207                    }
208                }
209
210                let par_iter = par_iter.into_par_iter();
211                match par_iter.opt_len() {
212                    Some(len) => {
213                        let mut collect = CollectTuple::new(($(self.tuple.$fields),+), len);
214                        par_iter.drive_unindexed(collect.as_consumer());
215                        collect.complete()
216                    }
217                    None => {
218                        let list = par_iter
219                            .into_par_iter()
220                            .fold(|| ($(Vec::<$types>::new()),+), |mut vecs, elem| {
221                                $(
222                                    vecs.$fields.push(elem.$fields);
223                                )+
224                                vecs
225                            })
226                            .map(|item| {
227                                let mut list = LinkedList::new();
228                                list.push_back(item);
229                                list
230                            })
231                            .reduce(LinkedList::new, |mut list1, mut list2| {
232                                list1.append(&mut list2);
233                                list1
234                            });
235                        let len = list.iter().map(|vecs| vecs.0.len()).sum();
236
237                        $(
238                            self.tuple.$fields.reserve(len);
239                        )+
240                        for mut vecs in list {
241                            $(
242                                self.tuple.$fields.append(&mut vecs.$fields);
243                            )+
244                        }
245                    }
246                }
247            }
248        }
249    };
250}
251
252extend_tuple!(ExtendTuple3, (0 A), (1 B), (2 C));
253extend_tuple!(ExtendTuple10, (0 A), (1 B), (2 C), (3 D), (4 E), (5 F), (6 G), (7 H), (8 I), (9 J));
254
255pub struct ExtendVec<'a, T> {
256    vec: &'a mut Vec<T>,
257}
258
259impl<'a, T> ExtendVec<'a, T> {
260    pub fn new(vec: &'a mut Vec<T>) -> Self {
261        Self { vec }
262    }
263}
264
265impl<T: Send> rayon::iter::ParallelExtend<T> for ExtendVec<'_, T> {
266    fn par_extend<I>(&mut self, par_iter: I)
267    where
268        I: rayon::iter::IntoParallelIterator<Item = T>,
269    {
270        self.vec.par_extend(par_iter);
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    use rayon::prelude::*;
279
280    #[test]
281    fn tuple10() {
282        let mut vec0 = vec![];
283        let mut vec1 = vec![];
284        let mut vec2 = vec![];
285        let mut vec3 = vec![];
286        let mut vec4 = vec![];
287        let mut vec5 = vec![];
288        let mut vec6 = vec![];
289        let mut vec7 = vec![];
290        let mut vec8 = vec![];
291        let mut vec9 = vec![];
292
293        ExtendTuple10::new((
294            &mut vec0, &mut vec1, &mut vec2, &mut vec3, &mut vec4, &mut vec5, &mut vec6, &mut vec7,
295            &mut vec8, &mut vec9,
296        ))
297        .par_extend(
298            (0..3)
299                .into_par_iter()
300                .map(|i| (i, i + 1, i + 2, i + 3, i + 4, i + 5, i + 6, i + 7, i + 8, i + 9)),
301        );
302
303        assert_eq!(vec0, [0, 1, 2]);
304        assert_eq!(vec1, [1, 2, 3]);
305        assert_eq!(vec2, [2, 3, 4]);
306        assert_eq!(vec3, [3, 4, 5]);
307        assert_eq!(vec4, [4, 5, 6]);
308        assert_eq!(vec5, [5, 6, 7]);
309        assert_eq!(vec6, [6, 7, 8]);
310        assert_eq!(vec7, [7, 8, 9]);
311        assert_eq!(vec8, [8, 9, 10]);
312        assert_eq!(vec9, [9, 10, 11]);
313    }
314}