1use super::super::handle::ReadValue;
9use super::super::stream::{
10 Frame, StreamReader, StreamReaderBinder, StreamWriter, StreamWriterBinder,
11};
12use super::super::{
13 Proxy, ProxyTransferInitiationReceiver, Proxyable, ProxyableRW, RemoveFromProxyTable,
14 StreamRefSender,
15};
16use crate::labels::{NodeId, TransferKey};
17use crate::peer::{FramedStreamReader, FramedStreamWriter};
18use anyhow::{Context as _, Error, bail, format_err};
19use fuchsia_sync::Mutex;
20use futures::future::Either;
21use futures::prelude::*;
22use std::sync::Arc;
23use zx_status;
24
25#[derive(Debug)]
29enum FinishProxyLoopAction<Hdl: Proxyable> {
30 InitiateTransfer {
31 paired_handle: fidl::NullableHandle,
32 drain_stream: FramedStreamWriter,
33 stream_ref_sender: StreamRefSender,
34 stream_reader: StreamReader<Hdl::Message>,
35 },
36 FollowTransfer {
37 initiate_transfer: ProxyTransferInitiationReceiver,
38 new_destination_node: NodeId,
39 transfer_key: TransferKey,
40 stream_reader: StreamReader<Hdl::Message>,
41 },
42 Shutdown {
43 result: Result<(), zx_status::Status>,
44 stream_reader: StreamReader<Hdl::Message>,
45 },
46}
47
48struct FinishProxyLoopSender<Hdl: Proxyable> {
49 chan: futures::channel::oneshot::Sender<FinishProxyLoopAction<Hdl>>,
50}
51type FinishProxyLoopReceiver<Hdl> = futures::channel::oneshot::Receiver<FinishProxyLoopAction<Hdl>>;
52
53impl<Hdl: 'static + Proxyable> FinishProxyLoopSender<Hdl> {
54 fn and_then(self, then: FinishProxyLoopAction<Hdl>) -> Result<(), Error> {
55 Ok(self.chan.send(then).map_err(|_| format_err!("Join channel broken"))?)
56 }
57
58 fn and_then_initiate(
60 self,
61 paired_handle: fidl::NullableHandle,
62 drain_stream: FramedStreamWriter,
63 stream_ref_sender: StreamRefSender,
64 stream_reader: StreamReader<Hdl::Message>,
65 ) -> Result<(), Error> {
66 self.and_then(FinishProxyLoopAction::InitiateTransfer {
67 paired_handle,
68 drain_stream,
69 stream_ref_sender,
70 stream_reader,
71 })
72 }
73
74 fn and_then_follow(
76 self,
77 initiate_transfer: ProxyTransferInitiationReceiver,
78 new_destination_node: NodeId,
79 transfer_key: TransferKey,
80 stream_reader: StreamReader<Hdl::Message>,
81 ) -> Result<(), Error> {
82 self.and_then(FinishProxyLoopAction::FollowTransfer {
83 initiate_transfer,
84 new_destination_node,
85 transfer_key,
86 stream_reader,
87 })
88 }
89
90 fn and_then_shutdown(
91 self,
92 result: Result<(), zx_status::Status>,
93 stream_reader: StreamReader<Hdl::Message>,
94 ) -> Result<(), Error> {
95 self.and_then(FinishProxyLoopAction::Shutdown { result, stream_reader })
96 }
97}
98
99fn new_task_joiner<Hdl: Proxyable>() -> (FinishProxyLoopSender<Hdl>, FinishProxyLoopReceiver<Hdl>) {
100 let (tx, rx) = futures::channel::oneshot::channel();
101 (FinishProxyLoopSender { chan: tx }, rx)
102}
103
104static PROXY_DROP_EVENT: Mutex<Option<Box<dyn Fn(&Result<(), Error>) + 'static + Send>>> =
106 Mutex::new(None);
107
108pub fn set_proxy_drop_event_handler(handler: impl Fn(&Result<(), Error>) + 'static + Send) {
111 *PROXY_DROP_EVENT.lock() = Some(Box::new(handler));
112}
113
114pub(crate) async fn run_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
116 proxy: Arc<Proxy<Hdl>>,
117 initiate_transfer: ProxyTransferInitiationReceiver,
118 stream_writer: FramedStreamWriter,
119 initial_stream_reader: Option<FramedStreamReader>,
120 stream_reader: FramedStreamReader,
121) -> Result<(), Error> {
122 assert!(Arc::strong_count(&proxy) == 1);
123 let (tx_join, rx_join) = new_task_joiner();
124 let hdl = proxy.hdl();
125 let mut stream_writer = stream_writer.bind(hdl);
126 let initial_stream_reader = initial_stream_reader.map(|s| s.bind(hdl));
127 let mut stream_reader = stream_reader.bind(hdl);
128 let res = futures::future::try_join(
129 async {
130 if !stream_reader.is_initiator() {
131 stream_reader.expect_hello().await?;
132 } else {
133 stream_writer.send_hello().await?;
134 }
135 Ok::<(), Error>(())
136 },
137 async {
138 if let Some(initial_stream_reader) = initial_stream_reader {
139 drain(proxy.clone(), initial_stream_reader).await?;
140 }
141 Ok(())
142 },
143 )
144 .await;
145
146 if let Err(e) = res {
147 Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
148 return Err(e);
149 }
150
151 let mut my_proxy = Some(Arc::clone(&proxy));
152
153 let take_proxy = || {
154 my_proxy = None;
155 };
156
157 let res = futures::future::try_join(
158 stream_to_handle(proxy.clone(), initiate_transfer, stream_reader, tx_join)
159 .map_err(|e| e.context("stream_to_handle")),
160 handle_to_stream(proxy, stream_writer, rx_join, take_proxy)
161 .map_err(|e| e.context("handle_to_stream")),
162 )
163 .map_ok(drop)
164 .await;
165
166 if let Some(cb) = &*PROXY_DROP_EVENT.lock() {
167 cb(&res)
168 }
169 if let Err(e) = res {
170 if let Some(proxy) = my_proxy {
171 Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
172 }
173 Err(e)
174 } else {
175 Ok(())
176 }
177}
178
179async fn handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
180 proxy: Arc<Proxy<Hdl>>,
181 mut stream: StreamWriter<Hdl::Message>,
182 mut finish_proxy_loop: FinishProxyLoopReceiver<Hdl>,
183 take_proxy: impl FnOnce(),
184) -> Result<(), Error> {
185 let mut message = Default::default();
186 let finish_proxy_loop_action = loop {
187 let sr =
188 futures::future::select(&mut finish_proxy_loop, proxy.read_from_handle(&mut message))
189 .await;
190 match sr {
191 Either::Left((finish_proxy_loop_action, _)) => {
192 break finish_proxy_loop_action;
194 }
195 Either::Right((Err(zx_status::Status::PEER_CLOSED), _)) => {
196 if let Some(finish_proxy_loop_action) = finish_proxy_loop.now_or_never() {
197 break finish_proxy_loop_action;
198 }
199 stream.send_shutdown(Ok(())).await.context("send_shutdown")?;
200 return Ok(());
201 }
202 Either::Right((Err(x), _)) => {
203 stream.send_shutdown(Err(x)).await.context("send_shutdown")?;
204 return Err(x).context("read_from_handle");
205 }
206 Either::Right((Ok(ReadValue::Message), _)) => {
207 drop(sr);
208 stream.send_data(&mut message).await.context("send_data")?;
209 }
210 Either::Right((Ok(ReadValue::SignalUpdate(signal_update)), _)) => {
211 stream.send_signal(signal_update).await.context("send_signal")?;
212 }
213 };
214 };
215 take_proxy();
216 let proxy = Arc::try_unwrap(proxy).map_err(|_| format_err!("Proxy should be isolated"))?;
217 match finish_proxy_loop_action {
218 Ok(FinishProxyLoopAction::InitiateTransfer {
219 paired_handle,
220 drain_stream,
221 stream_ref_sender,
222 stream_reader,
223 }) => {
224 super::xfer::initiate(
225 proxy,
226 paired_handle,
227 stream,
228 stream_reader,
229 drain_stream,
230 stream_ref_sender,
231 )
232 .await
233 }
234 Ok(FinishProxyLoopAction::FollowTransfer {
235 initiate_transfer,
236 new_destination_node,
237 transfer_key,
238 stream_reader,
239 }) => {
240 super::xfer::follow(
241 proxy,
242 initiate_transfer,
243 stream,
244 new_destination_node,
245 transfer_key,
246 stream_reader,
247 )
248 .await
249 }
250 Ok(FinishProxyLoopAction::Shutdown { result, stream_reader }) => {
251 join_shutdown(proxy, stream, stream_reader, result).await
252 }
253 Err(futures::channel::oneshot::Canceled) => unreachable!(),
254 }
255}
256
257async fn join_shutdown<Hdl: 'static + Proxyable>(
258 proxy: Proxy<Hdl>,
259 stream_writer: StreamWriter<Hdl::Message>,
260 stream_reader: StreamReader<Hdl::Message>,
261 result: Result<(), zx_status::Status>,
262) -> Result<(), Error> {
263 stream_writer.send_shutdown(result).await?;
264 let _ = stream_reader.expect_shutdown(Ok(())).await;
265 proxy.close_with_reason(format!("Proxy shut down (result: {result:?})"));
266 Ok(())
267}
268
269async fn drain<Hdl: 'static + for<'a> ProxyableRW<'a>>(
270 proxy: Arc<Proxy<Hdl>>,
271 mut drain_stream: StreamReader<Hdl::Message>,
272) -> Result<(), Error> {
273 loop {
274 let frame = drain_stream.next().await?;
275 match frame {
276 Frame::Data(message) => proxy.write_to_handle(message).await?,
277 Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
278 Frame::EndTransfer => break,
279 Frame::BeginTransfer(_, _) => bail!("BeginTransfer on drain stream"),
280 Frame::AckTransfer => bail!("AckTransfer on drain stream"),
281 Frame::Hello => bail!("Hello frame disallowed on drain streams"),
282 Frame::Shutdown(r) => bail!("Stream shutdown during drain: {:?}", r),
283 }
284 }
285 Ok(())
286}
287
288async fn stream_to_handle<Hdl: 'static + for<'a> ProxyableRW<'a>>(
289 proxy: Arc<Proxy<Hdl>>,
290 mut initiate_transfer: ProxyTransferInitiationReceiver,
291 mut stream: StreamReader<Hdl::Message>,
292 finish_proxy_loop: FinishProxyLoopSender<Hdl>,
293) -> Result<(), Error> {
294 let removed_from_proxy_table = loop {
295 let frame = match futures::future::select(&mut initiate_transfer, stream.next()).await {
296 Either::Left((removed_from_proxy_table, _)) => {
297 break removed_from_proxy_table;
300 }
301 Either::Right((frame, _)) => frame.context("stream.next()")?,
302 };
303 match frame {
304 Frame::Data(message) => {
305 if let Err(e) = proxy.write_to_handle(message).await {
306 let _ = finish_proxy_loop.and_then_shutdown(Err(e), stream);
307 match e {
308 zx_status::Status::PEER_CLOSED => {
309 return Ok(());
310 }
311 _ => return Err(e).context("write_to_handle"),
312 }
313 }
314 }
315 Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
316 Frame::BeginTransfer(new_destination_node, transfer_key) => {
317 return finish_proxy_loop
318 .and_then_follow(initiate_transfer, new_destination_node, transfer_key, stream)
319 .context("finish_proxy_loop");
320 }
321 Frame::EndTransfer => bail!("Received EndTransfer on a regular stream"),
322 Frame::AckTransfer => bail!("Received AckTransfer before sending a BeginTransfer"),
323 Frame::Hello => bail!("Hello frame received after stream established"),
324 Frame::Shutdown(result) => {
325 let _ = finish_proxy_loop.and_then_shutdown(result, stream);
326 return result.context("Remote shutdown");
327 }
328 }
329 };
330 match removed_from_proxy_table {
331 Err(e) => Err(e.into()),
332 Ok(RemoveFromProxyTable::Dropped) => unreachable!(),
333 Ok(RemoveFromProxyTable::InitiateTransfer {
334 paired_handle,
335 drain_stream,
336 stream_ref_sender,
337 }) => Ok(finish_proxy_loop.and_then_initiate(
338 paired_handle,
339 drain_stream,
340 stream_ref_sender,
341 stream,
342 )?),
343 }
344}