1use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23 BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24 CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33 sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37 pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45 where
46 S: DnsRequestSender + 'static + Send + Unpin,
47 {
48 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49 let message_sender = BufDnsRequestStreamHandle { sender };
50
51 Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52 }
53
54 pub fn from_stream_with_receiver<S, TE>(
56 stream: S,
57 receiver: mpsc::Receiver<OneshotDnsRequest>,
58 sender: BufDnsRequestStreamHandle,
59 ) -> (Self, DnsExchangeBackground<S, TE>)
60 where
61 S: DnsRequestSender + 'static + Send + Unpin,
62 {
63 let background = DnsExchangeBackground {
64 io_stream: stream,
65 outbound_messages: receiver.peekable(),
66 marker: PhantomData,
67 };
68
69 (Self { sender }, background)
70 }
71
72 pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76 where
77 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78 S: DnsRequestSender + 'static + Send + Unpin,
79 TE: Time + Unpin,
80 {
81 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82 let message_sender = BufDnsRequestStreamHandle { sender };
83
84 DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85 }
86}
87
88impl Clone for DnsExchange {
89 fn clone(&self) -> Self {
90 Self {
91 sender: self.sender.clone(),
92 }
93 }
94}
95
96impl DnsHandle for DnsExchange {
97 type Response = DnsExchangeSend;
98 type Error = ProtoError;
99
100 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
101 DnsExchangeSend {
102 result: self.sender.send(request),
103 _sender: self.sender.clone(), }
105 }
106}
107
108#[must_use = "futures do nothing unless polled"]
110pub struct DnsExchangeSend {
111 result: DnsResponseReceiver,
112 _sender: BufDnsRequestStreamHandle,
113}
114
115impl Stream for DnsExchangeSend {
116 type Item = Result<DnsResponse, ProtoError>;
117
118 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 self.result.poll_next_unpin(cx)
121 }
122}
123
124#[must_use = "futures do nothing unless polled"]
128pub struct DnsExchangeBackground<S, TE>
129where
130 S: DnsRequestSender + 'static + Send + Unpin,
131{
132 io_stream: S,
133 outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
134 marker: PhantomData<TE>,
135}
136
137impl<S, TE> DnsExchangeBackground<S, TE>
138where
139 S: DnsRequestSender + 'static + Send + Unpin,
140{
141 fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
142 (&mut self.io_stream, &mut self.outbound_messages)
143 }
144}
145
146impl<S, TE> Future for DnsExchangeBackground<S, TE>
147where
148 S: DnsRequestSender + 'static + Send + Unpin,
149 TE: Time + Unpin,
150{
151 type Output = Result<(), ProtoError>;
152
153 #[allow(clippy::unused_unit)]
154 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155 let (io_stream, outbound_messages) = self.pollable_split();
156 let mut io_stream = Pin::new(io_stream);
157 let mut outbound_messages = Pin::new(outbound_messages);
158
159 loop {
162 match io_stream.as_mut().poll_next(cx) {
164 Poll::Ready(Some(Ok(()))) => (),
166 Poll::Pending => {
167 if io_stream.is_shutdown() {
168 return Poll::Pending;
170 }
171
172 ()
174 } Poll::Ready(None) => {
176 debug!("io_stream is done, shutting down");
177 return Poll::Ready(Ok(()));
180 }
181 Poll::Ready(Some(Err(err))) => {
182 warn!("io_stream hit an error, shutting down: {}", err);
183
184 return Poll::Ready(Err(err));
185 }
186 }
187
188 match outbound_messages.as_mut().poll_next(cx) {
190 Poll::Ready(Some(dns_request)) => {
192 let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
194
195 match serial_response.send_response(io_stream.send_message(dns_request)) {
200 Ok(()) => (),
201 Err(_) => {
202 warn!("failed to associate send_message response to the sender");
203 }
204 }
205 }
206 Poll::Pending => return Poll::Pending,
208 Poll::Ready(None) => {
209 io_stream.shutdown();
211
212 }
214 }
215
216 }
218 }
219}
220
221pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
230where
231 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
232 S: DnsRequestSender + 'static,
233 TE: Time + Unpin;
234
235impl<F, S, TE> DnsExchangeConnect<F, S, TE>
236where
237 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
238 S: DnsRequestSender + 'static,
239 TE: Time + Unpin,
240{
241 fn connect(
242 connect_future: F,
243 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
244 sender: BufDnsRequestStreamHandle,
245 ) -> Self {
246 Self(DnsExchangeConnectInner::Connecting {
247 connect_future,
248 outbound_messages: Some(outbound_messages),
249 sender: Some(sender),
250 })
251 }
252}
253
254#[allow(clippy::type_complexity)]
255impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
256where
257 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
258 S: DnsRequestSender + 'static + Send + Unpin,
259 TE: Time + Unpin,
260{
261 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
262
263 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
264 self.0.poll_unpin(cx)
265 }
266}
267
268enum DnsExchangeConnectInner<F, S, TE>
269where
270 F: Future<Output = Result<S, ProtoError>> + 'static + Send,
271 S: DnsRequestSender + 'static + Send,
272 TE: Time + Unpin,
273{
274 Connecting {
275 connect_future: F,
276 outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
277 sender: Option<BufDnsRequestStreamHandle>,
278 },
279 Connected {
280 exchange: DnsExchange,
281 background: Option<DnsExchangeBackground<S, TE>>,
282 },
283 FailAll {
284 error: ProtoError,
285 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
286 },
287}
288
289#[allow(clippy::type_complexity)]
290impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
291where
292 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
293 S: DnsRequestSender + 'static + Send + Unpin,
294 TE: Time + Unpin,
295{
296 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
297
298 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299 loop {
300 let next;
301 match *self {
302 Self::Connecting {
303 ref mut connect_future,
304 ref mut outbound_messages,
305 ref mut sender,
306 } => {
307 let connect_future = Pin::new(connect_future);
308 match connect_future.poll(cx) {
309 Poll::Ready(Ok(stream)) => {
310 let (exchange, background) = DnsExchange::from_stream_with_receiver(
313 stream,
314 outbound_messages
315 .take()
316 .expect("cannot poll after complete"),
317 sender.take().expect("cannot poll after complete"),
318 );
319
320 next = Self::Connected {
321 exchange,
322 background: Some(background),
323 };
324 }
325 Poll::Pending => return Poll::Pending,
326 Poll::Ready(Err(error)) => {
327 debug!("stream errored while connecting: {:?}", error);
328 next = Self::FailAll {
329 error,
330 outbound_messages: outbound_messages
331 .take()
332 .expect("cannot poll after complete"),
333 }
334 }
335 };
336 }
337 Self::Connected {
338 ref exchange,
339 ref mut background,
340 } => {
341 let exchange = exchange.clone();
342 let background = background.take().expect("cannot poll after complete");
343
344 return Poll::Ready(Ok((exchange, background)));
345 }
346 Self::FailAll {
347 ref error,
348 ref mut outbound_messages,
349 } => {
350 while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
351 Poll::Ready(opt) => opt,
352 Poll::Pending => return Poll::Pending,
353 } {
354 outbound_message
356 .into_parts()
357 .1
358 .send_response(error.clone().into())
359 .ok();
360 }
361
362 return Poll::Ready(Err(error.clone()));
363 }
364 }
365
366 *self = next;
367 }
368 }
369}