async_generator/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6#![allow(clippy::let_unit_value)]
7
8//! Asynchronous generator-like functionality in stable Rust.
9
10use futures::channel::mpsc;
11use futures::future::FusedFuture;
12use futures::prelude::*;
13use futures::stream::FusedStream;
14use futures::task::{Context, Poll};
15use pin_project::pin_project;
16use std::pin::Pin;
17
18/// Produces an asynchronous `Stream` of [`GeneratorState<I, R>`] by invoking the given closure
19/// with a handle that can be used to yield items.
20///
21/// The returned `Stream` will produce a GeneratorState::Yielded variant for all yielded items
22/// from the asynchronous task, followed by a single GeneratorState::Complete variant, which will
23/// always be present as the final element in the stream.
24pub fn generate<'a, I, R, C, F>(cb: C) -> Generator<F, I, R>
25where
26    C: FnOnce(Yield<I>) -> F,
27    F: Future<Output = R> + 'a,
28    I: Send + 'static,
29    R: Send + 'static,
30{
31    let (send, recv) = mpsc::channel(0);
32    Generator { task: cb(Yield(send)).fuse(), stream: recv, res: None }
33}
34
35/// Control handle to yield items to the coroutine.
36pub struct Yield<I>(mpsc::Sender<I>);
37
38impl<I> Yield<I>
39where
40    I: Send + 'static,
41{
42    /// Yield a single item to the coroutine, waiting for it to receive the item.
43    pub fn yield_(&mut self, item: I) -> impl Future<Output = ()> + '_ {
44        // Ignore errors as Generator never drops the stream before the task.
45        self.0.send(item).map(|_| ())
46    }
47
48    /// Yield multiple items to the coroutine, waiting for it to receive all of them.
49    pub fn yield_all<S>(&mut self, items: S) -> impl Future<Output = ()> + '_
50    where
51        S: IntoIterator<Item = I>,
52        S::IntoIter: 'static,
53    {
54        let mut items = futures::stream::iter(items.into_iter().map(Ok));
55        async move {
56            let _ = self.0.send_all(&mut items).await;
57        }
58    }
59}
60
61/// Emitted state from an async generator.
62#[derive(Debug, PartialEq, Eq)]
63pub enum GeneratorState<I, R> {
64    /// The async generator yielded a value.
65    Yielded(I),
66
67    /// The async generator completed with a return value.
68    Complete(R),
69}
70
71impl<I, R> GeneratorState<I, R> {
72    fn into_yielded(self) -> Option<I> {
73        match self {
74            GeneratorState::Yielded(item) => Some(item),
75            _ => None,
76        }
77    }
78
79    fn into_complete(self) -> Option<R> {
80        match self {
81            GeneratorState::Complete(res) => Some(res),
82            _ => None,
83        }
84    }
85}
86
87/// An asynchronous generator.
88#[pin_project]
89#[derive(Debug)]
90pub struct Generator<F, I, R>
91where
92    F: Future<Output = R>,
93{
94    #[pin]
95    task: future::Fuse<F>,
96    #[pin]
97    stream: mpsc::Receiver<I>,
98    res: Option<R>,
99}
100
101impl<F, I, E> Generator<F, I, Result<(), E>>
102where
103    F: Future<Output = Result<(), E>>,
104{
105    /// Transforms this stream of `GeneratorState<I, Result<(), E>>` into a stream of `Result<I, E>`.
106    pub fn into_try_stream(self) -> impl FusedStream<Item = Result<I, E>> {
107        self.filter_map(|state| {
108            future::ready(match state {
109                GeneratorState::Yielded(i) => Some(Ok(i)),
110                GeneratorState::Complete(Ok(())) => None,
111                GeneratorState::Complete(Err(e)) => Some(Err(e)),
112            })
113        })
114    }
115}
116
117impl<F, I, R> Generator<F, I, R>
118where
119    F: Future<Output = R>,
120{
121    /// Discards all intermediate values produced by this generator, producing just the final result.
122    pub async fn into_complete(self) -> R {
123        let s = self.filter_map(|state| future::ready(state.into_complete()));
124        futures::pin_mut!(s);
125
126        // Generators always yield a complete item as the final element once the task
127        // completes.
128        s.next().await.unwrap()
129    }
130}
131
132impl<F, I> Generator<F, I, ()>
133where
134    F: Future<Output = ()>,
135{
136    /// Filters the states produced by this generator to only include intermediate yielded values,
137    /// discarding the final result.
138    pub fn into_yielded(self) -> impl FusedStream<Item = I> {
139        self.filter_map(|state| future::ready(state.into_yielded()))
140    }
141}
142
143impl<F, I, R> Stream for Generator<F, I, R>
144where
145    F: Future<Output = R>,
146{
147    type Item = GeneratorState<I, R>;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        let this = self.project();
151
152        // Always poll the task first to make forward progress and maybe push an item into the
153        // channel.
154        let mut task_done = this.task.is_terminated();
155        if let Poll::Ready(res) = this.task.poll(cx) {
156            // This stream might not be ready for the final result yet, store it for later.
157            this.res.replace(res);
158            task_done = true;
159        }
160
161        // Return anything available from the stream, ignoring stream termination to let the task
162        // termination yield the last value.
163        if !this.stream.is_terminated() {
164            match this.stream.poll_next(cx) {
165                Poll::Pending => return Poll::Pending,
166                Poll::Ready(Some(item)) => return Poll::Ready(Some(GeneratorState::Yielded(item))),
167                Poll::Ready(None) => {}
168            }
169        }
170
171        if !task_done {
172            return Poll::Pending;
173        }
174
175        // Flush the final result once all tasks are done.
176        match this.res.take() {
177            Some(res) => Poll::Ready(Some(GeneratorState::Complete(res))),
178            None => Poll::Ready(None),
179        }
180    }
181}
182
183impl<F, I, R> FusedStream for Generator<F, I, R>
184where
185    F: Future<Output = R>,
186{
187    fn is_terminated(&self) -> bool {
188        self.task.is_terminated() && self.stream.is_terminated() && self.res.is_none()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use futures::executor::block_on;
196    use std::sync::atomic;
197
198    /// Returns a future that yields to the executor once before completing.
199    fn yield_once() -> impl Future<Output = ()> {
200        let mut done = false;
201        future::poll_fn(move |cx: &mut Context<'_>| {
202            if !done {
203                done = true;
204                cx.waker().wake_by_ref();
205                Poll::Pending
206            } else {
207                Poll::Ready(())
208            }
209        })
210    }
211
212    #[derive(Debug, Default)]
213    struct Counter(atomic::AtomicU32);
214
215    impl Counter {
216        fn inc(&self) {
217            self.0.fetch_add(1, atomic::Ordering::SeqCst);
218        }
219
220        fn take(&self) -> u32 {
221            self.0.swap(0, atomic::Ordering::SeqCst)
222        }
223    }
224
225    #[test]
226    fn generator_waits_for_item_to_yield() {
227        let counter = Counter::default();
228
229        let s = generate(|mut co| {
230            let counter = &counter;
231            async move {
232                counter.inc();
233                co.yield_("first").await;
234
235                // This yield should not be observable by the stream, but the extra increment will
236                // be.
237                counter.inc();
238                yield_once().await;
239
240                counter.inc();
241                co.yield_("second").await;
242
243                drop(co);
244                yield_once().await;
245
246                counter.inc();
247            }
248        });
249
250        block_on(async {
251            futures::pin_mut!(s);
252
253            assert_eq!(counter.take(), 0);
254
255            assert_eq!(s.next().await, Some(GeneratorState::Yielded("first")));
256            assert_eq!(counter.take(), 1);
257
258            assert_eq!(s.next().await, Some(GeneratorState::Yielded("second")));
259            assert_eq!(counter.take(), 2);
260
261            assert_eq!(s.next().await, Some(GeneratorState::Complete(())));
262            assert_eq!(counter.take(), 1);
263
264            assert_eq!(s.next().await, None);
265            assert_eq!(counter.take(), 0);
266        });
267    }
268
269    #[test]
270    fn yield_all_yields_all() {
271        let s = generate(|mut co| async move {
272            co.yield_all(1u32..4).await;
273            co.yield_(42).await;
274        });
275
276        let res = block_on(s.collect::<Vec<GeneratorState<u32, ()>>>());
277
278        assert_eq!(
279            res,
280            vec![
281                GeneratorState::Yielded(1),
282                GeneratorState::Yielded(2),
283                GeneratorState::Yielded(3),
284                GeneratorState::Yielded(42),
285                GeneratorState::Complete(()),
286            ]
287        );
288    }
289
290    #[test]
291    fn fused_impl() {
292        let s = generate(|mut co| async move {
293            co.yield_(1u32).await;
294            drop(co);
295
296            yield_once().await;
297
298            "done"
299        });
300
301        block_on(async {
302            futures::pin_mut!(s);
303
304            assert!(!s.is_terminated());
305            assert_eq!(s.next().await, Some(GeneratorState::Yielded(1)));
306
307            assert!(!s.is_terminated());
308            assert_eq!(s.next().await, Some(GeneratorState::Complete("done")));
309
310            // FusedStream's is_terminated typically returns false after yielding None to indicate
311            // no items are left, but it is also valid to return true when the stream is going to
312            // not make further progress.
313            assert!(s.is_terminated());
314            assert_eq!(s.next().await, None);
315
316            assert!(s.is_terminated());
317        });
318    }
319
320    #[test]
321    fn into_try_stream_transposes_generator_states() {
322        let s = generate(|mut co| async move {
323            co.yield_(1u8).await;
324            co.yield_(2u8).await;
325
326            Result::<(), &'static str>::Err("oops")
327        })
328        .into_try_stream();
329
330        let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
331
332        assert_eq!(res, vec![Ok(1), Ok(2), Err("oops")]);
333    }
334
335    #[test]
336    fn into_try_stream_eats_unit_success() {
337        let s = generate(|mut co| async move {
338            co.yield_(1u8).await;
339            co.yield_(2u8).await;
340
341            Result::<(), &'static str>::Ok(())
342        })
343        .into_try_stream();
344
345        let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
346
347        assert_eq!(res, vec![Ok(1), Ok(2)]);
348    }
349
350    #[test]
351    fn runs_task_to_completion() {
352        let finished = Counter::default();
353
354        let make_s = || {
355            generate(|mut co| async {
356                co.yield_(8u8).await;
357
358                // Try really hard to cause this task to be dropped without completing.
359                drop(co);
360                yield_once().await;
361
362                finished.inc();
363            })
364        };
365
366        // No matter which combinator is used.
367
368        block_on(async {
369            let res = make_s().collect::<Vec<GeneratorState<u8, ()>>>().await;
370            assert_eq!(res, vec![GeneratorState::Yielded(8), GeneratorState::Complete(())]);
371            assert_eq!(finished.take(), 1);
372        });
373
374        block_on(async {
375            assert_eq!(make_s().into_yielded().collect::<Vec<_>>().await, vec![8]);
376            assert_eq!(finished.take(), 1);
377        });
378
379        block_on(async {
380            let () = make_s().into_complete().await;
381            assert_eq!(finished.take(), 1);
382        });
383    }
384
385    #[test]
386    fn fibonacci() {
387        let fib = generate(|mut co| async move {
388            let (mut a, mut b) = (0u32, 1u32);
389            loop {
390                co.yield_(a).await;
391
392                let n = b;
393                b += a;
394                a = n;
395            }
396        })
397        .into_yielded()
398        .take(10)
399        .collect::<Vec<_>>();
400
401        assert_eq!(block_on(fib), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
402    }
403}