trust_dns_proto/xfer/
dns_multiplexer.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! `DnsMultiplexer` and associated types implement the state machines for sending DNS messages while using the underlying streams.
9
10use std::borrow::Borrow;
11use std::collections::hash_map::Entry;
12use std::collections::HashMap;
13use std::fmt::{self, Display};
14use std::marker::Unpin;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::{Duration, SystemTime, UNIX_EPOCH};
19
20use futures_channel::mpsc;
21use futures_util::stream::{Stream, StreamExt};
22use futures_util::{future::Future, ready, FutureExt};
23use rand;
24use rand::distributions::{Distribution, Standard};
25use tracing::{debug, warn};
26
27use crate::error::*;
28use crate::op::{MessageFinalizer, MessageVerifier};
29use crate::xfer::{
30    ignore_send, BufDnsStreamHandle, DnsClientStream, DnsRequest, DnsRequestSender, DnsResponse,
31    DnsResponseStream, SerialMessage, CHANNEL_BUFFER_SIZE,
32};
33use crate::DnsStreamHandle;
34use crate::Time;
35
36const QOS_MAX_RECEIVE_MSGS: usize = 100; // max number of messages to receive from the UDP socket
37
38struct ActiveRequest {
39    // the completion is the channel for a response to the original request
40    completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
41    request_id: u16,
42    timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
43    verifier: Option<MessageVerifier>,
44}
45
46impl ActiveRequest {
47    fn new(
48        completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
49        request_id: u16,
50        timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
51        verifier: Option<MessageVerifier>,
52    ) -> Self {
53        Self {
54            completion,
55            request_id,
56            // request,
57            timeout,
58            verifier,
59        }
60    }
61
62    /// polls the timeout and converts the error
63    fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
64        self.timeout.poll_unpin(cx)
65    }
66
67    /// Returns true of the other side canceled the request
68    fn is_canceled(&self) -> bool {
69        self.completion.is_closed()
70    }
71
72    /// the request id of the message that was sent
73    fn request_id(&self) -> u16 {
74        self.request_id
75    }
76
77    /// Sends an error
78    fn complete_with_error(mut self, error: ProtoError) {
79        ignore_send(self.completion.try_send(Err(error)));
80    }
81}
82
83/// A DNS Client implemented over futures-rs.
84///
85/// This Client is generic and capable of wrapping UDP, TCP, and other underlying DNS protocol
86///  implementations. This should be used for underlying protocols that do not natively support
87///  multiplexed sessions.
88#[must_use = "futures do nothing unless polled"]
89pub struct DnsMultiplexer<S, MF>
90where
91    S: DnsClientStream + 'static,
92    MF: MessageFinalizer,
93{
94    stream: S,
95    timeout_duration: Duration,
96    stream_handle: BufDnsStreamHandle,
97    active_requests: HashMap<u16, ActiveRequest>,
98    signer: Option<Arc<MF>>,
99    is_shutdown: bool,
100}
101
102impl<S, MF> DnsMultiplexer<S, MF>
103where
104    S: DnsClientStream + Unpin + 'static,
105    MF: MessageFinalizer,
106{
107    /// Spawns a new DnsMultiplexer Stream. This uses a default timeout of 5 seconds for all requests.
108    ///
109    /// # Arguments
110    ///
111    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
112    ///              (see TcpClientStream or UdpClientStream)
113    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
114    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
115    #[allow(clippy::new_ret_no_self)]
116    pub fn new<F>(
117        stream: F,
118        stream_handle: BufDnsStreamHandle,
119        signer: Option<Arc<MF>>,
120    ) -> DnsMultiplexerConnect<F, S, MF>
121    where
122        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
123    {
124        Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
125    }
126
127    /// Spawns a new DnsMultiplexer Stream.
128    ///
129    /// # Arguments
130    ///
131    /// * `stream` - A stream of bytes that can be used to send/receive DNS messages
132    ///              (see TcpClientStream or UdpClientStream)
133    /// * `timeout_duration` - All requests may fail due to lack of response, this is the time to
134    ///                        wait for a response before canceling the request.
135    /// * `stream_handle` - The handle for the `stream` on which bytes can be sent/received.
136    /// * `signer` - An optional signer for requests, needed for Updates with Sig0, otherwise not needed
137    pub fn with_timeout<F>(
138        stream: F,
139        stream_handle: BufDnsStreamHandle,
140        timeout_duration: Duration,
141        signer: Option<Arc<MF>>,
142    ) -> DnsMultiplexerConnect<F, S, MF>
143    where
144        F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
145    {
146        DnsMultiplexerConnect {
147            stream,
148            stream_handle: Some(stream_handle),
149            timeout_duration,
150            signer,
151        }
152    }
153
154    /// loop over active_requests and remove cancelled requests
155    ///  this should free up space if we already had 4096 active requests
156    fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
157        let mut canceled = HashMap::<u16, ProtoError>::new();
158        for (&id, ref mut active_req) in &mut self.active_requests {
159            if active_req.is_canceled() {
160                canceled.insert(id, ProtoError::from("requestor canceled"));
161            }
162
163            // check for timeouts...
164            match active_req.poll_timeout(cx) {
165                Poll::Ready(()) => {
166                    debug!("request timed out: {}", id);
167                    canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
168                }
169                Poll::Pending => (),
170            }
171        }
172
173        // drop all the canceled requests
174        for (id, error) in canceled {
175            if let Some(active_request) = self.active_requests.remove(&id) {
176                // complete the request, it's failed...
177                active_request.complete_with_error(error);
178            }
179        }
180    }
181
182    /// creates random query_id, validates against all active queries
183    fn next_random_query_id(&self) -> Result<u16, ProtoError> {
184        let mut rand = rand::thread_rng();
185
186        for _ in 0..100 {
187            let id: u16 = Standard.sample(&mut rand); // the range is [0 ... u16::max]
188
189            if !self.active_requests.contains_key(&id) {
190                return Ok(id);
191            }
192        }
193
194        Err(ProtoError::from(
195            "id space exhausted, consider filing an issue",
196        ))
197    }
198
199    /// Closes all outstanding completes with a closed stream error
200    fn stream_closed_close_all(&mut self, error: ProtoError) {
201        if !self.active_requests.is_empty() {
202            warn!("stream {} error: {}", self.stream, error);
203        } else {
204            debug!("stream {} error: {}", self.stream, error);
205        }
206
207        for (_, active_request) in self.active_requests.drain() {
208            // complete the request, it's failed...
209            active_request.complete_with_error(error.clone());
210        }
211    }
212}
213
214/// A wrapper for a future DnsExchange connection
215#[must_use = "futures do nothing unless polled"]
216pub struct DnsMultiplexerConnect<F, S, MF>
217where
218    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
219    S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
220    MF: MessageFinalizer + Send + Sync + 'static,
221{
222    stream: F,
223    stream_handle: Option<BufDnsStreamHandle>,
224    timeout_duration: Duration,
225    signer: Option<Arc<MF>>,
226}
227
228impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
229where
230    F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
231    S: DnsClientStream + Unpin + 'static,
232    MF: MessageFinalizer + Send + Sync + 'static,
233{
234    type Output = Result<DnsMultiplexer<S, MF>, ProtoError>;
235
236    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        let stream: S = ready!(self.stream.poll_unpin(cx))?;
238
239        Poll::Ready(Ok(DnsMultiplexer {
240            stream,
241            timeout_duration: self.timeout_duration,
242            stream_handle: self
243                .stream_handle
244                .take()
245                .expect("must not poll after complete"),
246            active_requests: HashMap::new(),
247            signer: self.signer.clone(),
248            is_shutdown: false,
249        }))
250    }
251}
252
253impl<S, MF> Display for DnsMultiplexer<S, MF>
254where
255    S: DnsClientStream + 'static,
256    MF: MessageFinalizer + Send + Sync + 'static,
257{
258    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
259        write!(formatter, "{}", self.stream)
260    }
261}
262
263impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
264where
265    S: DnsClientStream + Unpin + 'static,
266    MF: MessageFinalizer + Send + Sync + 'static,
267{
268    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
269        if self.is_shutdown {
270            panic!("can not send messages after stream is shutdown")
271        }
272
273        if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
274            return ProtoError::from(ProtoErrorKind::Busy).into();
275        }
276
277        let query_id = match self.next_random_query_id() {
278            Ok(id) => id,
279            Err(e) => return e.into(),
280        };
281
282        let (mut request, _) = request.into_parts();
283        request.set_id(query_id);
284
285        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
286            Ok(now) => now.as_secs(),
287            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
288        };
289
290        // TODO: truncates u64 to u32, error on overflow?
291        let now = now as u32;
292
293        let mut verifier = None;
294        if let Some(ref signer) = self.signer {
295            if signer.should_finalize_message(&request) {
296                match request.finalize::<MF>(signer.borrow(), now) {
297                    Ok(answer_verifier) => verifier = answer_verifier,
298                    Err(e) => {
299                        debug!("could not sign message: {}", e);
300                        return e.into();
301                    }
302                }
303            }
304        }
305
306        // store a Timeout for this message before sending
307        let timeout = S::Time::delay_for(self.timeout_duration);
308
309        let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
310
311        // send the message
312        let active_request =
313            ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
314
315        match request.to_vec() {
316            Ok(buffer) => {
317                debug!("sending message id: {}", active_request.request_id());
318                let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
319
320                debug!(
321                    "final message: {}",
322                    serial_message
323                        .to_message()
324                        .expect("bizarre we just made this message")
325                );
326
327                // add to the map -after- the client send b/c we don't want to put it in the map if
328                //  we ended up returning an error from the send.
329                match self.stream_handle.send(serial_message) {
330                    Ok(()) => self
331                        .active_requests
332                        .insert(active_request.request_id(), active_request),
333                    Err(err) => return err.into(),
334                };
335            }
336            Err(e) => {
337                debug!(
338                    "error message id: {} error: {}",
339                    active_request.request_id(),
340                    e
341                );
342                // complete with the error, don't add to the map of active requests
343                return e.into();
344            }
345        }
346
347        receiver.into()
348    }
349
350    fn shutdown(&mut self) {
351        self.is_shutdown = true;
352    }
353
354    fn is_shutdown(&self) -> bool {
355        self.is_shutdown
356    }
357}
358
359impl<S, MF> Stream for DnsMultiplexer<S, MF>
360where
361    S: DnsClientStream + Unpin + 'static,
362    MF: MessageFinalizer + Send + Sync + 'static,
363{
364    type Item = Result<(), ProtoError>;
365
366    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
367        // Always drop the cancelled queries first
368        self.drop_cancelled(cx);
369
370        if self.is_shutdown && self.active_requests.is_empty() {
371            debug!("stream is done: {}", self);
372            return Poll::Ready(None);
373        }
374
375        // Collect all inbound requests, max 100 at a time for QoS
376        //   by having a max we will guarantee that the client can't be DOSed in this loop
377        // TODO: make the QoS configurable
378        let mut messages_received = 0;
379        for i in 0..QOS_MAX_RECEIVE_MSGS {
380            match self.stream.poll_next_unpin(cx) {
381                Poll::Ready(Some(Ok(buffer))) => {
382                    messages_received = i;
383
384                    //   deserialize or log decode_error
385                    match buffer.to_message() {
386                        Ok(message) => match self.active_requests.entry(message.id()) {
387                            Entry::Occupied(mut request_entry) => {
388                                // send the response, complete the request...
389                                let active_request = request_entry.get_mut();
390                                if let Some(ref mut verifier) = active_request.verifier {
391                                    ignore_send(
392                                        active_request
393                                            .completion
394                                            .try_send(verifier(buffer.bytes())),
395                                    );
396                                } else {
397                                    ignore_send(
398                                        active_request.completion.try_send(Ok(message.into())),
399                                    );
400                                }
401                            }
402                            Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
403                        },
404                        // TODO: return src address for diagnostics
405                        Err(e) => debug!("error decoding message: {}", e),
406                    }
407                }
408                Poll::Ready(err) => {
409                    let err = match err {
410                        Some(Err(e)) => e,
411                        None => ProtoError::from("stream closed"),
412                        _ => unreachable!(),
413                    };
414
415                    self.stream_closed_close_all(err);
416                    self.is_shutdown = true;
417                    return Poll::Ready(None);
418                }
419                Poll::Pending => break,
420            }
421        }
422
423        // If still active, then if the qos (for _ in 0..100 loop) limit
424        // was hit then "yield". This'll make sure that the future is
425        // woken up immediately on the next turn of the event loop.
426        if messages_received == QOS_MAX_RECEIVE_MSGS {
427            // FIXME: this was a task::current().notify(); is this right?
428            cx.waker().wake_by_ref();
429        }
430
431        // Finally, return not ready to keep the 'driver task' alive.
432        Poll::Pending
433    }
434}
435
436#[cfg(test)]
437mod test {
438    use super::*;
439    use crate::op::message::NoopMessageFinalizer;
440    use crate::op::op_code::OpCode;
441    use crate::op::{Message, MessageType, Query};
442    use crate::rr::record_type::RecordType;
443    use crate::rr::{DNSClass, Name, RData, Record};
444    use crate::serialize::binary::BinEncodable;
445    use crate::xfer::StreamReceiver;
446    use crate::xfer::{DnsClientStream, DnsRequestOptions};
447    use futures_util::future;
448    use futures_util::stream::TryStreamExt;
449    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
450
451    struct MockClientStream {
452        messages: Vec<Message>,
453        addr: SocketAddr,
454        id: Option<u16>,
455        receiver: Option<StreamReceiver>,
456    }
457
458    impl MockClientStream {
459        fn new(
460            mut messages: Vec<Message>,
461            addr: SocketAddr,
462        ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
463            messages.reverse(); // so we can pop() and get messages in order
464            Box::pin(future::ok(Self {
465                messages,
466                addr,
467                id: None,
468                receiver: None,
469            }))
470        }
471    }
472
473    impl fmt::Display for MockClientStream {
474        fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
475            write!(formatter, "TestClientStream")
476        }
477    }
478
479    impl Stream for MockClientStream {
480        type Item = Result<SerialMessage, ProtoError>;
481
482        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
483            let id = if let Some(id) = self.id {
484                id
485            } else {
486                let serial = ready!(self
487                    .receiver
488                    .as_mut()
489                    .expect("should only be polled after receiver has been set")
490                    .poll_next_unpin(cx));
491                let message = serial.unwrap().to_message().unwrap();
492                self.id = Some(message.id());
493                message.id()
494            };
495
496            if let Some(mut message) = self.messages.pop() {
497                message.set_id(id);
498                Poll::Ready(Some(Ok(SerialMessage::new(
499                    message.to_bytes().unwrap(),
500                    self.addr,
501                ))))
502            } else {
503                Poll::Pending
504            }
505        }
506    }
507
508    impl DnsClientStream for MockClientStream {
509        type Time = crate::TokioTime;
510
511        fn name_server_addr(&self) -> SocketAddr {
512            self.addr
513        }
514    }
515
516    async fn get_mocked_multiplexer(
517        mock_response: Vec<Message>,
518    ) -> DnsMultiplexer<MockClientStream, NoopMessageFinalizer> {
519        let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
520        let mock_response = MockClientStream::new(mock_response, addr);
521        let (handler, receiver) = BufDnsStreamHandle::new(addr);
522        let mut multiplexer =
523            DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
524                .await
525                .unwrap();
526
527        multiplexer.stream.receiver = Some(receiver); // so it can get the correct request id
528
529        multiplexer
530    }
531
532    fn a_query_answer() -> (DnsRequest, Vec<Message>) {
533        let name = Name::from_ascii("www.example.com").unwrap();
534
535        let mut msg = Message::new();
536        msg.add_query({
537            let mut query = Query::query(name.clone(), RecordType::A);
538            query.set_query_class(DNSClass::IN);
539            query
540        })
541        .set_message_type(MessageType::Query)
542        .set_op_code(OpCode::Query)
543        .set_recursion_desired(true);
544
545        let query = msg.clone();
546        msg.set_message_type(MessageType::Response).add_answer(
547            Record::new()
548                .set_name(name)
549                .set_ttl(86400)
550                .set_rr_type(RecordType::A)
551                .set_dns_class(DNSClass::IN)
552                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
553                .clone(),
554        );
555        (
556            DnsRequest::new(query, DnsRequestOptions::default()),
557            vec![msg],
558        )
559    }
560
561    fn axfr_query() -> Message {
562        let name = Name::from_ascii("example.com").unwrap();
563
564        let mut msg = Message::new();
565        msg.add_query({
566            let mut query = Query::query(name, RecordType::AXFR);
567            query.set_query_class(DNSClass::IN);
568            query
569        })
570        .set_message_type(MessageType::Query)
571        .set_op_code(OpCode::Query)
572        .set_recursion_desired(true);
573        msg
574    }
575
576    fn axfr_response() -> Vec<Record> {
577        use crate::rr::rdata::*;
578        let origin = Name::from_ascii("example.com").unwrap();
579        let soa = Record::new()
580            .set_name(origin.clone())
581            .set_ttl(3600)
582            .set_rr_type(RecordType::SOA)
583            .set_dns_class(DNSClass::IN)
584            .set_data(Some(RData::SOA(SOA::new(
585                Name::parse("sns.dns.icann.org.", None).unwrap(),
586                Name::parse("noc.dns.icann.org.", None).unwrap(),
587                2015082403,
588                7200,
589                3600,
590                1209600,
591                3600,
592            ))))
593            .clone();
594
595        vec![
596            soa.clone(),
597            Record::new()
598                .set_name(origin.clone())
599                .set_ttl(86400)
600                .set_rr_type(RecordType::NS)
601                .set_dns_class(DNSClass::IN)
602                .set_data(Some(RData::NS(
603                    Name::parse("a.iana-servers.net.", None).unwrap(),
604                )))
605                .clone(),
606            Record::new()
607                .set_name(origin.clone())
608                .set_ttl(86400)
609                .set_rr_type(RecordType::NS)
610                .set_dns_class(DNSClass::IN)
611                .set_data(Some(RData::NS(
612                    Name::parse("b.iana-servers.net.", None).unwrap(),
613                )))
614                .clone(),
615            Record::new()
616                .set_name(origin.clone())
617                .set_ttl(86400)
618                .set_rr_type(RecordType::A)
619                .set_dns_class(DNSClass::IN)
620                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
621                .clone(),
622            Record::new()
623                .set_name(origin)
624                .set_ttl(86400)
625                .set_rr_type(RecordType::AAAA)
626                .set_dns_class(DNSClass::IN)
627                .set_data(Some(RData::AAAA(Ipv6Addr::new(
628                    0x2606, 0x2800, 0x220, 0x1, 0x248, 0x1893, 0x25c8, 0x1946,
629                ))))
630                .clone(),
631            soa,
632        ]
633    }
634
635    fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
636        let mut msg = axfr_query();
637
638        let query = msg.clone();
639        msg.set_message_type(MessageType::Response)
640            .insert_answers(axfr_response());
641        (
642            DnsRequest::new(query, DnsRequestOptions::default()),
643            vec![msg],
644        )
645    }
646
647    fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
648        let base = axfr_query();
649
650        let query = base.clone();
651        let mut rr = axfr_response();
652        let rr2 = rr.split_off(3);
653        let mut msg1 = base.clone();
654        msg1.set_message_type(MessageType::Response)
655            .insert_answers(rr);
656        let mut msg2 = base;
657        msg2.set_message_type(MessageType::Response)
658            .insert_answers(rr2);
659        (
660            DnsRequest::new(query, DnsRequestOptions::default()),
661            vec![msg1, msg2],
662        )
663    }
664
665    #[tokio::test]
666    async fn test_multiplexer_a() {
667        let (query, answer) = a_query_answer();
668        let mut multiplexer = get_mocked_multiplexer(answer).await;
669        let response = multiplexer.send_message(query);
670        let response = tokio::select! {
671            _ = multiplexer.next() => {
672                // polling multiplexer to make it run
673                panic!("should never end")
674            },
675            r = response.try_collect::<Vec<_>>() => r.unwrap(),
676        };
677        assert_eq!(response.len(), 1);
678    }
679
680    #[tokio::test]
681    async fn test_multiplexer_axfr() {
682        let (query, answer) = axfr_query_answer();
683        let mut multiplexer = get_mocked_multiplexer(answer).await;
684        let response = multiplexer.send_message(query);
685        let response = tokio::select! {
686            _ = multiplexer.next() => {
687                // polling multiplexer to make it run
688                panic!("should never end")
689            },
690            r = response.try_collect::<Vec<_>>() => r.unwrap(),
691        };
692        assert_eq!(response.len(), 1);
693        assert_eq!(response[0].answers().len(), axfr_response().len());
694    }
695
696    #[tokio::test]
697    async fn test_multiplexer_axfr_multi() {
698        let (query, answer) = axfr_query_answer_multi();
699        let mut multiplexer = get_mocked_multiplexer(answer).await;
700        let response = multiplexer.send_message(query);
701        let response = tokio::select! {
702            _ = multiplexer.next() => {
703                // polling multiplexer to make it run
704                panic!("should never end")
705            },
706            r = response.try_collect::<Vec<_>>() => r.unwrap(),
707        };
708        assert_eq!(response.len(), 2);
709        assert_eq!(
710            response.iter().map(|m| m.answers().len()).sum::<usize>(),
711            axfr_response().len()
712        );
713    }
714}