1use std::borrow::Borrow;
9use std::fmt::{self, Display};
10use std::marker::PhantomData;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use futures_util::{future::Future, stream::Stream};
18use tracing::{debug, warn};
19
20use crate::error::ProtoError;
21use crate::op::message::NoopMessageFinalizer;
22use crate::op::{MessageFinalizer, MessageVerifier};
23use crate::udp::udp_stream::{NextRandomUdpSocket, UdpSocket};
24use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
25use crate::Time;
26
27#[must_use = "futures do nothing unless polled"]
32pub struct UdpClientStream<S, MF = NoopMessageFinalizer>
33where
34 S: Send,
35 MF: MessageFinalizer,
36{
37 name_server: SocketAddr,
38 bind_addr: Option<SocketAddr>,
39 timeout: Duration,
40 is_shutdown: bool,
41 signer: Option<Arc<MF>>,
42 marker: PhantomData<S>,
43}
44
45impl<S: Send> UdpClientStream<S, NoopMessageFinalizer> {
46 #[allow(clippy::new_ret_no_self)]
55 pub fn new(name_server: SocketAddr) -> UdpClientConnect<S, NoopMessageFinalizer> {
56 Self::with_timeout(name_server, Duration::from_secs(5))
57 }
58
59 pub fn with_timeout(
66 name_server: SocketAddr,
67 timeout: Duration,
68 ) -> UdpClientConnect<S, NoopMessageFinalizer> {
69 Self::with_bind_addr_and_timeout(name_server, None, timeout)
70 }
71
72 pub fn with_bind_addr_and_timeout(
80 name_server: SocketAddr,
81 bind_addr: Option<SocketAddr>,
82 timeout: Duration,
83 ) -> UdpClientConnect<S, NoopMessageFinalizer> {
84 Self::with_timeout_and_signer_and_bind_addr(name_server, timeout, None, bind_addr)
85 }
86}
87
88impl<S: Send, MF: MessageFinalizer> UdpClientStream<S, MF> {
89 pub fn with_timeout_and_signer(
96 name_server: SocketAddr,
97 timeout: Duration,
98 signer: Option<Arc<MF>>,
99 ) -> UdpClientConnect<S, MF> {
100 UdpClientConnect {
101 name_server,
102 bind_addr: None,
103 timeout,
104 signer,
105 marker: PhantomData::<S>,
106 }
107 }
108
109 pub fn with_timeout_and_signer_and_bind_addr(
117 name_server: SocketAddr,
118 timeout: Duration,
119 signer: Option<Arc<MF>>,
120 bind_addr: Option<SocketAddr>,
121 ) -> UdpClientConnect<S, MF> {
122 UdpClientConnect {
123 name_server,
124 bind_addr,
125 timeout,
126 signer,
127 marker: PhantomData::<S>,
128 }
129 }
130}
131
132impl<S: Send, MF: MessageFinalizer> Display for UdpClientStream<S, MF> {
133 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
134 write!(formatter, "UDP({})", self.name_server)
135 }
136}
137
138fn random_query_id() -> u16 {
140 use rand::distributions::{Distribution, Standard};
141 let mut rand = rand::thread_rng();
142
143 Standard.sample(&mut rand)
144}
145
146impl<S: UdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
147 for UdpClientStream<S, MF>
148{
149 fn send_message(&mut self, mut message: DnsRequest) -> DnsResponseStream {
150 if self.is_shutdown {
151 panic!("can not send messages after stream is shutdown")
152 }
153
154 message.set_id(random_query_id());
157
158 let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
159 Ok(now) => now.as_secs(),
160 Err(_) => return ProtoError::from("Current time is before the Unix epoch.").into(),
161 };
162
163 let now = now as u32;
165
166 let mut verifier = None;
167 if let Some(ref signer) = self.signer {
168 if signer.should_finalize_message(&message) {
169 match message.finalize::<MF>(signer.borrow(), now) {
170 Ok(answer_verifier) => verifier = answer_verifier,
171 Err(e) => {
172 debug!("could not sign message: {}", e);
173 return e.into();
174 }
175 }
176 }
177 }
178
179 let bytes = match message.to_vec() {
180 Ok(bytes) => bytes,
181 Err(err) => {
182 return err.into();
183 }
184 };
185
186 let message_id = message.id();
187 let message = SerialMessage::new(bytes, self.name_server);
188 let bind_addr = self.bind_addr;
189
190 debug!(
191 "final message: {}",
192 message
193 .to_message()
194 .expect("bizarre we just made this message")
195 );
196
197 S::Time::timeout::<Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>>(
198 self.timeout,
199 Box::pin(send_serial_message::<S>(
200 message, message_id, verifier, bind_addr,
201 )),
202 )
203 .into()
204 }
205
206 fn shutdown(&mut self) {
207 self.is_shutdown = true;
208 }
209
210 fn is_shutdown(&self) -> bool {
211 self.is_shutdown
212 }
213}
214
215impl<S: Send, MF: MessageFinalizer> Stream for UdpClientStream<S, MF> {
217 type Item = Result<(), ProtoError>;
218
219 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220 if self.is_shutdown {
222 Poll::Ready(None)
223 } else {
224 Poll::Ready(Some(Ok(())))
225 }
226 }
227}
228
229pub struct UdpClientConnect<S, MF = NoopMessageFinalizer>
231where
232 S: Send,
233 MF: MessageFinalizer,
234{
235 name_server: SocketAddr,
236 bind_addr: Option<SocketAddr>,
237 timeout: Duration,
238 signer: Option<Arc<MF>>,
239 marker: PhantomData<S>,
240}
241
242impl<S: Send + Unpin, MF: MessageFinalizer> Future for UdpClientConnect<S, MF> {
243 type Output = Result<UdpClientStream<S, MF>, ProtoError>;
244
245 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
246 Poll::Ready(Ok(UdpClientStream::<S, MF> {
248 name_server: self.name_server,
249 bind_addr: self.bind_addr,
250 is_shutdown: false,
251 timeout: self.timeout,
252 signer: self.signer.take(),
253 marker: PhantomData,
254 }))
255 }
256}
257
258async fn send_serial_message<S: UdpSocket + Send>(
259 msg: SerialMessage,
260 msg_id: u16,
261 verifier: Option<MessageVerifier>,
262 bind_addr: Option<SocketAddr>,
263) -> Result<DnsResponse, ProtoError> {
264 let name_server = msg.addr();
265 let socket: S = NextRandomUdpSocket::new(&name_server, &bind_addr).await?;
266 let bytes = msg.bytes();
267 let addr = msg.addr();
268 let len_sent: usize = socket.send_to(bytes, addr).await?;
269
270 if bytes.len() != len_sent {
271 return Err(ProtoError::from(format!(
272 "Not all bytes of message sent, {} of {}",
273 len_sent,
274 bytes.len()
275 )));
276 }
277
278 loop {
280 let mut recv_buf = [0u8; 2048];
282
283 let (len, src) = socket.recv_from(&mut recv_buf).await?;
284 let response = SerialMessage::new(recv_buf.iter().take(len).cloned().collect(), src);
285
286 let request_target = msg.addr();
288
289 if response.addr() != request_target {
290 warn!(
291 "ignoring response from {} because it does not match name_server: {}.",
292 response.addr(),
293 request_target,
294 );
295
296 continue;
298 }
299
300 match response.to_message() {
303 Ok(message) => {
304 if msg_id == message.id() {
305 debug!("received message id: {}", message.id());
306 if let Some(mut verifier) = verifier {
307 return verifier(response.bytes());
308 } else {
309 return Ok(DnsResponse::from(message));
310 }
311 } else {
312 warn!(
314 "expected message id: {} got: {}, dropped",
315 msg_id,
316 message.id()
317 );
318
319 continue;
320 }
321 }
322 Err(e) => {
323 warn!(
325 "dropped malformed message waiting for id: {} err: {}",
326 msg_id, e
327 );
328
329 continue;
330 }
331 }
332 }
333}
334
335#[cfg(test)]
336#[cfg(feature = "tokio-runtime")]
337mod tests {
338 #![allow(clippy::dbg_macro, clippy::print_stdout)]
339 use crate::tests::udp_client_stream_test;
340 use crate::TokioTime;
341 #[cfg(not(target_os = "linux"))]
342 use std::net::Ipv6Addr;
343 use std::net::{IpAddr, Ipv4Addr};
344 use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
345
346 #[test]
347 fn test_udp_client_stream_ipv4() {
348 let io_loop = Runtime::new().expect("failed to create tokio runtime");
349 udp_client_stream_test::<TokioUdpSocket, Runtime, TokioTime>(
350 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
351 io_loop,
352 )
353 }
354
355 #[test]
356 #[cfg(not(target_os = "linux"))] fn test_udp_client_stream_ipv6() {
358 let io_loop = Runtime::new().expect("failed to create tokio runtime");
359 udp_client_stream_test::<TokioUdpSocket, Runtime, TokioTime>(
360 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
361 io_loop,
362 )
363 }
364}