trust_dns_proto/udp/
udp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::borrow::Borrow;
9use std::fmt::{self, Display};
10use std::marker::PhantomData;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use futures_util::{future::Future, stream::Stream};
18use tracing::{debug, warn};
19
20use crate::error::ProtoError;
21use crate::op::message::NoopMessageFinalizer;
22use crate::op::{MessageFinalizer, MessageVerifier};
23use crate::udp::udp_stream::{NextRandomUdpSocket, UdpSocket};
24use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
25use crate::Time;
26
27/// A UDP client stream of DNS binary packets
28///
29/// This stream will create a new UDP socket for every request. This is to avoid potential cache
30///   poisoning during use by UDP based attacks.
31#[must_use = "futures do nothing unless polled"]
32pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
33where
34    S: Send,
35    MF: MessageFinalizer,
36{
37    name_server: SocketAddr,
38    bind_addr: Option<SocketAddr>,
39    timeout: Duration,
40    is_shutdown: bool,
41    signer: Option<Arc<MF>>,
42    marker: PhantomData<S>,
43}
44
45impl<S: Send> UdpClientStream<S, NoopMessageFinalizer> {
46    /// it is expected that the resolver wrapper will be responsible for creating and managing
47    ///  new UdpClients such that each new client would have a random port (reduce chance of cache
48    ///  poisoning)
49    ///
50    /// # Return
51    ///
52    /// a tuple of a Future Stream which will handle sending and receiving messages, and a
53    ///  handle which can be used to send messages into the stream.
54    #[allow(clippy::new_ret_no_self)]
55    pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
56        Self::with_timeout(name_server, Duration::from_secs(5))
57    }
58
59    /// Constructs a new UdpStream for a client to the specified SocketAddr.
60    ///
61    /// # Arguments
62    ///
63    /// * `name_server` - the IP and Port of the DNS server to connect to
64    /// * `timeout` - connection timeout
65    pub fn with_timeout(
66        name_server: SocketAddr,
67        timeout: Duration,
68    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
69        Self::with_bind_addr_and_timeout(name_server, None, timeout)
70    }
71
72    /// Constructs a new UdpStream for a client to the specified SocketAddr.
73    ///
74    /// # Arguments
75    ///
76    /// * `name_server` - the IP and Port of the DNS server to connect to
77    /// * `bind_addr` - the IP and port to connect from
78    /// * `timeout` - connection timeout
79    pub fn with_bind_addr_and_timeout(
80        name_server: SocketAddr,
81        bind_addr: Option<SocketAddr>,
82        timeout: Duration,
83    ) -> UdpClientConnect<S, NoopMessageFinalizer> {
84        Self::with_timeout_and_signer_and_bind_addr(name_server, timeout, None, bind_addr)
85    }
86}
87
88impl<S: Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
89    /// Constructs a new TcpStream for a client to the specified SocketAddr.
90    ///
91    /// # Arguments
92    ///
93    /// * `name_server` - the IP and Port of the DNS server to connect to
94    /// * `timeout` - connection timeout
95    pub fn with_timeout_and_signer(
96        name_server: SocketAddr,
97        timeout: Duration,
98        signer: Option<Arc<MF>>,
99    ) -> UdpClientConnect<S, MF> {
100        UdpClientConnect {
101            name_server,
102            bind_addr: None,
103            timeout,
104            signer,
105            marker: PhantomData::<S>,
106        }
107    }
108
109    /// Constructs a new TcpStream for a client to the specified SocketAddr.
110    ///
111    /// # Arguments
112    ///
113    /// * `name_server` - the IP and Port of the DNS server to connect to
114    /// * `timeout` - connection timeout
115    /// * `bind_addr` - the IP address and port to connect from
116    pub fn with_timeout_and_signer_and_bind_addr(
117        name_server: SocketAddr,
118        timeout: Duration,
119        signer: Option<Arc<MF>>,
120        bind_addr: Option<SocketAddr>,
121    ) -> UdpClientConnect<S, MF> {
122        UdpClientConnect {
123            name_server,
124            bind_addr,
125            timeout,
126            signer,
127            marker: PhantomData::<S>,
128        }
129    }
130}
131
132impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
133    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
134        write!(formatter, "UDP({})", self.name_server)
135    }
136}
137
138/// creates random query_id, each socket is unique, no need for global uniqueness
139fn random_query_id() -> u16 {
140    use rand::distributions::{Distribution, Standard};
141    let mut rand = rand::thread_rng();
142
143    Standard.sample(&mut rand)
144}
145
146impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
147    for UdpClientStream<S, MF>
148{
149    fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
150        if self.is_shutdown {
151            panic!("can not send messages after stream is shutdown")
152        }
153
154        // associated the ID for this request, b/c this connection is unique to socket port, the ID
155        //   does not need to be globally unique
156        message.set_id(random_query_id());
157
158        let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
159            Ok(now) => now.as_secs(),
160            Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
161        };
162
163        // TODO: truncates u64 to u32, error on overflow?
164        let now = now as u32;
165
166        let mut verifier = None;
167        if let Some(ref signer) = self.signer {
168            if signer.should_finalize_message(&message) {
169                match message.finalize::<MF>(signer.borrow(), now) {
170                    Ok(answer_verifier) => verifier = answer_verifier,
171                    Err(e) => {
172                        debug!("could not sign message: {}", e);
173                        return e.into();
174                    }
175                }
176            }
177        }
178
179        let bytes = match message.to_vec() {
180            Ok(bytes) => bytes,
181            Err(err) => {
182                return err.into();
183            }
184        };
185
186        let message_id = message.id();
187        let message = SerialMessage::new(bytes, self.name_server);
188        let bind_addr = self.bind_addr;
189
190        debug!(
191            "final message: {}",
192            message
193                .to_message()
194                .expect("bizarre we just made this message")
195        );
196
197        S::Time::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
198            self.timeout,
199            Box::pin(send_serial_message::<S>(
200                message, message_id, verifier, bind_addr,
201            )),
202        )
203        .into()
204    }
205
206    fn shutdown(&mut self) {
207        self.is_shutdown = true;
208    }
209
210    fn is_shutdown(&self) -> bool {
211        self.is_shutdown
212    }
213}
214
215// TODO: is this impl necessary? there's nothing being driven here...
216impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
217    type Item = Result<(), ProtoError>;
218
219    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220        // Technically the Stream doesn't actually do anything.
221        if self.is_shutdown {
222            Poll::Ready(None)
223        } else {
224            Poll::Ready(Some(Ok(())))
225        }
226    }
227}
228
229/// A future that resolves to an UdpClientStream
230pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
231where
232    S: Send,
233    MF: MessageFinalizer,
234{
235    name_server: SocketAddr,
236    bind_addr: Option<SocketAddr>,
237    timeout: Duration,
238    signer: Option<Arc<MF>>,
239    marker: PhantomData<S>,
240}
241
242impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
243    type Output = Result<UdpClientStream<S, MF>, ProtoError>;
244
245    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
246        // TODO: this doesn't need to be a future?
247        Poll::Ready(Ok(UdpClientStream::<S, MF> {
248            name_server: self.name_server,
249            bind_addr: self.bind_addr,
250            is_shutdown: false,
251            timeout: self.timeout,
252            signer: self.signer.take(),
253            marker: PhantomData,
254        }))
255    }
256}
257
258async fn send_serial_message<S: UdpSocket + Send>(
259    msg: SerialMessage,
260    msg_id: u16,
261    verifier: Option<MessageVerifier>,
262    bind_addr: Option<SocketAddr>,
263) -> Result<DnsResponse, ProtoError> {
264    let name_server = msg.addr();
265    let socket: S = NextRandomUdpSocket::new(&name_server, &bind_addr).await?;
266    let bytes = msg.bytes();
267    let addr = msg.addr();
268    let len_sent: usize = socket.send_to(bytes, addr).await?;
269
270    if bytes.len() != len_sent {
271        return Err(ProtoError::from(format!(
272            "Not all bytes of message sent, {} of {}",
273            len_sent,
274            bytes.len()
275        )));
276    }
277
278    // TODO: limit the max number of attempted messages? this relies on a timeout to die...
279    loop {
280        // TODO: consider making this heap based? need to verify it matches EDNS settings
281        let mut recv_buf = [0u8; 2048];
282
283        let (len, src) = socket.recv_from(&mut recv_buf).await?;
284        let response = SerialMessage::new(recv_buf.iter().take(len).cloned().collect(), src);
285
286        // compare expected src to received packet
287        let request_target = msg.addr();
288
289        if response.addr() != request_target {
290            warn!(
291                "ignoring response from {} because it does not match name_server: {}.",
292                response.addr(),
293                request_target,
294            );
295
296            // await an answer from the correct NameServer
297            continue;
298        }
299
300        // TODO: match query strings from request and response?
301
302        match response.to_message() {
303            Ok(message) => {
304                if msg_id == message.id() {
305                    debug!("received message id: {}", message.id());
306                    if let Some(mut verifier) = verifier {
307                        return verifier(response.bytes());
308                    } else {
309                        return Ok(DnsResponse::from(message));
310                    }
311                } else {
312                    // on wrong id, attempted poison?
313                    warn!(
314                        "expected message id: {} got: {}, dropped",
315                        msg_id,
316                        message.id()
317                    );
318
319                    continue;
320                }
321            }
322            Err(e) => {
323                // on errors deserializing, continue
324                warn!(
325                    "dropped malformed message waiting for id: {} err: {}",
326                    msg_id, e
327                );
328
329                continue;
330            }
331        }
332    }
333}
334
335#[cfg(test)]
336#[cfg(feature = "tokio-runtime")]
337mod tests {
338    #![allow(clippy::dbg_macro, clippy::print_stdout)]
339    use crate::tests::udp_client_stream_test;
340    use crate::TokioTime;
341    #[cfg(not(target_os = "linux"))]
342    use std::net::Ipv6Addr;
343    use std::net::{IpAddr, Ipv4Addr};
344    use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
345
346    #[test]
347    fn test_udp_client_stream_ipv4() {
348        let io_loop = Runtime::new().expect("failed to create tokio runtime");
349        udp_client_stream_test::<TokioUdpSocket, Runtime, TokioTime>(
350            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
351            io_loop,
352        )
353    }
354
355    #[test]
356    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
357    fn test_udp_client_stream_ipv6() {
358        let io_loop = Runtime::new().expect("failed to create tokio runtime");
359        udp_client_stream_test::<TokioUdpSocket, Runtime, TokioTime>(
360            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
361            io_loop,
362        )
363    }
364}