trust_dns_resolver/name_server/
name_server.rs1use std::cmp::Ordering;
9use std::fmt::{self, Debug, Formatter};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Instant;
13
14use futures_util::lock::Mutex;
15use futures_util::stream::{once, Stream};
16
17#[cfg(feature = "mdns")]
18use proto::multicast::MDNS_IPV4;
19use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
20use tracing::debug;
21
22#[cfg(feature = "mdns")]
23use crate::config::Protocol;
24use crate::config::{NameServerConfig, ResolverOpts};
25use crate::error::ResolveError;
26use crate::name_server::{ConnectionProvider, NameServerState, NameServerStats};
27#[cfg(feature = "tokio-runtime")]
28use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
29
30#[derive(Clone)]
32pub struct NameServer<
33 C: DnsHandle<Error = ResolveError> + Send,
34 P: ConnectionProvider<Conn = C> + Send,
35> {
36 config: NameServerConfig,
37 options: ResolverOpts,
38 client: Arc<Mutex<Option<C>>>,
39 state: Arc<NameServerState>,
40 stats: Arc<NameServerStats>,
41 conn_provider: P,
42}
43
44impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Debug
45 for NameServer<C, P>
46{
47 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
48 write!(f, "config: {:?}, options: {:?}", self.config, self.options)
49 }
50}
51
52#[cfg(feature = "tokio-runtime")]
53#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
54impl NameServer<TokioConnection, TokioConnectionProvider> {
55 pub fn new(config: NameServerConfig, options: ResolverOpts, runtime: TokioHandle) -> Self {
57 Self::new_with_provider(config, options, TokioConnectionProvider::new(runtime))
58 }
59}
60
61impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> NameServer<C, P> {
62 pub fn new_with_provider(
64 config: NameServerConfig,
65 options: ResolverOpts,
66 conn_provider: P,
67 ) -> Self {
68 Self {
69 config,
70 options,
71 client: Arc::new(Mutex::new(None)),
72 state: Arc::new(NameServerState::init(None)),
73 stats: Arc::new(NameServerStats::default()),
74 conn_provider,
75 }
76 }
77
78 #[doc(hidden)]
79 pub fn from_conn(
80 config: NameServerConfig,
81 options: ResolverOpts,
82 client: C,
83 conn_provider: P,
84 ) -> Self {
85 Self {
86 config,
87 options,
88 client: Arc::new(Mutex::new(Some(client))),
89 state: Arc::new(NameServerState::init(None)),
90 stats: Arc::new(NameServerStats::default()),
91 conn_provider,
92 }
93 }
94
95 #[cfg(test)]
96 #[allow(dead_code)]
97 pub(crate) fn is_connected(&self) -> bool {
98 !self.state.is_failed()
99 && if let Some(client) = self.client.try_lock() {
100 client.is_some()
101 } else {
102 true
104 }
105 }
106
107 async fn connected_mut_client(&mut self) -> Result<C, ResolveError> {
111 let mut client = self.client.lock().await;
112
113 if self.state.is_failed() || client.is_none() {
115 debug!("reconnecting: {:?}", self.config);
116
117 self.state.reinit(None);
119
120 let new_client = self
121 .conn_provider
122 .new_connection(&self.config, &self.options)
123 .await?;
124
125 *client = Some(new_client);
127 } else {
128 debug!("existing connection: {:?}", self.config);
129 }
130
131 Ok((*client)
132 .clone()
133 .expect("bad state, client should be connected"))
134 }
135
136 async fn inner_send<R: Into<DnsRequest> + Unpin + Send + 'static>(
137 mut self,
138 request: R,
139 ) -> Result<DnsResponse, ResolveError> {
140 let mut client = self.connected_mut_client().await?;
141 let response = client.send(request).first_answer().await;
142
143 match response {
144 Ok(response) => {
145 let response =
147 ResolveError::from_response(response, self.config.trust_nx_responses)?;
148
149 let remote_edns = response.extensions().clone();
151
152 self.state.establish(remote_edns);
154
155 self.stats.next_success();
157 Ok(response)
158 }
159 Err(error) => {
160 debug!("name_server connection failure: {}", error);
161
162 self.state.fail(Instant::now());
164
165 self.stats.next_failure();
167
168 Err(error)
170 }
171 }
172 }
173
174 pub fn trust_nx_responses(&self) -> bool {
176 self.config.trust_nx_responses
177 }
178}
179
180impl<C, P> DnsHandle for NameServer<C, P>
181where
182 C: DnsHandle<Error = ResolveError>,
183 P: ConnectionProvider<Conn = C>,
184{
185 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
186 type Error = ResolveError;
187
188 fn is_verifying_dnssec(&self) -> bool {
189 self.options.validate
190 }
191
192 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
194 let this = self.clone();
195 Box::pin(once(this.inner_send(request)))
197 }
198}
199
200impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Ord for NameServer<C, P> {
201 fn cmp(&self, other: &Self) -> Ordering {
203 if self == other {
205 return Ordering::Equal;
206 }
207
208 match self.state.cmp(&other.state) {
213 Ordering::Equal => (),
214 o => {
215 return o;
216 }
217 }
218
219 self.stats.cmp(&other.stats)
220 }
221}
222
223impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> PartialOrd
224 for NameServer<C, P>
225{
226 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
227 Some(self.cmp(other))
228 }
229}
230
231impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> PartialEq
232 for NameServer<C, P>
233{
234 fn eq(&self, other: &Self) -> bool {
236 self.config == other.config
237 }
238}
239
240impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Eq for NameServer<C, P> {}
241
242#[cfg(feature = "mdns")]
244pub(crate) fn mdns_nameserver<C, P>(
245 options: ResolverOpts,
246 conn_provider: P,
247 trust_nx_responses: bool,
248) -> NameServer<C, P>
249where
250 C: DnsHandle<Error = ResolveError>,
251 P: ConnectionProvider<Conn = C>,
252{
253 let config = NameServerConfig {
254 socket_addr: *MDNS_IPV4,
255 protocol: Protocol::Mdns,
256 tls_dns_name: None,
257 trust_nx_responses,
258 #[cfg(feature = "dns-over-rustls")]
259 tls_config: None,
260 bind_addr: None,
261 };
262 NameServer::new_with_provider(config, options, conn_provider)
263}
264
265#[cfg(test)]
266#[cfg(feature = "tokio-runtime")]
267mod tests {
268 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
269 use std::time::Duration;
270
271 use futures_util::{future, FutureExt};
272 use tokio::runtime::Runtime;
273
274 use proto::op::{Query, ResponseCode};
275 use proto::rr::{Name, RecordType};
276 use proto::xfer::{DnsHandle, DnsRequestOptions, FirstAnswer};
277
278 use super::*;
279 use crate::config::Protocol;
280
281 #[test]
282 fn test_name_server() {
283 let config = NameServerConfig {
286 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
287 protocol: Protocol::Udp,
288 tls_dns_name: None,
289 trust_nx_responses: false,
290 #[cfg(feature = "dns-over-rustls")]
291 tls_config: None,
292 bind_addr: None,
293 };
294 let io_loop = Runtime::new().unwrap();
295 let runtime_handle = TokioHandle;
296 let name_server = future::lazy(|_| {
297 NameServer::<_, TokioConnectionProvider>::new(
298 config,
299 ResolverOpts::default(),
300 runtime_handle,
301 )
302 });
303
304 let name = Name::parse("www.example.com.", None).unwrap();
305 let response = io_loop
306 .block_on(name_server.then(|mut name_server| {
307 name_server
308 .lookup(
309 Query::query(name.clone(), RecordType::A),
310 DnsRequestOptions::default(),
311 )
312 .first_answer()
313 }))
314 .expect("query failed");
315 assert_eq!(response.response_code(), ResponseCode::NoError);
316 }
317
318 #[test]
319 fn test_failed_name_server() {
320 let options = ResolverOpts {
321 timeout: Duration::from_millis(1), ..ResolverOpts::default()
323 };
324 let config = NameServerConfig {
325 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 252),
326 protocol: Protocol::Udp,
327 tls_dns_name: None,
328 trust_nx_responses: false,
329 #[cfg(feature = "dns-over-rustls")]
330 tls_config: None,
331 bind_addr: None,
332 };
333 let io_loop = Runtime::new().unwrap();
334 let runtime_handle = TokioHandle;
335 let name_server = future::lazy(|_| {
336 NameServer::<_, TokioConnectionProvider>::new(config, options, runtime_handle)
337 });
338
339 let name = Name::parse("www.example.com.", None).unwrap();
340 assert!(io_loop
341 .block_on(name_server.then(|mut name_server| {
342 name_server
343 .lookup(
344 Query::query(name.clone(), RecordType::A),
345 DnsRequestOptions::default(),
346 )
347 .first_answer()
348 }))
349 .is_err());
350 }
351}