trust_dns_proto/udp/
udp_stream.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::marker::PhantomData;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use async_trait::async_trait;
15use futures_util::stream::Stream;
16use futures_util::{future::Future, ready, TryFutureExt};
17use rand;
18use rand::distributions::{uniform::Uniform, Distribution};
19use tracing::{debug, warn};
20
21use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
22use crate::Time;
23
24/// Trait for UdpSocket
25#[async_trait]
26pub trait UdpSocket
27where
28    Self: Send + Sync + Sized + Unpin,
29{
30    /// Time implementation used for this type
31    type Time: Time;
32
33    /// setups up a "client" udp connection that will only receive packets from the associated address
34    async fn connect(addr: SocketAddr) -> io::Result<Self>;
35
36    /// same as connect, but binds to the specified local address for seding address
37    async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
38
39    /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
40    async fn bind(addr: SocketAddr) -> io::Result<Self>;
41
42    /// Poll once Receive data from the socket and returns the number of bytes read and the address from
43    /// where the data came on success.
44    fn poll_recv_from(
45        &self,
46        cx: &mut Context<'_>,
47        buf: &mut [u8],
48    ) -> Poll<io::Result<(usize, SocketAddr)>>;
49
50    /// Receive data from the socket and returns the number of bytes read and the address from
51    /// where the data came on success.
52    async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
53        futures_util::future::poll_fn(|cx| self.poll_recv_from(cx, buf)).await
54    }
55
56    /// Poll once to send data to the given address.
57    fn poll_send_to(
58        &self,
59        cx: &mut Context<'_>,
60        buf: &[u8],
61        target: SocketAddr,
62    ) -> Poll<io::Result<usize>>;
63
64    /// Send data to the given address.
65    async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
66        futures_util::future::poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
67    }
68}
69
70/// A UDP stream of DNS binary packets
71#[must_use = "futures do nothing unless polled"]
72pub struct UdpStream<S: Send> {
73    socket: S,
74    outbound_messages: StreamReceiver,
75}
76
77impl<S: UdpSocket + Send + 'static> UdpStream<S> {
78    /// This method is intended for client connections, see `with_bound` for a method better for
79    ///  straight listening. It is expected that the resolver wrapper will be responsible for
80    ///  creating and managing new UdpStreams such that each new client would have a random port
81    ///  (reduce chance of cache poisoning). This will return a randomly assigned local port.
82    ///
83    /// # Arguments
84    ///
85    /// * `remote_addr` - socket address for the remote connection (used to determine IPv4 or IPv6)
86    ///
87    /// # Return
88    ///
89    /// a tuple of a Future Stream which will handle sending and receiving messages, and a
90    ///  handle which can be used to send messages into the stream.
91    #[allow(clippy::type_complexity)]
92    pub fn new(
93        remote_addr: SocketAddr,
94        bind_addr: Option<SocketAddr>,
95    ) -> (
96        Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
97        BufDnsStreamHandle,
98    ) {
99        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
100
101        // TODO: allow the bind address to be specified...
102        // constructs a future for getting the next randomly bound port to a UdpSocket
103        let next_socket = NextRandomUdpSocket::new(&remote_addr, &bind_addr);
104
105        // This set of futures collapses the next udp socket into a stream which can be used for
106        //  sending and receiving udp packets.
107        let stream = Box::new(next_socket.map_ok(move |socket| Self {
108            socket,
109            outbound_messages,
110        }));
111
112        (stream, message_sender)
113    }
114
115    /// Initialize the Stream with an already bound socket. Generally this should be only used for
116    ///  server listening sockets. See `new` for a client oriented socket. Specifically, this there
117    ///  is already a bound socket in this context, whereas `new` makes sure to randomize ports
118    ///  for additional cache poison prevention.
119    ///
120    /// # Arguments
121    ///
122    /// * `socket` - an already bound UDP socket
123    /// * `remote_addr` - remote side of this connection
124    ///
125    /// # Return
126    ///
127    /// a tuple of a Future Stream which will handle sending and receiving messsages, and a
128    ///  handle which can be used to send messages into the stream.
129    pub fn with_bound(socket: S, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
130        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
131        let stream = Self {
132            socket,
133            outbound_messages,
134        };
135
136        (stream, message_sender)
137    }
138
139    #[allow(unused)]
140    pub(crate) fn from_parts(socket: S, outbound_messages: StreamReceiver) -> Self {
141        Self {
142            socket,
143            outbound_messages,
144        }
145    }
146}
147
148impl<S: Send> UdpStream<S> {
149    #[allow(clippy::type_complexity)]
150    fn pollable_split(&mut self) -> (&mut S, &mut StreamReceiver) {
151        (&mut self.socket, &mut self.outbound_messages)
152    }
153}
154
155impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
156    type Item = Result<SerialMessage, io::Error>;
157
158    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
159        let (socket, outbound_messages) = self.pollable_split();
160        let socket = Pin::new(socket);
161        let mut outbound_messages = Pin::new(outbound_messages);
162
163        // this will not accept incoming data while there is data to send
164        //  makes this self throttling.
165        while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
166            // first try to send
167            let addr = message.addr();
168
169            // this wiil return if not ready,
170            //   meaning that sending will be prefered over receiving...
171
172            // TODO: shouldn't this return the error to send to the sender?
173            if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
174                // Drop the UDP packet and continue
175                warn!(
176                    "error sending message to {} on udp_socket, dropping response: {}",
177                    addr, e
178                );
179            }
180
181            // message sent, need to pop the message
182            assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
183        }
184
185        // For QoS, this will only accept one message and output that
186        // receive all inbound messages
187
188        // TODO: this should match edns settings
189        let mut buf = [0u8; 4096];
190        let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
191
192        let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
193        Poll::Ready(Some(Ok(serial_message)))
194    }
195}
196
197#[must_use = "futures do nothing unless polled"]
198pub(crate) struct NextRandomUdpSocket<S> {
199    bind_address: SocketAddr,
200    marker: PhantomData<S>,
201}
202
203impl<S: UdpSocket> NextRandomUdpSocket<S> {
204    /// Creates a future for randomly binding to a local socket address for client connections,
205    /// if no port is specified.
206    ///
207    /// If a port is specified in the bind address it is used.
208    pub(crate) fn new(name_server: &SocketAddr, bind_addr: &Option<SocketAddr>) -> Self {
209        let bind_address = match bind_addr {
210            Some(ba) => *ba,
211            None => match *name_server {
212                SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
213                SocketAddr::V6(..) => {
214                    SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
215                }
216            },
217        };
218
219        Self {
220            bind_address,
221            marker: PhantomData,
222        }
223    }
224
225    async fn bind(addr: SocketAddr) -> Result<S, io::Error> {
226        S::bind(addr).await
227    }
228}
229
230impl<S: UdpSocket> Future for NextRandomUdpSocket<S> {
231    type Output = Result<S, io::Error>;
232
233    /// polls until there is an available next random UDP port,
234    /// if no port has been specified in bind_addr.
235    ///
236    /// if there is no port available after 10 attempts, returns NotReady
237    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238        if self.bind_address.port() == 0 {
239            // Per RFC 6056 Section 2.1:
240            //
241            //    The dynamic port range defined by IANA consists of the 49152-65535
242            //    range, and is meant for the selection of ephemeral ports.
243            let rand_port_range = Uniform::new_inclusive(49152_u16, u16::max_value());
244            let mut rand = rand::thread_rng();
245
246            for attempt in 0..10 {
247                let port = rand_port_range.sample(&mut rand);
248                let bind_addr = SocketAddr::new(self.bind_address.ip(), port);
249
250                // TODO: allow TTL to be adjusted...
251                // TODO: this immediate poll might be wrong in some cases...
252                match Box::pin(Self::bind(bind_addr)).as_mut().poll(cx) {
253                    Poll::Ready(Ok(socket)) => {
254                        debug!("created socket successfully");
255                        return Poll::Ready(Ok(socket));
256                    }
257                    Poll::Ready(Err(err)) => match err.kind() {
258                        io::ErrorKind::AddrInUse => {
259                            debug!("unable to bind port, attempt: {}: {}", attempt, err);
260                        }
261                        _ => {
262                            debug!("failed to bind port: {}", err);
263                            return Poll::Ready(Err(err));
264                        }
265                    },
266                    Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
267                }
268            }
269
270            debug!("could not get next random port, delaying");
271
272            // TODO: because no interest is registered anywhere, we must awake.
273            cx.waker().wake_by_ref();
274
275            // returning NotReady here, perhaps the next poll there will be some more socket available.
276            Poll::Pending
277        } else {
278            // Use port that was specified in bind address.
279            Box::pin(Self::bind(self.bind_address)).as_mut().poll(cx)
280        }
281    }
282}
283
284#[cfg(feature = "tokio-runtime")]
285#[async_trait]
286impl UdpSocket for tokio::net::UdpSocket {
287    type Time = crate::TokioTime;
288
289    /// setups up a "client" udp connection that will only receive packets from the associated address
290    ///
291    /// if the addr is ipv4 then it will bind local addr to 0.0.0.0:0, ipv6 \[::\]0
292    async fn connect(addr: SocketAddr) -> io::Result<Self> {
293        let bind_addr: SocketAddr = match addr {
294            SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
295            SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
296        };
297
298        Self::connect_with_bind(addr, bind_addr).await
299    }
300
301    /// same as connect, but binds to the specified local address for seding address
302    async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
303        let socket = Self::bind(bind_addr).await?;
304
305        // TODO: research connect more, it appears to break UDP receiving tests, etc...
306        // socket.connect(addr).await?;
307
308        Ok(socket)
309    }
310
311    async fn bind(addr: SocketAddr) -> io::Result<Self> {
312        Self::bind(addr).await
313    }
314
315    fn poll_recv_from(
316        &self,
317        cx: &mut Context<'_>,
318        buf: &mut [u8],
319    ) -> Poll<io::Result<(usize, SocketAddr)>> {
320        let mut buf = tokio::io::ReadBuf::new(buf);
321        let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
322        let len = buf.filled().len();
323
324        Poll::Ready(Ok((len, addr)))
325    }
326
327    fn poll_send_to(
328        &self,
329        cx: &mut Context<'_>,
330        buf: &[u8],
331        target: SocketAddr,
332    ) -> Poll<io::Result<usize>> {
333        Self::poll_send_to(self, cx, buf, target)
334    }
335}
336
337#[cfg(test)]
338#[cfg(feature = "tokio-runtime")]
339mod tests {
340    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
341    use std::net::Ipv6Addr;
342    use std::net::{IpAddr, Ipv4Addr};
343    use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
344
345    #[test]
346    fn test_next_random_socket() {
347        use crate::tests::next_random_socket_test;
348        let io_loop = Runtime::new().expect("failed to create tokio runtime");
349        next_random_socket_test::<TokioUdpSocket, Runtime>(io_loop)
350    }
351
352    #[test]
353    fn test_udp_stream_ipv4() {
354        use crate::tests::udp_stream_test;
355        let io_loop = Runtime::new().expect("failed to create tokio runtime");
356        io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V4(
357            Ipv4Addr::new(127, 0, 0, 1),
358        )));
359    }
360
361    #[test]
362    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
363    fn test_udp_stream_ipv6() {
364        use crate::tests::udp_stream_test;
365        let io_loop = Runtime::new().expect("failed to create tokio runtime");
366        io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V6(
367            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
368        )));
369    }
370}