Skip to main content

trust_dns_resolver/
caching_client.rs

1// Copyright 2015-2017 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Caching related functionality for the Resolver.
9
10use std::borrow::Cow;
11use std::error::Error;
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::atomic::{AtomicU8, Ordering};
15use std::sync::Arc;
16use std::time::Instant;
17
18use futures_util::future::Future;
19
20use proto::error::ProtoError;
21use proto::op::{Query, ResponseCode};
22use proto::rr::domain::usage::{
23    ResolverUsage, DEFAULT, INVALID, IN_ADDR_ARPA_127, IP6_ARPA_1, LOCAL,
24    LOCALHOST as LOCALHOST_usage, ONION,
25};
26use proto::rr::{DNSClass, Name, RData, Record, RecordType};
27use proto::xfer::{DnsHandle, DnsRequestOptions, DnsResponse, FirstAnswer};
28
29use crate::dns_lru::DnsLru;
30use crate::dns_lru::{self, TtlConfig};
31use crate::error::*;
32use crate::lookup::Lookup;
33
34const MAX_QUERY_DEPTH: u8 = 8; // arbitrarily chosen number...
35
36lazy_static! {
37    static ref LOCALHOST: RData = RData::PTR(Name::from_ascii("localhost.").unwrap());
38    static ref LOCALHOST_V4: RData = RData::A(Ipv4Addr::new(127, 0, 0, 1));
39    static ref LOCALHOST_V6: RData = RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
40}
41
42struct DepthTracker {
43    query_depth: Arc<AtomicU8>,
44}
45
46impl DepthTracker {
47    fn track(query_depth: Arc<AtomicU8>) -> Self {
48        query_depth.fetch_add(1, Ordering::Release);
49        Self { query_depth }
50    }
51}
52
53impl Drop for DepthTracker {
54    fn drop(&mut self) {
55        self.query_depth.fetch_sub(1, Ordering::Release);
56    }
57}
58
59// TODO: need to consider this storage type as it compares to Authority in server...
60//       should it just be an variation on Authority?
61#[derive(Clone, Debug)]
62#[doc(hidden)]
63pub struct CachingClient<C, E>
64where
65    C: DnsHandle<Error = E>,
66    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
67{
68    lru: DnsLru,
69    client: C,
70    query_depth: Arc<AtomicU8>,
71    preserve_intermediates: bool,
72}
73
74impl<C, E> CachingClient<C, E>
75where
76    C: DnsHandle<Error = E> + Send + 'static,
77    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
78{
79    #[doc(hidden)]
80    pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
81        Self::with_cache(
82            DnsLru::new(max_size, TtlConfig::default()),
83            client,
84            preserve_intermediates,
85        )
86    }
87
88    pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
89        let query_depth = Arc::new(AtomicU8::new(0));
90        Self {
91            lru,
92            client,
93            query_depth,
94            preserve_intermediates,
95        }
96    }
97
98    /// Perform a lookup against this caching client, looking first in the cache for a result
99    pub fn lookup(
100        &mut self,
101        query: Query,
102        options: DnsRequestOptions,
103    ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
104        Box::pin(Self::inner_lookup(query, options, self.clone(), vec![]))
105    }
106
107    async fn inner_lookup(
108        query: Query,
109        options: DnsRequestOptions,
110        mut client: Self,
111        preserved_records: Vec<(Record, u32)>,
112    ) -> Result<Lookup, ResolveError> {
113        // see https://tools.ietf.org/html/rfc6761
114        //
115        // ```text
116        // Name resolution APIs and libraries SHOULD recognize localhost
117        // names as special and SHOULD always return the IP loopback address
118        // for address queries and negative responses for all other query
119        // types.  Name resolution APIs SHOULD NOT send queries for
120        // localhost names to their configured caching DNS server(s).
121        // ```
122        // special use rules only apply to the IN Class
123        if query.query_class() == DNSClass::IN {
124            let usage = match query.name() {
125                n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
126                n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
127                n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
128                n if INVALID.zone_of(n) => &*INVALID,
129                n if LOCAL.zone_of(n) => &*LOCAL,
130                n if ONION.zone_of(n) => &*ONION,
131                _ => &*DEFAULT,
132            };
133
134            match usage.resolver() {
135                ResolverUsage::Loopback => match query.query_type() {
136                    // TODO: look in hosts for these ips/names first...
137                    RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
138                    RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
139                    RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
140                    _ => {
141                        return Err(ResolveError::nx_error(
142                            query,
143                            None,
144                            None,
145                            ResponseCode::NoError,
146                            false,
147                        ))
148                    } // Are there any other types we can use?
149                },
150                // when mdns is enabled we will follow a standard query path
151                #[cfg(feature = "mdns")]
152                ResolverUsage::LinkLocal => (),
153                // TODO: this requires additional config, as Kubernetes and other systems misuse the .local. zone.
154                // when mdns is not enabled we will return errors on LinkLocal ("*.local.") names
155                #[cfg(not(feature = "mdns"))]
156                ResolverUsage::LinkLocal => (),
157                ResolverUsage::NxDomain => {
158                    return Err(ResolveError::nx_error(
159                        query,
160                        None,
161                        None,
162                        ResponseCode::NXDomain,
163                        false,
164                    ))
165                }
166                ResolverUsage::Normal => (),
167            }
168        }
169
170        let _tracker = DepthTracker::track(client.query_depth.clone());
171        let is_dnssec = client.client.is_verifying_dnssec();
172
173        // first transition any polling that is needed (mutable refs...)
174        if let Some(cached_lookup) = client.lookup_from_cache(&query) {
175            return cached_lookup;
176        };
177
178        let response_message = client
179            .client
180            .lookup(query.clone(), options)
181            .first_answer()
182            .await
183            .map_err(E::into);
184
185        // TODO: technically this might be duplicating work, as name_server already performs this evaluation.
186        //  we may want to create a new type, if evaluated... but this is most generic to support any impl in LookupState...
187        let response_message = if let Ok(response) = response_message {
188            ResolveError::from_response(response, false)
189        } else {
190            response_message
191        };
192
193        // TODO: take all records and cache them?
194        //  if it's DNSSec they must be signed, otherwise?
195        let records: Result<Records, ResolveError> = match response_message {
196            // this is the only cacheable form
197            Err(ResolveError {
198                kind:
199                    ResolveErrorKind::NoRecordsFound {
200                        query,
201                        soa,
202                        negative_ttl,
203                        response_code,
204                        trusted,
205                    },
206                ..
207            }) => {
208                Err(Self::handle_nxdomain(
209                    is_dnssec,
210                    false, /*tbd*/
211                    *query,
212                    soa.map(|v| *v),
213                    negative_ttl,
214                    response_code,
215                    trusted,
216                ))
217            }
218            Err(e) => return Err(e),
219            Ok(response_message) => {
220                // allow the handle_noerror function to deal with any error codes
221                let records = Self::handle_noerror(
222                    &mut client,
223                    options,
224                    is_dnssec,
225                    &query,
226                    response_message,
227                    preserved_records,
228                )?;
229
230                Ok(records)
231            }
232        };
233
234        // after the request, evaluate if we have additional queries to perform
235        match records {
236            Ok(Records::CnameChain {
237                next: future,
238                min_ttl: ttl,
239            }) => client.cname(future.await?, query, ttl),
240            Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
241            Err(e) => client.cache(query, Err(e)),
242        }
243    }
244
245    /// Check if this query is already cached
246    fn lookup_from_cache(&self, query: &Query) -> Option<Result<Lookup, ResolveError>> {
247        self.lru.get(query, Instant::now())
248    }
249
250    /// See https://tools.ietf.org/html/rfc2308
251    ///
252    /// For now we will regard NXDomain to strictly mean the query failed
253    ///  and a record for the name, regardless of CNAME presence, what have you
254    ///  ultimately does not exist.
255    ///
256    /// This also handles empty responses in the same way. When performing DNSSec enabled queries, we should
257    ///  never enter here, and should never cache unless verified requests.
258    ///
259    /// TODO: should this should be expanded to do a forward lookup? Today, this will fail even if there are
260    ///   forwarding options.
261    ///
262    /// # Arguments
263    ///
264    /// * `message` - message to extract SOA, etc, from for caching failed requests
265    /// * `valid_nsec` - species that in DNSSec mode, this request is safe to cache
266    /// * `negative_ttl` - this should be the SOA minimum for negative ttl
267    fn handle_nxdomain(
268        is_dnssec: bool,
269        valid_nsec: bool,
270        query: Query,
271        soa: Option<Record>,
272        negative_ttl: Option<u32>,
273        response_code: ResponseCode,
274        trusted: bool,
275    ) -> ResolveError {
276        if valid_nsec || !is_dnssec {
277            // only trust if there were validated NSEC records
278            ResolveErrorKind::NoRecordsFound {
279                query: Box::new(query),
280                soa: soa.map(Box::new),
281                negative_ttl,
282                response_code,
283                trusted: true,
284            }
285            .into()
286        } else {
287            // not cacheable, no ttl...
288            ResolveErrorKind::NoRecordsFound {
289                query: Box::new(query),
290                soa: soa.map(Box::new),
291                negative_ttl: None,
292                response_code,
293                trusted,
294            }
295            .into()
296        }
297    }
298
299    /// Handle the case where there is no error returned
300    fn handle_noerror(
301        client: &mut Self,
302        options: DnsRequestOptions,
303        is_dnssec: bool,
304        query: &Query,
305        mut response: DnsResponse,
306        mut preserved_records: Vec<(Record, u32)>,
307    ) -> Result<Records, ResolveError> {
308        // initial ttl is what CNAMES for min usage
309        const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
310
311        // need to capture these before the subsequent and destructive record processing
312        let soa = response.soa().cloned();
313        let negative_ttl = response.negative_ttl();
314        let response_code = response.response_code();
315
316        // seek out CNAMES, this is only performed if the query is not a CNAME, ANY, or SRV
317        // FIXME: for SRV this evaluation is inadequate. CNAME is a single chain to a single record
318        //   for SRV, there could be many different targets. The search_name needs to be enhanced to
319        //   be a list of names found for SRV records.
320        let (search_name, cname_ttl, was_cname, preserved_records) = {
321            // this will only search for CNAMEs if the request was not meant to be for one of the triggers for recursion
322            let (search_name, cname_ttl, was_cname) =
323                if query.query_type().is_any() || query.query_type().is_cname() {
324                    (Cow::Borrowed(query.name()), INITIAL_TTL, false)
325                } else {
326                    // Folds any cnames from the answers section, into the final cname in the answers section
327                    //   this works by folding the last CNAME found into the final folded result.
328                    //   it assumes that the CNAMEs are in chained order in the DnsResponse Message...
329                    // For SRV, the name added for the search becomes the target name.
330                    //
331                    // TODO: should this include the additionals?
332                    response.answers().iter().fold(
333                        (Cow::Borrowed(query.name()), INITIAL_TTL, false),
334                        |(search_name, cname_ttl, was_cname), r| {
335                            match r.data() {
336                                Some(RData::CNAME(ref cname)) => {
337                                    // take the minimum TTL of the cname_ttl and the next record in the chain
338                                    let ttl = cname_ttl.min(r.ttl());
339                                    debug_assert_eq!(r.rr_type(), RecordType::CNAME);
340                                    if search_name.as_ref() == r.name() {
341                                        return (Cow::Owned(cname.clone()), ttl, true);
342                                    }
343                                }
344                                Some(RData::SRV(ref srv)) => {
345                                    // take the minimum TTL of the cname_ttl and the next record in the chain
346                                    let ttl = cname_ttl.min(r.ttl());
347                                    debug_assert_eq!(r.rr_type(), RecordType::SRV);
348
349                                    // the search name becomes the srv.target
350                                    return (Cow::Owned(srv.target().clone()), ttl, true);
351                                }
352                                _ => (),
353                            }
354
355                            (search_name, cname_ttl, was_cname)
356                        },
357                    )
358                };
359
360            // take all answers. // TODO: following CNAMES?
361            let answers = response.take_answers();
362            let additionals = response.take_additionals();
363            let name_servers = response.take_name_servers();
364
365            // set of names that still require resolution
366            // TODO: this needs to be enhanced for SRV
367            let mut found_name = false;
368
369            // After following all the CNAMES to the last one, try and lookup the final name
370            let records = answers
371                .into_iter()
372                // Chained records will generally exist in the additionals section
373                .chain(additionals.into_iter())
374                .chain(name_servers.into_iter())
375                .filter_map(|r| {
376                    // because this resolved potentially recursively, we want the min TTL from the chain
377                    let ttl = cname_ttl.min(r.ttl());
378                    // TODO: disable name validation with ResolverOpts? glibc feature...
379                    // restrict to the RData type requested
380                    if query.query_class() == r.dns_class() {
381                        // standard evaluation, it's an any type or it's the requested type and the search_name matches
382                        #[allow(clippy::suspicious_operation_groupings)]
383                        if (query.query_type().is_any() || query.query_type() == r.rr_type())
384                            && (search_name.as_ref() == r.name() || query.name() == r.name())
385                        {
386                            found_name = true;
387                            return Some((r, ttl));
388                        }
389                        // CNAME evaluation, it's an A/AAAA lookup and the record is from the CNAME lookup chain.
390                        if client.preserve_intermediates
391                            && r.rr_type() == RecordType::CNAME
392                            && (query.query_type() == RecordType::A
393                                || query.query_type() == RecordType::AAAA)
394                        {
395                            return Some((r, ttl));
396                        }
397                        // srv evaluation, it's an srv lookup and the srv_search_name/target matches this name
398                        //    and it's an IP
399                        if query.query_type().is_srv()
400                            && r.rr_type().is_ip_addr()
401                            && search_name.as_ref() == r.name()
402                        {
403                            found_name = true;
404                            Some((r, ttl))
405                        } else if query.query_type().is_ns() && r.rr_type().is_ip_addr() {
406                            Some((r, ttl))
407                        } else {
408                            None
409                        }
410                    } else {
411                        None
412                    }
413                })
414                .collect::<Vec<_>>();
415
416            // adding the newly collected records to the preserved records
417            preserved_records.extend(records);
418            if !preserved_records.is_empty() && found_name {
419                return Ok(Records::Exists(preserved_records));
420            }
421
422            (
423                search_name.into_owned(),
424                cname_ttl,
425                was_cname,
426                preserved_records,
427            )
428        };
429
430        // TODO: for SRV records we *could* do an implicit lookup, but, this requires knowing the type of IP desired
431        //    for now, we'll make the API require the user to perform a follow up to the lookups.
432        // It was a CNAME, but not included in the request...
433        if was_cname && client.query_depth.load(Ordering::Acquire) < MAX_QUERY_DEPTH {
434            let next_query = Query::query(search_name, query.query_type());
435            Ok(Records::CnameChain {
436                next: Box::pin(Self::inner_lookup(
437                    next_query,
438                    options,
439                    client.clone(),
440                    preserved_records,
441                )),
442                min_ttl: cname_ttl,
443            })
444        } else {
445            // TODO: review See https://tools.ietf.org/html/rfc2308 for NoData section
446            // Note on DNSSec, in secure_client_handle, if verify_nsec fails then the request fails.
447            //   this will mean that no unverified negative caches will make it to this point and be stored
448            Err(Self::handle_nxdomain(
449                is_dnssec,
450                true,
451                query.clone(),
452                soa,
453                negative_ttl,
454                response_code,
455                false,
456            ))
457        }
458    }
459
460    #[allow(clippy::unnecessary_wraps)]
461    fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ResolveError> {
462        // this duplicates the cache entry under the original query
463        Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
464    }
465
466    fn cache(
467        &self,
468        query: Query,
469        records: Result<Vec<(Record, u32)>, ResolveError>,
470    ) -> Result<Lookup, ResolveError> {
471        // this will put this object into an inconsistent state, but no one should call poll again...
472        match records {
473            Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
474            Err(err) => Err(self.lru.negative(query, err, Instant::now())),
475        }
476    }
477
478    /// Flushes/Removes all entries from the cache
479    pub fn clear_cache(&self) {
480        self.lru.clear();
481    }
482
483    /// Returns a shared reference to the underlying client.
484    pub fn client(&self) -> &C {
485        &self.client
486    }
487}
488
489enum Records {
490    /// The records exists, a vec of rdata with ttl
491    Exists(Vec<(Record, u32)>),
492    /// Future lookup for recursive cname records
493    CnameChain {
494        next: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
495        min_ttl: u32,
496    },
497}
498
499// see also the lookup_tests.rs in integration-tests crate
500#[cfg(test)]
501mod tests {
502    use std::net::*;
503    use std::str::FromStr;
504    use std::time::*;
505
506    use futures_executor::block_on;
507    use proto::op::{Message, Query};
508    use proto::rr::rdata::SRV;
509    use proto::rr::{Name, Record};
510
511    use super::*;
512    use crate::lookup_ip::tests::*;
513
514    #[test]
515    fn test_empty_cache() {
516        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
517        let client = mock(vec![empty()]);
518        let client = CachingClient::with_cache(cache, client, false);
519
520        if let ResolveErrorKind::NoRecordsFound {
521            query,
522            negative_ttl,
523            ..
524        } = block_on(CachingClient::inner_lookup(
525            Query::new(),
526            DnsRequestOptions::default(),
527            client,
528            vec![],
529        ))
530        .unwrap_err()
531        .kind()
532        {
533            assert_eq!(**query, Query::new());
534            assert_eq!(*negative_ttl, None);
535        } else {
536            panic!("wrong error received")
537        }
538    }
539
540    #[test]
541    fn test_from_cache() {
542        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
543        let query = Query::new();
544        cache.insert(
545            query.clone(),
546            vec![(
547                Record::from_rdata(
548                    query.name().clone(),
549                    u32::max_value(),
550                    RData::A(Ipv4Addr::new(127, 0, 0, 1)),
551                ),
552                u32::max_value(),
553            )],
554            Instant::now(),
555        );
556
557        let client = mock(vec![empty()]);
558        let client = CachingClient::with_cache(cache, client, false);
559
560        let ips = block_on(CachingClient::inner_lookup(
561            Query::new(),
562            DnsRequestOptions::default(),
563            client,
564            vec![],
565        ))
566        .unwrap();
567
568        assert_eq!(
569            ips.iter().cloned().collect::<Vec<_>>(),
570            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
571        );
572    }
573
574    #[test]
575    fn test_no_cache_insert() {
576        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
577        // first should come from client...
578        let client = mock(vec![v4_message()]);
579        let client = CachingClient::with_cache(cache.clone(), client, false);
580
581        let ips = block_on(CachingClient::inner_lookup(
582            Query::new(),
583            DnsRequestOptions::default(),
584            client,
585            vec![],
586        ))
587        .unwrap();
588
589        assert_eq!(
590            ips.iter().cloned().collect::<Vec<_>>(),
591            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
592        );
593
594        // next should come from cache...
595        let client = mock(vec![empty()]);
596        let client = CachingClient::with_cache(cache, client, false);
597
598        let ips = block_on(CachingClient::inner_lookup(
599            Query::new(),
600            DnsRequestOptions::default(),
601            client,
602            vec![],
603        ))
604        .unwrap();
605
606        assert_eq!(
607            ips.iter().cloned().collect::<Vec<_>>(),
608            vec![RData::A(Ipv4Addr::new(127, 0, 0, 1))]
609        );
610    }
611
612    #[allow(clippy::unnecessary_wraps)]
613    pub(crate) fn cname_message() -> Result<DnsResponse, ResolveError> {
614        let mut message = Message::new();
615        message.add_query(Query::query(
616            Name::from_str("www.example.com.").unwrap(),
617            RecordType::A,
618        ));
619        message.insert_answers(vec![Record::from_rdata(
620            Name::from_str("www.example.com.").unwrap(),
621            86400,
622            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
623        )]);
624        Ok(message.into())
625    }
626
627    #[allow(clippy::unnecessary_wraps)]
628    pub(crate) fn srv_message() -> Result<DnsResponse, ResolveError> {
629        let mut message = Message::new();
630        message.add_query(Query::query(
631            Name::from_str("_443._tcp.www.example.com.").unwrap(),
632            RecordType::SRV,
633        ));
634        message.insert_answers(vec![Record::from_rdata(
635            Name::from_str("_443._tcp.www.example.com.").unwrap(),
636            86400,
637            RData::SRV(SRV::new(
638                1,
639                2,
640                443,
641                Name::from_str("www.example.com.").unwrap(),
642            )),
643        )]);
644        Ok(message.into())
645    }
646
647    #[allow(clippy::unnecessary_wraps)]
648    pub(crate) fn ns_message() -> Result<DnsResponse, ResolveError> {
649        let mut message = Message::new();
650        message.add_query(Query::query(
651            Name::from_str("www.example.com.").unwrap(),
652            RecordType::NS,
653        ));
654        message.insert_answers(vec![Record::from_rdata(
655            Name::from_str("www.example.com.").unwrap(),
656            86400,
657            RData::NS(Name::from_str("www.example.com.").unwrap()),
658        )]);
659        Ok(message.into())
660    }
661
662    fn no_recursion_on_query_test(query_type: RecordType) {
663        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
664
665        // the cname should succeed, we shouldn't query again after that, which would cause an error...
666        let client = mock(vec![error(), cname_message()]);
667        let client = CachingClient::with_cache(cache, client, false);
668
669        let ips = block_on(CachingClient::inner_lookup(
670            Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
671            DnsRequestOptions::default(),
672            client,
673            vec![],
674        ))
675        .expect("lookup failed");
676
677        assert_eq!(
678            ips.iter().cloned().collect::<Vec<_>>(),
679            vec![RData::CNAME(Name::from_str("actual.example.com.").unwrap())]
680        );
681    }
682
683    #[test]
684    fn test_no_recursion_on_cname_query() {
685        no_recursion_on_query_test(RecordType::CNAME);
686    }
687
688    #[test]
689    fn test_no_recursion_on_all_query() {
690        no_recursion_on_query_test(RecordType::ANY);
691    }
692
693    #[test]
694    fn test_non_recursive_srv_query() {
695        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
696
697        // the cname should succeed, we shouldn't query again after that, which would cause an error...
698        let client = mock(vec![error(), srv_message()]);
699        let client = CachingClient::with_cache(cache, client, false);
700
701        let ips = block_on(CachingClient::inner_lookup(
702            Query::query(
703                Name::from_str("_443._tcp.www.example.com.").unwrap(),
704                RecordType::SRV,
705            ),
706            DnsRequestOptions::default(),
707            client,
708            vec![],
709        ))
710        .expect("lookup failed");
711
712        assert_eq!(
713            ips.iter().cloned().collect::<Vec<_>>(),
714            vec![RData::SRV(SRV::new(
715                1,
716                2,
717                443,
718                Name::from_str("www.example.com.").unwrap(),
719            ))]
720        );
721    }
722
723    #[test]
724    fn test_single_srv_query_response() {
725        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
726
727        let mut message = srv_message().unwrap();
728        message.add_answer(Record::from_rdata(
729            Name::from_str("www.example.com.").unwrap(),
730            86400,
731            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
732        ));
733        message.insert_additionals(vec![
734            Record::from_rdata(
735                Name::from_str("actual.example.com.").unwrap(),
736                86400,
737                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
738            ),
739            Record::from_rdata(
740                Name::from_str("actual.example.com.").unwrap(),
741                86400,
742                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
743            ),
744        ]);
745
746        let client = mock(vec![error(), Ok(message)]);
747        let client = CachingClient::with_cache(cache, client, false);
748
749        let ips = block_on(CachingClient::inner_lookup(
750            Query::query(
751                Name::from_str("_443._tcp.www.example.com.").unwrap(),
752                RecordType::SRV,
753            ),
754            DnsRequestOptions::default(),
755            client,
756            vec![],
757        ))
758        .expect("lookup failed");
759
760        assert_eq!(
761            ips.iter().cloned().collect::<Vec<_>>(),
762            vec![
763                RData::SRV(SRV::new(
764                    1,
765                    2,
766                    443,
767                    Name::from_str("www.example.com.").unwrap(),
768                )),
769                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
770                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
771            ]
772        );
773    }
774
775    // TODO: if we ever enable recursive lookups for SRV, here are the tests...
776    // #[test]
777    // fn test_recursive_srv_query() {
778    //     let cache = Arc::new(Mutex::new(DnsLru::new(1)));
779
780    //     let mut message = Message::new();
781    //     message.add_answer(Record::from_rdata(
782    //         Name::from_str("www.example.com.").unwrap(),
783    //         86400,
784    //         RecordType::CNAME,
785    //         RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
786    //     ));
787    //     message.insert_additionals(vec![
788    //         Record::from_rdata(
789    //             Name::from_str("actual.example.com.").unwrap(),
790    //             86400,
791    //             RecordType::A,
792    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
793    //         ),
794    //     ]);
795
796    //     let mut client = mock(vec![error(), Ok(message.into()), srv_message()]);
797
798    //     let ips = QueryState::lookup(
799    //         Query::query(
800    //             Name::from_str("_443._tcp.www.example.com.").unwrap(),
801    //             RecordType::SRV,
802    //         ),
803    //         Default::default(),
804    //         &mut client,
805    //         cache.clone(),
806    //     ).wait()
807    //         .expect("lookup failed");
808
809    //     assert_eq!(
810    //         ips.iter().cloned().collect::<Vec<_>>(),
811    //         vec![
812    //             RData::SRV(SRV::new(
813    //                 1,
814    //                 2,
815    //                 443,
816    //                 Name::from_str("www.example.com.").unwrap(),
817    //             )),
818    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
819    //             //RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
820    //         ]
821    //     );
822    // }
823
824    #[test]
825    fn test_single_ns_query_response() {
826        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
827
828        let mut message = ns_message().unwrap();
829        message.add_answer(Record::from_rdata(
830            Name::from_str("www.example.com.").unwrap(),
831            86400,
832            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
833        ));
834        message.insert_additionals(vec![
835            Record::from_rdata(
836                Name::from_str("actual.example.com.").unwrap(),
837                86400,
838                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
839            ),
840            Record::from_rdata(
841                Name::from_str("actual.example.com.").unwrap(),
842                86400,
843                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
844            ),
845        ]);
846
847        let client = mock(vec![error(), Ok(message)]);
848        let client = CachingClient::with_cache(cache, client, false);
849
850        let ips = block_on(CachingClient::inner_lookup(
851            Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::NS),
852            DnsRequestOptions::default(),
853            client,
854            vec![],
855        ))
856        .expect("lookup failed");
857
858        assert_eq!(
859            ips.iter().cloned().collect::<Vec<_>>(),
860            vec![
861                RData::NS(Name::from_str("www.example.com.").unwrap()),
862                RData::A(Ipv4Addr::new(127, 0, 0, 1)),
863                RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
864            ]
865        );
866    }
867
868    fn cname_ttl_test(first: u32, second: u32) {
869        let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
870        // expecting no queries to be performed
871        let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
872
873        let mut message = Message::new();
874        message.insert_answers(vec![Record::from_rdata(
875            Name::from_str("ttl.example.com.").unwrap(),
876            first,
877            RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
878        )]);
879        message.insert_additionals(vec![Record::from_rdata(
880            Name::from_str("actual.example.com.").unwrap(),
881            second,
882            RData::A(Ipv4Addr::new(127, 0, 0, 1)),
883        )]);
884
885        let records = CachingClient::handle_noerror(
886            &mut client,
887            DnsRequestOptions::default(),
888            false,
889            &Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
890            message.into(),
891            vec![],
892        );
893
894        if let Ok(records) = records {
895            if let Records::Exists(records) = records {
896                for (record, ttl) in records.iter() {
897                    if record.record_type() == RecordType::CNAME {
898                        continue;
899                    }
900                    assert_eq!(ttl, &1);
901                }
902            } else {
903                panic!("records don't exist");
904            }
905        } else {
906            panic!("error getting records");
907        }
908    }
909
910    #[test]
911    fn test_cname_ttl() {
912        cname_ttl_test(1, 2);
913        cname_ttl_test(2, 1);
914    }
915
916    #[test]
917    fn test_early_return_localhost() {
918        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
919        let client = mock(vec![empty()]);
920        let mut client = CachingClient::with_cache(cache, client, false);
921
922        {
923            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
924            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
925                .expect("should have returned localhost");
926            assert_eq!(lookup.query(), &query);
927            assert_eq!(
928                lookup.iter().cloned().collect::<Vec<_>>(),
929                vec![LOCALHOST_V4.clone()]
930            );
931        }
932
933        {
934            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
935            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
936                .expect("should have returned localhost");
937            assert_eq!(lookup.query(), &query);
938            assert_eq!(
939                lookup.iter().cloned().collect::<Vec<_>>(),
940                vec![LOCALHOST_V6.clone()]
941            );
942        }
943
944        {
945            let query = Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::PTR);
946            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
947                .expect("should have returned localhost");
948            assert_eq!(lookup.query(), &query);
949            assert_eq!(
950                lookup.iter().cloned().collect::<Vec<_>>(),
951                vec![LOCALHOST.clone()]
952            );
953        }
954
955        {
956            let query = Query::query(
957                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
958                RecordType::PTR,
959            );
960            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
961                .expect("should have returned localhost");
962            assert_eq!(lookup.query(), &query);
963            assert_eq!(
964                lookup.iter().cloned().collect::<Vec<_>>(),
965                vec![LOCALHOST.clone()]
966            );
967        }
968
969        assert!(block_on(client.lookup(
970            Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
971            DnsRequestOptions::default()
972        ))
973        .is_err());
974
975        assert!(block_on(client.lookup(
976            Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::MX),
977            DnsRequestOptions::default()
978        ))
979        .is_err());
980
981        assert!(block_on(client.lookup(
982            Query::query(
983                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
984                RecordType::MX
985            ),
986            DnsRequestOptions::default()
987        ))
988        .is_err());
989    }
990
991    #[test]
992    fn test_early_return_invalid() {
993        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
994        let client = mock(vec![empty()]);
995        let mut client = CachingClient::with_cache(cache, client, false);
996
997        assert!(block_on(client.lookup(
998            Query::query(
999                Name::from_ascii("horrible.invalid.").unwrap(),
1000                RecordType::A,
1001            ),
1002            DnsRequestOptions::default()
1003        ))
1004        .is_err());
1005    }
1006
1007    #[test]
1008    fn test_no_error_on_dot_local_no_mdns() {
1009        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
1010
1011        let mut message = srv_message().unwrap();
1012        message.add_query(Query::query(
1013            Name::from_ascii("www.example.local.").unwrap(),
1014            RecordType::A,
1015        ));
1016        message.add_answer(Record::from_rdata(
1017            Name::from_str("www.example.local.").unwrap(),
1018            86400,
1019            RData::A(Ipv4Addr::new(127, 0, 0, 1)),
1020        ));
1021
1022        let client = mock(vec![error(), Ok(message)]);
1023        let mut client = CachingClient::with_cache(cache, client, false);
1024
1025        assert!(block_on(client.lookup(
1026            Query::query(
1027                Name::from_ascii("www.example.local.").unwrap(),
1028                RecordType::A,
1029            ),
1030            DnsRequestOptions::default()
1031        ))
1032        .is_ok());
1033    }
1034}