trust_dns_proto/xfer/
dns_exchange.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the types for demuxing DNS oriented streams.
9
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23    BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24    CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28/// This is a generic Exchange implemented over multiplexed DNS connection providers.
29///
30/// The underlying `DnsRequestSender` is expected to multiplex any I/O connections. DnsExchange assumes that the underlying stream is responsible for this.
31#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33    sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37    /// Initializes a TcpStream with an existing tcp::TcpStream.
38    ///
39    /// This is intended for use with a TcpListener and Incoming.
40    ///
41    /// # Arguments
42    ///
43    /// * `stream` - the established IO stream for communication
44    pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45    where
46        S: DnsRequestSender + 'static + Send + Unpin,
47    {
48        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49        let message_sender = BufDnsRequestStreamHandle { sender };
50
51        Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52    }
53
54    /// Wraps a stream where a sender and receiver have already been established
55    pub fn from_stream_with_receiver<S, TE>(
56        stream: S,
57        receiver: mpsc::Receiver<OneshotDnsRequest>,
58        sender: BufDnsRequestStreamHandle,
59    ) -> (Self, DnsExchangeBackground<S, TE>)
60    where
61        S: DnsRequestSender + 'static + Send + Unpin,
62    {
63        let background = DnsExchangeBackground {
64            io_stream: stream,
65            outbound_messages: receiver.peekable(),
66            marker: PhantomData,
67        };
68
69        (Self { sender }, background)
70    }
71
72    /// Returns a future, which itself wraps a future which is awaiting connection.
73    ///
74    /// The connect_future should be lazy.
75    pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76    where
77        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78        S: DnsRequestSender + 'static + Send + Unpin,
79        TE: Time + Unpin,
80    {
81        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82        let message_sender = BufDnsRequestStreamHandle { sender };
83
84        DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85    }
86}
87
88impl Clone for DnsExchange {
89    fn clone(&self) -> Self {
90        Self {
91            sender: self.sender.clone(),
92        }
93    }
94}
95
96impl DnsHandle for DnsExchange {
97    type Response = DnsExchangeSend;
98    type Error = ProtoError;
99
100    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
101        DnsExchangeSend {
102            result: self.sender.send(request),
103            _sender: self.sender.clone(), // TODO: this shouldn't be necessary, currently the presence of Senders is what allows the background to track current users, it generally is dropped right after send, this makes sure that there is at least one active after send
104        }
105    }
106}
107
108/// A Stream that will resolve to Responses after sending the request
109#[must_use = "futures do nothing unless polled"]
110pub struct DnsExchangeSend {
111    result: DnsResponseReceiver,
112    _sender: BufDnsRequestStreamHandle,
113}
114
115impl Stream for DnsExchangeSend {
116    type Item = Result<DnsResponse, ProtoError>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        // as long as there is no result, poll the exchange
120        self.result.poll_next_unpin(cx)
121    }
122}
123
124/// This background future is responsible for driving all network operations for the DNS protocol.
125///
126/// It must be spawned before any DNS messages are sent.
127#[must_use = "futures do nothing unless polled"]
128pub struct DnsExchangeBackground<S, TE>
129where
130    S: DnsRequestSender + 'static + Send + Unpin,
131{
132    io_stream: S,
133    outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
134    marker: PhantomData<TE>,
135}
136
137impl<S, TE> DnsExchangeBackground<S, TE>
138where
139    S: DnsRequestSender + 'static + Send + Unpin,
140{
141    fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
142        (&mut self.io_stream, &mut self.outbound_messages)
143    }
144}
145
146impl<S, TE> Future for DnsExchangeBackground<S, TE>
147where
148    S: DnsRequestSender + 'static + Send + Unpin,
149    TE: Time + Unpin,
150{
151    type Output = Result<(), ProtoError>;
152
153    #[allow(clippy::unused_unit)]
154    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155        let (io_stream, outbound_messages) = self.pollable_split();
156        let mut io_stream = Pin::new(io_stream);
157        let mut outbound_messages = Pin::new(outbound_messages);
158
159        // this will not accept incoming data while there is data to send
160        //  makes this self throttling.
161        loop {
162            // poll the underlying stream, to drive it...
163            match io_stream.as_mut().poll_next(cx) {
164                // The stream is ready
165                Poll::Ready(Some(Ok(()))) => (),
166                Poll::Pending => {
167                    if io_stream.is_shutdown() {
168                        // the io_stream is in a shutdown state, we are only waiting for final results...
169                        return Poll::Pending;
170                    }
171
172                    // NotReady and not shutdown, see if there are more messages to send
173                    ()
174                } // underlying stream is complete.
175                Poll::Ready(None) => {
176                    debug!("io_stream is done, shutting down");
177                    // TODO: return shutdown error to anything in the stream?
178
179                    return Poll::Ready(Ok(()));
180                }
181                Poll::Ready(Some(Err(err))) => {
182                    warn!("io_stream hit an error, shutting down: {}", err);
183
184                    return Poll::Ready(Err(err));
185                }
186            }
187
188            // then see if there is more to send
189            match outbound_messages.as_mut().poll_next(cx) {
190                // already handled above, here to make sure the poll() pops the next message
191                Poll::Ready(Some(dns_request)) => {
192                    // if there is no peer, this connection should die...
193                    let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
194
195                    // Try to forward the `DnsResponseStream` to the requesting task. If we fail,
196                    // it must be because the requesting task has gone away / is no longer
197                    // interested. In that case, we can just log a warning, but there's no need
198                    // to take any more serious measures (such as shutting down this task).
199                    match serial_response.send_response(io_stream.send_message(dns_request)) {
200                        Ok(()) => (),
201                        Err(_) => {
202                            warn!("failed to associate send_message response to the sender");
203                        }
204                    }
205                }
206                // On not ready, this is our time to return...
207                Poll::Pending => return Poll::Pending,
208                Poll::Ready(None) => {
209                    // if there is nothing that can use this connection to send messages, then this is done...
210                    io_stream.shutdown();
211
212                    // now we'll await the stream to shutdown... see io_stream poll above
213                }
214            }
215
216            // else we loop to poll on the outbound_messages
217        }
218    }
219}
220
221/// A wrapper for a future DnsExchange connection.
222///
223/// DnsExchangeConnect is clonable, making it possible to share this if the connection
224///  will be shared across threads.
225///
226/// The future will return a tuple of the DnsExchange (for sending messages) and a background
227///  for running the background tasks. The background is optional as only one thread should run
228///  the background. If returned, it must be spawned before any dns requests will function.
229pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
230where
231    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
232    S: DnsRequestSender + 'static,
233    TE: Time + Unpin;
234
235impl<F, S, TE> DnsExchangeConnect<F, S, TE>
236where
237    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
238    S: DnsRequestSender + 'static,
239    TE: Time + Unpin,
240{
241    fn connect(
242        connect_future: F,
243        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
244        sender: BufDnsRequestStreamHandle,
245    ) -> Self {
246        Self(DnsExchangeConnectInner::Connecting {
247            connect_future,
248            outbound_messages: Some(outbound_messages),
249            sender: Some(sender),
250        })
251    }
252}
253
254#[allow(clippy::type_complexity)]
255impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
256where
257    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
258    S: DnsRequestSender + 'static + Send + Unpin,
259    TE: Time + Unpin,
260{
261    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
262
263    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
264        self.0.poll_unpin(cx)
265    }
266}
267
268enum DnsExchangeConnectInner<F, S, TE>
269where
270    F: Future<Output = Result<S, ProtoError>> + 'static + Send,
271    S: DnsRequestSender + 'static + Send,
272    TE: Time + Unpin,
273{
274    Connecting {
275        connect_future: F,
276        outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
277        sender: Option<BufDnsRequestStreamHandle>,
278    },
279    Connected {
280        exchange: DnsExchange,
281        background: Option<DnsExchangeBackground<S, TE>>,
282    },
283    FailAll {
284        error: ProtoError,
285        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
286    },
287}
288
289#[allow(clippy::type_complexity)]
290impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
291where
292    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
293    S: DnsRequestSender + 'static + Send + Unpin,
294    TE: Time + Unpin,
295{
296    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
297
298    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
299        loop {
300            let next;
301            match *self {
302                Self::Connecting {
303                    ref mut connect_future,
304                    ref mut outbound_messages,
305                    ref mut sender,
306                } => {
307                    let connect_future = Pin::new(connect_future);
308                    match connect_future.poll(cx) {
309                        Poll::Ready(Ok(stream)) => {
310                            //debug!("connection established: {}", stream);
311
312                            let (exchange, background) = DnsExchange::from_stream_with_receiver(
313                                stream,
314                                outbound_messages
315                                    .take()
316                                    .expect("cannot poll after complete"),
317                                sender.take().expect("cannot poll after complete"),
318                            );
319
320                            next = Self::Connected {
321                                exchange,
322                                background: Some(background),
323                            };
324                        }
325                        Poll::Pending => return Poll::Pending,
326                        Poll::Ready(Err(error)) => {
327                            debug!("stream errored while connecting: {:?}", error);
328                            next = Self::FailAll {
329                                error,
330                                outbound_messages: outbound_messages
331                                    .take()
332                                    .expect("cannot poll after complete"),
333                            }
334                        }
335                    };
336                }
337                Self::Connected {
338                    ref exchange,
339                    ref mut background,
340                } => {
341                    let exchange = exchange.clone();
342                    let background = background.take().expect("cannot poll after complete");
343
344                    return Poll::Ready(Ok((exchange, background)));
345                }
346                Self::FailAll {
347                    ref error,
348                    ref mut outbound_messages,
349                } => {
350                    while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
351                        Poll::Ready(opt) => opt,
352                        Poll::Pending => return Poll::Pending,
353                    } {
354                        // ignoring errors... best effort send...
355                        outbound_message
356                            .into_parts()
357                            .1
358                            .send_response(error.clone().into())
359                            .ok();
360                    }
361
362                    return Poll::Ready(Err(error.clone()));
363                }
364            }
365
366            *self = next;
367        }
368    }
369}