hyper/common/
drain.rs
1use std::mem;
2
3use pin_project_lite::pin_project;
4use tokio::sync::watch;
5
6use super::{task, Future, Pin, Poll};
7
8pub(crate) fn channel() -> (Signal, Watch) {
9 let (tx, rx) = watch::channel(());
10 (Signal { tx }, Watch { rx })
11}
12
13pub(crate) struct Signal {
14 tx: watch::Sender<()>,
15}
16
17pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>);
18
19#[derive(Clone)]
20pub(crate) struct Watch {
21 rx: watch::Receiver<()>,
22}
23
24pin_project! {
25 #[allow(missing_debug_implementations)]
26 pub struct Watching<F, FN> {
27 #[pin]
28 future: F,
29 state: State<FN>,
30 watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
31 _rx: watch::Receiver<()>,
32 }
33}
34
35enum State<F> {
36 Watch(F),
37 Draining,
38}
39
40impl Signal {
41 pub(crate) fn drain(self) -> Draining {
42 let _ = self.tx.send(());
43 Draining(Box::pin(async move { self.tx.closed().await }))
44 }
45}
46
47impl Future for Draining {
48 type Output = ();
49
50 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
51 Pin::new(&mut self.as_mut().0).poll(cx)
52 }
53}
54
55impl Watch {
56 pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
57 where
58 F: Future,
59 FN: FnOnce(Pin<&mut F>),
60 {
61 let Self { mut rx } = self;
62 let _rx = rx.clone();
63 Watching {
64 future,
65 state: State::Watch(on_drain),
66 watch: Box::pin(async move {
67 let _ = rx.changed().await;
68 }),
69 _rx,
72 }
73 }
74}
75
76impl<F, FN> Future for Watching<F, FN>
77where
78 F: Future,
79 FN: FnOnce(Pin<&mut F>),
80{
81 type Output = F::Output;
82
83 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
84 let mut me = self.project();
85 loop {
86 match mem::replace(me.state, State::Draining) {
87 State::Watch(on_drain) => {
88 match Pin::new(&mut me.watch).poll(cx) {
89 Poll::Ready(()) => {
90 on_drain(me.future.as_mut());
92 }
93 Poll::Pending => {
94 *me.state = State::Watch(on_drain);
95 return me.future.poll(cx);
96 }
97 }
98 }
99 State::Draining => return me.future.poll(cx),
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 struct TestMe {
110 draining: bool,
111 finished: bool,
112 poll_cnt: usize,
113 }
114
115 impl Future for TestMe {
116 type Output = ();
117
118 fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
119 self.poll_cnt += 1;
120 if self.finished {
121 Poll::Ready(())
122 } else {
123 Poll::Pending
124 }
125 }
126 }
127
128 #[test]
129 fn watch() {
130 let mut mock = tokio_test::task::spawn(());
131 mock.enter(|cx, _| {
132 let (tx, rx) = channel();
133 let fut = TestMe {
134 draining: false,
135 finished: false,
136 poll_cnt: 0,
137 };
138
139 let mut watch = rx.watch(fut, |mut fut| {
140 fut.draining = true;
141 });
142
143 assert_eq!(watch.future.poll_cnt, 0);
144
145 assert!(Pin::new(&mut watch).poll(cx).is_pending());
147 assert_eq!(watch.future.poll_cnt, 1);
148
149 assert!(Pin::new(&mut watch).poll(cx).is_pending());
151 assert_eq!(watch.future.poll_cnt, 2);
152
153 let mut draining = tx.drain();
154 assert!(!watch.future.draining);
156 assert_eq!(watch.future.poll_cnt, 2);
157
158 assert!(Pin::new(&mut watch).poll(cx).is_pending());
160 assert_eq!(watch.future.poll_cnt, 3);
161 assert!(watch.future.draining);
162
163 assert!(Pin::new(&mut draining).poll(cx).is_pending());
165
166 watch.future.finished = true;
168 assert!(Pin::new(&mut watch).poll(cx).is_ready());
169 assert_eq!(watch.future.poll_cnt, 4);
170 drop(watch);
171
172 assert!(Pin::new(&mut draining).poll(cx).is_ready());
173 })
174 }
175
176 #[test]
177 fn watch_clones() {
178 let mut mock = tokio_test::task::spawn(());
179 mock.enter(|cx, _| {
180 let (tx, rx) = channel();
181
182 let fut1 = TestMe {
183 draining: false,
184 finished: false,
185 poll_cnt: 0,
186 };
187 let fut2 = TestMe {
188 draining: false,
189 finished: false,
190 poll_cnt: 0,
191 };
192
193 let watch1 = rx.clone().watch(fut1, |mut fut| {
194 fut.draining = true;
195 });
196 let watch2 = rx.watch(fut2, |mut fut| {
197 fut.draining = true;
198 });
199
200 let mut draining = tx.drain();
201
202 assert!(Pin::new(&mut draining).poll(cx).is_pending());
204
205 drop(watch1);
207
208 assert!(Pin::new(&mut draining).poll(cx).is_pending());
210
211 drop(watch2);
212
213 assert!(Pin::new(&mut draining).poll(cx).is_ready());
215 });
216 }
217}