1use crate::{stream, Error, Node, Result};
6
7use futures::channel::mpsc::{channel, Receiver, Sender};
8use futures::channel::oneshot;
9use futures::lock::Mutex;
10use futures::stream::Stream;
11use futures::StreamExt;
12use rand::random;
13use std::collections::hash_map::Entry;
14use std::collections::HashMap;
15use std::io::Write;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, Weak};
18use std::time::Duration;
19
20impl crate::protocol::ProtocolMessage for u64 {
22 const MIN_SIZE: usize = 8;
23
24 fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
25 out.write_all(&self.to_le_bytes())?;
26 Ok(8)
27 }
28
29 fn byte_size(&self) -> usize {
30 8
31 }
32
33 fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
34 Ok((u64::from_le_bytes(bytes.try_into().map_err(|_| Error::BufferTooShort(8))?), 8))
35 }
36}
37
38enum StreamMapEntry {
40 Waiting(oneshot::Sender<(stream::Reader, stream::Writer)>),
43 Ready(stream::Reader, stream::Writer),
46 Taken,
50}
51
52#[derive(Copy, Clone, Debug)]
53enum ClientOrServer {
54 Client,
55 Server,
56}
57
58impl ClientOrServer {
59 fn is_server(&self) -> bool {
60 matches!(self, Self::Server)
61 }
62}
63
64impl StreamMapEntry {
65 fn take(&mut self) -> Option<(stream::Reader, stream::Writer)> {
68 match std::mem::replace(self, Self::Taken) {
69 Self::Waiting(_) | Self::Taken => None,
70 Self::Ready(r, w) => Some((r, w)),
71 }
72 }
73
74 fn ready(&mut self, reader: stream::Reader, writer: stream::Writer) {
78 if let Self::Waiting(sender) = std::mem::replace(self, Self::Taken) {
79 if let Err((reader, writer)) = sender.send((reader, writer)) {
80 *self = Self::Ready(reader, writer)
81 }
82 }
83 }
84}
85
86type StreamMap = Mutex<HashMap<u64, StreamMapEntry>>;
88
89#[derive(Clone)]
103pub struct Connection {
104 id: u64,
105 streams: Arc<StreamMap>,
106 node: Arc<Node>,
107 peer_node_id: String,
108 next_stream_id: Arc<AtomicU64>,
109}
110
111impl Connection {
112 pub async fn bind_stream(&self, id: u64) -> Option<(stream::Reader, stream::Writer)> {
113 let receiver = {
114 match self.streams.lock().await.entry(id) {
115 Entry::Occupied(mut e) => return e.get_mut().take(),
116 Entry::Vacant(v) => {
117 let (sender, receiver) = oneshot::channel();
118 v.insert(StreamMapEntry::Waiting(sender));
119 receiver
120 }
121 }
122 };
123
124 receiver.await.ok()
125 }
126
127 pub fn from(&self) -> &str {
128 &self.peer_node_id
129 }
130
131 pub async fn alloc_stream(
133 &self,
134 reader: stream::Reader,
135 writer: stream::Writer,
136 ) -> Result<u64> {
137 let id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
138 reader.push_back_protocol_message(&id)?;
139 reader.push_back_protocol_message(&self.id)?;
140 self.node.connect_to_peer(reader, writer, &self.peer_node_id).await?;
141 Ok(id)
142 }
143
144 pub fn is_client(&self) -> bool {
147 is_client_stream_id(self.next_stream_id.load(Ordering::Relaxed))
148 }
149}
150
151impl std::fmt::Debug for Connection {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "Connection({:#x} to {})", self.id, self.peer_node_id)
154 }
155}
156
157pub struct ConnectionNode {
160 node: Arc<Node>,
161 conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
162}
163
164fn is_client_stream_id(id: u64) -> bool {
168 (id & 1) == 0
169}
170
171impl ConnectionNode {
172 pub fn new(
181 node_id: &str,
182 protocol: &str,
183 new_peer_sender: Sender<String>,
184 ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
185 let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
186 let node = Arc::new(Node::new(node_id, protocol, new_peer_sender, incoming_stream_sender)?);
187 let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));
188 let conn_stream =
189 conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver);
190
191 Ok((ConnectionNode { node, conns }, conn_stream))
192 }
193
194 pub fn new_with_router(
196 node_id: &str,
197 protocol: &str,
198 interval: Duration,
199 new_peer_sender: Sender<String>,
200 ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
201 let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
202 let (node, router) = Node::new_with_router(
203 node_id,
204 protocol,
205 interval,
206 new_peer_sender,
207 incoming_stream_sender,
208 )?;
209 let node = Arc::new(node);
210 let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));
211
212 let conn_stream =
214 conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver)
215 .map(Some);
216 let router_stream = futures::stream::once(router).map(|()| None);
217 let conn_stream =
218 futures::stream::select(conn_stream, router_stream).filter_map(|x| async move { x });
219
220 Ok((ConnectionNode { node, conns }, conn_stream))
221 }
222
223 pub async fn connect_to_peer(
226 &self,
227 node_id: &str,
228 connection_reader: stream::Reader,
229 connection_writer: stream::Writer,
230 ) -> Result<Connection> {
231 if &*self.node.node_id() == node_id {
232 return Err(Error::LoopbackUnsupported);
233 }
234
235 let id = random();
236
237 connection_reader.push_back_protocol_message(&0u64)?;
240 connection_reader.push_back_protocol_message(&id)?;
241 self.node.connect_to_peer(connection_reader, connection_writer, node_id).await?;
242 let streams = Arc::new(Mutex::new(HashMap::new()));
243
244 self.conns.lock().await.insert(id, (Arc::downgrade(&streams), ClientOrServer::Client));
245
246 Ok(Connection {
247 id,
248 streams,
249 node: Arc::clone(&self.node),
250 peer_node_id: node_id.to_string(),
251 next_stream_id: Arc::new(AtomicU64::new(2)),
252 })
253 }
254
255 pub fn node(&self) -> &Node {
257 &*self.node
258 }
259}
260
261fn conn_stream(
267 node: Weak<Node>,
268 conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
269 incoming_stream_receiver: Receiver<(stream::Reader, stream::Writer, String)>,
270) -> impl Stream<Item = Connection> + Send {
271 incoming_stream_receiver.filter_map(move |(reader, writer, peer_node_id)| {
272 let conns = Arc::clone(&conns);
273 let node = node.upgrade();
274 async move {
275 let node = node?;
276 let got = reader
277 .read(16, |buf| {
278 Ok((
279 (
280 u64::from_le_bytes(buf[..8].try_into().unwrap()),
281 u64::from_le_bytes(buf[8..16].try_into().unwrap()),
282 ),
283 16,
284 ))
285 })
286 .await;
287
288 let (conn_id, stream_id) = match got {
289 Ok(got) => got,
290 Err(Error::ConnectionClosed(reason)) => {
291 let reason = reason.as_deref().unwrap_or("(No reason given)");
292 log::warn!(reason; "New stream closed without associating with a connection");
293 return None;
294 }
295 _ => unreachable!("Deserializing the connection ID should never fail!"),
296 };
297
298 let mut conns = conns.lock().await;
299
300 if let Some((streams, client_or_server)) = conns.get(&conn_id) {
301 if let Some(streams) = streams.upgrade() {
302 if is_client_stream_id(stream_id) == client_or_server.is_server() {
303 match streams.lock().await.entry(stream_id) {
304 Entry::Occupied(mut o) => o.get_mut().ready(reader, writer),
305 Entry::Vacant(v) => {
306 v.insert(StreamMapEntry::Ready(reader, writer));
307 }
308 }
309 } else {
310 log::warn!(stream_id, end:? = client_or_server.is_server();
311 "Peer initiated stream ID which does not match role");
312 }
313 }
314 None
315 } else if stream_id == 0 {
316 let mut streams = HashMap::new();
317 streams.insert(stream_id, StreamMapEntry::Ready(reader, writer));
318 let streams = Arc::new(Mutex::new(streams));
319 conns.insert(conn_id, (Arc::downgrade(&streams), ClientOrServer::Server));
320 Some(Connection {
321 id: conn_id,
322 streams,
323 node,
324 peer_node_id,
325 next_stream_id: Arc::new(AtomicU64::new(1)),
326 })
327 } else {
328 log::warn!(conn_id, stream_id; "Connection does not exist");
329 None
330 }
331 }
332 })
333}