trust_dns_resolver/
caching_client.rs

1// Copyright 2015-2017 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Caching related functionality for the Resolver.
9
10use std::borrow::Cow;
11use std::error::Error;
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::atomic::{AtomicU8, Ordering};
15use std::sync::Arc;
16use std::time::Instant;
17
18use futures_util::future::Future;
19
20use proto::error::ProtoError;
21use proto::op::{Query, ResponseCode};
22use proto::rr::domain::usage::{
23    ResolverUsage, DEFAULT, INVALID, IN_ADDR_ARPA_127, IP6_ARPA_1, LOCAL,
24    LOCALHOST as LOCALHOST_usage, ONION,
25};
26use proto::rr::{DNSClass, Name, RData, Record, RecordType};
27use proto::xfer::{DnsHandle, DnsRequestOptions, DnsResponse, FirstAnswer};
28
29use crate::dns_lru::DnsLru;
30use crate::dns_lru::{self, TtlConfig};
31use crate::error::*;
32use crate::lookup::Lookup;
33
34const MAX_QUERY_DEPTH: u8 = 8; // arbitrarily chosen number...
35
36lazy_static! {
37    static ref LOCALHOST: RData = RData::PTR(Name::from_ascii("localhost.").unwrap());
38    static ref LOCALHOST_V4: RData = RData::A(Ipv4Addr::new(127, 0, 0, 1));
39    static ref LOCALHOST_V6: RData = RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
40}
41
42struct DepthTracker {
43    query_depth: Arc<AtomicU8>,
44}
45
46impl DepthTracker {
47    fn track(query_depth: Arc<AtomicU8>) -> Self {
48        query_depth.fetch_add(1, Ordering::Release);
49        Self { query_depth }
50    }
51}
52
53impl Drop for DepthTracker {
54    fn drop(&mut self) {
55        self.query_depth.fetch_sub(1, Ordering::Release);
56    }
57}
58
59// TODO: need to consider this storage type as it compares to Authority in server...
60//       should it just be an variation on Authority?
61#[derive(Clone, Debug)]
62#[doc(hidden)]
63pub struct CachingClient<C, E>
64where
65    C: DnsHandle<Error = E>,
66    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
67{
68    lru: DnsLru,
69    client: C,
70    query_depth: Arc<AtomicU8>,
71    preserve_intermediates: bool,
72}
73
74impl<C, E> CachingClient<C, E>
75where
76    C: DnsHandle<Error = E> + Send + 'static,
77    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
78{
79    #[doc(hidden)]
80    pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
81        Self::with_cache(
82            DnsLru::new(max_size, TtlConfig::default()),
83            client,
84            preserve_intermediates,
85        )
86    }
87
88    pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
89        let query_depth = Arc::new(AtomicU8::new(0));
90        Self {
91            lru,
92            client,
93            query_depth,
94            preserve_intermediates,
95        }
96    }
97
98    /// Perform a lookup against this caching client, looking first in the cache for a result
99    pub fn lookup(
100        &mut self,
101        query: Query,
102        options: DnsRequestOptions,
103    ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
104        Box::pin(Self::inner_lookup(query, options, self.clone(), vec![]))
105    }
106
107    async fn inner_lookup(
108        query: Query,
109        options: DnsRequestOptions,
110        mut client: Self,
111        preserved_records: Vec<(Record, u32)>,
112    ) -> Result<Lookup, ResolveError> {
113        // see https://tools.ietf.org/html/rfc6761
114        //
115        // ```text
116        // Name resolution APIs and libraries SHOULD recognize localhost
117        // names as special and SHOULD always return the IP loopback address
118        // for address queries and negative responses for all other query
119        // types.  Name resolution APIs SHOULD NOT send queries for
120        // localhost names to their configured caching DNS server(s).
121        // ```
122        // special use rules only apply to the IN Class
123        if query.query_class() == DNSClass::IN {
124            let usage = match query.name() {
125                n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
126                n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
127                n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
128                n if INVALID.zone_of(n) => &*INVALID,
129                n if LOCAL.zone_of(n) => &*LOCAL,
130                n if ONION.zone_of(n) => &*ONION,
131                _ => &*DEFAULT,
132            };
133
134            match usage.resolver() {
135                ResolverUsage::Loopback => match query.query_type() {
136                    // TODO: look in hosts for these ips/names first...
137                    RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
138                    RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
139                    RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
140                    _ => {
141                        return Err(ResolveError::nx_error(
142                            query,
143                            None,
144                            None,
145                            ResponseCode::NoError,
146                            false,
147                        ))
148                    } // Are there any other types we can use?
149                },
150                // when mdns is enabled we will follow a standard query path
151                #[cfg(feature = "mdns")]
152                ResolverUsage::LinkLocal => (),
153                // TODO: this requires additional config, as Kubernetes and other systems misuse the .local. zone.
154                // when mdns is not enabled we will return errors on LinkLocal ("*.local.") names
155                #[cfg(not(feature = "mdns"))]
156                ResolverUsage::LinkLocal => (),
157                ResolverUsage::NxDomain => {
158                    return Err(ResolveError::nx_error(
159                        query,
160                        None,
161                        None,
162                        ResponseCode::NXDomain,
163                        false,
164                    ))
165                }
166                ResolverUsage::Normal => (),
167            }
168        }
169
170        let _tracker = DepthTracker::track(client.query_depth.clone());
171        let is_dnssec = client.client.is_verifying_dnssec();
172
173        // first transition any polling that is needed (mutable refs...)
174        if let Some(cached_lookup) = client.lookup_from_cache(&query) {
175            return cached_lookup;
176        };
177
178        let response_message = client
179            .client
180            .lookup(query.clone(), options)
181            .first_answer()
182            .await
183            .map_err(E::into);
184
185        // TODO: technically this might be duplicating work, as name_server already performs this evaluation.
186        //  we may want to create a new type, if evaluated... but this is most generic to support any impl in LookupState...
187        let response_message = if let Ok(response) = response_message {
188            ResolveError::from_response(response, false)
189        } else {
190            response_message
191        };
192
193        // TODO: take all records and cache them?
194        //  if it's DNSSec they must be signed, otherwise?
195        let records: Result<Records, ResolveError> = match response_message {
196            // this is the only cacheable form
197            Err(ResolveError {
198                kind:
199                    ResolveErrorKind::NoRecordsFound {
200                        query,
201                        soa,
202                        negative_ttl,
203                        response_code,
204                        trusted,
205                    },
206                ..
207            }) => {
208                Err(Self::handle_nxdomain(
209                    is_dnssec,
210                    false, /*tbd*/
211                    *query,
212                    soa.map(|v| *v),
213                    negative_ttl,
214                    response_code,
215                    trusted,
216                ))
217            }
218            Err(e) => return Err(e),
219            Ok(response_message) => {
220                // allow the handle_noerror function to deal with any error codes
221                let records = Self::handle_noerror(
222                    &mut client,
223                    options,
224                    is_dnssec,
225                    &query,
226                    response_message,
227                    preserved_records,
228                )?;
229
230                Ok(records)
231            }
232        };
233
234        // after the request, evaluate if we have additional queries to perform
235        match records {
236            Ok(Records::CnameChain {
237                next: future,
238                min_ttl: ttl,
239            }) => client.cname(future.await?, query, ttl),
240            Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
241            Err(e) => client.cache(query, Err(e)),
242        }
243    }
244
245    /// Check if this query is already cached
246    fn lookup_from_cache(&self, query: &Query) -> Option<Result<Lookup, ResolveError>> {
247        self.lru.get(query, Instant::now())
248    }
249
250    /// See https://tools.ietf.org/html/rfc2308
251    ///
252    /// For now we will regard NXDomain to strictly mean the query failed
253    ///  and a record for the name, regardless of CNAME presence, what have you
254    ///  ultimately does not exist.
255    ///
256    /// This also handles empty responses in the same way. When performing DNSSec enabled queries, we should
257    ///  never enter here, and should never cache unless verified requests.
258    ///
259    /// TODO: should this should be expanded to do a forward lookup? Today, this will fail even if there are
260    ///   forwarding options.
261    ///
262    /// # Arguments
263    ///
264    /// * `message` - message to extract SOA, etc, from for caching failed requests
265    /// * `valid_nsec` - species that in DNSSec mode, this request is safe to cache
266    /// * `negative_ttl` - this should be the SOA minimum for negative ttl
267    fn handle_nxdomain(
268        is_dnssec: bool,
269        valid_nsec: bool,
270        query: Query,
271        soa: Option<Record>,
272        negative_ttl: Option<u32>,
273        response_code: ResponseCode,
274        trusted: bool,
275    ) -> ResolveError {
276        if valid_nsec || !is_dnssec {
277            // only trust if there were validated NSEC records
278            ResolveErrorKind::NoRecordsFound {
279                query: Box::new(query),
280                soa: soa.map(Box::new),
281                negative_ttl,
282                response_code,
283                trusted: true,
284            }
285            .into()
286        } else {
287            // not cacheable, no ttl...
288            ResolveErrorKind::NoRecordsFound {
289                query: Box::new(query),
290                soa: soa.map(Box::new),
291                negative_ttl: None,
292                response_code,
293                trusted,
294            }
295            .into()
296        }
297    }
298
299    /// Handle the case where there is no error returned
300    fn handle_noerror(
301        client: &mut Self,
302        options: DnsRequestOptions,
303        is_dnssec: bool,
304        query: &Query,
305        mut response: DnsResponse,
306        mut preserved_records: Vec<(Record, u32)>,
307    ) -> Result<Records, ResolveError> {
308        // initial ttl is what CNAMES for min usage
309        const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
310
311        // need to capture these before the subsequent and destructive record processing
312        let soa = response.soa().cloned();
313        let negative_ttl = response.negative_ttl();
314        let response_code = response.response_code();
315
316        // seek out CNAMES, this is only performed if the query is not a CNAME, ANY, or SRV
317        // FIXME: for SRV this evaluation is inadequate. CNAME is a single chain to a single record
318        //   for SRV, there could be many different targets. The search_name needs to be enhanced to
319        //   be a list of names found for SRV records.
320        let (search_name, cname_ttl, was_cname, preserved_records) = {
321            // this will only search for CNAMEs if the request was not meant to be for one of the triggers for recursion
322            let (search_name, cname_ttl, was_cname) =
323                if query.query_type().is_any() || query.query_type().is_cname() {
324                    (Cow::Borrowed(query.name()), INITIAL_TTL, false)
325                } else {
326                    // Folds any cnames from the answers section, into the final cname in the answers section
327                    //   this works by folding the last CNAME found into the final folded result.
328                    //   it assumes that the CNAMEs are in chained order in the DnsResponse Message...
329                    // For SRV, the name added for the search becomes the target name.
330                    //
331                    // TODO: should this include the additionals?
332                    response.answers().iter().fold(
333                        (Cow::Borrowed(query.name()), INITIAL_TTL, false),
334                        |(search_name, cname_ttl, was_cname), r| {
335                            match r.data() {
336                                Some(RData::CNAME(ref cname)) => {
337                                    // take the minimum TTL of the cname_ttl and the next record in the chain
338                                    let ttl = cname_ttl.min(r.ttl());
339                                    debug_assert_eq!(r.rr_type(), RecordType::CNAME);
340                                    if search_name.as_ref() == r.name() {
341                                        return (Cow::Owned(cname.clone()), ttl, true);
342                                    }
343                                }
344                                Some(RData::SRV(ref srv)) => {
345                                    // take the minimum TTL of the cname_ttl and the next record in the chain
346                                    let ttl = cname_ttl.min(r.ttl());
347                                    debug_assert_eq!(r.rr_type(), RecordType::SRV);
348
349                                    // the search name becomes the srv.target
350                                    return (Cow::Owned(srv.target().clone()), ttl, true);
351                                }
352                                _ => (),
353                            }
354
355                            (search_name, cname_ttl, was_cname)
356                        },
357                    )
358                };
359
360            // take all answers. // TODO: following CNAMES?
361            let answers = response.take_answers();
362            let additionals = response.take_additionals();
363            let name_servers = response.take_name_servers();
364
365            // set of names that still require resolution
366            // TODO: this needs to be enhanced for SRV
367            let mut found_name = false;
368
369            // After following all the CNAMES to the last one, try and lookup the final name
370            let records = answers
371                .into_iter()
372                // Chained records will generally exist in the additionals section
373                .chain(additionals.into_iter())
374                .chain(name_servers.into_iter())
375                .filter_map(|r| {
376                    // because this resolved potentially recursively, we want the min TTL from the chain
377                    let ttl = cname_ttl.min(r.ttl());
378                    // TODO: disable name validation with ResolverOpts? glibc feature...
379                    // restrict to the RData type requested
380                    if query.query_class() == r.dns_class() {
381                        // standard evaluation, it's an any type or it's the requested type and the search_name matches
382                        #[allow(clippy::suspicious_operation_groupings)]
383                        if (query.query_type().is_any() || query.query_type() == r.rr_type())
384                            && (search_name.as_ref() == r.name() || query.name() == r.name())
385                        {
386                            found_name = true;
387                            return Some((r, ttl));
388                        }
389                        // CNAME evaluation, it's an A/AAAA lookup and the record is from the CNAME lookup chain.
390                        if client.preserve_intermediates
391                            && r.rr_type() == RecordType::CNAME
392                            && (query.query_type() == RecordType::A
393                                || query.query_type() == RecordType::AAAA)
394                        {
395                            return Some((r, ttl));
396                        }
397                        // srv evaluation, it's an srv lookup and the srv_search_name/target matches this name
398                        //    and it's an IP
399                        if query.query_type().is_srv()
400                            && r.rr_type().is_ip_addr()
401                            && search_name.as_ref() == r.name()
402                        {
403                            found_name = true;
404                            Some((r, ttl))
405                        } else if query.query_type().is_ns() && r.rr_type().is_ip_addr() {
406                            Some((r, ttl))
407                        } else {
408                            None
409                        }
410                    } else {
411                        None
412                    }
413                })
414                .collect::<Vec<_>>();
415
416            // adding the newly collected records to the preserved records
417            preserved_records.extend(records);
418            if !preserved_records.is_empty() && found_name {
419                return Ok(Records::Exists(preserved_records));
420            }
421
422            (
423                search_name.into_owned(),
424                cname_ttl,
425                was_cname,
426                preserved_records,
427            )
428        };
429
430        // TODO: for SRV records we *could* do an implicit lookup, but, this requires knowing the type of IP desired
431        //    for now, we'll make the API require the user to perform a follow up to the lookups.
432        // It was a CNAME, but not included in the request...
433        if was_cname && client.query_depth.load(Ordering::Acquire) < MAX_QUERY_DEPTH {
434            let next_query = Query::query(search_name, query.query_type());
435            Ok(Records::CnameChain {
436                next: Box::pin(Self::inner_lookup(
437                    next_query,
438                    options,
439                    client.clone(),
440                    preserved_records,
441                )),
442                min_ttl: cname_ttl,
443            })
444        } else {
445            // TODO: review See https://tools.ietf.org/html/rfc2308 for NoData section
446            // Note on DNSSec, in secure_client_handle, if verify_nsec fails then the request fails.
447            //   this will mean that no unverified negative caches will make it to this point and be stored
448            Err(Self::handle_nxdomain(
449                is_dnssec,
450                true,
451                query.clone(),
452                soa,
453                negative_ttl,
454                response_code,
455                false,
456            ))
457        }
458    }
459
460    #[allow(clippy::unnecessary_wraps)]
461    fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ResolveError> {
462        // this duplicates the cache entry under the original query
463        Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
464    }
465
466    fn cache(
467        &self,
468        query: Query,
469        records: Result<Vec<(Record, u32)>, ResolveError>,
470    ) -> Result<Lookup, ResolveError> {
471        // this will put this object into an inconsistent state, but no one should call poll again...
472        match records {
473            Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
474            Err(err) => Err(self.lru.negative(query, err, Instant::now())),
475        }
476    }
477
478    /// Flushes/Removes all entries from the cache
479    pub fn clear_cache(&self) {
480        self.lru.clear();
481    }
482}
483
484enum Records {
485    /// The records exists, a vec of rdata with ttl
486    Exists(Vec<(Record, u32)>),
487    /// Future lookup for recursive cname records
488    CnameChain {
489        next: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
490        min_ttl: u32,
491    },
492}
493
494// see also the lookup_tests.rs in integration-tests crate
495#[cfg(test)]
496mod tests {
497    use std::net::*;
498    use std::str::FromStr;
499    use std::time::*;
500
501    use futures_executor::block_on;
502    use proto::op::{Message, Query};
503    use proto::rr::rdata::SRV;
504    use proto::rr::{Name, Record};
505
506    use super::*;
507    use crate::lookup_ip::tests::*;
508
509    #[test]
510    fn test_empty_cache() {
511        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
512        let client = mock(vec![empty()]);
513        let client = CachingClient::with_cache(cache, client, false);
514
515        if let ResolveErrorKind::NoRecordsFound {
516            query,
517            negative_ttl,
518            ..
519        } = block_on(CachingClient::inner_lookup(
520            Query::new(),
521            DnsRequestOptions::default(),
522            client,
523            vec![],
524        ))
525        .unwrap_err()
526        .kind()
527        {
528            assert_eq!(**query, Query::new());
529            assert_eq!(*negative_ttl, None);
530        } else {
531            panic!("wrong error received")
532        }
533    }
534
535    #[test]
536    fn test_from_cache() {
537        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
538        let query = Query::new();
539        cache.insert(
540            query.clone(),
541            vec![(
542                Record::from_rdata(
543                    query.name().clone(),
544                    u32::max_value(),
545                    RData::A(Ipv4Addr::new(127, 0, 0, 1)),
546                ),
547                u32::max_value(),
548            )],
549            Instant::now(),
550        );
551
552        let client = mock(vec![empty()]);
553        let client = CachingClient::with_cache(cache, client, false);
554
555        let ips = block_on(CachingClient::inner_lookup(
556            Query::new(),
557            DnsRequestOptions::default(),
558            client,
559            vec![],
560        ))
561        .unwrap();
562
563        assert_eq!(
564            ips.iter().cloned().collect::<Vec<_>>(),
565            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
566        );
567    }
568
569    #[test]
570    fn test_no_cache_insert() {
571        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
572        // first should come from client...
573        let client = mock(vec![v4_message()]);
574        let client = CachingClient::with_cache(cache.clone(), client, false);
575
576        let ips = block_on(CachingClient::inner_lookup(
577            Query::new(),
578            DnsRequestOptions::default(),
579            client,
580            vec![],
581        ))
582        .unwrap();
583
584        assert_eq!(
585            ips.iter().cloned().collect::<Vec<_>>(),
586            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
587        );
588
589        // next should come from cache...
590        let client = mock(vec![empty()]);
591        let client = CachingClient::with_cache(cache, client, false);
592
593        let ips = block_on(CachingClient::inner_lookup(
594            Query::new(),
595            DnsRequestOptions::default(),
596            client,
597            vec![],
598        ))
599        .unwrap();
600
601        assert_eq!(
602            ips.iter().cloned().collect::<Vec<_>>(),
603            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
604        );
605    }
606
607    #[allow(clippy::unnecessary_wraps)]
608    pub(crate) fn cname_message() -> Result<DnsResponse, ResolveError> {
609        let mut message = Message::new();
610        message.add_query(Query::query(
611            Name::from_str("www.example.com.").unwrap(),
612            RecordType::A,
613        ));
614        message.insert_answers(vec![Record::from_rdata(
615            Name::from_str("www.example.com.").unwrap(),
616            86400,
617            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
618        )]);
619        Ok(message.into())
620    }
621
622    #[allow(clippy::unnecessary_wraps)]
623    pub(crate) fn srv_message() -> Result<DnsResponse, ResolveError> {
624        let mut message = Message::new();
625        message.add_query(Query::query(
626            Name::from_str("_443._tcp.www.example.com.").unwrap(),
627            RecordType::SRV,
628        ));
629        message.insert_answers(vec![Record::from_rdata(
630            Name::from_str("_443._tcp.www.example.com.").unwrap(),
631            86400,
632            RData::SRV(SRV::new(
633                1,
634                2,
635                443,
636                Name::from_str("www.example.com.").unwrap(),
637            )),
638        )]);
639        Ok(message.into())
640    }
641
642    #[allow(clippy::unnecessary_wraps)]
643    pub(crate) fn ns_message() -> Result<DnsResponse, ResolveError> {
644        let mut message = Message::new();
645        message.add_query(Query::query(
646            Name::from_str("www.example.com.").unwrap(),
647            RecordType::NS,
648        ));
649        message.insert_answers(vec![Record::from_rdata(
650            Name::from_str("www.example.com.").unwrap(),
651            86400,
652            RData::NS(Name::from_str("www.example.com.").unwrap()),
653        )]);
654        Ok(message.into())
655    }
656
657    fn no_recursion_on_query_test(query_type: RecordType) {
658        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
659
660        // the cname should succeed, we shouldn't query again after that, which would cause an error...
661        let client = mock(vec![error(), cname_message()]);
662        let client = CachingClient::with_cache(cache, client, false);
663
664        let ips = block_on(CachingClient::inner_lookup(
665            Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
666            DnsRequestOptions::default(),
667            client,
668            vec![],
669        ))
670        .expect("lookup failed");
671
672        assert_eq!(
673            ips.iter().cloned().collect::<Vec<_>>(),
674            vec![RData::CNAME(Name::from_str("actual.example.com.").unwrap())]
675        );
676    }
677
678    #[test]
679    fn test_no_recursion_on_cname_query() {
680        no_recursion_on_query_test(RecordType::CNAME);
681    }
682
683    #[test]
684    fn test_no_recursion_on_all_query() {
685        no_recursion_on_query_test(RecordType::ANY);
686    }
687
688    #[test]
689    fn test_non_recursive_srv_query() {
690        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
691
692        // the cname should succeed, we shouldn't query again after that, which would cause an error...
693        let client = mock(vec![error(), srv_message()]);
694        let client = CachingClient::with_cache(cache, client, false);
695
696        let ips = block_on(CachingClient::inner_lookup(
697            Query::query(
698                Name::from_str("_443._tcp.www.example.com.").unwrap(),
699                RecordType::SRV,
700            ),
701            DnsRequestOptions::default(),
702            client,
703            vec![],
704        ))
705        .expect("lookup failed");
706
707        assert_eq!(
708            ips.iter().cloned().collect::<Vec<_>>(),
709            vec![RData::SRV(SRV::new(
710                1,
711                2,
712                443,
713                Name::from_str("www.example.com.").unwrap(),
714            ))]
715        );
716    }
717
718    #[test]
719    fn test_single_srv_query_response() {
720        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
721
722        let mut message = srv_message().unwrap();
723        message.add_answer(Record::from_rdata(
724            Name::from_str("www.example.com.").unwrap(),
725            86400,
726            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
727        ));
728        message.insert_additionals(vec![
729            Record::from_rdata(
730                Name::from_str("actual.example.com.").unwrap(),
731                86400,
732                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
733            ),
734            Record::from_rdata(
735                Name::from_str("actual.example.com.").unwrap(),
736                86400,
737                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
738            ),
739        ]);
740
741        let client = mock(vec![error(), Ok(message)]);
742        let client = CachingClient::with_cache(cache, client, false);
743
744        let ips = block_on(CachingClient::inner_lookup(
745            Query::query(
746                Name::from_str("_443._tcp.www.example.com.").unwrap(),
747                RecordType::SRV,
748            ),
749            DnsRequestOptions::default(),
750            client,
751            vec![],
752        ))
753        .expect("lookup failed");
754
755        assert_eq!(
756            ips.iter().cloned().collect::<Vec<_>>(),
757            vec![
758                RData::SRV(SRV::new(
759                    1,
760                    2,
761                    443,
762                    Name::from_str("www.example.com.").unwrap(),
763                )),
764                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
765                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
766            ]
767        );
768    }
769
770    // TODO: if we ever enable recursive lookups for SRV, here are the tests...
771    // #[test]
772    // fn test_recursive_srv_query() {
773    //     let cache = Arc::new(Mutex::new(DnsLru::new(1)));
774
775    //     let mut message = Message::new();
776    //     message.add_answer(Record::from_rdata(
777    //         Name::from_str("www.example.com.").unwrap(),
778    //         86400,
779    //         RecordType::CNAME,
780    //         RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
781    //     ));
782    //     message.insert_additionals(vec![
783    //         Record::from_rdata(
784    //             Name::from_str("actual.example.com.").unwrap(),
785    //             86400,
786    //             RecordType::A,
787    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
788    //         ),
789    //     ]);
790
791    //     let mut client = mock(vec![error(), Ok(message.into()), srv_message()]);
792
793    //     let ips = QueryState::lookup(
794    //         Query::query(
795    //             Name::from_str("_443._tcp.www.example.com.").unwrap(),
796    //             RecordType::SRV,
797    //         ),
798    //         Default::default(),
799    //         &mut client,
800    //         cache.clone(),
801    //     ).wait()
802    //         .expect("lookup failed");
803
804    //     assert_eq!(
805    //         ips.iter().cloned().collect::<Vec<_>>(),
806    //         vec![
807    //             RData::SRV(SRV::new(
808    //                 1,
809    //                 2,
810    //                 443,
811    //                 Name::from_str("www.example.com.").unwrap(),
812    //             )),
813    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
814    //             //RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
815    //         ]
816    //     );
817    // }
818
819    #[test]
820    fn test_single_ns_query_response() {
821        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
822
823        let mut message = ns_message().unwrap();
824        message.add_answer(Record::from_rdata(
825            Name::from_str("www.example.com.").unwrap(),
826            86400,
827            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
828        ));
829        message.insert_additionals(vec![
830            Record::from_rdata(
831                Name::from_str("actual.example.com.").unwrap(),
832                86400,
833                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
834            ),
835            Record::from_rdata(
836                Name::from_str("actual.example.com.").unwrap(),
837                86400,
838                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
839            ),
840        ]);
841
842        let client = mock(vec![error(), Ok(message)]);
843        let client = CachingClient::with_cache(cache, client, false);
844
845        let ips = block_on(CachingClient::inner_lookup(
846            Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::NS),
847            DnsRequestOptions::default(),
848            client,
849            vec![],
850        ))
851        .expect("lookup failed");
852
853        assert_eq!(
854            ips.iter().cloned().collect::<Vec<_>>(),
855            vec![
856                RData::NS(Name::from_str("www.example.com.").unwrap()),
857                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
858                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
859            ]
860        );
861    }
862
863    fn cname_ttl_test(first: u32, second: u32) {
864        let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
865        // expecting no queries to be performed
866        let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
867
868        let mut message = Message::new();
869        message.insert_answers(vec![Record::from_rdata(
870            Name::from_str("ttl.example.com.").unwrap(),
871            first,
872            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
873        )]);
874        message.insert_additionals(vec![Record::from_rdata(
875            Name::from_str("actual.example.com.").unwrap(),
876            second,
877            RData::A(Ipv4Addr::new(127, 0, 0, 1)),
878        )]);
879
880        let records = CachingClient::handle_noerror(
881            &mut client,
882            DnsRequestOptions::default(),
883            false,
884            &Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
885            message.into(),
886            vec![],
887        );
888
889        if let Ok(records) = records {
890            if let Records::Exists(records) = records {
891                for (record, ttl) in records.iter() {
892                    if record.record_type() == RecordType::CNAME {
893                        continue;
894                    }
895                    assert_eq!(ttl, &1);
896                }
897            } else {
898                panic!("records don't exist");
899            }
900        } else {
901            panic!("error getting records");
902        }
903    }
904
905    #[test]
906    fn test_cname_ttl() {
907        cname_ttl_test(1, 2);
908        cname_ttl_test(2, 1);
909    }
910
911    #[test]
912    fn test_early_return_localhost() {
913        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
914        let client = mock(vec![empty()]);
915        let mut client = CachingClient::with_cache(cache, client, false);
916
917        {
918            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
919            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
920                .expect("should have returned localhost");
921            assert_eq!(lookup.query(), &query);
922            assert_eq!(
923                lookup.iter().cloned().collect::<Vec<_>>(),
924                vec![LOCALHOST_V4.clone()]
925            );
926        }
927
928        {
929            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
930            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
931                .expect("should have returned localhost");
932            assert_eq!(lookup.query(), &query);
933            assert_eq!(
934                lookup.iter().cloned().collect::<Vec<_>>(),
935                vec![LOCALHOST_V6.clone()]
936            );
937        }
938
939        {
940            let query = Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::PTR);
941            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
942                .expect("should have returned localhost");
943            assert_eq!(lookup.query(), &query);
944            assert_eq!(
945                lookup.iter().cloned().collect::<Vec<_>>(),
946                vec![LOCALHOST.clone()]
947            );
948        }
949
950        {
951            let query = Query::query(
952                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
953                RecordType::PTR,
954            );
955            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
956                .expect("should have returned localhost");
957            assert_eq!(lookup.query(), &query);
958            assert_eq!(
959                lookup.iter().cloned().collect::<Vec<_>>(),
960                vec![LOCALHOST.clone()]
961            );
962        }
963
964        assert!(block_on(client.lookup(
965            Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
966            DnsRequestOptions::default()
967        ))
968        .is_err());
969
970        assert!(block_on(client.lookup(
971            Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::MX),
972            DnsRequestOptions::default()
973        ))
974        .is_err());
975
976        assert!(block_on(client.lookup(
977            Query::query(
978                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
979                RecordType::MX
980            ),
981            DnsRequestOptions::default()
982        ))
983        .is_err());
984    }
985
986    #[test]
987    fn test_early_return_invalid() {
988        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
989        let client = mock(vec![empty()]);
990        let mut client = CachingClient::with_cache(cache, client, false);
991
992        assert!(block_on(client.lookup(
993            Query::query(
994                Name::from_ascii("horrible.invalid.").unwrap(),
995                RecordType::A,
996            ),
997            DnsRequestOptions::default()
998        ))
999        .is_err());
1000    }
1001
1002    #[test]
1003    fn test_no_error_on_dot_local_no_mdns() {
1004        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
1005
1006        let mut message = srv_message().unwrap();
1007        message.add_query(Query::query(
1008            Name::from_ascii("www.example.local.").unwrap(),
1009            RecordType::A,
1010        ));
1011        message.add_answer(Record::from_rdata(
1012            Name::from_str("www.example.local.").unwrap(),
1013            86400,
1014            RData::A(Ipv4Addr::new(127, 0, 0, 1)),
1015        ));
1016
1017        let client = mock(vec![error(), Ok(message)]);
1018        let mut client = CachingClient::with_cache(cache, client, false);
1019
1020        assert!(block_on(client.lookup(
1021            Query::query(
1022                Name::from_ascii("www.example.local.").unwrap(),
1023                RecordType::A,
1024            ),
1025            DnsRequestOptions::default()
1026        ))
1027        .is_ok());
1028    }
1029}