overnet_core/peer/
framed_stream.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Framing and deframing datagrams onto QUIC streams
6
7use super::PeerConnRef;
8use crate::labels::NodeId;
9use anyhow::{format_err, Error};
10
11/// The type of frame that can be received on a QUIC stream
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub(crate) enum FrameType {
14    Hello,
15    Data,
16    Control,
17    Signal,
18}
19
20/// Header for one frame of data on a QUIC stream
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22struct FrameHeader {
23    /// Type of the frame
24    frame_type: FrameType,
25    /// Length of the frame (usize here to avoid casts in client code; this is checked to fit in a
26    /// u32 before serialization)
27    length: usize,
28}
29
30/// Length of the header for a frame.
31const FRAME_HEADER_LENGTH: usize = 8;
32
33impl FrameHeader {
34    fn to_bytes(&self) -> Result<[u8; FRAME_HEADER_LENGTH], Error> {
35        let length = self.length;
36        if length > std::u32::MAX as usize {
37            return Err(anyhow::format_err!("Message too long: {}", length));
38        }
39        let length = length as u32;
40        let hdr: u64 = (length as u64)
41            | (match self.frame_type {
42                FrameType::Hello => 0,
43                FrameType::Data => 4,
44                FrameType::Control => 5,
45                FrameType::Signal => 6,
46            } << 32);
47        Ok(hdr.to_le_bytes())
48    }
49
50    fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
51        let hdr: &[u8; FRAME_HEADER_LENGTH] = bytes[0..FRAME_HEADER_LENGTH].try_into()?;
52        let hdr = u64::from_le_bytes(*hdr);
53        let length = (hdr & 0xffff_ffff) as usize;
54        let frame_type = match hdr >> 32 {
55            0 => FrameType::Hello,
56            1 | 2 | 3 => return Err(anyhow::format_err!("Frame with no persistence header")),
57            4 => FrameType::Data,
58            5 => FrameType::Control,
59            6 => FrameType::Signal,
60            _ => return Err(anyhow::format_err!("Unknown frame type {}", hdr >> 32)),
61        };
62        Ok(FrameHeader { frame_type, length })
63    }
64}
65
66#[derive(Debug)]
67pub(crate) struct FramedStreamWriter {
68    /// Underlying writer
69    writer: circuit::stream::Writer,
70    /// The circuit's ID number
71    id: u64,
72    /// The connection supporting the writer,
73    conn: circuit::Connection,
74    /// The peer node id
75    peer_node_id: NodeId,
76}
77
78impl FramedStreamWriter {
79    pub fn from_circuit(
80        writer: circuit::stream::Writer,
81        id: u64,
82        conn: circuit::Connection,
83        peer_node_id: NodeId,
84    ) -> Self {
85        Self { writer, id, conn, peer_node_id }
86    }
87
88    pub async fn abandon(&mut self) {
89        let (_reader, dead_writer) = circuit::stream::stream();
90        self.writer = dead_writer;
91    }
92
93    pub fn conn(&self) -> PeerConnRef<'_> {
94        PeerConnRef::from_circuit(&self.conn, self.peer_node_id)
95    }
96
97    pub fn id(&self) -> u64 {
98        self.id
99    }
100
101    pub async fn send(&mut self, frame_type: FrameType, bytes: &[u8]) -> Result<(), Error> {
102        let r = self.send_inner(frame_type, bytes).await;
103        if r.is_err() {
104            self.abandon().await;
105        }
106        r
107    }
108
109    async fn send_inner(&mut self, frame_type: FrameType, bytes: &[u8]) -> Result<(), Error> {
110        let frame_len = bytes.len();
111        assert!(frame_len <= 0xffff_ffff);
112        let header = FrameHeader { frame_type, length: frame_len }.to_bytes()?;
113        log::trace!(header:?; "");
114        self.writer.write(header.len(), |buf| {
115            buf[..header.len()].copy_from_slice(&header);
116            Ok(header.len())
117        })?;
118
119        if !bytes.is_empty() {
120            self.writer.write(bytes.len(), |buf| {
121                buf[..bytes.len()].copy_from_slice(bytes);
122                Ok(bytes.len())
123            })?;
124        }
125        Ok(())
126    }
127}
128
129pub(crate) enum FramedStreamReadResult {
130    Frame(FrameType, Vec<u8>),
131    Closed(Option<String>),
132}
133
134#[derive(Debug)]
135pub(crate) struct FramedStreamReader {
136    /// The underlying reader,
137    reader: circuit::stream::Reader,
138    /// The connection supporting th reader.
139    conn: circuit::Connection,
140    /// Peer node id
141    peer_node_id: NodeId,
142    /// Current read state
143    read_state: ReadState,
144    /// Scratch space for reading the frame header
145    hdr: [u8; FRAME_HEADER_LENGTH],
146}
147
148impl FramedStreamReader {
149    pub fn from_circuit(
150        reader: circuit::stream::Reader,
151        conn: circuit::Connection,
152        peer_node_id: NodeId,
153    ) -> Self {
154        Self {
155            reader,
156            conn,
157            peer_node_id,
158            read_state: ReadState::Initial,
159            hdr: [0u8; FRAME_HEADER_LENGTH],
160        }
161    }
162
163    pub(crate) async fn abandon(&mut self) {
164        let (dead_reader, _writer) = circuit::stream::stream();
165        self.reader = dead_reader;
166    }
167
168    pub fn conn(&self) -> PeerConnRef<'_> {
169        PeerConnRef::from_circuit(&self.conn, self.peer_node_id)
170    }
171
172    pub fn is_initiator(&self) -> bool {
173        self.conn.is_client()
174    }
175
176    pub(crate) async fn next<'b>(&'b mut self) -> Result<FramedStreamReadResult, Error> {
177        if let ReadState::Initial = self.read_state {
178            if !read_exact(&self.reader, &mut self.hdr).await? {
179                return Ok(FramedStreamReadResult::Closed(self.reader.closed_reason()));
180            }
181            let hdr = FrameHeader::from_bytes(&self.hdr)?;
182
183            if hdr.length == 0 {
184                return Ok(FramedStreamReadResult::Frame(hdr.frame_type, Vec::new()));
185            }
186
187            self.read_state = ReadState::GotHeader(hdr);
188        }
189
190        let ReadState::GotHeader(hdr) = &self.read_state else {
191            unreachable!();
192        };
193
194        let mut payload = vec![0; hdr.length];
195        payload.resize(hdr.length, 0);
196        if !read_exact(&self.reader, &mut payload).await? {
197            return Err(format_err!("Unexpected end of stream"));
198        }
199        let frame_type = hdr.frame_type;
200        self.read_state = ReadState::Initial;
201        Ok(FramedStreamReadResult::Frame(frame_type, payload))
202    }
203}
204
205async fn read_exact(reader: &circuit::stream::Reader, buf: &mut [u8]) -> Result<bool, Error> {
206    reader
207        .read(buf.len(), |input| {
208            buf.copy_from_slice(&input[..buf.len()]);
209            Ok((true, buf.len()))
210        })
211        .await
212        .or_else(|x| match x {
213            circuit::Error::ConnectionClosed(reason) => {
214                if let Some(reason) = reason {
215                    log::debug!(reason:?; "");
216                }
217                Ok(false)
218            }
219            other => Err(other.into()),
220        })
221}
222
223#[derive(Debug)]
224enum ReadState {
225    Initial,
226    GotHeader(FrameHeader),
227}
228
229#[cfg(test)]
230mod test {
231    use super::*;
232
233    fn roundtrip(h: FrameHeader) {
234        assert_eq!(h, FrameHeader::from_bytes(&h.to_bytes().unwrap()).unwrap());
235    }
236
237    #[fuchsia::test]
238    fn roundtrips() {
239        roundtrip(FrameHeader { frame_type: FrameType::Data, length: 0 });
240        roundtrip(FrameHeader { frame_type: FrameType::Data, length: std::u32::MAX as usize });
241    }
242
243    #[fuchsia::test]
244    fn too_long() {
245        FrameHeader { frame_type: FrameType::Data, length: (std::u32::MAX as usize) + 1 }
246            .to_bytes()
247            .expect_err("Should fail");
248    }
249
250    #[fuchsia::test]
251    fn bad_frame_type() {
252        assert!(format!(
253            "{}",
254            FrameHeader::from_bytes(&[0, 0, 0, 0, 11, 0, 0, 0]).expect_err("should fail")
255        )
256        .contains("Unknown frame type 11"));
257    }
258}