1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::{convert::TryInto as _, os::unix::io::AsRawFd as _};
pub trait TcpStreamExt {
fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()>;
fn user_timeout(&self) -> Result<std::time::Duration, Error>;
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("netstack returned an error: {0}")]
Netstack(std::io::Error),
#[error("netstack returned a negative duration: {0}")]
NegativeDuration(i32),
}
impl TcpStreamExt for std::net::TcpStream {
fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()> {
set_tcp_option(
self,
libc::TCP_USER_TIMEOUT,
timeout.as_millis().try_into().map_err(|std::num::TryFromIntError { .. }| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"user timeout duration milliseconds does not fit in an i32",
)
})?,
)
}
fn user_timeout(&self) -> Result<std::time::Duration, Error> {
get_tcp_option(self, libc::TCP_USER_TIMEOUT).map_err(Error::Netstack).and_then(|timeout| {
Ok(std::time::Duration::from_millis(
timeout
.try_into()
.map_err(|std::num::TryFromIntError { .. }| Error::NegativeDuration(timeout))?,
))
})
}
}
fn set_option(
stream: &std::net::TcpStream,
option_level: libc::c_int,
option_name: libc::c_int,
option_value: i32,
) -> std::io::Result<()> {
let fd = stream.as_raw_fd();
if unsafe {
libc::setsockopt(
fd,
option_level,
option_name,
&option_value as *const _ as *const libc::c_void,
std::mem::size_of_val(&option_value) as libc::socklen_t,
)
} != 0
{
Err(std::io::Error::last_os_error())?;
}
Ok(())
}
fn set_tcp_option(
stream: &std::net::TcpStream,
option_name: libc::c_int,
option_value: i32,
) -> std::io::Result<()> {
set_option(stream, libc::IPPROTO_TCP, option_name, option_value)
}
fn get_option(
stream: &std::net::TcpStream,
option_level: libc::c_int,
option_name: libc::c_int,
) -> std::io::Result<i32> {
let fd = stream.as_raw_fd();
let mut option_value = 0i32;
let mut option_value_size = std::mem::size_of_val(&option_value) as libc::socklen_t;
if unsafe {
libc::getsockopt(
fd,
option_level,
option_name,
&mut option_value as *mut _ as *mut libc::c_void,
&mut option_value_size,
)
} != 0
{
Err(std::io::Error::last_os_error())?;
}
Ok(option_value)
}
fn get_tcp_option(stream: &std::net::TcpStream, option_name: libc::c_int) -> std::io::Result<i32> {
get_option(stream, libc::IPPROTO_TCP, option_name)
}
#[cfg(test)]
mod test {
use super::TcpStreamExt as _;
fn stream() -> std::io::Result<std::net::TcpStream> {
use socket2::{Domain, Socket, Type};
let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
Ok(socket.into())
}
proptest::proptest! {
#[test]
fn user_timeout_roundtrip
(timeout in 0..=i32::max_value() as u64)
{
let stream = stream().expect("failed to create stream");
let timeout = std::time::Duration::from_millis(timeout);
let () = stream.set_user_timeout(timeout).expect("failed to set user timeout");
proptest::prop_assert_eq!(stream.user_timeout().expect("failed to get user timeout"), timeout);
}
}
}