1use std::io;
9use std::marker::PhantomData;
10use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use async_trait::async_trait;
15use futures_util::stream::Stream;
16use futures_util::{future::Future, ready, TryFutureExt};
17use rand;
18use rand::distributions::{uniform::Uniform, Distribution};
19use tracing::{debug, warn};
20
21use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
22use crate::Time;
23
24#[async_trait]
26pub trait UdpSocket
27where
28 Self: Send + Sync + Sized + Unpin,
29{
30 type Time: Time;
32
33 async fn connect(addr: SocketAddr) -> io::Result<Self>;
35
36 async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self>;
38
39 async fn bind(addr: SocketAddr) -> io::Result<Self>;
41
42 fn poll_recv_from(
45 &self,
46 cx: &mut Context<'_>,
47 buf: &mut [u8],
48 ) -> Poll<io::Result<(usize, SocketAddr)>>;
49
50 async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
53 futures_util::future::poll_fn(|cx| self.poll_recv_from(cx, buf)).await
54 }
55
56 fn poll_send_to(
58 &self,
59 cx: &mut Context<'_>,
60 buf: &[u8],
61 target: SocketAddr,
62 ) -> Poll<io::Result<usize>>;
63
64 async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
66 futures_util::future::poll_fn(|cx| self.poll_send_to(cx, buf, target)).await
67 }
68}
69
70#[must_use = "futures do nothing unless polled"]
72pub struct UdpStream<S: Send> {
73 socket: S,
74 outbound_messages: StreamReceiver,
75}
76
77impl<S: UdpSocket + Send + 'static> UdpStream<S> {
78 #[allow(clippy::type_complexity)]
92 pub fn new(
93 remote_addr: SocketAddr,
94 bind_addr: Option<SocketAddr>,
95 ) -> (
96 Box<dyn Future<Output = Result<Self, io::Error>> + Send + Unpin>,
97 BufDnsStreamHandle,
98 ) {
99 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
100
101 let next_socket = NextRandomUdpSocket::new(&remote_addr, &bind_addr);
104
105 let stream = Box::new(next_socket.map_ok(move |socket| Self {
108 socket,
109 outbound_messages,
110 }));
111
112 (stream, message_sender)
113 }
114
115 pub fn with_bound(socket: S, remote_addr: SocketAddr) -> (Self, BufDnsStreamHandle) {
130 let (message_sender, outbound_messages) = BufDnsStreamHandle::new(remote_addr);
131 let stream = Self {
132 socket,
133 outbound_messages,
134 };
135
136 (stream, message_sender)
137 }
138
139 #[allow(unused)]
140 pub(crate) fn from_parts(socket: S, outbound_messages: StreamReceiver) -> Self {
141 Self {
142 socket,
143 outbound_messages,
144 }
145 }
146}
147
148impl<S: Send> UdpStream<S> {
149 #[allow(clippy::type_complexity)]
150 fn pollable_split(&mut self) -> (&mut S, &mut StreamReceiver) {
151 (&mut self.socket, &mut self.outbound_messages)
152 }
153}
154
155impl<S: UdpSocket + Send + 'static> Stream for UdpStream<S> {
156 type Item = Result<SerialMessage, io::Error>;
157
158 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
159 let (socket, outbound_messages) = self.pollable_split();
160 let socket = Pin::new(socket);
161 let mut outbound_messages = Pin::new(outbound_messages);
162
163 while let Poll::Ready(Some(message)) = outbound_messages.as_mut().poll_peek(cx) {
166 let addr = message.addr();
168
169 if let Err(e) = ready!(socket.poll_send_to(cx, message.bytes(), addr)) {
174 warn!(
176 "error sending message to {} on udp_socket, dropping response: {}",
177 addr, e
178 );
179 }
180
181 assert!(outbound_messages.as_mut().poll_next(cx).is_ready());
183 }
184
185 let mut buf = [0u8; 4096];
190 let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;
191
192 let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
193 Poll::Ready(Some(Ok(serial_message)))
194 }
195}
196
197#[must_use = "futures do nothing unless polled"]
198pub(crate) struct NextRandomUdpSocket<S> {
199 bind_address: SocketAddr,
200 marker: PhantomData<S>,
201}
202
203impl<S: UdpSocket> NextRandomUdpSocket<S> {
204 pub(crate) fn new(name_server: &SocketAddr, bind_addr: &Option<SocketAddr>) -> Self {
209 let bind_address = match bind_addr {
210 Some(ba) => *ba,
211 None => match *name_server {
212 SocketAddr::V4(..) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
213 SocketAddr::V6(..) => {
214 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
215 }
216 },
217 };
218
219 Self {
220 bind_address,
221 marker: PhantomData,
222 }
223 }
224
225 async fn bind(addr: SocketAddr) -> Result<S, io::Error> {
226 S::bind(addr).await
227 }
228}
229
230impl<S: UdpSocket> Future for NextRandomUdpSocket<S> {
231 type Output = Result<S, io::Error>;
232
233 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 if self.bind_address.port() == 0 {
239 let rand_port_range = Uniform::new_inclusive(49152_u16, u16::max_value());
244 let mut rand = rand::thread_rng();
245
246 for attempt in 0..10 {
247 let port = rand_port_range.sample(&mut rand);
248 let bind_addr = SocketAddr::new(self.bind_address.ip(), port);
249
250 match Box::pin(Self::bind(bind_addr)).as_mut().poll(cx) {
253 Poll::Ready(Ok(socket)) => {
254 debug!("created socket successfully");
255 return Poll::Ready(Ok(socket));
256 }
257 Poll::Ready(Err(err)) => match err.kind() {
258 io::ErrorKind::AddrInUse => {
259 debug!("unable to bind port, attempt: {}: {}", attempt, err);
260 }
261 _ => {
262 debug!("failed to bind port: {}", err);
263 return Poll::Ready(Err(err));
264 }
265 },
266 Poll::Pending => debug!("unable to bind port, attempt: {}", attempt),
267 }
268 }
269
270 debug!("could not get next random port, delaying");
271
272 cx.waker().wake_by_ref();
274
275 Poll::Pending
277 } else {
278 Box::pin(Self::bind(self.bind_address)).as_mut().poll(cx)
280 }
281 }
282}
283
284#[cfg(feature = "tokio-runtime")]
285#[async_trait]
286impl UdpSocket for tokio::net::UdpSocket {
287 type Time = crate::TokioTime;
288
289 async fn connect(addr: SocketAddr) -> io::Result<Self> {
293 let bind_addr: SocketAddr = match addr {
294 SocketAddr::V4(_addr) => (Ipv4Addr::UNSPECIFIED, 0).into(),
295 SocketAddr::V6(_addr) => (Ipv6Addr::UNSPECIFIED, 0).into(),
296 };
297
298 Self::connect_with_bind(addr, bind_addr).await
299 }
300
301 async fn connect_with_bind(_addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
303 let socket = Self::bind(bind_addr).await?;
304
305 Ok(socket)
309 }
310
311 async fn bind(addr: SocketAddr) -> io::Result<Self> {
312 Self::bind(addr).await
313 }
314
315 fn poll_recv_from(
316 &self,
317 cx: &mut Context<'_>,
318 buf: &mut [u8],
319 ) -> Poll<io::Result<(usize, SocketAddr)>> {
320 let mut buf = tokio::io::ReadBuf::new(buf);
321 let addr = ready!(Self::poll_recv_from(self, cx, &mut buf))?;
322 let len = buf.filled().len();
323
324 Poll::Ready(Ok((len, addr)))
325 }
326
327 fn poll_send_to(
328 &self,
329 cx: &mut Context<'_>,
330 buf: &[u8],
331 target: SocketAddr,
332 ) -> Poll<io::Result<usize>> {
333 Self::poll_send_to(self, cx, buf, target)
334 }
335}
336
337#[cfg(test)]
338#[cfg(feature = "tokio-runtime")]
339mod tests {
340 #[cfg(not(target_os = "linux"))] use std::net::Ipv6Addr;
342 use std::net::{IpAddr, Ipv4Addr};
343 use tokio::{net::UdpSocket as TokioUdpSocket, runtime::Runtime};
344
345 #[test]
346 fn test_next_random_socket() {
347 use crate::tests::next_random_socket_test;
348 let io_loop = Runtime::new().expect("failed to create tokio runtime");
349 next_random_socket_test::<TokioUdpSocket, Runtime>(io_loop)
350 }
351
352 #[test]
353 fn test_udp_stream_ipv4() {
354 use crate::tests::udp_stream_test;
355 let io_loop = Runtime::new().expect("failed to create tokio runtime");
356 io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V4(
357 Ipv4Addr::new(127, 0, 0, 1),
358 )));
359 }
360
361 #[test]
362 #[cfg(not(target_os = "linux"))] fn test_udp_stream_ipv6() {
364 use crate::tests::udp_stream_test;
365 let io_loop = Runtime::new().expect("failed to create tokio runtime");
366 io_loop.block_on(udp_stream_test::<TokioUdpSocket>(IpAddr::V6(
367 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
368 )));
369 }
370}