circuit/
connection.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use crate::{stream, Error, Node, Result};

use futures::channel::mpsc::{channel, Receiver, Sender};
use futures::channel::oneshot;
use futures::lock::Mutex;
use futures::stream::Stream;
use futures::StreamExt;
use rand::random;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io::Write;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::Duration;

// We're stuffing enough u64s into streams that this is worth doing.
impl crate::protocol::ProtocolMessage for u64 {
    const MIN_SIZE: usize = 8;

    fn write_bytes<W: Write>(&self, out: &mut W) -> Result<usize> {
        out.write_all(&self.to_le_bytes())?;
        Ok(8)
    }

    fn byte_size(&self) -> usize {
        8
    }

    fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
        Ok((u64::from_le_bytes(bytes.try_into().map_err(|_| Error::BufferTooShort(8))?), 8))
    }
}

/// Entry in a stream map. See `StreamMap` below.
enum StreamMapEntry {
    /// The user is expecting the other end of the connection to start a stream with a certain ID,
    /// but we haven't actually seen the stream show up yet.
    Waiting(oneshot::Sender<(stream::Reader, stream::Writer)>),
    /// The other end of the connection has started a stream with a given ID, but we're still
    /// waiting on the user on this end to accept it by invoking `bind_stream()`.
    Ready(stream::Reader, stream::Writer),
    /// A stream with a given ID was started by the other end of the connection, and accepted by
    /// this end. If we see another stream with that ID something has gone wrong in the protocol
    /// state machine.
    Taken,
}

#[derive(Copy, Clone, Debug)]
enum ClientOrServer {
    Client,
    Server,
}

impl ClientOrServer {
    fn is_server(&self) -> bool {
        matches!(self, Self::Server)
    }
}

impl StreamMapEntry {
    /// Turns the `Ready` state into the `Taken` state and returns the reader and writer that were
    /// consumed. Returns `None` if we were not in the `Ready` state.
    fn take(&mut self) -> Option<(stream::Reader, stream::Writer)> {
        match std::mem::replace(self, Self::Taken) {
            Self::Waiting(_) | Self::Taken => None,
            Self::Ready(r, w) => Some((r, w)),
        }
    }

    /// Turns the `Waiting` state into the `Taken` state, passing the given reader and writer to the
    /// waiter. If we're not in the `Waiting` state this drops the reader and writer and does
    /// nothing.
    fn ready(&mut self, reader: stream::Reader, writer: stream::Writer) {
        if let Self::Waiting(sender) = std::mem::replace(self, Self::Taken) {
            if let Err((reader, writer)) = sender.send((reader, writer)) {
                *self = Self::Ready(reader, writer)
            }
        }
    }
}

/// A Mutex-protected map of stream IDs to an optional reader/writer pair for the stream.
type StreamMap = Mutex<HashMap<u64, StreamMapEntry>>;

/// A connection is a group of streams that span from one node to another on the circuit network.
///
/// When we establish a network of circuit nodes, we can create a stream from any one to any other.
/// These streams are independent, however; they have no name or identifier, nor anything else by
/// which you might group them.
///
/// A `Connection` is a link from a node to a peer that we can obtain streams from, just as we can
/// from the `Node` itself, but the streams have IDs which are in a namespace unique to the
/// connection. Multiple connections can exist per node, and each sees only the streams related to
/// itself.
///
/// There's a small bit of added protocol associated with this, so there will be some change to
/// traffic on the wire.
#[derive(Clone)]
pub struct Connection {
    id: u64,
    streams: Arc<StreamMap>,
    node: Arc<Node>,
    peer_node_id: String,
    next_stream_id: Arc<AtomicU64>,
}

impl Connection {
    pub async fn bind_stream(&self, id: u64) -> Option<(stream::Reader, stream::Writer)> {
        let receiver = {
            match self.streams.lock().await.entry(id) {
                Entry::Occupied(mut e) => return e.get_mut().take(),
                Entry::Vacant(v) => {
                    let (sender, receiver) = oneshot::channel();
                    v.insert(StreamMapEntry::Waiting(sender));
                    receiver
                }
            }
        };

        receiver.await.ok()
    }

    pub fn from(&self) -> &str {
        &self.peer_node_id
    }

    /// Create a new stream to the other end of this connection.
    pub async fn alloc_stream(
        &self,
        reader: stream::Reader,
        writer: stream::Writer,
    ) -> Result<u64> {
        let id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
        reader.push_back_protocol_message(&id)?;
        reader.push_back_protocol_message(&self.id)?;
        self.node.connect_to_peer(reader, writer, &self.peer_node_id).await?;
        Ok(id)
    }

    /// Whether this connection is a client (initiated by another node) as opposed to a server
    /// (initiated by this node).
    pub fn is_client(&self) -> bool {
        is_client_stream_id(self.next_stream_id.load(Ordering::Relaxed))
    }
}

impl std::fmt::Debug for Connection {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Connection({:#x} to {})", self.id, self.peer_node_id)
    }
}

/// Wrapper class for a `Node` that lets us create `Connection` objects instead of creating raw
/// streams on the node itself.
pub struct ConnectionNode {
    node: Arc<Node>,
    conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
}

/// Streams initiated by the client (the end of the connection that initiated the connection) should
/// be even, while streams initiated by the server (the end of the connection that received the
/// connection) should be odd.
fn is_client_stream_id(id: u64) -> bool {
    (id & 1) == 0
}

impl ConnectionNode {
    /// Create a new `ConnectionNode`. We can create a `Connection` object for any peer via this
    /// node, and we can then create streams to that peer and have that peer create streams to us.
    /// Unlike with a raw `Node`, those streams will be associated with only our connection object,
    /// and the peer will get a connection object that will return to it only the streams we create
    /// from this connection object.
    ///
    /// Returns both a `ConnectionNode`, and a `futures::stream::Stream` of `Connection` objects,
    /// which are produced by other nodes connecting to us.
    pub fn new(
        node_id: &str,
        protocol: &str,
        new_peer_sender: Sender<String>,
    ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
        let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
        let node = Arc::new(Node::new(node_id, protocol, new_peer_sender, incoming_stream_sender)?);
        let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));
        let conn_stream =
            conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver);

        Ok((ConnectionNode { node, conns }, conn_stream))
    }

    /// Like `ConnectionNode::new` but creates a router task as well. See `Node::new_with_router`.
    pub fn new_with_router(
        node_id: &str,
        protocol: &str,
        interval: Duration,
        new_peer_sender: Sender<String>,
    ) -> Result<(ConnectionNode, impl Stream<Item = Connection> + Send)> {
        let (incoming_stream_sender, incoming_stream_receiver) = channel(1);
        let (node, router) = Node::new_with_router(
            node_id,
            protocol,
            interval,
            new_peer_sender,
            incoming_stream_sender,
        )?;
        let node = Arc::new(node);
        let conns = Arc::new(Mutex::new(HashMap::<u64, (Weak<StreamMap>, ClientOrServer)>::new()));

        // With some cleverness, we can make polling the conn stream poll the router as a side effect.
        let conn_stream =
            conn_stream(Arc::downgrade(&node), Arc::clone(&conns), incoming_stream_receiver)
                .map(Some);
        let router_stream = futures::stream::once(router).map(|()| None);
        let conn_stream =
            futures::stream::select(conn_stream, router_stream).filter_map(|x| async move { x });

        Ok((ConnectionNode { node, conns }, conn_stream))
    }

    /// Establish a connection to a peer. The `connection_reader` and `connection_writer` will be
    /// used to service stream ID 0, which is always created when we start a connection.
    pub async fn connect_to_peer(
        &self,
        node_id: &str,
        connection_reader: stream::Reader,
        connection_writer: stream::Writer,
    ) -> Result<Connection> {
        if &*self.node.node_id() == node_id {
            return Err(Error::LoopbackUnsupported);
        }

        let id = random();

        // Create stream 0 automatically. This will have the side effect of verifying connectivity
        // to the node as well.
        connection_reader.push_back_protocol_message(&0u64)?;
        connection_reader.push_back_protocol_message(&id)?;
        self.node.connect_to_peer(connection_reader, connection_writer, node_id).await?;
        let streams = Arc::new(Mutex::new(HashMap::new()));

        self.conns.lock().await.insert(id, (Arc::downgrade(&streams), ClientOrServer::Client));

        Ok(Connection {
            id,
            streams,
            node: Arc::clone(&self.node),
            peer_node_id: node_id.to_string(),
            next_stream_id: Arc::new(AtomicU64::new(2)),
        })
    }

    /// Get this node as a plain old `Node`.
    pub fn node(&self) -> &Node {
        &*self.node
    }
}

/// Creates a futures::Stream that when polled, will yield incoming streams on a particular
/// connection.
///
/// Polling is also responsible for dispatching incoming streams from the node to existing
/// connections.
fn conn_stream(
    node: Weak<Node>,
    conns: Arc<Mutex<HashMap<u64, (Weak<StreamMap>, ClientOrServer)>>>,
    incoming_stream_receiver: Receiver<(stream::Reader, stream::Writer, String)>,
) -> impl Stream<Item = Connection> + Send {
    incoming_stream_receiver.filter_map(move |(reader, writer, peer_node_id)| {
        let conns = Arc::clone(&conns);
        let node = node.upgrade();
        async move {
            let node = node?;
            let got = reader
                .read(16, |buf| {
                    Ok((
                        (
                            u64::from_le_bytes(buf[..8].try_into().unwrap()),
                            u64::from_le_bytes(buf[8..16].try_into().unwrap()),
                        ),
                        16,
                    ))
                })
                .await;

            let (conn_id, stream_id) = match got {
                Ok(got) => got,
                Err(Error::ConnectionClosed(reason)) => {
                    let reason = reason.as_deref().unwrap_or("(No reason given)");
                    tracing::warn!(
                        reason,
                        "New stream closed without associating with a connection"
                    );
                    return None;
                }
                _ => unreachable!("Deserializing the connection ID should never fail!"),
            };

            let mut conns = conns.lock().await;

            if let Some((streams, client_or_server)) = conns.get(&conn_id) {
                if let Some(streams) = streams.upgrade() {
                    if is_client_stream_id(stream_id) == client_or_server.is_server() {
                        match streams.lock().await.entry(stream_id) {
                            Entry::Occupied(mut o) => o.get_mut().ready(reader, writer),
                            Entry::Vacant(v) => {
                                v.insert(StreamMapEntry::Ready(reader, writer));
                            }
                        }
                    } else {
                        tracing::warn!(stream_id, end = ?client_or_server.is_server(),
                        "Peer initiated stream ID which does not match role");
                    }
                }
                None
            } else if stream_id == 0 {
                let mut streams = HashMap::new();
                streams.insert(stream_id, StreamMapEntry::Ready(reader, writer));
                let streams = Arc::new(Mutex::new(streams));
                conns.insert(conn_id, (Arc::downgrade(&streams), ClientOrServer::Server));
                Some(Connection {
                    id: conn_id,
                    streams,
                    node,
                    peer_node_id,
                    next_stream_id: Arc::new(AtomicU64::new(1)),
                })
            } else {
                tracing::warn!(conn_id, stream_id, "Connection does not exist");
                None
            }
        }
    })
}