use super::super::handle::{Message, ProxyableHandle, ProxyableRW, ReadValue};
use super::super::stream::{Frame, StreamReader, StreamWriter, StreamWriterBinder};
use super::super::{Proxy, ProxyTransferInitiationReceiver, StreamRefSender};
use crate::labels::{generate_transfer_key, Endpoint, NodeId, TransferKey};
use crate::peer::{FramedStreamReader, FramedStreamWriter};
use crate::router::OpenedTransfer;
use anyhow::{bail, format_err, Error};
use futures::future::Either;
use futures::prelude::*;
use futures::task::{noop_waker_ref, Context, Poll};
use std::sync::{Arc, Weak};
use zx_status;
pub(crate) async fn follow<Hdl: 'static + for<'a> ProxyableRW<'a>>(
mut proxy: Proxy<Hdl>,
initiate_transfer: ProxyTransferInitiationReceiver,
stream_writer: StreamWriter<Hdl::Message>,
new_destination_node: NodeId,
transfer_key: TransferKey,
stream_reader: StreamReader<Hdl::Message>,
) -> Result<(), Error> {
futures::future::try_join(stream_reader.expect_shutdown(Ok(())), async move {
stream_writer.send_ack_transfer().await?;
let hdl = proxy.hdl.take().ok_or_else(|| format_err!("Handle already taken"))?;
let router = Weak::upgrade(&hdl.router()).ok_or_else(|| format_err!("Router gone"))?;
let hdl = hdl.into_fidl_handle()?;
drop(proxy);
let r = router.open_transfer(new_destination_node.into(), transfer_key, hdl).await?;
match r {
OpenedTransfer::Fused => {
assert!(initiate_transfer.await.unwrap().is_dropped());
Ok(())
}
OpenedTransfer::Remote(new_writer, new_reader, handle) => {
let handle = Hdl::from_fidl_handle(handle)?;
make_boxed_main_loop(
Proxy::new(handle, Arc::downgrade(&router)),
initiate_transfer,
new_writer.into(),
None,
new_reader.into(),
)
.await?;
Ok(())
}
}
})
.await?;
Ok(())
}
fn make_boxed_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
proxy: Arc<Proxy<Hdl>>,
initiate_transfer: ProxyTransferInitiationReceiver,
stream_writer: FramedStreamWriter,
initial_stream_reader: Option<FramedStreamReader>,
stream_reader: FramedStreamReader,
) -> std::pin::Pin<Box<dyn Send + Future<Output = Result<(), Error>>>> {
super::main::run_main_loop(
proxy,
initiate_transfer,
stream_writer,
initial_stream_reader,
stream_reader,
)
.boxed()
}
pub(crate) async fn initiate<Hdl: 'static + for<'a> ProxyableRW<'a>>(
proxy: Proxy<Hdl>,
pair: fidl::Handle,
mut stream_writer: StreamWriter<Hdl::Message>,
mut stream_reader: StreamReader<Hdl::Message>,
drain_stream: FramedStreamWriter,
stream_ref_sender: StreamRefSender,
) -> Result<(), Error> {
let transfer_key = generate_transfer_key();
let drain_stream = drain_stream.bind(&proxy.hdl());
let drain_stream_id = drain_stream.id();
let peer_node_id = drain_stream.conn().peer_node_id();
futures::future::try_join(
drain_handle_to_stream(
ProxyableHandle::new(Hdl::from_fidl_handle(pair)?, proxy.hdl().router().clone()),
drain_stream,
),
async move {
let stream_ref_sender = flush_outgoing_messages(
&proxy,
transfer_key,
&mut stream_writer,
&mut stream_reader,
drain_stream_id,
stream_ref_sender,
)
.await?;
stream_writer.send_begin_transfer(peer_node_id, transfer_key).await?;
if let Some(stream_ref_sender) = stream_ref_sender {
drain_original_stream(
&proxy,
transfer_key,
stream_writer,
stream_reader,
drain_stream_id,
stream_ref_sender,
)
.await?;
} else {
stream_writer.send_ack_transfer().await?;
stream_reader.expect_ack_transfer().await?;
}
Ok(())
},
)
.await?;
Ok(())
}
async fn drain_handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
hdl: ProxyableHandle<Hdl>,
mut stream_writer: StreamWriter<Hdl::Message>,
) -> Result<(), Error> {
let mut message = Default::default();
loop {
match hdl.read(&mut message).await {
Ok(ReadValue::Message) => stream_writer.send_data(&mut message).await?,
Ok(ReadValue::SignalUpdate(signal_update)) => {
stream_writer.send_signal(signal_update).await?
}
Err(zx_status::Status::PEER_CLOSED) => break,
Err(x) => return Err(x.into()),
}
}
stream_writer.send_end_transfer().await
}
#[derive(Debug)]
enum FlushOutgoingMsg<'a, Msg: Message> {
FromChannel,
FromStream(Frame<'a, Msg>),
}
async fn flush_outgoing_messages<Hdl: 'static + for<'a> ProxyableRW<'a>>(
proxy: &Proxy<Hdl>,
original_transfer_key: TransferKey,
stream_writer: &mut StreamWriter<Hdl::Message>,
stream_reader: &mut StreamReader<Hdl::Message>,
drain_stream_id: u64,
stream_ref_sender: StreamRefSender,
) -> Result<Option<StreamRefSender>, Error> {
let mut message = Default::default();
let endpoint = stream_reader.conn().endpoint();
loop {
let msg = match futures::future::select(
proxy.read_from_handle(&mut message),
stream_reader.next(),
)
.poll_unpin(&mut Context::from_waker(noop_waker_ref()))
{
Poll::Pending => return Ok(Some(stream_ref_sender)),
Poll::Ready(Either::Left((x, _))) => {
x?;
FlushOutgoingMsg::FromChannel
}
Poll::Ready(Either::Right((msg, _))) => FlushOutgoingMsg::FromStream(msg?),
};
match msg {
FlushOutgoingMsg::FromChannel => {
stream_writer.send_data(&mut message).await?;
}
FlushOutgoingMsg::FromStream(Frame::Data(msg)) => {
proxy.write_to_handle(msg).await?;
}
FlushOutgoingMsg::FromStream(Frame::SignalUpdate(signal_update)) => {
proxy.apply_signal_update(signal_update)?;
}
FlushOutgoingMsg::FromStream(Frame::BeginTransfer(
new_destination_node,
new_transfer_key,
)) => {
match endpoint {
Endpoint::Client => {
stream_ref_sender.draining_initiate(
drain_stream_id,
new_destination_node,
new_transfer_key,
)?;
}
Endpoint::Server => {
stream_ref_sender.draining_await(drain_stream_id, original_transfer_key)?;
}
}
proxy.drain_handle_to_stream(stream_writer).await?;
return Ok(None);
}
FlushOutgoingMsg::FromStream(Frame::Hello) => {
bail!("Hello frame received after stream established")
}
FlushOutgoingMsg::FromStream(Frame::AckTransfer) => {
bail!("AckTransfer received before BeginTransfer sent")
}
FlushOutgoingMsg::FromStream(Frame::EndTransfer) => {
bail!("EndTransfer received on a regular stream")
}
FlushOutgoingMsg::FromStream(Frame::Shutdown(r)) => {
bail!("Stream shutdown during transfer: {:?}", r)
}
}
}
}
async fn drain_original_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
proxy: &Proxy<Hdl>,
original_transfer_key: TransferKey,
stream_writer: StreamWriter<Hdl::Message>,
mut stream_reader: StreamReader<Hdl::Message>,
drain_stream_id: u64,
stream_ref_sender: StreamRefSender,
) -> Result<(), Error> {
let endpoint = stream_reader.conn().endpoint();
loop {
let r = stream_reader.next().await;
match r {
Ok(Frame::Hello) => {
bail!("Hello frame received after stream established");
}
Ok(Frame::Data(mut message)) => {
proxy.write_to_handle(&mut message).await?;
}
Ok(Frame::SignalUpdate(signal_update)) => proxy.apply_signal_update(signal_update)?,
Ok(Frame::BeginTransfer(new_destination_node, new_transfer_key)) => {
match endpoint {
Endpoint::Client => {
stream_ref_sender.draining_initiate(
drain_stream_id,
new_destination_node,
new_transfer_key,
)?;
stream_writer.send_ack_transfer().await?;
return stream_reader.expect_ack_transfer().await;
}
Endpoint::Server => {
stream_ref_sender.draining_await(drain_stream_id, original_transfer_key)?;
stream_writer.send_ack_transfer().await?;
return stream_reader.expect_ack_transfer().await;
}
}
}
Ok(Frame::AckTransfer) => {
stream_writer.send_shutdown(Ok(())).await?;
return stream_ref_sender.draining_await(drain_stream_id, original_transfer_key);
}
Ok(Frame::EndTransfer) => bail!("EndTransfer received on a regular stream"),
Ok(Frame::Shutdown(r)) => bail!("Stream shutdown during transfer: {:?}", r),
Err(e) => return Err(e),
}
}
}