Skip to main content

trust_dns_resolver/name_server/
name_server_pool.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::error::{ResolveError, ResolveErrorKind};
24#[cfg(feature = "mdns")]
25use crate::name_server;
26use crate::name_server::{ConnectionProvider, NameServer, NameServerStats};
27#[cfg(test)]
28#[cfg(feature = "tokio-runtime")]
29use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
30
31/// A pool of NameServers
32///
33/// This is not expected to be used directly, see [crate::AsyncResolver].
34#[derive(Clone)]
35pub struct NameServerPool<
36    C: DnsHandle<Error = ResolveError> + Send + Sync + 'static,
37    P: ConnectionProvider<Conn = C> + Send + 'static,
38> {
39    // TODO: switch to FuturesMutex (Mutex will have some undesireable locking)
40    datagram_conns: Arc<[NameServer<C, P>]>, /* All NameServers must be the same type */
41    stream_conns: Arc<[NameServer<C, P>]>,   /* All NameServers must be the same type */
42    #[cfg(feature = "mdns")]
43    mdns_conns: NameServer<C, P>, /* All NameServers must be the same type */
44    options: ResolverOpts,
45}
46
47#[cfg(test)]
48#[cfg(feature = "tokio-runtime")]
49impl NameServerPool<TokioConnection, TokioConnectionProvider> {
50    pub(crate) fn tokio_from_config(
51        config: &ResolverConfig,
52        options: &ResolverOpts,
53        runtime: TokioHandle,
54    ) -> Self {
55        Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
56    }
57}
58
59impl<C, P> NameServerPool<C, P>
60where
61    C: DnsHandle<Error = ResolveError> + Sync + 'static,
62    P: ConnectionProvider<Conn = C> + 'static,
63{
64    pub(crate) fn from_config_with_provider(
65        config: &ResolverConfig,
66        options: &ResolverOpts,
67        conn_provider: P,
68    ) -> Self {
69        let datagram_conns: Vec<NameServer<C, P>> = config
70            .name_servers()
71            .iter()
72            .filter(|ns_config| ns_config.protocol.is_datagram())
73            .map(|ns_config| {
74                #[cfg(feature = "dns-over-rustls")]
75                let ns_config = {
76                    let mut ns_config = ns_config.clone();
77                    ns_config.tls_config = config.client_config().clone();
78                    ns_config
79                };
80                #[cfg(not(feature = "dns-over-rustls"))]
81                let ns_config = { ns_config.clone() };
82
83                NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
84            })
85            .collect();
86
87        let stream_conns: Vec<NameServer<C, P>> = config
88            .name_servers()
89            .iter()
90            .filter(|ns_config| ns_config.protocol.is_stream())
91            .map(|ns_config| {
92                #[cfg(feature = "dns-over-rustls")]
93                let ns_config = {
94                    let mut ns_config = ns_config.clone();
95                    ns_config.tls_config = config.client_config().clone();
96                    ns_config
97                };
98                #[cfg(not(feature = "dns-over-rustls"))]
99                let ns_config = { ns_config.clone() };
100
101                NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
102            })
103            .collect();
104
105        Self {
106            datagram_conns: Arc::from(datagram_conns),
107            stream_conns: Arc::from(stream_conns),
108            #[cfg(feature = "mdns")]
109            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
110            options: *options,
111        }
112    }
113
114    /// Construct a NameServerPool from a set of name server configs
115    pub fn from_config(
116        name_servers: NameServerConfigGroup,
117        options: &ResolverOpts,
118        conn_provider: P,
119    ) -> Self {
120        let map_config_to_ns = |ns_config| {
121            NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
122        };
123
124        let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
125            .into_inner()
126            .into_iter()
127            .partition(|ns| ns.protocol.is_datagram());
128
129        let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
130        let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
131
132        Self {
133            datagram_conns: Arc::from(datagram_conns),
134            stream_conns: Arc::from(stream_conns),
135            #[cfg(feature = "mdns")]
136            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
137            options: *options,
138        }
139    }
140
141    #[doc(hidden)]
142    #[cfg(not(feature = "mdns"))]
143    pub fn from_nameservers(
144        options: &ResolverOpts,
145        datagram_conns: Vec<NameServer<C, P>>,
146        stream_conns: Vec<NameServer<C, P>>,
147    ) -> Self {
148        Self {
149            datagram_conns: Arc::from(datagram_conns),
150            stream_conns: Arc::from(stream_conns),
151            options: *options,
152        }
153    }
154
155    #[doc(hidden)]
156    #[cfg(feature = "mdns")]
157    pub fn from_nameservers(
158        options: &ResolverOpts,
159        datagram_conns: Vec<NameServer<C, P>>,
160        stream_conns: Vec<NameServer<C, P>>,
161        mdns_conns: NameServer<C, P>,
162    ) -> Self {
163        NameServerPool {
164            datagram_conns: Arc::from(datagram_conns),
165            stream_conns: Arc::from(stream_conns),
166            mdns_conns,
167            options: *options,
168        }
169    }
170
171    #[cfg(test)]
172    #[cfg(not(feature = "mdns"))]
173    #[allow(dead_code)]
174    fn from_nameservers_test(
175        options: &ResolverOpts,
176        datagram_conns: Arc<[NameServer<C, P>]>,
177        stream_conns: Arc<[NameServer<C, P>]>,
178    ) -> Self {
179        Self {
180            datagram_conns,
181            stream_conns,
182            options: *options,
183        }
184    }
185
186    #[cfg(test)]
187    #[cfg(feature = "mdns")]
188    fn from_nameservers_test(
189        options: &ResolverOpts,
190        datagram_conns: Arc<[NameServer<C, P>]>,
191        stream_conns: Arc<[NameServer<C, P>]>,
192        mdns_conns: NameServer<C, P>,
193    ) -> Self {
194        NameServerPool {
195            datagram_conns,
196            stream_conns,
197            mdns_conns,
198            options: *options,
199            conn_provider,
200        }
201    }
202
203    async fn try_send(
204        opts: ResolverOpts,
205        conns: Arc<[NameServer<C, P>]>,
206        request: DnsRequest,
207    ) -> Result<DnsResponse, ResolveError> {
208        let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
209
210        match opts.server_ordering_strategy {
211            // select the highest priority connection
212            //   reorder the connections based on current view...
213            //   this reorders the inner set
214            ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
215            ServerOrderingStrategy::UserProvidedOrder => {}
216        }
217        let request_loop = request.clone();
218
219        parallel_conn_loop(conns, request_loop, opts).await
220    }
221
222    pub fn name_server_stats(&self) -> Vec<NameServerStats> {
223        self.datagram_conns.iter().chain(self.stream_conns.iter()).map(NameServer::stats).collect()
224    }
225}
226
227impl<C, P> DnsHandle for NameServerPool<C, P>
228where
229    C: DnsHandle<Error = ResolveError> + Sync + 'static,
230    P: ConnectionProvider<Conn = C> + 'static,
231{
232    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
233    type Error = ResolveError;
234
235    fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
236        let opts = self.options;
237        let request = request.into();
238        let datagram_conns = Arc::clone(&self.datagram_conns);
239        let stream_conns = Arc::clone(&self.stream_conns);
240        // TODO: remove this clone, return the Message in the error?
241        let tcp_message = request.clone();
242
243        // if it's a .local. query, then we *only* query mDNS, these should never be sent on to upstream resolvers
244        #[cfg(feature = "mdns")]
245        let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
246
247        // TODO: limited to only when mDNS is enabled, but this should probably always be enforced?
248        #[cfg(not(feature = "mdns"))]
249        let mdns = Local::NotMdns(request);
250
251        // local queries are queried through mDNS
252        if mdns.is_local() {
253            return mdns.take_stream();
254        }
255
256        // TODO: should we allow mDNS to be used for standard lookups as well?
257
258        // it wasn't a local query, continue with standard lookup path
259        let request = mdns.take_request();
260        Box::pin(once(async move {
261            debug!("sending request: {:?}", request.queries());
262
263            // First try the UDP connections
264            let udp_res = match Self::try_send(opts, datagram_conns, request).await {
265                Ok(response) if response.truncated() => {
266                    debug!("truncated response received, retrying over TCP");
267                    Ok(response)
268                }
269                Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
270                    debug!("error from UDP, retrying over TCP: {}", e);
271                    Err(e)
272                }
273                result => return result,
274            };
275
276            if stream_conns.is_empty() {
277                debug!("no TCP connections available");
278                return udp_res;
279            }
280
281            // Try query over TCP, as response to query over UDP was either truncated or was an
282            // error.
283            let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
284
285            let tcp_err = match tcp_res {
286                res @ Ok(..) => return res,
287                Err(e) => e,
288            };
289
290            // Even if the UDP result was truncated, return that
291            let udp_err = match udp_res {
292                Ok(response) => return Ok(response),
293                Err(e) => e,
294            };
295
296            match udp_err.cmp_specificity(&tcp_err) {
297                Ordering::Greater => Err(udp_err),
298                _ => Err(tcp_err),
299            }
300        }))
301    }
302}
303
304// TODO: we should be able to have a self-referential future here with Pin and not require cloned conns
305/// An async function that will loop over all the conns with a max parallel request count of ops.num_concurrent_req
306async fn parallel_conn_loop<C, P>(
307    mut conns: Vec<NameServer<C, P>>,
308    request: DnsRequest,
309    opts: ResolverOpts,
310) -> Result<DnsResponse, ResolveError>
311where
312    C: DnsHandle<Error = ResolveError> + 'static,
313    P: ConnectionProvider<Conn = C> + 'static,
314{
315    let mut err = ResolveError::no_connections();
316    // If the name server we're trying is giving us backpressure by returning ProtoErrorKind::Busy,
317    // we will first try the other name servers (as for other error types). However, if the other
318    // servers are also busy, we're going to wait for a little while and then retry each server that
319    // returned Busy in the previous round. If the server is still Busy, this continues, while
320    // the backoff increases exponentially (by a factor of 2), until it hits 300ms, in which case we
321    // give up. The request might still be retried by the caller (likely the DnsRetryHandle).
322    //
323    // TODO: more principled handling of timeouts. Currently, timeouts appear to be handled mostly
324    // close to the connection, which means the top level resolution might take substantially longer
325    // to fire than the timeout configured in `ResolverOpts`.
326    let mut backoff = Duration::from_millis(20);
327    let mut busy = SmallVec::<[NameServer<C, P>; 2]>::new();
328
329    loop {
330        let request_cont = request.clone();
331
332        // construct the parallel requests, 2 is the default
333        let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
334        let count = conns.len().min(opts.num_concurrent_reqs.max(1));
335        for conn in conns.drain(..count) {
336            par_conns.push(conn);
337        }
338
339        if par_conns.is_empty() {
340            if !busy.is_empty() && backoff < Duration::from_millis(300) {
341                P::Time::delay_for(backoff).await;
342                conns.extend(busy.drain(..));
343                backoff *= 2;
344                continue;
345            }
346            return Err(err);
347        }
348
349        let mut requests = par_conns
350            .into_iter()
351            .map(move |mut conn| {
352                conn.send(request_cont.clone())
353                    .first_answer()
354                    .map(|result| result.map_err(|e| (conn, e)))
355            })
356            .collect::<FuturesUnordered<_>>();
357
358        while let Some(result) = requests.next().await {
359            let (conn, e) = match result {
360                Ok(sent) => return Ok(sent),
361                Err((conn, e)) => (conn, e),
362            };
363
364            match e.kind() {
365                ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
366                    return Err(e);
367                }
368                ResolveErrorKind::Proto(e) if e.is_busy() => {
369                    busy.push(conn);
370                }
371                _ if err.cmp_specificity(&e) == Ordering::Less => {
372                    err = e;
373                }
374                _ => {}
375            }
376        }
377    }
378}
379
380#[cfg(feature = "mdns")]
381mod mdns {
382    use super::*;
383
384    use proto::rr::domain::usage;
385    use proto::DnsHandle;
386
387    /// Returns true
388    pub(crate) fn maybe_local<C, P>(
389        name_server: &mut NameServer<C, P>,
390        request: DnsRequest,
391    ) -> Local
392    where
393        C: DnsHandle<Error = ResolveError> + 'static,
394        P: ConnectionProvider<Conn = C> + 'static,
395        P: ConnectionProvider,
396    {
397        if request
398            .queries()
399            .iter()
400            .any(|query| usage::LOCAL.name().zone_of(query.name()))
401        {
402            Local::ResolveStream(name_server.send(request))
403        } else {
404            Local::NotMdns(request)
405        }
406    }
407}
408
409pub(crate) enum Local {
410    #[allow(dead_code)]
411    ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
412    NotMdns(DnsRequest),
413}
414
415impl Local {
416    fn is_local(&self) -> bool {
417        matches!(*self, Self::ResolveStream(..))
418    }
419
420    /// Takes the stream
421    ///
422    /// # Panics
423    ///
424    /// Panics if this is in fact a Local::NotMdns
425    fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
426        match self {
427            Self::ResolveStream(future) => future,
428            _ => panic!("non Local queries have no future, see take_message()"),
429        }
430    }
431
432    /// Takes the message
433    ///
434    /// # Panics
435    ///
436    /// Panics if this is in fact a Local::ResolveStream
437    fn take_request(self) -> DnsRequest {
438        match self {
439            Self::NotMdns(request) => request,
440            _ => panic!("Local queries must be polled, see take_future()"),
441        }
442    }
443}
444
445impl Stream for Local {
446    type Item = Result<DnsResponse, ResolveError>;
447
448    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
449        match *self {
450            Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
451            // TODO: making this a panic for now
452            Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), //Local::NotMdns(message) => return Err(ResolveErrorKind::Message("not mDNS")),
453        }
454    }
455}
456
457#[cfg(test)]
458#[cfg(feature = "tokio-runtime")]
459mod tests {
460    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
461    use std::str::FromStr;
462
463    use tokio::runtime::Runtime;
464
465    use proto::op::Query;
466    use proto::rr::{Name, RecordType};
467    use proto::xfer::{DnsHandle, DnsRequestOptions};
468    use trust_dns_proto::rr::RData;
469
470    use super::*;
471    use crate::config::NameServerConfig;
472    use crate::config::Protocol;
473
474    #[ignore]
475    // because of there is a real connection that needs a reasonable timeout
476    #[test]
477    fn test_failed_then_success_pool() {
478        let config1 = NameServerConfig {
479            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
480            protocol: Protocol::Udp,
481            tls_dns_name: None,
482            trust_nx_responses: false,
483            #[cfg(feature = "dns-over-rustls")]
484            tls_config: None,
485            bind_addr: None,
486        };
487
488        let config2 = NameServerConfig {
489            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
490            protocol: Protocol::Udp,
491            tls_dns_name: None,
492            trust_nx_responses: false,
493            #[cfg(feature = "dns-over-rustls")]
494            tls_config: None,
495            bind_addr: None,
496        };
497
498        let mut resolver_config = ResolverConfig::new();
499        resolver_config.add_name_server(config1);
500        resolver_config.add_name_server(config2);
501
502        let io_loop = Runtime::new().unwrap();
503        let mut pool = NameServerPool::<_, TokioConnectionProvider>::tokio_from_config(
504            &resolver_config,
505            &ResolverOpts::default(),
506            TokioHandle,
507        );
508
509        let name = Name::parse("www.example.com.", None).unwrap();
510
511        // TODO: it's not clear why there are two failures before the success
512        for i in 0..2 {
513            assert!(
514                io_loop
515                    .block_on(
516                        pool.lookup(
517                            Query::query(name.clone(), RecordType::A),
518                            DnsRequestOptions::default()
519                        )
520                        .first_answer()
521                    )
522                    .is_err(),
523                "iter: {}",
524                i
525            );
526        }
527
528        for i in 0..10 {
529            assert!(
530                io_loop
531                    .block_on(
532                        pool.lookup(
533                            Query::query(name.clone(), RecordType::A),
534                            DnsRequestOptions::default()
535                        )
536                        .first_answer()
537                    )
538                    .is_ok(),
539                "iter: {}",
540                i
541            );
542        }
543    }
544
545    #[test]
546    fn test_multi_use_conns() {
547        let io_loop = Runtime::new().unwrap();
548        let conn_provider = TokioConnectionProvider::new(TokioHandle);
549
550        let tcp = NameServerConfig {
551            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
552            protocol: Protocol::Tcp,
553            tls_dns_name: None,
554            trust_nx_responses: false,
555            #[cfg(feature = "dns-over-rustls")]
556            tls_config: None,
557            bind_addr: None,
558        };
559
560        let opts = ResolverOpts {
561            try_tcp_on_error: true,
562            ..ResolverOpts::default()
563        };
564        let ns_config = { tcp };
565        let name_server = NameServer::new_with_provider(ns_config, opts, conn_provider);
566        let name_servers: Arc<[_]> = Arc::from([name_server]);
567
568        let mut pool = NameServerPool::from_nameservers_test(
569            &opts,
570            Arc::from([]),
571            Arc::clone(&name_servers),
572            #[cfg(feature = "mdns")]
573            name_server::mdns_nameserver(opts, TokioConnectionProvider::new(TokioHandle)),
574        );
575
576        let name = Name::from_str("www.example.com.").unwrap();
577
578        // first lookup
579        let response = io_loop
580            .block_on(
581                pool.lookup(
582                    Query::query(name.clone(), RecordType::A),
583                    DnsRequestOptions::default(),
584                )
585                .first_answer(),
586            )
587            .expect("lookup failed");
588
589        assert_eq!(
590            *response.answers()[0]
591                .data()
592                .and_then(RData::as_a)
593                .expect("no a record available"),
594            Ipv4Addr::new(93, 184, 216, 34)
595        );
596
597        assert!(
598            name_servers[0].is_connected(),
599            "if this is failing then the NameServers aren't being properly shared."
600        );
601
602        // first lookup
603        let response = io_loop
604            .block_on(
605                pool.lookup(
606                    Query::query(name, RecordType::AAAA),
607                    DnsRequestOptions::default(),
608                )
609                .first_answer(),
610            )
611            .expect("lookup failed");
612
613        assert_eq!(
614            *response.answers()[0]
615                .data()
616                .and_then(RData::as_aaaa)
617                .expect("no aaaa record available"),
618            Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
619        );
620
621        assert!(
622            name_servers[0].is_connected(),
623            "if this is failing then the NameServers aren't being properly shared."
624        );
625    }
626}