fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15/// A socket in a remote FDomain.
16#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21/// Disposition of a socket.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24    WriteEnabled,
25    WriteDisabled,
26}
27
28impl SocketDisposition {
29    /// Convert to a proto::SocketDisposition
30    fn proto(self) -> proto::SocketDisposition {
31        match self {
32            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34        }
35    }
36}
37
38impl Socket {
39    /// Read up to the given buffer's length from the socket.
40    pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41        let client = Arc::downgrade(&self.0.client());
42        let handle = self.0.proto();
43
44        futures::future::poll_fn(move |ctx| {
45            client
46                .upgrade()
47                .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
48                .poll_socket(handle, ctx, buf)
49        })
50    }
51
52    /// Polls on reading this socket. Not to be confused with `AsyncRead::poll_read` which has a
53    /// different method of error reporting. That will handle errors, whereas this will return them
54    /// directly.
55    pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
56        let client = self.0.client();
57        client.poll_socket(self.0.proto(), ctx, out)
58    }
59
60    /// Write all of the given data to the socket.
61    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
62        let data = bytes.to_vec();
63        let len = bytes.len();
64        let hid = self.0.proto();
65
66        let client = self.0.client();
67        client
68            .transaction(
69                ordinals::WRITE_SOCKET,
70                proto::SocketWriteSocketRequest { handle: hid, data },
71                move |x| Responder::WriteSocket(x),
72            )
73            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
74    }
75
76    /// Set the disposition of this socket and/or its peer.
77    pub fn set_socket_disposition(
78        &self,
79        disposition: Option<SocketDisposition>,
80        disposition_peer: Option<SocketDisposition>,
81    ) -> impl Future<Output = Result<(), Error>> {
82        let disposition =
83            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
84        let disposition_peer = disposition_peer
85            .map(SocketDisposition::proto)
86            .unwrap_or(proto::SocketDisposition::NoChange);
87        let client = self.0.client();
88        let handle = self.0.proto();
89        client.transaction(
90            ordinals::SET_SOCKET_DISPOSITION,
91            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
92            Responder::SetSocketDisposition,
93        )
94    }
95
96    /// Split this socket into a streaming reader and a writer. This is more
97    /// efficient on the read side if you intend to consume all of the data from
98    /// the socket. However it will prevent you from transferring the handle in
99    /// the future. It also means data will build up in the buffer, so it may
100    /// lead to memory issues if you don't intend to use the data from the
101    /// socket as fast as it comes.
102    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
103        self.0.client().start_socket_streaming(self.0.proto())?;
104
105        let a = Arc::new(self);
106        let b = Arc::clone(&a);
107
108        Ok((SocketReadStream(a), SocketWriter(b)))
109    }
110}
111
112/// A write-only handle to a socket.
113pub struct SocketWriter(Arc<Socket>);
114
115impl SocketWriter {
116    /// Write all of the given data to the socket.
117    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
118        self.0.write_all(bytes)
119    }
120}
121
122/// A stream of data issuing from a socket.
123pub struct SocketReadStream(Arc<Socket>);
124
125impl SocketReadStream {
126    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
127    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
128        self.0.read(buf).await
129    }
130}
131
132impl Drop for SocketReadStream {
133    fn drop(&mut self) {
134        if let Some(client) = self.0.0.client.upgrade() {
135            client.stop_socket_streaming(self.0.0.proto());
136        }
137    }
138}
139
140/// Wrapper for [`Client::poll_socket`] that adapts the return value semantics
141/// to what Unix prescribes, and what `futures::io` thus prescribes.
142fn convert_poll_res_to_async_read(
143    poll_res: Poll<Result<usize, Error>>,
144) -> Poll<std::io::Result<usize>> {
145    let res = ready!(poll_res).or_else(|e| match e {
146        Error::FDomain(proto::Error::TargetError(e))
147            if e == zx_status::Status::PEER_CLOSED.into_raw() =>
148        {
149            Ok(0)
150        }
151        other => Err(std::io::Error::other(other)),
152    });
153    Poll::Ready(res)
154}
155
156impl futures::AsyncRead for Socket {
157    fn poll_read(
158        self: Pin<&mut Self>,
159        cx: &mut Context<'_>,
160        buf: &mut [u8],
161    ) -> Poll<std::io::Result<usize>> {
162        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
163    }
164}
165
166impl futures::AsyncRead for &Socket {
167    fn poll_read(
168        self: Pin<&mut Self>,
169        cx: &mut Context<'_>,
170        buf: &mut [u8],
171    ) -> Poll<std::io::Result<usize>> {
172        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
173    }
174}
175
176impl futures::AsyncWrite for Socket {
177    fn poll_write(
178        self: Pin<&mut Self>,
179        _cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<std::io::Result<usize>> {
182        let _ = self.write_all(buf);
183        Poll::Ready(Ok(buf.len()))
184    }
185
186    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
187        Poll::Ready(Ok(()))
188    }
189
190    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
191        self.0 = Handle::invalid();
192        Poll::Ready(Ok(()))
193    }
194}