trust_dns_resolver/name_server/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::marker::Unpin;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use futures_util::future::{Future, FutureExt};
13use futures_util::ready;
14use futures_util::stream::{Stream, StreamExt};
15#[cfg(feature = "tokio-runtime")]
16use tokio::net::TcpStream as TokioTcpStream;
17#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
18use tokio_native_tls::TlsStream as TokioTlsStream;
19#[cfg(all(
20    feature = "dns-over-openssl",
21    not(feature = "dns-over-rustls"),
22    not(feature = "dns-over-native-tls")
23))]
24use tokio_openssl::SslStream as TokioTlsStream;
25#[cfg(feature = "dns-over-rustls")]
26use tokio_rustls::client::TlsStream as TokioTlsStream;
27
28#[cfg(feature = "dns-over-https")]
29use proto::https::{HttpsClientConnect, HttpsClientStream};
30#[cfg(feature = "mdns")]
31use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
32#[cfg(feature = "dns-over-quic")]
33use proto::quic::{QuicClientConnect, QuicClientStream};
34use proto::{
35    self,
36    error::ProtoError,
37    op::NoopMessageFinalizer,
38    tcp::Connect,
39    tcp::TcpClientConnect,
40    tcp::TcpClientStream,
41    udp::UdpClientConnect,
42    udp::{UdpClientStream, UdpSocket},
43    xfer::{
44        DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
45        DnsMultiplexerConnect, DnsRequest, DnsResponse,
46    },
47    Time,
48};
49#[cfg(feature = "tokio-runtime")]
50use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
51
52use crate::config::Protocol;
53use crate::config::{NameServerConfig, ResolverOpts};
54use crate::error::ResolveError;
55
56/// A type to allow for custom ConnectionProviders. Needed mainly for mocking purposes.
57///
58/// ConnectionProvider is responsible for spawning any background tasks as necessary.
59pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
60    /// The handle to the connect for sending DNS requests.
61    type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
62
63    /// Ths future is responsible for spawning any background tasks as necessary
64    type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
65
66    /// The type used to set up timeout futures
67    type Time: Time;
68
69    /// The returned handle should
70    fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
71        -> Self::FutureConn;
72}
73
74/// RuntimeProvider defines which async runtime that handles IO and timers.
75pub trait RuntimeProvider: Clone + 'static {
76    /// Handle to the executor;
77    type Handle: Clone + Send + Spawn + Sync + Unpin;
78
79    /// Timer
80    type Timer: Time + Send + Unpin;
81
82    /// UdpSocket
83    type Udp: UdpSocket + Send;
84
85    /// TcpStream
86    type Tcp: Connect;
87}
88
89/// A type defines the Handle which can spawn future.
90pub trait Spawn {
91    /// Spawn a future in the background
92    fn spawn_bg<F>(&mut self, future: F)
93    where
94        F: Future<Output = Result<(), ProtoError>> + Send + 'static;
95}
96
97/// Standard connection implements the default mechanism for creating new Connections
98#[derive(Clone)]
99pub struct GenericConnectionProvider<R: RuntimeProvider>(R::Handle);
100
101impl<R: RuntimeProvider> GenericConnectionProvider<R> {
102    /// construct a new Connection provider based on the Runtime Handle
103    pub fn new(handle: R::Handle) -> Self {
104        Self(handle)
105    }
106}
107
108impl<R> ConnectionProvider for GenericConnectionProvider<R>
109where
110    R: RuntimeProvider,
111    <R as RuntimeProvider>::Tcp: Connect,
112{
113    type Conn = GenericConnection;
114    type FutureConn = ConnectionFuture<R>;
115    type Time = R::Timer;
116
117    /// Constructs an initial constructor for the ConnectionHandle to be used to establish a
118    ///   future connection.
119    fn new_connection(
120        &self,
121        config: &NameServerConfig,
122        options: &ResolverOpts,
123    ) -> Self::FutureConn {
124        let dns_connect = match config.protocol {
125            Protocol::Udp => {
126                let stream = UdpClientStream::<R::Udp>::with_bind_addr_and_timeout(
127                    config.socket_addr,
128                    config.bind_addr,
129                    options.timeout,
130                );
131                let exchange = DnsExchange::connect(stream);
132                ConnectionConnect::Udp(exchange)
133            }
134            Protocol::Tcp => {
135                let socket_addr = config.socket_addr;
136                let bind_addr = config.bind_addr;
137                let timeout = options.timeout;
138
139                let (stream, handle) = TcpClientStream::<R::Tcp>::with_bind_addr_and_timeout(
140                    socket_addr,
141                    bind_addr,
142                    timeout,
143                );
144                // TODO: need config for Signer...
145                let dns_conn = DnsMultiplexer::with_timeout(
146                    stream,
147                    handle,
148                    timeout,
149                    NoopMessageFinalizer::new(),
150                );
151
152                let exchange = DnsExchange::connect(dns_conn);
153                ConnectionConnect::Tcp(exchange)
154            }
155            #[cfg(feature = "dns-over-tls")]
156            Protocol::Tls => {
157                let socket_addr = config.socket_addr;
158                let bind_addr = config.bind_addr;
159                let timeout = options.timeout;
160                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
161                #[cfg(feature = "dns-over-rustls")]
162                let client_config = config.tls_config.clone();
163
164                #[cfg(feature = "dns-over-rustls")]
165                let (stream, handle) = {
166                    crate::tls::new_tls_stream::<R>(
167                        socket_addr,
168                        bind_addr,
169                        tls_dns_name,
170                        client_config,
171                    )
172                };
173                #[cfg(not(feature = "dns-over-rustls"))]
174                let (stream, handle) =
175                    { crate::tls::new_tls_stream::<R>(socket_addr, bind_addr, tls_dns_name) };
176
177                let dns_conn = DnsMultiplexer::with_timeout(
178                    stream,
179                    handle,
180                    timeout,
181                    NoopMessageFinalizer::new(),
182                );
183
184                let exchange = DnsExchange::connect(dns_conn);
185                ConnectionConnect::Tls(exchange)
186            }
187            #[cfg(feature = "dns-over-https")]
188            Protocol::Https => {
189                let socket_addr = config.socket_addr;
190                let bind_addr = config.bind_addr;
191                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
192                #[cfg(feature = "dns-over-rustls")]
193                let client_config = config.tls_config.clone();
194
195                let exchange = crate::https::new_https_stream::<R>(
196                    socket_addr,
197                    bind_addr,
198                    tls_dns_name,
199                    client_config,
200                );
201                ConnectionConnect::Https(exchange)
202            }
203            #[cfg(feature = "dns-over-quic")]
204            Protocol::Quic => {
205                let socket_addr = config.socket_addr;
206                let bind_addr = config.bind_addr;
207                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
208                #[cfg(feature = "dns-over-rustls")]
209                let client_config = config.tls_config.clone();
210
211                let exchange = crate::quic::new_quic_stream(
212                    socket_addr,
213                    bind_addr,
214                    tls_dns_name,
215                    client_config,
216                );
217                ConnectionConnect::Quic(exchange)
218            }
219            #[cfg(feature = "mdns")]
220            Protocol::Mdns => {
221                let socket_addr = config.socket_addr;
222                let timeout = options.timeout;
223
224                let (stream, handle) =
225                    MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
226                // TODO: need config for Signer...
227                let dns_conn = DnsMultiplexer::with_timeout(
228                    stream,
229                    handle,
230                    timeout,
231                    NoopMessageFinalizer::new(),
232                );
233
234                let exchange = DnsExchange::connect(dns_conn);
235                ConnectionConnect::Mdns(exchange)
236            }
237        };
238
239        ConnectionFuture {
240            connect: dns_connect,
241            spawner: self.0.clone(),
242        }
243    }
244}
245
246#[cfg(feature = "dns-over-tls")]
247/// Predefined type for TLS client stream
248type TlsClientStream<S> =
249    TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
250
251/// The variants of all supported connections for the Resolver
252#[allow(clippy::large_enum_variant, clippy::type_complexity)]
253pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
254    Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
255    Tcp(
256        DnsExchangeConnect<
257            DnsMultiplexerConnect<
258                TcpClientConnect<<R as RuntimeProvider>::Tcp>,
259                TcpClientStream<<R as RuntimeProvider>::Tcp>,
260                NoopMessageFinalizer,
261            >,
262            DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
263            R::Timer,
264        >,
265    ),
266    #[cfg(feature = "dns-over-tls")]
267    Tls(
268        DnsExchangeConnect<
269            DnsMultiplexerConnect<
270                Pin<
271                    Box<
272                        dyn Future<
273                                Output = Result<
274                                    TlsClientStream<<R as RuntimeProvider>::Tcp>,
275                                    ProtoError,
276                                >,
277                            > + Send
278                            + 'static,
279                    >,
280                >,
281                TlsClientStream<<R as RuntimeProvider>::Tcp>,
282                NoopMessageFinalizer,
283            >,
284            DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
285            TokioTime,
286        >,
287    ),
288    #[cfg(feature = "dns-over-https")]
289    Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
290    #[cfg(feature = "dns-over-quic")]
291    Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
292    #[cfg(feature = "mdns")]
293    Mdns(
294        DnsExchangeConnect<
295            DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
296            DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
297            TokioTime,
298        >,
299    ),
300}
301
302/// Resolves to a new Connection
303#[must_use = "futures do nothing unless polled"]
304pub struct ConnectionFuture<R: RuntimeProvider> {
305    connect: ConnectionConnect<R>,
306    spawner: R::Handle,
307}
308
309impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
310    type Output = Result<GenericConnection, ResolveError>;
311
312    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313        Poll::Ready(Ok(match &mut self.connect {
314            ConnectionConnect::Udp(ref mut conn) => {
315                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
316                self.spawner.spawn_bg(bg);
317                GenericConnection(conn)
318            }
319            ConnectionConnect::Tcp(ref mut conn) => {
320                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
321                self.spawner.spawn_bg(bg);
322                GenericConnection(conn)
323            }
324            #[cfg(feature = "dns-over-tls")]
325            ConnectionConnect::Tls(ref mut conn) => {
326                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
327                self.spawner.spawn_bg(bg);
328                GenericConnection(conn)
329            }
330            #[cfg(feature = "dns-over-https")]
331            ConnectionConnect::Https(ref mut conn) => {
332                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
333                self.spawner.spawn_bg(bg);
334                GenericConnection(conn)
335            }
336            #[cfg(feature = "dns-over-quic")]
337            ConnectionConnect::Quic(ref mut conn) => {
338                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
339                self.spawner.spawn_bg(bg);
340                GenericConnection(conn)
341            }
342            #[cfg(feature = "mdns")]
343            ConnectionConnect::Mdns(ref mut conn) => {
344                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
345                self.spawner.spawn_bg(bg);
346                GenericConnection(conn)
347            }
348        }))
349    }
350}
351
352/// A connected DNS handle
353#[derive(Clone)]
354pub struct GenericConnection(DnsExchange);
355
356impl DnsHandle for GenericConnection {
357    type Response = ConnectionResponse;
358    type Error = ResolveError;
359
360    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
361        ConnectionResponse(self.0.send(request))
362    }
363}
364
365/// A stream of response to a DNS request.
366#[must_use = "steam do nothing unless polled"]
367pub struct ConnectionResponse(DnsExchangeSend);
368
369impl Stream for ConnectionResponse {
370    type Item = Result<DnsResponse, ResolveError>;
371
372    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
373        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
374    }
375}
376
377#[cfg(feature = "tokio-runtime")]
378#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
379#[allow(unreachable_pub)]
380pub mod tokio_runtime {
381    use super::*;
382    use tokio::net::UdpSocket as TokioUdpSocket;
383
384    /// A handle to the Tokio runtime
385    #[derive(Clone, Copy)]
386    pub struct TokioHandle;
387    impl Spawn for TokioHandle {
388        fn spawn_bg<F>(&mut self, future: F)
389        where
390            F: Future<Output = Result<(), ProtoError>> + Send + 'static,
391        {
392            let _join = tokio::spawn(future);
393        }
394    }
395
396    /// The Tokio Runtime for async execution
397    #[derive(Clone, Copy)]
398    pub struct TokioRuntime;
399    impl RuntimeProvider for TokioRuntime {
400        type Handle = TokioHandle;
401        type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
402        type Timer = TokioTime;
403        type Udp = TokioUdpSocket;
404    }
405
406    /// An alias for Tokio use cases
407    pub type TokioConnection = GenericConnection;
408
409    /// An alias for Tokio use cases
410    pub type TokioConnectionProvider = GenericConnectionProvider<TokioRuntime>;
411}