trust_dns_proto/xfer/
retry_dns_handle.rs1use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use futures_util::stream::{Stream, StreamExt};
14
15use crate::error::{ProtoError, ProtoErrorKind};
16use crate::xfer::{DnsRequest, DnsResponse};
17use crate::DnsHandle;
18
19#[derive(Clone)]
32#[must_use = "queries can only be sent through a ClientHandle"]
33pub struct RetryDnsHandle<H>
34where
35 H: DnsHandle + Unpin + Send,
36 H::Error: RetryableError,
37{
38 handle: H,
39 attempts: usize,
40}
41
42impl<H> RetryDnsHandle<H>
43where
44 H: DnsHandle + Unpin + Send,
45 H::Error: RetryableError,
46{
47 pub fn new(handle: H, attempts: usize) -> Self {
54 Self { handle, attempts }
55 }
56
57 pub fn handle(&self) -> &H {
59 &self.handle
60 }
61}
62
63impl<H> DnsHandle for RetryDnsHandle<H>
64where
65 H: DnsHandle + Send + Unpin + 'static,
66 H::Error: RetryableError,
67{
68 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, Self::Error>> + Send + Unpin>>;
69 type Error = <H as DnsHandle>::Error;
70
71 fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
72 let request = request.into();
73
74 let stream = self.handle.send(request.clone());
77
78 Box::pin(RetrySendStream {
79 request,
80 handle: self.handle.clone(),
81 stream,
82 remaining_attempts: self.attempts,
83 })
84 }
85}
86
87struct RetrySendStream<H>
89where
90 H: DnsHandle,
91{
92 request: DnsRequest,
93 handle: H,
94 stream: <H as DnsHandle>::Response,
95 remaining_attempts: usize,
96}
97
98impl<H: DnsHandle + Unpin> Stream for RetrySendStream<H>
99where
100 <H as DnsHandle>::Error: RetryableError,
101{
102 type Item = Result<DnsResponse, <H as DnsHandle>::Error>;
103
104 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105 loop {
108 match self.stream.poll_next_unpin(cx) {
109 Poll::Ready(Some(Err(e))) => {
110 if self.remaining_attempts == 0 || !e.should_retry() {
111 return Poll::Ready(Some(Err(e)));
112 }
113
114 if e.attempted() {
115 self.remaining_attempts -= 1;
116 }
117
118 let request = self.request.clone();
121 self.stream = self.handle.send(request);
122 }
123 poll => return poll,
124 }
125 }
126 }
127}
128
129pub trait RetryableError {
131 fn should_retry(&self) -> bool;
133 fn attempted(&self) -> bool;
135}
136
137impl RetryableError for ProtoError {
138 fn should_retry(&self) -> bool {
139 true
140 }
141
142 fn attempted(&self) -> bool {
143 !matches!(self.kind(), ProtoErrorKind::Busy)
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use super::*;
150 use crate::error::*;
151 use crate::op::*;
152 use crate::xfer::FirstAnswer;
153 use futures_executor::block_on;
154 use futures_util::future::*;
155 use futures_util::stream::*;
156 use std::sync::{
157 atomic::{AtomicU16, Ordering},
158 Arc,
159 };
160 use DnsHandle;
161
162 #[derive(Clone)]
163 struct TestClient {
164 last_succeed: bool,
165 retries: u16,
166 attempts: Arc<AtomicU16>,
167 }
168
169 impl DnsHandle for TestClient {
170 type Response = Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>;
171 type Error = ProtoError;
172
173 fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
174 let i = self.attempts.load(Ordering::SeqCst);
175
176 if (i > self.retries || self.retries - i == 0) && self.last_succeed {
177 let mut message = Message::new();
178 message.set_id(i);
179 return Box::new(once(ok(message.into())));
180 }
181
182 self.attempts.fetch_add(1, Ordering::SeqCst);
183 Box::new(once(err(ProtoError::from("last retry set to fail"))))
184 }
185 }
186
187 #[test]
188 fn test_retry() {
189 let mut handle = RetryDnsHandle::new(
190 TestClient {
191 last_succeed: true,
192 retries: 1,
193 attempts: Arc::new(AtomicU16::new(0)),
194 },
195 2,
196 );
197 let test1 = Message::new();
198 let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
199 assert_eq!(result.id(), 1); }
201
202 #[test]
203 fn test_error() {
204 let mut client = RetryDnsHandle::new(
205 TestClient {
206 last_succeed: false,
207 retries: 1,
208 attempts: Arc::new(AtomicU16::new(0)),
209 },
210 2,
211 );
212 let test1 = Message::new();
213 assert!(block_on(client.send(test1).first_answer()).is_err());
214 }
215}