speedtest/
socket.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::num::{NonZeroU32, TryFromIntError};
6use std::time::{Duration, Instant};
7use std::u64;
8
9use flex_fuchsia_developer_ffx_speedtest as fspeedtest;
10use futures::{AsyncReadExt, AsyncWriteExt as _};
11use thiserror::Error;
12
13pub struct Transfer {
14    pub socket: flex_client::Socket,
15    pub params: TransferParams,
16}
17
18#[derive(Debug, Clone)]
19pub struct TransferParams {
20    pub data_len: NonZeroU32,
21    pub buffer_len: NonZeroU32,
22}
23
24impl TryFrom<fspeedtest::TransferParams> for TransferParams {
25    type Error = TryFromIntError;
26    fn try_from(value: fspeedtest::TransferParams) -> Result<Self, Self::Error> {
27        let fspeedtest::TransferParams { len_bytes, buffer_bytes, __source_breaking } = value;
28        Ok(Self {
29            data_len: len_bytes.unwrap_or(fspeedtest::DEFAULT_TRANSFER_SIZE).try_into()?,
30            buffer_len: buffer_bytes.unwrap_or(fspeedtest::DEFAULT_BUFFER_SIZE).try_into()?,
31        })
32    }
33}
34
35impl TryFrom<TransferParams> for fspeedtest::TransferParams {
36    type Error = TryFromIntError;
37    fn try_from(value: TransferParams) -> Result<Self, Self::Error> {
38        let TransferParams { data_len, buffer_len } = value;
39        Ok(Self {
40            len_bytes: Some(data_len.try_into()?),
41            buffer_bytes: Some(buffer_len.try_into()?),
42            __source_breaking: fidl::marker::SourceBreaking,
43        })
44    }
45}
46
47#[derive(Debug)]
48pub struct Report {
49    pub duration: Duration,
50}
51
52impl From<Report> for fspeedtest::TransferReport {
53    fn from(value: Report) -> Self {
54        let Report { duration } = value;
55        Self {
56            duration_nsec: Some(duration.as_nanos().try_into().unwrap_or(u64::MAX)),
57            __source_breaking: fidl::marker::SourceBreaking,
58        }
59    }
60}
61
62#[derive(Error, Debug)]
63#[error("missing mandatory field")]
64pub struct MissingFieldError;
65
66impl TryFrom<fspeedtest::TransferReport> for Report {
67    type Error = MissingFieldError;
68
69    fn try_from(value: fspeedtest::TransferReport) -> Result<Self, Self::Error> {
70        let fspeedtest::TransferReport { duration_nsec, __source_breaking } = value;
71        Ok(Self { duration: Duration::from_nanos(duration_nsec.ok_or(MissingFieldError)?) })
72    }
73}
74
75#[derive(Error, Debug)]
76pub enum TransferError {
77    #[error(transparent)]
78    IntConversion(#[from] TryFromIntError),
79    #[error(transparent)]
80    Io(#[from] std::io::Error),
81    #[error(transparent)]
82    FDomain(#[from] fdomain_client::Error),
83    #[error("remote hung up before terminating transfer")]
84    Hangup,
85}
86
87impl Transfer {
88    pub async fn send(self) -> Result<Report, TransferError> {
89        let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
90        let mut socket = flex_client::socket_to_async(socket);
91        let buffer_len = usize::try_from(buffer_len.get())?;
92        let mut data_len = usize::try_from(data_len.get())?;
93        let buffer = vec![0xAA; buffer_len];
94        let start = Instant::now();
95        while data_len != 0 {
96            let send = buffer_len.min(data_len);
97            let written = socket.write(&buffer[..send]).await?;
98            data_len -= written;
99        }
100        let end = Instant::now();
101        Ok(Report { duration: end - start })
102    }
103
104    pub async fn receive(self) -> Result<Report, TransferError> {
105        let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
106        let mut socket = flex_client::socket_to_async(socket);
107        let buffer_len = usize::try_from(buffer_len.get())?;
108        let mut data_len = usize::try_from(data_len.get())?;
109        let mut buffer = vec![0x00; buffer_len];
110        let start = Instant::now();
111        while data_len != 0 {
112            let recv = buffer_len.min(data_len);
113            let recv = AsyncReadExt::read(&mut socket, &mut buffer[..recv]).await?;
114            if recv == 0 {
115                return Err(TransferError::Hangup);
116            }
117            data_len -= recv;
118        }
119        let end = Instant::now();
120        Ok(Report { duration: end - start })
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    use assert_matches::assert_matches;
129
130    #[fuchsia::test]
131    async fn receive_hangup() {
132        #[cfg(feature = "fdomain")]
133        let client = fdomain_local::local_client(|| Err(zx_status::Status::NOT_SUPPORTED));
134        #[cfg(not(feature = "fdomain"))]
135        let client = fidl::endpoints::ZirconClient;
136        let (socket, _) = client.create_stream_socket();
137        let result = Transfer {
138            socket,
139            params: TransferParams {
140                data_len: NonZeroU32::new(10).unwrap(),
141                buffer_len: NonZeroU32::new(100).unwrap(),
142            },
143        }
144        .receive()
145        .await;
146
147        assert_matches!(result, Err(TransferError::Hangup));
148    }
149}