trust_dns_proto/tcp/
tcp_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use std::io;
11use std::mem;
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use async_trait::async_trait;
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, future::Future, ready, FutureExt};
21use tracing::debug;
22
23use crate::error::*;
24use crate::xfer::{SerialMessage, StreamReceiver};
25use crate::BufDnsStreamHandle;
26use crate::Time;
27
28/// Trait for TCP connection
29pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
30    /// Timer type to use with this TCP stream type
31    type Time: Time;
32}
33
34/// Trait for TCP connection
35#[async_trait]
36pub trait Connect: DnsTcpStream {
37    /// connect to tcp
38    async fn connect(addr: SocketAddr) -> io::Result<Self> {
39        Self::connect_with_bind(addr, None).await
40    }
41
42    /// connect to tcp with address to connect from
43    async fn connect_with_bind(addr: SocketAddr, bind_addr: Option<SocketAddr>)
44        -> io::Result<Self>;
45}
46
47/// Current state while writing to the remote of the TCP connection
48enum WriteTcpState {
49    /// Currently writing the length of bytes to of the buffer.
50    LenBytes {
51        /// Current position in the length buffer being written
52        pos: usize,
53        /// Length of the buffer
54        length: [u8; 2],
55        /// Buffer to write after the length
56        bytes: Vec<u8>,
57    },
58    /// Currently writing the buffer to the remote
59    Bytes {
60        /// Current position in the buffer written
61        pos: usize,
62        /// Buffer to write to the remote
63        bytes: Vec<u8>,
64    },
65    /// Currently flushing the bytes to the remote
66    Flushing,
67}
68
69/// Current state of a TCP stream as it's being read.
70pub(crate) enum ReadTcpState {
71    /// Currently reading the length of the TCP packet
72    LenBytes {
73        /// Current position in the buffer
74        pos: usize,
75        /// Buffer of the length to read
76        bytes: [u8; 2],
77    },
78    /// Currently reading the bytes of the DNS packet
79    Bytes {
80        /// Current position while reading the buffer
81        pos: usize,
82        /// buffer being read into
83        bytes: Vec<u8>,
84    },
85}
86
87/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
88#[must_use = "futures do nothing unless polled"]
89pub struct TcpStream<S: DnsTcpStream> {
90    socket: S,
91    outbound_messages: StreamReceiver,
92    send_state: Option<WriteTcpState>,
93    read_state: ReadTcpState,
94    peer_addr: SocketAddr,
95}
96
97impl<S: Connect> TcpStream<S> {
98    /// Creates a new future of the eventually establish a IO stream connection or fail trying.
99    ///
100    /// Defaults to a 5 second timeout
101    ///
102    /// # Arguments
103    ///
104    /// * `name_server` - the IP and Port of the DNS server to connect to
105    #[allow(clippy::new_ret_no_self, clippy::type_complexity)]
106    pub fn new<E>(
107        name_server: SocketAddr,
108    ) -> (
109        impl Future<Output = Result<Self, io::Error>> + Send,
110        BufDnsStreamHandle,
111    )
112    where
113        E: FromProtoError,
114    {
115        Self::with_timeout(name_server, Duration::from_secs(5))
116    }
117
118    /// Creates a new future of the eventually establish a IO stream connection or fail trying
119    ///
120    /// # Arguments
121    ///
122    /// * `name_server` - the IP and Port of the DNS server to connect to
123    /// * `timeout` - connection timeout
124    #[allow(clippy::type_complexity)]
125    pub fn with_timeout(
126        name_server: SocketAddr,
127        timeout: Duration,
128    ) -> (
129        impl Future<Output = Result<Self, io::Error>> + Send,
130        BufDnsStreamHandle,
131    ) {
132        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
133
134        // This set of futures collapses the next tcp socket into a stream which can be used for
135        //  sending and receiving tcp packets.
136        let stream_fut = Self::connect(name_server, None, timeout, outbound_messages);
137
138        (stream_fut, message_sender)
139    }
140
141    /// Creates a new future of the eventually establish a IO stream connection or fail trying
142    ///
143    /// # Arguments
144    ///
145    /// * `name_server` - the IP and Port of the DNS server to connect to
146    /// * `bind_addr` - the IP and port to connect from
147    /// * `timeout` - connection timeout
148    #[allow(clippy::type_complexity)]
149    pub fn with_bind_addr_and_timeout(
150        name_server: SocketAddr,
151        bind_addr: Option<SocketAddr>,
152        timeout: Duration,
153    ) -> (
154        impl Future<Output = Result<Self, io::Error>> + Send,
155        BufDnsStreamHandle,
156    ) {
157        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
158        let stream_fut = Self::connect(name_server, bind_addr, timeout, outbound_messages);
159
160        (stream_fut, message_sender)
161    }
162
163    async fn connect(
164        name_server: SocketAddr,
165        bind_addr: Option<SocketAddr>,
166        timeout: Duration,
167        outbound_messages: StreamReceiver,
168    ) -> Result<Self, io::Error> {
169        let tcp = S::connect_with_bind(name_server, bind_addr);
170        S::Time::timeout(timeout, tcp)
171            .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
172                tcp_stream
173                    .and_then(|tcp_stream| tcp_stream)
174                    .map(|tcp_stream| {
175                        debug!("TCP connection established to: {}", name_server);
176                        Self {
177                            socket: tcp_stream,
178                            outbound_messages,
179                            send_state: None,
180                            read_state: ReadTcpState::LenBytes {
181                                pos: 0,
182                                bytes: [0u8; 2],
183                            },
184                            peer_addr: name_server,
185                        }
186                    })
187            })
188            .await
189    }
190}
191
192impl<S: DnsTcpStream> TcpStream<S> {
193    /// Returns the address of the peer connection.
194    pub fn peer_addr(&self) -> SocketAddr {
195        self.peer_addr
196    }
197
198    fn pollable_split(
199        &mut self,
200    ) -> (
201        &mut S,
202        &mut StreamReceiver,
203        &mut Option<WriteTcpState>,
204        &mut ReadTcpState,
205    ) {
206        (
207            &mut self.socket,
208            &mut self.outbound_messages,
209            &mut self.send_state,
210            &mut self.read_state,
211        )
212    }
213
214    /// Initializes a TcpStream.
215    ///
216    /// This is intended for use with a TcpListener and Incoming.
217    ///
218    /// # Arguments
219    ///
220    /// * `stream` - the established IO stream for communication
221    /// * `peer_addr` - sources address of the stream
222    pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
223        let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
224        let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
225        (stream, message_sender)
226    }
227
228    /// Wraps a stream where a sender and receiver have already been established
229    pub fn from_stream_with_receiver(
230        socket: S,
231        peer_addr: SocketAddr,
232        outbound_messages: StreamReceiver,
233    ) -> Self {
234        Self {
235            socket,
236            outbound_messages,
237            send_state: None,
238            read_state: ReadTcpState::LenBytes {
239                pos: 0,
240                bytes: [0u8; 2],
241            },
242            peer_addr,
243        }
244    }
245}
246
247impl<S: DnsTcpStream> Stream for TcpStream<S> {
248    type Item = io::Result<SerialMessage>;
249
250    #[allow(clippy::cognitive_complexity)]
251    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
252        let peer = self.peer_addr;
253        let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
254        let mut socket = Pin::new(socket);
255        let mut outbound_messages = Pin::new(outbound_messages);
256
257        // this will not accept incoming data while there is data to send
258        //  makes this self throttling.
259        // TODO: it might be interesting to try and split the sending and receiving futures.
260        loop {
261            // in the case we are sending, send it all?
262            if send_state.is_some() {
263                // sending...
264                match send_state {
265                    Some(WriteTcpState::LenBytes {
266                        ref mut pos,
267                        ref length,
268                        ..
269                    }) => {
270                        let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
271                        *pos += wrote;
272                    }
273                    Some(WriteTcpState::Bytes {
274                        ref mut pos,
275                        ref bytes,
276                    }) => {
277                        let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
278                        *pos += wrote;
279                    }
280                    Some(WriteTcpState::Flushing) => {
281                        ready!(socket.as_mut().poll_flush(cx))?;
282                    }
283                    _ => (),
284                }
285
286                // get current state
287                let current_state = send_state.take();
288
289                // switch states
290                match current_state {
291                    Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
292                        if pos < length.len() {
293                            *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
294                        } else {
295                            *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
296                        }
297                    }
298                    Some(WriteTcpState::Bytes { pos, bytes }) => {
299                        if pos < bytes.len() {
300                            *send_state = Some(WriteTcpState::Bytes { pos, bytes });
301                        } else {
302                            // At this point we successfully delivered the entire message.
303                            //  flush
304                            *send_state = Some(WriteTcpState::Flushing);
305                        }
306                    }
307                    Some(WriteTcpState::Flushing) => {
308                        // At this point we successfully delivered the entire message.
309                        send_state.take();
310                    }
311                    None => (),
312                };
313            } else {
314                // then see if there is more to send
315                match outbound_messages.as_mut().poll_next(cx)
316                    // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
317                {
318                    // already handled above, here to make sure the poll() pops the next message
319                    Poll::Ready(Some(message)) => {
320                        // if there is no peer, this connection should die...
321                        let (buffer, dst) = message.into();
322
323                        // This is an error if the destination is not our peer (this is TCP after all)
324                        //  This will kill the connection...
325                        if peer != dst {
326                            return Poll::Ready(Some(Err(io::Error::new(
327                                io::ErrorKind::InvalidData,
328                                format!("mismatched peer: {} and dst: {}", peer, dst),
329                            ))));
330                        }
331
332                        // will return if the socket will block
333                        // the length is 16 bits
334                        let len = u16::to_be_bytes(buffer.len() as u16);
335
336                        debug!("sending message len: {} to: {}", buffer.len(), dst);
337                        *send_state = Some(WriteTcpState::LenBytes {
338                            pos: 0,
339                            length: len,
340                            bytes: buffer,
341                        });
342                    }
343                    // now we get to drop through to the receives...
344                    // TODO: should we also return None if there are no more messages to send?
345                    Poll::Pending => break,
346                    Poll::Ready(None) => {
347                        debug!("no messages to send");
348                        break;
349                    }
350                }
351            }
352        }
353
354        let mut ret_buf: Option<Vec<u8>> = None;
355
356        // this will loop while there is data to read, or the data has been read, or an IO
357        //  event would block
358        while ret_buf.is_none() {
359            // Evaluates the next state. If None is the result, then no state change occurs,
360            //  if Some(_) is returned, then that will be used as the next state.
361            let new_state: Option<ReadTcpState> = match read_state {
362                ReadTcpState::LenBytes {
363                    ref mut pos,
364                    ref mut bytes,
365                } => {
366                    // debug!("reading length {}", bytes.len());
367                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
368                    if read == 0 {
369                        // the Stream was closed!
370                        debug!("zero bytes read, stream closed?");
371                        //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
372
373                        if *pos == 0 {
374                            // Since this is the start of the next message, we have a clean end
375                            return Poll::Ready(None);
376                        } else {
377                            return Poll::Ready(Some(Err(io::Error::new(
378                                io::ErrorKind::BrokenPipe,
379                                "closed while reading length",
380                            ))));
381                        }
382                    }
383                    debug!("in ReadTcpState::LenBytes: {}", pos);
384                    *pos += read;
385
386                    if *pos < bytes.len() {
387                        debug!("remain ReadTcpState::LenBytes: {}", pos);
388                        None
389                    } else {
390                        let length = u16::from_be_bytes(*bytes);
391                        debug!("got length: {}", length);
392                        let mut bytes = vec![0; length as usize];
393                        bytes.resize(length as usize, 0);
394
395                        debug!("move ReadTcpState::Bytes: {}", bytes.len());
396                        Some(ReadTcpState::Bytes { pos: 0, bytes })
397                    }
398                }
399                ReadTcpState::Bytes {
400                    ref mut pos,
401                    ref mut bytes,
402                } => {
403                    let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
404                    if read == 0 {
405                        // the Stream was closed!
406                        debug!("zero bytes read for message, stream closed?");
407
408                        // Since this is the start of the next message, we have a clean end
409                        // try!(self.socket.shutdown(Shutdown::Both));  // TODO: add generic shutdown function
410                        return Poll::Ready(Some(Err(io::Error::new(
411                            io::ErrorKind::BrokenPipe,
412                            "closed while reading message",
413                        ))));
414                    }
415
416                    debug!("in ReadTcpState::Bytes: {}", bytes.len());
417                    *pos += read;
418
419                    if *pos < bytes.len() {
420                        debug!("remain ReadTcpState::Bytes: {}", bytes.len());
421                        None
422                    } else {
423                        debug!("reset ReadTcpState::LenBytes: {}", 0);
424                        Some(ReadTcpState::LenBytes {
425                            pos: 0,
426                            bytes: [0u8; 2],
427                        })
428                    }
429                }
430            };
431
432            // this will move to the next state,
433            //  if it was a completed receipt of bytes, then it will move out the bytes
434            if let Some(state) = new_state {
435                if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
436                    debug!("returning bytes");
437                    assert_eq!(pos, bytes.len());
438                    ret_buf = Some(bytes);
439                }
440            }
441        }
442
443        // if the buffer is ready, return it, if not we're Pending
444        if let Some(buffer) = ret_buf {
445            debug!("returning buffer");
446            let src_addr = self.peer_addr;
447            Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
448        } else {
449            debug!("bottomed out");
450            // at a minimum the outbound_messages should have been polled,
451            //  which will wake this future up later...
452            Poll::Pending
453        }
454    }
455}
456
457#[cfg(test)]
458#[cfg(feature = "tokio-runtime")]
459mod tests {
460    #[cfg(not(target_os = "linux"))]
461    use std::net::Ipv6Addr;
462    use std::net::{IpAddr, Ipv4Addr};
463    use tokio::net::TcpStream as TokioTcpStream;
464    use tokio::runtime::Runtime;
465
466    use crate::iocompat::AsyncIoTokioAsStd;
467    use crate::TokioTime;
468
469    use crate::tests::tcp_stream_test;
470    #[test]
471    fn test_tcp_stream_ipv4() {
472        let io_loop = Runtime::new().expect("failed to create tokio runtime");
473        tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
474            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
475            io_loop,
476        )
477    }
478
479    #[test]
480    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
481    fn test_tcp_stream_ipv6() {
482        let io_loop = Runtime::new().expect("failed to create tokio runtime");
483        tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
484            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
485            io_loop,
486        )
487    }
488}