overnet_core/proxy/
stream.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::handle::{Message, Proxyable, ProxyableHandle, RouterHolder, Serializer};
6use crate::coding::{decode_fidl, encode_fidl};
7use crate::labels::{NodeId, TransferKey};
8use crate::peer::{
9    FrameType, FramedStreamReadResult, FramedStreamReader, FramedStreamWriter, PeerConn,
10    PeerConnRef,
11};
12use crate::router::Router;
13use anyhow::{format_err, Context as _, Error};
14use fidl_fuchsia_overnet_protocol::{BeginTransfer, Empty, SignalUpdate, StreamControl};
15use futures::future::{poll_fn, BoxFuture};
16use futures::prelude::*;
17use futures::ready;
18use std::pin::Pin;
19use std::sync::Weak;
20use std::task::{Context, Poll};
21use zx_status;
22
23pub(crate) struct StreamWriter<Msg: Message> {
24    stream: FramedStreamWriter,
25    send_buffer: Vec<u8>,
26    router: Weak<Router>,
27    closed: bool,
28    _phantom_msg: std::marker::PhantomData<Msg>,
29}
30
31impl<Msg: Message> StreamWriter<Msg> {
32    pub fn conn(&self) -> PeerConnRef<'_> {
33        self.stream.conn()
34    }
35
36    pub fn id(&self) -> u64 {
37        self.stream.id()
38    }
39
40    pub async fn send_data(&mut self, msg: &mut Msg) -> Result<(), Error> {
41        assert_ne!(self.closed, true);
42        let mut s = Msg::Serializer::new();
43        let send_buffer = &mut self.send_buffer;
44        let conn = self.stream.conn();
45        let mut rh = RouterHolder::Unused(&self.router);
46        poll_fn(|fut_ctx| s.poll_ser(msg, send_buffer, conn, &mut rh, fut_ctx))
47            .await
48            .with_context(|| format_err!("Serializing message {:?}", msg))?;
49        self.stream
50            .send(FrameType::Data, &self.send_buffer)
51            .await
52            .with_context(|| format_err!("sending data {:?} ser={:?}", msg, self.send_buffer))
53    }
54
55    async fn send_control(&mut self, mut msg: StreamControl, fin: bool) -> Result<(), Error> {
56        assert_ne!(self.closed, true);
57        let msg = encode_fidl(&mut msg)
58            .with_context(|| format_err!("encoding control message {:?}", msg))?;
59        if fin {
60            self.closed = true;
61        }
62        self.stream
63            .send(FrameType::Control, msg.as_slice())
64            .await
65            .with_context(|| format_err!("sending control message {:?}", msg))
66    }
67
68    pub async fn send_signal(&mut self, mut msg: SignalUpdate) -> Result<(), Error> {
69        assert_ne!(self.closed, true);
70        let msg = encode_fidl(&mut msg)
71            .with_context(|| format_err!("encoding control message {:?}", msg))?;
72        self.stream
73            .send(FrameType::Signal, msg.as_slice())
74            .await
75            .with_context(|| format_err!("sending control message {:?}", msg))
76    }
77
78    pub async fn send_ack_transfer(mut self) -> Result<(), Error> {
79        Ok(self.send_control(StreamControl::AckTransfer(Empty {}), true).await?)
80    }
81
82    pub async fn send_end_transfer(mut self) -> Result<(), Error> {
83        Ok(self.send_control(StreamControl::EndTransfer(Empty {}), true).await?)
84    }
85
86    pub async fn send_begin_transfer(
87        &mut self,
88        new_destination_node: NodeId,
89        transfer_key: TransferKey,
90    ) -> Result<(), Error> {
91        Ok(self
92            .send_control(
93                StreamControl::BeginTransfer(BeginTransfer {
94                    new_destination_node: new_destination_node.into(),
95                    transfer_key,
96                }),
97                false,
98            )
99            .await?)
100    }
101
102    pub async fn send_hello(&mut self) -> Result<(), Error> {
103        self.stream.send(FrameType::Hello, &[]).await.with_context(|| format_err!("sending hello"))
104    }
105
106    pub async fn send_shutdown(mut self, r: Result<(), zx_status::Status>) -> Result<(), Error> {
107        self.send_control(
108            StreamControl::Shutdown(
109                match r {
110                    Ok(()) => zx_status::Status::OK,
111                    Err(s) => s,
112                }
113                .into_raw(),
114            ),
115            true,
116        )
117        .await
118    }
119}
120
121pub(crate) trait StreamWriterBinder {
122    fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
123        self,
124        hdl: &ProxyableHandle<H>,
125    ) -> StreamWriter<Msg>;
126}
127
128impl StreamWriterBinder for FramedStreamWriter {
129    fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
130        self,
131        hdl: &ProxyableHandle<H>,
132    ) -> StreamWriter<Msg> {
133        StreamWriter {
134            stream: self,
135            send_buffer: Vec::new(),
136            router: hdl.router().clone(),
137            closed: false,
138            _phantom_msg: std::marker::PhantomData,
139        }
140    }
141}
142
143#[derive(PartialEq, Debug)]
144pub(crate) enum Frame<'a, Msg: Message> {
145    Hello,
146    Data(&'a mut Msg),
147    SignalUpdate(SignalUpdate),
148    BeginTransfer(NodeId, TransferKey),
149    AckTransfer,
150    EndTransfer,
151    Shutdown(Result<(), zx_status::Status>),
152}
153
154#[derive(Debug)]
155pub(crate) struct StreamReader<Msg: Message> {
156    stream: FramedStreamReader,
157    incoming_message: Msg,
158    router: Weak<Router>,
159    state: ReadNextState<Msg::Parser>,
160}
161
162#[derive(Debug)]
163pub(crate) struct ReadNext<'a, Msg: Message> {
164    read_next_frame_or_peer_conn_ref: ReadNextFrameOrPeerConnRef<'a>,
165    state: &'a mut ReadNextState<Msg::Parser>,
166    conn: PeerConn,
167    incoming_message: Option<&'a mut Msg>,
168    router_holder: RouterHolder<'a>,
169}
170
171enum ReadNextFrameOrPeerConnRef<'a> {
172    ReadNextFrame(BoxFuture<'a, Result<FramedStreamReadResult, Error>>),
173    PeerConnRef(PeerConnRef<'a>),
174}
175
176impl std::fmt::Debug for ReadNextFrameOrPeerConnRef<'_> {
177    fn fmt(&self, writer: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
178        match self {
179            Self::ReadNextFrame(_) => {
180                write!(writer, "ReadNextFrameOrPeerConnRef::ReadNextFrame(...)")
181            }
182            Self::PeerConnRef(x) => {
183                write!(writer, "ReadNextFrameOrPeerConnRef::PeerConnRef({x:?})")
184            }
185        }
186    }
187}
188
189impl<'a> ReadNextFrameOrPeerConnRef<'a> {
190    fn as_read_next_frame_mut(
191        &mut self,
192    ) -> Option<&mut BoxFuture<'a, Result<FramedStreamReadResult, Error>>> {
193        match self {
194            Self::ReadNextFrame(x) => Some(x),
195            _ => None,
196        }
197    }
198}
199
200#[derive(Debug)]
201enum ReadNextState<Parser> {
202    Reading,
203    DeserializingData(Vec<u8>, Parser),
204}
205
206impl<'a, Msg: Message> ReadNext<'a, Msg> {
207    fn poll_inner(&mut self, ctx: &mut Context<'_>) -> Poll<Result<Frame<'a, Msg>, Error>> {
208        loop {
209            return Poll::Ready(Ok(match *self.state {
210                ReadNextState::Reading => {
211                    let (frame_type, mut bytes) = match ready!(self
212                        .read_next_frame_or_peer_conn_ref
213                        .as_read_next_frame_mut()
214                        .unwrap()
215                        .poll_unpin(ctx))?
216                    {
217                        FramedStreamReadResult::Frame(frame_type, bytes) => (frame_type, bytes),
218                        FramedStreamReadResult::Closed(Some(e)) => {
219                            return Poll::Ready(Err(format_err!("unexpected end of stream ({e})")))
220                        }
221                        FramedStreamReadResult::Closed(None) => {
222                            return Poll::Ready(Err(format_err!("unexpected end of stream")))
223                        }
224                    };
225
226                    match frame_type {
227                        FrameType::Hello => {
228                            if bytes.len() != 0 {
229                                return Poll::Ready(Err(format_err!("Hello frame must be empty")));
230                            }
231                            Frame::Hello
232                        }
233                        FrameType::Data => {
234                            *self.state =
235                                ReadNextState::DeserializingData(bytes, Msg::Parser::new());
236                            continue;
237                        }
238                        FrameType::Signal => Frame::SignalUpdate(decode_fidl(&mut bytes)?),
239                        FrameType::Control => match decode_fidl(&mut bytes)? {
240                            StreamControl::AckTransfer(Empty {}) => Frame::AckTransfer,
241                            StreamControl::EndTransfer(Empty {}) => Frame::EndTransfer,
242                            StreamControl::Shutdown(status_code) => {
243                                Frame::Shutdown(zx_status::Status::ok(status_code))
244                            }
245
246                            StreamControl::BeginTransfer(BeginTransfer {
247                                new_destination_node,
248                                transfer_key,
249                            }) => Frame::BeginTransfer(new_destination_node.into(), transfer_key),
250                        },
251                    }
252                }
253                ReadNextState::DeserializingData(ref mut bytes, ref mut parser) => {
254                    ready!(parser.poll_ser(
255                        self.incoming_message.as_mut().unwrap(),
256                        bytes,
257                        self.conn.as_ref(),
258                        &mut self.router_holder,
259                        ctx,
260                    ))?;
261                    *self.state = ReadNextState::Reading;
262                    Frame::Data(self.incoming_message.take().unwrap())
263                }
264            }));
265        }
266    }
267}
268
269impl<'a, Msg: Message> Future for ReadNext<'a, Msg> {
270    type Output = Result<Frame<'a, Msg>, Error>;
271    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
272        Pin::into_inner(self).poll_inner(ctx)
273    }
274}
275
276impl<Msg: Message> StreamReader<Msg> {
277    pub fn conn(&self) -> PeerConnRef<'_> {
278        self.stream.conn()
279    }
280
281    pub fn is_initiator(&self) -> bool {
282        self.stream.is_initiator()
283    }
284
285    pub fn next<'a>(&'a mut self) -> ReadNext<'a, Msg> {
286        let conn = self.stream.conn().into_peer_conn();
287        ReadNext {
288            read_next_frame_or_peer_conn_ref: match self.state {
289                ReadNextState::Reading => {
290                    ReadNextFrameOrPeerConnRef::ReadNextFrame(self.stream.next().boxed())
291                }
292                ReadNextState::DeserializingData(_, _) => {
293                    ReadNextFrameOrPeerConnRef::PeerConnRef(self.stream.conn())
294                }
295            },
296            state: &mut self.state,
297            conn,
298            incoming_message: Some(&mut self.incoming_message),
299            router_holder: RouterHolder::Unused(&self.router),
300        }
301    }
302
303    async fn expect(&mut self, frame: Frame<'_, Msg>) -> Result<(), Error> {
304        let received = self.next().await?;
305        if received != frame {
306            let msg = format_err!("Expected {:?} got {:?}", frame, received);
307            self.stream.abandon().await;
308            Err(msg)
309        } else {
310            Ok(())
311        }
312    }
313
314    pub async fn expect_ack_transfer(mut self) -> Result<(), Error> {
315        self.expect(Frame::AckTransfer).await
316    }
317
318    pub async fn expect_hello(&mut self) -> Result<(), Error> {
319        self.expect(Frame::Hello).await
320    }
321
322    pub async fn expect_shutdown(
323        mut self,
324        result: Result<(), zx_status::Status>,
325    ) -> Result<(), Error> {
326        self.expect(Frame::Shutdown(result)).await
327    }
328}
329
330pub(crate) trait StreamReaderBinder {
331    fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
332        self,
333        hdl: &ProxyableHandle<H>,
334    ) -> StreamReader<Msg>;
335}
336
337impl StreamReaderBinder for FramedStreamReader {
338    fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
339        self,
340        hdl: &ProxyableHandle<H>,
341    ) -> StreamReader<Msg> {
342        StreamReader {
343            stream: self,
344            incoming_message: Default::default(),
345            router: hdl.router().clone(),
346            state: ReadNextState::Reading,
347        }
348    }
349}