1#![deny(missing_docs)]
10#![allow(clippy::let_unit_value)]
11
12use {
15 futures::{
16 channel::mpsc,
17 future::FusedFuture,
18 prelude::*,
19 stream::FusedStream,
20 task::{Context, Poll},
21 },
22 pin_project::pin_project,
23 std::pin::Pin,
24};
25
26pub fn generate<'a, I, R, C, F>(cb: C) -> Generator<F, I, R>
33where
34 C: FnOnce(Yield<I>) -> F,
35 F: Future<Output = R> + 'a,
36 I: Send + 'static,
37 R: Send + 'static,
38{
39 let (send, recv) = mpsc::channel(0);
40 Generator {
41 task: cb(Yield(send)).fuse(),
42 stream: recv,
43 res: None,
44 }
45}
46
47pub struct Yield<I>(mpsc::Sender<I>);
49
50impl<I> Yield<I>
51where
52 I: Send + 'static,
53{
54 pub fn yield_(&mut self, item: I) -> impl Future<Output = ()> + '_ {
56 self.0.send(item).map(|_| ())
58 }
59
60 pub fn yield_all<S>(&mut self, items: S) -> impl Future<Output = ()> + '_
62 where
63 S: IntoIterator<Item = I>,
64 S::IntoIter: 'static,
65 {
66 let mut items = futures::stream::iter(items.into_iter().map(Ok));
67 async move {
68 let _ = self.0.send_all(&mut items).await;
69 }
70 }
71}
72
73#[derive(Debug, PartialEq, Eq)]
75pub enum GeneratorState<I, R> {
76 Yielded(I),
78
79 Complete(R),
81}
82
83impl<I, R> GeneratorState<I, R> {
84 fn into_yielded(self) -> Option<I> {
85 match self {
86 GeneratorState::Yielded(item) => Some(item),
87 _ => None,
88 }
89 }
90
91 fn into_complete(self) -> Option<R> {
92 match self {
93 GeneratorState::Complete(res) => Some(res),
94 _ => None,
95 }
96 }
97}
98
99#[pin_project]
101#[derive(Debug)]
102pub struct Generator<F, I, R>
103where
104 F: Future<Output = R>,
105{
106 #[pin]
107 task: future::Fuse<F>,
108 #[pin]
109 stream: mpsc::Receiver<I>,
110 res: Option<R>,
111}
112
113impl<F, I, E> Generator<F, I, Result<(), E>>
114where
115 F: Future<Output = Result<(), E>>,
116{
117 pub fn into_try_stream(self) -> impl FusedStream<Item = Result<I, E>> {
119 self.filter_map(|state| {
120 future::ready(match state {
121 GeneratorState::Yielded(i) => Some(Ok(i)),
122 GeneratorState::Complete(Ok(())) => None,
123 GeneratorState::Complete(Err(e)) => Some(Err(e)),
124 })
125 })
126 }
127}
128
129impl<F, I, R> Generator<F, I, R>
130where
131 F: Future<Output = R>,
132{
133 pub async fn into_complete(self) -> R {
135 let s = self.filter_map(|state| future::ready(state.into_complete()));
136 futures::pin_mut!(s);
137
138 s.next().await.unwrap()
141 }
142}
143
144impl<F, I> Generator<F, I, ()>
145where
146 F: Future<Output = ()>,
147{
148 pub fn into_yielded(self) -> impl FusedStream<Item = I> {
151 self.filter_map(|state| future::ready(state.into_yielded()))
152 }
153}
154
155impl<F, I, R> Stream for Generator<F, I, R>
156where
157 F: Future<Output = R>,
158{
159 type Item = GeneratorState<I, R>;
160
161 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
162 let this = self.project();
163
164 let mut task_done = this.task.is_terminated();
167 if let Poll::Ready(res) = this.task.poll(cx) {
168 this.res.replace(res);
170 task_done = true;
171 }
172
173 if !this.stream.is_terminated() {
176 match this.stream.poll_next(cx) {
177 Poll::Pending => return Poll::Pending,
178 Poll::Ready(Some(item)) => return Poll::Ready(Some(GeneratorState::Yielded(item))),
179 Poll::Ready(None) => {}
180 }
181 }
182
183 if !task_done {
184 return Poll::Pending;
185 }
186
187 match this.res.take() {
189 Some(res) => Poll::Ready(Some(GeneratorState::Complete(res))),
190 None => Poll::Ready(None),
191 }
192 }
193}
194
195impl<F, I, R> FusedStream for Generator<F, I, R>
196where
197 F: Future<Output = R>,
198{
199 fn is_terminated(&self) -> bool {
200 self.task.is_terminated() && self.stream.is_terminated() && self.res.is_none()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use futures::executor::block_on;
208 use std::sync::atomic;
209
210 fn yield_once() -> impl Future<Output = ()> {
212 let mut done = false;
213 future::poll_fn(move |cx: &mut Context<'_>| {
214 if !done {
215 done = true;
216 cx.waker().wake_by_ref();
217 Poll::Pending
218 } else {
219 Poll::Ready(())
220 }
221 })
222 }
223
224 #[derive(Debug, Default)]
225 struct Counter(atomic::AtomicU32);
226
227 impl Counter {
228 fn inc(&self) {
229 self.0.fetch_add(1, atomic::Ordering::SeqCst);
230 }
231
232 fn take(&self) -> u32 {
233 self.0.swap(0, atomic::Ordering::SeqCst)
234 }
235 }
236
237 #[test]
238 fn generator_waits_for_item_to_yield() {
239 let counter = Counter::default();
240
241 let s = generate(|mut co| {
242 let counter = &counter;
243 async move {
244 counter.inc();
245 co.yield_("first").await;
246
247 counter.inc();
250 yield_once().await;
251
252 counter.inc();
253 co.yield_("second").await;
254
255 drop(co);
256 yield_once().await;
257
258 counter.inc();
259 }
260 });
261
262 block_on(async {
263 futures::pin_mut!(s);
264
265 assert_eq!(counter.take(), 0);
266
267 assert_eq!(s.next().await, Some(GeneratorState::Yielded("first")));
268 assert_eq!(counter.take(), 1);
269
270 assert_eq!(s.next().await, Some(GeneratorState::Yielded("second")));
271 assert_eq!(counter.take(), 2);
272
273 assert_eq!(s.next().await, Some(GeneratorState::Complete(())));
274 assert_eq!(counter.take(), 1);
275
276 assert_eq!(s.next().await, None);
277 assert_eq!(counter.take(), 0);
278 });
279 }
280
281 #[test]
282 fn yield_all_yields_all() {
283 let s = generate(|mut co| async move {
284 co.yield_all(1u32..4).await;
285 co.yield_(42).await;
286 });
287
288 let res = block_on(s.collect::<Vec<GeneratorState<u32, ()>>>());
289
290 assert_eq!(
291 res,
292 vec![
293 GeneratorState::Yielded(1),
294 GeneratorState::Yielded(2),
295 GeneratorState::Yielded(3),
296 GeneratorState::Yielded(42),
297 GeneratorState::Complete(()),
298 ]
299 );
300 }
301
302 #[test]
303 fn fused_impl() {
304 let s = generate(|mut co| async move {
305 co.yield_(1u32).await;
306 drop(co);
307
308 yield_once().await;
309
310 "done"
311 });
312
313 block_on(async {
314 futures::pin_mut!(s);
315
316 assert!(!s.is_terminated());
317 assert_eq!(s.next().await, Some(GeneratorState::Yielded(1)));
318
319 assert!(!s.is_terminated());
320 assert_eq!(s.next().await, Some(GeneratorState::Complete("done")));
321
322 assert!(s.is_terminated());
326 assert_eq!(s.next().await, None);
327
328 assert!(s.is_terminated());
329 });
330 }
331
332 #[test]
333 fn into_try_stream_transposes_generator_states() {
334 let s = generate(|mut co| async move {
335 co.yield_(1u8).await;
336 co.yield_(2u8).await;
337
338 Result::<(), &'static str>::Err("oops")
339 })
340 .into_try_stream();
341
342 let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
343
344 assert_eq!(res, vec![Ok(1), Ok(2), Err("oops")]);
345 }
346
347 #[test]
348 fn into_try_stream_eats_unit_success() {
349 let s = generate(|mut co| async move {
350 co.yield_(1u8).await;
351 co.yield_(2u8).await;
352
353 Result::<(), &'static str>::Ok(())
354 })
355 .into_try_stream();
356
357 let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
358
359 assert_eq!(res, vec![Ok(1), Ok(2)]);
360 }
361
362 #[test]
363 fn runs_task_to_completion() {
364 let finished = Counter::default();
365
366 let make_s = || {
367 generate(|mut co| async {
368 co.yield_(8u8).await;
369
370 drop(co);
372 yield_once().await;
373
374 finished.inc();
375 })
376 };
377
378 block_on(async {
381 let res = make_s().collect::<Vec<GeneratorState<u8, ()>>>().await;
382 assert_eq!(
383 res,
384 vec![GeneratorState::Yielded(8), GeneratorState::Complete(())]
385 );
386 assert_eq!(finished.take(), 1);
387 });
388
389 block_on(async {
390 assert_eq!(make_s().into_yielded().collect::<Vec<_>>().await, vec![8]);
391 assert_eq!(finished.take(), 1);
392 });
393
394 block_on(async {
395 let () = make_s().into_complete().await;
396 assert_eq!(finished.take(), 1);
397 });
398 }
399
400 #[test]
401 fn fibonacci() {
402 let fib = generate(|mut co| async move {
403 let (mut a, mut b) = (0u32, 1u32);
404 loop {
405 co.yield_(a).await;
406
407 let n = b;
408 b += a;
409 a = n;
410 }
411 })
412 .into_yielded()
413 .take(10)
414 .collect::<Vec<_>>();
415
416 assert_eq!(block_on(fib), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
417 }
418}