1#[cfg(feature = "http2")]
2use std::future::Future;
3
4use futures_util::FutureExt;
5use tokio::sync::{mpsc, oneshot};
6
7#[cfg(feature = "http2")]
8use crate::common::Pin;
9use crate::common::{task, Poll};
10
11pub(crate) type RetryPromise<T, U> = oneshot::Receiver<Result<U, (crate::Error, Option<T>)>>;
12pub(crate) type Promise<T> = oneshot::Receiver<Result<T, crate::Error>>;
13
14pub(crate) fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
15 let (tx, rx) = mpsc::unbounded_channel();
16 let (giver, taker) = want::new();
17 let tx = Sender {
18 buffered_once: false,
19 giver,
20 inner: tx,
21 };
22 let rx = Receiver { inner: rx, taker };
23 (tx, rx)
24}
25
26pub(crate) struct Sender<T, U> {
31 buffered_once: bool,
35 giver: want::Giver,
40 inner: mpsc::UnboundedSender<Envelope<T, U>>,
42}
43
44#[cfg(feature = "http2")]
49pub(crate) struct UnboundedSender<T, U> {
50 giver: want::SharedGiver,
52 inner: mpsc::UnboundedSender<Envelope<T, U>>,
53}
54
55impl<T, U> Sender<T, U> {
56 pub(crate) fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
57 self.giver
58 .poll_want(cx)
59 .map_err(|_| crate::Error::new_closed())
60 }
61
62 pub(crate) fn is_ready(&self) -> bool {
63 self.giver.is_wanting()
64 }
65
66 pub(crate) fn is_closed(&self) -> bool {
67 self.giver.is_canceled()
68 }
69
70 fn can_send(&mut self) -> bool {
71 if self.giver.give() || !self.buffered_once {
72 self.buffered_once = true;
77 true
78 } else {
79 false
80 }
81 }
82
83 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
84 if !self.can_send() {
85 return Err(val);
86 }
87 let (tx, rx) = oneshot::channel();
88 self.inner
89 .send(Envelope(Some((val, Callback::Retry(tx)))))
90 .map(move |_| rx)
91 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
92 }
93
94 pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
95 if !self.can_send() {
96 return Err(val);
97 }
98 let (tx, rx) = oneshot::channel();
99 self.inner
100 .send(Envelope(Some((val, Callback::NoRetry(tx)))))
101 .map(move |_| rx)
102 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
103 }
104
105 #[cfg(feature = "http2")]
106 pub(crate) fn unbound(self) -> UnboundedSender<T, U> {
107 UnboundedSender {
108 giver: self.giver.shared(),
109 inner: self.inner,
110 }
111 }
112}
113
114#[cfg(feature = "http2")]
115impl<T, U> UnboundedSender<T, U> {
116 pub(crate) fn is_ready(&self) -> bool {
117 !self.giver.is_canceled()
118 }
119
120 pub(crate) fn is_closed(&self) -> bool {
121 self.giver.is_canceled()
122 }
123
124 pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
125 let (tx, rx) = oneshot::channel();
126 self.inner
127 .send(Envelope(Some((val, Callback::Retry(tx)))))
128 .map(move |_| rx)
129 .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
130 }
131}
132
133#[cfg(feature = "http2")]
134impl<T, U> Clone for UnboundedSender<T, U> {
135 fn clone(&self) -> Self {
136 UnboundedSender {
137 giver: self.giver.clone(),
138 inner: self.inner.clone(),
139 }
140 }
141}
142
143pub(crate) struct Receiver<T, U> {
144 inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
145 taker: want::Taker,
146}
147
148impl<T, U> Receiver<T, U> {
149 pub(crate) fn poll_recv(
150 &mut self,
151 cx: &mut task::Context<'_>,
152 ) -> Poll<Option<(T, Callback<T, U>)>> {
153 match self.inner.poll_recv(cx) {
154 Poll::Ready(item) => {
155 Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
156 }
157 Poll::Pending => {
158 self.taker.want();
159 Poll::Pending
160 }
161 }
162 }
163
164 #[cfg(feature = "http1")]
165 pub(crate) fn close(&mut self) {
166 self.taker.cancel();
167 self.inner.close();
168 }
169
170 #[cfg(feature = "http1")]
171 pub(crate) fn try_recv(&mut self) -> Option<(T, Callback<T, U>)> {
172 match self.inner.recv().now_or_never() {
173 Some(Some(mut env)) => env.0.take(),
174 _ => None,
175 }
176 }
177}
178
179impl<T, U> Drop for Receiver<T, U> {
180 fn drop(&mut self) {
181 self.taker.cancel();
184 }
185}
186
187struct Envelope<T, U>(Option<(T, Callback<T, U>)>);
188
189impl<T, U> Drop for Envelope<T, U> {
190 fn drop(&mut self) {
191 if let Some((val, cb)) = self.0.take() {
192 cb.send(Err((
193 crate::Error::new_canceled().with("connection closed"),
194 Some(val),
195 )));
196 }
197 }
198}
199
200pub(crate) enum Callback<T, U> {
201 Retry(oneshot::Sender<Result<U, (crate::Error, Option<T>)>>),
202 NoRetry(oneshot::Sender<Result<U, crate::Error>>),
203}
204
205impl<T, U> Callback<T, U> {
206 #[cfg(feature = "http2")]
207 pub(crate) fn is_canceled(&self) -> bool {
208 match *self {
209 Callback::Retry(ref tx) => tx.is_closed(),
210 Callback::NoRetry(ref tx) => tx.is_closed(),
211 }
212 }
213
214 pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> {
215 match *self {
216 Callback::Retry(ref mut tx) => tx.poll_closed(cx),
217 Callback::NoRetry(ref mut tx) => tx.poll_closed(cx),
218 }
219 }
220
221 pub(crate) fn send(self, val: Result<U, (crate::Error, Option<T>)>) {
222 match self {
223 Callback::Retry(tx) => {
224 let _ = tx.send(val);
225 }
226 Callback::NoRetry(tx) => {
227 let _ = tx.send(val.map_err(|e| e.0));
228 }
229 }
230 }
231
232 #[cfg(feature = "http2")]
233 pub(crate) async fn send_when(
234 self,
235 mut when: impl Future<Output = Result<U, (crate::Error, Option<T>)>> + Unpin,
236 ) {
237 use futures_util::future;
238 use tracing::trace;
239
240 let mut cb = Some(self);
241
242 future::poll_fn(move |cx| {
244 match Pin::new(&mut when).poll(cx) {
245 Poll::Ready(Ok(res)) => {
246 cb.take().expect("polled after complete").send(Ok(res));
247 Poll::Ready(())
248 }
249 Poll::Pending => {
250 ready!(cb.as_mut().unwrap().poll_canceled(cx));
252 trace!("send_when canceled");
253 Poll::Ready(())
254 }
255 Poll::Ready(Err(err)) => {
256 cb.take().expect("polled after complete").send(Err(err));
257 Poll::Ready(())
258 }
259 }
260 })
261 .await
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 #[cfg(feature = "nightly")]
268 extern crate test;
269
270 use std::future::Future;
271 use std::pin::Pin;
272 use std::task::{Context, Poll};
273
274 use super::{channel, Callback, Receiver};
275
276 #[derive(Debug)]
277 struct Custom(i32);
278
279 impl<T, U> Future for Receiver<T, U> {
280 type Output = Option<(T, Callback<T, U>)>;
281
282 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
283 self.poll_recv(cx)
284 }
285 }
286
287 struct PollOnce<'a, F>(&'a mut F);
289
290 impl<F, T> Future for PollOnce<'_, F>
291 where
292 F: Future<Output = T> + Unpin,
293 {
294 type Output = Option<()>;
295
296 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297 match Pin::new(&mut self.0).poll(cx) {
298 Poll::Ready(_) => Poll::Ready(Some(())),
299 Poll::Pending => Poll::Ready(None),
300 }
301 }
302 }
303
304 #[tokio::test]
305 async fn drop_receiver_sends_cancel_errors() {
306 let _ = pretty_env_logger::try_init();
307
308 let (mut tx, mut rx) = channel::<Custom, ()>();
309
310 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
312
313 let promise = tx.try_send(Custom(43)).unwrap();
314 drop(rx);
315
316 let fulfilled = promise.await;
317 let err = fulfilled
318 .expect("fulfilled")
319 .expect_err("promise should error");
320 match (err.0.kind(), err.1) {
321 (&crate::error::Kind::Canceled, Some(_)) => (),
322 e => panic!("expected Error::Cancel(_), found {:?}", e),
323 }
324 }
325
326 #[tokio::test]
327 async fn sender_checks_for_want_on_send() {
328 let (mut tx, mut rx) = channel::<Custom, ()>();
329
330 let _ = tx.try_send(Custom(1)).expect("1 buffered");
332 tx.try_send(Custom(2)).expect_err("2 not ready");
333
334 assert!(PollOnce(&mut rx).await.is_some(), "rx once");
335
336 tx.try_send(Custom(2)).expect_err("2 still not ready");
339
340 assert!(PollOnce(&mut rx).await.is_none(), "rx empty");
341
342 let _ = tx.try_send(Custom(2)).expect("2 ready");
343 }
344
345 #[cfg(feature = "http2")]
346 #[test]
347 fn unbounded_sender_doesnt_bound_on_want() {
348 let (tx, rx) = channel::<Custom, ()>();
349 let mut tx = tx.unbound();
350
351 let _ = tx.try_send(Custom(1)).unwrap();
352 let _ = tx.try_send(Custom(2)).unwrap();
353 let _ = tx.try_send(Custom(3)).unwrap();
354
355 drop(rx);
356
357 let _ = tx.try_send(Custom(4)).unwrap_err();
358 }
359
360 #[cfg(feature = "nightly")]
361 #[bench]
362 fn giver_queue_throughput(b: &mut test::Bencher) {
363 use crate::{Body, Request, Response};
364
365 let rt = tokio::runtime::Builder::new_current_thread()
366 .enable_all()
367 .build()
368 .unwrap();
369 let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
370
371 b.iter(move || {
372 let _ = tx.send(Request::default()).unwrap();
373 rt.block_on(async {
374 loop {
375 let poll_once = PollOnce(&mut rx);
376 let opt = poll_once.await;
377 if opt.is_none() {
378 break;
379 }
380 }
381 });
382 })
383 }
384
385 #[cfg(feature = "nightly")]
386 #[bench]
387 fn giver_queue_not_ready(b: &mut test::Bencher) {
388 let rt = tokio::runtime::Builder::new_current_thread()
389 .enable_all()
390 .build()
391 .unwrap();
392 let (_tx, mut rx) = channel::<i32, ()>();
393 b.iter(move || {
394 rt.block_on(async {
395 let poll_once = PollOnce(&mut rx);
396 assert!(poll_once.await.is_none());
397 });
398 })
399 }
400
401 #[cfg(feature = "nightly")]
402 #[bench]
403 fn giver_queue_cancel(b: &mut test::Bencher) {
404 let (_tx, mut rx) = channel::<i32, ()>();
405
406 b.iter(move || {
407 rx.taker.cancel();
408 })
409 }
410}