circuit/
stream.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use crate::error::{Error, Result};
use crate::protocol;

use std::collections::VecDeque;
use std::sync::{Arc, Mutex as SyncMutex};
use tokio::sync::oneshot;

/// We shrink our internal buffers until they are no more than this much larger
/// than the actual data we are accumulating.
///
/// The value of 1MiB was discovered experimentally to yield acceptable
/// performance on tests.
const BUFFER_TRIM_GRANULARITY: usize = 1048576;

/// BUFFER_TRIM_GRANULARITY should be a power of 2.
const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());

/// Indicates whether a stream is open or closed, and if closed, why it closed.
#[derive(Debug, Clone)]
enum Status {
    /// Stream is open.
    Open,
    /// Stream is closed. Argument may contain a reason for closure.
    Closed(Option<String>),
}

impl Status {
    fn is_closed(&self) -> bool {
        match self {
            Status::Open => false,
            Status::Closed(_) => true,
        }
    }

    fn reason(&self) -> Option<String> {
        match self {
            Status::Open => None,
            Status::Closed(x) => x.clone(),
        }
    }

    fn close(&mut self) {
        if let Status::Open = self {
            *self = Status::Closed(None);
        }
    }
}

/// Internal state of a stream. See `stream()`.
#[derive(Debug)]
struct State {
    /// The ring buffer itself.
    deque: VecDeque<u8>,
    /// How many bytes are readable. This is different from the length of the deque, as we may allow
    /// bytes to be in the deque that are "initialized but unavailable." Mostly that just
    /// accommodates a quirk of Rust's memory model; we have to initialize all bytes before we show
    /// them to the user, even if they're there just to be overwritten, and if we pop the bytes out
    /// of the deque Rust counts them as uninitialized again, so to avoid duplicating the
    /// initialization process we just leave the initialized-but-unwritten bytes in the deque.
    readable: usize,
    /// If the reader needs to sleep, it puts a oneshot sender here so it can be woken up again. It
    /// also lists how many bytes should be available before it should be woken up.
    notify_readable: Option<(oneshot::Sender<()>, usize)>,
    /// Whether this stream is closed. I.e. whether either the `Reader` or `Writer` has been dropped.
    closed: Status,
}

/// Read half of a stream. See `stream()`.
pub struct Reader(Arc<SyncMutex<State>>);

impl Reader {
    /// Debug
    pub fn inspect_shutdown(&self) -> String {
        let lock = self.0.lock().unwrap();
        if lock.closed.is_closed() {
            lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
        } else {
            "Not closed".to_owned()
        }
    }

    /// Read bytes from the stream.
    ///
    /// The reader will wait until there are *at least* `size` bytes to read, Then it will call the
    /// given callback with a slice containing all available bytes to read.
    ///
    /// If the callback processes data successfully, it should return `Ok` with a tuple containing
    /// a value of the user's choice, and the number of bytes used. If the number of bytes returned
    /// from the callback is less than what was available in the buffer, the unused bytes will
    /// appear at the start of the buffer for subsequent read calls. It is allowable to `peek` at
    /// the bytes in the buffer by returning a number of bytes read that is smaller than the number
    /// of bytes actually used.
    ///
    /// If the callback returns `Error::BufferTooShort` and the expected buffer value contained in
    /// the error is larger than the data that was provided, we will wait again until there are
    /// enough bytes to satisfy the error and then call the callback again. If the callback returns
    /// `Error::BufferTooShort` but the buffer should have been long enough according to the error,
    /// `Error::CallbackRejectedBuffer` is returned. Other errors from the callback are returned
    /// as-is from `read` itself.
    ///
    /// If there are no bytes available to read and the `Writer` for this stream has already been
    /// dropped, `read` returns `Error::ConnectionClosed`. If there are *not enough* bytes available
    /// to be read and the `Writer` has been dropped, `read` returns `Error::BufferTooSmall`. This
    /// is the only time `read` should return `Error::BufferTooSmall`.
    ///
    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
    pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
    where
        F: FnMut(&[u8]) -> Result<(U, usize)>,
    {
        loop {
            let receiver = {
                let mut state = self.0.lock().unwrap();

                if let Status::Closed(reason) = &state.closed {
                    if size == 0 {
                        return Err(Error::ConnectionClosed(reason.clone()));
                    }
                }

                if state.readable >= size {
                    let (first, _) = state.deque.as_slices();

                    let first = if first.len() >= size {
                        first
                    } else {
                        state.deque.make_contiguous();
                        state.deque.as_slices().0
                    };

                    debug_assert!(first.len() >= size);

                    let first = &first[..std::cmp::min(first.len(), state.readable)];
                    let (ret, consumed) = match f(first) {
                        Err(Error::BufferTooShort(s)) => {
                            if s < first.len() {
                                return Err(Error::CallbackRejectedBuffer(s, first.len()));
                            }

                            size = s;
                            continue;
                        }
                        other => other?,
                    };

                    if consumed > first.len() {
                        panic!("Read claimed to consume more bytes than it was given!");
                    }

                    state.readable -= consumed;
                    state.deque.drain(..consumed);
                    let target_capacity = std::cmp::max(
                        state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
                        BUFFER_TRIM_GRANULARITY,
                    );

                    if target_capacity <= state.deque.capacity() / 2 {
                        state.deque.shrink_to(target_capacity);
                    }
                    return Ok(ret);
                }

                if let Status::Closed(reason) = &state.closed {
                    if state.readable > 0 {
                        return Err(Error::BufferTooShort(size));
                    } else {
                        return Err(Error::ConnectionClosed(reason.clone()));
                    }
                }

                let (sender, receiver) = oneshot::channel();
                state.notify_readable = Some((sender, size));
                receiver
            };

            let _ = receiver.await;
        }
    }

    /// Read a protocol message from the stream. This is just a quick way to wire
    /// `ProtocolObject::try_from_bytes` in to `read`.
    pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
        self.read(P::MIN_SIZE, P::try_from_bytes).await
    }

    /// This writes the given protocol message to the stream at the *beginning* of the stream,
    /// meaning that it will be the next thing read off the stream.
    pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
        &self,
        message: &P,
    ) -> Result<()> {
        let size = message.byte_size();
        let mut state = self.0.lock().unwrap();
        let readable = state.readable;
        state.deque.resize(readable + size, 0);
        state.deque.rotate_right(size);
        let (first, _) = state.deque.as_mut_slices();

        let mut first = if first.len() >= size {
            first
        } else {
            state.deque.make_contiguous();
            state.deque.as_mut_slices().0
        };

        let got = message.write_bytes(&mut first)?;
        debug_assert!(got == size);
        state.readable += size;

        if let Some((sender, size)) = state.notify_readable.take() {
            if size <= state.readable {
                let _ = sender.send(());
            } else {
                state.notify_readable = Some((sender, size));
            }
        }

        Ok(())
    }

    /// Whether this stream is closed. Returns false so long as there is unread
    /// data in the buffer, even if the writer has hung up.
    pub fn is_closed(&self) -> bool {
        let state = self.0.lock().unwrap();
        state.closed.is_closed() && state.readable == 0
    }

    /// Get the reason this reader is closed. If the reader is not closed, or if
    /// no reason was given, return `None`.
    pub fn closed_reason(&self) -> Option<String> {
        let state = self.0.lock().unwrap();
        state.closed.reason()
    }

    /// Close this stream, giving a reason for the closure.
    pub fn close(self, reason: String) {
        let mut state = self.0.lock().unwrap();
        match &state.closed {
            Status::Closed(Some(_)) => (),
            _ => state.closed = Status::Closed(Some(reason)),
        }
    }
}

impl std::fmt::Debug for Reader {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
    }
}

impl Drop for Reader {
    fn drop(&mut self) {
        let mut state = self.0.lock().unwrap();
        state.closed.close();
    }
}

/// Write half of a stream. See `stream()`.
pub struct Writer(Arc<SyncMutex<State>>);

impl Writer {
    /// Write data to this stream.
    ///
    /// Space for `size` bytes is allocated in the stream immediately, and then the callback is
    /// invoked with a mutable slice to that region so that it may populate it. The slice given to
    /// the callback *may* be larger than requested but will never be smaller.
    ///
    /// The callback should return `Ok` with the number of bytes actually written, which may be less
    /// than `size`. If the callback returns an error, that error is returned from `write` as-is.
    /// Note that we do not specially process `Error::BufferTooSmall` as with `Reader::read`.
    ///
    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
    pub fn write<F>(&self, size: usize, f: F) -> Result<()>
    where
        F: FnOnce(&mut [u8]) -> Result<usize>,
    {
        let mut state = self.0.lock().unwrap();

        if let Status::Closed(reason) = &state.closed {
            return Err(Error::ConnectionClosed(reason.clone()));
        }

        let total_size = state.readable + size;

        if state.deque.len() < total_size {
            let total_size = std::cmp::max(total_size, state.deque.capacity());
            state.deque.resize(total_size, 0);
        }

        let readable = state.readable;
        let (first, second) = state.deque.as_mut_slices();

        let slice = if first.len() > readable {
            &mut first[readable..]
        } else {
            &mut second[(readable - first.len())..]
        };

        let slice = if slice.len() >= size {
            slice
        } else {
            state.deque.make_contiguous();
            &mut state.deque.as_mut_slices().0[readable..]
        };

        debug_assert!(slice.len() >= size);
        let size = f(slice)?;

        if size > slice.len() {
            panic!("Write claimed to produce more bytes than buffer had space for!");
        }

        state.readable += size;

        if let Some((sender, size)) = state.notify_readable.take() {
            if size <= state.readable {
                let _ = sender.send(());
            } else {
                state.notify_readable = Some((sender, size));
            }
        }

        Ok(())
    }

    /// Write a protocol message to the stream. This is just a quick way to wire
    /// `ProtocolObject::write_bytes` in to `write`.
    pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
        self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
    }

    /// Close this stream, giving a reason for the closure.
    pub fn close(self, reason: String) {
        self.0.lock().unwrap().closed = Status::Closed(Some(reason))
    }

    /// Whether this stream is closed. Returns false so long as there is unread
    /// data in the buffer, even if the writer has hung up.
    pub fn is_closed(&self) -> bool {
        let state = self.0.lock().unwrap();
        state.closed.is_closed() && state.readable == 0
    }

    /// Get the reason this writer is closed. If the writer is not closed, or if
    /// no reason was given, return `None`.
    pub fn closed_reason(&self) -> Option<String> {
        let state = self.0.lock().unwrap();
        state.closed.reason()
    }
}

impl std::fmt::Debug for Writer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
    }
}

impl Drop for Writer {
    fn drop(&mut self) {
        let Some(x) = ({
            let mut state = self.0.lock().unwrap();
            state.closed.close();

            state.notify_readable.take()
        }) else {
            return;
        };
        let _ = x.0.send(());
    }
}

/// Creates a unidirectional stream of bytes.
///
/// The `Reader` and `Writer` share an expanding ring buffer. This allows sending bytes between
/// tasks with minimal extra allocations or copies.
pub fn stream() -> (Reader, Writer) {
    let reader = Arc::new(SyncMutex::new(State {
        deque: VecDeque::new(),
        readable: 0,
        notify_readable: None,
        closed: Status::Open,
    }));
    let writer = Arc::clone(&reader);

    (Reader(reader), Writer(writer))
}

#[cfg(test)]
mod test {
    use futures::task::noop_waker;
    use futures::FutureExt;
    use std::future::Future;
    use std::pin::pin;
    use std::task::{Context, Poll};

    use super::*;

    impl protocol::ProtocolMessage for [u8; 4] {
        const MIN_SIZE: usize = 4;
        fn byte_size(&self) -> usize {
            4
        }

        fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
            out.write_all(self)?;
            Ok(4)
        }

        fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
            if bytes.len() < 4 {
                return Err(Error::BufferTooShort(4));
            }

            Ok((bytes[..4].try_into().unwrap(), 4))
        }
    }

    #[fuchsia::test]
    async fn stream_test() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4], got);

        writer
            .write(2, |buf| {
                buf[..2].copy_from_slice(&[9, 10]);
                Ok(2)
            })
            .unwrap();

        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();

        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
    }

    #[fuchsia::test]
    async fn push_back_test() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4], got);

        reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();

        writer
            .write(2, |buf| {
                buf[..2].copy_from_slice(&[9, 10]);
                Ok(2)
            })
            .unwrap();

        let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();

        assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
    }

    #[fuchsia::test]
    async fn writer_sees_close() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4], got);

        std::mem::drop(reader);

        assert!(matches!(
            writer.write(2, |buf| {
                buf[..2].copy_from_slice(&[9, 10]);
                Ok(2)
            }),
            Err(Error::ConnectionClosed(None))
        ));
    }

    #[fuchsia::test]
    async fn reader_sees_closed() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4], got);

        writer
            .write(2, |buf| {
                buf[..2].copy_from_slice(&[9, 10]);
                Ok(2)
            })
            .unwrap();

        std::mem::drop(writer);

        assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));

        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();

        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
        assert!(matches!(
            reader.read(1, |_| Ok(((), 1))).await,
            Err(Error::ConnectionClosed(None))
        ));
    }

    #[fuchsia::test]
    async fn reader_sees_closed_when_polling() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);

        let fut = reader
            .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
        let mut fut = std::pin::pin!(fut);

        assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());

        std::mem::drop(writer);

        assert!(matches!(
            fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
            Poll::Ready(Err(Error::ConnectionClosed(None)))
        ));
    }

    #[fuchsia::test]
    async fn reader_sees_closed_separate_task() {
        let (reader, writer) = stream();
        writer
            .write(8, |buf| {
                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
                Ok(8)
            })
            .unwrap();

        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();

        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);

        let (sender, receiver) = oneshot::channel();
        let task = fuchsia_async::Task::spawn(async move {
            let fut = reader.read(1, |_| Ok(((), 1)));
            let mut fut = std::pin::pin!(fut);
            let mut writer = Some(writer);
            let fut = futures::future::poll_fn(move |cx| {
                let ret = fut.as_mut().poll(cx);

                if writer.take().is_some() {
                    assert!(matches!(ret, Poll::Pending));
                }

                ret
            });
            assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
            sender.send(()).unwrap();
        });

        receiver.await.unwrap();
        task.await;
    }

    #[fuchsia::test]
    async fn reader_buffer_too_short() {
        let (reader, writer) = stream();
        let (sender, receiver) = oneshot::channel();
        let mut sender = Some(sender);

        let reader_task = async move {
            let got = reader
                .read(1, |buf| {
                    if buf.len() != 4 {
                        sender.take().unwrap().send(buf.len()).unwrap();
                        Err(Error::BufferTooShort(4))
                    } else {
                        Ok((buf[..4].to_vec(), 4))
                    }
                })
                .await
                .unwrap();
            assert_eq!(vec![1, 2, 3, 4], got);
        };

        let writer_task = async move {
            writer
                .write(2, |buf| {
                    buf[..2].copy_from_slice(&[1, 2]);
                    Ok(2)
                })
                .unwrap();

            assert_eq!(2, receiver.await.unwrap());

            writer
                .write(2, |buf| {
                    buf[..2].copy_from_slice(&[3, 4]);
                    Ok(2)
                })
                .unwrap();
        };

        futures::future::join(pin!(reader_task), pin!(writer_task)).await;
    }
}