trust_dns_proto/tests/
tcp.rs

1use std::io::{Read, Write};
2use std::net::{IpAddr, SocketAddr};
3use std::sync::{atomic::AtomicBool, Arc};
4
5use futures_util::stream::StreamExt;
6
7use crate::error::ProtoError;
8use crate::tcp::{Connect, TcpClientStream, TcpStream};
9use crate::xfer::dns_handle::DnsStreamHandle;
10use crate::xfer::SerialMessage;
11use crate::{Executor, Time};
12
13const TEST_BYTES: &[u8; 8] = b"DEADBEEF";
14const TEST_BYTES_LEN: usize = 8;
15const SEND_RECV_TIMES: usize = 4;
16
17fn tcp_server_setup(
18    server_name: &str,
19    server_addr: IpAddr,
20) -> (Arc<AtomicBool>, std::thread::JoinHandle<()>, SocketAddr) {
21    let succeeded = Arc::new(AtomicBool::new(false));
22    let succeeded_clone = succeeded.clone();
23    std::thread::Builder::new()
24        .name("thread_killer".to_string())
25        .spawn(move || {
26            let succeeded = succeeded_clone;
27            for _ in 0..15 {
28                std::thread::sleep(std::time::Duration::from_secs(1));
29                if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
30                    return;
31                }
32            }
33
34            println!("Thread Killer has been awoken, killing process");
35            std::process::exit(-1);
36        })
37        .expect("Thread spawning failed");
38
39    // TODO: need a timeout on listen
40    let server = std::net::TcpListener::bind(SocketAddr::new(server_addr, 0))
41        .expect("Unable to bind a TCP socket");
42    let server_addr = server.local_addr().unwrap();
43
44    // an in and out server
45    let server_handle = std::thread::Builder::new()
46        .name(server_name.to_string())
47        .spawn(move || {
48            let (mut socket, _) = server.accept().expect("accept failed");
49
50            socket
51                .set_read_timeout(Some(std::time::Duration::from_secs(5)))
52                .unwrap(); // should receive something within 5 seconds...
53            socket
54                .set_write_timeout(Some(std::time::Duration::from_secs(5)))
55                .unwrap(); // should receive something within 5 seconds...
56
57            for _ in 0..SEND_RECV_TIMES {
58                // wait for some bytes...
59                let mut len_bytes = [0_u8; 2];
60                socket
61                    .read_exact(&mut len_bytes)
62                    .expect("SERVER: receive failed");
63                let length =
64                    u16::from(len_bytes[0]) << 8 & 0xFF00 | u16::from(len_bytes[1]) & 0x00FF;
65                assert_eq!(length as usize, TEST_BYTES_LEN);
66
67                let mut buffer = [0_u8; TEST_BYTES_LEN];
68                socket.read_exact(&mut buffer).unwrap();
69
70                // println!("read bytes iter: {}", i);
71                assert_eq!(&buffer, TEST_BYTES);
72
73                // bounce them right back...
74                socket
75                    .write_all(&len_bytes)
76                    .expect("SERVER: send length failed");
77                socket
78                    .write_all(&buffer)
79                    .expect("SERVER: send buffer failed");
80                // println!("wrote bytes iter: {}", i);
81                std::thread::yield_now();
82            }
83        })
84        .unwrap();
85    (succeeded, server_handle, server_addr)
86}
87
88/// Test tcp_stream.
89pub fn tcp_stream_test<S: Connect, E: Executor, TE: Time>(server_addr: IpAddr, mut exec: E) {
90    let (succeeded, server_handle, server_addr) =
91        tcp_server_setup("test_tcp_stream:server", server_addr);
92
93    // setup the client, which is going to run on the testing thread...
94
95    // the tests should run within 5 seconds... right?
96    // TODO: add timeout here, so that test never hangs...
97    // let timeout = Timeout::new(Duration::from_secs(5));
98    let (stream, mut sender) = TcpStream::<S>::new::<ProtoError>(server_addr);
99
100    let mut stream = exec.block_on(stream).expect("run failed to get stream");
101
102    for _ in 0..SEND_RECV_TIMES {
103        // test once
104        sender
105            .send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
106            .expect("send failed");
107
108        let (buffer, stream_tmp) = exec.block_on(stream.into_future());
109        stream = stream_tmp;
110        let message = buffer
111            .expect("no buffer received")
112            .expect("error receiving buffer");
113        assert_eq!(message.bytes(), TEST_BYTES);
114    }
115
116    succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
117    server_handle.join().expect("server thread failed");
118}
119
120/// Test tcp_client_stream.
121pub fn tcp_client_stream_test<S: Connect, E: Executor, TE: Time + 'static>(
122    server_addr: IpAddr,
123    mut exec: E,
124) {
125    let (succeeded, server_handle, server_addr) =
126        tcp_server_setup("test_tcp_client_stream:server", server_addr);
127
128    // setup the client, which is going to run on the testing thread...
129
130    // the tests should run within 5 seconds... right?
131    // TODO: add timeout here, so that test never hangs...
132    // let timeout = Timeout::new(Duration::from_secs(5));
133    let (stream, mut sender) = TcpClientStream::<S>::new(server_addr);
134
135    let mut stream = exec.block_on(stream).expect("run failed to get stream");
136
137    for _ in 0..SEND_RECV_TIMES {
138        // test once
139        sender
140            .send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
141            .expect("send failed");
142        let (buffer, stream_tmp) = exec.block_on(stream.into_future());
143        stream = stream_tmp;
144        let buffer = buffer
145            .expect("no buffer received")
146            .expect("error receiving buffer");
147        assert_eq!(buffer.bytes(), TEST_BYTES);
148    }
149
150    succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
151    server_handle.join().expect("server thread failed");
152}