1use super::handle::{Message, Proxyable, ProxyableHandle, RouterHolder, Serializer};
6use crate::coding::{decode_fidl, encode_fidl};
7use crate::labels::{NodeId, TransferKey};
8use crate::peer::{
9 FrameType, FramedStreamReadResult, FramedStreamReader, FramedStreamWriter, PeerConn,
10 PeerConnRef,
11};
12use crate::router::Router;
13use anyhow::{format_err, Context as _, Error};
14use fidl_fuchsia_overnet_protocol::{BeginTransfer, Empty, SignalUpdate, StreamControl};
15use futures::future::{poll_fn, BoxFuture};
16use futures::prelude::*;
17use futures::ready;
18use std::pin::Pin;
19use std::sync::Weak;
20use std::task::{Context, Poll};
21use zx_status;
22
23pub(crate) struct StreamWriter<Msg: Message> {
24 stream: FramedStreamWriter,
25 send_buffer: Vec<u8>,
26 router: Weak<Router>,
27 closed: bool,
28 _phantom_msg: std::marker::PhantomData<Msg>,
29}
30
31impl<Msg: Message> StreamWriter<Msg> {
32 pub fn conn(&self) -> PeerConnRef<'_> {
33 self.stream.conn()
34 }
35
36 pub fn id(&self) -> u64 {
37 self.stream.id()
38 }
39
40 pub async fn send_data(&mut self, msg: &mut Msg) -> Result<(), Error> {
41 assert_ne!(self.closed, true);
42 let mut s = Msg::Serializer::new();
43 let send_buffer = &mut self.send_buffer;
44 let conn = self.stream.conn();
45 let mut rh = RouterHolder::Unused(&self.router);
46 poll_fn(|fut_ctx| s.poll_ser(msg, send_buffer, conn, &mut rh, fut_ctx))
47 .await
48 .with_context(|| format_err!("Serializing message {:?}", msg))?;
49 self.stream
50 .send(FrameType::Data, &self.send_buffer)
51 .await
52 .with_context(|| format_err!("sending data {:?} ser={:?}", msg, self.send_buffer))
53 }
54
55 async fn send_control(&mut self, mut msg: StreamControl, fin: bool) -> Result<(), Error> {
56 assert_ne!(self.closed, true);
57 let msg = encode_fidl(&mut msg)
58 .with_context(|| format_err!("encoding control message {:?}", msg))?;
59 if fin {
60 self.closed = true;
61 }
62 self.stream
63 .send(FrameType::Control, msg.as_slice())
64 .await
65 .with_context(|| format_err!("sending control message {:?}", msg))
66 }
67
68 pub async fn send_signal(&mut self, mut msg: SignalUpdate) -> Result<(), Error> {
69 assert_ne!(self.closed, true);
70 let msg = encode_fidl(&mut msg)
71 .with_context(|| format_err!("encoding control message {:?}", msg))?;
72 self.stream
73 .send(FrameType::Signal, msg.as_slice())
74 .await
75 .with_context(|| format_err!("sending control message {:?}", msg))
76 }
77
78 pub async fn send_ack_transfer(mut self) -> Result<(), Error> {
79 Ok(self.send_control(StreamControl::AckTransfer(Empty {}), true).await?)
80 }
81
82 pub async fn send_end_transfer(mut self) -> Result<(), Error> {
83 Ok(self.send_control(StreamControl::EndTransfer(Empty {}), true).await?)
84 }
85
86 pub async fn send_begin_transfer(
87 &mut self,
88 new_destination_node: NodeId,
89 transfer_key: TransferKey,
90 ) -> Result<(), Error> {
91 Ok(self
92 .send_control(
93 StreamControl::BeginTransfer(BeginTransfer {
94 new_destination_node: new_destination_node.into(),
95 transfer_key,
96 }),
97 false,
98 )
99 .await?)
100 }
101
102 pub async fn send_hello(&mut self) -> Result<(), Error> {
103 self.stream.send(FrameType::Hello, &[]).await.with_context(|| format_err!("sending hello"))
104 }
105
106 pub async fn send_shutdown(mut self, r: Result<(), zx_status::Status>) -> Result<(), Error> {
107 self.send_control(
108 StreamControl::Shutdown(
109 match r {
110 Ok(()) => zx_status::Status::OK,
111 Err(s) => s,
112 }
113 .into_raw(),
114 ),
115 true,
116 )
117 .await
118 }
119}
120
121pub(crate) trait StreamWriterBinder {
122 fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
123 self,
124 hdl: &ProxyableHandle<H>,
125 ) -> StreamWriter<Msg>;
126}
127
128impl StreamWriterBinder for FramedStreamWriter {
129 fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
130 self,
131 hdl: &ProxyableHandle<H>,
132 ) -> StreamWriter<Msg> {
133 StreamWriter {
134 stream: self,
135 send_buffer: Vec::new(),
136 router: hdl.router().clone(),
137 closed: false,
138 _phantom_msg: std::marker::PhantomData,
139 }
140 }
141}
142
143#[derive(PartialEq, Debug)]
144pub(crate) enum Frame<'a, Msg: Message> {
145 Hello,
146 Data(&'a mut Msg),
147 SignalUpdate(SignalUpdate),
148 BeginTransfer(NodeId, TransferKey),
149 AckTransfer,
150 EndTransfer,
151 Shutdown(Result<(), zx_status::Status>),
152}
153
154#[derive(Debug)]
155pub(crate) struct StreamReader<Msg: Message> {
156 stream: FramedStreamReader,
157 incoming_message: Msg,
158 router: Weak<Router>,
159 state: ReadNextState<Msg::Parser>,
160}
161
162#[derive(Debug)]
163pub(crate) struct ReadNext<'a, Msg: Message> {
164 read_next_frame_or_peer_conn_ref: ReadNextFrameOrPeerConnRef<'a>,
165 state: &'a mut ReadNextState<Msg::Parser>,
166 conn: PeerConn,
167 incoming_message: Option<&'a mut Msg>,
168 router_holder: RouterHolder<'a>,
169}
170
171enum ReadNextFrameOrPeerConnRef<'a> {
172 ReadNextFrame(BoxFuture<'a, Result<FramedStreamReadResult, Error>>),
173 PeerConnRef(PeerConnRef<'a>),
174}
175
176impl std::fmt::Debug for ReadNextFrameOrPeerConnRef<'_> {
177 fn fmt(&self, writer: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
178 match self {
179 Self::ReadNextFrame(_) => {
180 write!(writer, "ReadNextFrameOrPeerConnRef::ReadNextFrame(...)")
181 }
182 Self::PeerConnRef(x) => {
183 write!(writer, "ReadNextFrameOrPeerConnRef::PeerConnRef({x:?})")
184 }
185 }
186 }
187}
188
189impl<'a> ReadNextFrameOrPeerConnRef<'a> {
190 fn as_read_next_frame_mut(
191 &mut self,
192 ) -> Option<&mut BoxFuture<'a, Result<FramedStreamReadResult, Error>>> {
193 match self {
194 Self::ReadNextFrame(x) => Some(x),
195 _ => None,
196 }
197 }
198}
199
200#[derive(Debug)]
201enum ReadNextState<Parser> {
202 Reading,
203 DeserializingData(Vec<u8>, Parser),
204}
205
206impl<'a, Msg: Message> ReadNext<'a, Msg> {
207 fn poll_inner(&mut self, ctx: &mut Context<'_>) -> Poll<Result<Frame<'a, Msg>, Error>> {
208 loop {
209 return Poll::Ready(Ok(match *self.state {
210 ReadNextState::Reading => {
211 let (frame_type, mut bytes) = match ready!(self
212 .read_next_frame_or_peer_conn_ref
213 .as_read_next_frame_mut()
214 .unwrap()
215 .poll_unpin(ctx))?
216 {
217 FramedStreamReadResult::Frame(frame_type, bytes) => (frame_type, bytes),
218 FramedStreamReadResult::Closed(Some(e)) => {
219 return Poll::Ready(Err(format_err!("unexpected end of stream ({e})")))
220 }
221 FramedStreamReadResult::Closed(None) => {
222 return Poll::Ready(Err(format_err!("unexpected end of stream")))
223 }
224 };
225
226 match frame_type {
227 FrameType::Hello => {
228 if bytes.len() != 0 {
229 return Poll::Ready(Err(format_err!("Hello frame must be empty")));
230 }
231 Frame::Hello
232 }
233 FrameType::Data => {
234 *self.state =
235 ReadNextState::DeserializingData(bytes, Msg::Parser::new());
236 continue;
237 }
238 FrameType::Signal => Frame::SignalUpdate(decode_fidl(&mut bytes)?),
239 FrameType::Control => match decode_fidl(&mut bytes)? {
240 StreamControl::AckTransfer(Empty {}) => Frame::AckTransfer,
241 StreamControl::EndTransfer(Empty {}) => Frame::EndTransfer,
242 StreamControl::Shutdown(status_code) => {
243 Frame::Shutdown(zx_status::Status::ok(status_code))
244 }
245
246 StreamControl::BeginTransfer(BeginTransfer {
247 new_destination_node,
248 transfer_key,
249 }) => Frame::BeginTransfer(new_destination_node.into(), transfer_key),
250 },
251 }
252 }
253 ReadNextState::DeserializingData(ref mut bytes, ref mut parser) => {
254 ready!(parser.poll_ser(
255 self.incoming_message.as_mut().unwrap(),
256 bytes,
257 self.conn.as_ref(),
258 &mut self.router_holder,
259 ctx,
260 ))?;
261 *self.state = ReadNextState::Reading;
262 Frame::Data(self.incoming_message.take().unwrap())
263 }
264 }));
265 }
266 }
267}
268
269impl<'a, Msg: Message> Future for ReadNext<'a, Msg> {
270 type Output = Result<Frame<'a, Msg>, Error>;
271 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
272 Pin::into_inner(self).poll_inner(ctx)
273 }
274}
275
276impl<Msg: Message> StreamReader<Msg> {
277 pub fn conn(&self) -> PeerConnRef<'_> {
278 self.stream.conn()
279 }
280
281 pub fn is_initiator(&self) -> bool {
282 self.stream.is_initiator()
283 }
284
285 pub fn next<'a>(&'a mut self) -> ReadNext<'a, Msg> {
286 let conn = self.stream.conn().into_peer_conn();
287 ReadNext {
288 read_next_frame_or_peer_conn_ref: match self.state {
289 ReadNextState::Reading => {
290 ReadNextFrameOrPeerConnRef::ReadNextFrame(self.stream.next().boxed())
291 }
292 ReadNextState::DeserializingData(_, _) => {
293 ReadNextFrameOrPeerConnRef::PeerConnRef(self.stream.conn())
294 }
295 },
296 state: &mut self.state,
297 conn,
298 incoming_message: Some(&mut self.incoming_message),
299 router_holder: RouterHolder::Unused(&self.router),
300 }
301 }
302
303 async fn expect(&mut self, frame: Frame<'_, Msg>) -> Result<(), Error> {
304 let received = self.next().await?;
305 if received != frame {
306 let msg = format_err!("Expected {:?} got {:?}", frame, received);
307 self.stream.abandon().await;
308 Err(msg)
309 } else {
310 Ok(())
311 }
312 }
313
314 pub async fn expect_ack_transfer(mut self) -> Result<(), Error> {
315 self.expect(Frame::AckTransfer).await
316 }
317
318 pub async fn expect_hello(&mut self) -> Result<(), Error> {
319 self.expect(Frame::Hello).await
320 }
321
322 pub async fn expect_shutdown(
323 mut self,
324 result: Result<(), zx_status::Status>,
325 ) -> Result<(), Error> {
326 self.expect(Frame::Shutdown(result)).await
327 }
328}
329
330pub(crate) trait StreamReaderBinder {
331 fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
332 self,
333 hdl: &ProxyableHandle<H>,
334 ) -> StreamReader<Msg>;
335}
336
337impl StreamReaderBinder for FramedStreamReader {
338 fn bind<Msg: Message, H: Proxyable<Message = Msg>>(
339 self,
340 hdl: &ProxyableHandle<H>,
341 ) -> StreamReader<Msg> {
342 StreamReader {
343 stream: self,
344 incoming_message: Default::default(),
345 router: hdl.router().clone(),
346 state: ReadNextState::Reading,
347 }
348 }
349}