fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15/// A socket in a remote FDomain.
16#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21/// Disposition of a socket.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24    WriteEnabled,
25    WriteDisabled,
26}
27
28impl SocketDisposition {
29    /// Convert to a proto::SocketDisposition
30    fn proto(self) -> proto::SocketDisposition {
31        match self {
32            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34        }
35    }
36}
37
38impl Socket {
39    /// Read up to the given buffer's length from the socket.
40    pub fn fdomain_read<'a>(
41        &self,
42        buf: &'a mut [u8],
43    ) -> impl Future<Output = Result<usize, Error>> + 'a {
44        let client = Arc::downgrade(&self.0.client());
45        let handle = self.0.proto();
46
47        futures::future::poll_fn(move |ctx| {
48            client
49                .upgrade()
50                .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
51                .poll_socket(handle, ctx, buf)
52        })
53    }
54
55    /// Polls on reading this socket. Not to be confused with `AsyncRead::poll_read` which has a
56    /// different method of error reporting. That will handle errors, whereas this will return them
57    /// directly.
58    pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
59        let client = self.0.client();
60        client.poll_socket(self.0.proto(), ctx, out)
61    }
62
63    /// Write all of the given data to the socket.
64    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> + use<> {
65        let data = bytes.to_vec();
66        let len = bytes.len();
67        let hid = self.0.proto();
68
69        let client = self.0.client();
70        client
71            .transaction(
72                ordinals::WRITE_SOCKET,
73                proto::SocketWriteSocketRequest { handle: hid, data },
74                move |x| Responder::WriteSocket(x),
75            )
76            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
77    }
78
79    /// Set the disposition of this socket and/or its peer.
80    pub fn set_socket_disposition(
81        &self,
82        disposition: Option<SocketDisposition>,
83        disposition_peer: Option<SocketDisposition>,
84    ) -> impl Future<Output = Result<(), Error>> {
85        let disposition =
86            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
87        let disposition_peer = disposition_peer
88            .map(SocketDisposition::proto)
89            .unwrap_or(proto::SocketDisposition::NoChange);
90        let client = self.0.client();
91        let handle = self.0.proto();
92        client.transaction(
93            ordinals::SET_SOCKET_DISPOSITION,
94            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
95            Responder::SetSocketDisposition,
96        )
97    }
98
99    /// Split this socket into a streaming reader and a writer. This is more
100    /// efficient on the read side if you intend to consume all of the data from
101    /// the socket. However it will prevent you from transferring the handle in
102    /// the future. It also means data will build up in the buffer, so it may
103    /// lead to memory issues if you don't intend to use the data from the
104    /// socket as fast as it comes.
105    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
106        self.0.client().start_socket_streaming(self.0.proto())?;
107
108        let a = Arc::new(self);
109        let b = Arc::clone(&a);
110
111        Ok((SocketReadStream(a), SocketWriter(b)))
112    }
113}
114
115/// A write-only handle to a socket.
116pub struct SocketWriter(Arc<Socket>);
117
118impl SocketWriter {
119    /// Write all of the given data to the socket.
120    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
121        self.0.write_all(bytes)
122    }
123}
124
125/// A stream of data issuing from a socket.
126pub struct SocketReadStream(Arc<Socket>);
127
128impl SocketReadStream {
129    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
130    pub async fn fdomain_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
131        self.0.fdomain_read(buf).await
132    }
133}
134
135impl futures::AsyncRead for SocketReadStream {
136    fn poll_read(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &mut [u8],
140    ) -> Poll<std::io::Result<usize>> {
141        convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
142    }
143}
144
145impl futures::AsyncRead for &SocketReadStream {
146    fn poll_read(
147        self: Pin<&mut Self>,
148        cx: &mut Context<'_>,
149        buf: &mut [u8],
150    ) -> Poll<std::io::Result<usize>> {
151        convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
152    }
153}
154
155impl Drop for SocketReadStream {
156    fn drop(&mut self) {
157        if let Some(client) = self.0.0.client.upgrade() {
158            client.stop_socket_streaming(self.0.0.proto());
159        }
160    }
161}
162
163/// Wrapper for [`Client::poll_socket`] that adapts the return value semantics
164/// to what Unix prescribes, and what `futures::io` thus prescribes.
165fn convert_poll_res_to_async_read(
166    poll_res: Poll<Result<usize, Error>>,
167) -> Poll<std::io::Result<usize>> {
168    let res = ready!(poll_res).or_else(|e| match e {
169        Error::FDomain(proto::Error::TargetError(e))
170            if e == zx_status::Status::PEER_CLOSED.into_raw() =>
171        {
172            Ok(0)
173        }
174        other => Err(std::io::Error::other(other)),
175    });
176    Poll::Ready(res)
177}
178
179impl futures::AsyncRead for Socket {
180    fn poll_read(
181        self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        buf: &mut [u8],
184    ) -> Poll<std::io::Result<usize>> {
185        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
186    }
187}
188
189impl futures::AsyncRead for &Socket {
190    fn poll_read(
191        self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193        buf: &mut [u8],
194    ) -> Poll<std::io::Result<usize>> {
195        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
196    }
197}
198
199impl futures::AsyncWrite for Socket {
200    fn poll_write(
201        self: Pin<&mut Self>,
202        _cx: &mut Context<'_>,
203        buf: &[u8],
204    ) -> Poll<std::io::Result<usize>> {
205        let _ = self.write_all(buf);
206        Poll::Ready(Ok(buf.len()))
207    }
208
209    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
210        Poll::Ready(Ok(()))
211    }
212
213    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
214        self.0 = Handle::invalid();
215        Poll::Ready(Ok(()))
216    }
217}