1use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24 WriteEnabled,
25 WriteDisabled,
26}
27
28impl SocketDisposition {
29 fn proto(self) -> proto::SocketDisposition {
31 match self {
32 SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33 SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34 }
35 }
36}
37
38impl Socket {
39 pub fn fdomain_read<'a>(
41 &self,
42 buf: &'a mut [u8],
43 ) -> impl Future<Output = Result<usize, Error>> + 'a {
44 let client = Arc::downgrade(&self.0.client());
45 let handle = self.0.proto();
46
47 futures::future::poll_fn(move |ctx| {
48 client
49 .upgrade()
50 .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
51 .poll_socket(handle, ctx, buf)
52 })
53 }
54
55 pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
59 let client = self.0.client();
60 client.poll_socket(self.0.proto(), ctx, out)
61 }
62
63 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> + use<> {
65 let data = bytes.to_vec();
66 let len = bytes.len();
67 let hid = self.0.proto();
68
69 let client = self.0.client();
70 client
71 .transaction(
72 ordinals::WRITE_SOCKET,
73 proto::SocketWriteSocketRequest { handle: hid, data },
74 move |x| Responder::WriteSocket(x),
75 )
76 .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
77 }
78
79 pub fn set_socket_disposition(
81 &self,
82 disposition: Option<SocketDisposition>,
83 disposition_peer: Option<SocketDisposition>,
84 ) -> impl Future<Output = Result<(), Error>> {
85 let disposition =
86 disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
87 let disposition_peer = disposition_peer
88 .map(SocketDisposition::proto)
89 .unwrap_or(proto::SocketDisposition::NoChange);
90 let client = self.0.client();
91 let handle = self.0.proto();
92 client.transaction(
93 ordinals::SET_SOCKET_DISPOSITION,
94 proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
95 Responder::SetSocketDisposition,
96 )
97 }
98
99 pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
106 self.0.client().start_socket_streaming(self.0.proto())?;
107
108 let a = Arc::new(self);
109 let b = Arc::clone(&a);
110
111 Ok((SocketReadStream(a), SocketWriter(b)))
112 }
113}
114
115pub struct SocketWriter(Arc<Socket>);
117
118impl SocketWriter {
119 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
121 self.0.write_all(bytes)
122 }
123}
124
125pub struct SocketReadStream(Arc<Socket>);
127
128impl SocketReadStream {
129 pub async fn fdomain_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
131 self.0.fdomain_read(buf).await
132 }
133}
134
135impl futures::AsyncRead for SocketReadStream {
136 fn poll_read(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &mut [u8],
140 ) -> Poll<std::io::Result<usize>> {
141 convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
142 }
143}
144
145impl futures::AsyncRead for &SocketReadStream {
146 fn poll_read(
147 self: Pin<&mut Self>,
148 cx: &mut Context<'_>,
149 buf: &mut [u8],
150 ) -> Poll<std::io::Result<usize>> {
151 convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
152 }
153}
154
155impl Drop for SocketReadStream {
156 fn drop(&mut self) {
157 if let Some(client) = self.0.0.client.upgrade() {
158 client.stop_socket_streaming(self.0.0.proto());
159 }
160 }
161}
162
163fn convert_poll_res_to_async_read(
166 poll_res: Poll<Result<usize, Error>>,
167) -> Poll<std::io::Result<usize>> {
168 let res = ready!(poll_res).or_else(|e| match e {
169 Error::FDomain(proto::Error::TargetError(e))
170 if e == zx_status::Status::PEER_CLOSED.into_raw() =>
171 {
172 Ok(0)
173 }
174 other => Err(std::io::Error::other(other)),
175 });
176 Poll::Ready(res)
177}
178
179impl futures::AsyncRead for Socket {
180 fn poll_read(
181 self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 buf: &mut [u8],
184 ) -> Poll<std::io::Result<usize>> {
185 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
186 }
187}
188
189impl futures::AsyncRead for &Socket {
190 fn poll_read(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &mut [u8],
194 ) -> Poll<std::io::Result<usize>> {
195 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
196 }
197}
198
199impl futures::AsyncWrite for Socket {
200 fn poll_write(
201 self: Pin<&mut Self>,
202 _cx: &mut Context<'_>,
203 buf: &[u8],
204 ) -> Poll<std::io::Result<usize>> {
205 let _ = self.write_all(buf);
206 Poll::Ready(Ok(buf.len()))
207 }
208
209 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
210 Poll::Ready(Ok(()))
211 }
212
213 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
214 self.0 = Handle::invalid();
215 Poll::Ready(Ok(()))
216 }
217}