1use std::cmp::min;
11use std::error::Error;
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::slice::Iter;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use futures_util::stream::Stream;
20use futures_util::{future, future::Future, FutureExt};
21
22use proto::error::ProtoError;
23use proto::op::Query;
24use proto::rr::rdata;
25use proto::rr::{Name, RData, Record, RecordType};
26use proto::xfer::{DnsRequest, DnsRequestOptions, DnsResponse};
27#[cfg(feature = "dnssec")]
28use proto::DnssecDnsHandle;
29use proto::{DnsHandle, RetryDnsHandle};
30
31use crate::caching_client::CachingClient;
32use crate::dns_lru::MAX_TTL;
33use crate::error::*;
34use crate::lookup_ip::LookupIpIter;
35use crate::name_server::{ConnectionProvider, NameServerPool};
36
37#[derive(Clone, Debug, Eq, PartialEq)]
41pub struct Lookup {
42 query: Query,
43 records: Arc<[Record]>,
44 valid_until: Instant,
45}
46
47impl Lookup {
48 pub fn from_rdata(query: Query, rdata: RData) -> Self {
50 let record = Record::from_rdata(query.name().clone(), MAX_TTL, rdata);
51 Self::new_with_max_ttl(query, Arc::from([record]))
52 }
53
54 pub fn new_with_max_ttl(query: Query, records: Arc<[Record]>) -> Self {
56 let valid_until = Instant::now() + Duration::from_secs(u64::from(MAX_TTL));
57 Self {
58 query,
59 records,
60 valid_until,
61 }
62 }
63
64 pub fn new_with_deadline(query: Query, records: Arc<[Record]>, valid_until: Instant) -> Self {
66 Self {
67 query,
68 records,
69 valid_until,
70 }
71 }
72
73 pub fn query(&self) -> &Query {
75 &self.query
76 }
77
78 pub fn iter(&self) -> LookupIter<'_> {
80 LookupIter(self.records.iter())
81 }
82
83 pub fn record_iter(&self) -> LookupRecordIter<'_> {
85 LookupRecordIter(self.records.iter())
86 }
87
88 pub fn valid_until(&self) -> Instant {
90 self.valid_until
91 }
92
93 #[doc(hidden)]
94 pub fn is_empty(&self) -> bool {
95 self.records.is_empty()
96 }
97
98 pub(crate) fn len(&self) -> usize {
99 self.records.len()
100 }
101
102 pub fn records(&self) -> &[Record] {
104 self.records.as_ref()
105 }
106
107 pub(crate) fn append(&self, other: Self) -> Self {
109 let mut records = Vec::with_capacity(self.len() + other.len());
110 records.extend_from_slice(&self.records);
111 records.extend_from_slice(&other.records);
112
113 let valid_until = min(self.valid_until(), other.valid_until());
115 Self::new_with_deadline(self.query.clone(), Arc::from(records), valid_until)
116 }
117}
118
119pub struct LookupIter<'a>(Iter<'a, Record>);
121
122impl<'a> Iterator for LookupIter<'a> {
123 type Item = &'a RData;
124
125 fn next(&mut self) -> Option<Self::Item> {
126 self.0.next().and_then(Record::data)
127 }
128}
129
130pub struct LookupRecordIter<'a>(Iter<'a, Record>);
132
133impl<'a> Iterator for LookupRecordIter<'a> {
134 type Item = &'a Record;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 self.0.next()
138 }
139}
140
141impl IntoIterator for Lookup {
143 type Item = RData;
144 type IntoIter = LookupIntoIter;
145
146 fn into_iter(self) -> Self::IntoIter {
149 LookupIntoIter {
150 records: Arc::clone(&self.records),
151 index: 0,
152 }
153 }
154}
155
156pub struct LookupIntoIter {
160 records: Arc<[Record]>,
162 index: usize,
163}
164
165impl Iterator for LookupIntoIter {
166 type Item = RData;
167
168 fn next(&mut self) -> Option<Self::Item> {
169 let rdata = self.records.get(self.index).and_then(Record::data);
170 self.index += 1;
171 rdata.cloned()
172 }
173}
174
175#[derive(Clone)]
177#[doc(hidden)]
178pub enum LookupEither<
179 C: DnsHandle<Error = ResolveError> + 'static,
180 P: ConnectionProvider<Conn = C> + 'static,
181> {
182 Retry(RetryDnsHandle<NameServerPool<C, P>>),
183 #[cfg(feature = "dnssec")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
185 Secure(DnssecDnsHandle<RetryDnsHandle<NameServerPool<C, P>>>),
186}
187
188impl<C, P> LookupEither<C, P>
189where
190 C: DnsHandle<Error = ResolveError> + 'static,
191 P: ConnectionProvider<Conn = C> + 'static,
192{
193 pub fn pool(&self) -> &NameServerPool<C, P> {
194 match self {
195 Self::Retry(ref c) => c.handle(),
196 #[cfg(feature = "dnssec")]
197 Self::Secure(ref c) => c.handle().handle(),
198 }
199 }
200}
201
202impl<C: DnsHandle<Error = ResolveError> + Sync, P: ConnectionProvider<Conn = C>> DnsHandle
203 for LookupEither<C, P>
204{
205 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
206 type Error = ResolveError;
207
208 fn is_verifying_dnssec(&self) -> bool {
209 match *self {
210 Self::Retry(ref c) => c.is_verifying_dnssec(),
211 #[cfg(feature = "dnssec")]
212 Self::Secure(ref c) => c.is_verifying_dnssec(),
213 }
214 }
215
216 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
217 match *self {
218 Self::Retry(ref mut c) => c.send(request),
219 #[cfg(feature = "dnssec")]
220 Self::Secure(ref mut c) => c.send(request),
221 }
222 }
223}
224
225#[doc(hidden)]
227pub struct LookupFuture<C, E>
228where
229 C: DnsHandle<Error = E> + 'static,
230 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
231{
232 client_cache: CachingClient<C, E>,
233 names: Vec<Name>,
234 record_type: RecordType,
235 options: DnsRequestOptions,
236 query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
237}
238
239impl<C, E> LookupFuture<C, E>
240where
241 C: DnsHandle<Error = E> + 'static,
242 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
243{
244 #[doc(hidden)]
252 pub fn lookup(
253 mut names: Vec<Name>,
254 record_type: RecordType,
255 options: DnsRequestOptions,
256 mut client_cache: CachingClient<C, E>,
257 ) -> Self {
258 let name = names.pop().ok_or_else(|| {
259 ResolveError::from(ResolveErrorKind::Message("can not lookup for no names"))
260 });
261
262 let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name {
263 Ok(name) => client_cache
264 .lookup(Query::query(name, record_type), options)
265 .boxed(),
266 Err(err) => future::err(err).boxed(),
267 };
268
269 Self {
270 client_cache,
271 names,
272 record_type,
273 options,
274 query,
275 }
276 }
277}
278
279impl<C, E> Future for LookupFuture<C, E>
280where
281 C: DnsHandle<Error = E> + 'static,
282 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
283{
284 type Output = Result<Lookup, ResolveError>;
285
286 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
287 loop {
288 let query = self.query.as_mut().poll_unpin(cx);
290
291 let should_retry = match query {
293 Poll::Pending => return Poll::Pending,
295 Poll::Ready(Ok(ref lookup)) => lookup.records.len() == 0,
299 Poll::Ready(Err(_)) => true,
301 };
302
303 if should_retry {
304 if let Some(name) = self.names.pop() {
305 let record_type = self.record_type;
306 let options = self.options;
307
308 self.query = self
311 .client_cache
312 .lookup(Query::query(name, record_type), options);
313 continue;
316 }
317 }
318 return query;
322 }
327 }
328}
329
330#[derive(Debug, Clone)]
332pub struct SrvLookup(Lookup);
333
334impl SrvLookup {
335 pub fn iter(&self) -> SrvLookupIter<'_> {
337 SrvLookupIter(self.0.iter())
338 }
339
340 pub fn query(&self) -> &Query {
342 self.0.query()
343 }
344
345 pub fn ip_iter(&self) -> LookupIpIter<'_> {
349 LookupIpIter(self.0.iter())
350 }
351
352 pub fn as_lookup(&self) -> &Lookup {
356 &self.0
357 }
358}
359
360impl From<Lookup> for SrvLookup {
361 fn from(lookup: Lookup) -> Self {
362 Self(lookup)
363 }
364}
365
366pub struct SrvLookupIter<'i>(LookupIter<'i>);
368
369impl<'i> Iterator for SrvLookupIter<'i> {
370 type Item = &'i rdata::SRV;
371
372 fn next(&mut self) -> Option<Self::Item> {
373 let iter: &mut _ = &mut self.0;
374 iter.filter_map(|rdata| match *rdata {
375 RData::SRV(ref data) => Some(data),
376 _ => None,
377 })
378 .next()
379 }
380}
381
382impl IntoIterator for SrvLookup {
383 type Item = rdata::SRV;
384 type IntoIter = SrvLookupIntoIter;
385
386 fn into_iter(self) -> Self::IntoIter {
389 SrvLookupIntoIter(self.0.into_iter())
390 }
391}
392
393pub struct SrvLookupIntoIter(LookupIntoIter);
395
396impl Iterator for SrvLookupIntoIter {
397 type Item = rdata::SRV;
398
399 fn next(&mut self) -> Option<Self::Item> {
400 let iter: &mut _ = &mut self.0;
401 iter.filter_map(|rdata| match rdata {
402 RData::SRV(data) => Some(data),
403 _ => None,
404 })
405 .next()
406 }
407}
408
409macro_rules! lookup_type {
411 ($l:ident, $i:ident, $ii:ident, $r:path, $t:path) => {
412 #[derive(Debug, Clone)]
414 pub struct $l(Lookup);
415
416 impl $l {
417 pub fn iter(&self) -> $i<'_> {
419 $i(self.0.iter())
420 }
421
422 pub fn query(&self) -> &Query {
424 self.0.query()
425 }
426
427 pub fn valid_until(&self) -> Instant {
429 self.0.valid_until()
430 }
431
432 pub fn as_lookup(&self) -> &Lookup {
436 &self.0
437 }
438 }
439
440 impl From<Lookup> for $l {
441 fn from(lookup: Lookup) -> Self {
442 $l(lookup)
443 }
444 }
445
446 pub struct $i<'i>(LookupIter<'i>);
448
449 impl<'i> Iterator for $i<'i> {
450 type Item = &'i $t;
451
452 fn next(&mut self) -> Option<Self::Item> {
453 let iter: &mut _ = &mut self.0;
454 iter.filter_map(|rdata| match *rdata {
455 $r(ref data) => Some(data),
456 _ => None,
457 })
458 .next()
459 }
460 }
461
462 impl IntoIterator for $l {
463 type Item = $t;
464 type IntoIter = $ii;
465
466 fn into_iter(self) -> Self::IntoIter {
469 $ii(self.0.into_iter())
470 }
471 }
472
473 pub struct $ii(LookupIntoIter);
475
476 impl Iterator for $ii {
477 type Item = $t;
478
479 fn next(&mut self) -> Option<Self::Item> {
480 let iter: &mut _ = &mut self.0;
481 iter.filter_map(|rdata| match rdata {
482 $r(data) => Some(data),
483 _ => None,
484 })
485 .next()
486 }
487 }
488 };
489}
490
491lookup_type!(
493 ReverseLookup,
494 ReverseLookupIter,
495 ReverseLookupIntoIter,
496 RData::PTR,
497 Name
498);
499lookup_type!(
500 Ipv4Lookup,
501 Ipv4LookupIter,
502 Ipv4LookupIntoIter,
503 RData::A,
504 Ipv4Addr
505);
506lookup_type!(
507 Ipv6Lookup,
508 Ipv6LookupIter,
509 Ipv6LookupIntoIter,
510 RData::AAAA,
511 Ipv6Addr
512);
513lookup_type!(
514 MxLookup,
515 MxLookupIter,
516 MxLookupIntoIter,
517 RData::MX,
518 rdata::MX
519);
520lookup_type!(
521 TlsaLookup,
522 TlsaLookupIter,
523 TlsaLookupIntoIter,
524 RData::TLSA,
525 rdata::TLSA
526);
527lookup_type!(
528 TxtLookup,
529 TxtLookupIter,
530 TxtLookupIntoIter,
531 RData::TXT,
532 rdata::TXT
533);
534lookup_type!(
535 SoaLookup,
536 SoaLookupIter,
537 SoaLookupIntoIter,
538 RData::SOA,
539 rdata::SOA
540);
541lookup_type!(NsLookup, NsLookupIter, NsLookupIntoIter, RData::NS, Name);
542
543#[cfg(test)]
544pub mod tests {
545 use std::net::{IpAddr, Ipv4Addr};
546 use std::str::FromStr;
547 use std::sync::{Arc, Mutex};
548
549 use futures_executor::block_on;
550 use futures_util::future;
551 use futures_util::stream::once;
552
553 use proto::op::{Message, Query};
554 use proto::rr::{Name, RData, Record, RecordType};
555 use proto::xfer::{DnsRequest, DnsRequestOptions};
556
557 use super::*;
558 use crate::error::ResolveError;
559
560 #[derive(Clone)]
561 pub struct MockDnsHandle {
562 messages: Arc<Mutex<Vec<Result<DnsResponse, ResolveError>>>>,
563 }
564
565 impl DnsHandle for MockDnsHandle {
566 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
567 type Error = ResolveError;
568
569 fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
570 Box::pin(once(
571 future::ready(self.messages.lock().unwrap().pop().unwrap_or_else(empty)).boxed(),
572 ))
573 }
574 }
575
576 pub fn v4_message() -> Result<DnsResponse, ResolveError> {
577 let mut message = Message::new();
578 message.add_query(Query::query(Name::root(), RecordType::A));
579 message.insert_answers(vec![Record::from_rdata(
580 Name::root(),
581 86400,
582 RData::A(Ipv4Addr::new(127, 0, 0, 1)),
583 )]);
584
585 let resp: DnsResponse = message.into();
586 assert!(resp.contains_answer());
587 Ok(resp)
588 }
589
590 pub fn empty() -> Result<DnsResponse, ResolveError> {
591 Ok(Message::new().into())
592 }
593
594 pub fn error() -> Result<DnsResponse, ResolveError> {
595 Err(ResolveError::from(ProtoError::from(std::io::Error::from(
596 std::io::ErrorKind::Other,
597 ))))
598 }
599
600 pub fn mock(messages: Vec<Result<DnsResponse, ResolveError>>) -> MockDnsHandle {
601 MockDnsHandle {
602 messages: Arc::new(Mutex::new(messages)),
603 }
604 }
605
606 #[test]
607 fn test_lookup() {
608 assert_eq!(
609 block_on(LookupFuture::lookup(
610 vec![Name::root()],
611 RecordType::A,
612 DnsRequestOptions::default(),
613 CachingClient::new(0, mock(vec![v4_message()]), false),
614 ))
615 .unwrap()
616 .iter()
617 .map(|r| r.to_ip_addr().unwrap())
618 .collect::<Vec<IpAddr>>(),
619 vec![Ipv4Addr::new(127, 0, 0, 1)]
620 );
621 }
622
623 #[test]
624 fn test_lookup_slice() {
625 assert_eq!(
626 Record::data(
627 &block_on(LookupFuture::lookup(
628 vec![Name::root()],
629 RecordType::A,
630 DnsRequestOptions::default(),
631 CachingClient::new(0, mock(vec![v4_message()]), false),
632 ))
633 .unwrap()
634 .records()[0]
635 )
636 .unwrap()
637 .to_ip_addr()
638 .unwrap(),
639 Ipv4Addr::new(127, 0, 0, 1)
640 );
641 }
642
643 #[test]
644 fn test_lookup_into_iter() {
645 assert_eq!(
646 block_on(LookupFuture::lookup(
647 vec![Name::root()],
648 RecordType::A,
649 DnsRequestOptions::default(),
650 CachingClient::new(0, mock(vec![v4_message()]), false),
651 ))
652 .unwrap()
653 .into_iter()
654 .map(|r| r.to_ip_addr().unwrap())
655 .collect::<Vec<IpAddr>>(),
656 vec![Ipv4Addr::new(127, 0, 0, 1)]
657 );
658 }
659
660 #[test]
661 fn test_error() {
662 assert!(block_on(LookupFuture::lookup(
663 vec![Name::root()],
664 RecordType::A,
665 DnsRequestOptions::default(),
666 CachingClient::new(0, mock(vec![error()]), false),
667 ))
668 .is_err());
669 }
670
671 #[test]
672 fn test_empty_no_response() {
673 if let ResolveErrorKind::NoRecordsFound {
674 query,
675 negative_ttl,
676 ..
677 } = block_on(LookupFuture::lookup(
678 vec![Name::root()],
679 RecordType::A,
680 DnsRequestOptions::default(),
681 CachingClient::new(0, mock(vec![empty()]), false),
682 ))
683 .unwrap_err()
684 .kind()
685 {
686 assert_eq!(**query, Query::query(Name::root(), RecordType::A));
687 assert_eq!(*negative_ttl, None);
688 } else {
689 panic!("wrong error recieved");
690 }
691 }
692
693 #[test]
694 fn test_lookup_into_iter_arc() {
695 let mut lookup = LookupIntoIter {
696 records: Arc::from([
697 Record::from_rdata(
698 Name::from_str("www.example.com.").unwrap(),
699 80,
700 RData::A(Ipv4Addr::new(127, 0, 0, 1)),
701 ),
702 Record::from_rdata(
703 Name::from_str("www.example.com.").unwrap(),
704 80,
705 RData::A(Ipv4Addr::new(127, 0, 0, 2)),
706 ),
707 ]),
708 index: 0,
709 };
710
711 assert_eq!(
712 lookup.next().unwrap(),
713 RData::A(Ipv4Addr::new(127, 0, 0, 1))
714 );
715 assert_eq!(
716 lookup.next().unwrap(),
717 RData::A(Ipv4Addr::new(127, 0, 0, 2))
718 );
719 assert_eq!(lookup.next(), None);
720 }
721}