trust_dns_proto/tcp/tcp_stream.rs
1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the TCP structures for demuxing TCP into streams of DNS packets.
9
10use std::io;
11use std::mem;
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use async_trait::async_trait;
18use futures_io::{AsyncRead, AsyncWrite};
19use futures_util::stream::Stream;
20use futures_util::{self, future::Future, ready, FutureExt};
21use tracing::debug;
22
23use crate::error::*;
24use crate::xfer::{SerialMessage, StreamReceiver};
25use crate::BufDnsStreamHandle;
26use crate::Time;
27
28/// Trait for TCP connection
29pub trait DnsTcpStream: AsyncRead + AsyncWrite + Unpin + Send + Sync + Sized + 'static {
30 /// Timer type to use with this TCP stream type
31 type Time: Time;
32}
33
34/// Trait for TCP connection
35#[async_trait]
36pub trait Connect: DnsTcpStream {
37 /// connect to tcp
38 async fn connect(addr: SocketAddr) -> io::Result<Self> {
39 Self::connect_with_bind(addr, None).await
40 }
41
42 /// connect to tcp with address to connect from
43 async fn connect_with_bind(addr: SocketAddr, bind_addr: Option<SocketAddr>)
44 -> io::Result<Self>;
45}
46
47/// Current state while writing to the remote of the TCP connection
48enum WriteTcpState {
49 /// Currently writing the length of bytes to of the buffer.
50 LenBytes {
51 /// Current position in the length buffer being written
52 pos: usize,
53 /// Length of the buffer
54 length: [u8; 2],
55 /// Buffer to write after the length
56 bytes: Vec<u8>,
57 },
58 /// Currently writing the buffer to the remote
59 Bytes {
60 /// Current position in the buffer written
61 pos: usize,
62 /// Buffer to write to the remote
63 bytes: Vec<u8>,
64 },
65 /// Currently flushing the bytes to the remote
66 Flushing,
67}
68
69/// Current state of a TCP stream as it's being read.
70pub(crate) enum ReadTcpState {
71 /// Currently reading the length of the TCP packet
72 LenBytes {
73 /// Current position in the buffer
74 pos: usize,
75 /// Buffer of the length to read
76 bytes: [u8; 2],
77 },
78 /// Currently reading the bytes of the DNS packet
79 Bytes {
80 /// Current position while reading the buffer
81 pos: usize,
82 /// buffer being read into
83 bytes: Vec<u8>,
84 },
85}
86
87/// A Stream used for sending data to and from a remote DNS endpoint (client or server).
88#[must_use = "futures do nothing unless polled"]
89pub struct TcpStream<S: DnsTcpStream> {
90 socket: S,
91 outbound_messages: StreamReceiver,
92 send_state: Option<WriteTcpState>,
93 read_state: ReadTcpState,
94 peer_addr: SocketAddr,
95}
96
97impl<S: Connect> TcpStream<S> {
98 /// Creates a new future of the eventually establish a IO stream connection or fail trying.
99 ///
100 /// Defaults to a 5 second timeout
101 ///
102 /// # Arguments
103 ///
104 /// * `name_server` - the IP and Port of the DNS server to connect to
105 #[allow(clippy::new_ret_no_self, clippy::type_complexity)]
106 pub fn new<E>(
107 name_server: SocketAddr,
108 ) -> (
109 impl Future<Output = Result<Self, io::Error>> + Send,
110 BufDnsStreamHandle,
111 )
112 where
113 E: FromProtoError,
114 {
115 Self::with_timeout(name_server, Duration::from_secs(5))
116 }
117
118 /// Creates a new future of the eventually establish a IO stream connection or fail trying
119 ///
120 /// # Arguments
121 ///
122 /// * `name_server` - the IP and Port of the DNS server to connect to
123 /// * `timeout` - connection timeout
124 #[allow(clippy::type_complexity)]
125 pub fn with_timeout(
126 name_server: SocketAddr,
127 timeout: Duration,
128 ) -> (
129 impl Future<Output = Result<Self, io::Error>> + Send,
130 BufDnsStreamHandle,
131 ) {
132 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
133
134 // This set of futures collapses the next tcp socket into a stream which can be used for
135 // sending and receiving tcp packets.
136 let stream_fut = Self::connect(name_server, None, timeout, outbound_messages);
137
138 (stream_fut, message_sender)
139 }
140
141 /// Creates a new future of the eventually establish a IO stream connection or fail trying
142 ///
143 /// # Arguments
144 ///
145 /// * `name_server` - the IP and Port of the DNS server to connect to
146 /// * `bind_addr` - the IP and port to connect from
147 /// * `timeout` - connection timeout
148 #[allow(clippy::type_complexity)]
149 pub fn with_bind_addr_and_timeout(
150 name_server: SocketAddr,
151 bind_addr: Option<SocketAddr>,
152 timeout: Duration,
153 ) -> (
154 impl Future<Output = Result<Self, io::Error>> + Send,
155 BufDnsStreamHandle,
156 ) {
157 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
158 let stream_fut = Self::connect(name_server, bind_addr, timeout, outbound_messages);
159
160 (stream_fut, message_sender)
161 }
162
163 async fn connect(
164 name_server: SocketAddr,
165 bind_addr: Option<SocketAddr>,
166 timeout: Duration,
167 outbound_messages: StreamReceiver,
168 ) -> Result<Self, io::Error> {
169 let tcp = S::connect_with_bind(name_server, bind_addr);
170 S::Time::timeout(timeout, tcp)
171 .map(move |tcp_stream: Result<Result<S, io::Error>, _>| {
172 tcp_stream
173 .and_then(|tcp_stream| tcp_stream)
174 .map(|tcp_stream| {
175 debug!("TCP connection established to: {}", name_server);
176 Self {
177 socket: tcp_stream,
178 outbound_messages,
179 send_state: None,
180 read_state: ReadTcpState::LenBytes {
181 pos: 0,
182 bytes: [0u8; 2],
183 },
184 peer_addr: name_server,
185 }
186 })
187 })
188 .await
189 }
190}
191
192impl<S: DnsTcpStream> TcpStream<S> {
193 /// Returns the address of the peer connection.
194 pub fn peer_addr(&self) -> SocketAddr {
195 self.peer_addr
196 }
197
198 fn pollable_split(
199 &mut self,
200 ) -> (
201 &mut S,
202 &mut StreamReceiver,
203 &mut Option<WriteTcpState>,
204 &mut ReadTcpState,
205 ) {
206 (
207 &mut self.socket,
208 &mut self.outbound_messages,
209 &mut self.send_state,
210 &mut self.read_state,
211 )
212 }
213
214 /// Initializes a TcpStream.
215 ///
216 /// This is intended for use with a TcpListener and Incoming.
217 ///
218 /// # Arguments
219 ///
220 /// * `stream` - the established IO stream for communication
221 /// * `peer_addr` - sources address of the stream
222 pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
223 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
224 let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
225 (stream, message_sender)
226 }
227
228 /// Wraps a stream where a sender and receiver have already been established
229 pub fn from_stream_with_receiver(
230 socket: S,
231 peer_addr: SocketAddr,
232 outbound_messages: StreamReceiver,
233 ) -> Self {
234 Self {
235 socket,
236 outbound_messages,
237 send_state: None,
238 read_state: ReadTcpState::LenBytes {
239 pos: 0,
240 bytes: [0u8; 2],
241 },
242 peer_addr,
243 }
244 }
245}
246
247impl<S: DnsTcpStream> Stream for TcpStream<S> {
248 type Item = io::Result<SerialMessage>;
249
250 #[allow(clippy::cognitive_complexity)]
251 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
252 let peer = self.peer_addr;
253 let (socket, outbound_messages, send_state, read_state) = self.pollable_split();
254 let mut socket = Pin::new(socket);
255 let mut outbound_messages = Pin::new(outbound_messages);
256
257 // this will not accept incoming data while there is data to send
258 // makes this self throttling.
259 // TODO: it might be interesting to try and split the sending and receiving futures.
260 loop {
261 // in the case we are sending, send it all?
262 if send_state.is_some() {
263 // sending...
264 match send_state {
265 Some(WriteTcpState::LenBytes {
266 ref mut pos,
267 ref length,
268 ..
269 }) => {
270 let wrote = ready!(socket.as_mut().poll_write(cx, &length[*pos..]))?;
271 *pos += wrote;
272 }
273 Some(WriteTcpState::Bytes {
274 ref mut pos,
275 ref bytes,
276 }) => {
277 let wrote = ready!(socket.as_mut().poll_write(cx, &bytes[*pos..]))?;
278 *pos += wrote;
279 }
280 Some(WriteTcpState::Flushing) => {
281 ready!(socket.as_mut().poll_flush(cx))?;
282 }
283 _ => (),
284 }
285
286 // get current state
287 let current_state = send_state.take();
288
289 // switch states
290 match current_state {
291 Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
292 if pos < length.len() {
293 *send_state = Some(WriteTcpState::LenBytes { pos, length, bytes });
294 } else {
295 *send_state = Some(WriteTcpState::Bytes { pos: 0, bytes });
296 }
297 }
298 Some(WriteTcpState::Bytes { pos, bytes }) => {
299 if pos < bytes.len() {
300 *send_state = Some(WriteTcpState::Bytes { pos, bytes });
301 } else {
302 // At this point we successfully delivered the entire message.
303 // flush
304 *send_state = Some(WriteTcpState::Flushing);
305 }
306 }
307 Some(WriteTcpState::Flushing) => {
308 // At this point we successfully delivered the entire message.
309 send_state.take();
310 }
311 None => (),
312 };
313 } else {
314 // then see if there is more to send
315 match outbound_messages.as_mut().poll_next(cx)
316 // .map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
317 {
318 // already handled above, here to make sure the poll() pops the next message
319 Poll::Ready(Some(message)) => {
320 // if there is no peer, this connection should die...
321 let (buffer, dst) = message.into();
322
323 // This is an error if the destination is not our peer (this is TCP after all)
324 // This will kill the connection...
325 if peer != dst {
326 return Poll::Ready(Some(Err(io::Error::new(
327 io::ErrorKind::InvalidData,
328 format!("mismatched peer: {} and dst: {}", peer, dst),
329 ))));
330 }
331
332 // will return if the socket will block
333 // the length is 16 bits
334 let len = u16::to_be_bytes(buffer.len() as u16);
335
336 debug!("sending message len: {} to: {}", buffer.len(), dst);
337 *send_state = Some(WriteTcpState::LenBytes {
338 pos: 0,
339 length: len,
340 bytes: buffer,
341 });
342 }
343 // now we get to drop through to the receives...
344 // TODO: should we also return None if there are no more messages to send?
345 Poll::Pending => break,
346 Poll::Ready(None) => {
347 debug!("no messages to send");
348 break;
349 }
350 }
351 }
352 }
353
354 let mut ret_buf: Option<Vec<u8>> = None;
355
356 // this will loop while there is data to read, or the data has been read, or an IO
357 // event would block
358 while ret_buf.is_none() {
359 // Evaluates the next state. If None is the result, then no state change occurs,
360 // if Some(_) is returned, then that will be used as the next state.
361 let new_state: Option<ReadTcpState> = match read_state {
362 ReadTcpState::LenBytes {
363 ref mut pos,
364 ref mut bytes,
365 } => {
366 // debug!("reading length {}", bytes.len());
367 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
368 if read == 0 {
369 // the Stream was closed!
370 debug!("zero bytes read, stream closed?");
371 //try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
372
373 if *pos == 0 {
374 // Since this is the start of the next message, we have a clean end
375 return Poll::Ready(None);
376 } else {
377 return Poll::Ready(Some(Err(io::Error::new(
378 io::ErrorKind::BrokenPipe,
379 "closed while reading length",
380 ))));
381 }
382 }
383 debug!("in ReadTcpState::LenBytes: {}", pos);
384 *pos += read;
385
386 if *pos < bytes.len() {
387 debug!("remain ReadTcpState::LenBytes: {}", pos);
388 None
389 } else {
390 let length = u16::from_be_bytes(*bytes);
391 debug!("got length: {}", length);
392 let mut bytes = vec![0; length as usize];
393 bytes.resize(length as usize, 0);
394
395 debug!("move ReadTcpState::Bytes: {}", bytes.len());
396 Some(ReadTcpState::Bytes { pos: 0, bytes })
397 }
398 }
399 ReadTcpState::Bytes {
400 ref mut pos,
401 ref mut bytes,
402 } => {
403 let read = ready!(socket.as_mut().poll_read(cx, &mut bytes[*pos..]))?;
404 if read == 0 {
405 // the Stream was closed!
406 debug!("zero bytes read for message, stream closed?");
407
408 // Since this is the start of the next message, we have a clean end
409 // try!(self.socket.shutdown(Shutdown::Both)); // TODO: add generic shutdown function
410 return Poll::Ready(Some(Err(io::Error::new(
411 io::ErrorKind::BrokenPipe,
412 "closed while reading message",
413 ))));
414 }
415
416 debug!("in ReadTcpState::Bytes: {}", bytes.len());
417 *pos += read;
418
419 if *pos < bytes.len() {
420 debug!("remain ReadTcpState::Bytes: {}", bytes.len());
421 None
422 } else {
423 debug!("reset ReadTcpState::LenBytes: {}", 0);
424 Some(ReadTcpState::LenBytes {
425 pos: 0,
426 bytes: [0u8; 2],
427 })
428 }
429 }
430 };
431
432 // this will move to the next state,
433 // if it was a completed receipt of bytes, then it will move out the bytes
434 if let Some(state) = new_state {
435 if let ReadTcpState::Bytes { pos, bytes } = mem::replace(read_state, state) {
436 debug!("returning bytes");
437 assert_eq!(pos, bytes.len());
438 ret_buf = Some(bytes);
439 }
440 }
441 }
442
443 // if the buffer is ready, return it, if not we're Pending
444 if let Some(buffer) = ret_buf {
445 debug!("returning buffer");
446 let src_addr = self.peer_addr;
447 Poll::Ready(Some(Ok(SerialMessage::new(buffer, src_addr))))
448 } else {
449 debug!("bottomed out");
450 // at a minimum the outbound_messages should have been polled,
451 // which will wake this future up later...
452 Poll::Pending
453 }
454 }
455}
456
457#[cfg(test)]
458#[cfg(feature = "tokio-runtime")]
459mod tests {
460 #[cfg(not(target_os = "linux"))]
461 use std::net::Ipv6Addr;
462 use std::net::{IpAddr, Ipv4Addr};
463 use tokio::net::TcpStream as TokioTcpStream;
464 use tokio::runtime::Runtime;
465
466 use crate::iocompat::AsyncIoTokioAsStd;
467 use crate::TokioTime;
468
469 use crate::tests::tcp_stream_test;
470 #[test]
471 fn test_tcp_stream_ipv4() {
472 let io_loop = Runtime::new().expect("failed to create tokio runtime");
473 tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
474 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
475 io_loop,
476 )
477 }
478
479 #[test]
480 #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
481 fn test_tcp_stream_ipv6() {
482 let io_loop = Runtime::new().expect("failed to create tokio runtime");
483 tcp_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
484 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
485 io_loop,
486 )
487 }
488}