1use std::cmp::min;
11use std::error::Error;
12use std::net::{Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::slice::Iter;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use futures_util::stream::Stream;
20use futures_util::{future, future::Future, FutureExt};
21
22use proto::error::ProtoError;
23use proto::op::Query;
24use proto::rr::rdata;
25use proto::rr::{Name, RData, Record, RecordType};
26use proto::xfer::{DnsRequest, DnsRequestOptions, DnsResponse};
27#[cfg(feature = "dnssec")]
28use proto::DnssecDnsHandle;
29use proto::{DnsHandle, RetryDnsHandle};
30
31use crate::caching_client::CachingClient;
32use crate::dns_lru::MAX_TTL;
33use crate::error::*;
34use crate::lookup_ip::LookupIpIter;
35use crate::name_server::{ConnectionProvider, NameServerPool};
36
37#[derive(Clone, Debug, Eq, PartialEq)]
41pub struct Lookup {
42 query: Query,
43 records: Arc<[Record]>,
44 valid_until: Instant,
45}
46
47impl Lookup {
48 pub fn from_rdata(query: Query, rdata: RData) -> Self {
50 let record = Record::from_rdata(query.name().clone(), MAX_TTL, rdata);
51 Self::new_with_max_ttl(query, Arc::from([record]))
52 }
53
54 pub fn new_with_max_ttl(query: Query, records: Arc<[Record]>) -> Self {
56 let valid_until = Instant::now() + Duration::from_secs(u64::from(MAX_TTL));
57 Self {
58 query,
59 records,
60 valid_until,
61 }
62 }
63
64 pub fn new_with_deadline(query: Query, records: Arc<[Record]>, valid_until: Instant) -> Self {
66 Self {
67 query,
68 records,
69 valid_until,
70 }
71 }
72
73 pub fn query(&self) -> &Query {
75 &self.query
76 }
77
78 pub fn iter(&self) -> LookupIter<'_> {
80 LookupIter(self.records.iter())
81 }
82
83 pub fn record_iter(&self) -> LookupRecordIter<'_> {
85 LookupRecordIter(self.records.iter())
86 }
87
88 pub fn valid_until(&self) -> Instant {
90 self.valid_until
91 }
92
93 #[doc(hidden)]
94 pub fn is_empty(&self) -> bool {
95 self.records.is_empty()
96 }
97
98 pub(crate) fn len(&self) -> usize {
99 self.records.len()
100 }
101
102 pub fn records(&self) -> &[Record] {
104 self.records.as_ref()
105 }
106
107 pub(crate) fn append(&self, other: Self) -> Self {
109 let mut records = Vec::with_capacity(self.len() + other.len());
110 records.extend_from_slice(&self.records);
111 records.extend_from_slice(&other.records);
112
113 let valid_until = min(self.valid_until(), other.valid_until());
115 Self::new_with_deadline(self.query.clone(), Arc::from(records), valid_until)
116 }
117}
118
119pub struct LookupIter<'a>(Iter<'a, Record>);
121
122impl<'a> Iterator for LookupIter<'a> {
123 type Item = &'a RData;
124
125 fn next(&mut self) -> Option<Self::Item> {
126 self.0.next().and_then(Record::data)
127 }
128}
129
130pub struct LookupRecordIter<'a>(Iter<'a, Record>);
132
133impl<'a> Iterator for LookupRecordIter<'a> {
134 type Item = &'a Record;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 self.0.next()
138 }
139}
140
141impl IntoIterator for Lookup {
143 type Item = RData;
144 type IntoIter = LookupIntoIter;
145
146 fn into_iter(self) -> Self::IntoIter {
149 LookupIntoIter {
150 records: Arc::clone(&self.records),
151 index: 0,
152 }
153 }
154}
155
156pub struct LookupIntoIter {
160 records: Arc<[Record]>,
162 index: usize,
163}
164
165impl Iterator for LookupIntoIter {
166 type Item = RData;
167
168 fn next(&mut self) -> Option<Self::Item> {
169 let rdata = self.records.get(self.index).and_then(Record::data);
170 self.index += 1;
171 rdata.cloned()
172 }
173}
174
175#[derive(Clone)]
177#[doc(hidden)]
178pub enum LookupEither<
179 C: DnsHandle<Error = ResolveError> + 'static,
180 P: ConnectionProvider<Conn = C> + 'static,
181> {
182 Retry(RetryDnsHandle<NameServerPool<C, P>>),
183 #[cfg(feature = "dnssec")]
184 #[cfg_attr(docsrs, doc(cfg(feature = "dnssec")))]
185 Secure(DnssecDnsHandle<RetryDnsHandle<NameServerPool<C, P>>>),
186}
187
188impl<C: DnsHandle<Error = ResolveError> + Sync, P: ConnectionProvider<Conn = C>> DnsHandle
189 for LookupEither<C, P>
190{
191 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
192 type Error = ResolveError;
193
194 fn is_verifying_dnssec(&self) -> bool {
195 match *self {
196 Self::Retry(ref c) => c.is_verifying_dnssec(),
197 #[cfg(feature = "dnssec")]
198 Self::Secure(ref c) => c.is_verifying_dnssec(),
199 }
200 }
201
202 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
203 match *self {
204 Self::Retry(ref mut c) => c.send(request),
205 #[cfg(feature = "dnssec")]
206 Self::Secure(ref mut c) => c.send(request),
207 }
208 }
209}
210
211#[doc(hidden)]
213pub struct LookupFuture<C, E>
214where
215 C: DnsHandle<Error = E> + 'static,
216 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
217{
218 client_cache: CachingClient<C, E>,
219 names: Vec<Name>,
220 record_type: RecordType,
221 options: DnsRequestOptions,
222 query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
223}
224
225impl<C, E> LookupFuture<C, E>
226where
227 C: DnsHandle<Error = E> + 'static,
228 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
229{
230 #[doc(hidden)]
238 pub fn lookup(
239 mut names: Vec<Name>,
240 record_type: RecordType,
241 options: DnsRequestOptions,
242 mut client_cache: CachingClient<C, E>,
243 ) -> Self {
244 let name = names.pop().ok_or_else(|| {
245 ResolveError::from(ResolveErrorKind::Message("can not lookup for no names"))
246 });
247
248 let query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> = match name {
249 Ok(name) => client_cache
250 .lookup(Query::query(name, record_type), options)
251 .boxed(),
252 Err(err) => future::err(err).boxed(),
253 };
254
255 Self {
256 client_cache,
257 names,
258 record_type,
259 options,
260 query,
261 }
262 }
263}
264
265impl<C, E> Future for LookupFuture<C, E>
266where
267 C: DnsHandle<Error = E> + 'static,
268 E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
269{
270 type Output = Result<Lookup, ResolveError>;
271
272 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273 loop {
274 let query = self.query.as_mut().poll_unpin(cx);
276
277 let should_retry = match query {
279 Poll::Pending => return Poll::Pending,
281 Poll::Ready(Ok(ref lookup)) => lookup.records.len() == 0,
285 Poll::Ready(Err(_)) => true,
287 };
288
289 if should_retry {
290 if let Some(name) = self.names.pop() {
291 let record_type = self.record_type;
292 let options = self.options;
293
294 self.query = self
297 .client_cache
298 .lookup(Query::query(name, record_type), options);
299 continue;
302 }
303 }
304 return query;
308 }
313 }
314}
315
316#[derive(Debug, Clone)]
318pub struct SrvLookup(Lookup);
319
320impl SrvLookup {
321 pub fn iter(&self) -> SrvLookupIter<'_> {
323 SrvLookupIter(self.0.iter())
324 }
325
326 pub fn query(&self) -> &Query {
328 self.0.query()
329 }
330
331 pub fn ip_iter(&self) -> LookupIpIter<'_> {
335 LookupIpIter(self.0.iter())
336 }
337
338 pub fn as_lookup(&self) -> &Lookup {
342 &self.0
343 }
344}
345
346impl From<Lookup> for SrvLookup {
347 fn from(lookup: Lookup) -> Self {
348 Self(lookup)
349 }
350}
351
352pub struct SrvLookupIter<'i>(LookupIter<'i>);
354
355impl<'i> Iterator for SrvLookupIter<'i> {
356 type Item = &'i rdata::SRV;
357
358 fn next(&mut self) -> Option<Self::Item> {
359 let iter: &mut _ = &mut self.0;
360 iter.filter_map(|rdata| match *rdata {
361 RData::SRV(ref data) => Some(data),
362 _ => None,
363 })
364 .next()
365 }
366}
367
368impl IntoIterator for SrvLookup {
369 type Item = rdata::SRV;
370 type IntoIter = SrvLookupIntoIter;
371
372 fn into_iter(self) -> Self::IntoIter {
375 SrvLookupIntoIter(self.0.into_iter())
376 }
377}
378
379pub struct SrvLookupIntoIter(LookupIntoIter);
381
382impl Iterator for SrvLookupIntoIter {
383 type Item = rdata::SRV;
384
385 fn next(&mut self) -> Option<Self::Item> {
386 let iter: &mut _ = &mut self.0;
387 iter.filter_map(|rdata| match rdata {
388 RData::SRV(data) => Some(data),
389 _ => None,
390 })
391 .next()
392 }
393}
394
395macro_rules! lookup_type {
397 ($l:ident, $i:ident, $ii:ident, $r:path, $t:path) => {
398 #[derive(Debug, Clone)]
400 pub struct $l(Lookup);
401
402 impl $l {
403 pub fn iter(&self) -> $i<'_> {
405 $i(self.0.iter())
406 }
407
408 pub fn query(&self) -> &Query {
410 self.0.query()
411 }
412
413 pub fn valid_until(&self) -> Instant {
415 self.0.valid_until()
416 }
417
418 pub fn as_lookup(&self) -> &Lookup {
422 &self.0
423 }
424 }
425
426 impl From<Lookup> for $l {
427 fn from(lookup: Lookup) -> Self {
428 $l(lookup)
429 }
430 }
431
432 pub struct $i<'i>(LookupIter<'i>);
434
435 impl<'i> Iterator for $i<'i> {
436 type Item = &'i $t;
437
438 fn next(&mut self) -> Option<Self::Item> {
439 let iter: &mut _ = &mut self.0;
440 iter.filter_map(|rdata| match *rdata {
441 $r(ref data) => Some(data),
442 _ => None,
443 })
444 .next()
445 }
446 }
447
448 impl IntoIterator for $l {
449 type Item = $t;
450 type IntoIter = $ii;
451
452 fn into_iter(self) -> Self::IntoIter {
455 $ii(self.0.into_iter())
456 }
457 }
458
459 pub struct $ii(LookupIntoIter);
461
462 impl Iterator for $ii {
463 type Item = $t;
464
465 fn next(&mut self) -> Option<Self::Item> {
466 let iter: &mut _ = &mut self.0;
467 iter.filter_map(|rdata| match rdata {
468 $r(data) => Some(data),
469 _ => None,
470 })
471 .next()
472 }
473 }
474 };
475}
476
477lookup_type!(
479 ReverseLookup,
480 ReverseLookupIter,
481 ReverseLookupIntoIter,
482 RData::PTR,
483 Name
484);
485lookup_type!(
486 Ipv4Lookup,
487 Ipv4LookupIter,
488 Ipv4LookupIntoIter,
489 RData::A,
490 Ipv4Addr
491);
492lookup_type!(
493 Ipv6Lookup,
494 Ipv6LookupIter,
495 Ipv6LookupIntoIter,
496 RData::AAAA,
497 Ipv6Addr
498);
499lookup_type!(
500 MxLookup,
501 MxLookupIter,
502 MxLookupIntoIter,
503 RData::MX,
504 rdata::MX
505);
506lookup_type!(
507 TlsaLookup,
508 TlsaLookupIter,
509 TlsaLookupIntoIter,
510 RData::TLSA,
511 rdata::TLSA
512);
513lookup_type!(
514 TxtLookup,
515 TxtLookupIter,
516 TxtLookupIntoIter,
517 RData::TXT,
518 rdata::TXT
519);
520lookup_type!(
521 SoaLookup,
522 SoaLookupIter,
523 SoaLookupIntoIter,
524 RData::SOA,
525 rdata::SOA
526);
527lookup_type!(NsLookup, NsLookupIter, NsLookupIntoIter, RData::NS, Name);
528
529#[cfg(test)]
530pub mod tests {
531 use std::net::{IpAddr, Ipv4Addr};
532 use std::str::FromStr;
533 use std::sync::{Arc, Mutex};
534
535 use futures_executor::block_on;
536 use futures_util::future;
537 use futures_util::stream::once;
538
539 use proto::op::{Message, Query};
540 use proto::rr::{Name, RData, Record, RecordType};
541 use proto::xfer::{DnsRequest, DnsRequestOptions};
542
543 use super::*;
544 use crate::error::ResolveError;
545
546 #[derive(Clone)]
547 pub struct MockDnsHandle {
548 messages: Arc<Mutex<Vec<Result<DnsResponse, ResolveError>>>>,
549 }
550
551 impl DnsHandle for MockDnsHandle {
552 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
553 type Error = ResolveError;
554
555 fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
556 Box::pin(once(
557 future::ready(self.messages.lock().unwrap().pop().unwrap_or_else(empty)).boxed(),
558 ))
559 }
560 }
561
562 pub fn v4_message() -> Result<DnsResponse, ResolveError> {
563 let mut message = Message::new();
564 message.add_query(Query::query(Name::root(), RecordType::A));
565 message.insert_answers(vec![Record::from_rdata(
566 Name::root(),
567 86400,
568 RData::A(Ipv4Addr::new(127, 0, 0, 1)),
569 )]);
570
571 let resp: DnsResponse = message.into();
572 assert!(resp.contains_answer());
573 Ok(resp)
574 }
575
576 pub fn empty() -> Result<DnsResponse, ResolveError> {
577 Ok(Message::new().into())
578 }
579
580 pub fn error() -> Result<DnsResponse, ResolveError> {
581 Err(ResolveError::from(ProtoError::from(std::io::Error::from(
582 std::io::ErrorKind::Other,
583 ))))
584 }
585
586 pub fn mock(messages: Vec<Result<DnsResponse, ResolveError>>) -> MockDnsHandle {
587 MockDnsHandle {
588 messages: Arc::new(Mutex::new(messages)),
589 }
590 }
591
592 #[test]
593 fn test_lookup() {
594 assert_eq!(
595 block_on(LookupFuture::lookup(
596 vec![Name::root()],
597 RecordType::A,
598 DnsRequestOptions::default(),
599 CachingClient::new(0, mock(vec![v4_message()]), false),
600 ))
601 .unwrap()
602 .iter()
603 .map(|r| r.to_ip_addr().unwrap())
604 .collect::<Vec<IpAddr>>(),
605 vec![Ipv4Addr::new(127, 0, 0, 1)]
606 );
607 }
608
609 #[test]
610 fn test_lookup_slice() {
611 assert_eq!(
612 Record::data(
613 &block_on(LookupFuture::lookup(
614 vec![Name::root()],
615 RecordType::A,
616 DnsRequestOptions::default(),
617 CachingClient::new(0, mock(vec![v4_message()]), false),
618 ))
619 .unwrap()
620 .records()[0]
621 )
622 .unwrap()
623 .to_ip_addr()
624 .unwrap(),
625 Ipv4Addr::new(127, 0, 0, 1)
626 );
627 }
628
629 #[test]
630 fn test_lookup_into_iter() {
631 assert_eq!(
632 block_on(LookupFuture::lookup(
633 vec![Name::root()],
634 RecordType::A,
635 DnsRequestOptions::default(),
636 CachingClient::new(0, mock(vec![v4_message()]), false),
637 ))
638 .unwrap()
639 .into_iter()
640 .map(|r| r.to_ip_addr().unwrap())
641 .collect::<Vec<IpAddr>>(),
642 vec![Ipv4Addr::new(127, 0, 0, 1)]
643 );
644 }
645
646 #[test]
647 fn test_error() {
648 assert!(block_on(LookupFuture::lookup(
649 vec![Name::root()],
650 RecordType::A,
651 DnsRequestOptions::default(),
652 CachingClient::new(0, mock(vec![error()]), false),
653 ))
654 .is_err());
655 }
656
657 #[test]
658 fn test_empty_no_response() {
659 if let ResolveErrorKind::NoRecordsFound {
660 query,
661 negative_ttl,
662 ..
663 } = block_on(LookupFuture::lookup(
664 vec![Name::root()],
665 RecordType::A,
666 DnsRequestOptions::default(),
667 CachingClient::new(0, mock(vec![empty()]), false),
668 ))
669 .unwrap_err()
670 .kind()
671 {
672 assert_eq!(**query, Query::query(Name::root(), RecordType::A));
673 assert_eq!(*negative_ttl, None);
674 } else {
675 panic!("wrong error recieved");
676 }
677 }
678
679 #[test]
680 fn test_lookup_into_iter_arc() {
681 let mut lookup = LookupIntoIter {
682 records: Arc::from([
683 Record::from_rdata(
684 Name::from_str("www.example.com.").unwrap(),
685 80,
686 RData::A(Ipv4Addr::new(127, 0, 0, 1)),
687 ),
688 Record::from_rdata(
689 Name::from_str("www.example.com.").unwrap(),
690 80,
691 RData::A(Ipv4Addr::new(127, 0, 0, 2)),
692 ),
693 ]),
694 index: 0,
695 };
696
697 assert_eq!(
698 lookup.next().unwrap(),
699 RData::A(Ipv4Addr::new(127, 0, 0, 1))
700 );
701 assert_eq!(
702 lookup.next().unwrap(),
703 RData::A(Ipv4Addr::new(127, 0, 0, 2))
704 );
705 assert_eq!(lookup.next(), None);
706 }
707}