#![deny(missing_docs)]
#![allow(clippy::let_unit_value)]
use futures::channel::mpsc;
use futures::future::FusedFuture;
use futures::prelude::*;
use futures::stream::FusedStream;
use futures::task::{Context, Poll};
use pin_project::pin_project;
use std::pin::Pin;
pub fn generate<'a, I, R, C, F>(cb: C) -> Generator<F, I, R>
where
C: FnOnce(Yield<I>) -> F,
F: Future<Output = R> + 'a,
I: Send + 'static,
R: Send + 'static,
{
let (send, recv) = mpsc::channel(0);
Generator { task: cb(Yield(send)).fuse(), stream: recv, res: None }
}
pub struct Yield<I>(mpsc::Sender<I>);
impl<I> Yield<I>
where
I: Send + 'static,
{
pub fn yield_(&mut self, item: I) -> impl Future<Output = ()> + '_ {
self.0.send(item).map(|_| ())
}
pub fn yield_all<S>(&mut self, items: S) -> impl Future<Output = ()> + '_
where
S: IntoIterator<Item = I>,
S::IntoIter: 'static,
{
let mut items = futures::stream::iter(items.into_iter().map(Ok));
async move {
let _ = self.0.send_all(&mut items).await;
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum GeneratorState<I, R> {
Yielded(I),
Complete(R),
}
impl<I, R> GeneratorState<I, R> {
fn into_yielded(self) -> Option<I> {
match self {
GeneratorState::Yielded(item) => Some(item),
_ => None,
}
}
fn into_complete(self) -> Option<R> {
match self {
GeneratorState::Complete(res) => Some(res),
_ => None,
}
}
}
#[pin_project]
#[derive(Debug)]
pub struct Generator<F, I, R>
where
F: Future<Output = R>,
{
#[pin]
task: future::Fuse<F>,
#[pin]
stream: mpsc::Receiver<I>,
res: Option<R>,
}
impl<F, I, E> Generator<F, I, Result<(), E>>
where
F: Future<Output = Result<(), E>>,
{
pub fn into_try_stream(self) -> impl FusedStream<Item = Result<I, E>> {
self.filter_map(|state| {
future::ready(match state {
GeneratorState::Yielded(i) => Some(Ok(i)),
GeneratorState::Complete(Ok(())) => None,
GeneratorState::Complete(Err(e)) => Some(Err(e)),
})
})
}
}
impl<F, I, R> Generator<F, I, R>
where
F: Future<Output = R>,
{
pub async fn into_complete(self) -> R {
let s = self.filter_map(|state| future::ready(state.into_complete()));
futures::pin_mut!(s);
s.next().await.unwrap()
}
}
impl<F, I> Generator<F, I, ()>
where
F: Future<Output = ()>,
{
pub fn into_yielded(self) -> impl FusedStream<Item = I> {
self.filter_map(|state| future::ready(state.into_yielded()))
}
}
impl<F, I, R> Stream for Generator<F, I, R>
where
F: Future<Output = R>,
{
type Item = GeneratorState<I, R>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let mut task_done = this.task.is_terminated();
if let Poll::Ready(res) = this.task.poll(cx) {
this.res.replace(res);
task_done = true;
}
if !this.stream.is_terminated() {
match this.stream.poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(item)) => return Poll::Ready(Some(GeneratorState::Yielded(item))),
Poll::Ready(None) => {}
}
}
if !task_done {
return Poll::Pending;
}
match this.res.take() {
Some(res) => Poll::Ready(Some(GeneratorState::Complete(res))),
None => Poll::Ready(None),
}
}
}
impl<F, I, R> FusedStream for Generator<F, I, R>
where
F: Future<Output = R>,
{
fn is_terminated(&self) -> bool {
self.task.is_terminated() && self.stream.is_terminated() && self.res.is_none()
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
use std::sync::atomic;
fn yield_once() -> impl Future<Output = ()> {
let mut done = false;
future::poll_fn(move |cx: &mut Context<'_>| {
if !done {
done = true;
cx.waker().wake_by_ref();
Poll::Pending
} else {
Poll::Ready(())
}
})
}
#[derive(Debug, Default)]
struct Counter(atomic::AtomicU32);
impl Counter {
fn inc(&self) {
self.0.fetch_add(1, atomic::Ordering::SeqCst);
}
fn take(&self) -> u32 {
self.0.swap(0, atomic::Ordering::SeqCst)
}
}
#[test]
fn generator_waits_for_item_to_yield() {
let counter = Counter::default();
let s = generate(|mut co| {
let counter = &counter;
async move {
counter.inc();
co.yield_("first").await;
counter.inc();
yield_once().await;
counter.inc();
co.yield_("second").await;
drop(co);
yield_once().await;
counter.inc();
}
});
block_on(async {
futures::pin_mut!(s);
assert_eq!(counter.take(), 0);
assert_eq!(s.next().await, Some(GeneratorState::Yielded("first")));
assert_eq!(counter.take(), 1);
assert_eq!(s.next().await, Some(GeneratorState::Yielded("second")));
assert_eq!(counter.take(), 2);
assert_eq!(s.next().await, Some(GeneratorState::Complete(())));
assert_eq!(counter.take(), 1);
assert_eq!(s.next().await, None);
assert_eq!(counter.take(), 0);
});
}
#[test]
fn yield_all_yields_all() {
let s = generate(|mut co| async move {
co.yield_all(1u32..4).await;
co.yield_(42).await;
});
let res = block_on(s.collect::<Vec<GeneratorState<u32, ()>>>());
assert_eq!(
res,
vec![
GeneratorState::Yielded(1),
GeneratorState::Yielded(2),
GeneratorState::Yielded(3),
GeneratorState::Yielded(42),
GeneratorState::Complete(()),
]
);
}
#[test]
fn fused_impl() {
let s = generate(|mut co| async move {
co.yield_(1u32).await;
drop(co);
yield_once().await;
"done"
});
block_on(async {
futures::pin_mut!(s);
assert!(!s.is_terminated());
assert_eq!(s.next().await, Some(GeneratorState::Yielded(1)));
assert!(!s.is_terminated());
assert_eq!(s.next().await, Some(GeneratorState::Complete("done")));
assert!(s.is_terminated());
assert_eq!(s.next().await, None);
assert!(s.is_terminated());
});
}
#[test]
fn into_try_stream_transposes_generator_states() {
let s = generate(|mut co| async move {
co.yield_(1u8).await;
co.yield_(2u8).await;
Result::<(), &'static str>::Err("oops")
})
.into_try_stream();
let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
assert_eq!(res, vec![Ok(1), Ok(2), Err("oops")]);
}
#[test]
fn into_try_stream_eats_unit_success() {
let s = generate(|mut co| async move {
co.yield_(1u8).await;
co.yield_(2u8).await;
Result::<(), &'static str>::Ok(())
})
.into_try_stream();
let res = block_on(s.collect::<Vec<Result<u8, &'static str>>>());
assert_eq!(res, vec![Ok(1), Ok(2)]);
}
#[test]
fn runs_task_to_completion() {
let finished = Counter::default();
let make_s = || {
generate(|mut co| async {
co.yield_(8u8).await;
drop(co);
yield_once().await;
finished.inc();
})
};
block_on(async {
let res = make_s().collect::<Vec<GeneratorState<u8, ()>>>().await;
assert_eq!(res, vec![GeneratorState::Yielded(8), GeneratorState::Complete(())]);
assert_eq!(finished.take(), 1);
});
block_on(async {
assert_eq!(make_s().into_yielded().collect::<Vec<_>>().await, vec![8]);
assert_eq!(finished.take(), 1);
});
block_on(async {
let () = make_s().into_complete().await;
assert_eq!(finished.take(), 1);
});
}
#[test]
fn fibonacci() {
let fib = generate(|mut co| async move {
let (mut a, mut b) = (0u32, 1u32);
loop {
co.yield_(a).await;
let n = b;
b += a;
a = n;
}
})
.into_yielded()
.take(10)
.collect::<Vec<_>>();
assert_eq!(block_on(fib), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
}
}