fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15/// A socket in a remote FDomain.
16#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21/// Disposition of a socket.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24    WriteEnabled,
25    WriteDisabled,
26}
27
28impl SocketDisposition {
29    /// Convert to a proto::SocketDisposition
30    fn proto(self) -> proto::SocketDisposition {
31        match self {
32            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34        }
35    }
36}
37
38impl Socket {
39    /// Read up to the given buffer's length from the socket.
40    pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41        let client = self.0.client();
42        let handle = self.0.proto();
43
44        futures::future::poll_fn(move |ctx| client.poll_socket(handle, ctx, buf))
45    }
46
47    /// Write all of the given data to the socket.
48    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
49        let data = bytes.to_vec();
50        let len = bytes.len();
51        let hid = self.0.proto();
52
53        let client = self.0.client();
54        client
55            .transaction(
56                ordinals::WRITE_SOCKET,
57                proto::SocketWriteSocketRequest { handle: hid, data },
58                move |x| Responder::WriteSocket(x),
59            )
60            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
61    }
62
63    /// Set the disposition of this socket and/or its peer.
64    pub fn set_socket_disposition(
65        &self,
66        disposition: Option<SocketDisposition>,
67        disposition_peer: Option<SocketDisposition>,
68    ) -> impl Future<Output = Result<(), Error>> {
69        let disposition =
70            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
71        let disposition_peer = disposition_peer
72            .map(SocketDisposition::proto)
73            .unwrap_or(proto::SocketDisposition::NoChange);
74        let client = self.0.client();
75        let handle = self.0.proto();
76        client.transaction(
77            ordinals::SET_SOCKET_DISPOSITION,
78            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
79            Responder::SetSocketDisposition,
80        )
81    }
82
83    /// Split this socket into a streaming reader and a writer. This is more
84    /// efficient on the read side if you intend to consume all of the data from
85    /// the socket. However it will prevent you from transferring the handle in
86    /// the future. It also means data will build up in the buffer, so it may
87    /// lead to memory issues if you don't intend to use the data from the
88    /// socket as fast as it comes.
89    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
90        self.0.client().start_socket_streaming(self.0.proto())?;
91
92        let a = Arc::new(self);
93        let b = Arc::clone(&a);
94
95        Ok((SocketReadStream(a), SocketWriter(b)))
96    }
97}
98
99/// A write-only handle to a socket.
100pub struct SocketWriter(Arc<Socket>);
101
102impl SocketWriter {
103    /// Write all of the given data to the socket.
104    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
105        self.0.write_all(bytes)
106    }
107}
108
109/// A stream of data issuing from a socket.
110pub struct SocketReadStream(Arc<Socket>);
111
112impl SocketReadStream {
113    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
114    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
115        self.0.read(buf).await
116    }
117}
118
119impl Drop for SocketReadStream {
120    fn drop(&mut self) {
121        if let Some(client) = self.0.0.client.upgrade() {
122            client.stop_socket_streaming(self.0.0.proto());
123        }
124    }
125}
126
127/// Wrapper for [`Client::poll_socket`] that adapts the return value semantics
128/// to what Unix prescribes, and what `futures::io` thus prescribes.
129fn async_read_poll_socket(
130    client: Arc<crate::Client>,
131    proto: proto::HandleId,
132    cx: &mut Context<'_>,
133    buf: &mut [u8],
134) -> Poll<std::io::Result<usize>> {
135    let res = ready!(client.poll_socket(proto, cx, buf)).or_else(|e| match e {
136        Error::FDomain(proto::Error::TargetError(e))
137            if e == zx_status::Status::PEER_CLOSED.into_raw() =>
138        {
139            Ok(0)
140        }
141        other => Err(std::io::Error::other(other)),
142    });
143    Poll::Ready(res)
144}
145
146impl futures::AsyncRead for Socket {
147    fn poll_read(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150        buf: &mut [u8],
151    ) -> Poll<std::io::Result<usize>> {
152        let client = self.0.client();
153        async_read_poll_socket(client, self.0.proto(), cx, buf)
154    }
155}
156
157impl futures::AsyncRead for &Socket {
158    fn poll_read(
159        self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161        buf: &mut [u8],
162    ) -> Poll<std::io::Result<usize>> {
163        let client = self.0.client();
164        async_read_poll_socket(client, self.0.proto(), cx, buf)
165    }
166}
167
168impl futures::AsyncWrite for Socket {
169    fn poll_write(
170        self: Pin<&mut Self>,
171        _cx: &mut Context<'_>,
172        buf: &[u8],
173    ) -> Poll<std::io::Result<usize>> {
174        let _ = self.write_all(buf);
175        Poll::Ready(Ok(buf.len()))
176    }
177
178    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
179        Poll::Ready(Ok(()))
180    }
181
182    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
183        self.0 = Handle::invalid();
184        Poll::Ready(Ok(()))
185    }
186}