1use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::error::{ResolveError, ResolveErrorKind};
24#[cfg(feature = "mdns")]
25use crate::name_server;
26use crate::name_server::{ConnectionProvider, NameServer, NameServerStats};
27#[cfg(test)]
28#[cfg(feature = "tokio-runtime")]
29use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
30
31#[derive(Clone)]
35pub struct NameServerPool<
36 C: DnsHandle<Error = ResolveError> + Send + Sync + 'static,
37 P: ConnectionProvider<Conn = C> + Send + 'static,
38> {
39 datagram_conns: Arc<[NameServer<C, P>]>, stream_conns: Arc<[NameServer<C, P>]>, #[cfg(feature = "mdns")]
43 mdns_conns: NameServer<C, P>, options: ResolverOpts,
45}
46
47#[cfg(test)]
48#[cfg(feature = "tokio-runtime")]
49impl NameServerPool<TokioConnection, TokioConnectionProvider> {
50 pub(crate) fn tokio_from_config(
51 config: &ResolverConfig,
52 options: &ResolverOpts,
53 runtime: TokioHandle,
54 ) -> Self {
55 Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
56 }
57}
58
59impl<C, P> NameServerPool<C, P>
60where
61 C: DnsHandle<Error = ResolveError> + Sync + 'static,
62 P: ConnectionProvider<Conn = C> + 'static,
63{
64 pub(crate) fn from_config_with_provider(
65 config: &ResolverConfig,
66 options: &ResolverOpts,
67 conn_provider: P,
68 ) -> Self {
69 let datagram_conns: Vec<NameServer<C, P>> = config
70 .name_servers()
71 .iter()
72 .filter(|ns_config| ns_config.protocol.is_datagram())
73 .map(|ns_config| {
74 #[cfg(feature = "dns-over-rustls")]
75 let ns_config = {
76 let mut ns_config = ns_config.clone();
77 ns_config.tls_config = config.client_config().clone();
78 ns_config
79 };
80 #[cfg(not(feature = "dns-over-rustls"))]
81 let ns_config = { ns_config.clone() };
82
83 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
84 })
85 .collect();
86
87 let stream_conns: Vec<NameServer<C, P>> = config
88 .name_servers()
89 .iter()
90 .filter(|ns_config| ns_config.protocol.is_stream())
91 .map(|ns_config| {
92 #[cfg(feature = "dns-over-rustls")]
93 let ns_config = {
94 let mut ns_config = ns_config.clone();
95 ns_config.tls_config = config.client_config().clone();
96 ns_config
97 };
98 #[cfg(not(feature = "dns-over-rustls"))]
99 let ns_config = { ns_config.clone() };
100
101 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
102 })
103 .collect();
104
105 Self {
106 datagram_conns: Arc::from(datagram_conns),
107 stream_conns: Arc::from(stream_conns),
108 #[cfg(feature = "mdns")]
109 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
110 options: *options,
111 }
112 }
113
114 pub fn from_config(
116 name_servers: NameServerConfigGroup,
117 options: &ResolverOpts,
118 conn_provider: P,
119 ) -> Self {
120 let map_config_to_ns = |ns_config| {
121 NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
122 };
123
124 let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
125 .into_inner()
126 .into_iter()
127 .partition(|ns| ns.protocol.is_datagram());
128
129 let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
130 let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
131
132 Self {
133 datagram_conns: Arc::from(datagram_conns),
134 stream_conns: Arc::from(stream_conns),
135 #[cfg(feature = "mdns")]
136 mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
137 options: *options,
138 }
139 }
140
141 #[doc(hidden)]
142 #[cfg(not(feature = "mdns"))]
143 pub fn from_nameservers(
144 options: &ResolverOpts,
145 datagram_conns: Vec<NameServer<C, P>>,
146 stream_conns: Vec<NameServer<C, P>>,
147 ) -> Self {
148 Self {
149 datagram_conns: Arc::from(datagram_conns),
150 stream_conns: Arc::from(stream_conns),
151 options: *options,
152 }
153 }
154
155 #[doc(hidden)]
156 #[cfg(feature = "mdns")]
157 pub fn from_nameservers(
158 options: &ResolverOpts,
159 datagram_conns: Vec<NameServer<C, P>>,
160 stream_conns: Vec<NameServer<C, P>>,
161 mdns_conns: NameServer<C, P>,
162 ) -> Self {
163 NameServerPool {
164 datagram_conns: Arc::from(datagram_conns),
165 stream_conns: Arc::from(stream_conns),
166 mdns_conns,
167 options: *options,
168 }
169 }
170
171 #[cfg(test)]
172 #[cfg(not(feature = "mdns"))]
173 #[allow(dead_code)]
174 fn from_nameservers_test(
175 options: &ResolverOpts,
176 datagram_conns: Arc<[NameServer<C, P>]>,
177 stream_conns: Arc<[NameServer<C, P>]>,
178 ) -> Self {
179 Self {
180 datagram_conns,
181 stream_conns,
182 options: *options,
183 }
184 }
185
186 #[cfg(test)]
187 #[cfg(feature = "mdns")]
188 fn from_nameservers_test(
189 options: &ResolverOpts,
190 datagram_conns: Arc<[NameServer<C, P>]>,
191 stream_conns: Arc<[NameServer<C, P>]>,
192 mdns_conns: NameServer<C, P>,
193 ) -> Self {
194 NameServerPool {
195 datagram_conns,
196 stream_conns,
197 mdns_conns,
198 options: *options,
199 conn_provider,
200 }
201 }
202
203 async fn try_send(
204 opts: ResolverOpts,
205 conns: Arc<[NameServer<C, P>]>,
206 request: DnsRequest,
207 ) -> Result<DnsResponse, ResolveError> {
208 let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
209
210 match opts.server_ordering_strategy {
211 ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
215 ServerOrderingStrategy::UserProvidedOrder => {}
216 }
217 let request_loop = request.clone();
218
219 parallel_conn_loop(conns, request_loop, opts).await
220 }
221
222 pub fn name_server_stats(&self) -> Vec<NameServerStats> {
223 self.datagram_conns.iter().chain(self.stream_conns.iter()).map(NameServer::stats).collect()
224 }
225}
226
227impl<C, P> DnsHandle for NameServerPool<C, P>
228where
229 C: DnsHandle<Error = ResolveError> + Sync + 'static,
230 P: ConnectionProvider<Conn = C> + 'static,
231{
232 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
233 type Error = ResolveError;
234
235 fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
236 let opts = self.options;
237 let request = request.into();
238 let datagram_conns = Arc::clone(&self.datagram_conns);
239 let stream_conns = Arc::clone(&self.stream_conns);
240 let tcp_message = request.clone();
242
243 #[cfg(feature = "mdns")]
245 let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
246
247 #[cfg(not(feature = "mdns"))]
249 let mdns = Local::NotMdns(request);
250
251 if mdns.is_local() {
253 return mdns.take_stream();
254 }
255
256 let request = mdns.take_request();
260 Box::pin(once(async move {
261 debug!("sending request: {:?}", request.queries());
262
263 let udp_res = match Self::try_send(opts, datagram_conns, request).await {
265 Ok(response) if response.truncated() => {
266 debug!("truncated response received, retrying over TCP");
267 Ok(response)
268 }
269 Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
270 debug!("error from UDP, retrying over TCP: {}", e);
271 Err(e)
272 }
273 result => return result,
274 };
275
276 if stream_conns.is_empty() {
277 debug!("no TCP connections available");
278 return udp_res;
279 }
280
281 let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
284
285 let tcp_err = match tcp_res {
286 res @ Ok(..) => return res,
287 Err(e) => e,
288 };
289
290 let udp_err = match udp_res {
292 Ok(response) => return Ok(response),
293 Err(e) => e,
294 };
295
296 match udp_err.cmp_specificity(&tcp_err) {
297 Ordering::Greater => Err(udp_err),
298 _ => Err(tcp_err),
299 }
300 }))
301 }
302}
303
304async fn parallel_conn_loop<C, P>(
307 mut conns: Vec<NameServer<C, P>>,
308 request: DnsRequest,
309 opts: ResolverOpts,
310) -> Result<DnsResponse, ResolveError>
311where
312 C: DnsHandle<Error = ResolveError> + 'static,
313 P: ConnectionProvider<Conn = C> + 'static,
314{
315 let mut err = ResolveError::no_connections();
316 let mut backoff = Duration::from_millis(20);
327 let mut busy = SmallVec::<[NameServer<C, P>; 2]>::new();
328
329 loop {
330 let request_cont = request.clone();
331
332 let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
334 let count = conns.len().min(opts.num_concurrent_reqs.max(1));
335 for conn in conns.drain(..count) {
336 par_conns.push(conn);
337 }
338
339 if par_conns.is_empty() {
340 if !busy.is_empty() && backoff < Duration::from_millis(300) {
341 P::Time::delay_for(backoff).await;
342 conns.extend(busy.drain(..));
343 backoff *= 2;
344 continue;
345 }
346 return Err(err);
347 }
348
349 let mut requests = par_conns
350 .into_iter()
351 .map(move |mut conn| {
352 conn.send(request_cont.clone())
353 .first_answer()
354 .map(|result| result.map_err(|e| (conn, e)))
355 })
356 .collect::<FuturesUnordered<_>>();
357
358 while let Some(result) = requests.next().await {
359 let (conn, e) = match result {
360 Ok(sent) => return Ok(sent),
361 Err((conn, e)) => (conn, e),
362 };
363
364 match e.kind() {
365 ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
366 return Err(e);
367 }
368 ResolveErrorKind::Proto(e) if e.is_busy() => {
369 busy.push(conn);
370 }
371 _ if err.cmp_specificity(&e) == Ordering::Less => {
372 err = e;
373 }
374 _ => {}
375 }
376 }
377 }
378}
379
380#[cfg(feature = "mdns")]
381mod mdns {
382 use super::*;
383
384 use proto::rr::domain::usage;
385 use proto::DnsHandle;
386
387 pub(crate) fn maybe_local<C, P>(
389 name_server: &mut NameServer<C, P>,
390 request: DnsRequest,
391 ) -> Local
392 where
393 C: DnsHandle<Error = ResolveError> + 'static,
394 P: ConnectionProvider<Conn = C> + 'static,
395 P: ConnectionProvider,
396 {
397 if request
398 .queries()
399 .iter()
400 .any(|query| usage::LOCAL.name().zone_of(query.name()))
401 {
402 Local::ResolveStream(name_server.send(request))
403 } else {
404 Local::NotMdns(request)
405 }
406 }
407}
408
409pub(crate) enum Local {
410 #[allow(dead_code)]
411 ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
412 NotMdns(DnsRequest),
413}
414
415impl Local {
416 fn is_local(&self) -> bool {
417 matches!(*self, Self::ResolveStream(..))
418 }
419
420 fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
426 match self {
427 Self::ResolveStream(future) => future,
428 _ => panic!("non Local queries have no future, see take_message()"),
429 }
430 }
431
432 fn take_request(self) -> DnsRequest {
438 match self {
439 Self::NotMdns(request) => request,
440 _ => panic!("Local queries must be polled, see take_future()"),
441 }
442 }
443}
444
445impl Stream for Local {
446 type Item = Result<DnsResponse, ResolveError>;
447
448 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
449 match *self {
450 Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
451 Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
454 }
455}
456
457#[cfg(test)]
458#[cfg(feature = "tokio-runtime")]
459mod tests {
460 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
461 use std::str::FromStr;
462
463 use tokio::runtime::Runtime;
464
465 use proto::op::Query;
466 use proto::rr::{Name, RecordType};
467 use proto::xfer::{DnsHandle, DnsRequestOptions};
468 use trust_dns_proto::rr::RData;
469
470 use super::*;
471 use crate::config::NameServerConfig;
472 use crate::config::Protocol;
473
474 #[ignore]
475 #[test]
477 fn test_failed_then_success_pool() {
478 let config1 = NameServerConfig {
479 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
480 protocol: Protocol::Udp,
481 tls_dns_name: None,
482 trust_nx_responses: false,
483 #[cfg(feature = "dns-over-rustls")]
484 tls_config: None,
485 bind_addr: None,
486 };
487
488 let config2 = NameServerConfig {
489 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
490 protocol: Protocol::Udp,
491 tls_dns_name: None,
492 trust_nx_responses: false,
493 #[cfg(feature = "dns-over-rustls")]
494 tls_config: None,
495 bind_addr: None,
496 };
497
498 let mut resolver_config = ResolverConfig::new();
499 resolver_config.add_name_server(config1);
500 resolver_config.add_name_server(config2);
501
502 let io_loop = Runtime::new().unwrap();
503 let mut pool = NameServerPool::<_, TokioConnectionProvider>::tokio_from_config(
504 &resolver_config,
505 &ResolverOpts::default(),
506 TokioHandle,
507 );
508
509 let name = Name::parse("www.example.com.", None).unwrap();
510
511 for i in 0..2 {
513 assert!(
514 io_loop
515 .block_on(
516 pool.lookup(
517 Query::query(name.clone(), RecordType::A),
518 DnsRequestOptions::default()
519 )
520 .first_answer()
521 )
522 .is_err(),
523 "iter: {}",
524 i
525 );
526 }
527
528 for i in 0..10 {
529 assert!(
530 io_loop
531 .block_on(
532 pool.lookup(
533 Query::query(name.clone(), RecordType::A),
534 DnsRequestOptions::default()
535 )
536 .first_answer()
537 )
538 .is_ok(),
539 "iter: {}",
540 i
541 );
542 }
543 }
544
545 #[test]
546 fn test_multi_use_conns() {
547 let io_loop = Runtime::new().unwrap();
548 let conn_provider = TokioConnectionProvider::new(TokioHandle);
549
550 let tcp = NameServerConfig {
551 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
552 protocol: Protocol::Tcp,
553 tls_dns_name: None,
554 trust_nx_responses: false,
555 #[cfg(feature = "dns-over-rustls")]
556 tls_config: None,
557 bind_addr: None,
558 };
559
560 let opts = ResolverOpts {
561 try_tcp_on_error: true,
562 ..ResolverOpts::default()
563 };
564 let ns_config = { tcp };
565 let name_server = NameServer::new_with_provider(ns_config, opts, conn_provider);
566 let name_servers: Arc<[_]> = Arc::from([name_server]);
567
568 let mut pool = NameServerPool::from_nameservers_test(
569 &opts,
570 Arc::from([]),
571 Arc::clone(&name_servers),
572 #[cfg(feature = "mdns")]
573 name_server::mdns_nameserver(opts, TokioConnectionProvider::new(TokioHandle)),
574 );
575
576 let name = Name::from_str("www.example.com.").unwrap();
577
578 let response = io_loop
580 .block_on(
581 pool.lookup(
582 Query::query(name.clone(), RecordType::A),
583 DnsRequestOptions::default(),
584 )
585 .first_answer(),
586 )
587 .expect("lookup failed");
588
589 assert_eq!(
590 *response.answers()[0]
591 .data()
592 .and_then(RData::as_a)
593 .expect("no a record available"),
594 Ipv4Addr::new(93, 184, 216, 34)
595 );
596
597 assert!(
598 name_servers[0].is_connected(),
599 "if this is failing then the NameServers aren't being properly shared."
600 );
601
602 let response = io_loop
604 .block_on(
605 pool.lookup(
606 Query::query(name, RecordType::AAAA),
607 DnsRequestOptions::default(),
608 )
609 .first_answer(),
610 )
611 .expect("lookup failed");
612
613 assert_eq!(
614 *response.answers()[0]
615 .data()
616 .and_then(RData::as_aaaa)
617 .expect("no aaaa record available"),
618 Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
619 );
620
621 assert!(
622 name_servers[0].is_connected(),
623 "if this is failing then the NameServers aren't being properly shared."
624 );
625 }
626}