1use super::PeerConnRef;
8use crate::labels::NodeId;
9use anyhow::{format_err, Error};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub(crate) enum FrameType {
14 Hello,
15 Data,
16 Control,
17 Signal,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22struct FrameHeader {
23 frame_type: FrameType,
25 length: usize,
28}
29
30const FRAME_HEADER_LENGTH: usize = 8;
32
33impl FrameHeader {
34 fn to_bytes(&self) -> Result<[u8; FRAME_HEADER_LENGTH], Error> {
35 let length = self.length;
36 if length > std::u32::MAX as usize {
37 return Err(anyhow::format_err!("Message too long: {}", length));
38 }
39 let length = length as u32;
40 let hdr: u64 = (length as u64)
41 | (match self.frame_type {
42 FrameType::Hello => 0,
43 FrameType::Data => 4,
44 FrameType::Control => 5,
45 FrameType::Signal => 6,
46 } << 32);
47 Ok(hdr.to_le_bytes())
48 }
49
50 fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
51 let hdr: &[u8; FRAME_HEADER_LENGTH] = bytes[0..FRAME_HEADER_LENGTH].try_into()?;
52 let hdr = u64::from_le_bytes(*hdr);
53 let length = (hdr & 0xffff_ffff) as usize;
54 let frame_type = match hdr >> 32 {
55 0 => FrameType::Hello,
56 1 | 2 | 3 => return Err(anyhow::format_err!("Frame with no persistence header")),
57 4 => FrameType::Data,
58 5 => FrameType::Control,
59 6 => FrameType::Signal,
60 _ => return Err(anyhow::format_err!("Unknown frame type {}", hdr >> 32)),
61 };
62 Ok(FrameHeader { frame_type, length })
63 }
64}
65
66#[derive(Debug)]
67pub(crate) struct FramedStreamWriter {
68 writer: circuit::stream::Writer,
70 id: u64,
72 conn: circuit::Connection,
74 peer_node_id: NodeId,
76}
77
78impl FramedStreamWriter {
79 pub fn from_circuit(
80 writer: circuit::stream::Writer,
81 id: u64,
82 conn: circuit::Connection,
83 peer_node_id: NodeId,
84 ) -> Self {
85 Self { writer, id, conn, peer_node_id }
86 }
87
88 pub async fn abandon(&mut self) {
89 let (_reader, dead_writer) = circuit::stream::stream();
90 self.writer = dead_writer;
91 }
92
93 pub fn conn(&self) -> PeerConnRef<'_> {
94 PeerConnRef::from_circuit(&self.conn, self.peer_node_id)
95 }
96
97 pub fn id(&self) -> u64 {
98 self.id
99 }
100
101 pub async fn send(&mut self, frame_type: FrameType, bytes: &[u8]) -> Result<(), Error> {
102 let r = self.send_inner(frame_type, bytes).await;
103 if r.is_err() {
104 self.abandon().await;
105 }
106 r
107 }
108
109 async fn send_inner(&mut self, frame_type: FrameType, bytes: &[u8]) -> Result<(), Error> {
110 let frame_len = bytes.len();
111 assert!(frame_len <= 0xffff_ffff);
112 let header = FrameHeader { frame_type, length: frame_len }.to_bytes()?;
113 log::trace!(header:?; "");
114 self.writer.write(header.len(), |buf| {
115 buf[..header.len()].copy_from_slice(&header);
116 Ok(header.len())
117 })?;
118
119 if !bytes.is_empty() {
120 self.writer.write(bytes.len(), |buf| {
121 buf[..bytes.len()].copy_from_slice(bytes);
122 Ok(bytes.len())
123 })?;
124 }
125 Ok(())
126 }
127}
128
129pub(crate) enum FramedStreamReadResult {
130 Frame(FrameType, Vec<u8>),
131 Closed(Option<String>),
132}
133
134#[derive(Debug)]
135pub(crate) struct FramedStreamReader {
136 reader: circuit::stream::Reader,
138 conn: circuit::Connection,
140 peer_node_id: NodeId,
142 read_state: ReadState,
144 hdr: [u8; FRAME_HEADER_LENGTH],
146}
147
148impl FramedStreamReader {
149 pub fn from_circuit(
150 reader: circuit::stream::Reader,
151 conn: circuit::Connection,
152 peer_node_id: NodeId,
153 ) -> Self {
154 Self {
155 reader,
156 conn,
157 peer_node_id,
158 read_state: ReadState::Initial,
159 hdr: [0u8; FRAME_HEADER_LENGTH],
160 }
161 }
162
163 pub(crate) async fn abandon(&mut self) {
164 let (dead_reader, _writer) = circuit::stream::stream();
165 self.reader = dead_reader;
166 }
167
168 pub fn conn(&self) -> PeerConnRef<'_> {
169 PeerConnRef::from_circuit(&self.conn, self.peer_node_id)
170 }
171
172 pub fn is_initiator(&self) -> bool {
173 self.conn.is_client()
174 }
175
176 pub(crate) async fn next<'b>(&'b mut self) -> Result<FramedStreamReadResult, Error> {
177 if let ReadState::Initial = self.read_state {
178 if !read_exact(&self.reader, &mut self.hdr).await? {
179 return Ok(FramedStreamReadResult::Closed(self.reader.closed_reason()));
180 }
181 let hdr = FrameHeader::from_bytes(&self.hdr)?;
182
183 if hdr.length == 0 {
184 return Ok(FramedStreamReadResult::Frame(hdr.frame_type, Vec::new()));
185 }
186
187 self.read_state = ReadState::GotHeader(hdr);
188 }
189
190 let ReadState::GotHeader(hdr) = &self.read_state else {
191 unreachable!();
192 };
193
194 let mut payload = vec![0; hdr.length];
195 payload.resize(hdr.length, 0);
196 if !read_exact(&self.reader, &mut payload).await? {
197 return Err(format_err!("Unexpected end of stream"));
198 }
199 let frame_type = hdr.frame_type;
200 self.read_state = ReadState::Initial;
201 Ok(FramedStreamReadResult::Frame(frame_type, payload))
202 }
203}
204
205async fn read_exact(reader: &circuit::stream::Reader, buf: &mut [u8]) -> Result<bool, Error> {
206 reader
207 .read(buf.len(), |input| {
208 buf.copy_from_slice(&input[..buf.len()]);
209 Ok((true, buf.len()))
210 })
211 .await
212 .or_else(|x| match x {
213 circuit::Error::ConnectionClosed(reason) => {
214 if let Some(reason) = reason {
215 log::debug!(reason:?; "");
216 }
217 Ok(false)
218 }
219 other => Err(other.into()),
220 })
221}
222
223#[derive(Debug)]
224enum ReadState {
225 Initial,
226 GotHeader(FrameHeader),
227}
228
229#[cfg(test)]
230mod test {
231 use super::*;
232
233 fn roundtrip(h: FrameHeader) {
234 assert_eq!(h, FrameHeader::from_bytes(&h.to_bytes().unwrap()).unwrap());
235 }
236
237 #[fuchsia::test]
238 fn roundtrips() {
239 roundtrip(FrameHeader { frame_type: FrameType::Data, length: 0 });
240 roundtrip(FrameHeader { frame_type: FrameType::Data, length: std::u32::MAX as usize });
241 }
242
243 #[fuchsia::test]
244 fn too_long() {
245 FrameHeader { frame_type: FrameType::Data, length: (std::u32::MAX as usize) + 1 }
246 .to_bytes()
247 .expect_err("Should fail");
248 }
249
250 #[fuchsia::test]
251 fn bad_frame_type() {
252 assert!(format!(
253 "{}",
254 FrameHeader::from_bytes(&[0, 0, 0, 0, 11, 0, 0, 0]).expect_err("should fail")
255 )
256 .contains("Unknown frame type 11"));
257 }
258}