trust_dns_resolver/name_server/
name_server_pool.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::task::{Context, Poll};
12use std::time::Duration;
13
14use futures_util::future::FutureExt;
15use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
16use smallvec::SmallVec;
17
18use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
19use proto::Time;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::error::{ResolveError, ResolveErrorKind};
24#[cfg(feature = "mdns")]
25use crate::name_server;
26use crate::name_server::{ConnectionProvider, NameServer};
27#[cfg(test)]
28#[cfg(feature = "tokio-runtime")]
29use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
30
31/// A pool of NameServers
32///
33/// This is not expected to be used directly, see [crate::AsyncResolver].
34#[derive(Clone)]
35pub struct NameServerPool<
36    C: DnsHandle<Error = ResolveError> + Send + Sync + 'static,
37    P: ConnectionProvider<Conn = C> + Send + 'static,
38> {
39    // TODO: switch to FuturesMutex (Mutex will have some undesireable locking)
40    datagram_conns: Arc<[NameServer<C, P>]>, /* All NameServers must be the same type */
41    stream_conns: Arc<[NameServer<C, P>]>,   /* All NameServers must be the same type */
42    #[cfg(feature = "mdns")]
43    mdns_conns: NameServer<C, P>, /* All NameServers must be the same type */
44    options: ResolverOpts,
45}
46
47#[cfg(test)]
48#[cfg(feature = "tokio-runtime")]
49impl NameServerPool<TokioConnection, TokioConnectionProvider> {
50    pub(crate) fn tokio_from_config(
51        config: &ResolverConfig,
52        options: &ResolverOpts,
53        runtime: TokioHandle,
54    ) -> Self {
55        Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
56    }
57}
58
59impl<C, P> NameServerPool<C, P>
60where
61    C: DnsHandle<Error = ResolveError> + Sync + 'static,
62    P: ConnectionProvider<Conn = C> + 'static,
63{
64    pub(crate) fn from_config_with_provider(
65        config: &ResolverConfig,
66        options: &ResolverOpts,
67        conn_provider: P,
68    ) -> Self {
69        let datagram_conns: Vec<NameServer<C, P>> = config
70            .name_servers()
71            .iter()
72            .filter(|ns_config| ns_config.protocol.is_datagram())
73            .map(|ns_config| {
74                #[cfg(feature = "dns-over-rustls")]
75                let ns_config = {
76                    let mut ns_config = ns_config.clone();
77                    ns_config.tls_config = config.client_config().clone();
78                    ns_config
79                };
80                #[cfg(not(feature = "dns-over-rustls"))]
81                let ns_config = { ns_config.clone() };
82
83                NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
84            })
85            .collect();
86
87        let stream_conns: Vec<NameServer<C, P>> = config
88            .name_servers()
89            .iter()
90            .filter(|ns_config| ns_config.protocol.is_stream())
91            .map(|ns_config| {
92                #[cfg(feature = "dns-over-rustls")]
93                let ns_config = {
94                    let mut ns_config = ns_config.clone();
95                    ns_config.tls_config = config.client_config().clone();
96                    ns_config
97                };
98                #[cfg(not(feature = "dns-over-rustls"))]
99                let ns_config = { ns_config.clone() };
100
101                NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
102            })
103            .collect();
104
105        Self {
106            datagram_conns: Arc::from(datagram_conns),
107            stream_conns: Arc::from(stream_conns),
108            #[cfg(feature = "mdns")]
109            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
110            options: *options,
111        }
112    }
113
114    /// Construct a NameServerPool from a set of name server configs
115    pub fn from_config(
116        name_servers: NameServerConfigGroup,
117        options: &ResolverOpts,
118        conn_provider: P,
119    ) -> Self {
120        let map_config_to_ns = |ns_config| {
121            NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
122        };
123
124        let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
125            .into_inner()
126            .into_iter()
127            .partition(|ns| ns.protocol.is_datagram());
128
129        let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
130        let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
131
132        Self {
133            datagram_conns: Arc::from(datagram_conns),
134            stream_conns: Arc::from(stream_conns),
135            #[cfg(feature = "mdns")]
136            mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
137            options: *options,
138        }
139    }
140
141    #[doc(hidden)]
142    #[cfg(not(feature = "mdns"))]
143    pub fn from_nameservers(
144        options: &ResolverOpts,
145        datagram_conns: Vec<NameServer<C, P>>,
146        stream_conns: Vec<NameServer<C, P>>,
147    ) -> Self {
148        Self {
149            datagram_conns: Arc::from(datagram_conns),
150            stream_conns: Arc::from(stream_conns),
151            options: *options,
152        }
153    }
154
155    #[doc(hidden)]
156    #[cfg(feature = "mdns")]
157    pub fn from_nameservers(
158        options: &ResolverOpts,
159        datagram_conns: Vec<NameServer<C, P>>,
160        stream_conns: Vec<NameServer<C, P>>,
161        mdns_conns: NameServer<C, P>,
162    ) -> Self {
163        NameServerPool {
164            datagram_conns: Arc::from(datagram_conns),
165            stream_conns: Arc::from(stream_conns),
166            mdns_conns,
167            options: *options,
168        }
169    }
170
171    #[cfg(test)]
172    #[cfg(not(feature = "mdns"))]
173    #[allow(dead_code)]
174    fn from_nameservers_test(
175        options: &ResolverOpts,
176        datagram_conns: Arc<[NameServer<C, P>]>,
177        stream_conns: Arc<[NameServer<C, P>]>,
178    ) -> Self {
179        Self {
180            datagram_conns,
181            stream_conns,
182            options: *options,
183        }
184    }
185
186    #[cfg(test)]
187    #[cfg(feature = "mdns")]
188    fn from_nameservers_test(
189        options: &ResolverOpts,
190        datagram_conns: Arc<[NameServer<C, P>]>,
191        stream_conns: Arc<[NameServer<C, P>]>,
192        mdns_conns: NameServer<C, P>,
193    ) -> Self {
194        NameServerPool {
195            datagram_conns,
196            stream_conns,
197            mdns_conns,
198            options: *options,
199            conn_provider,
200        }
201    }
202
203    async fn try_send(
204        opts: ResolverOpts,
205        conns: Arc<[NameServer<C, P>]>,
206        request: DnsRequest,
207    ) -> Result<DnsResponse, ResolveError> {
208        let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
209
210        match opts.server_ordering_strategy {
211            // select the highest priority connection
212            //   reorder the connections based on current view...
213            //   this reorders the inner set
214            ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
215            ServerOrderingStrategy::UserProvidedOrder => {}
216        }
217        let request_loop = request.clone();
218
219        parallel_conn_loop(conns, request_loop, opts).await
220    }
221}
222
223impl<C, P> DnsHandle for NameServerPool<C, P>
224where
225    C: DnsHandle<Error = ResolveError> + Sync + 'static,
226    P: ConnectionProvider<Conn = C> + 'static,
227{
228    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
229    type Error = ResolveError;
230
231    fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
232        let opts = self.options;
233        let request = request.into();
234        let datagram_conns = Arc::clone(&self.datagram_conns);
235        let stream_conns = Arc::clone(&self.stream_conns);
236        // TODO: remove this clone, return the Message in the error?
237        let tcp_message = request.clone();
238
239        // if it's a .local. query, then we *only* query mDNS, these should never be sent on to upstream resolvers
240        #[cfg(feature = "mdns")]
241        let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
242
243        // TODO: limited to only when mDNS is enabled, but this should probably always be enforced?
244        #[cfg(not(feature = "mdns"))]
245        let mdns = Local::NotMdns(request);
246
247        // local queries are queried through mDNS
248        if mdns.is_local() {
249            return mdns.take_stream();
250        }
251
252        // TODO: should we allow mDNS to be used for standard lookups as well?
253
254        // it wasn't a local query, continue with standard lookup path
255        let request = mdns.take_request();
256        Box::pin(once(async move {
257            debug!("sending request: {:?}", request.queries());
258
259            // First try the UDP connections
260            let udp_res = match Self::try_send(opts, datagram_conns, request).await {
261                Ok(response) if response.truncated() => {
262                    debug!("truncated response received, retrying over TCP");
263                    Ok(response)
264                }
265                Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
266                    debug!("error from UDP, retrying over TCP: {}", e);
267                    Err(e)
268                }
269                result => return result,
270            };
271
272            if stream_conns.is_empty() {
273                debug!("no TCP connections available");
274                return udp_res;
275            }
276
277            // Try query over TCP, as response to query over UDP was either truncated or was an
278            // error.
279            let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
280
281            let tcp_err = match tcp_res {
282                res @ Ok(..) => return res,
283                Err(e) => e,
284            };
285
286            // Even if the UDP result was truncated, return that
287            let udp_err = match udp_res {
288                Ok(response) => return Ok(response),
289                Err(e) => e,
290            };
291
292            match udp_err.cmp_specificity(&tcp_err) {
293                Ordering::Greater => Err(udp_err),
294                _ => Err(tcp_err),
295            }
296        }))
297    }
298}
299
300// TODO: we should be able to have a self-referential future here with Pin and not require cloned conns
301/// An async function that will loop over all the conns with a max parallel request count of ops.num_concurrent_req
302async fn parallel_conn_loop<C, P>(
303    mut conns: Vec<NameServer<C, P>>,
304    request: DnsRequest,
305    opts: ResolverOpts,
306) -> Result<DnsResponse, ResolveError>
307where
308    C: DnsHandle<Error = ResolveError> + 'static,
309    P: ConnectionProvider<Conn = C> + 'static,
310{
311    let mut err = ResolveError::no_connections();
312    // If the name server we're trying is giving us backpressure by returning ProtoErrorKind::Busy,
313    // we will first try the other name servers (as for other error types). However, if the other
314    // servers are also busy, we're going to wait for a little while and then retry each server that
315    // returned Busy in the previous round. If the server is still Busy, this continues, while
316    // the backoff increases exponentially (by a factor of 2), until it hits 300ms, in which case we
317    // give up. The request might still be retried by the caller (likely the DnsRetryHandle).
318    //
319    // TODO: more principled handling of timeouts. Currently, timeouts appear to be handled mostly
320    // close to the connection, which means the top level resolution might take substantially longer
321    // to fire than the timeout configured in `ResolverOpts`.
322    let mut backoff = Duration::from_millis(20);
323    let mut busy = SmallVec::<[NameServer<C, P>; 2]>::new();
324
325    loop {
326        let request_cont = request.clone();
327
328        // construct the parallel requests, 2 is the default
329        let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
330        let count = conns.len().min(opts.num_concurrent_reqs.max(1));
331        for conn in conns.drain(..count) {
332            par_conns.push(conn);
333        }
334
335        if par_conns.is_empty() {
336            if !busy.is_empty() && backoff < Duration::from_millis(300) {
337                P::Time::delay_for(backoff).await;
338                conns.extend(busy.drain(..));
339                backoff *= 2;
340                continue;
341            }
342            return Err(err);
343        }
344
345        let mut requests = par_conns
346            .into_iter()
347            .map(move |mut conn| {
348                conn.send(request_cont.clone())
349                    .first_answer()
350                    .map(|result| result.map_err(|e| (conn, e)))
351            })
352            .collect::<FuturesUnordered<_>>();
353
354        while let Some(result) = requests.next().await {
355            let (conn, e) = match result {
356                Ok(sent) => return Ok(sent),
357                Err((conn, e)) => (conn, e),
358            };
359
360            match e.kind() {
361                ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
362                    return Err(e);
363                }
364                ResolveErrorKind::Proto(e) if e.is_busy() => {
365                    busy.push(conn);
366                }
367                _ if err.cmp_specificity(&e) == Ordering::Less => {
368                    err = e;
369                }
370                _ => {}
371            }
372        }
373    }
374}
375
376#[cfg(feature = "mdns")]
377mod mdns {
378    use super::*;
379
380    use proto::rr::domain::usage;
381    use proto::DnsHandle;
382
383    /// Returns true
384    pub(crate) fn maybe_local<C, P>(
385        name_server: &mut NameServer<C, P>,
386        request: DnsRequest,
387    ) -> Local
388    where
389        C: DnsHandle<Error = ResolveError> + 'static,
390        P: ConnectionProvider<Conn = C> + 'static,
391        P: ConnectionProvider,
392    {
393        if request
394            .queries()
395            .iter()
396            .any(|query| usage::LOCAL.name().zone_of(query.name()))
397        {
398            Local::ResolveStream(name_server.send(request))
399        } else {
400            Local::NotMdns(request)
401        }
402    }
403}
404
405pub(crate) enum Local {
406    #[allow(dead_code)]
407    ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
408    NotMdns(DnsRequest),
409}
410
411impl Local {
412    fn is_local(&self) -> bool {
413        matches!(*self, Self::ResolveStream(..))
414    }
415
416    /// Takes the stream
417    ///
418    /// # Panics
419    ///
420    /// Panics if this is in fact a Local::NotMdns
421    fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
422        match self {
423            Self::ResolveStream(future) => future,
424            _ => panic!("non Local queries have no future, see take_message()"),
425        }
426    }
427
428    /// Takes the message
429    ///
430    /// # Panics
431    ///
432    /// Panics if this is in fact a Local::ResolveStream
433    fn take_request(self) -> DnsRequest {
434        match self {
435            Self::NotMdns(request) => request,
436            _ => panic!("Local queries must be polled, see take_future()"),
437        }
438    }
439}
440
441impl Stream for Local {
442    type Item = Result<DnsResponse, ResolveError>;
443
444    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
445        match *self {
446            Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
447            // TODO: making this a panic for now
448            Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), //Local::NotMdns(message) => return Err(ResolveErrorKind::Message("not mDNS")),
449        }
450    }
451}
452
453#[cfg(test)]
454#[cfg(feature = "tokio-runtime")]
455mod tests {
456    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
457    use std::str::FromStr;
458
459    use tokio::runtime::Runtime;
460
461    use proto::op::Query;
462    use proto::rr::{Name, RecordType};
463    use proto::xfer::{DnsHandle, DnsRequestOptions};
464    use trust_dns_proto::rr::RData;
465
466    use super::*;
467    use crate::config::NameServerConfig;
468    use crate::config::Protocol;
469
470    #[ignore]
471    // because of there is a real connection that needs a reasonable timeout
472    #[test]
473    fn test_failed_then_success_pool() {
474        let config1 = NameServerConfig {
475            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
476            protocol: Protocol::Udp,
477            tls_dns_name: None,
478            trust_nx_responses: false,
479            #[cfg(feature = "dns-over-rustls")]
480            tls_config: None,
481            bind_addr: None,
482        };
483
484        let config2 = NameServerConfig {
485            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
486            protocol: Protocol::Udp,
487            tls_dns_name: None,
488            trust_nx_responses: false,
489            #[cfg(feature = "dns-over-rustls")]
490            tls_config: None,
491            bind_addr: None,
492        };
493
494        let mut resolver_config = ResolverConfig::new();
495        resolver_config.add_name_server(config1);
496        resolver_config.add_name_server(config2);
497
498        let io_loop = Runtime::new().unwrap();
499        let mut pool = NameServerPool::<_, TokioConnectionProvider>::tokio_from_config(
500            &resolver_config,
501            &ResolverOpts::default(),
502            TokioHandle,
503        );
504
505        let name = Name::parse("www.example.com.", None).unwrap();
506
507        // TODO: it's not clear why there are two failures before the success
508        for i in 0..2 {
509            assert!(
510                io_loop
511                    .block_on(
512                        pool.lookup(
513                            Query::query(name.clone(), RecordType::A),
514                            DnsRequestOptions::default()
515                        )
516                        .first_answer()
517                    )
518                    .is_err(),
519                "iter: {}",
520                i
521            );
522        }
523
524        for i in 0..10 {
525            assert!(
526                io_loop
527                    .block_on(
528                        pool.lookup(
529                            Query::query(name.clone(), RecordType::A),
530                            DnsRequestOptions::default()
531                        )
532                        .first_answer()
533                    )
534                    .is_ok(),
535                "iter: {}",
536                i
537            );
538        }
539    }
540
541    #[test]
542    fn test_multi_use_conns() {
543        let io_loop = Runtime::new().unwrap();
544        let conn_provider = TokioConnectionProvider::new(TokioHandle);
545
546        let tcp = NameServerConfig {
547            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
548            protocol: Protocol::Tcp,
549            tls_dns_name: None,
550            trust_nx_responses: false,
551            #[cfg(feature = "dns-over-rustls")]
552            tls_config: None,
553            bind_addr: None,
554        };
555
556        let opts = ResolverOpts {
557            try_tcp_on_error: true,
558            ..ResolverOpts::default()
559        };
560        let ns_config = { tcp };
561        let name_server = NameServer::new_with_provider(ns_config, opts, conn_provider);
562        let name_servers: Arc<[_]> = Arc::from([name_server]);
563
564        let mut pool = NameServerPool::from_nameservers_test(
565            &opts,
566            Arc::from([]),
567            Arc::clone(&name_servers),
568            #[cfg(feature = "mdns")]
569            name_server::mdns_nameserver(opts, TokioConnectionProvider::new(TokioHandle)),
570        );
571
572        let name = Name::from_str("www.example.com.").unwrap();
573
574        // first lookup
575        let response = io_loop
576            .block_on(
577                pool.lookup(
578                    Query::query(name.clone(), RecordType::A),
579                    DnsRequestOptions::default(),
580                )
581                .first_answer(),
582            )
583            .expect("lookup failed");
584
585        assert_eq!(
586            *response.answers()[0]
587                .data()
588                .and_then(RData::as_a)
589                .expect("no a record available"),
590            Ipv4Addr::new(93, 184, 216, 34)
591        );
592
593        assert!(
594            name_servers[0].is_connected(),
595            "if this is failing then the NameServers aren't being properly shared."
596        );
597
598        // first lookup
599        let response = io_loop
600            .block_on(
601                pool.lookup(
602                    Query::query(name, RecordType::AAAA),
603                    DnsRequestOptions::default(),
604                )
605                .first_answer(),
606            )
607            .expect("lookup failed");
608
609        assert_eq!(
610            *response.answers()[0]
611                .data()
612                .and_then(RData::as_aaaa)
613                .expect("no aaaa record available"),
614            Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
615        );
616
617        assert!(
618            name_servers[0].is_connected(),
619            "if this is failing then the NameServers aren't being properly shared."
620        );
621    }
622}