overnet_core/proxy/run/
main.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Main loops (and associated spawn functions) for proxying... handles moving data from one point
6//! to another, and calling into crate::proxy::xfer once a handle transfer is required.
7
8use super::super::handle::ReadValue;
9use super::super::stream::{
10    Frame, StreamReader, StreamReaderBinder, StreamWriter, StreamWriterBinder,
11};
12use super::super::{
13    Proxy, ProxyTransferInitiationReceiver, Proxyable, ProxyableRW, RemoveFromProxyTable,
14    StreamRefSender,
15};
16use crate::labels::{NodeId, TransferKey};
17use crate::peer::{FramedStreamReader, FramedStreamWriter};
18use anyhow::{bail, format_err, Context as _, Error};
19use futures::future::Either;
20use futures::prelude::*;
21use std::sync::{Arc, Mutex};
22use zx_status;
23
24#[cfg(not(target_os = "fuchsia"))]
25use fuchsia_async::emulated_handle::ChannelProxyProtocol;
26
27// We run two tasks to proxy a handle - one to handle handle->stream, the other to handle
28// stream->handle. When we want to perform a transfer operation we end up wanting to think about
29// just one task, so we provide a join operation here.
30#[derive(Debug)]
31enum FinishProxyLoopAction<Hdl: Proxyable> {
32    InitiateTransfer {
33        paired_handle: fidl::Handle,
34        drain_stream: FramedStreamWriter,
35        stream_ref_sender: StreamRefSender,
36        stream_reader: StreamReader<Hdl::Message>,
37    },
38    FollowTransfer {
39        initiate_transfer: ProxyTransferInitiationReceiver,
40        new_destination_node: NodeId,
41        transfer_key: TransferKey,
42        stream_reader: StreamReader<Hdl::Message>,
43    },
44    Shutdown {
45        result: Result<(), zx_status::Status>,
46        stream_reader: StreamReader<Hdl::Message>,
47    },
48}
49
50struct FinishProxyLoopSender<Hdl: Proxyable> {
51    chan: futures::channel::oneshot::Sender<FinishProxyLoopAction<Hdl>>,
52}
53type FinishProxyLoopReceiver<Hdl> = futures::channel::oneshot::Receiver<FinishProxyLoopAction<Hdl>>;
54
55impl<Hdl: 'static + Proxyable> FinishProxyLoopSender<Hdl> {
56    fn and_then(self, then: FinishProxyLoopAction<Hdl>) -> Result<(), Error> {
57        Ok(self.chan.send(then).map_err(|_| format_err!("Join channel broken"))?)
58    }
59
60    // This join is to initiate a new transfer.
61    fn and_then_initiate(
62        self,
63        paired_handle: fidl::Handle,
64        drain_stream: FramedStreamWriter,
65        stream_ref_sender: StreamRefSender,
66        stream_reader: StreamReader<Hdl::Message>,
67    ) -> Result<(), Error> {
68        self.and_then(FinishProxyLoopAction::InitiateTransfer {
69            paired_handle,
70            drain_stream,
71            stream_ref_sender,
72            stream_reader,
73        })
74    }
75
76    // This join is to follow a transfer initiated by the remote end.
77    fn and_then_follow(
78        self,
79        initiate_transfer: ProxyTransferInitiationReceiver,
80        new_destination_node: NodeId,
81        transfer_key: TransferKey,
82        stream_reader: StreamReader<Hdl::Message>,
83    ) -> Result<(), Error> {
84        self.and_then(FinishProxyLoopAction::FollowTransfer {
85            initiate_transfer,
86            new_destination_node,
87            transfer_key,
88            stream_reader,
89        })
90    }
91
92    fn and_then_shutdown(
93        self,
94        result: Result<(), zx_status::Status>,
95        stream_reader: StreamReader<Hdl::Message>,
96    ) -> Result<(), Error> {
97        self.and_then(FinishProxyLoopAction::Shutdown { result, stream_reader })
98    }
99}
100
101fn new_task_joiner<Hdl: Proxyable>() -> (FinishProxyLoopSender<Hdl>, FinishProxyLoopReceiver<Hdl>) {
102    let (tx, rx) = futures::channel::oneshot::channel();
103    (FinishProxyLoopSender { chan: tx }, rx)
104}
105
106/// Store behind [`set_proxy_drop_event_handler`]
107static PROXY_DROP_EVENT: Mutex<Option<Box<dyn Fn(&Result<(), Error>) + 'static + Send>>> =
108    Mutex::new(None);
109
110/// Sets a global callback to call every time a proxy is dropped. It's given a
111/// reference to the error and can be used to send metrics events.
112pub fn set_proxy_drop_event_handler(handler: impl Fn(&Result<(), Error>) + 'static + Send) {
113    *PROXY_DROP_EVENT.lock().unwrap() = Some(Box::new(handler));
114}
115
116// Spawn a proxy (two tasks, one for each direction of proxying).
117pub(crate) async fn run_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
118    proxy: Arc<Proxy<Hdl>>,
119    initiate_transfer: ProxyTransferInitiationReceiver,
120    stream_writer: FramedStreamWriter,
121    initial_stream_reader: Option<FramedStreamReader>,
122    stream_reader: FramedStreamReader,
123) -> Result<(), Error> {
124    #[cfg(not(target_os = "fuchsia"))]
125    proxy.set_channel_proxy_protocol(ChannelProxyProtocol::Cso);
126
127    assert!(Arc::strong_count(&proxy) == 1);
128    let (tx_join, rx_join) = new_task_joiner();
129    let hdl = proxy.hdl();
130    let mut stream_writer = stream_writer.bind(hdl);
131    let initial_stream_reader = initial_stream_reader.map(|s| s.bind(hdl));
132    let mut stream_reader = stream_reader.bind(hdl);
133    let res = futures::future::try_join(
134        async {
135            if !stream_reader.is_initiator() {
136                stream_reader.expect_hello().await?;
137            } else {
138                stream_writer.send_hello().await?;
139            }
140            Ok::<(), Error>(())
141        },
142        async {
143            if let Some(initial_stream_reader) = initial_stream_reader {
144                drain(proxy.clone(), initial_stream_reader).await?;
145            }
146            Ok(())
147        },
148    )
149    .await;
150
151    if let Err(e) = res {
152        Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
153        return Err(e);
154    }
155
156    let mut my_proxy = Some(Arc::clone(&proxy));
157
158    let take_proxy = || {
159        my_proxy = None;
160    };
161
162    let res = futures::future::try_join(
163        stream_to_handle(proxy.clone(), initiate_transfer, stream_reader, tx_join)
164            .map_err(|e| e.context("stream_to_handle")),
165        handle_to_stream(proxy, stream_writer, rx_join, take_proxy)
166            .map_err(|e| e.context("handle_to_stream")),
167    )
168    .map_ok(drop)
169    .await;
170
171    if let Some(cb) = &*PROXY_DROP_EVENT.lock().unwrap() {
172        cb(&res)
173    }
174    if let Err(e) = res {
175        if let Some(proxy) = my_proxy {
176            Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
177        }
178        Err(e)
179    } else {
180        Ok(())
181    }
182}
183
184async fn handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
185    proxy: Arc<Proxy<Hdl>>,
186    mut stream: StreamWriter<Hdl::Message>,
187    mut finish_proxy_loop: FinishProxyLoopReceiver<Hdl>,
188    take_proxy: impl FnOnce(),
189) -> Result<(), Error> {
190    let mut message = Default::default();
191    let finish_proxy_loop_action = loop {
192        let sr =
193            futures::future::select(&mut finish_proxy_loop, proxy.read_from_handle(&mut message))
194                .await;
195        match sr {
196            Either::Left((finish_proxy_loop_action, _)) => {
197                // Note: Proxy guarantees that read_from_handle can be dropped safely without losing data.
198                break finish_proxy_loop_action;
199            }
200            Either::Right((Err(zx_status::Status::PEER_CLOSED), _)) => {
201                if let Some(finish_proxy_loop_action) = finish_proxy_loop.now_or_never() {
202                    break finish_proxy_loop_action;
203                }
204                stream.send_shutdown(Ok(())).await.context("send_shutdown")?;
205                return Ok(());
206            }
207            Either::Right((Err(x), _)) => {
208                stream.send_shutdown(Err(x)).await.context("send_shutdown")?;
209                return Err(x).context("read_from_handle");
210            }
211            Either::Right((Ok(ReadValue::Message), _)) => {
212                drop(sr);
213                stream.send_data(&mut message).await.context("send_data")?;
214            }
215            Either::Right((Ok(ReadValue::SignalUpdate(signal_update)), _)) => {
216                stream.send_signal(signal_update).await.context("send_signal")?;
217            }
218        };
219    };
220    take_proxy();
221    let proxy = Arc::try_unwrap(proxy).map_err(|_| format_err!("Proxy should be isolated"))?;
222    match finish_proxy_loop_action {
223        Ok(FinishProxyLoopAction::InitiateTransfer {
224            paired_handle,
225            drain_stream,
226            stream_ref_sender,
227            stream_reader,
228        }) => {
229            super::xfer::initiate(
230                proxy,
231                paired_handle,
232                stream,
233                stream_reader,
234                drain_stream,
235                stream_ref_sender,
236            )
237            .await
238        }
239        Ok(FinishProxyLoopAction::FollowTransfer {
240            initiate_transfer,
241            new_destination_node,
242            transfer_key,
243            stream_reader,
244        }) => {
245            super::xfer::follow(
246                proxy,
247                initiate_transfer,
248                stream,
249                new_destination_node,
250                transfer_key,
251                stream_reader,
252            )
253            .await
254        }
255        Ok(FinishProxyLoopAction::Shutdown { result, stream_reader }) => {
256            join_shutdown(proxy, stream, stream_reader, result).await
257        }
258        Err(futures::channel::oneshot::Canceled) => unreachable!(),
259    }
260}
261
262async fn join_shutdown<Hdl: 'static + Proxyable>(
263    proxy: Proxy<Hdl>,
264    stream_writer: StreamWriter<Hdl::Message>,
265    stream_reader: StreamReader<Hdl::Message>,
266    result: Result<(), zx_status::Status>,
267) -> Result<(), Error> {
268    stream_writer.send_shutdown(result).await?;
269    let _ = stream_reader.expect_shutdown(Ok(())).await;
270    proxy.close_with_reason(format!("Proxy shut down (result: {result:?})"));
271    Ok(())
272}
273
274async fn drain<Hdl: 'static + for<'a> ProxyableRW<'a>>(
275    proxy: Arc<Proxy<Hdl>>,
276    mut drain_stream: StreamReader<Hdl::Message>,
277) -> Result<(), Error> {
278    loop {
279        let frame = drain_stream.next().await?;
280        match frame {
281            Frame::Data(message) => proxy.write_to_handle(message).await?,
282            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
283            Frame::EndTransfer => break,
284            Frame::BeginTransfer(_, _) => bail!("BeginTransfer on drain stream"),
285            Frame::AckTransfer => bail!("AckTransfer on drain stream"),
286            Frame::Hello => bail!("Hello frame disallowed on drain streams"),
287            Frame::Shutdown(r) => bail!("Stream shutdown during drain: {:?}", r),
288        }
289    }
290    Ok(())
291}
292
293async fn stream_to_handle<Hdl: 'static + for<'a> ProxyableRW<'a>>(
294    proxy: Arc<Proxy<Hdl>>,
295    mut initiate_transfer: ProxyTransferInitiationReceiver,
296    mut stream: StreamReader<Hdl::Message>,
297    finish_proxy_loop: FinishProxyLoopSender<Hdl>,
298) -> Result<(), Error> {
299    let removed_from_proxy_table = loop {
300        let frame = match futures::future::select(&mut initiate_transfer, stream.next()).await {
301            Either::Left((removed_from_proxy_table, _)) => {
302                // Note: StreamReader guarantees it's safe to drop a partial read without
303                // losing data.
304                break removed_from_proxy_table;
305            }
306            Either::Right((frame, _)) => frame.context("stream.next()")?,
307        };
308        match frame {
309            Frame::Data(message) => {
310                if let Err(e) = proxy.write_to_handle(message).await {
311                    let _ = finish_proxy_loop.and_then_shutdown(Err(e), stream);
312                    match e {
313                        zx_status::Status::PEER_CLOSED => {
314                            return Ok(());
315                        }
316                        _ => return Err(e).context("write_to_handle"),
317                    }
318                }
319            }
320            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
321            Frame::BeginTransfer(new_destination_node, transfer_key) => {
322                return finish_proxy_loop
323                    .and_then_follow(initiate_transfer, new_destination_node, transfer_key, stream)
324                    .context("finish_proxy_loop")
325            }
326            Frame::EndTransfer => bail!("Received EndTransfer on a regular stream"),
327            Frame::AckTransfer => bail!("Received AckTransfer before sending a BeginTransfer"),
328            Frame::Hello => bail!("Hello frame received after stream established"),
329            Frame::Shutdown(result) => {
330                let _ = finish_proxy_loop.and_then_shutdown(result, stream);
331                return result.context("Remote shutdown");
332            }
333        }
334    };
335    match removed_from_proxy_table {
336        Err(e) => Err(e.into()),
337        Ok(RemoveFromProxyTable::Dropped) => unreachable!(),
338        Ok(RemoveFromProxyTable::InitiateTransfer {
339            paired_handle,
340            drain_stream,
341            stream_ref_sender,
342        }) => Ok(finish_proxy_loop.and_then_initiate(
343            paired_handle,
344            drain_stream,
345            stream_ref_sender,
346            stream,
347        )?),
348    }
349}