circuit/
stream.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::error::{Error, Result};
6use crate::protocol;
7
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex as SyncMutex};
10use tokio::sync::oneshot;
11
12/// We shrink our internal buffers until they are no more than this much larger
13/// than the actual data we are accumulating.
14///
15/// The value of 1MiB was discovered experimentally to yield acceptable
16/// performance on tests.
17const BUFFER_TRIM_GRANULARITY: usize = 1048576;
18
19/// BUFFER_TRIM_GRANULARITY should be a power of 2.
20const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
21
22/// Indicates whether a stream is open or closed, and if closed, why it closed.
23#[derive(Debug, Clone)]
24enum Status {
25    /// Stream is open.
26    Open,
27    /// Stream is closed. Argument may contain a reason for closure.
28    Closed(Option<String>),
29}
30
31impl Status {
32    fn is_closed(&self) -> bool {
33        match self {
34            Status::Open => false,
35            Status::Closed(_) => true,
36        }
37    }
38
39    fn reason(&self) -> Option<String> {
40        match self {
41            Status::Open => None,
42            Status::Closed(x) => x.clone(),
43        }
44    }
45
46    fn close(&mut self) {
47        if let Status::Open = self {
48            *self = Status::Closed(None);
49        }
50    }
51}
52
53/// Internal state of a stream. See `stream()`.
54#[derive(Debug)]
55struct State {
56    /// The ring buffer itself.
57    deque: VecDeque<u8>,
58    /// How many bytes are readable. This is different from the length of the deque, as we may allow
59    /// bytes to be in the deque that are "initialized but unavailable." Mostly that just
60    /// accommodates a quirk of Rust's memory model; we have to initialize all bytes before we show
61    /// them to the user, even if they're there just to be overwritten, and if we pop the bytes out
62    /// of the deque Rust counts them as uninitialized again, so to avoid duplicating the
63    /// initialization process we just leave the initialized-but-unwritten bytes in the deque.
64    readable: usize,
65    /// If the reader needs to sleep, it puts a oneshot sender here so it can be woken up again. It
66    /// also lists how many bytes should be available before it should be woken up.
67    notify_readable: Option<(oneshot::Sender<()>, usize)>,
68    /// Whether this stream is closed. I.e. whether either the `Reader` or `Writer` has been dropped.
69    closed: Status,
70}
71
72/// Read half of a stream. See `stream()`.
73pub struct Reader(Arc<SyncMutex<State>>);
74
75impl Reader {
76    /// Debug
77    pub fn inspect_shutdown(&self) -> String {
78        let lock = self.0.lock().unwrap();
79        if lock.closed.is_closed() {
80            lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
81        } else {
82            "Not closed".to_owned()
83        }
84    }
85
86    /// Read bytes from the stream.
87    ///
88    /// The reader will wait until there are *at least* `size` bytes to read, Then it will call the
89    /// given callback with a slice containing all available bytes to read.
90    ///
91    /// If the callback processes data successfully, it should return `Ok` with a tuple containing
92    /// a value of the user's choice, and the number of bytes used. If the number of bytes returned
93    /// from the callback is less than what was available in the buffer, the unused bytes will
94    /// appear at the start of the buffer for subsequent read calls. It is allowable to `peek` at
95    /// the bytes in the buffer by returning a number of bytes read that is smaller than the number
96    /// of bytes actually used.
97    ///
98    /// If the callback returns `Error::BufferTooShort` and the expected buffer value contained in
99    /// the error is larger than the data that was provided, we will wait again until there are
100    /// enough bytes to satisfy the error and then call the callback again. If the callback returns
101    /// `Error::BufferTooShort` but the buffer should have been long enough according to the error,
102    /// `Error::CallbackRejectedBuffer` is returned. Other errors from the callback are returned
103    /// as-is from `read` itself.
104    ///
105    /// If there are no bytes available to read and the `Writer` for this stream has already been
106    /// dropped, `read` returns `Error::ConnectionClosed`. If there are *not enough* bytes available
107    /// to be read and the `Writer` has been dropped, `read` returns `Error::BufferTooSmall`. This
108    /// is the only time `read` should return `Error::BufferTooSmall`.
109    ///
110    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
111    pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
112    where
113        F: FnMut(&[u8]) -> Result<(U, usize)>,
114    {
115        loop {
116            let receiver = {
117                let mut state = self.0.lock().unwrap();
118
119                if let Status::Closed(reason) = &state.closed {
120                    if size == 0 {
121                        return Err(Error::ConnectionClosed(reason.clone()));
122                    }
123                }
124
125                if state.readable >= size {
126                    let (first, _) = state.deque.as_slices();
127
128                    let first = if first.len() >= size {
129                        first
130                    } else {
131                        state.deque.make_contiguous();
132                        state.deque.as_slices().0
133                    };
134
135                    debug_assert!(first.len() >= size);
136
137                    let first = &first[..std::cmp::min(first.len(), state.readable)];
138                    let (ret, consumed) = match f(first) {
139                        Err(Error::BufferTooShort(s)) => {
140                            if s < first.len() {
141                                return Err(Error::CallbackRejectedBuffer(s, first.len()));
142                            }
143
144                            size = s;
145                            continue;
146                        }
147                        other => other?,
148                    };
149
150                    if consumed > first.len() {
151                        panic!("Read claimed to consume more bytes than it was given!");
152                    }
153
154                    state.readable -= consumed;
155                    state.deque.drain(..consumed);
156                    let target_capacity = std::cmp::max(
157                        state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
158                        BUFFER_TRIM_GRANULARITY,
159                    );
160
161                    if target_capacity <= state.deque.capacity() / 2 {
162                        state.deque.shrink_to(target_capacity);
163                    }
164                    return Ok(ret);
165                }
166
167                if let Status::Closed(reason) = &state.closed {
168                    if state.readable > 0 {
169                        return Err(Error::BufferTooShort(size));
170                    } else {
171                        return Err(Error::ConnectionClosed(reason.clone()));
172                    }
173                }
174
175                let (sender, receiver) = oneshot::channel();
176                state.notify_readable = Some((sender, size));
177                receiver
178            };
179
180            let _ = receiver.await;
181        }
182    }
183
184    /// Read a protocol message from the stream. This is just a quick way to wire
185    /// `ProtocolObject::try_from_bytes` in to `read`.
186    pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
187        self.read(P::MIN_SIZE, P::try_from_bytes).await
188    }
189
190    /// This writes the given protocol message to the stream at the *beginning* of the stream,
191    /// meaning that it will be the next thing read off the stream.
192    pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
193        &self,
194        message: &P,
195    ) -> Result<()> {
196        let size = message.byte_size();
197        let mut state = self.0.lock().unwrap();
198        let readable = state.readable;
199        state.deque.resize(readable + size, 0);
200        state.deque.rotate_right(size);
201        let (first, _) = state.deque.as_mut_slices();
202
203        let mut first = if first.len() >= size {
204            first
205        } else {
206            state.deque.make_contiguous();
207            state.deque.as_mut_slices().0
208        };
209
210        let got = message.write_bytes(&mut first)?;
211        debug_assert!(got == size);
212        state.readable += size;
213
214        if let Some((sender, size)) = state.notify_readable.take() {
215            if size <= state.readable {
216                let _ = sender.send(());
217            } else {
218                state.notify_readable = Some((sender, size));
219            }
220        }
221
222        Ok(())
223    }
224
225    /// Whether this stream is closed. Returns false so long as there is unread
226    /// data in the buffer, even if the writer has hung up.
227    pub fn is_closed(&self) -> bool {
228        let state = self.0.lock().unwrap();
229        state.closed.is_closed() && state.readable == 0
230    }
231
232    /// Get the reason this reader is closed. If the reader is not closed, or if
233    /// no reason was given, return `None`.
234    pub fn closed_reason(&self) -> Option<String> {
235        let state = self.0.lock().unwrap();
236        state.closed.reason()
237    }
238
239    /// Close this stream, giving a reason for the closure.
240    pub fn close(self, reason: String) {
241        let mut state = self.0.lock().unwrap();
242        match &state.closed {
243            Status::Closed(Some(_)) => (),
244            _ => state.closed = Status::Closed(Some(reason)),
245        }
246    }
247}
248
249impl std::fmt::Debug for Reader {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
252    }
253}
254
255impl Drop for Reader {
256    fn drop(&mut self) {
257        let mut state = self.0.lock().unwrap();
258        state.closed.close();
259    }
260}
261
262/// Write half of a stream. See `stream()`.
263pub struct Writer(Arc<SyncMutex<State>>);
264
265impl Writer {
266    /// Write data to this stream.
267    ///
268    /// Space for `size` bytes is allocated in the stream immediately, and then the callback is
269    /// invoked with a mutable slice to that region so that it may populate it. The slice given to
270    /// the callback *may* be larger than requested but will never be smaller.
271    ///
272    /// The callback should return `Ok` with the number of bytes actually written, which may be less
273    /// than `size`. If the callback returns an error, that error is returned from `write` as-is.
274    /// Note that we do not specially process `Error::BufferTooSmall` as with `Reader::read`.
275    ///
276    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
277    pub fn write<F>(&self, size: usize, f: F) -> Result<()>
278    where
279        F: FnOnce(&mut [u8]) -> Result<usize>,
280    {
281        let mut state = self.0.lock().unwrap();
282
283        if let Status::Closed(reason) = &state.closed {
284            return Err(Error::ConnectionClosed(reason.clone()));
285        }
286
287        let total_size = state.readable + size;
288
289        if state.deque.len() < total_size {
290            let total_size = std::cmp::max(total_size, state.deque.capacity());
291            state.deque.resize(total_size, 0);
292        }
293
294        let readable = state.readable;
295        let (first, second) = state.deque.as_mut_slices();
296
297        let slice = if first.len() > readable {
298            &mut first[readable..]
299        } else {
300            &mut second[(readable - first.len())..]
301        };
302
303        let slice = if slice.len() >= size {
304            slice
305        } else {
306            state.deque.make_contiguous();
307            &mut state.deque.as_mut_slices().0[readable..]
308        };
309
310        debug_assert!(slice.len() >= size);
311        let size = f(slice)?;
312
313        if size > slice.len() {
314            panic!("Write claimed to produce more bytes than buffer had space for!");
315        }
316
317        state.readable += size;
318
319        if let Some((sender, size)) = state.notify_readable.take() {
320            if size <= state.readable {
321                let _ = sender.send(());
322            } else {
323                state.notify_readable = Some((sender, size));
324            }
325        }
326
327        Ok(())
328    }
329
330    /// Write a protocol message to the stream. This is just a quick way to wire
331    /// `ProtocolObject::write_bytes` in to `write`.
332    pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
333        self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
334    }
335
336    /// Close this stream, giving a reason for the closure.
337    pub fn close(self, reason: String) {
338        self.0.lock().unwrap().closed = Status::Closed(Some(reason))
339    }
340
341    /// Whether this stream is closed. Returns false so long as there is unread
342    /// data in the buffer, even if the writer has hung up.
343    pub fn is_closed(&self) -> bool {
344        let state = self.0.lock().unwrap();
345        state.closed.is_closed() && state.readable == 0
346    }
347
348    /// Get the reason this writer is closed. If the writer is not closed, or if
349    /// no reason was given, return `None`.
350    pub fn closed_reason(&self) -> Option<String> {
351        let state = self.0.lock().unwrap();
352        state.closed.reason()
353    }
354}
355
356impl std::fmt::Debug for Writer {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
359    }
360}
361
362impl Drop for Writer {
363    fn drop(&mut self) {
364        let Some(x) = ({
365            let mut state = self.0.lock().unwrap();
366            state.closed.close();
367
368            state.notify_readable.take()
369        }) else {
370            return;
371        };
372        let _ = x.0.send(());
373    }
374}
375
376/// Creates a unidirectional stream of bytes.
377///
378/// The `Reader` and `Writer` share an expanding ring buffer. This allows sending bytes between
379/// tasks with minimal extra allocations or copies.
380pub fn stream() -> (Reader, Writer) {
381    let reader = Arc::new(SyncMutex::new(State {
382        deque: VecDeque::new(),
383        readable: 0,
384        notify_readable: None,
385        closed: Status::Open,
386    }));
387    let writer = Arc::clone(&reader);
388
389    (Reader(reader), Writer(writer))
390}
391
392#[cfg(test)]
393mod test {
394    use futures::task::noop_waker;
395    use futures::FutureExt;
396    use std::future::Future;
397    use std::pin::pin;
398    use std::task::{Context, Poll};
399
400    use super::*;
401
402    impl protocol::ProtocolMessage for [u8; 4] {
403        const MIN_SIZE: usize = 4;
404        fn byte_size(&self) -> usize {
405            4
406        }
407
408        fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
409            out.write_all(self)?;
410            Ok(4)
411        }
412
413        fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
414            if bytes.len() < 4 {
415                return Err(Error::BufferTooShort(4));
416            }
417
418            Ok((bytes[..4].try_into().unwrap(), 4))
419        }
420    }
421
422    #[fuchsia::test]
423    async fn stream_test() {
424        let (reader, writer) = stream();
425        writer
426            .write(8, |buf| {
427                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
428                Ok(8)
429            })
430            .unwrap();
431
432        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
433
434        assert_eq!(vec![1, 2, 3, 4], got);
435
436        writer
437            .write(2, |buf| {
438                buf[..2].copy_from_slice(&[9, 10]);
439                Ok(2)
440            })
441            .unwrap();
442
443        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
444
445        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
446    }
447
448    #[fuchsia::test]
449    async fn push_back_test() {
450        let (reader, writer) = stream();
451        writer
452            .write(8, |buf| {
453                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
454                Ok(8)
455            })
456            .unwrap();
457
458        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
459
460        assert_eq!(vec![1, 2, 3, 4], got);
461
462        reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
463
464        writer
465            .write(2, |buf| {
466                buf[..2].copy_from_slice(&[9, 10]);
467                Ok(2)
468            })
469            .unwrap();
470
471        let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
472
473        assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
474    }
475
476    #[fuchsia::test]
477    async fn writer_sees_close() {
478        let (reader, writer) = stream();
479        writer
480            .write(8, |buf| {
481                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
482                Ok(8)
483            })
484            .unwrap();
485
486        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
487
488        assert_eq!(vec![1, 2, 3, 4], got);
489
490        std::mem::drop(reader);
491
492        assert!(matches!(
493            writer.write(2, |buf| {
494                buf[..2].copy_from_slice(&[9, 10]);
495                Ok(2)
496            }),
497            Err(Error::ConnectionClosed(None))
498        ));
499    }
500
501    #[fuchsia::test]
502    async fn reader_sees_closed() {
503        let (reader, writer) = stream();
504        writer
505            .write(8, |buf| {
506                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
507                Ok(8)
508            })
509            .unwrap();
510
511        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
512
513        assert_eq!(vec![1, 2, 3, 4], got);
514
515        writer
516            .write(2, |buf| {
517                buf[..2].copy_from_slice(&[9, 10]);
518                Ok(2)
519            })
520            .unwrap();
521
522        std::mem::drop(writer);
523
524        assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
525
526        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
527
528        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
529        assert!(matches!(
530            reader.read(1, |_| Ok(((), 1))).await,
531            Err(Error::ConnectionClosed(None))
532        ));
533    }
534
535    #[fuchsia::test]
536    async fn reader_sees_closed_when_polling() {
537        let (reader, writer) = stream();
538        writer
539            .write(8, |buf| {
540                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
541                Ok(8)
542            })
543            .unwrap();
544
545        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
546
547        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
548
549        let fut = reader
550            .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
551        let mut fut = std::pin::pin!(fut);
552
553        assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
554
555        std::mem::drop(writer);
556
557        assert!(matches!(
558            fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
559            Poll::Ready(Err(Error::ConnectionClosed(None)))
560        ));
561    }
562
563    #[fuchsia::test]
564    async fn reader_sees_closed_separate_task() {
565        let (reader, writer) = stream();
566        writer
567            .write(8, |buf| {
568                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
569                Ok(8)
570            })
571            .unwrap();
572
573        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
574
575        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
576
577        let (sender, receiver) = oneshot::channel();
578        let task = fuchsia_async::Task::spawn(async move {
579            let fut = reader.read(1, |_| Ok(((), 1)));
580            let mut fut = std::pin::pin!(fut);
581            let mut writer = Some(writer);
582            let fut = futures::future::poll_fn(move |cx| {
583                let ret = fut.as_mut().poll(cx);
584
585                if writer.take().is_some() {
586                    assert!(matches!(ret, Poll::Pending));
587                }
588
589                ret
590            });
591            assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
592            sender.send(()).unwrap();
593        });
594
595        receiver.await.unwrap();
596        task.await;
597    }
598
599    #[fuchsia::test]
600    async fn reader_buffer_too_short() {
601        let (reader, writer) = stream();
602        let (sender, receiver) = oneshot::channel();
603        let mut sender = Some(sender);
604
605        let reader_task = async move {
606            let got = reader
607                .read(1, |buf| {
608                    if buf.len() != 4 {
609                        sender.take().unwrap().send(buf.len()).unwrap();
610                        Err(Error::BufferTooShort(4))
611                    } else {
612                        Ok((buf[..4].to_vec(), 4))
613                    }
614                })
615                .await
616                .unwrap();
617            assert_eq!(vec![1, 2, 3, 4], got);
618        };
619
620        let writer_task = async move {
621            writer
622                .write(2, |buf| {
623                    buf[..2].copy_from_slice(&[1, 2]);
624                    Ok(2)
625                })
626                .unwrap();
627
628            assert_eq!(2, receiver.await.unwrap());
629
630            writer
631                .write(2, |buf| {
632                    buf[..2].copy_from_slice(&[3, 4]);
633                    Ok(2)
634                })
635                .unwrap();
636        };
637
638        futures::future::join(pin!(reader_task), pin!(writer_task)).await;
639    }
640}