tcp_stream_ext/
lib.rs
1use std::convert::TryInto as _;
6use std::os::unix::io::AsRawFd as _;
7
8pub trait TcpStreamExt {
9 fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()>;
12
13 fn user_timeout(&self) -> Result<std::time::Duration, Error>;
15}
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19 #[error("netstack returned an error: {0}")]
20 Netstack(std::io::Error),
21 #[error("netstack returned a negative duration: {0}")]
22 NegativeDuration(i32),
23}
24
25impl TcpStreamExt for std::net::TcpStream {
26 fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()> {
27 set_tcp_option(
28 self,
29 libc::TCP_USER_TIMEOUT,
30 timeout.as_millis().try_into().map_err(|std::num::TryFromIntError { .. }| {
31 std::io::Error::new(
32 std::io::ErrorKind::InvalidInput,
33 "user timeout duration milliseconds does not fit in an i32",
34 )
35 })?,
36 )
37 }
38
39 fn user_timeout(&self) -> Result<std::time::Duration, Error> {
40 get_tcp_option(self, libc::TCP_USER_TIMEOUT).map_err(Error::Netstack).and_then(|timeout| {
41 Ok(std::time::Duration::from_millis(
42 timeout
43 .try_into()
44 .map_err(|std::num::TryFromIntError { .. }| Error::NegativeDuration(timeout))?,
45 ))
46 })
47 }
48}
49
50fn set_option(
51 stream: &std::net::TcpStream,
52 option_level: libc::c_int,
53 option_name: libc::c_int,
54 option_value: i32,
55) -> std::io::Result<()> {
56 let fd = stream.as_raw_fd();
57 if unsafe {
59 libc::setsockopt(
60 fd,
61 option_level,
62 option_name,
63 &option_value as *const _ as *const libc::c_void,
64 std::mem::size_of_val(&option_value) as libc::socklen_t,
65 )
66 } != 0
67 {
68 Err(std::io::Error::last_os_error())?;
69 }
70 Ok(())
71}
72
73fn set_tcp_option(
74 stream: &std::net::TcpStream,
75 option_name: libc::c_int,
76 option_value: i32,
77) -> std::io::Result<()> {
78 set_option(stream, libc::IPPROTO_TCP, option_name, option_value)
79}
80
81fn get_option(
82 stream: &std::net::TcpStream,
83 option_level: libc::c_int,
84 option_name: libc::c_int,
85) -> std::io::Result<i32> {
86 let fd = stream.as_raw_fd();
87 let mut option_value = 0i32;
88 let mut option_value_size = std::mem::size_of_val(&option_value) as libc::socklen_t;
89 if unsafe {
91 libc::getsockopt(
92 fd,
93 option_level,
94 option_name,
95 &mut option_value as *mut _ as *mut libc::c_void,
96 &mut option_value_size,
97 )
98 } != 0
99 {
100 Err(std::io::Error::last_os_error())?;
101 }
102 Ok(option_value)
103}
104
105fn get_tcp_option(stream: &std::net::TcpStream, option_name: libc::c_int) -> std::io::Result<i32> {
106 get_option(stream, libc::IPPROTO_TCP, option_name)
107}
108
109#[cfg(test)]
110mod test {
111 use super::TcpStreamExt as _;
112
113 fn stream() -> std::io::Result<std::net::TcpStream> {
114 use socket2::{Domain, Socket, Type};
115
116 let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
117 Ok(socket.into())
118 }
119
120 proptest::proptest! {
121 #[test]
122 fn user_timeout_roundtrip
123 (timeout in 0..=i32::max_value() as u64)
124 {
125 let stream = stream().expect("failed to create stream");
126 let timeout = std::time::Duration::from_millis(timeout);
127
128 let () = stream.set_user_timeout(timeout).expect("failed to set user timeout");
129 proptest::prop_assert_eq!(stream.user_timeout().expect("failed to get user timeout"), timeout);
130 }
131 }
132}