Skip to main content

trust_dns_resolver/name_server/
name_server.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::fmt::{self, Debug, Formatter};
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::Instant;
13
14use futures_util::lock::Mutex;
15use futures_util::stream::{once, Stream};
16
17#[cfg(feature = "mdns")]
18use proto::multicast::MDNS_IPV4;
19use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
20use tracing::debug;
21
22#[cfg(feature = "mdns")]
23use crate::config::Protocol;
24use crate::config::{NameServerConfig, ResolverOpts};
25use crate::error::ResolveError;
26use crate::name_server::{
27    ConnectionProvider, InternalNameServerStats, NameServerState, NameServerStats,
28};
29#[cfg(feature = "tokio-runtime")]
30use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
31
32/// Specifies the details of a remote NameServer used for lookups
33#[derive(Clone)]
34pub struct NameServer<
35    C: DnsHandle<Error = ResolveError> + Send,
36    P: ConnectionProvider<Conn = C> + Send,
37> {
38    config: NameServerConfig,
39    options: ResolverOpts,
40    client: Arc<Mutex<Option<C>>>,
41    state: Arc<NameServerState>,
42    stats: Arc<std::sync::Mutex<InternalNameServerStats>>,
43    conn_provider: P,
44}
45
46impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Debug
47    for NameServer<C, P>
48{
49    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
50        write!(f, "config: {:?}, options: {:?}", self.config, self.options)
51    }
52}
53
54#[cfg(feature = "tokio-runtime")]
55#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
56impl NameServer<TokioConnection, TokioConnectionProvider> {
57    /// A shortcut for constructing a nameserver usable in the Tokio runtime
58    pub fn new(config: NameServerConfig, options: ResolverOpts, runtime: TokioHandle) -> Self {
59        Self::new_with_provider(config, options, TokioConnectionProvider::new(runtime))
60    }
61}
62
63impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> NameServer<C, P> {
64    /// Construct a new Nameserver with the configuration and options. The connection provider will create UDP and TCP sockets
65    pub fn new_with_provider(
66        config: NameServerConfig,
67        options: ResolverOpts,
68        conn_provider: P,
69    ) -> Self {
70        let retained_errors = config.num_retained_errors;
71        Self {
72            config,
73            options,
74            client: Arc::new(Mutex::new(None)),
75            state: Arc::new(NameServerState::init(None)),
76            stats: Arc::new(std::sync::Mutex::new(InternalNameServerStats::new(retained_errors))),
77            conn_provider,
78        }
79    }
80
81    #[doc(hidden)]
82    pub fn from_conn(
83        config: NameServerConfig,
84        options: ResolverOpts,
85        client: C,
86        conn_provider: P,
87    ) -> Self {
88        let retained_errors = config.num_retained_errors;
89        Self {
90            config,
91            options,
92            client: Arc::new(Mutex::new(Some(client))),
93            state: Arc::new(NameServerState::init(None)),
94            stats: Arc::new(std::sync::Mutex::new(InternalNameServerStats::new(retained_errors))),
95            conn_provider,
96        }
97    }
98
99    #[cfg(test)]
100    #[allow(dead_code)]
101    pub(crate) fn is_connected(&self) -> bool {
102        !self.state.is_failed()
103            && if let Some(client) = self.client.try_lock() {
104                client.is_some()
105            } else {
106                // assuming that if someone has it locked it will be or is connected
107                true
108            }
109    }
110
111    /// This will return a mutable client to allows for sending messages.
112    ///
113    /// If the connection is in a failed state, then this will establish a new connection
114    async fn connected_mut_client(&mut self) -> Result<C, ResolveError> {
115        let mut client = self.client.lock().await;
116
117        // if this is in a failure state
118        if self.state.is_failed() || client.is_none() {
119            debug!("reconnecting: {:?}", self.config);
120
121            // TODO: we need the local EDNS options
122            self.state.reinit(None);
123
124            let new_client = self
125                .conn_provider
126                .new_connection(&self.config, &self.options)
127                .await?;
128
129            // establish a new connection
130            *client = Some(new_client);
131        } else {
132            debug!("existing connection: {:?}", self.config);
133        }
134
135        Ok((*client)
136            .clone()
137            .expect("bad state, client should be connected"))
138    }
139
140    async fn inner_send<R: Into<DnsRequest> + Unpin + Send + 'static>(
141        mut self,
142        request: R,
143    ) -> Result<DnsResponse, ResolveError> {
144        let mut client = self.connected_mut_client().await?;
145        let response = client.send(request).first_answer().await;
146
147        match response {
148            Ok(response) => {
149                // First evaluate if the message succeeded.
150                let response =
151                    match ResolveError::from_response(response, self.config.trust_nx_responses) {
152                        Ok(response) => response,
153                        Err(e) => {
154                            self.stats.lock().unwrap().next_failure(e.kind.clone());
155                            return Err(e);
156                        }
157                    };
158
159                // TODO: consider making message::take_edns...
160                let remote_edns = response.extensions().clone();
161
162                // take the remote edns options and store them
163                self.state.establish(remote_edns);
164
165                // record the success
166                self.stats.lock().unwrap().next_success();
167                Ok(response)
168            }
169            Err(error) => {
170                debug!("name_server connection failure: {}", error);
171
172                // this transitions the state to failure
173                self.state.fail(Instant::now());
174
175                // record the failure
176                self.stats.lock().unwrap().next_failure(error.kind.clone());
177
178                // These are connection failures, not lookup failures, that is handled in the resolver layer
179                Err(error)
180            }
181        }
182    }
183
184    /// Specifies that thie NameServer will treat negative responses as permanent failures and will not retry
185    pub fn trust_nx_responses(&self) -> bool {
186        self.config.trust_nx_responses
187    }
188
189    pub fn stats(&self) -> NameServerStats {
190        self.stats.lock().unwrap().export(self.config.socket_addr, self.config.protocol)
191    }
192}
193
194impl<C, P> DnsHandle for NameServer<C, P>
195where
196    C: DnsHandle<Error = ResolveError>,
197    P: ConnectionProvider<Conn = C>,
198{
199    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
200    type Error = ResolveError;
201
202    fn is_verifying_dnssec(&self) -> bool {
203        self.options.validate
204    }
205
206    // TODO: there needs to be some way of customizing the connection based on EDNS options from the server side...
207    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
208        let this = self.clone();
209        // if state is failed, return future::err(), unless retry delay expired..
210        Box::pin(once(this.inner_send(request)))
211    }
212}
213
214impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Ord for NameServer<C, P> {
215    /// Custom implementation of Ord for NameServer which incorporates the performance of the connection into it's ranking
216    fn cmp(&self, other: &Self) -> Ordering {
217        // if they are literally equal, just return
218        if self == other {
219            return Ordering::Equal;
220        }
221
222        // otherwise, run our evaluation to determine the next to be returned from the Heap
223        //   this will prefer established connections, we should try other connections after
224        //   some number to make sure that all are used. This is more important for when
225        //   latency is started to be used.
226        match self.state.cmp(&other.state) {
227            Ordering::Equal => (),
228            o => {
229                return o;
230            }
231        }
232
233        // Avoid deadlock by cloning data out from under the locks.
234        let self_stats = self.stats.lock().unwrap().clone();
235        let other_stats = other.stats.lock().unwrap().clone();
236
237        self_stats.cmp(&other_stats)
238    }
239}
240
241impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> PartialOrd
242    for NameServer<C, P>
243{
244    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
245        Some(self.cmp(other))
246    }
247}
248
249impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> PartialEq
250    for NameServer<C, P>
251{
252    /// NameServers are equal if the config (connection information) are equal
253    fn eq(&self, other: &Self) -> bool {
254        self.config == other.config
255    }
256}
257
258impl<C: DnsHandle<Error = ResolveError>, P: ConnectionProvider<Conn = C>> Eq for NameServer<C, P> {}
259
260// TODO: once IPv6 is better understood, also make this a binary keep.
261#[cfg(feature = "mdns")]
262pub(crate) fn mdns_nameserver<C, P>(
263    options: ResolverOpts,
264    conn_provider: P,
265    trust_nx_responses: bool,
266) -> NameServer<C, P>
267where
268    C: DnsHandle<Error = ResolveError>,
269    P: ConnectionProvider<Conn = C>,
270{
271    let config = NameServerConfig {
272        socket_addr: *MDNS_IPV4,
273        protocol: Protocol::Mdns,
274        tls_dns_name: None,
275        trust_nx_responses,
276        #[cfg(feature = "dns-over-rustls")]
277        tls_config: None,
278        bind_addr: None,
279    };
280    NameServer::new_with_provider(config, options, conn_provider)
281}
282
283#[cfg(test)]
284#[cfg(feature = "tokio-runtime")]
285mod tests {
286    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
287    use std::time::Duration;
288
289    use futures_util::{future, FutureExt};
290    use tokio::runtime::Runtime;
291
292    use proto::op::{Query, ResponseCode};
293    use proto::rr::{Name, RecordType};
294    use proto::xfer::{DnsHandle, DnsRequestOptions, FirstAnswer};
295
296    use super::*;
297    use crate::config::Protocol;
298
299    #[test]
300    fn test_name_server() {
301        //env_logger::try_init().ok();
302
303        let config = NameServerConfig {
304            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
305            protocol: Protocol::Udp,
306            tls_dns_name: None,
307            trust_nx_responses: false,
308            #[cfg(feature = "dns-over-rustls")]
309            tls_config: None,
310            bind_addr: None,
311        };
312        let io_loop = Runtime::new().unwrap();
313        let runtime_handle = TokioHandle;
314        let name_server = future::lazy(|_| {
315            NameServer::<_, TokioConnectionProvider>::new(
316                config,
317                ResolverOpts::default(),
318                runtime_handle,
319            )
320        });
321
322        let name = Name::parse("www.example.com.", None).unwrap();
323        let response = io_loop
324            .block_on(name_server.then(|mut name_server| {
325                name_server
326                    .lookup(
327                        Query::query(name.clone(), RecordType::A),
328                        DnsRequestOptions::default(),
329                    )
330                    .first_answer()
331            }))
332            .expect("query failed");
333        assert_eq!(response.response_code(), ResponseCode::NoError);
334    }
335
336    #[test]
337    fn test_failed_name_server() {
338        let options = ResolverOpts {
339            timeout: Duration::from_millis(1), // this is going to fail, make it fail fast...
340            ..ResolverOpts::default()
341        };
342        let config = NameServerConfig {
343            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 252),
344            protocol: Protocol::Udp,
345            tls_dns_name: None,
346            trust_nx_responses: false,
347            #[cfg(feature = "dns-over-rustls")]
348            tls_config: None,
349            bind_addr: None,
350        };
351        let io_loop = Runtime::new().unwrap();
352        let runtime_handle = TokioHandle;
353        let name_server = future::lazy(|_| {
354            NameServer::<_, TokioConnectionProvider>::new(config, options, runtime_handle)
355        });
356
357        let name = Name::parse("www.example.com.", None).unwrap();
358        assert!(io_loop
359            .block_on(name_server.then(|mut name_server| {
360                name_server
361                    .lookup(
362                        Query::query(name.clone(), RecordType::A),
363                        DnsRequestOptions::default(),
364                    )
365                    .first_answer()
366            }))
367            .is_err());
368    }
369}