1use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24 WriteEnabled,
25 WriteDisabled,
26}
27
28impl SocketDisposition {
29 fn proto(self) -> proto::SocketDisposition {
31 match self {
32 SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33 SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34 }
35 }
36}
37
38impl Socket {
39 pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41 let client = Arc::downgrade(&self.0.client());
42 let handle = self.0.proto();
43
44 futures::future::poll_fn(move |ctx| {
45 client
46 .upgrade()
47 .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
48 .poll_socket(handle, ctx, buf)
49 })
50 }
51
52 pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
56 let client = self.0.client();
57 client.poll_socket(self.0.proto(), ctx, out)
58 }
59
60 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
62 let data = bytes.to_vec();
63 let len = bytes.len();
64 let hid = self.0.proto();
65
66 let client = self.0.client();
67 client
68 .transaction(
69 ordinals::WRITE_SOCKET,
70 proto::SocketWriteSocketRequest { handle: hid, data },
71 move |x| Responder::WriteSocket(x),
72 )
73 .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
74 }
75
76 pub fn set_socket_disposition(
78 &self,
79 disposition: Option<SocketDisposition>,
80 disposition_peer: Option<SocketDisposition>,
81 ) -> impl Future<Output = Result<(), Error>> {
82 let disposition =
83 disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
84 let disposition_peer = disposition_peer
85 .map(SocketDisposition::proto)
86 .unwrap_or(proto::SocketDisposition::NoChange);
87 let client = self.0.client();
88 let handle = self.0.proto();
89 client.transaction(
90 ordinals::SET_SOCKET_DISPOSITION,
91 proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
92 Responder::SetSocketDisposition,
93 )
94 }
95
96 pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
103 self.0.client().start_socket_streaming(self.0.proto())?;
104
105 let a = Arc::new(self);
106 let b = Arc::clone(&a);
107
108 Ok((SocketReadStream(a), SocketWriter(b)))
109 }
110}
111
112pub struct SocketWriter(Arc<Socket>);
114
115impl SocketWriter {
116 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
118 self.0.write_all(bytes)
119 }
120}
121
122pub struct SocketReadStream(Arc<Socket>);
124
125impl SocketReadStream {
126 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
128 self.0.read(buf).await
129 }
130}
131
132impl Drop for SocketReadStream {
133 fn drop(&mut self) {
134 if let Some(client) = self.0.0.client.upgrade() {
135 client.stop_socket_streaming(self.0.0.proto());
136 }
137 }
138}
139
140fn convert_poll_res_to_async_read(
143 poll_res: Poll<Result<usize, Error>>,
144) -> Poll<std::io::Result<usize>> {
145 let res = ready!(poll_res).or_else(|e| match e {
146 Error::FDomain(proto::Error::TargetError(e))
147 if e == zx_status::Status::PEER_CLOSED.into_raw() =>
148 {
149 Ok(0)
150 }
151 other => Err(std::io::Error::other(other)),
152 });
153 Poll::Ready(res)
154}
155
156impl futures::AsyncRead for Socket {
157 fn poll_read(
158 self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 buf: &mut [u8],
161 ) -> Poll<std::io::Result<usize>> {
162 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
163 }
164}
165
166impl futures::AsyncRead for &Socket {
167 fn poll_read(
168 self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 buf: &mut [u8],
171 ) -> Poll<std::io::Result<usize>> {
172 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
173 }
174}
175
176impl futures::AsyncWrite for Socket {
177 fn poll_write(
178 self: Pin<&mut Self>,
179 _cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<std::io::Result<usize>> {
182 let _ = self.write_all(buf);
183 Poll::Ready(Ok(buf.len()))
184 }
185
186 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
187 Poll::Ready(Ok(()))
188 }
189
190 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
191 self.0 = Handle::invalid();
192 Poll::Ready(Ok(()))
193 }
194}