trust_dns_resolver/name_server/
connection_provider.rs1use std::marker::Unpin;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use futures_util::future::{Future, FutureExt};
13use futures_util::ready;
14use futures_util::stream::{Stream, StreamExt};
15#[cfg(feature = "tokio-runtime")]
16use tokio::net::TcpStream as TokioTcpStream;
17#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
18use tokio_native_tls::TlsStream as TokioTlsStream;
19#[cfg(all(
20 feature = "dns-over-openssl",
21 not(feature = "dns-over-rustls"),
22 not(feature = "dns-over-native-tls")
23))]
24use tokio_openssl::SslStream as TokioTlsStream;
25#[cfg(feature = "dns-over-rustls")]
26use tokio_rustls::client::TlsStream as TokioTlsStream;
27
28#[cfg(feature = "dns-over-https")]
29use proto::https::{HttpsClientConnect, HttpsClientStream};
30#[cfg(feature = "mdns")]
31use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
32#[cfg(feature = "dns-over-quic")]
33use proto::quic::{QuicClientConnect, QuicClientStream};
34use proto::{
35 self,
36 error::ProtoError,
37 op::NoopMessageFinalizer,
38 tcp::Connect,
39 tcp::TcpClientConnect,
40 tcp::TcpClientStream,
41 udp::UdpClientConnect,
42 udp::{UdpClientStream, UdpSocket},
43 xfer::{
44 DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
45 DnsMultiplexerConnect, DnsRequest, DnsResponse,
46 },
47 Time,
48};
49#[cfg(feature = "tokio-runtime")]
50use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
51
52use crate::config::Protocol;
53use crate::config::{NameServerConfig, ResolverOpts};
54use crate::error::ResolveError;
55
56pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
60 type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
62
63 type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
65
66 type Time: Time;
68
69 fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
71 -> Self::FutureConn;
72}
73
74pub trait RuntimeProvider: Clone + 'static {
76 type Handle: Clone + Send + Spawn + Sync + Unpin;
78
79 type Timer: Time + Send + Unpin;
81
82 type Udp: UdpSocket + Send;
84
85 type Tcp: Connect;
87}
88
89pub trait Spawn {
91 fn spawn_bg<F>(&mut self, future: F)
93 where
94 F: Future<Output = Result<(), ProtoError>> + Send + 'static;
95}
96
97#[derive(Clone)]
99pub struct GenericConnectionProvider<R: RuntimeProvider>(R::Handle);
100
101impl<R: RuntimeProvider> GenericConnectionProvider<R> {
102 pub fn new(handle: R::Handle) -> Self {
104 Self(handle)
105 }
106}
107
108impl<R> ConnectionProvider for GenericConnectionProvider<R>
109where
110 R: RuntimeProvider,
111 <R as RuntimeProvider>::Tcp: Connect,
112{
113 type Conn = GenericConnection;
114 type FutureConn = ConnectionFuture<R>;
115 type Time = R::Timer;
116
117 fn new_connection(
120 &self,
121 config: &NameServerConfig,
122 options: &ResolverOpts,
123 ) -> Self::FutureConn {
124 let dns_connect = match config.protocol {
125 Protocol::Udp => {
126 let stream = UdpClientStream::<R::Udp>::with_bind_addr_and_timeout(
127 config.socket_addr,
128 config.bind_addr,
129 options.timeout,
130 );
131 let exchange = DnsExchange::connect(stream);
132 ConnectionConnect::Udp(exchange)
133 }
134 Protocol::Tcp => {
135 let socket_addr = config.socket_addr;
136 let bind_addr = config.bind_addr;
137 let timeout = options.timeout;
138
139 let (stream, handle) = TcpClientStream::<R::Tcp>::with_bind_addr_and_timeout(
140 socket_addr,
141 bind_addr,
142 timeout,
143 );
144 let dns_conn = DnsMultiplexer::with_timeout(
146 stream,
147 handle,
148 timeout,
149 NoopMessageFinalizer::new(),
150 );
151
152 let exchange = DnsExchange::connect(dns_conn);
153 ConnectionConnect::Tcp(exchange)
154 }
155 #[cfg(feature = "dns-over-tls")]
156 Protocol::Tls => {
157 let socket_addr = config.socket_addr;
158 let bind_addr = config.bind_addr;
159 let timeout = options.timeout;
160 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
161 #[cfg(feature = "dns-over-rustls")]
162 let client_config = config.tls_config.clone();
163
164 #[cfg(feature = "dns-over-rustls")]
165 let (stream, handle) = {
166 crate::tls::new_tls_stream::<R>(
167 socket_addr,
168 bind_addr,
169 tls_dns_name,
170 client_config,
171 )
172 };
173 #[cfg(not(feature = "dns-over-rustls"))]
174 let (stream, handle) =
175 { crate::tls::new_tls_stream::<R>(socket_addr, bind_addr, tls_dns_name) };
176
177 let dns_conn = DnsMultiplexer::with_timeout(
178 stream,
179 handle,
180 timeout,
181 NoopMessageFinalizer::new(),
182 );
183
184 let exchange = DnsExchange::connect(dns_conn);
185 ConnectionConnect::Tls(exchange)
186 }
187 #[cfg(feature = "dns-over-https")]
188 Protocol::Https => {
189 let socket_addr = config.socket_addr;
190 let bind_addr = config.bind_addr;
191 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
192 #[cfg(feature = "dns-over-rustls")]
193 let client_config = config.tls_config.clone();
194
195 let exchange = crate::https::new_https_stream::<R>(
196 socket_addr,
197 bind_addr,
198 tls_dns_name,
199 client_config,
200 );
201 ConnectionConnect::Https(exchange)
202 }
203 #[cfg(feature = "dns-over-quic")]
204 Protocol::Quic => {
205 let socket_addr = config.socket_addr;
206 let bind_addr = config.bind_addr;
207 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
208 #[cfg(feature = "dns-over-rustls")]
209 let client_config = config.tls_config.clone();
210
211 let exchange = crate::quic::new_quic_stream(
212 socket_addr,
213 bind_addr,
214 tls_dns_name,
215 client_config,
216 );
217 ConnectionConnect::Quic(exchange)
218 }
219 #[cfg(feature = "mdns")]
220 Protocol::Mdns => {
221 let socket_addr = config.socket_addr;
222 let timeout = options.timeout;
223
224 let (stream, handle) =
225 MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
226 let dns_conn = DnsMultiplexer::with_timeout(
228 stream,
229 handle,
230 timeout,
231 NoopMessageFinalizer::new(),
232 );
233
234 let exchange = DnsExchange::connect(dns_conn);
235 ConnectionConnect::Mdns(exchange)
236 }
237 };
238
239 ConnectionFuture {
240 connect: dns_connect,
241 spawner: self.0.clone(),
242 }
243 }
244}
245
246#[cfg(feature = "dns-over-tls")]
247type TlsClientStream<S> =
249 TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
250
251#[allow(clippy::large_enum_variant, clippy::type_complexity)]
253pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
254 Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
255 Tcp(
256 DnsExchangeConnect<
257 DnsMultiplexerConnect<
258 TcpClientConnect<<R as RuntimeProvider>::Tcp>,
259 TcpClientStream<<R as RuntimeProvider>::Tcp>,
260 NoopMessageFinalizer,
261 >,
262 DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
263 R::Timer,
264 >,
265 ),
266 #[cfg(feature = "dns-over-tls")]
267 Tls(
268 DnsExchangeConnect<
269 DnsMultiplexerConnect<
270 Pin<
271 Box<
272 dyn Future<
273 Output = Result<
274 TlsClientStream<<R as RuntimeProvider>::Tcp>,
275 ProtoError,
276 >,
277 > + Send
278 + 'static,
279 >,
280 >,
281 TlsClientStream<<R as RuntimeProvider>::Tcp>,
282 NoopMessageFinalizer,
283 >,
284 DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
285 TokioTime,
286 >,
287 ),
288 #[cfg(feature = "dns-over-https")]
289 Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
290 #[cfg(feature = "dns-over-quic")]
291 Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
292 #[cfg(feature = "mdns")]
293 Mdns(
294 DnsExchangeConnect<
295 DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
296 DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
297 TokioTime,
298 >,
299 ),
300}
301
302#[must_use = "futures do nothing unless polled"]
304pub struct ConnectionFuture<R: RuntimeProvider> {
305 connect: ConnectionConnect<R>,
306 spawner: R::Handle,
307}
308
309impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
310 type Output = Result<GenericConnection, ResolveError>;
311
312 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313 Poll::Ready(Ok(match &mut self.connect {
314 ConnectionConnect::Udp(ref mut conn) => {
315 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
316 self.spawner.spawn_bg(bg);
317 GenericConnection(conn)
318 }
319 ConnectionConnect::Tcp(ref mut conn) => {
320 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
321 self.spawner.spawn_bg(bg);
322 GenericConnection(conn)
323 }
324 #[cfg(feature = "dns-over-tls")]
325 ConnectionConnect::Tls(ref mut conn) => {
326 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
327 self.spawner.spawn_bg(bg);
328 GenericConnection(conn)
329 }
330 #[cfg(feature = "dns-over-https")]
331 ConnectionConnect::Https(ref mut conn) => {
332 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
333 self.spawner.spawn_bg(bg);
334 GenericConnection(conn)
335 }
336 #[cfg(feature = "dns-over-quic")]
337 ConnectionConnect::Quic(ref mut conn) => {
338 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
339 self.spawner.spawn_bg(bg);
340 GenericConnection(conn)
341 }
342 #[cfg(feature = "mdns")]
343 ConnectionConnect::Mdns(ref mut conn) => {
344 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
345 self.spawner.spawn_bg(bg);
346 GenericConnection(conn)
347 }
348 }))
349 }
350}
351
352#[derive(Clone)]
354pub struct GenericConnection(DnsExchange);
355
356impl DnsHandle for GenericConnection {
357 type Response = ConnectionResponse;
358 type Error = ResolveError;
359
360 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
361 ConnectionResponse(self.0.send(request))
362 }
363}
364
365#[must_use = "steam do nothing unless polled"]
367pub struct ConnectionResponse(DnsExchangeSend);
368
369impl Stream for ConnectionResponse {
370 type Item = Result<DnsResponse, ResolveError>;
371
372 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
373 Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
374 }
375}
376
377#[cfg(feature = "tokio-runtime")]
378#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
379#[allow(unreachable_pub)]
380pub mod tokio_runtime {
381 use super::*;
382 use tokio::net::UdpSocket as TokioUdpSocket;
383
384 #[derive(Clone, Copy)]
386 pub struct TokioHandle;
387 impl Spawn for TokioHandle {
388 fn spawn_bg<F>(&mut self, future: F)
389 where
390 F: Future<Output = Result<(), ProtoError>> + Send + 'static,
391 {
392 let _join = tokio::spawn(future);
393 }
394 }
395
396 #[derive(Clone, Copy)]
398 pub struct TokioRuntime;
399 impl RuntimeProvider for TokioRuntime {
400 type Handle = TokioHandle;
401 type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
402 type Timer = TokioTime;
403 type Udp = TokioUdpSocket;
404 }
405
406 pub type TokioConnection = GenericConnection;
408
409 pub type TokioConnectionProvider = GenericConnectionProvider<TokioRuntime>;
411}