overnet_core/proxy/run/
main.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Main loops (and associated spawn functions) for proxying... handles moving data from one point
6//! to another, and calling into crate::proxy::xfer once a handle transfer is required.
7
8use super::super::handle::ReadValue;
9use super::super::stream::{
10    Frame, StreamReader, StreamReaderBinder, StreamWriter, StreamWriterBinder,
11};
12use super::super::{
13    Proxy, ProxyTransferInitiationReceiver, Proxyable, ProxyableRW, RemoveFromProxyTable,
14    StreamRefSender,
15};
16use crate::labels::{NodeId, TransferKey};
17use crate::peer::{FramedStreamReader, FramedStreamWriter};
18use anyhow::{Context as _, Error, bail, format_err};
19use fuchsia_sync::Mutex;
20use futures::future::Either;
21use futures::prelude::*;
22use std::sync::Arc;
23use zx_status;
24
25// We run two tasks to proxy a handle - one to handle handle->stream, the other to handle
26// stream->handle. When we want to perform a transfer operation we end up wanting to think about
27// just one task, so we provide a join operation here.
28#[derive(Debug)]
29enum FinishProxyLoopAction<Hdl: Proxyable> {
30    InitiateTransfer {
31        paired_handle: fidl::NullableHandle,
32        drain_stream: FramedStreamWriter,
33        stream_ref_sender: StreamRefSender,
34        stream_reader: StreamReader<Hdl::Message>,
35    },
36    FollowTransfer {
37        initiate_transfer: ProxyTransferInitiationReceiver,
38        new_destination_node: NodeId,
39        transfer_key: TransferKey,
40        stream_reader: StreamReader<Hdl::Message>,
41    },
42    Shutdown {
43        result: Result<(), zx_status::Status>,
44        stream_reader: StreamReader<Hdl::Message>,
45    },
46}
47
48struct FinishProxyLoopSender<Hdl: Proxyable> {
49    chan: futures::channel::oneshot::Sender<FinishProxyLoopAction<Hdl>>,
50}
51type FinishProxyLoopReceiver<Hdl> = futures::channel::oneshot::Receiver<FinishProxyLoopAction<Hdl>>;
52
53impl<Hdl: 'static + Proxyable> FinishProxyLoopSender<Hdl> {
54    fn and_then(self, then: FinishProxyLoopAction<Hdl>) -> Result<(), Error> {
55        Ok(self.chan.send(then).map_err(|_| format_err!("Join channel broken"))?)
56    }
57
58    // This join is to initiate a new transfer.
59    fn and_then_initiate(
60        self,
61        paired_handle: fidl::NullableHandle,
62        drain_stream: FramedStreamWriter,
63        stream_ref_sender: StreamRefSender,
64        stream_reader: StreamReader<Hdl::Message>,
65    ) -> Result<(), Error> {
66        self.and_then(FinishProxyLoopAction::InitiateTransfer {
67            paired_handle,
68            drain_stream,
69            stream_ref_sender,
70            stream_reader,
71        })
72    }
73
74    // This join is to follow a transfer initiated by the remote end.
75    fn and_then_follow(
76        self,
77        initiate_transfer: ProxyTransferInitiationReceiver,
78        new_destination_node: NodeId,
79        transfer_key: TransferKey,
80        stream_reader: StreamReader<Hdl::Message>,
81    ) -> Result<(), Error> {
82        self.and_then(FinishProxyLoopAction::FollowTransfer {
83            initiate_transfer,
84            new_destination_node,
85            transfer_key,
86            stream_reader,
87        })
88    }
89
90    fn and_then_shutdown(
91        self,
92        result: Result<(), zx_status::Status>,
93        stream_reader: StreamReader<Hdl::Message>,
94    ) -> Result<(), Error> {
95        self.and_then(FinishProxyLoopAction::Shutdown { result, stream_reader })
96    }
97}
98
99fn new_task_joiner<Hdl: Proxyable>() -> (FinishProxyLoopSender<Hdl>, FinishProxyLoopReceiver<Hdl>) {
100    let (tx, rx) = futures::channel::oneshot::channel();
101    (FinishProxyLoopSender { chan: tx }, rx)
102}
103
104/// Store behind [`set_proxy_drop_event_handler`]
105static PROXY_DROP_EVENT: Mutex<Option<Box<dyn Fn(&Result<(), Error>) + 'static + Send>>> =
106    Mutex::new(None);
107
108/// Sets a global callback to call every time a proxy is dropped. It's given a
109/// reference to the error and can be used to send metrics events.
110pub fn set_proxy_drop_event_handler(handler: impl Fn(&Result<(), Error>) + 'static + Send) {
111    *PROXY_DROP_EVENT.lock() = Some(Box::new(handler));
112}
113
114// Spawn a proxy (two tasks, one for each direction of proxying).
115pub(crate) async fn run_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
116    proxy: Arc<Proxy<Hdl>>,
117    initiate_transfer: ProxyTransferInitiationReceiver,
118    stream_writer: FramedStreamWriter,
119    initial_stream_reader: Option<FramedStreamReader>,
120    stream_reader: FramedStreamReader,
121) -> Result<(), Error> {
122    assert!(Arc::strong_count(&proxy) == 1);
123    let (tx_join, rx_join) = new_task_joiner();
124    let hdl = proxy.hdl();
125    let mut stream_writer = stream_writer.bind(hdl);
126    let initial_stream_reader = initial_stream_reader.map(|s| s.bind(hdl));
127    let mut stream_reader = stream_reader.bind(hdl);
128    let res = futures::future::try_join(
129        async {
130            if !stream_reader.is_initiator() {
131                stream_reader.expect_hello().await?;
132            } else {
133                stream_writer.send_hello().await?;
134            }
135            Ok::<(), Error>(())
136        },
137        async {
138            if let Some(initial_stream_reader) = initial_stream_reader {
139                drain(proxy.clone(), initial_stream_reader).await?;
140            }
141            Ok(())
142        },
143    )
144    .await;
145
146    if let Err(e) = res {
147        Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
148        return Err(e);
149    }
150
151    let mut my_proxy = Some(Arc::clone(&proxy));
152
153    let take_proxy = || {
154        my_proxy = None;
155    };
156
157    let res = futures::future::try_join(
158        stream_to_handle(proxy.clone(), initiate_transfer, stream_reader, tx_join)
159            .map_err(|e| e.context("stream_to_handle")),
160        handle_to_stream(proxy, stream_writer, rx_join, take_proxy)
161            .map_err(|e| e.context("handle_to_stream")),
162    )
163    .map_ok(drop)
164    .await;
165
166    if let Some(cb) = &*PROXY_DROP_EVENT.lock() {
167        cb(&res)
168    }
169    if let Err(e) = res {
170        if let Some(proxy) = my_proxy {
171            Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
172        }
173        Err(e)
174    } else {
175        Ok(())
176    }
177}
178
179async fn handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
180    proxy: Arc<Proxy<Hdl>>,
181    mut stream: StreamWriter<Hdl::Message>,
182    mut finish_proxy_loop: FinishProxyLoopReceiver<Hdl>,
183    take_proxy: impl FnOnce(),
184) -> Result<(), Error> {
185    let mut message = Default::default();
186    let finish_proxy_loop_action = loop {
187        let sr =
188            futures::future::select(&mut finish_proxy_loop, proxy.read_from_handle(&mut message))
189                .await;
190        match sr {
191            Either::Left((finish_proxy_loop_action, _)) => {
192                // Note: Proxy guarantees that read_from_handle can be dropped safely without losing data.
193                break finish_proxy_loop_action;
194            }
195            Either::Right((Err(zx_status::Status::PEER_CLOSED), _)) => {
196                if let Some(finish_proxy_loop_action) = finish_proxy_loop.now_or_never() {
197                    break finish_proxy_loop_action;
198                }
199                stream.send_shutdown(Ok(())).await.context("send_shutdown")?;
200                return Ok(());
201            }
202            Either::Right((Err(x), _)) => {
203                stream.send_shutdown(Err(x)).await.context("send_shutdown")?;
204                return Err(x).context("read_from_handle");
205            }
206            Either::Right((Ok(ReadValue::Message), _)) => {
207                drop(sr);
208                stream.send_data(&mut message).await.context("send_data")?;
209            }
210            Either::Right((Ok(ReadValue::SignalUpdate(signal_update)), _)) => {
211                stream.send_signal(signal_update).await.context("send_signal")?;
212            }
213        };
214    };
215    take_proxy();
216    let proxy = Arc::try_unwrap(proxy).map_err(|_| format_err!("Proxy should be isolated"))?;
217    match finish_proxy_loop_action {
218        Ok(FinishProxyLoopAction::InitiateTransfer {
219            paired_handle,
220            drain_stream,
221            stream_ref_sender,
222            stream_reader,
223        }) => {
224            super::xfer::initiate(
225                proxy,
226                paired_handle,
227                stream,
228                stream_reader,
229                drain_stream,
230                stream_ref_sender,
231            )
232            .await
233        }
234        Ok(FinishProxyLoopAction::FollowTransfer {
235            initiate_transfer,
236            new_destination_node,
237            transfer_key,
238            stream_reader,
239        }) => {
240            super::xfer::follow(
241                proxy,
242                initiate_transfer,
243                stream,
244                new_destination_node,
245                transfer_key,
246                stream_reader,
247            )
248            .await
249        }
250        Ok(FinishProxyLoopAction::Shutdown { result, stream_reader }) => {
251            join_shutdown(proxy, stream, stream_reader, result).await
252        }
253        Err(futures::channel::oneshot::Canceled) => unreachable!(),
254    }
255}
256
257async fn join_shutdown<Hdl: 'static + Proxyable>(
258    proxy: Proxy<Hdl>,
259    stream_writer: StreamWriter<Hdl::Message>,
260    stream_reader: StreamReader<Hdl::Message>,
261    result: Result<(), zx_status::Status>,
262) -> Result<(), Error> {
263    stream_writer.send_shutdown(result).await?;
264    let _ = stream_reader.expect_shutdown(Ok(())).await;
265    proxy.close_with_reason(format!("Proxy shut down (result: {result:?})"));
266    Ok(())
267}
268
269async fn drain<Hdl: 'static + for<'a> ProxyableRW<'a>>(
270    proxy: Arc<Proxy<Hdl>>,
271    mut drain_stream: StreamReader<Hdl::Message>,
272) -> Result<(), Error> {
273    loop {
274        let frame = drain_stream.next().await?;
275        match frame {
276            Frame::Data(message) => proxy.write_to_handle(message).await?,
277            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
278            Frame::EndTransfer => break,
279            Frame::BeginTransfer(_, _) => bail!("BeginTransfer on drain stream"),
280            Frame::AckTransfer => bail!("AckTransfer on drain stream"),
281            Frame::Hello => bail!("Hello frame disallowed on drain streams"),
282            Frame::Shutdown(r) => bail!("Stream shutdown during drain: {:?}", r),
283        }
284    }
285    Ok(())
286}
287
288async fn stream_to_handle<Hdl: 'static + for<'a> ProxyableRW<'a>>(
289    proxy: Arc<Proxy<Hdl>>,
290    mut initiate_transfer: ProxyTransferInitiationReceiver,
291    mut stream: StreamReader<Hdl::Message>,
292    finish_proxy_loop: FinishProxyLoopSender<Hdl>,
293) -> Result<(), Error> {
294    let removed_from_proxy_table = loop {
295        let frame = match futures::future::select(&mut initiate_transfer, stream.next()).await {
296            Either::Left((removed_from_proxy_table, _)) => {
297                // Note: StreamReader guarantees it's safe to drop a partial read without
298                // losing data.
299                break removed_from_proxy_table;
300            }
301            Either::Right((frame, _)) => frame.context("stream.next()")?,
302        };
303        match frame {
304            Frame::Data(message) => {
305                if let Err(e) = proxy.write_to_handle(message).await {
306                    let _ = finish_proxy_loop.and_then_shutdown(Err(e), stream);
307                    match e {
308                        zx_status::Status::PEER_CLOSED => {
309                            return Ok(());
310                        }
311                        _ => return Err(e).context("write_to_handle"),
312                    }
313                }
314            }
315            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
316            Frame::BeginTransfer(new_destination_node, transfer_key) => {
317                return finish_proxy_loop
318                    .and_then_follow(initiate_transfer, new_destination_node, transfer_key, stream)
319                    .context("finish_proxy_loop");
320            }
321            Frame::EndTransfer => bail!("Received EndTransfer on a regular stream"),
322            Frame::AckTransfer => bail!("Received AckTransfer before sending a BeginTransfer"),
323            Frame::Hello => bail!("Hello frame received after stream established"),
324            Frame::Shutdown(result) => {
325                let _ = finish_proxy_loop.and_then_shutdown(result, stream);
326                return result.context("Remote shutdown");
327            }
328        }
329    };
330    match removed_from_proxy_table {
331        Err(e) => Err(e.into()),
332        Ok(RemoveFromProxyTable::Dropped) => unreachable!(),
333        Ok(RemoveFromProxyTable::InitiateTransfer {
334            paired_handle,
335            drain_stream,
336            stream_ref_sender,
337        }) => Ok(finish_proxy_loop.and_then_initiate(
338            paired_handle,
339            drain_stream,
340            stream_ref_sender,
341            stream,
342        )?),
343    }
344}