1use super::assert_stream;
2use core::{fmt, pin::Pin};
3use futures_core::stream::{FusedStream, Stream};
4use futures_core::task::{Context, Poll};
5use pin_project_lite::pin_project;
6
7#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)]
9pub enum PollNext {
10 Left,
12 Right,
14}
15
16impl PollNext {
17 pub fn toggle(&mut self) -> Self {
19 let old = *self;
20 *self = self.other();
21 old
22 }
23
24 fn other(&self) -> PollNext {
25 match self {
26 PollNext::Left => PollNext::Right,
27 PollNext::Right => PollNext::Left,
28 }
29 }
30}
31
32impl Default for PollNext {
33 fn default() -> Self {
34 PollNext::Left
35 }
36}
37
38enum InternalState {
39 Start,
40 LeftFinished,
41 RightFinished,
42 BothFinished,
43}
44
45impl InternalState {
46 fn finish(&mut self, ps: PollNext) {
47 match (&self, ps) {
48 (InternalState::Start, PollNext::Left) => {
49 *self = InternalState::LeftFinished;
50 }
51 (InternalState::Start, PollNext::Right) => {
52 *self = InternalState::RightFinished;
53 }
54 (InternalState::LeftFinished, PollNext::Right)
55 | (InternalState::RightFinished, PollNext::Left) => {
56 *self = InternalState::BothFinished;
57 }
58 _ => {}
59 }
60 }
61}
62
63pin_project! {
64 #[must_use = "streams do nothing unless polled"]
66 #[project = SelectWithStrategyProj]
67 pub struct SelectWithStrategy<St1, St2, Clos, State> {
68 #[pin]
69 stream1: St1,
70 #[pin]
71 stream2: St2,
72 internal_state: InternalState,
73 state: State,
74 clos: Clos,
75 }
76}
77
78pub fn select_with_strategy<St1, St2, Clos, State>(
140 stream1: St1,
141 stream2: St2,
142 which: Clos,
143) -> SelectWithStrategy<St1, St2, Clos, State>
144where
145 St1: Stream,
146 St2: Stream<Item = St1::Item>,
147 Clos: FnMut(&mut State) -> PollNext,
148 State: Default,
149{
150 assert_stream::<St1::Item, _>(SelectWithStrategy {
151 stream1,
152 stream2,
153 state: Default::default(),
154 internal_state: InternalState::Start,
155 clos: which,
156 })
157}
158
159impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
160 pub fn get_ref(&self) -> (&St1, &St2) {
163 (&self.stream1, &self.stream2)
164 }
165
166 pub fn get_mut(&mut self) -> (&mut St1, &mut St2) {
172 (&mut self.stream1, &mut self.stream2)
173 }
174
175 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) {
181 let this = self.project();
182 (this.stream1, this.stream2)
183 }
184
185 pub fn into_inner(self) -> (St1, St2) {
190 (self.stream1, self.stream2)
191 }
192}
193
194impl<St1, St2, Clos, State> FusedStream for SelectWithStrategy<St1, St2, Clos, State>
195where
196 St1: Stream,
197 St2: Stream<Item = St1::Item>,
198 Clos: FnMut(&mut State) -> PollNext,
199{
200 fn is_terminated(&self) -> bool {
201 match self.internal_state {
202 InternalState::BothFinished => true,
203 _ => false,
204 }
205 }
206}
207
208#[inline]
209fn poll_side<St1, St2, Clos, State>(
210 select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
211 side: PollNext,
212 cx: &mut Context<'_>,
213) -> Poll<Option<St1::Item>>
214where
215 St1: Stream,
216 St2: Stream<Item = St1::Item>,
217{
218 match side {
219 PollNext::Left => select.stream1.as_mut().poll_next(cx),
220 PollNext::Right => select.stream2.as_mut().poll_next(cx),
221 }
222}
223
224#[inline]
225fn poll_inner<St1, St2, Clos, State>(
226 select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
227 side: PollNext,
228 cx: &mut Context<'_>,
229) -> Poll<Option<St1::Item>>
230where
231 St1: Stream,
232 St2: Stream<Item = St1::Item>,
233{
234 let first_done = match poll_side(select, side, cx) {
235 Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
236 Poll::Ready(None) => {
237 select.internal_state.finish(side);
238 true
239 }
240 Poll::Pending => false,
241 };
242 let other = side.other();
243 match poll_side(select, other, cx) {
244 Poll::Ready(None) => {
245 select.internal_state.finish(other);
246 if first_done {
247 Poll::Ready(None)
248 } else {
249 Poll::Pending
250 }
251 }
252 a => a,
253 }
254}
255
256impl<St1, St2, Clos, State> Stream for SelectWithStrategy<St1, St2, Clos, State>
257where
258 St1: Stream,
259 St2: Stream<Item = St1::Item>,
260 Clos: FnMut(&mut State) -> PollNext,
261{
262 type Item = St1::Item;
263
264 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
265 let mut this = self.project();
266
267 match this.internal_state {
268 InternalState::Start => {
269 let next_side = (this.clos)(this.state);
270 poll_inner(&mut this, next_side, cx)
271 }
272 InternalState::LeftFinished => match this.stream2.poll_next(cx) {
273 Poll::Ready(None) => {
274 *this.internal_state = InternalState::BothFinished;
275 Poll::Ready(None)
276 }
277 a => a,
278 },
279 InternalState::RightFinished => match this.stream1.poll_next(cx) {
280 Poll::Ready(None) => {
281 *this.internal_state = InternalState::BothFinished;
282 Poll::Ready(None)
283 }
284 a => a,
285 },
286 InternalState::BothFinished => Poll::Ready(None),
287 }
288 }
289}
290
291impl<St1, St2, Clos, State> fmt::Debug for SelectWithStrategy<St1, St2, Clos, State>
292where
293 St1: fmt::Debug,
294 St2: fmt::Debug,
295 State: fmt::Debug,
296{
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.debug_struct("SelectWithStrategy")
299 .field("stream1", &self.stream1)
300 .field("stream2", &self.stream2)
301 .field("state", &self.state)
302 .finish()
303 }
304}