1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
45use crate::error::{Error, ExtendBufferTooShortError, Result};
6use std::io::Write;
78/// A serialization trait for objects which we convert to and from a wire format.
9pub trait ProtocolMessage: Sized {
10/// The minimum size an object of this type can have when serialized via `write_bytes`.
11const MIN_SIZE: usize;
12/// Encode this value into the given buffer as a stream of bytes.
13fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize>;
14/// Returns the size of the data that `write_bytes` will write for this value. Useful for buffer
15 /// allocation.
16fn byte_size(&self) -> usize;
17/// Try to read a serialized form of this value from the given buffer. On success returns both
18 /// the value and how many bytes were consumed. If we return `Error::BufferTooShort`, we may
19 /// have only part of the value and can try again with an extension of the same data.
20fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)>;
21}
2223/// We often encode strings on the wire as a 1-byte (u8) length followed by a stream of UTF-8
24/// characters. This has the restriction, of course, that the length of the string must fit in one
25/// byte.
26///
27/// EncodableString is a wrapper for strings that is only constructible if the string is short
28/// enough to be encoded in this way. It's otherwise mostly transparent, but handling strings
29/// through this type saves us a lot of error handling (or worse, unwraps!)
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct EncodableString(String);
3233impl ProtocolMessage for EncodableString {
34const MIN_SIZE: usize = 1;
35fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
36let len: u8 =
37self.0.as_bytes().len().try_into().expect("EncodableString wasn't encodable!");
38 out.write_all(&[len])?;
39 out.write_all(self.0.as_bytes())?;
40Ok(usize::from(len) + 1)
41 }
4243fn byte_size(&self) -> usize {
44self.0.as_bytes().len() + 1
45}
4647fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
48if bytes.is_empty() {
49Err(Error::BufferTooShort(1))
50 } else if bytes.len() - 1 < bytes[0] as usize {
51Err(Error::BufferTooShort(bytes[0] as usize + 1))
52 } else {
53let len = bytes[0] as usize;
54let bytes = &bytes[1..][..len];
55Ok((
56 std::str::from_utf8(bytes)
57 .map_err(|_| Error::BadUTF8(String::from_utf8_lossy(bytes).to_string()))?
58.to_owned()
59 .try_into()
60 .expect("String wasn't decodable right after encoding!"),
61 len + 1,
62 ))
63 }
64 }
65}
6667impl TryFrom<String> for EncodableString {
68type Error = Error;
69fn try_from(src: String) -> Result<EncodableString> {
70let _: u8 =
71 src.as_bytes().len().try_into().map_err(|_| Error::StringTooBig(src.clone()))?;
72Ok(EncodableString(src))
73 }
74}
7576impl std::ops::Deref for EncodableString {
77type Target = String;
78fn deref(&self) -> &String {
79&self.0
80}
81}
8283impl<T> PartialEq<T> for EncodableString
84where
85String: PartialEq<T>,
86{
87fn eq(&self, other: &T) -> bool {
88 PartialEq::eq(&self.0, other)
89 }
90}
9192impl std::fmt::Display for EncodableString {
93fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 std::fmt::Display::fmt(&self.0, f)
95 }
96}
9798/// The initial packet that goes out on the main control stream when a circuit node first connects
99/// to another. Contains basic version info, an implementation-usable string indicating what
100/// protocol will be run atop the circuit network.
101#[derive(Debug)]
102pub struct Identify {
103pub circuit_version: u8,
104pub protocol: EncodableString,
105}
106107impl Identify {
108/// Construct a new Identify header.
109pub fn new(protocol: EncodableString) -> Self {
110 Identify { circuit_version: crate::CIRCUIT_VERSION, protocol }
111 }
112}
113114impl ProtocolMessage for Identify {
115const MIN_SIZE: usize = 1 + EncodableString::MIN_SIZE;
116fn byte_size(&self) -> usize {
117self.protocol.byte_size() + 1
118}
119120fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
121let mut bytes = 0;
122 out.write_all(&[self.circuit_version])?;
123 bytes += 1;
124 bytes += self.protocol.write_bytes(out)?;
125Ok(bytes)
126 }
127128fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
129if bytes.len() < 2 {
130return Err(Error::BufferTooShort(2));
131 }
132133let circuit_version = bytes[0];
134let (protocol, proto_len) =
135 EncodableString::try_from_bytes(&bytes[1..]).extend_buffer_too_short(1)?;
136137Ok((Identify { circuit_version, protocol }, 1 + proto_len))
138 }
139}
140141/// Information about the quality of a link. A lower value for the contained u8 is better, with 0
142/// usually meaning a node linked to itself with no intermediate connection. The u8 value should
143/// never be 255 as this has a reserved meaning when we encode.
144#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
145pub struct Quality(u8);
146147impl Quality {
148/// Quality of connecting from a node to itself in a loop.
149pub const SELF: Quality = Quality(0);
150/// Quality of connecting two nodes in the same process directly.
151pub const IN_PROCESS: Quality = Quality(1);
152/// Quality of connecting two nodes over a local IPC mechanism.
153pub const LOCAL_SOCKET: Quality = Quality(2);
154/// Quality of connecting two nodes via a USB link.
155pub const USB: Quality = Quality(5);
156/// Quality of connecting two nodes over the network.
157pub const NETWORK: Quality = Quality(20);
158/// Worst quality value.
159pub const WORST: Quality = Quality(u8::MAX - 1);
160/// Unknonwn quality value.
161pub const UNKNOWN: Quality = Quality::WORST;
162163/// Add two quality values together. If we are routing a stream across two links, we can add the
164 /// quality of those links to get the quality of the combined link the stream is on.
165pub fn combine(self, other: Self) -> Self {
166 Quality(std::cmp::min(self.0.saturating_add(other.0), u8::MAX - 1))
167 }
168}
169170impl TryFrom<u8> for Quality {
171type Error = ();
172fn try_from(value: u8) -> Result<Self, Self::Error> {
173if value != u8::MAX {
174Ok(Quality(value))
175 } else {
176Err(())
177 }
178 }
179}
180181/// Information about the state of a node. We transmit this information from node to node so that
182/// each node knows what peers are available to establish circuits with.
183#[derive(Debug)]
184pub enum NodeState {
185/// Node is online.
186Online(EncodableString, Quality),
187/// Node is offline.
188Offline(EncodableString),
189}
190191impl NodeState {
192/// Same as `write_bytes` but specifically takes a vector, and thus cannot return an error.
193pub fn write_bytes_vec(&self, out: &mut Vec<u8>) -> usize {
194self.write_bytes(out).expect("Write to vector should't fail but did!")
195 }
196}
197198impl ProtocolMessage for NodeState {
199const MIN_SIZE: usize = 1 + EncodableString::MIN_SIZE;
200fn byte_size(&self) -> usize {
201let s = match self {
202 NodeState::Online(s, _) => s,
203 NodeState::Offline(s) => s,
204 };
205206 s.byte_size() + 1
207}
208209fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
210let (st, speed) = match self {
211 NodeState::Online(s, quality) => {
212debug_assert!(quality.0 != u8::MAX);
213 (s, quality.0)
214 }
215 NodeState::Offline(s) => (s, u8::MAX),
216 };
217let mut bytes = 0;
218 out.write_all(&[speed])?;
219 bytes += 1;
220 bytes += st.write_bytes(out)?;
221Ok(bytes)
222 }
223224fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
225if bytes.len() < 2 {
226return Err(Error::BufferTooShort(2));
227 }
228229let quality = bytes[0];
230let (node, node_len) =
231 EncodableString::try_from_bytes(&bytes[1..]).extend_buffer_too_short(1)?;
232233let state = if let Ok(quality) = quality.try_into() {
234 NodeState::Online(node, quality)
235 } else {
236 NodeState::Offline(node)
237 };
238239Ok((state, 1 + node_len))
240 }
241}