omaha_client/
async_generator.rs

1// Copyright 2020 The Fuchsia Authors
2//
3// Licensed under a BSD-style license <LICENSE-BSD>, Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0>, or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your option.
6// This file may not be copied, modified, or distributed except according to
7// those terms.
8
9#![deny(missing_docs)]
10#![allow(clippy::let_unit_value)]
11
12//! Asynchronous generator-like functionality in stable Rust.
13
14use {
15    futures::{
16        channel::mpsc,
17        future::FusedFuture,
18        prelude::*,
19        stream::FusedStream,
20        task::{Context, Poll},
21    },
22    pin_project::pin_project,
23    std::pin::Pin,
24};
25
26/// Produces an asynchronous `Stream` of [`GeneratorState<I, R>`] by invoking the given closure
27/// with a handle that can be used to yield items.
28///
29/// The returned `Stream` will produce a GeneratorState::Yielded variant for all yielded items
30/// from the asynchronous task, followed by a single GeneratorState::Complete variant, which will
31/// always be present as the final element in the stream.
32pub fn generate<'a, I, R, C, F>(cb: C) -> Generator<F, I, R>
33where
34    C: FnOnce(Yield<I>) -> F,
35    F: Future<Output = R> + 'a,
36    I: Send + 'static,
37    R: Send + 'static,
38{
39    let (send, recv) = mpsc::channel(0);
40    Generator {
41        task: cb(Yield(send)).fuse(),
42        stream: recv,
43        res: None,
44    }
45}
46
47/// Control handle to yield items to the coroutine.
48pub struct Yield<I>(mpsc::Sender<I>);
49
50impl<I> Yield<I>
51where
52    I: Send + 'static,
53{
54    /// Yield a single item to the coroutine, waiting for it to receive the item.
55    pub fn yield_(&mut self, item: I) -> impl Future<Output = ()> + '_ {
56        // Ignore errors as Generator never drops the stream before the task.
57        self.0.send(item).map(|_| ())
58    }
59
60    /// Yield multiple items to the coroutine, waiting for it to receive all of them.
61    pub fn yield_all<S>(&mut self, items: S) -> impl Future<Output = ()> + '_
62    where
63        S: IntoIterator<Item = I>,
64        S::IntoIter: 'static,
65    {
66        let mut items = futures::stream::iter(items.into_iter().map(Ok));
67        async move {
68            let _ = self.0.send_all(&mut items).await;
69        }
70    }
71}
72
73/// Emitted state from an async generator.
74#[derive(Debug, PartialEq, Eq)]
75pub enum GeneratorState<I, R> {
76    /// The async generator yielded a value.
77    Yielded(I),
78
79    /// The async generator completed with a return value.
80    Complete(R),
81}
82
83impl<I, R> GeneratorState<I, R> {
84    fn into_yielded(self) -> Option<I> {
85        match self {
86            GeneratorState::Yielded(item) => Some(item),
87            _ => None,
88        }
89    }
90
91    fn into_complete(self) -> Option<R> {
92        match self {
93            GeneratorState::Complete(res) => Some(res),
94            _ => None,
95        }
96    }
97}
98
99/// An asynchronous generator.
100#[pin_project]
101#[derive(Debug)]
102pub struct Generator<F, I, R>
103where
104    F: Future<Output = R>,
105{
106    #[pin]
107    task: future::Fuse<F>,
108    #[pin]
109    stream: mpsc::Receiver<I>,
110    res: Option<R>,
111}
112
113impl<F, I, E> Generator<F, I, Result<(), E>>
114where
115    F: Future<Output = Result<(), E>>,
116{
117    /// Transforms this stream of `GeneratorState<I, Result<(), E>>` into a stream of `Result<I, E>`.
118    pub fn into_try_stream(self) -> impl FusedStream<Item = Result<I, E>> {
119        self.filter_map(|state| {
120            future::ready(match state {
121                GeneratorState::Yielded(i) => Some(Ok(i)),
122                GeneratorState::Complete(Ok(())) => None,
123                GeneratorState::Complete(Err(e)) => Some(Err(e)),
124            })
125        })
126    }
127}
128
129impl<F, I, R> Generator<F, I, R>
130where
131    F: Future<Output = R>,
132{
133    /// Discards all intermediate values produced by this generator, producing just the final result.
134    pub async fn into_complete(self) -> R {
135        let s = self.filter_map(|state| future::ready(state.into_complete()));
136        futures::pin_mut!(s);
137
138        // Generators always yield a complete item as the final element once the task
139        // completes.
140        s.next().await.unwrap()
141    }
142}
143
144impl<F, I> Generator<F, I, ()>
145where
146    F: Future<Output = ()>,
147{
148    /// Filters the states produced by this generator to only include intermediate yielded values,
149    /// discarding the final result.
150    pub fn into_yielded(self) -> impl FusedStream<Item = I> {
151        self.filter_map(|state| future::ready(state.into_yielded()))
152    }
153}
154
155impl<F, I, R> Stream for Generator<F, I, R>
156where
157    F: Future<Output = R>,
158{
159    type Item = GeneratorState<I, R>;
160
161    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162        let this = self.project();
163
164        // Always poll the task first to make forward progress and maybe push an item into the
165        // channel.
166        let mut task_done = this.task.is_terminated();
167        if let Poll::Ready(res) = this.task.poll(cx) {
168            // This stream might not be ready for the final result yet, store it for later.
169            this.res.replace(res);
170            task_done = true;
171        }
172
173        // Return anything available from the stream, ignoring stream termination to let the task
174        // termination yield the last value.
175        if !this.stream.is_terminated() {
176            match this.stream.poll_next(cx) {
177                Poll::Pending => return Poll::Pending,
178                Poll::Ready(Some(item)) => return Poll::Ready(Some(GeneratorState::Yielded(item))),
179                Poll::Ready(None) => {}
180            }
181        }
182
183        if !task_done {
184            return Poll::Pending;
185        }
186
187        // Flush the final result once all tasks are done.
188        match this.res.take() {
189            Some(res) => Poll::Ready(Some(GeneratorState::Complete(res))),
190            None => Poll::Ready(None),
191        }
192    }
193}
194
195impl<F, I, R> FusedStream for Generator<F, I, R>
196where
197    F: Future<Output = R>,
198{
199    fn is_terminated(&self) -> bool {
200        self.task.is_terminated() && self.stream.is_terminated() && self.res.is_none()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use futures::executor::block_on;
208    use std::sync::atomic;
209
210    /// Returns a future that yields to the executor once before completing.
211    fn yield_once() -> impl Future<Output = ()> {
212        let mut done = false;
213        future::poll_fn(move |cx: &mut Context<'_>| {
214            if !done {
215                done = true;
216                cx.waker().wake_by_ref();
217                Poll::Pending
218            } else {
219                Poll::Ready(())
220            }
221        })
222    }
223
224    #[derive(Debug, Default)]
225    struct Counter(atomic::AtomicU32);
226
227    impl Counter {
228        fn inc(&self) {
229            self.0.fetch_add(1, atomic::Ordering::SeqCst);
230        }
231
232        fn take(&self) -> u32 {
233            self.0.swap(0, atomic::Ordering::SeqCst)
234        }
235    }
236
237    #[test]
238    fn generator_waits_for_item_to_yield() {
239        let counter = Counter::default();
240
241        let s = generate(|mut co| {
242            let counter = &counter;
243            async move {
244                counter.inc();
245                co.yield_("first").await;
246
247                // This yield should not be observable by the stream, but the extra increment will
248                // be.
249                counter.inc();
250                yield_once().await;
251
252                counter.inc();
253                co.yield_("second").await;
254
255                drop(co);
256                yield_once().await;
257
258                counter.inc();
259            }
260        });
261
262        block_on(async {
263            futures::pin_mut!(s);
264
265            assert_eq!(counter.take(), 0);
266
267            assert_eq!(s.next().await, Some(GeneratorState::Yielded("first")));
268            assert_eq!(counter.take(), 1);
269
270            assert_eq!(s.next().await, Some(GeneratorState::Yielded("second")));
271            assert_eq!(counter.take(), 2);
272
273            assert_eq!(s.next().await, Some(GeneratorState::Complete(())));
274            assert_eq!(counter.take(), 1);
275
276            assert_eq!(s.next().await, None);
277            assert_eq!(counter.take(), 0);
278        });
279    }
280
281    #[test]
282    fn yield_all_yields_all() {
283        let s = generate(|mut co| async move {
284            co.yield_all(1u32..4).await;
285            co.yield_(42).await;
286        });
287
288        let res = block_on(s.collect::<Vec<GeneratorState<u32, ()>>>());
289
290        assert_eq!(
291            res,
292            vec![
293                GeneratorState::Yielded(1),
294                GeneratorState::Yielded(2),
295                GeneratorState::Yielded(3),
296                GeneratorState::Yielded(42),
297                GeneratorState::Complete(()),
298            ]
299        );
300    }
301
302    #[test]
303    fn fused_impl() {
304        let s = generate(|mut co| async move {
305            co.yield_(1u32).await;
306            drop(co);
307
308            yield_once().await;
309
310            "done"
311        });
312
313        block_on(async {
314            futures::pin_mut!(s);
315
316            assert!(!s.is_terminated());
317            assert_eq!(s.next().await, Some(GeneratorState::Yielded(1)));
318
319            assert!(!s.is_terminated());
320            assert_eq!(s.next().await, Some(GeneratorState::Complete("done")));
321
322            // FusedStream's is_terminated typically returns false after yielding None to indicate
323            // no items are left, but it is also valid to return true when the stream is going to
324            // not make further progress.
325            assert!(s.is_terminated());
326            assert_eq!(s.next().await, None);
327
328            assert!(s.is_terminated());
329        });
330    }
331
332    #[test]
333    fn into_try_stream_transposes_generator_states() {
334        let s = generate(|mut co| async move {
335            co.yield_(1u8).await;
336            co.yield_(2u8).await;
337
338            Result::<(), &'static str>::Err("oops")
339        })
340        .into_try_stream();
341
342        let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
343
344        assert_eq!(res, vec![Ok(1), Ok(2), Err("oops")]);
345    }
346
347    #[test]
348    fn into_try_stream_eats_unit_success() {
349        let s = generate(|mut co| async move {
350            co.yield_(1u8).await;
351            co.yield_(2u8).await;
352
353            Result::<(), &'static str>::Ok(())
354        })
355        .into_try_stream();
356
357        let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
358
359        assert_eq!(res, vec![Ok(1), Ok(2)]);
360    }
361
362    #[test]
363    fn runs_task_to_completion() {
364        let finished = Counter::default();
365
366        let make_s = || {
367            generate(|mut co| async {
368                co.yield_(8u8).await;
369
370                // Try really hard to cause this task to be dropped without completing.
371                drop(co);
372                yield_once().await;
373
374                finished.inc();
375            })
376        };
377
378        // No matter which combinator is used.
379
380        block_on(async {
381            let res = make_s().collect::<Vec<GeneratorState<u8, ()>>>().await;
382            assert_eq!(
383                res,
384                vec![GeneratorState::Yielded(8), GeneratorState::Complete(())]
385            );
386            assert_eq!(finished.take(), 1);
387        });
388
389        block_on(async {
390            assert_eq!(make_s().into_yielded().collect::<Vec<_>>().await, vec![8]);
391            assert_eq!(finished.take(), 1);
392        });
393
394        block_on(async {
395            let () = make_s().into_complete().await;
396            assert_eq!(finished.take(), 1);
397        });
398    }
399
400    #[test]
401    fn fibonacci() {
402        let fib = generate(|mut co| async move {
403            let (mut a, mut b) = (0u32, 1u32);
404            loop {
405                co.yield_(a).await;
406
407                let n = b;
408                b += a;
409                a = n;
410            }
411        })
412        .into_yielded()
413        .take(10)
414        .collect::<Vec<_>>();
415
416        assert_eq!(block_on(fib), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
417    }
418}