circuit/
stream.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::error::{Error, Result};
6use crate::protocol;
7
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex as SyncMutex};
10use std::task::{ready, Context, Poll, Waker};
11
12/// We shrink our internal buffers until they are no more than this much larger
13/// than the actual data we are accumulating.
14///
15/// The value of 1MiB was discovered experimentally to yield acceptable
16/// performance on tests.
17const BUFFER_TRIM_GRANULARITY: usize = 1048576;
18
19/// BUFFER_TRIM_GRANULARITY should be a power of 2.
20const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
21
22/// Indicates whether a stream is open or closed, and if closed, why it closed.
23#[derive(Debug, Clone)]
24enum Status {
25    /// Stream is open.
26    Open,
27    /// Stream is closed. Argument may contain a reason for closure.
28    Closed(Option<String>),
29}
30
31impl Status {
32    fn is_closed(&self) -> bool {
33        match self {
34            Status::Open => false,
35            Status::Closed(_) => true,
36        }
37    }
38
39    fn reason(&self) -> Option<String> {
40        match self {
41            Status::Open => None,
42            Status::Closed(x) => x.clone(),
43        }
44    }
45
46    fn close(&mut self) {
47        if let Status::Open = self {
48            *self = Status::Closed(None);
49        }
50    }
51}
52
53/// Internal state of a stream. See `stream()`.
54#[derive(Debug)]
55struct State {
56    /// The ring buffer itself.
57    deque: VecDeque<u8>,
58    /// How many bytes are readable. This is different from the length of the deque, as we may allow
59    /// bytes to be in the deque that are "initialized but unavailable." Mostly that just
60    /// accommodates a quirk of Rust's memory model; we have to initialize all bytes before we show
61    /// them to the user, even if they're there just to be overwritten, and if we pop the bytes out
62    /// of the deque Rust counts them as uninitialized again, so to avoid duplicating the
63    /// initialization process we just leave the initialized-but-unwritten bytes in the deque.
64    readable: usize,
65    /// If the reader needs to sleep, it puts a waker here so it can be woken up
66    /// again. It also lists how many bytes should be available before it should
67    /// be woken up.
68    notify_readable: Option<(Waker, usize)>,
69    /// Whether this stream is closed. I.e. whether either the `Reader` or `Writer` has been dropped.
70    closed: Status,
71}
72
73/// Read half of a stream. See `stream()`.
74pub struct Reader(Arc<SyncMutex<State>>);
75
76impl Reader {
77    /// Debug
78    pub fn inspect_shutdown(&self) -> String {
79        let lock = self.0.lock().unwrap();
80        if lock.closed.is_closed() {
81            lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
82        } else {
83            "Not closed".to_owned()
84        }
85    }
86
87    /// Read bytes from the stream.
88    ///
89    /// The reader will wait until there are *at least* `size` bytes to read,
90    /// Then it will call the given callback with a slice containing all
91    /// available bytes to read.
92    ///
93    /// If the callback processes data successfully, it should return `Ok` with
94    /// a tuple containing a value of the user's choice, and the number of bytes
95    /// used. If the number of bytes returned from the callback is less than
96    /// what was available in the buffer, the unused bytes will appear at the
97    /// start of the buffer for subsequent read calls. It is allowable to `peek`
98    /// at the bytes in the buffer by returning a number of bytes read that is
99    /// smaller than the number of bytes actually used.
100    ///
101    /// If the callback returns `Error::BufferTooShort` and the expected buffer
102    /// value contained in the error is larger than the data that was provided,
103    /// we will wait again until there are enough bytes to satisfy the error and
104    /// then call the callback again. If the callback returns
105    /// `Error::BufferTooShort` but the buffer should have been long enough
106    /// according to the error, `Error::CallbackRejectedBuffer` is returned.
107    /// Other errors from the callback are returned as-is from `read` itself.
108    ///
109    /// If there are no bytes available to read and the `Writer` for this stream
110    /// has already been dropped, `read` returns `Error::ConnectionClosed`. If
111    /// there are *not enough* bytes available to be read and the `Writer` has
112    /// been dropped, `read` returns `Error::BufferTooSmall`. This is the only
113    /// time `read` should return `Error::BufferTooSmall`.
114    ///
115    /// Panics if the callback returns a number of bytes greater than the size
116    /// of the buffer.
117    pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
118    where
119        F: FnMut(&[u8]) -> Result<(U, usize)>,
120    {
121        let mut f = move |_: &mut Context<'_>, b: &'_ [u8]| Poll::Ready(f(b));
122        futures::future::poll_fn(move |ctx| self.poll_read(ctx, &mut size, &mut f)).await
123    }
124
125    /// Like `read` but a poll function rather than a future. Passes the
126    /// `Context` to the callback so the callback can also poll.
127    ///
128    /// The `size` argument is a mutable reference so that if we determine
129    /// during reading that more data is needed, we can update the size. It
130    /// should always increase. The value of `size` is not meaningful after
131    /// returning `Poll::Ready`.
132    pub fn poll_read<F, U>(
133        &self,
134        ctx: &mut Context<'_>,
135        size: &mut usize,
136        mut f: F,
137    ) -> Poll<Result<U>>
138    where
139        F: FnMut(&mut Context<'_>, &[u8]) -> Poll<Result<(U, usize)>>,
140    {
141        let mut state = self.0.lock().unwrap();
142
143        if let Status::Closed(reason) = &state.closed && *size == 0 {
144            return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
145        }
146
147        if state.readable >= *size {
148            let (first, _) = state.deque.as_slices();
149
150            let first = if first.len() >= *size {
151                first
152            } else {
153                state.deque.make_contiguous();
154                state.deque.as_slices().0
155            };
156
157            debug_assert!(first.len() >= *size);
158
159            let first = &first[..std::cmp::min(first.len(), state.readable)];
160            let (ret, consumed) = match ready!(f(ctx, first)) {
161                Err(Error::BufferTooShort(s)) => {
162                    if s < first.len() {
163                        return Poll::Ready(Err(Error::CallbackRejectedBuffer(s, first.len())));
164                    }
165
166                    *size = s;
167                    ctx.waker().wake_by_ref();
168                    return Poll::Pending;
169                }
170                other => other?,
171            };
172
173            if consumed > first.len() {
174                panic!("Read claimed to consume more bytes than it was given!");
175            }
176
177            state.readable -= consumed;
178            state.deque.drain(..consumed);
179            let target_capacity = std::cmp::max(
180                state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
181                BUFFER_TRIM_GRANULARITY,
182            );
183
184            if target_capacity <= state.deque.capacity() / 2 {
185                state.deque.shrink_to(target_capacity);
186            }
187            return Poll::Ready(Ok(ret));
188        }
189
190        if let Status::Closed(reason) = &state.closed {
191            if state.readable > 0 {
192                return Poll::Ready(Err(Error::BufferTooShort(*size)));
193            } else {
194                return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
195            }
196        }
197
198        state.notify_readable = Some((ctx.waker().clone(), *size));
199        Poll::Pending
200    }
201
202    /// Read a protocol message from the stream. This is just a quick way to wire
203    /// `ProtocolObject::try_from_bytes` in to `read`.
204    pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
205        self.read(P::MIN_SIZE, P::try_from_bytes).await
206    }
207
208    /// This writes the given protocol message to the stream at the *beginning* of the stream,
209    /// meaning that it will be the next thing read off the stream.
210    pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
211        &self,
212        message: &P,
213    ) -> Result<()> {
214        let size = message.byte_size();
215        let mut state = self.0.lock().unwrap();
216        let readable = state.readable;
217        state.deque.resize(readable + size, 0);
218        state.deque.rotate_right(size);
219        let (first, _) = state.deque.as_mut_slices();
220
221        let mut first = if first.len() >= size {
222            first
223        } else {
224            state.deque.make_contiguous();
225            state.deque.as_mut_slices().0
226        };
227
228        let got = message.write_bytes(&mut first)?;
229        debug_assert!(got == size);
230        state.readable += size;
231
232        if let Some((waker, size)) = state.notify_readable.take() {
233            if size <= state.readable {
234                waker.wake();
235            } else {
236                state.notify_readable = Some((waker, size));
237            }
238        }
239
240        Ok(())
241    }
242
243    /// Whether this stream is closed. Returns false so long as there is unread
244    /// data in the buffer, even if the writer has hung up.
245    pub fn is_closed(&self) -> bool {
246        let state = self.0.lock().unwrap();
247        state.closed.is_closed() && state.readable == 0
248    }
249
250    /// Get the reason this reader is closed. If the reader is not closed, or if
251    /// no reason was given, return `None`.
252    pub fn closed_reason(&self) -> Option<String> {
253        let state = self.0.lock().unwrap();
254        state.closed.reason()
255    }
256
257    /// Close this stream, giving a reason for the closure.
258    pub fn close(self, reason: String) {
259        let mut state = self.0.lock().unwrap();
260        match &state.closed {
261            Status::Closed(Some(_)) => (),
262            _ => state.closed = Status::Closed(Some(reason)),
263        }
264    }
265}
266
267impl std::fmt::Debug for Reader {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
270    }
271}
272
273impl Drop for Reader {
274    fn drop(&mut self) {
275        let mut state = self.0.lock().unwrap();
276        state.closed.close();
277    }
278}
279
280/// Write half of a stream. See `stream()`.
281pub struct Writer(Arc<SyncMutex<State>>);
282
283impl Writer {
284    /// Write data to this stream.
285    ///
286    /// Space for `size` bytes is allocated in the stream immediately, and then the callback is
287    /// invoked with a mutable slice to that region so that it may populate it. The slice given to
288    /// the callback *may* be larger than requested but will never be smaller.
289    ///
290    /// The callback should return `Ok` with the number of bytes actually written, which may be less
291    /// than `size`. If the callback returns an error, that error is returned from `write` as-is.
292    /// Note that we do not specially process `Error::BufferTooSmall` as with `Reader::read`.
293    ///
294    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
295    pub fn write<F>(&self, size: usize, f: F) -> Result<()>
296    where
297        F: FnOnce(&mut [u8]) -> Result<usize>,
298    {
299        let mut state = self.0.lock().unwrap();
300
301        if let Status::Closed(reason) = &state.closed {
302            return Err(Error::ConnectionClosed(reason.clone()));
303        }
304
305        let total_size = state.readable + size;
306
307        if state.deque.len() < total_size {
308            let total_size = std::cmp::max(total_size, state.deque.capacity());
309            state.deque.resize(total_size, 0);
310        }
311
312        let readable = state.readable;
313        let (first, second) = state.deque.as_mut_slices();
314
315        let slice = if first.len() > readable {
316            &mut first[readable..]
317        } else {
318            &mut second[(readable - first.len())..]
319        };
320
321        let slice = if slice.len() >= size {
322            slice
323        } else {
324            state.deque.make_contiguous();
325            &mut state.deque.as_mut_slices().0[readable..]
326        };
327
328        debug_assert!(slice.len() >= size);
329        let size = f(slice)?;
330
331        if size > slice.len() {
332            panic!("Write claimed to produce more bytes than buffer had space for!");
333        }
334
335        state.readable += size;
336
337        if let Some((waker, size)) = state.notify_readable.take() {
338            if size <= state.readable {
339                waker.wake();
340            } else {
341                state.notify_readable = Some((waker, size));
342            }
343        }
344
345        Ok(())
346    }
347
348    /// Write a protocol message to the stream. This is just a quick way to wire
349    /// `ProtocolObject::write_bytes` in to `write`.
350    pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
351        self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
352    }
353
354    /// Close this stream, giving a reason for the closure.
355    pub fn close(self, reason: String) {
356        self.0.lock().unwrap().closed = Status::Closed(Some(reason))
357    }
358
359    /// Whether this stream is closed. Returns false so long as there is unread
360    /// data in the buffer, even if the writer has hung up.
361    pub fn is_closed(&self) -> bool {
362        let state = self.0.lock().unwrap();
363        state.closed.is_closed() && state.readable == 0
364    }
365
366    /// Get the reason this writer is closed. If the writer is not closed, or if
367    /// no reason was given, return `None`.
368    pub fn closed_reason(&self) -> Option<String> {
369        let state = self.0.lock().unwrap();
370        state.closed.reason()
371    }
372}
373
374impl std::fmt::Debug for Writer {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
377    }
378}
379
380impl Drop for Writer {
381    fn drop(&mut self) {
382        let mut state = self.0.lock().unwrap();
383        state.closed.close();
384
385        if let Some((waker, _)) = state.notify_readable.take() {
386            waker.wake();
387        }
388    }
389}
390
391/// Creates a unidirectional stream of bytes.
392///
393/// The `Reader` and `Writer` share an expanding ring buffer. This allows
394/// sending bytes between tasks with minimal extra allocations or copies.
395pub fn stream() -> (Reader, Writer) {
396    let reader = Arc::new(SyncMutex::new(State {
397        deque: VecDeque::new(),
398        readable: 0,
399        notify_readable: None,
400        closed: Status::Open,
401    }));
402    let writer = Arc::clone(&reader);
403
404    (Reader(reader), Writer(writer))
405}
406
407#[cfg(test)]
408mod test {
409    use futures::channel::oneshot;
410    use futures::task::noop_waker;
411    use futures::FutureExt;
412    use std::future::Future;
413    use std::pin::pin;
414    use std::task::{Context, Poll};
415
416    use super::*;
417
418    impl protocol::ProtocolMessage for [u8; 4] {
419        const MIN_SIZE: usize = 4;
420        fn byte_size(&self) -> usize {
421            4
422        }
423
424        fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
425            out.write_all(self)?;
426            Ok(4)
427        }
428
429        fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
430            if bytes.len() < 4 {
431                return Err(Error::BufferTooShort(4));
432            }
433
434            Ok((bytes[..4].try_into().unwrap(), 4))
435        }
436    }
437
438    #[fuchsia::test]
439    async fn stream_test() {
440        let (reader, writer) = stream();
441        writer
442            .write(8, |buf| {
443                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
444                Ok(8)
445            })
446            .unwrap();
447
448        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
449
450        assert_eq!(vec![1, 2, 3, 4], got);
451
452        writer
453            .write(2, |buf| {
454                buf[..2].copy_from_slice(&[9, 10]);
455                Ok(2)
456            })
457            .unwrap();
458
459        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
460
461        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
462    }
463
464    #[fuchsia::test]
465    async fn push_back_test() {
466        let (reader, writer) = stream();
467        writer
468            .write(8, |buf| {
469                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
470                Ok(8)
471            })
472            .unwrap();
473
474        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
475
476        assert_eq!(vec![1, 2, 3, 4], got);
477
478        reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
479
480        writer
481            .write(2, |buf| {
482                buf[..2].copy_from_slice(&[9, 10]);
483                Ok(2)
484            })
485            .unwrap();
486
487        let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
488
489        assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
490    }
491
492    #[fuchsia::test]
493    async fn writer_sees_close() {
494        let (reader, writer) = stream();
495        writer
496            .write(8, |buf| {
497                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
498                Ok(8)
499            })
500            .unwrap();
501
502        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
503
504        assert_eq!(vec![1, 2, 3, 4], got);
505
506        std::mem::drop(reader);
507
508        assert!(matches!(
509            writer.write(2, |buf| {
510                buf[..2].copy_from_slice(&[9, 10]);
511                Ok(2)
512            }),
513            Err(Error::ConnectionClosed(None))
514        ));
515    }
516
517    #[fuchsia::test]
518    async fn reader_sees_closed() {
519        let (reader, writer) = stream();
520        writer
521            .write(8, |buf| {
522                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
523                Ok(8)
524            })
525            .unwrap();
526
527        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
528
529        assert_eq!(vec![1, 2, 3, 4], got);
530
531        writer
532            .write(2, |buf| {
533                buf[..2].copy_from_slice(&[9, 10]);
534                Ok(2)
535            })
536            .unwrap();
537
538        std::mem::drop(writer);
539
540        assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
541
542        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
543
544        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
545        assert!(matches!(
546            reader.read(1, |_| Ok(((), 1))).await,
547            Err(Error::ConnectionClosed(None))
548        ));
549    }
550
551    #[fuchsia::test]
552    async fn reader_sees_closed_when_polling() {
553        let (reader, writer) = stream();
554        writer
555            .write(8, |buf| {
556                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
557                Ok(8)
558            })
559            .unwrap();
560
561        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
562
563        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
564
565        let fut = reader
566            .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
567        let mut fut = std::pin::pin!(fut);
568
569        assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
570
571        std::mem::drop(writer);
572
573        assert!(matches!(
574            fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
575            Poll::Ready(Err(Error::ConnectionClosed(None)))
576        ));
577    }
578
579    #[fuchsia::test]
580    async fn reader_sees_closed_separate_task() {
581        let (reader, writer) = stream();
582        writer
583            .write(8, |buf| {
584                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
585                Ok(8)
586            })
587            .unwrap();
588
589        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
590
591        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
592
593        let (sender, receiver) = oneshot::channel();
594        let task = fuchsia_async::Task::spawn(async move {
595            let fut = reader.read(1, |_| Ok(((), 1)));
596            let mut fut = std::pin::pin!(fut);
597            let mut writer = Some(writer);
598            let fut = futures::future::poll_fn(move |cx| {
599                let ret = fut.as_mut().poll(cx);
600
601                if writer.take().is_some() {
602                    assert!(matches!(ret, Poll::Pending));
603                }
604
605                ret
606            });
607            assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
608            sender.send(()).unwrap();
609        });
610
611        receiver.await.unwrap();
612        task.await;
613    }
614
615    #[fuchsia::test]
616    async fn reader_buffer_too_short() {
617        let (reader, writer) = stream();
618        let (sender, receiver) = oneshot::channel();
619        let mut sender = Some(sender);
620
621        let reader_task = async move {
622            let got = reader
623                .read(1, |buf| {
624                    if buf.len() != 4 {
625                        sender.take().unwrap().send(buf.len()).unwrap();
626                        Err(Error::BufferTooShort(4))
627                    } else {
628                        Ok((buf[..4].to_vec(), 4))
629                    }
630                })
631                .await
632                .unwrap();
633            assert_eq!(vec![1, 2, 3, 4], got);
634        };
635
636        let writer_task = async move {
637            writer
638                .write(2, |buf| {
639                    buf[..2].copy_from_slice(&[1, 2]);
640                    Ok(2)
641                })
642                .unwrap();
643
644            assert_eq!(2, receiver.await.unwrap());
645
646            writer
647                .write(2, |buf| {
648                    buf[..2].copy_from_slice(&[3, 4]);
649                    Ok(2)
650                })
651                .unwrap();
652        };
653
654        futures::future::join(pin!(reader_task), pin!(writer_task)).await;
655    }
656}