async_utils/
fold.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Provides utilities to fold [`Stream`]s and [`TryStream`]s with a
6//! short-circuited result.
7
8use futures::{Future, FutureExt, Stream, StreamExt, TryStream, TryStreamExt};
9
10/// Controls folding behavior.
11#[derive(Debug)]
12pub enum FoldWhile<C, D> {
13    /// Continue folding with state `C`.
14    Continue(C),
15    /// Short-circuit folding with result `D`.
16    Done(D),
17}
18
19/// The result of folding a stream.
20#[derive(Debug, Eq, PartialEq, Clone, Copy)]
21pub enum FoldResult<F, R> {
22    /// The stream ended with folded state `F`.
23    StreamEnded(F),
24    /// The stream was short-cirtuited with result `R`.
25    ShortCircuited(R),
26}
27
28impl<F, R> FoldResult<F, R> {
29    /// Transforms into [`Result`] mapping the [`FoldResult::StreamEnded`]
30    /// variant into `Ok`.
31    pub fn ended(self) -> Result<F, R> {
32        match self {
33            FoldResult::StreamEnded(r) => Ok(r),
34            FoldResult::ShortCircuited(r) => Err(r),
35        }
36    }
37
38    /// Transforms into [`Result`] mapping the [`FoldResult::ShortCircuited`]
39    /// variant into `Ok`.
40    pub fn short_circuited(self) -> Result<R, F> {
41        match self {
42            FoldResult::StreamEnded(r) => Err(r),
43            FoldResult::ShortCircuited(r) => Ok(r),
44        }
45    }
46}
47
48impl<F> FoldResult<F, F> {
49    /// Unwraps this [`FoldResult`] into its inner value, discarding the variant
50    /// information.
51    pub fn into_inner(self) -> F {
52        match self {
53            FoldResult::StreamEnded(r) | FoldResult::ShortCircuited(r) => r,
54        }
55    }
56}
57
58/// Similar to [`TryStreamExt::try_fold`], but the closure `f` can short-circuit
59/// the operation by returning [`FoldWhile::Done`].
60///
61/// Returns [`FoldResult::StreamEnded`] with the current folded value when the
62/// stream ends. Returns [`FoldResult::ShortCircuited`] with the value of
63/// [`FoldWhile::Done`] if `f` short-circuits the operation.
64/// Returns `Err` if either `s` or `f` returns an error.
65pub fn try_fold_while<S, T, D, F, Fut>(
66    s: S,
67    init: T,
68    mut f: F,
69) -> impl Future<Output = Result<FoldResult<T, D>, S::Error>>
70where
71    S: TryStream,
72    F: FnMut(T, S::Ok) -> Fut,
73    Fut: Future<Output = Result<FoldWhile<T, D>, S::Error>>,
74{
75    s.map_err(Err)
76        .try_fold(init, move |acc, n| {
77            f(acc, n).map(|r| match r {
78                Ok(FoldWhile::Continue(r)) => Ok(r),
79                Ok(FoldWhile::Done(d)) => Err(Ok(d)),
80                Err(e) => Err(Err(e)),
81            })
82        })
83        .map(|r| match r {
84            Ok(n) => Ok(FoldResult::StreamEnded(n)),
85            Err(Ok(n)) => Ok(FoldResult::ShortCircuited(n)),
86            Err(Err(e)) => Err(e),
87        })
88}
89
90/// Similar to [`StreamExt::fold`], but the closure `f` can short-circuit
91/// the operation by returning [`FoldWhile::Done`].
92///
93/// Returns [`FoldResult::StreamEnded`] with the current folded value when the
94/// stream ends. Returns [`FoldResult::ShortCircuited`] with the value of
95/// [`FoldWhile::Done`] if `f` short-circuits the operation.
96pub fn fold_while<S, T, D, F, Fut>(
97    s: S,
98    init: T,
99    mut f: F,
100) -> impl Future<Output = FoldResult<T, D>>
101where
102    S: Stream,
103    F: FnMut(T, S::Item) -> Fut,
104    Fut: Future<Output = FoldWhile<T, D>>,
105{
106    s.map(Ok)
107        .try_fold(init, move |acc, n| {
108            f(acc, n).map(|r| match r {
109                FoldWhile::Continue(r) => Ok(r),
110                FoldWhile::Done(d) => Err(d),
111            })
112        })
113        .map(|r| match r {
114            Ok(n) => FoldResult::StreamEnded(n),
115            Err(n) => FoldResult::ShortCircuited(n),
116        })
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use fuchsia_async as fasync;
123    use futures::channel::mpsc;
124    use futures::future;
125
126    #[fasync::run_singlethreaded(test)]
127    async fn test_try_fold_while_short_circuit() {
128        let (sender, stream) = mpsc::unbounded::<u32>();
129        const STOP_AT: u32 = 5;
130        let mut sum = 0;
131        for i in 0..10 {
132            if i < STOP_AT {
133                sum += i;
134            }
135            let () = sender.unbounded_send(i).expect("failed to send item");
136        }
137        let (acc, stop) = try_fold_while(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
138            future::ok(if next == STOP_AT {
139                FoldWhile::Done((acc, next))
140            } else {
141                FoldWhile::Continue(next + acc)
142            })
143        })
144        .await
145        .expect("try_fold_while failed")
146        .short_circuited()
147        .expect("try_fold_while should've short-circuited");
148        assert_eq!(stop, STOP_AT);
149        assert_eq!(acc, sum);
150    }
151
152    #[fasync::run_singlethreaded(test)]
153    async fn test_try_fold_while_stream_ended() {
154        let (sender, stream) = mpsc::unbounded::<u32>();
155        let mut sum = 0u32;
156        for i in 0..10 {
157            sum += i;
158            let () = sender.unbounded_send(i).expect("failed to send item");
159        }
160        std::mem::drop(sender);
161        let result =
162            try_fold_while::<_, _, (), _, _>(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
163                future::ok(FoldWhile::Continue(next + acc))
164            })
165            .await
166            .expect("try_fold_while failed")
167            .ended()
168            .expect("try_fold_while should have seen the stream end");
169
170        assert_eq!(result, sum);
171    }
172
173    #[fasync::run_singlethreaded(test)]
174    async fn test_try_fold_while_stream_error() {
175        #[derive(Debug)]
176        struct StreamErr;
177        let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
178        let () = sender.unbounded_send(Err(StreamErr {})).expect("failed to send item");
179        let StreamErr {} = try_fold_while::<_, _, (), _, _>(stream, (), |(), _: u32| async {
180            panic!("shouldn't receive error input");
181        })
182        .await
183        .expect_err("try_fold_while should return error");
184    }
185
186    #[fasync::run_singlethreaded(test)]
187    async fn test_try_fold_while_closure_error() {
188        #[derive(Debug)]
189        struct StreamErr {
190            item: u32,
191        }
192        const ERROR_ITEM: u32 = 1234;
193        let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
194        let () = sender.unbounded_send(Ok(ERROR_ITEM)).expect("failed to send item");
195        let StreamErr { item } = try_fold_while::<_, _, (), _, _>(stream, (), |(), item| {
196            future::err(StreamErr { item })
197        })
198        .await
199        .expect_err("try_fold_while should return error");
200        assert_eq!(item, ERROR_ITEM);
201    }
202
203    #[fasync::run_singlethreaded(test)]
204    async fn test_fold_while_short_circuit() {
205        let (sender, stream) = mpsc::unbounded::<u32>();
206        const STOP_AT: u32 = 5;
207        let mut sum = 0;
208        for i in 0..10 {
209            if i < STOP_AT {
210                sum += i;
211            }
212            let () = sender.unbounded_send(i).expect("failed to send item");
213        }
214        let (acc, stop) = fold_while(stream, 0, |acc, next| {
215            future::ready(if next == STOP_AT {
216                FoldWhile::Done((acc, next))
217            } else {
218                FoldWhile::Continue(next + acc)
219            })
220        })
221        .await
222        .short_circuited()
223        .expect("fold_while should've short-circuited");
224        assert_eq!(stop, STOP_AT);
225        assert_eq!(acc, sum);
226    }
227
228    #[fasync::run_singlethreaded(test)]
229    async fn test_fold_while_stream_ended() {
230        let (sender, stream) = mpsc::unbounded::<u32>();
231        let mut sum = 0u32;
232        for i in 0..10 {
233            sum += i;
234            let () = sender.unbounded_send(i).expect("failed to send item");
235        }
236        std::mem::drop(sender);
237        let result = fold_while::<_, _, (), _, _>(stream, 0, |acc, next| {
238            future::ready(FoldWhile::Continue(next + acc))
239        })
240        .await
241        .ended()
242        .expect("fold_while should have seen the stream end");
243
244        assert_eq!(result, sum);
245    }
246
247    #[test]
248    fn test_fold_result_into_inner() {
249        let x = FoldResult::<u32, u32>::StreamEnded(1);
250        let y = FoldResult::<u32, u32>::ShortCircuited(2);
251        assert_eq!(x.into_inner(), 1);
252        assert_eq!(y.into_inner(), 2);
253    }
254
255    #[test]
256    fn test_fold_result_mapping() {
257        type FoldResult = super::FoldResult<u32, bool>;
258        assert_eq!(FoldResult::StreamEnded(1).ended(), Ok(1));
259        assert_eq!(FoldResult::ShortCircuited(false).ended(), Err(false));
260
261        assert_eq!(FoldResult::StreamEnded(2).short_circuited(), Err(2));
262        assert_eq!(FoldResult::ShortCircuited(true).short_circuited(), Ok(true));
263    }
264}