1use std::borrow::Borrow;
11use std::collections::hash_map::Entry;
12use std::collections::HashMap;
13use std::fmt::{self, Display};
14use std::marker::Unpin;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use std::time::{Duration, SystemTime, UNIX_EPOCH};
19
20use futures_channel::mpsc;
21use futures_util::stream::{Stream, StreamExt};
22use futures_util::{future::Future, ready, FutureExt};
23use rand;
24use rand::distributions::{Distribution, Standard};
25use tracing::{debug, warn};
26
27use crate::error::*;
28use crate::op::{MessageFinalizer, MessageVerifier};
29use crate::xfer::{
30 ignore_send, BufDnsStreamHandle, DnsClientStream, DnsRequest, DnsRequestSender, DnsResponse,
31 DnsResponseStream, SerialMessage, CHANNEL_BUFFER_SIZE,
32};
33use crate::DnsStreamHandle;
34use crate::Time;
35
36const QOS_MAX_RECEIVE_MSGS: usize = 100; struct ActiveRequest {
39 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
41 request_id: u16,
42 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
43 verifier: Option<MessageVerifier>,
44}
45
46impl ActiveRequest {
47 fn new(
48 completion: mpsc::Sender<Result<DnsResponse, ProtoError>>,
49 request_id: u16,
50 timeout: Box<dyn Future<Output = ()> + Send + Unpin>,
51 verifier: Option<MessageVerifier>,
52 ) -> Self {
53 Self {
54 completion,
55 request_id,
56 timeout,
58 verifier,
59 }
60 }
61
62 fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> {
64 self.timeout.poll_unpin(cx)
65 }
66
67 fn is_canceled(&self) -> bool {
69 self.completion.is_closed()
70 }
71
72 fn request_id(&self) -> u16 {
74 self.request_id
75 }
76
77 fn complete_with_error(mut self, error: ProtoError) {
79 ignore_send(self.completion.try_send(Err(error)));
80 }
81}
82
83#[must_use = "futures do nothing unless polled"]
89pub struct DnsMultiplexer<S, MF>
90where
91 S: DnsClientStream + 'static,
92 MF: MessageFinalizer,
93{
94 stream: S,
95 timeout_duration: Duration,
96 stream_handle: BufDnsStreamHandle,
97 active_requests: HashMap<u16, ActiveRequest>,
98 signer: Option<Arc<MF>>,
99 is_shutdown: bool,
100}
101
102impl<S, MF> DnsMultiplexer<S, MF>
103where
104 S: DnsClientStream + Unpin + 'static,
105 MF: MessageFinalizer,
106{
107 #[allow(clippy::new_ret_no_self)]
116 pub fn new<F>(
117 stream: F,
118 stream_handle: BufDnsStreamHandle,
119 signer: Option<Arc<MF>>,
120 ) -> DnsMultiplexerConnect<F, S, MF>
121 where
122 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
123 {
124 Self::with_timeout(stream, stream_handle, Duration::from_secs(5), signer)
125 }
126
127 pub fn with_timeout<F>(
138 stream: F,
139 stream_handle: BufDnsStreamHandle,
140 timeout_duration: Duration,
141 signer: Option<Arc<MF>>,
142 ) -> DnsMultiplexerConnect<F, S, MF>
143 where
144 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
145 {
146 DnsMultiplexerConnect {
147 stream,
148 stream_handle: Some(stream_handle),
149 timeout_duration,
150 signer,
151 }
152 }
153
154 fn drop_cancelled(&mut self, cx: &mut Context<'_>) {
157 let mut canceled = HashMap::<u16, ProtoError>::new();
158 for (&id, ref mut active_req) in &mut self.active_requests {
159 if active_req.is_canceled() {
160 canceled.insert(id, ProtoError::from("requestor canceled"));
161 }
162
163 match active_req.poll_timeout(cx) {
165 Poll::Ready(()) => {
166 debug!("request timed out: {}", id);
167 canceled.insert(id, ProtoError::from(ProtoErrorKind::Timeout));
168 }
169 Poll::Pending => (),
170 }
171 }
172
173 for (id, error) in canceled {
175 if let Some(active_request) = self.active_requests.remove(&id) {
176 active_request.complete_with_error(error);
178 }
179 }
180 }
181
182 fn next_random_query_id(&self) -> Result<u16, ProtoError> {
184 let mut rand = rand::thread_rng();
185
186 for _ in 0..100 {
187 let id: u16 = Standard.sample(&mut rand); if !self.active_requests.contains_key(&id) {
190 return Ok(id);
191 }
192 }
193
194 Err(ProtoError::from(
195 "id space exhausted, consider filing an issue",
196 ))
197 }
198
199 fn stream_closed_close_all(&mut self, error: ProtoError) {
201 if !self.active_requests.is_empty() {
202 warn!("stream {} error: {}", self.stream, error);
203 } else {
204 debug!("stream {} error: {}", self.stream, error);
205 }
206
207 for (_, active_request) in self.active_requests.drain() {
208 active_request.complete_with_error(error.clone());
210 }
211 }
212}
213
214#[must_use = "futures do nothing unless polled"]
216pub struct DnsMultiplexerConnect<F, S, MF>
217where
218 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
219 S: Stream<Item = Result<SerialMessage, ProtoError>> + Unpin,
220 MF: MessageFinalizer + Send + Sync + 'static,
221{
222 stream: F,
223 stream_handle: Option<BufDnsStreamHandle>,
224 timeout_duration: Duration,
225 signer: Option<Arc<MF>>,
226}
227
228impl<F, S, MF> Future for DnsMultiplexerConnect<F, S, MF>
229where
230 F: Future<Output = Result<S, ProtoError>> + Send + Unpin + 'static,
231 S: DnsClientStream + Unpin + 'static,
232 MF: MessageFinalizer + Send + Sync + 'static,
233{
234 type Output = Result<DnsMultiplexer<S, MF>, ProtoError>;
235
236 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 let stream: S = ready!(self.stream.poll_unpin(cx))?;
238
239 Poll::Ready(Ok(DnsMultiplexer {
240 stream,
241 timeout_duration: self.timeout_duration,
242 stream_handle: self
243 .stream_handle
244 .take()
245 .expect("must not poll after complete"),
246 active_requests: HashMap::new(),
247 signer: self.signer.clone(),
248 is_shutdown: false,
249 }))
250 }
251}
252
253impl<S, MF> Display for DnsMultiplexer<S, MF>
254where
255 S: DnsClientStream + 'static,
256 MF: MessageFinalizer + Send + Sync + 'static,
257{
258 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
259 write!(formatter, "{}", self.stream)
260 }
261}
262
263impl<S, MF> DnsRequestSender for DnsMultiplexer<S, MF>
264where
265 S: DnsClientStream + Unpin + 'static,
266 MF: MessageFinalizer + Send + Sync + 'static,
267{
268 fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream {
269 if self.is_shutdown {
270 panic!("can not send messages after stream is shutdown")
271 }
272
273 if self.active_requests.len() > CHANNEL_BUFFER_SIZE {
274 return ProtoError::from(ProtoErrorKind::Busy).into();
275 }
276
277 let query_id = match self.next_random_query_id() {
278 Ok(id) => id,
279 Err(e) => return e.into(),
280 };
281
282 let (mut request, _) = request.into_parts();
283 request.set_id(query_id);
284
285 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
286 Ok(now) => now.as_secs(),
287 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
288 };
289
290 let now = now as u32;
292
293 let mut verifier = None;
294 if let Some(ref signer) = self.signer {
295 if signer.should_finalize_message(&request) {
296 match request.finalize::<MF>(signer.borrow(), now) {
297 Ok(answer_verifier) => verifier = answer_verifier,
298 Err(e) => {
299 debug!("could not sign message: {}", e);
300 return e.into();
301 }
302 }
303 }
304 }
305
306 let timeout = S::Time::delay_for(self.timeout_duration);
308
309 let (complete, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
310
311 let active_request =
313 ActiveRequest::new(complete, request.id(), Box::new(timeout), verifier);
314
315 match request.to_vec() {
316 Ok(buffer) => {
317 debug!("sending message id: {}", active_request.request_id());
318 let serial_message = SerialMessage::new(buffer, self.stream.name_server_addr());
319
320 debug!(
321 "final message: {}",
322 serial_message
323 .to_message()
324 .expect("bizarre we just made this message")
325 );
326
327 match self.stream_handle.send(serial_message) {
330 Ok(()) => self
331 .active_requests
332 .insert(active_request.request_id(), active_request),
333 Err(err) => return err.into(),
334 };
335 }
336 Err(e) => {
337 debug!(
338 "error message id: {} error: {}",
339 active_request.request_id(),
340 e
341 );
342 return e.into();
344 }
345 }
346
347 receiver.into()
348 }
349
350 fn shutdown(&mut self) {
351 self.is_shutdown = true;
352 }
353
354 fn is_shutdown(&self) -> bool {
355 self.is_shutdown
356 }
357}
358
359impl<S, MF> Stream for DnsMultiplexer<S, MF>
360where
361 S: DnsClientStream + Unpin + 'static,
362 MF: MessageFinalizer + Send + Sync + 'static,
363{
364 type Item = Result<(), ProtoError>;
365
366 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
367 self.drop_cancelled(cx);
369
370 if self.is_shutdown && self.active_requests.is_empty() {
371 debug!("stream is done: {}", self);
372 return Poll::Ready(None);
373 }
374
375 let mut messages_received = 0;
379 for i in 0..QOS_MAX_RECEIVE_MSGS {
380 match self.stream.poll_next_unpin(cx) {
381 Poll::Ready(Some(Ok(buffer))) => {
382 messages_received = i;
383
384 match buffer.to_message() {
386 Ok(message) => match self.active_requests.entry(message.id()) {
387 Entry::Occupied(mut request_entry) => {
388 let active_request = request_entry.get_mut();
390 if let Some(ref mut verifier) = active_request.verifier {
391 ignore_send(
392 active_request
393 .completion
394 .try_send(verifier(buffer.bytes())),
395 );
396 } else {
397 ignore_send(
398 active_request.completion.try_send(Ok(message.into())),
399 );
400 }
401 }
402 Entry::Vacant(..) => debug!("unexpected request_id: {}", message.id()),
403 },
404 Err(e) => debug!("error decoding message: {}", e),
406 }
407 }
408 Poll::Ready(err) => {
409 let err = match err {
410 Some(Err(e)) => e,
411 None => ProtoError::from("stream closed"),
412 _ => unreachable!(),
413 };
414
415 self.stream_closed_close_all(err);
416 self.is_shutdown = true;
417 return Poll::Ready(None);
418 }
419 Poll::Pending => break,
420 }
421 }
422
423 if messages_received == QOS_MAX_RECEIVE_MSGS {
427 cx.waker().wake_by_ref();
429 }
430
431 Poll::Pending
433 }
434}
435
436#[cfg(test)]
437mod test {
438 use super::*;
439 use crate::op::message::NoopMessageFinalizer;
440 use crate::op::op_code::OpCode;
441 use crate::op::{Message, MessageType, Query};
442 use crate::rr::record_type::RecordType;
443 use crate::rr::{DNSClass, Name, RData, Record};
444 use crate::serialize::binary::BinEncodable;
445 use crate::xfer::StreamReceiver;
446 use crate::xfer::{DnsClientStream, DnsRequestOptions};
447 use futures_util::future;
448 use futures_util::stream::TryStreamExt;
449 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
450
451 struct MockClientStream {
452 messages: Vec<Message>,
453 addr: SocketAddr,
454 id: Option<u16>,
455 receiver: Option<StreamReceiver>,
456 }
457
458 impl MockClientStream {
459 fn new(
460 mut messages: Vec<Message>,
461 addr: SocketAddr,
462 ) -> Pin<Box<dyn Future<Output = Result<Self, ProtoError>> + Send>> {
463 messages.reverse(); Box::pin(future::ok(Self {
465 messages,
466 addr,
467 id: None,
468 receiver: None,
469 }))
470 }
471 }
472
473 impl fmt::Display for MockClientStream {
474 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
475 write!(formatter, "TestClientStream")
476 }
477 }
478
479 impl Stream for MockClientStream {
480 type Item = Result<SerialMessage, ProtoError>;
481
482 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
483 let id = if let Some(id) = self.id {
484 id
485 } else {
486 let serial = ready!(self
487 .receiver
488 .as_mut()
489 .expect("should only be polled after receiver has been set")
490 .poll_next_unpin(cx));
491 let message = serial.unwrap().to_message().unwrap();
492 self.id = Some(message.id());
493 message.id()
494 };
495
496 if let Some(mut message) = self.messages.pop() {
497 message.set_id(id);
498 Poll::Ready(Some(Ok(SerialMessage::new(
499 message.to_bytes().unwrap(),
500 self.addr,
501 ))))
502 } else {
503 Poll::Pending
504 }
505 }
506 }
507
508 impl DnsClientStream for MockClientStream {
509 type Time = crate::TokioTime;
510
511 fn name_server_addr(&self) -> SocketAddr {
512 self.addr
513 }
514 }
515
516 async fn get_mocked_multiplexer(
517 mock_response: Vec<Message>,
518 ) -> DnsMultiplexer<MockClientStream, NoopMessageFinalizer> {
519 let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
520 let mock_response = MockClientStream::new(mock_response, addr);
521 let (handler, receiver) = BufDnsStreamHandle::new(addr);
522 let mut multiplexer =
523 DnsMultiplexer::with_timeout(mock_response, handler, Duration::from_millis(100), None)
524 .await
525 .unwrap();
526
527 multiplexer.stream.receiver = Some(receiver); multiplexer
530 }
531
532 fn a_query_answer() -> (DnsRequest, Vec<Message>) {
533 let name = Name::from_ascii("www.example.com").unwrap();
534
535 let mut msg = Message::new();
536 msg.add_query({
537 let mut query = Query::query(name.clone(), RecordType::A);
538 query.set_query_class(DNSClass::IN);
539 query
540 })
541 .set_message_type(MessageType::Query)
542 .set_op_code(OpCode::Query)
543 .set_recursion_desired(true);
544
545 let query = msg.clone();
546 msg.set_message_type(MessageType::Response).add_answer(
547 Record::new()
548 .set_name(name)
549 .set_ttl(86400)
550 .set_rr_type(RecordType::A)
551 .set_dns_class(DNSClass::IN)
552 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
553 .clone(),
554 );
555 (
556 DnsRequest::new(query, DnsRequestOptions::default()),
557 vec![msg],
558 )
559 }
560
561 fn axfr_query() -> Message {
562 let name = Name::from_ascii("example.com").unwrap();
563
564 let mut msg = Message::new();
565 msg.add_query({
566 let mut query = Query::query(name, RecordType::AXFR);
567 query.set_query_class(DNSClass::IN);
568 query
569 })
570 .set_message_type(MessageType::Query)
571 .set_op_code(OpCode::Query)
572 .set_recursion_desired(true);
573 msg
574 }
575
576 fn axfr_response() -> Vec<Record> {
577 use crate::rr::rdata::*;
578 let origin = Name::from_ascii("example.com").unwrap();
579 let soa = Record::new()
580 .set_name(origin.clone())
581 .set_ttl(3600)
582 .set_rr_type(RecordType::SOA)
583 .set_dns_class(DNSClass::IN)
584 .set_data(Some(RData::SOA(SOA::new(
585 Name::parse("sns.dns.icann.org.", None).unwrap(),
586 Name::parse("noc.dns.icann.org.", None).unwrap(),
587 2015082403,
588 7200,
589 3600,
590 1209600,
591 3600,
592 ))))
593 .clone();
594
595 vec![
596 soa.clone(),
597 Record::new()
598 .set_name(origin.clone())
599 .set_ttl(86400)
600 .set_rr_type(RecordType::NS)
601 .set_dns_class(DNSClass::IN)
602 .set_data(Some(RData::NS(
603 Name::parse("a.iana-servers.net.", None).unwrap(),
604 )))
605 .clone(),
606 Record::new()
607 .set_name(origin.clone())
608 .set_ttl(86400)
609 .set_rr_type(RecordType::NS)
610 .set_dns_class(DNSClass::IN)
611 .set_data(Some(RData::NS(
612 Name::parse("b.iana-servers.net.", None).unwrap(),
613 )))
614 .clone(),
615 Record::new()
616 .set_name(origin.clone())
617 .set_ttl(86400)
618 .set_rr_type(RecordType::A)
619 .set_dns_class(DNSClass::IN)
620 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 216, 34))))
621 .clone(),
622 Record::new()
623 .set_name(origin)
624 .set_ttl(86400)
625 .set_rr_type(RecordType::AAAA)
626 .set_dns_class(DNSClass::IN)
627 .set_data(Some(RData::AAAA(Ipv6Addr::new(
628 0x2606, 0x2800, 0x220, 0x1, 0x248, 0x1893, 0x25c8, 0x1946,
629 ))))
630 .clone(),
631 soa,
632 ]
633 }
634
635 fn axfr_query_answer() -> (DnsRequest, Vec<Message>) {
636 let mut msg = axfr_query();
637
638 let query = msg.clone();
639 msg.set_message_type(MessageType::Response)
640 .insert_answers(axfr_response());
641 (
642 DnsRequest::new(query, DnsRequestOptions::default()),
643 vec![msg],
644 )
645 }
646
647 fn axfr_query_answer_multi() -> (DnsRequest, Vec<Message>) {
648 let base = axfr_query();
649
650 let query = base.clone();
651 let mut rr = axfr_response();
652 let rr2 = rr.split_off(3);
653 let mut msg1 = base.clone();
654 msg1.set_message_type(MessageType::Response)
655 .insert_answers(rr);
656 let mut msg2 = base;
657 msg2.set_message_type(MessageType::Response)
658 .insert_answers(rr2);
659 (
660 DnsRequest::new(query, DnsRequestOptions::default()),
661 vec![msg1, msg2],
662 )
663 }
664
665 #[tokio::test]
666 async fn test_multiplexer_a() {
667 let (query, answer) = a_query_answer();
668 let mut multiplexer = get_mocked_multiplexer(answer).await;
669 let response = multiplexer.send_message(query);
670 let response = tokio::select! {
671 _ = multiplexer.next() => {
672 panic!("should never end")
674 },
675 r = response.try_collect::<Vec<_>>() => r.unwrap(),
676 };
677 assert_eq!(response.len(), 1);
678 }
679
680 #[tokio::test]
681 async fn test_multiplexer_axfr() {
682 let (query, answer) = axfr_query_answer();
683 let mut multiplexer = get_mocked_multiplexer(answer).await;
684 let response = multiplexer.send_message(query);
685 let response = tokio::select! {
686 _ = multiplexer.next() => {
687 panic!("should never end")
689 },
690 r = response.try_collect::<Vec<_>>() => r.unwrap(),
691 };
692 assert_eq!(response.len(), 1);
693 assert_eq!(response[0].answers().len(), axfr_response().len());
694 }
695
696 #[tokio::test]
697 async fn test_multiplexer_axfr_multi() {
698 let (query, answer) = axfr_query_answer_multi();
699 let mut multiplexer = get_mocked_multiplexer(answer).await;
700 let response = multiplexer.send_message(query);
701 let response = tokio::select! {
702 _ = multiplexer.next() => {
703 panic!("should never end")
705 },
706 r = response.try_collect::<Vec<_>>() => r.unwrap(),
707 };
708 assert_eq!(response.len(), 2);
709 assert_eq!(
710 response.iter().map(|m| m.answers().len()).sum::<usize>(),
711 axfr_response().len()
712 );
713 }
714}