1use std::num::{NonZeroU32, TryFromIntError};
6use std::time::{Duration, Instant};
7use std::u64;
8
9use flex_fuchsia_developer_ffx_speedtest as fspeedtest;
10use futures::{AsyncReadExt, AsyncWriteExt as _};
11use thiserror::Error;
12
13pub struct Transfer {
14 pub socket: flex_client::Socket,
15 pub params: TransferParams,
16}
17
18#[derive(Debug, Clone)]
19pub struct TransferParams {
20 pub data_len: NonZeroU32,
21 pub buffer_len: NonZeroU32,
22}
23
24impl TryFrom<fspeedtest::TransferParams> for TransferParams {
25 type Error = TryFromIntError;
26 fn try_from(value: fspeedtest::TransferParams) -> Result<Self, Self::Error> {
27 let fspeedtest::TransferParams { len_bytes, buffer_bytes, __source_breaking } = value;
28 Ok(Self {
29 data_len: len_bytes.unwrap_or(fspeedtest::DEFAULT_TRANSFER_SIZE).try_into()?,
30 buffer_len: buffer_bytes.unwrap_or(fspeedtest::DEFAULT_BUFFER_SIZE).try_into()?,
31 })
32 }
33}
34
35impl TryFrom<TransferParams> for fspeedtest::TransferParams {
36 type Error = TryFromIntError;
37 fn try_from(value: TransferParams) -> Result<Self, Self::Error> {
38 let TransferParams { data_len, buffer_len } = value;
39 Ok(Self {
40 len_bytes: Some(data_len.try_into()?),
41 buffer_bytes: Some(buffer_len.try_into()?),
42 __source_breaking: fidl::marker::SourceBreaking,
43 })
44 }
45}
46
47#[derive(Debug)]
48pub struct Report {
49 pub duration: Duration,
50}
51
52impl From<Report> for fspeedtest::TransferReport {
53 fn from(value: Report) -> Self {
54 let Report { duration } = value;
55 Self {
56 duration_nsec: Some(duration.as_nanos().try_into().unwrap_or(u64::MAX)),
57 __source_breaking: fidl::marker::SourceBreaking,
58 }
59 }
60}
61
62#[derive(Error, Debug)]
63#[error("missing mandatory field")]
64pub struct MissingFieldError;
65
66impl TryFrom<fspeedtest::TransferReport> for Report {
67 type Error = MissingFieldError;
68
69 fn try_from(value: fspeedtest::TransferReport) -> Result<Self, Self::Error> {
70 let fspeedtest::TransferReport { duration_nsec, __source_breaking } = value;
71 Ok(Self { duration: Duration::from_nanos(duration_nsec.ok_or(MissingFieldError)?) })
72 }
73}
74
75#[derive(Error, Debug)]
76pub enum TransferError {
77 #[error(transparent)]
78 IntConversion(#[from] TryFromIntError),
79 #[error(transparent)]
80 Io(#[from] std::io::Error),
81 #[error(transparent)]
82 FDomain(#[from] fdomain_client::Error),
83 #[error("remote hung up before terminating transfer")]
84 Hangup,
85}
86
87impl Transfer {
88 pub async fn send(self) -> Result<Report, TransferError> {
89 let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
90 let mut socket = flex_client::socket_to_async(socket);
91 let buffer_len = usize::try_from(buffer_len.get())?;
92 let mut data_len = usize::try_from(data_len.get())?;
93 let buffer = vec![0xAA; buffer_len];
94 let start = Instant::now();
95 while data_len != 0 {
96 let send = buffer_len.min(data_len);
97 let written = socket.write(&buffer[..send]).await?;
98 data_len -= written;
99 }
100 let end = Instant::now();
101 Ok(Report { duration: end - start })
102 }
103
104 pub async fn receive(self) -> Result<Report, TransferError> {
105 let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
106 let mut socket = flex_client::socket_to_async(socket);
107 let buffer_len = usize::try_from(buffer_len.get())?;
108 let mut data_len = usize::try_from(data_len.get())?;
109 let mut buffer = vec![0x00; buffer_len];
110 let start = Instant::now();
111 while data_len != 0 {
112 let recv = buffer_len.min(data_len);
113 let recv = AsyncReadExt::read(&mut socket, &mut buffer[..recv]).await?;
114 if recv == 0 {
115 return Err(TransferError::Hangup);
116 }
117 data_len -= recv;
118 }
119 let end = Instant::now();
120 Ok(Report { duration: end - start })
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 use assert_matches::assert_matches;
129
130 #[fuchsia::test]
131 async fn receive_hangup() {
132 #[cfg(feature = "fdomain")]
133 let client = fdomain_local::local_client(|| Err(zx_status::Status::NOT_SUPPORTED));
134 #[cfg(not(feature = "fdomain"))]
135 let client = fidl::endpoints::ZirconClient;
136 let (socket, _) = client.create_stream_socket();
137 let result = Transfer {
138 socket,
139 params: TransferParams {
140 data_len: NonZeroU32::new(10).unwrap(),
141 buffer_len: NonZeroU32::new(100).unwrap(),
142 },
143 }
144 .receive()
145 .await;
146
147 assert_matches!(result, Err(TransferError::Hangup));
148 }
149}