hyper/common/
drain.rs

1use std::mem;
2
3use pin_project_lite::pin_project;
4use tokio::sync::watch;
5
6use super::{task, Future, Pin, Poll};
7
8pub(crate) fn channel() -> (Signal, Watch) {
9    let (tx, rx) = watch::channel(());
10    (Signal { tx }, Watch { rx })
11}
12
13pub(crate) struct Signal {
14    tx: watch::Sender<()>,
15}
16
17pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>);
18
19#[derive(Clone)]
20pub(crate) struct Watch {
21    rx: watch::Receiver<()>,
22}
23
24pin_project! {
25    #[allow(missing_debug_implementations)]
26    pub struct Watching<F, FN> {
27        #[pin]
28        future: F,
29        state: State<FN>,
30        watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
31        _rx: watch::Receiver<()>,
32    }
33}
34
35enum State<F> {
36    Watch(F),
37    Draining,
38}
39
40impl Signal {
41    pub(crate) fn drain(self) -> Draining {
42        let _ = self.tx.send(());
43        Draining(Box::pin(async move { self.tx.closed().await }))
44    }
45}
46
47impl Future for Draining {
48    type Output = ();
49
50    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
51        Pin::new(&mut self.as_mut().0).poll(cx)
52    }
53}
54
55impl Watch {
56    pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
57    where
58        F: Future,
59        FN: FnOnce(Pin<&mut F>),
60    {
61        let Self { mut rx } = self;
62        let _rx = rx.clone();
63        Watching {
64            future,
65            state: State::Watch(on_drain),
66            watch: Box::pin(async move {
67                let _ = rx.changed().await;
68            }),
69            // Keep the receiver alive until the future completes, so that
70            // dropping it can signal that draining has completed.
71            _rx,
72        }
73    }
74}
75
76impl<F, FN> Future for Watching<F, FN>
77where
78    F: Future,
79    FN: FnOnce(Pin<&mut F>),
80{
81    type Output = F::Output;
82
83    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
84        let mut me = self.project();
85        loop {
86            match mem::replace(me.state, State::Draining) {
87                State::Watch(on_drain) => {
88                    match Pin::new(&mut me.watch).poll(cx) {
89                        Poll::Ready(()) => {
90                            // Drain has been triggered!
91                            on_drain(me.future.as_mut());
92                        }
93                        Poll::Pending => {
94                            *me.state = State::Watch(on_drain);
95                            return me.future.poll(cx);
96                        }
97                    }
98                }
99                State::Draining => return me.future.poll(cx),
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    struct TestMe {
110        draining: bool,
111        finished: bool,
112        poll_cnt: usize,
113    }
114
115    impl Future for TestMe {
116        type Output = ();
117
118        fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
119            self.poll_cnt += 1;
120            if self.finished {
121                Poll::Ready(())
122            } else {
123                Poll::Pending
124            }
125        }
126    }
127
128    #[test]
129    fn watch() {
130        let mut mock = tokio_test::task::spawn(());
131        mock.enter(|cx, _| {
132            let (tx, rx) = channel();
133            let fut = TestMe {
134                draining: false,
135                finished: false,
136                poll_cnt: 0,
137            };
138
139            let mut watch = rx.watch(fut, |mut fut| {
140                fut.draining = true;
141            });
142
143            assert_eq!(watch.future.poll_cnt, 0);
144
145            // First poll should poll the inner future
146            assert!(Pin::new(&mut watch).poll(cx).is_pending());
147            assert_eq!(watch.future.poll_cnt, 1);
148
149            // Second poll should poll the inner future again
150            assert!(Pin::new(&mut watch).poll(cx).is_pending());
151            assert_eq!(watch.future.poll_cnt, 2);
152
153            let mut draining = tx.drain();
154            // Drain signaled, but needs another poll to be noticed.
155            assert!(!watch.future.draining);
156            assert_eq!(watch.future.poll_cnt, 2);
157
158            // Now, poll after drain has been signaled.
159            assert!(Pin::new(&mut watch).poll(cx).is_pending());
160            assert_eq!(watch.future.poll_cnt, 3);
161            assert!(watch.future.draining);
162
163            // Draining is not ready until watcher completes
164            assert!(Pin::new(&mut draining).poll(cx).is_pending());
165
166            // Finishing up the watch future
167            watch.future.finished = true;
168            assert!(Pin::new(&mut watch).poll(cx).is_ready());
169            assert_eq!(watch.future.poll_cnt, 4);
170            drop(watch);
171
172            assert!(Pin::new(&mut draining).poll(cx).is_ready());
173        })
174    }
175
176    #[test]
177    fn watch_clones() {
178        let mut mock = tokio_test::task::spawn(());
179        mock.enter(|cx, _| {
180            let (tx, rx) = channel();
181
182            let fut1 = TestMe {
183                draining: false,
184                finished: false,
185                poll_cnt: 0,
186            };
187            let fut2 = TestMe {
188                draining: false,
189                finished: false,
190                poll_cnt: 0,
191            };
192
193            let watch1 = rx.clone().watch(fut1, |mut fut| {
194                fut.draining = true;
195            });
196            let watch2 = rx.watch(fut2, |mut fut| {
197                fut.draining = true;
198            });
199
200            let mut draining = tx.drain();
201
202            // Still 2 outstanding watchers
203            assert!(Pin::new(&mut draining).poll(cx).is_pending());
204
205            // drop 1 for whatever reason
206            drop(watch1);
207
208            // Still not ready, 1 other watcher still pending
209            assert!(Pin::new(&mut draining).poll(cx).is_pending());
210
211            drop(watch2);
212
213            // Now all watchers are gone, draining is complete
214            assert!(Pin::new(&mut draining).poll(cx).is_ready());
215        });
216    }
217}