trust_dns_proto/tcp/
tcp_client_stream.rs

1// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::fmt::{self, Display};
9#[cfg(feature = "tokio-runtime")]
10use std::io;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16#[cfg(feature = "tokio-runtime")]
17use async_trait::async_trait;
18use futures_util::{future::Future, stream::Stream, StreamExt, TryFutureExt};
19use tracing::warn;
20
21use crate::error::ProtoError;
22#[cfg(feature = "tokio-runtime")]
23use crate::iocompat::AsyncIoTokioAsStd;
24use crate::tcp::{Connect, DnsTcpStream, TcpStream};
25use crate::xfer::{DnsClientStream, SerialMessage};
26use crate::BufDnsStreamHandle;
27#[cfg(feature = "tokio-runtime")]
28use crate::TokioTime;
29
30/// Tcp client stream
31///
32/// Use with `trust_dns_client::client::DnsMultiplexer` impls
33#[must_use = "futures do nothing unless polled"]
34pub struct TcpClientStream<S>
35where
36    S: DnsTcpStream,
37{
38    tcp_stream: TcpStream<S>,
39}
40
41impl<S: Connect> TcpClientStream<S> {
42    /// Constructs a new TcpStream for a client to the specified SocketAddr.
43    ///
44    /// Defaults to a 5 second timeout
45    ///
46    /// # Arguments
47    ///
48    /// * `name_server` - the IP and Port of the DNS server to connect to
49    #[allow(clippy::new_ret_no_self)]
50    pub fn new(name_server: SocketAddr) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
51        Self::with_timeout(name_server, Duration::from_secs(5))
52    }
53
54    /// Constructs a new TcpStream for a client to the specified SocketAddr.
55    ///
56    /// # Arguments
57    ///
58    /// * `name_server` - the IP and Port of the DNS server to connect to
59    /// * `timeout` - connection timeout
60    pub fn with_timeout(
61        name_server: SocketAddr,
62        timeout: Duration,
63    ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
64        Self::with_bind_addr_and_timeout(name_server, None, timeout)
65    }
66
67    /// Constructs a new TcpStream for a client to the specified SocketAddr.
68    ///
69    /// # Arguments
70    ///
71    /// * `name_server` - the IP and Port of the DNS server to connect to
72    /// * `bind_addr` - the IP and port to connect from
73    /// * `timeout` - connection timeout
74    #[allow(clippy::new_ret_no_self)]
75    pub fn with_bind_addr_and_timeout(
76        name_server: SocketAddr,
77        bind_addr: Option<SocketAddr>,
78        timeout: Duration,
79    ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
80        let (stream_future, sender) =
81            TcpStream::<S>::with_bind_addr_and_timeout(name_server, bind_addr, timeout);
82
83        let new_future = Box::pin(
84            stream_future
85                .map_ok(move |tcp_stream| Self { tcp_stream })
86                .map_err(ProtoError::from),
87        );
88
89        (TcpClientConnect(new_future), sender)
90    }
91}
92
93impl<S: DnsTcpStream> TcpClientStream<S> {
94    /// Wraps the TcpStream in TcpClientStream
95    pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
96        Self { tcp_stream }
97    }
98}
99
100impl<S: DnsTcpStream> Display for TcpClientStream<S> {
101    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
102        write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
103    }
104}
105
106impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
107    type Time = S::Time;
108
109    fn name_server_addr(&self) -> SocketAddr {
110        self.tcp_stream.peer_addr()
111    }
112}
113
114impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
115    type Item = Result<SerialMessage, ProtoError>;
116
117    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
118        let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
119
120        // this is busted if the tcp connection doesn't have a peer
121        let peer = self.tcp_stream.peer_addr();
122        if message.addr() != peer {
123            // TODO: this should be an error, right?
124            warn!("{} does not match name_server: {}", message.addr(), peer)
125        }
126
127        Poll::Ready(Some(Ok(message)))
128    }
129}
130
131// TODO: create unboxed future for the TCP Stream
132/// A future that resolves to an TcpClientStream
133pub struct TcpClientConnect<S: DnsTcpStream>(
134    Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
135);
136
137impl<S: DnsTcpStream> Future for TcpClientConnect<S> {
138    type Output = Result<TcpClientStream<S>, ProtoError>;
139
140    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        self.0.as_mut().poll(cx)
142    }
143}
144
145#[cfg(feature = "tokio-runtime")]
146use tokio::net::TcpStream as TokioTcpStream;
147
148#[cfg(feature = "tokio-runtime")]
149impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
150where
151    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
152{
153    type Time = TokioTime;
154}
155
156#[cfg(feature = "tokio-runtime")]
157#[async_trait]
158impl Connect for AsyncIoTokioAsStd<TokioTcpStream> {
159    async fn connect_with_bind(
160        addr: SocketAddr,
161        bind_addr: Option<SocketAddr>,
162    ) -> io::Result<Self> {
163        super::tokio::connect_with_bind(&addr, &bind_addr)
164            .await
165            .map(AsyncIoTokioAsStd)
166    }
167}
168
169#[cfg(test)]
170#[cfg(feature = "tokio-runtime")]
171mod tests {
172    use super::AsyncIoTokioAsStd;
173    #[cfg(not(target_os = "linux"))]
174    use std::net::Ipv6Addr;
175    use std::net::{IpAddr, Ipv4Addr};
176    use tokio::net::TcpStream as TokioTcpStream;
177    use tokio::runtime::Runtime;
178
179    use crate::tests::tcp_client_stream_test;
180    use crate::TokioTime;
181    #[test]
182    fn test_tcp_stream_ipv4() {
183        let io_loop = Runtime::new().expect("failed to create tokio runtime");
184        tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
185            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
186            io_loop,
187        )
188    }
189
190    #[test]
191    #[cfg(not(target_os = "linux"))] // ignored until Travis-CI fixes IPv6
192    fn test_tcp_stream_ipv6() {
193        let io_loop = Runtime::new().expect("failed to create tokio runtime");
194        tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
195            IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
196            io_loop,
197        )
198    }
199}