tcp_stream_ext/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::convert::TryInto as _;
6use std::os::unix::io::AsRawFd as _;
7
8pub trait TcpStreamExt {
9    /// Sets TCP_USER_TIMEOUT. Fuchsia supports `1..=i32::max_value()`
10    /// milliseconds.
11    fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()>;
12
13    /// Gets TCP_USER_TIMEOUT.
14    fn user_timeout(&self) -> Result<std::time::Duration, Error>;
15}
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19    #[error("netstack returned an error: {0}")]
20    Netstack(std::io::Error),
21    #[error("netstack returned a negative duration: {0}")]
22    NegativeDuration(i32),
23}
24
25impl TcpStreamExt for std::net::TcpStream {
26    fn set_user_timeout(&self, timeout: std::time::Duration) -> std::io::Result<()> {
27        set_tcp_option(
28            self,
29            libc::TCP_USER_TIMEOUT,
30            timeout.as_millis().try_into().map_err(|std::num::TryFromIntError { .. }| {
31                std::io::Error::new(
32                    std::io::ErrorKind::InvalidInput,
33                    "user timeout duration milliseconds does not fit in an i32",
34                )
35            })?,
36        )
37    }
38
39    fn user_timeout(&self) -> Result<std::time::Duration, Error> {
40        get_tcp_option(self, libc::TCP_USER_TIMEOUT).map_err(Error::Netstack).and_then(|timeout| {
41            Ok(std::time::Duration::from_millis(
42                timeout
43                    .try_into()
44                    .map_err(|std::num::TryFromIntError { .. }| Error::NegativeDuration(timeout))?,
45            ))
46        })
47    }
48}
49
50fn set_option(
51    stream: &std::net::TcpStream,
52    option_level: libc::c_int,
53    option_name: libc::c_int,
54    option_value: i32,
55) -> std::io::Result<()> {
56    let fd = stream.as_raw_fd();
57    // Safe because `setsockopt` does not retain memory passed to it.
58    if unsafe {
59        libc::setsockopt(
60            fd,
61            option_level,
62            option_name,
63            &option_value as *const _ as *const libc::c_void,
64            std::mem::size_of_val(&option_value) as libc::socklen_t,
65        )
66    } != 0
67    {
68        Err(std::io::Error::last_os_error())?;
69    }
70    Ok(())
71}
72
73fn set_tcp_option(
74    stream: &std::net::TcpStream,
75    option_name: libc::c_int,
76    option_value: i32,
77) -> std::io::Result<()> {
78    set_option(stream, libc::IPPROTO_TCP, option_name, option_value)
79}
80
81fn get_option(
82    stream: &std::net::TcpStream,
83    option_level: libc::c_int,
84    option_name: libc::c_int,
85) -> std::io::Result<i32> {
86    let fd = stream.as_raw_fd();
87    let mut option_value = 0i32;
88    let mut option_value_size = std::mem::size_of_val(&option_value) as libc::socklen_t;
89    // Safe because `getsockopt` does not retain memory passed to it.
90    if unsafe {
91        libc::getsockopt(
92            fd,
93            option_level,
94            option_name,
95            &mut option_value as *mut _ as *mut libc::c_void,
96            &mut option_value_size,
97        )
98    } != 0
99    {
100        Err(std::io::Error::last_os_error())?;
101    }
102    Ok(option_value)
103}
104
105fn get_tcp_option(stream: &std::net::TcpStream, option_name: libc::c_int) -> std::io::Result<i32> {
106    get_option(stream, libc::IPPROTO_TCP, option_name)
107}
108
109#[cfg(test)]
110mod test {
111    use super::TcpStreamExt as _;
112
113    fn stream() -> std::io::Result<std::net::TcpStream> {
114        use socket2::{Domain, Socket, Type};
115
116        let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
117        Ok(socket.into())
118    }
119
120    proptest::proptest! {
121        #[test]
122        fn user_timeout_roundtrip
123            (timeout in 0..=i32::max_value() as u64)
124        {
125            let stream = stream().expect("failed to create stream");
126            let timeout = std::time::Duration::from_millis(timeout);
127
128            let () = stream.set_user_timeout(timeout).expect("failed to set user timeout");
129            proptest::prop_assert_eq!(stream.user_timeout().expect("failed to get user timeout"), timeout);
130        }
131    }
132}