trust_dns_proto/tcp/
tcp_client_stream.rs
1use std::fmt::{self, Display};
9#[cfg(feature = "tokio-runtime")]
10use std::io;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16#[cfg(feature = "tokio-runtime")]
17use async_trait::async_trait;
18use futures_util::{future::Future, stream::Stream, StreamExt, TryFutureExt};
19use tracing::warn;
20
21use crate::error::ProtoError;
22#[cfg(feature = "tokio-runtime")]
23use crate::iocompat::AsyncIoTokioAsStd;
24use crate::tcp::{Connect, DnsTcpStream, TcpStream};
25use crate::xfer::{DnsClientStream, SerialMessage};
26use crate::BufDnsStreamHandle;
27#[cfg(feature = "tokio-runtime")]
28use crate::TokioTime;
29
30#[must_use = "futures do nothing unless polled"]
34pub struct TcpClientStream<S>
35where
36 S: DnsTcpStream,
37{
38 tcp_stream: TcpStream<S>,
39}
40
41impl<S: Connect> TcpClientStream<S> {
42 #[allow(clippy::new_ret_no_self)]
50 pub fn new(name_server: SocketAddr) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
51 Self::with_timeout(name_server, Duration::from_secs(5))
52 }
53
54 pub fn with_timeout(
61 name_server: SocketAddr,
62 timeout: Duration,
63 ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
64 Self::with_bind_addr_and_timeout(name_server, None, timeout)
65 }
66
67 #[allow(clippy::new_ret_no_self)]
75 pub fn with_bind_addr_and_timeout(
76 name_server: SocketAddr,
77 bind_addr: Option<SocketAddr>,
78 timeout: Duration,
79 ) -> (TcpClientConnect<S>, BufDnsStreamHandle) {
80 let (stream_future, sender) =
81 TcpStream::<S>::with_bind_addr_and_timeout(name_server, bind_addr, timeout);
82
83 let new_future = Box::pin(
84 stream_future
85 .map_ok(move |tcp_stream| Self { tcp_stream })
86 .map_err(ProtoError::from),
87 );
88
89 (TcpClientConnect(new_future), sender)
90 }
91}
92
93impl<S: DnsTcpStream> TcpClientStream<S> {
94 pub fn from_stream(tcp_stream: TcpStream<S>) -> Self {
96 Self { tcp_stream }
97 }
98}
99
100impl<S: DnsTcpStream> Display for TcpClientStream<S> {
101 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
102 write!(formatter, "TCP({})", self.tcp_stream.peer_addr())
103 }
104}
105
106impl<S: DnsTcpStream> DnsClientStream for TcpClientStream<S> {
107 type Time = S::Time;
108
109 fn name_server_addr(&self) -> SocketAddr {
110 self.tcp_stream.peer_addr()
111 }
112}
113
114impl<S: DnsTcpStream> Stream for TcpClientStream<S> {
115 type Item = Result<SerialMessage, ProtoError>;
116
117 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
118 let message = try_ready_stream!(self.tcp_stream.poll_next_unpin(cx));
119
120 let peer = self.tcp_stream.peer_addr();
122 if message.addr() != peer {
123 warn!("{} does not match name_server: {}", message.addr(), peer)
125 }
126
127 Poll::Ready(Some(Ok(message)))
128 }
129}
130
131pub struct TcpClientConnect<S: DnsTcpStream>(
134 Pin<Box<dyn Future<Output = Result<TcpClientStream<S>, ProtoError>> + Send + 'static>>,
135);
136
137impl<S: DnsTcpStream> Future for TcpClientConnect<S> {
138 type Output = Result<TcpClientStream<S>, ProtoError>;
139
140 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 self.0.as_mut().poll(cx)
142 }
143}
144
145#[cfg(feature = "tokio-runtime")]
146use tokio::net::TcpStream as TokioTcpStream;
147
148#[cfg(feature = "tokio-runtime")]
149impl<T> DnsTcpStream for AsyncIoTokioAsStd<T>
150where
151 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + Sized + 'static,
152{
153 type Time = TokioTime;
154}
155
156#[cfg(feature = "tokio-runtime")]
157#[async_trait]
158impl Connect for AsyncIoTokioAsStd<TokioTcpStream> {
159 async fn connect_with_bind(
160 addr: SocketAddr,
161 bind_addr: Option<SocketAddr>,
162 ) -> io::Result<Self> {
163 super::tokio::connect_with_bind(&addr, &bind_addr)
164 .await
165 .map(AsyncIoTokioAsStd)
166 }
167}
168
169#[cfg(test)]
170#[cfg(feature = "tokio-runtime")]
171mod tests {
172 use super::AsyncIoTokioAsStd;
173 #[cfg(not(target_os = "linux"))]
174 use std::net::Ipv6Addr;
175 use std::net::{IpAddr, Ipv4Addr};
176 use tokio::net::TcpStream as TokioTcpStream;
177 use tokio::runtime::Runtime;
178
179 use crate::tests::tcp_client_stream_test;
180 use crate::TokioTime;
181 #[test]
182 fn test_tcp_stream_ipv4() {
183 let io_loop = Runtime::new().expect("failed to create tokio runtime");
184 tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
185 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
186 io_loop,
187 )
188 }
189
190 #[test]
191 #[cfg(not(target_os = "linux"))] fn test_tcp_stream_ipv6() {
193 let io_loop = Runtime::new().expect("failed to create tokio runtime");
194 tcp_client_stream_test::<AsyncIoTokioAsStd<TokioTcpStream>, Runtime, TokioTime>(
195 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
196 io_loop,
197 )
198 }
199}