1use futures::{Future, FutureExt, Stream, StreamExt, TryStream, TryStreamExt};
9
10#[derive(Debug)]
12pub enum FoldWhile<C, D> {
13 Continue(C),
15 Done(D),
17}
18
19#[derive(Debug, Eq, PartialEq, Clone, Copy)]
21pub enum FoldResult<F, R> {
22 StreamEnded(F),
24 ShortCircuited(R),
26}
27
28impl<F, R> FoldResult<F, R> {
29 pub fn ended(self) -> Result<F, R> {
32 match self {
33 FoldResult::StreamEnded(r) => Ok(r),
34 FoldResult::ShortCircuited(r) => Err(r),
35 }
36 }
37
38 pub fn short_circuited(self) -> Result<R, F> {
41 match self {
42 FoldResult::StreamEnded(r) => Err(r),
43 FoldResult::ShortCircuited(r) => Ok(r),
44 }
45 }
46}
47
48impl<F> FoldResult<F, F> {
49 pub fn into_inner(self) -> F {
52 match self {
53 FoldResult::StreamEnded(r) | FoldResult::ShortCircuited(r) => r,
54 }
55 }
56}
57
58pub fn try_fold_while<S, T, D, F, Fut>(
66 s: S,
67 init: T,
68 mut f: F,
69) -> impl Future<Output = Result<FoldResult<T, D>, S::Error>>
70where
71 S: TryStream,
72 F: FnMut(T, S::Ok) -> Fut,
73 Fut: Future<Output = Result<FoldWhile<T, D>, S::Error>>,
74{
75 s.map_err(Err)
76 .try_fold(init, move |acc, n| {
77 f(acc, n).map(|r| match r {
78 Ok(FoldWhile::Continue(r)) => Ok(r),
79 Ok(FoldWhile::Done(d)) => Err(Ok(d)),
80 Err(e) => Err(Err(e)),
81 })
82 })
83 .map(|r| match r {
84 Ok(n) => Ok(FoldResult::StreamEnded(n)),
85 Err(Ok(n)) => Ok(FoldResult::ShortCircuited(n)),
86 Err(Err(e)) => Err(e),
87 })
88}
89
90pub fn fold_while<S, T, D, F, Fut>(
97 s: S,
98 init: T,
99 mut f: F,
100) -> impl Future<Output = FoldResult<T, D>>
101where
102 S: Stream,
103 F: FnMut(T, S::Item) -> Fut,
104 Fut: Future<Output = FoldWhile<T, D>>,
105{
106 s.map(Ok)
107 .try_fold(init, move |acc, n| {
108 f(acc, n).map(|r| match r {
109 FoldWhile::Continue(r) => Ok(r),
110 FoldWhile::Done(d) => Err(d),
111 })
112 })
113 .map(|r| match r {
114 Ok(n) => FoldResult::StreamEnded(n),
115 Err(n) => FoldResult::ShortCircuited(n),
116 })
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use fuchsia_async as fasync;
123 use futures::channel::mpsc;
124 use futures::future;
125
126 #[fasync::run_singlethreaded(test)]
127 async fn test_try_fold_while_short_circuit() {
128 let (sender, stream) = mpsc::unbounded::<u32>();
129 const STOP_AT: u32 = 5;
130 let mut sum = 0;
131 for i in 0..10 {
132 if i < STOP_AT {
133 sum += i;
134 }
135 let () = sender.unbounded_send(i).expect("failed to send item");
136 }
137 let (acc, stop) = try_fold_while(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
138 future::ok(if next == STOP_AT {
139 FoldWhile::Done((acc, next))
140 } else {
141 FoldWhile::Continue(next + acc)
142 })
143 })
144 .await
145 .expect("try_fold_while failed")
146 .short_circuited()
147 .expect("try_fold_while should've short-circuited");
148 assert_eq!(stop, STOP_AT);
149 assert_eq!(acc, sum);
150 }
151
152 #[fasync::run_singlethreaded(test)]
153 async fn test_try_fold_while_stream_ended() {
154 let (sender, stream) = mpsc::unbounded::<u32>();
155 let mut sum = 0u32;
156 for i in 0..10 {
157 sum += i;
158 let () = sender.unbounded_send(i).expect("failed to send item");
159 }
160 std::mem::drop(sender);
161 let result =
162 try_fold_while::<_, _, (), _, _>(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
163 future::ok(FoldWhile::Continue(next + acc))
164 })
165 .await
166 .expect("try_fold_while failed")
167 .ended()
168 .expect("try_fold_while should have seen the stream end");
169
170 assert_eq!(result, sum);
171 }
172
173 #[fasync::run_singlethreaded(test)]
174 async fn test_try_fold_while_stream_error() {
175 #[derive(Debug)]
176 struct StreamErr;
177 let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
178 let () = sender.unbounded_send(Err(StreamErr {})).expect("failed to send item");
179 let StreamErr {} = try_fold_while::<_, _, (), _, _>(stream, (), |(), _: u32| async {
180 panic!("shouldn't receive error input");
181 })
182 .await
183 .expect_err("try_fold_while should return error");
184 }
185
186 #[fasync::run_singlethreaded(test)]
187 async fn test_try_fold_while_closure_error() {
188 #[derive(Debug)]
189 struct StreamErr {
190 item: u32,
191 }
192 const ERROR_ITEM: u32 = 1234;
193 let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
194 let () = sender.unbounded_send(Ok(ERROR_ITEM)).expect("failed to send item");
195 let StreamErr { item } = try_fold_while::<_, _, (), _, _>(stream, (), |(), item| {
196 future::err(StreamErr { item })
197 })
198 .await
199 .expect_err("try_fold_while should return error");
200 assert_eq!(item, ERROR_ITEM);
201 }
202
203 #[fasync::run_singlethreaded(test)]
204 async fn test_fold_while_short_circuit() {
205 let (sender, stream) = mpsc::unbounded::<u32>();
206 const STOP_AT: u32 = 5;
207 let mut sum = 0;
208 for i in 0..10 {
209 if i < STOP_AT {
210 sum += i;
211 }
212 let () = sender.unbounded_send(i).expect("failed to send item");
213 }
214 let (acc, stop) = fold_while(stream, 0, |acc, next| {
215 future::ready(if next == STOP_AT {
216 FoldWhile::Done((acc, next))
217 } else {
218 FoldWhile::Continue(next + acc)
219 })
220 })
221 .await
222 .short_circuited()
223 .expect("fold_while should've short-circuited");
224 assert_eq!(stop, STOP_AT);
225 assert_eq!(acc, sum);
226 }
227
228 #[fasync::run_singlethreaded(test)]
229 async fn test_fold_while_stream_ended() {
230 let (sender, stream) = mpsc::unbounded::<u32>();
231 let mut sum = 0u32;
232 for i in 0..10 {
233 sum += i;
234 let () = sender.unbounded_send(i).expect("failed to send item");
235 }
236 std::mem::drop(sender);
237 let result = fold_while::<_, _, (), _, _>(stream, 0, |acc, next| {
238 future::ready(FoldWhile::Continue(next + acc))
239 })
240 .await
241 .ended()
242 .expect("fold_while should have seen the stream end");
243
244 assert_eq!(result, sum);
245 }
246
247 #[test]
248 fn test_fold_result_into_inner() {
249 let x = FoldResult::<u32, u32>::StreamEnded(1);
250 let y = FoldResult::<u32, u32>::ShortCircuited(2);
251 assert_eq!(x.into_inner(), 1);
252 assert_eq!(y.into_inner(), 2);
253 }
254
255 #[test]
256 fn test_fold_result_mapping() {
257 type FoldResult = super::FoldResult<u32, bool>;
258 assert_eq!(FoldResult::StreamEnded(1).ended(), Ok(1));
259 assert_eq!(FoldResult::ShortCircuited(false).ended(), Err(false));
260
261 assert_eq!(FoldResult::StreamEnded(2).short_circuited(), Err(2));
262 assert_eq!(FoldResult::ShortCircuited(true).short_circuited(), Ok(true));
263 }
264}