1use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::error::{ResolveError, ResolveErrorKind};
24#[cfg(feature = "mdns")]
25use crate::name_server;
26use crate::name_server::{ConnectionProvider, NameServer};
27#[cfg(test)]
28#[cfg(feature = "tokio-runtime")]
29use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
30
31#[derive(Clone)]
35pub struct NameServerPool<
36 C: DnsHandle<Error = ResolveError> + Send + Sync + 'static,
37 P: ConnectionProvider<Conn = C> + Send + 'static,
38> {
39 datagram_conns: Arc<[NameServer<C, P>]>, stream_conns: Arc<[NameServer<C, P>]>, #[cfg(feature = "mdns")]
43 mdns_conns: NameServer<C, P>, options: ResolverOpts,
45}
46
47#[cfg(test)]
48#[cfg(feature = "tokio-runtime")]
49impl NameServerPool<TokioConnection, TokioConnectionProvider> {
50 pub(crate) fn tokio_from_config(
51 config: &ResolverConfig,
52 options: &ResolverOpts,
53 runtime: TokioHandle,
54 ) -> Self {
55 Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
56 }
57}
58
59impl<C, P> NameServerPool<C, P>
60where
61 C: DnsHandle<Error = ResolveError> + Sync + 'static,
62 P: ConnectionProvider<Conn = C> + 'static,
63{
64 pub(crate) fn from_config_with_provider(
65 config: &ResolverConfig,
66 options: &ResolverOpts,
67 conn_provider: P,
68 ) -> Self {
69 let datagram_conns: Vec<NameServer<C, P>> = config
70 .name_servers()
71 .iter()
72 .filter(|ns_config| ns_config.protocol.is_datagram())
73 .map(|ns_config| {
74 #[cfg(feature = "dns-over-rustls")]
75 let ns_config = {
76 let mut ns_config = ns_config.clone();
77 ns_config.tls_config = config.client_config().clone();
78 ns_config
79 };
80 #[cfg(not(feature = "dns-over-rustls"))]
81 let ns_config = { ns_config.clone() };
82
83 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
84 })
85 .collect();
86
87 let stream_conns: Vec<NameServer<C, P>> = config
88 .name_servers()
89 .iter()
90 .filter(|ns_config| ns_config.protocol.is_stream())
91 .map(|ns_config| {
92 #[cfg(feature = "dns-over-rustls")]
93 let ns_config = {
94 let mut ns_config = ns_config.clone();
95 ns_config.tls_config = config.client_config().clone();
96 ns_config
97 };
98 #[cfg(not(feature = "dns-over-rustls"))]
99 let ns_config = { ns_config.clone() };
100
101 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
102 })
103 .collect();
104
105 Self {
106 datagram_conns: Arc::from(datagram_conns),
107 stream_conns: Arc::from(stream_conns),
108 #[cfg(feature = "mdns")]
109 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
110 options: *options,
111 }
112 }
113
114 pub fn from_config(
116 name_servers: NameServerConfigGroup,
117 options: &ResolverOpts,
118 conn_provider: P,
119 ) -> Self {
120 let map_config_to_ns = |ns_config| {
121 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
122 };
123
124 let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
125 .into_inner()
126 .into_iter()
127 .partition(|ns| ns.protocol.is_datagram());
128
129 let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
130 let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
131
132 Self {
133 datagram_conns: Arc::from(datagram_conns),
134 stream_conns: Arc::from(stream_conns),
135 #[cfg(feature = "mdns")]
136 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
137 options: *options,
138 }
139 }
140
141 #[doc(hidden)]
142 #[cfg(not(feature = "mdns"))]
143 pub fn from_nameservers(
144 options: &ResolverOpts,
145 datagram_conns: Vec<NameServer<C, P>>,
146 stream_conns: Vec<NameServer<C, P>>,
147 ) -> Self {
148 Self {
149 datagram_conns: Arc::from(datagram_conns),
150 stream_conns: Arc::from(stream_conns),
151 options: *options,
152 }
153 }
154
155 #[doc(hidden)]
156 #[cfg(feature = "mdns")]
157 pub fn from_nameservers(
158 options: &ResolverOpts,
159 datagram_conns: Vec<NameServer<C, P>>,
160 stream_conns: Vec<NameServer<C, P>>,
161 mdns_conns: NameServer<C, P>,
162 ) -> Self {
163 NameServerPool {
164 datagram_conns: Arc::from(datagram_conns),
165 stream_conns: Arc::from(stream_conns),
166 mdns_conns,
167 options: *options,
168 }
169 }
170
171 #[cfg(test)]
172 #[cfg(not(feature = "mdns"))]
173 #[allow(dead_code)]
174 fn from_nameservers_test(
175 options: &ResolverOpts,
176 datagram_conns: Arc<[NameServer<C, P>]>,
177 stream_conns: Arc<[NameServer<C, P>]>,
178 ) -> Self {
179 Self {
180 datagram_conns,
181 stream_conns,
182 options: *options,
183 }
184 }
185
186 #[cfg(test)]
187 #[cfg(feature = "mdns")]
188 fn from_nameservers_test(
189 options: &ResolverOpts,
190 datagram_conns: Arc<[NameServer<C, P>]>,
191 stream_conns: Arc<[NameServer<C, P>]>,
192 mdns_conns: NameServer<C, P>,
193 ) -> Self {
194 NameServerPool {
195 datagram_conns,
196 stream_conns,
197 mdns_conns,
198 options: *options,
199 conn_provider,
200 }
201 }
202
203 async fn try_send(
204 opts: ResolverOpts,
205 conns: Arc<[NameServer<C, P>]>,
206 request: DnsRequest,
207 ) -> Result<DnsResponse, ResolveError> {
208 let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
209
210 match opts.server_ordering_strategy {
211 ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
215 ServerOrderingStrategy::UserProvidedOrder => {}
216 }
217 let request_loop = request.clone();
218
219 parallel_conn_loop(conns, request_loop, opts).await
220 }
221}
222
223impl<C, P> DnsHandle for NameServerPool<C, P>
224where
225 C: DnsHandle<Error = ResolveError> + Sync + 'static,
226 P: ConnectionProvider<Conn = C> + 'static,
227{
228 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
229 type Error = ResolveError;
230
231 fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
232 let opts = self.options;
233 let request = request.into();
234 let datagram_conns = Arc::clone(&self.datagram_conns);
235 let stream_conns = Arc::clone(&self.stream_conns);
236 let tcp_message = request.clone();
238
239 #[cfg(feature = "mdns")]
241 let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
242
243 #[cfg(not(feature = "mdns"))]
245 let mdns = Local::NotMdns(request);
246
247 if mdns.is_local() {
249 return mdns.take_stream();
250 }
251
252 let request = mdns.take_request();
256 Box::pin(once(async move {
257 debug!("sending request: {:?}", request.queries());
258
259 let udp_res = match Self::try_send(opts, datagram_conns, request).await {
261 Ok(response) if response.truncated() => {
262 debug!("truncated response received, retrying over TCP");
263 Ok(response)
264 }
265 Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
266 debug!("error from UDP, retrying over TCP: {}", e);
267 Err(e)
268 }
269 result => return result,
270 };
271
272 if stream_conns.is_empty() {
273 debug!("no TCP connections available");
274 return udp_res;
275 }
276
277 let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
280
281 let tcp_err = match tcp_res {
282 res @ Ok(..) => return res,
283 Err(e) => e,
284 };
285
286 let udp_err = match udp_res {
288 Ok(response) => return Ok(response),
289 Err(e) => e,
290 };
291
292 match udp_err.cmp_specificity(&tcp_err) {
293 Ordering::Greater => Err(udp_err),
294 _ => Err(tcp_err),
295 }
296 }))
297 }
298}
299
300async fn parallel_conn_loop<C, P>(
303 mut conns: Vec<NameServer<C, P>>,
304 request: DnsRequest,
305 opts: ResolverOpts,
306) -> Result<DnsResponse, ResolveError>
307where
308 C: DnsHandle<Error = ResolveError> + 'static,
309 P: ConnectionProvider<Conn = C> + 'static,
310{
311 let mut err = ResolveError::no_connections();
312 let mut backoff = Duration::from_millis(20);
323 let mut busy = SmallVec::<[NameServer<C, P>; 2]>::new();
324
325 loop {
326 let request_cont = request.clone();
327
328 let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
330 let count = conns.len().min(opts.num_concurrent_reqs.max(1));
331 for conn in conns.drain(..count) {
332 par_conns.push(conn);
333 }
334
335 if par_conns.is_empty() {
336 if !busy.is_empty() && backoff < Duration::from_millis(300) {
337 P::Time::delay_for(backoff).await;
338 conns.extend(busy.drain(..));
339 backoff *= 2;
340 continue;
341 }
342 return Err(err);
343 }
344
345 let mut requests = par_conns
346 .into_iter()
347 .map(move |mut conn| {
348 conn.send(request_cont.clone())
349 .first_answer()
350 .map(|result| result.map_err(|e| (conn, e)))
351 })
352 .collect::<FuturesUnordered<_>>();
353
354 while let Some(result) = requests.next().await {
355 let (conn, e) = match result {
356 Ok(sent) => return Ok(sent),
357 Err((conn, e)) => (conn, e),
358 };
359
360 match e.kind() {
361 ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
362 return Err(e);
363 }
364 ResolveErrorKind::Proto(e) if e.is_busy() => {
365 busy.push(conn);
366 }
367 _ if err.cmp_specificity(&e) == Ordering::Less => {
368 err = e;
369 }
370 _ => {}
371 }
372 }
373 }
374}
375
376#[cfg(feature = "mdns")]
377mod mdns {
378 use super::*;
379
380 use proto::rr::domain::usage;
381 use proto::DnsHandle;
382
383 pub(crate) fn maybe_local<C, P>(
385 name_server: &mut NameServer<C, P>,
386 request: DnsRequest,
387 ) -> Local
388 where
389 C: DnsHandle<Error = ResolveError> + 'static,
390 P: ConnectionProvider<Conn = C> + 'static,
391 P: ConnectionProvider,
392 {
393 if request
394 .queries()
395 .iter()
396 .any(|query| usage::LOCAL.name().zone_of(query.name()))
397 {
398 Local::ResolveStream(name_server.send(request))
399 } else {
400 Local::NotMdns(request)
401 }
402 }
403}
404
405pub(crate) enum Local {
406 #[allow(dead_code)]
407 ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
408 NotMdns(DnsRequest),
409}
410
411impl Local {
412 fn is_local(&self) -> bool {
413 matches!(*self, Self::ResolveStream(..))
414 }
415
416 fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
422 match self {
423 Self::ResolveStream(future) => future,
424 _ => panic!("non Local queries have no future, see take_message()"),
425 }
426 }
427
428 fn take_request(self) -> DnsRequest {
434 match self {
435 Self::NotMdns(request) => request,
436 _ => panic!("Local queries must be polled, see take_future()"),
437 }
438 }
439}
440
441impl Stream for Local {
442 type Item = Result<DnsResponse, ResolveError>;
443
444 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
445 match *self {
446 Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
447 Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
450 }
451}
452
453#[cfg(test)]
454#[cfg(feature = "tokio-runtime")]
455mod tests {
456 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
457 use std::str::FromStr;
458
459 use tokio::runtime::Runtime;
460
461 use proto::op::Query;
462 use proto::rr::{Name, RecordType};
463 use proto::xfer::{DnsHandle, DnsRequestOptions};
464 use trust_dns_proto::rr::RData;
465
466 use super::*;
467 use crate::config::NameServerConfig;
468 use crate::config::Protocol;
469
470 #[ignore]
471 #[test]
473 fn test_failed_then_success_pool() {
474 let config1 = NameServerConfig {
475 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
476 protocol: Protocol::Udp,
477 tls_dns_name: None,
478 trust_nx_responses: false,
479 #[cfg(feature = "dns-over-rustls")]
480 tls_config: None,
481 bind_addr: None,
482 };
483
484 let config2 = NameServerConfig {
485 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
486 protocol: Protocol::Udp,
487 tls_dns_name: None,
488 trust_nx_responses: false,
489 #[cfg(feature = "dns-over-rustls")]
490 tls_config: None,
491 bind_addr: None,
492 };
493
494 let mut resolver_config = ResolverConfig::new();
495 resolver_config.add_name_server(config1);
496 resolver_config.add_name_server(config2);
497
498 let io_loop = Runtime::new().unwrap();
499 let mut pool = NameServerPool::<_, TokioConnectionProvider>::tokio_from_config(
500 &resolver_config,
501 &ResolverOpts::default(),
502 TokioHandle,
503 );
504
505 let name = Name::parse("www.example.com.", None).unwrap();
506
507 for i in 0..2 {
509 assert!(
510 io_loop
511 .block_on(
512 pool.lookup(
513 Query::query(name.clone(), RecordType::A),
514 DnsRequestOptions::default()
515 )
516 .first_answer()
517 )
518 .is_err(),
519 "iter: {}",
520 i
521 );
522 }
523
524 for i in 0..10 {
525 assert!(
526 io_loop
527 .block_on(
528 pool.lookup(
529 Query::query(name.clone(), RecordType::A),
530 DnsRequestOptions::default()
531 )
532 .first_answer()
533 )
534 .is_ok(),
535 "iter: {}",
536 i
537 );
538 }
539 }
540
541 #[test]
542 fn test_multi_use_conns() {
543 let io_loop = Runtime::new().unwrap();
544 let conn_provider = TokioConnectionProvider::new(TokioHandle);
545
546 let tcp = NameServerConfig {
547 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
548 protocol: Protocol::Tcp,
549 tls_dns_name: None,
550 trust_nx_responses: false,
551 #[cfg(feature = "dns-over-rustls")]
552 tls_config: None,
553 bind_addr: None,
554 };
555
556 let opts = ResolverOpts {
557 try_tcp_on_error: true,
558 ..ResolverOpts::default()
559 };
560 let ns_config = { tcp };
561 let name_server = NameServer::new_with_provider(ns_config, opts, conn_provider);
562 let name_servers: Arc<[_]> = Arc::from([name_server]);
563
564 let mut pool = NameServerPool::from_nameservers_test(
565 &opts,
566 Arc::from([]),
567 Arc::clone(&name_servers),
568 #[cfg(feature = "mdns")]
569 name_server::mdns_nameserver(opts, TokioConnectionProvider::new(TokioHandle)),
570 );
571
572 let name = Name::from_str("www.example.com.").unwrap();
573
574 let response = io_loop
576 .block_on(
577 pool.lookup(
578 Query::query(name.clone(), RecordType::A),
579 DnsRequestOptions::default(),
580 )
581 .first_answer(),
582 )
583 .expect("lookup failed");
584
585 assert_eq!(
586 *response.answers()[0]
587 .data()
588 .and_then(RData::as_a)
589 .expect("no a record available"),
590 Ipv4Addr::new(93, 184, 216, 34)
591 );
592
593 assert!(
594 name_servers[0].is_connected(),
595 "if this is failing then the NameServers aren't being properly shared."
596 );
597
598 let response = io_loop
600 .block_on(
601 pool.lookup(
602 Query::query(name, RecordType::AAAA),
603 DnsRequestOptions::default(),
604 )
605 .first_answer(),
606 )
607 .expect("lookup failed");
608
609 assert_eq!(
610 *response.answers()[0]
611 .data()
612 .and_then(RData::as_aaaa)
613 .expect("no aaaa record available"),
614 Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
615 );
616
617 assert!(
618 name_servers[0].is_connected(),
619 "if this is failing then the NameServers aren't being properly shared."
620 );
621 }
622}