futures_util/stream/
select_with_strategy.rs

1use super::assert_stream;
2use core::{fmt, pin::Pin};
3use futures_core::stream::{FusedStream, Stream};
4use futures_core::task::{Context, Poll};
5use pin_project_lite::pin_project;
6
7/// Type to tell [`SelectWithStrategy`] which stream to poll next.
8#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)]
9pub enum PollNext {
10    /// Poll the first stream.
11    Left,
12    /// Poll the second stream.
13    Right,
14}
15
16impl PollNext {
17    /// Toggle the value and return the old one.
18    pub fn toggle(&mut self) -> Self {
19        let old = *self;
20        *self = self.other();
21        old
22    }
23
24    fn other(&self) -> PollNext {
25        match self {
26            PollNext::Left => PollNext::Right,
27            PollNext::Right => PollNext::Left,
28        }
29    }
30}
31
32impl Default for PollNext {
33    fn default() -> Self {
34        PollNext::Left
35    }
36}
37
38enum InternalState {
39    Start,
40    LeftFinished,
41    RightFinished,
42    BothFinished,
43}
44
45impl InternalState {
46    fn finish(&mut self, ps: PollNext) {
47        match (&self, ps) {
48            (InternalState::Start, PollNext::Left) => {
49                *self = InternalState::LeftFinished;
50            }
51            (InternalState::Start, PollNext::Right) => {
52                *self = InternalState::RightFinished;
53            }
54            (InternalState::LeftFinished, PollNext::Right)
55            | (InternalState::RightFinished, PollNext::Left) => {
56                *self = InternalState::BothFinished;
57            }
58            _ => {}
59        }
60    }
61}
62
63pin_project! {
64    /// Stream for the [`select_with_strategy()`] function. See function docs for details.
65    #[must_use = "streams do nothing unless polled"]
66    #[project = SelectWithStrategyProj]
67    pub struct SelectWithStrategy<St1, St2, Clos, State> {
68        #[pin]
69        stream1: St1,
70        #[pin]
71        stream2: St2,
72        internal_state: InternalState,
73        state: State,
74        clos: Clos,
75    }
76}
77
78/// This function will attempt to pull items from both streams. You provide a
79/// closure to tell [`SelectWithStrategy`] which stream to poll. The closure can
80/// store state on `SelectWithStrategy` to which it will receive a `&mut` on every
81/// invocation. This allows basing the strategy on prior choices.
82///
83/// After one of the two input streams completes, the remaining one will be
84/// polled exclusively. The returned stream completes when both input
85/// streams have completed.
86///
87/// Note that this function consumes both streams and returns a wrapped
88/// version of them.
89///
90/// ## Examples
91///
92/// ### Priority
93/// This example shows how to always prioritize the left stream.
94///
95/// ```rust
96/// # futures::executor::block_on(async {
97/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
98///
99/// let left = repeat(1);
100/// let right = repeat(2);
101///
102/// // We don't need any state, so let's make it an empty tuple.
103/// // We must provide some type here, as there is no way for the compiler
104/// // to infer it. As we don't need to capture variables, we can just
105/// // use a function pointer instead of a closure.
106/// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left }
107///
108/// let mut out = select_with_strategy(left, right, prio_left);
109///
110/// for _ in 0..100 {
111///     // Whenever we poll out, we will alwas get `1`.
112///     assert_eq!(1, out.select_next_some().await);
113/// }
114/// # });
115/// ```
116///
117/// ### Round Robin
118/// This example shows how to select from both streams round robin.
119/// Note: this special case is provided by [`futures-util::stream::select`].
120///
121/// ```rust
122/// # futures::executor::block_on(async {
123/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
124///
125/// let left = repeat(1);
126/// let right = repeat(2);
127///
128/// let rrobin = |last: &mut PollNext| last.toggle();
129///
130/// let mut out = select_with_strategy(left, right, rrobin);
131///
132/// for _ in 0..100 {
133///     // We should be alternating now.
134///     assert_eq!(1, out.select_next_some().await);
135///     assert_eq!(2, out.select_next_some().await);
136/// }
137/// # });
138/// ```
139pub fn select_with_strategy<St1, St2, Clos, State>(
140    stream1: St1,
141    stream2: St2,
142    which: Clos,
143) -> SelectWithStrategy<St1, St2, Clos, State>
144where
145    St1: Stream,
146    St2: Stream<Item = St1::Item>,
147    Clos: FnMut(&mut State) -> PollNext,
148    State: Default,
149{
150    assert_stream::<St1::Item, _>(SelectWithStrategy {
151        stream1,
152        stream2,
153        state: Default::default(),
154        internal_state: InternalState::Start,
155        clos: which,
156    })
157}
158
159impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
160    /// Acquires a reference to the underlying streams that this combinator is
161    /// pulling from.
162    pub fn get_ref(&self) -> (&St1, &St2) {
163        (&self.stream1, &self.stream2)
164    }
165
166    /// Acquires a mutable reference to the underlying streams that this
167    /// combinator is pulling from.
168    ///
169    /// Note that care must be taken to avoid tampering with the state of the
170    /// stream which may otherwise confuse this combinator.
171    pub fn get_mut(&mut self) -> (&mut St1, &mut St2) {
172        (&mut self.stream1, &mut self.stream2)
173    }
174
175    /// Acquires a pinned mutable reference to the underlying streams that this
176    /// combinator is pulling from.
177    ///
178    /// Note that care must be taken to avoid tampering with the state of the
179    /// stream which may otherwise confuse this combinator.
180    pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) {
181        let this = self.project();
182        (this.stream1, this.stream2)
183    }
184
185    /// Consumes this combinator, returning the underlying streams.
186    ///
187    /// Note that this may discard intermediate state of this combinator, so
188    /// care should be taken to avoid losing resources when this is called.
189    pub fn into_inner(self) -> (St1, St2) {
190        (self.stream1, self.stream2)
191    }
192}
193
194impl<St1, St2, Clos, State> FusedStream for SelectWithStrategy<St1, St2, Clos, State>
195where
196    St1: Stream,
197    St2: Stream<Item = St1::Item>,
198    Clos: FnMut(&mut State) -> PollNext,
199{
200    fn is_terminated(&self) -> bool {
201        match self.internal_state {
202            InternalState::BothFinished => true,
203            _ => false,
204        }
205    }
206}
207
208#[inline]
209fn poll_side<St1, St2, Clos, State>(
210    select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
211    side: PollNext,
212    cx: &mut Context<'_>,
213) -> Poll<Option<St1::Item>>
214where
215    St1: Stream,
216    St2: Stream<Item = St1::Item>,
217{
218    match side {
219        PollNext::Left => select.stream1.as_mut().poll_next(cx),
220        PollNext::Right => select.stream2.as_mut().poll_next(cx),
221    }
222}
223
224#[inline]
225fn poll_inner<St1, St2, Clos, State>(
226    select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
227    side: PollNext,
228    cx: &mut Context<'_>,
229) -> Poll<Option<St1::Item>>
230where
231    St1: Stream,
232    St2: Stream<Item = St1::Item>,
233{
234    let first_done = match poll_side(select, side, cx) {
235        Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
236        Poll::Ready(None) => {
237            select.internal_state.finish(side);
238            true
239        }
240        Poll::Pending => false,
241    };
242    let other = side.other();
243    match poll_side(select, other, cx) {
244        Poll::Ready(None) => {
245            select.internal_state.finish(other);
246            if first_done {
247                Poll::Ready(None)
248            } else {
249                Poll::Pending
250            }
251        }
252        a => a,
253    }
254}
255
256impl<St1, St2, Clos, State> Stream for SelectWithStrategy<St1, St2, Clos, State>
257where
258    St1: Stream,
259    St2: Stream<Item = St1::Item>,
260    Clos: FnMut(&mut State) -> PollNext,
261{
262    type Item = St1::Item;
263
264    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
265        let mut this = self.project();
266
267        match this.internal_state {
268            InternalState::Start => {
269                let next_side = (this.clos)(this.state);
270                poll_inner(&mut this, next_side, cx)
271            }
272            InternalState::LeftFinished => match this.stream2.poll_next(cx) {
273                Poll::Ready(None) => {
274                    *this.internal_state = InternalState::BothFinished;
275                    Poll::Ready(None)
276                }
277                a => a,
278            },
279            InternalState::RightFinished => match this.stream1.poll_next(cx) {
280                Poll::Ready(None) => {
281                    *this.internal_state = InternalState::BothFinished;
282                    Poll::Ready(None)
283                }
284                a => a,
285            },
286            InternalState::BothFinished => Poll::Ready(None),
287        }
288    }
289}
290
291impl<St1, St2, Clos, State> fmt::Debug for SelectWithStrategy<St1, St2, Clos, State>
292where
293    St1: fmt::Debug,
294    St2: fmt::Debug,
295    State: fmt::Debug,
296{
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        f.debug_struct("SelectWithStrategy")
299            .field("stream1", &self.stream1)
300            .field("stream2", &self.stream2)
301            .field("state", &self.state)
302            .finish()
303    }
304}