1use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24 WriteEnabled,
25 WriteDisabled,
26}
27
28impl SocketDisposition {
29 fn proto(self) -> proto::SocketDisposition {
31 match self {
32 SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33 SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34 }
35 }
36}
37
38impl Socket {
39 pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41 let client = self.0.client();
42 let handle = self.0.proto();
43
44 futures::future::poll_fn(move |ctx| client.poll_socket(handle, ctx, buf))
45 }
46
47 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
49 let data = bytes.to_vec();
50 let len = bytes.len();
51 let hid = self.0.proto();
52
53 let client = self.0.client();
54 client
55 .transaction(
56 ordinals::WRITE_SOCKET,
57 proto::SocketWriteSocketRequest { handle: hid, data },
58 move |x| Responder::WriteSocket(x),
59 )
60 .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
61 }
62
63 pub fn set_socket_disposition(
65 &self,
66 disposition: Option<SocketDisposition>,
67 disposition_peer: Option<SocketDisposition>,
68 ) -> impl Future<Output = Result<(), Error>> {
69 let disposition =
70 disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
71 let disposition_peer = disposition_peer
72 .map(SocketDisposition::proto)
73 .unwrap_or(proto::SocketDisposition::NoChange);
74 let client = self.0.client();
75 let handle = self.0.proto();
76 client.transaction(
77 ordinals::SET_SOCKET_DISPOSITION,
78 proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
79 Responder::SetSocketDisposition,
80 )
81 }
82
83 pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
90 self.0.client().start_socket_streaming(self.0.proto())?;
91
92 let a = Arc::new(self);
93 let b = Arc::clone(&a);
94
95 Ok((SocketReadStream(a), SocketWriter(b)))
96 }
97}
98
99pub struct SocketWriter(Arc<Socket>);
101
102impl SocketWriter {
103 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
105 self.0.write_all(bytes)
106 }
107}
108
109pub struct SocketReadStream(Arc<Socket>);
111
112impl SocketReadStream {
113 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
115 self.0.read(buf).await
116 }
117}
118
119impl Drop for SocketReadStream {
120 fn drop(&mut self) {
121 if let Some(client) = self.0.0.client.upgrade() {
122 client.stop_socket_streaming(self.0.0.proto());
123 }
124 }
125}
126
127fn async_read_poll_socket(
130 client: Arc<crate::Client>,
131 proto: proto::HandleId,
132 cx: &mut Context<'_>,
133 buf: &mut [u8],
134) -> Poll<std::io::Result<usize>> {
135 let res = ready!(client.poll_socket(proto, cx, buf)).or_else(|e| match e {
136 Error::FDomain(proto::Error::TargetError(e))
137 if e == zx_status::Status::PEER_CLOSED.into_raw() =>
138 {
139 Ok(0)
140 }
141 other => Err(std::io::Error::other(other)),
142 });
143 Poll::Ready(res)
144}
145
146impl futures::AsyncRead for Socket {
147 fn poll_read(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 buf: &mut [u8],
151 ) -> Poll<std::io::Result<usize>> {
152 let client = self.0.client();
153 async_read_poll_socket(client, self.0.proto(), cx, buf)
154 }
155}
156
157impl futures::AsyncRead for &Socket {
158 fn poll_read(
159 self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 buf: &mut [u8],
162 ) -> Poll<std::io::Result<usize>> {
163 let client = self.0.client();
164 async_read_poll_socket(client, self.0.proto(), cx, buf)
165 }
166}
167
168impl futures::AsyncWrite for Socket {
169 fn poll_write(
170 self: Pin<&mut Self>,
171 _cx: &mut Context<'_>,
172 buf: &[u8],
173 ) -> Poll<std::io::Result<usize>> {
174 let _ = self.write_all(buf);
175 Poll::Ready(Ok(buf.len()))
176 }
177
178 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
179 Poll::Ready(Ok(()))
180 }
181
182 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
183 self.0 = Handle::invalid();
184 Poll::Ready(Ok(()))
185 }
186}