circuit/
connection.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::{stream, Error, Node, Result};
6
7use futures::channel::mpsc::{channel, Receiver, Sender};
8use futures::channel::oneshot;
9use futures::lock::Mutex;
10use futures::stream::Stream;
11use futures::StreamExt;
12use rand::random;
13use std::collections::hash_map::Entry;
14use std::collections::HashMap;
15use std::io::Write;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, Weak};
18use std::time::Duration;
19
20// We're stuffing enough u64s into streams that this is worth doing.
21impl crate::protocol::ProtocolMessage for u64 {
22    const MIN_SIZE: usize = 8;
23
24    fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
25        out.write_all(&self.to_le_bytes())?;
26        Ok(8)
27    }
28
29    fn byte_size(&self) -> usize {
30        8
31    }
32
33    fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
34        Ok((u64::from_le_bytes(bytes.try_into().map_err(|_| Error::BufferTooShort(8))?), 8))
35    }
36}
37
38/// Entry in a stream map. See `StreamMap` below.
39enum StreamMapEntry {
40    /// The user is expecting the other end of the connection to start a stream with a certain ID,
41    /// but we haven't actually seen the stream show up yet.
42    Waiting(oneshot::Sender<(stream::Reader, stream::Writer)>),
43    /// The other end of the connection has started a stream with a given ID, but we're still
44    /// waiting on the user on this end to accept it by invoking `bind_stream()`.
45    Ready(stream::Reader, stream::Writer),
46    /// A stream with a given ID was started by the other end of the connection, and accepted by
47    /// this end. If we see another stream with that ID something has gone wrong in the protocol
48    /// state machine.
49    Taken,
50}
51
52#[derive(Copy, Clone, Debug)]
53enum ClientOrServer {
54    Client,
55    Server,
56}
57
58impl ClientOrServer {
59    fn is_server(&self) -> bool {
60        matches!(self, Self::Server)
61    }
62}
63
64impl StreamMapEntry {
65    /// Turns the `Ready` state into the `Taken` state and returns the reader and writer that were
66    /// consumed. Returns `None` if we were not in the `Ready` state.
67    fn take(&mut self) -> Option<(stream::Reader, stream::Writer)> {
68        match std::mem::replace(self, Self::Taken) {
69            Self::Waiting(_) | Self::Taken => None,
70            Self::Ready(r, w) => Some((r, w)),
71        }
72    }
73
74    /// Turns the `Waiting` state into the `Taken` state, passing the given reader and writer to the
75    /// waiter. If we're not in the `Waiting` state this drops the reader and writer and does
76    /// nothing.
77    fn ready(&mut self, reader: stream::Reader, writer: stream::Writer) {
78        if let Self::Waiting(sender) = std::mem::replace(self, Self::Taken) {
79            if let Err((reader, writer)) = sender.send((reader, writer)) {
80                *self = Self::Ready(reader, writer)
81            }
82        }
83    }
84}
85
86/// A Mutex-protected map of stream IDs to an optional reader/writer pair for the stream.
87type StreamMap = Mutex<HashMap<u64, StreamMapEntry>>;
88
89/// A connection is a group of streams that span from one node to another on the circuit network.
90///
91/// When we establish a network of circuit nodes, we can create a stream from any one to any other.
92/// These streams are independent, however; they have no name or identifier, nor anything else by
93/// which you might group them.
94///
95/// A `Connection` is a link from a node to a peer that we can obtain streams from, just as we can
96/// from the `Node` itself, but the streams have IDs which are in a namespace unique to the
97/// connection. Multiple connections can exist per node, and each sees only the streams related to
98/// itself.
99///
100/// There's a small bit of added protocol associated with this, so there will be some change to
101/// traffic on the wire.
102#[derive(Clone)]
103pub struct Connection {
104    id: u64,
105    streams: Arc<StreamMap>,
106    node: Arc<Node>,
107    peer_node_id: String,
108    next_stream_id: Arc<AtomicU64>,
109}
110
111impl Connection {
112    pub async fn bind_stream(&self, id: u64) -> Option<(stream::Reader, stream::Writer)> {
113        let receiver = {
114            match self.streams.lock().await.entry(id) {
115                Entry::Occupied(mut e) => return e.get_mut().take(),
116                Entry::Vacant(v) => {
117                    let (sender, receiver) = oneshot::channel();
118                    v.insert(StreamMapEntry::Waiting(sender));
119                    receiver
120                }
121            }
122        };
123
124        receiver.await.ok()
125    }
126
127    pub fn from(&self) -> &str {
128        &self.peer_node_id
129    }
130
131    /// Create a new stream to the other end of this connection.
132    pub async fn alloc_stream(
133        &self,
134        reader: stream::Reader,
135        writer: stream::Writer,
136    ) -> Result<u64> {
137        let id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
138        reader.push_back_protocol_message(&id)?;
139        reader.push_back_protocol_message(&self.id)?;
140        self.node.connect_to_peer(reader, writer, &self.peer_node_id).await?;
141        Ok(id)
142    }
143
144    /// Whether this connection is a client (initiated by another node) as opposed to a server
145    /// (initiated by this node).
146    pub fn is_client(&self) -> bool {
147        is_client_stream_id(self.next_stream_id.load(Ordering::Relaxed))
148    }
149}
150
151impl std::fmt::Debug for Connection {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(f, "Connection({:#x} to {})", self.id, self.peer_node_id)
154    }
155}
156
157/// Wrapper class for a `Node` that lets us create `Connection` objects instead of creating raw
158/// streams on the node itself.
159pub struct ConnectionNode {
160    node: Arc<Node>,
161    conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
162}
163
164/// Streams initiated by the client (the end of the connection that initiated the connection) should
165/// be even, while streams initiated by the server (the end of the connection that received the
166/// connection) should be odd.
167fn is_client_stream_id(id: u64) -> bool {
168    (id & 1) == 0
169}
170
171impl ConnectionNode {
172    /// Create a new `ConnectionNode`. We can create a `Connection` object for any peer via this
173    /// node, and we can then create streams to that peer and have that peer create streams to us.
174    /// Unlike with a raw `Node`, those streams will be associated with only our connection object,
175    /// and the peer will get a connection object that will return to it only the streams we create
176    /// from this connection object.
177    ///
178    /// Returns both a `ConnectionNode`, and a `futures::stream::Stream` of `Connection` objects,
179    /// which are produced by other nodes connecting to us.
180    pub fn new(
181        node_id: &str,
182        protocol: &str,
183        new_peer_sender: Sender<String>,
184    ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
185        let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
186        let node = Arc::new(Node::new(node_id, protocol, new_peer_sender, incoming_stream_sender)?);
187        let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));
188        let conn_stream =
189            conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver);
190
191        Ok((ConnectionNode { node, conns }, conn_stream))
192    }
193
194    /// Like `ConnectionNode::new` but creates a router task as well. See `Node::new_with_router`.
195    pub fn new_with_router(
196        node_id: &str,
197        protocol: &str,
198        interval: Duration,
199        new_peer_sender: Sender<String>,
200    ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
201        let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
202        let (node, router) = Node::new_with_router(
203            node_id,
204            protocol,
205            interval,
206            new_peer_sender,
207            incoming_stream_sender,
208        )?;
209        let node = Arc::new(node);
210        let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));
211
212        // With some cleverness, we can make polling the conn stream poll the router as a side effect.
213        let conn_stream =
214            conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver)
215                .map(Some);
216        let router_stream = futures::stream::once(router).map(|()| None);
217        let conn_stream =
218            futures::stream::select(conn_stream, router_stream).filter_map(|x| async move { x });
219
220        Ok((ConnectionNode { node, conns }, conn_stream))
221    }
222
223    /// Establish a connection to a peer. The `connection_reader` and `connection_writer` will be
224    /// used to service stream ID 0, which is always created when we start a connection.
225    pub async fn connect_to_peer(
226        &self,
227        node_id: &str,
228        connection_reader: stream::Reader,
229        connection_writer: stream::Writer,
230    ) -> Result<Connection> {
231        if &*self.node.node_id() == node_id {
232            return Err(Error::LoopbackUnsupported);
233        }
234
235        let id = random();
236
237        // Create stream 0 automatically. This will have the side effect of verifying connectivity
238        // to the node as well.
239        connection_reader.push_back_protocol_message(&0u64)?;
240        connection_reader.push_back_protocol_message(&id)?;
241        self.node.connect_to_peer(connection_reader, connection_writer, node_id).await?;
242        let streams = Arc::new(Mutex::new(HashMap::new()));
243
244        self.conns.lock().await.insert(id, (Arc::downgrade(&streams), ClientOrServer::Client));
245
246        Ok(Connection {
247            id,
248            streams,
249            node: Arc::clone(&self.node),
250            peer_node_id: node_id.to_string(),
251            next_stream_id: Arc::new(AtomicU64::new(2)),
252        })
253    }
254
255    /// Get this node as a plain old `Node`.
256    pub fn node(&self) -> &Node {
257        &*self.node
258    }
259}
260
261/// Creates a futures::Stream that when polled, will yield incoming streams on a particular
262/// connection.
263///
264/// Polling is also responsible for dispatching incoming streams from the node to existing
265/// connections.
266fn conn_stream(
267    node: Weak<Node>,
268    conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
269    incoming_stream_receiver: Receiver<(stream::Reader, stream::Writer, String)>,
270) -> impl Stream<Item = Connection> + Send {
271    incoming_stream_receiver.filter_map(move |(reader, writer, peer_node_id)| {
272        let conns = Arc::clone(&conns);
273        let node = node.upgrade();
274        async move {
275            let node = node?;
276            let got = reader
277                .read(16, |buf| {
278                    Ok((
279                        (
280                            u64::from_le_bytes(buf[..8].try_into().unwrap()),
281                            u64::from_le_bytes(buf[8..16].try_into().unwrap()),
282                        ),
283                        16,
284                    ))
285                })
286                .await;
287
288            let (conn_id, stream_id) = match got {
289                Ok(got) => got,
290                Err(Error::ConnectionClosed(reason)) => {
291                    let reason = reason.as_deref().unwrap_or("(No reason given)");
292                    log::warn!(reason; "New stream closed without associating with a connection");
293                    return None;
294                }
295                _ => unreachable!("Deserializing the connection ID should never fail!"),
296            };
297
298            let mut conns = conns.lock().await;
299
300            if let Some((streams, client_or_server)) = conns.get(&conn_id) {
301                if let Some(streams) = streams.upgrade() {
302                    if is_client_stream_id(stream_id) == client_or_server.is_server() {
303                        match streams.lock().await.entry(stream_id) {
304                            Entry::Occupied(mut o) => o.get_mut().ready(reader, writer),
305                            Entry::Vacant(v) => {
306                                v.insert(StreamMapEntry::Ready(reader, writer));
307                            }
308                        }
309                    } else {
310                        log::warn!(stream_id, end:? = client_or_server.is_server();
311                        "Peer initiated stream ID which does not match role");
312                    }
313                }
314                None
315            } else if stream_id == 0 {
316                let mut streams = HashMap::new();
317                streams.insert(stream_id, StreamMapEntry::Ready(reader, writer));
318                let streams = Arc::new(Mutex::new(streams));
319                conns.insert(conn_id, (Arc::downgrade(&streams), ClientOrServer::Server));
320                Some(Connection {
321                    id: conn_id,
322                    streams,
323                    node,
324                    peer_node_id,
325                    next_stream_id: Arc::new(AtomicU64::new(1)),
326                })
327            } else {
328                log::warn!(conn_id, stream_id; "Connection does not exist");
329                None
330            }
331        }
332    })
333}