1#![deny(missing_docs)]
6#![allow(clippy::let_unit_value)]
7
8use futures::channel::mpsc;
11use futures::future::FusedFuture;
12use futures::prelude::*;
13use futures::stream::FusedStream;
14use futures::task::{Context, Poll};
15use pin_project::pin_project;
16use std::pin::Pin;
17
18pub fn generate<'a, I, R, C, F>(cb: C) -> Generator<F, I, R>
25where
26 C: FnOnce(Yield<I>) -> F,
27 F: Future<Output = R> + 'a,
28 I: Send + 'static,
29 R: Send + 'static,
30{
31 let (send, recv) = mpsc::channel(0);
32 Generator { task: cb(Yield(send)).fuse(), stream: recv, res: None }
33}
34
35pub struct Yield<I>(mpsc::Sender<I>);
37
38impl<I> Yield<I>
39where
40 I: Send + 'static,
41{
42 pub fn yield_(&mut self, item: I) -> impl Future<Output = ()> + '_ {
44 self.0.send(item).map(|_| ())
46 }
47
48 pub fn yield_all<S>(&mut self, items: S) -> impl Future<Output = ()> + '_
50 where
51 S: IntoIterator<Item = I>,
52 S::IntoIter: 'static,
53 {
54 let mut items = futures::stream::iter(items.into_iter().map(Ok));
55 async move {
56 let _ = self.0.send_all(&mut items).await;
57 }
58 }
59}
60
61#[derive(Debug, PartialEq, Eq)]
63pub enum GeneratorState<I, R> {
64 Yielded(I),
66
67 Complete(R),
69}
70
71impl<I, R> GeneratorState<I, R> {
72 fn into_yielded(self) -> Option<I> {
73 match self {
74 GeneratorState::Yielded(item) => Some(item),
75 _ => None,
76 }
77 }
78
79 fn into_complete(self) -> Option<R> {
80 match self {
81 GeneratorState::Complete(res) => Some(res),
82 _ => None,
83 }
84 }
85}
86
87#[pin_project]
89#[derive(Debug)]
90pub struct Generator<F, I, R>
91where
92 F: Future<Output = R>,
93{
94 #[pin]
95 task: future::Fuse<F>,
96 #[pin]
97 stream: mpsc::Receiver<I>,
98 res: Option<R>,
99}
100
101impl<F, I, E> Generator<F, I, Result<(), E>>
102where
103 F: Future<Output = Result<(), E>>,
104{
105 pub fn into_try_stream(self) -> impl FusedStream<Item = Result<I, E>> {
107 self.filter_map(|state| {
108 future::ready(match state {
109 GeneratorState::Yielded(i) => Some(Ok(i)),
110 GeneratorState::Complete(Ok(())) => None,
111 GeneratorState::Complete(Err(e)) => Some(Err(e)),
112 })
113 })
114 }
115}
116
117impl<F, I, R> Generator<F, I, R>
118where
119 F: Future<Output = R>,
120{
121 pub async fn into_complete(self) -> R {
123 let s = self.filter_map(|state| future::ready(state.into_complete()));
124 futures::pin_mut!(s);
125
126 s.next().await.unwrap()
129 }
130}
131
132impl<F, I> Generator<F, I, ()>
133where
134 F: Future<Output = ()>,
135{
136 pub fn into_yielded(self) -> impl FusedStream<Item = I> {
139 self.filter_map(|state| future::ready(state.into_yielded()))
140 }
141}
142
143impl<F, I, R> Stream for Generator<F, I, R>
144where
145 F: Future<Output = R>,
146{
147 type Item = GeneratorState<I, R>;
148
149 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150 let this = self.project();
151
152 let mut task_done = this.task.is_terminated();
155 if let Poll::Ready(res) = this.task.poll(cx) {
156 this.res.replace(res);
158 task_done = true;
159 }
160
161 if !this.stream.is_terminated() {
164 match this.stream.poll_next(cx) {
165 Poll::Pending => return Poll::Pending,
166 Poll::Ready(Some(item)) => return Poll::Ready(Some(GeneratorState::Yielded(item))),
167 Poll::Ready(None) => {}
168 }
169 }
170
171 if !task_done {
172 return Poll::Pending;
173 }
174
175 match this.res.take() {
177 Some(res) => Poll::Ready(Some(GeneratorState::Complete(res))),
178 None => Poll::Ready(None),
179 }
180 }
181}
182
183impl<F, I, R> FusedStream for Generator<F, I, R>
184where
185 F: Future<Output = R>,
186{
187 fn is_terminated(&self) -> bool {
188 self.task.is_terminated() && self.stream.is_terminated() && self.res.is_none()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use futures::executor::block_on;
196 use std::sync::atomic;
197
198 fn yield_once() -> impl Future<Output = ()> {
200 let mut done = false;
201 future::poll_fn(move |cx: &mut Context<'_>| {
202 if !done {
203 done = true;
204 cx.waker().wake_by_ref();
205 Poll::Pending
206 } else {
207 Poll::Ready(())
208 }
209 })
210 }
211
212 #[derive(Debug, Default)]
213 struct Counter(atomic::AtomicU32);
214
215 impl Counter {
216 fn inc(&self) {
217 self.0.fetch_add(1, atomic::Ordering::SeqCst);
218 }
219
220 fn take(&self) -> u32 {
221 self.0.swap(0, atomic::Ordering::SeqCst)
222 }
223 }
224
225 #[test]
226 fn generator_waits_for_item_to_yield() {
227 let counter = Counter::default();
228
229 let s = generate(|mut co| {
230 let counter = &counter;
231 async move {
232 counter.inc();
233 co.yield_("first").await;
234
235 counter.inc();
238 yield_once().await;
239
240 counter.inc();
241 co.yield_("second").await;
242
243 drop(co);
244 yield_once().await;
245
246 counter.inc();
247 }
248 });
249
250 block_on(async {
251 futures::pin_mut!(s);
252
253 assert_eq!(counter.take(), 0);
254
255 assert_eq!(s.next().await, Some(GeneratorState::Yielded("first")));
256 assert_eq!(counter.take(), 1);
257
258 assert_eq!(s.next().await, Some(GeneratorState::Yielded("second")));
259 assert_eq!(counter.take(), 2);
260
261 assert_eq!(s.next().await, Some(GeneratorState::Complete(())));
262 assert_eq!(counter.take(), 1);
263
264 assert_eq!(s.next().await, None);
265 assert_eq!(counter.take(), 0);
266 });
267 }
268
269 #[test]
270 fn yield_all_yields_all() {
271 let s = generate(|mut co| async move {
272 co.yield_all(1u32..4).await;
273 co.yield_(42).await;
274 });
275
276 let res = block_on(s.collect::<Vec<GeneratorState<u32, ()>>>());
277
278 assert_eq!(
279 res,
280 vec![
281 GeneratorState::Yielded(1),
282 GeneratorState::Yielded(2),
283 GeneratorState::Yielded(3),
284 GeneratorState::Yielded(42),
285 GeneratorState::Complete(()),
286 ]
287 );
288 }
289
290 #[test]
291 fn fused_impl() {
292 let s = generate(|mut co| async move {
293 co.yield_(1u32).await;
294 drop(co);
295
296 yield_once().await;
297
298 "done"
299 });
300
301 block_on(async {
302 futures::pin_mut!(s);
303
304 assert!(!s.is_terminated());
305 assert_eq!(s.next().await, Some(GeneratorState::Yielded(1)));
306
307 assert!(!s.is_terminated());
308 assert_eq!(s.next().await, Some(GeneratorState::Complete("done")));
309
310 assert!(s.is_terminated());
314 assert_eq!(s.next().await, None);
315
316 assert!(s.is_terminated());
317 });
318 }
319
320 #[test]
321 fn into_try_stream_transposes_generator_states() {
322 let s = generate(|mut co| async move {
323 co.yield_(1u8).await;
324 co.yield_(2u8).await;
325
326 Result::<(), &'static str>::Err("oops")
327 })
328 .into_try_stream();
329
330 let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
331
332 assert_eq!(res, vec![Ok(1), Ok(2), Err("oops")]);
333 }
334
335 #[test]
336 fn into_try_stream_eats_unit_success() {
337 let s = generate(|mut co| async move {
338 co.yield_(1u8).await;
339 co.yield_(2u8).await;
340
341 Result::<(), &'static str>::Ok(())
342 })
343 .into_try_stream();
344
345 let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
346
347 assert_eq!(res, vec![Ok(1), Ok(2)]);
348 }
349
350 #[test]
351 fn runs_task_to_completion() {
352 let finished = Counter::default();
353
354 let make_s = || {
355 generate(|mut co| async {
356 co.yield_(8u8).await;
357
358 drop(co);
360 yield_once().await;
361
362 finished.inc();
363 })
364 };
365
366 block_on(async {
369 let res = make_s().collect::<Vec<GeneratorState<u8, ()>>>().await;
370 assert_eq!(res, vec![GeneratorState::Yielded(8), GeneratorState::Complete(())]);
371 assert_eq!(finished.take(), 1);
372 });
373
374 block_on(async {
375 assert_eq!(make_s().into_yielded().collect::<Vec<_>>().await, vec![8]);
376 assert_eq!(finished.take(), 1);
377 });
378
379 block_on(async {
380 let () = make_s().into_complete().await;
381 assert_eq!(finished.take(), 1);
382 });
383 }
384
385 #[test]
386 fn fibonacci() {
387 let fib = generate(|mut co| async move {
388 let (mut a, mut b) = (0u32, 1u32);
389 loop {
390 co.yield_(a).await;
391
392 let n = b;
393 b += a;
394 a = n;
395 }
396 })
397 .into_yielded()
398 .take(10)
399 .collect::<Vec<_>>();
400
401 assert_eq!(block_on(fib), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
402 }
403}