Skip to main content

trust_dns_proto/xfer/
retry_dns_handle.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `RetryDnsHandle` allows for DnsQueries to be reattempted on failure
9
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use futures_util::stream::{Stream, StreamExt};
14
15use crate::error::{ProtoError, ProtoErrorKind};
16use crate::xfer::{DnsRequest, DnsResponse};
17use crate::DnsHandle;
18
19/// Can be used to reattempt queries if they fail
20///
21/// Note: this does not reattempt queries that fail with a negative response.
22/// For example, if a query gets a `NODATA` response from a name server, the
23/// query will not be retried. It only reattempts queries that effectively
24/// failed to get a response, such as queries that resulted in IO or timeout
25/// errors.
26///
27/// Whether an error is retryable by the [`RetryDnsHandle`] is determined by the
28/// [`RetryableError`] trait.
29///
30/// *note* Current value of this is not clear, it may be removed
31#[derive(Clone)]
32#[must_use = "queries can only be sent through a ClientHandle"]
33pub struct RetryDnsHandle<H>
34where
35    H: DnsHandle + Unpin + Send,
36    H::Error: RetryableError,
37{
38    handle: H,
39    attempts: usize,
40}
41
42impl<H> RetryDnsHandle<H>
43where
44    H: DnsHandle + Unpin + Send,
45    H::Error: RetryableError,
46{
47    /// Creates a new Client handler for reattempting requests on failures.
48    ///
49    /// # Arguments
50    ///
51    /// * `handle` - handle to the dns connection
52    /// * `attempts` - number of attempts before failing
53    pub fn new(handle: H, attempts: usize) -> Self {
54        Self { handle, attempts }
55    }
56
57    /// Returns a shared reference to the underlying handle.
58    pub fn handle(&self) -> &H {
59        &self.handle
60    }
61}
62
63impl<H> DnsHandle for RetryDnsHandle<H>
64where
65    H: DnsHandle + Send + Unpin + 'static,
66    H::Error: RetryableError,
67{
68    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, Self::Error>> + Send + Unpin>>;
69    type Error = <H as DnsHandle>::Error;
70
71    fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
72        let request = request.into();
73
74        // need to clone here so that the retry can resend if necessary...
75        //  obviously it would be nice to be lazy about this...
76        let stream = self.handle.send(request.clone());
77
78        Box::pin(RetrySendStream {
79            request,
80            handle: self.handle.clone(),
81            stream,
82            remaining_attempts: self.attempts,
83        })
84    }
85}
86
87/// A stream for retrying (on failure, for the remaining number of times specified)
88struct RetrySendStream<H>
89where
90    H: DnsHandle,
91{
92    request: DnsRequest,
93    handle: H,
94    stream: <H as DnsHandle>::Response,
95    remaining_attempts: usize,
96}
97
98impl<H: DnsHandle + Unpin> Stream for RetrySendStream<H>
99where
100    <H as DnsHandle>::Error: RetryableError,
101{
102    type Item = Result<DnsResponse, <H as DnsHandle>::Error>;
103
104    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105        // loop over the stream, on errors, spawn a new stream
106        //  on ready and not ready return.
107        loop {
108            match self.stream.poll_next_unpin(cx) {
109                Poll::Ready(Some(Err(e))) => {
110                    if self.remaining_attempts == 0 || !e.should_retry() {
111                        return Poll::Ready(Some(Err(e)));
112                    }
113
114                    if e.attempted() {
115                        self.remaining_attempts -= 1;
116                    }
117
118                    // TODO: if the "sent" Message is part of the error result,
119                    //  then we can just reuse it... and no clone necessary
120                    let request = self.request.clone();
121                    self.stream = self.handle.send(request);
122                }
123                poll => return poll,
124            }
125        }
126    }
127}
128
129/// What errors should be retried
130pub trait RetryableError {
131    /// Whether the query should be retried after this error
132    fn should_retry(&self) -> bool;
133    /// Whether this error should count as an attempt
134    fn attempted(&self) -> bool;
135}
136
137impl RetryableError for ProtoError {
138    fn should_retry(&self) -> bool {
139        true
140    }
141
142    fn attempted(&self) -> bool {
143        !matches!(self.kind(), ProtoErrorKind::Busy)
144    }
145}
146
147#[cfg(test)]
148mod test {
149    use super::*;
150    use crate::error::*;
151    use crate::op::*;
152    use crate::xfer::FirstAnswer;
153    use futures_executor::block_on;
154    use futures_util::future::*;
155    use futures_util::stream::*;
156    use std::sync::{
157        atomic::{AtomicU16, Ordering},
158        Arc,
159    };
160    use DnsHandle;
161
162    #[derive(Clone)]
163    struct TestClient {
164        last_succeed: bool,
165        retries: u16,
166        attempts: Arc<AtomicU16>,
167    }
168
169    impl DnsHandle for TestClient {
170        type Response = Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>;
171        type Error = ProtoError;
172
173        fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
174            let i = self.attempts.load(Ordering::SeqCst);
175
176            if (i > self.retries || self.retries - i == 0) && self.last_succeed {
177                let mut message = Message::new();
178                message.set_id(i);
179                return Box::new(once(ok(message.into())));
180            }
181
182            self.attempts.fetch_add(1, Ordering::SeqCst);
183            Box::new(once(err(ProtoError::from("last retry set to fail"))))
184        }
185    }
186
187    #[test]
188    fn test_retry() {
189        let mut handle = RetryDnsHandle::new(
190            TestClient {
191                last_succeed: true,
192                retries: 1,
193                attempts: Arc::new(AtomicU16::new(0)),
194            },
195            2,
196        );
197        let test1 = Message::new();
198        let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
199        assert_eq!(result.id(), 1); // this is checking the number of iterations the TestClient ran
200    }
201
202    #[test]
203    fn test_error() {
204        let mut client = RetryDnsHandle::new(
205            TestClient {
206                last_succeed: false,
207                retries: 1,
208                attempts: Arc::new(AtomicU16::new(0)),
209            },
210            2,
211        );
212        let test1 = Message::new();
213        assert!(block_on(client.send(test1).first_answer()).is_err());
214    }
215}