circuit/
stream.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::error::{Error, Result};
6use crate::protocol;
7
8use fuchsia_sync::Mutex as SyncMutex;
9use std::collections::VecDeque;
10use std::sync::Arc;
11use std::task::{Context, Poll, Waker, ready};
12
13/// We shrink our internal buffers until they are no more than this much larger
14/// than the actual data we are accumulating.
15///
16/// The value of 1MiB was discovered experimentally to yield acceptable
17/// performance on tests.
18const BUFFER_TRIM_GRANULARITY: usize = 1048576;
19
20/// BUFFER_TRIM_GRANULARITY should be a power of 2.
21const _: () = assert!(BUFFER_TRIM_GRANULARITY.is_power_of_two());
22
23/// Indicates whether a stream is open or closed, and if closed, why it closed.
24#[derive(Debug, Clone)]
25enum Status {
26    /// Stream is open.
27    Open,
28    /// Stream is closed. Argument may contain a reason for closure.
29    Closed(Option<String>),
30}
31
32impl Status {
33    fn is_closed(&self) -> bool {
34        match self {
35            Status::Open => false,
36            Status::Closed(_) => true,
37        }
38    }
39
40    fn reason(&self) -> Option<String> {
41        match self {
42            Status::Open => None,
43            Status::Closed(x) => x.clone(),
44        }
45    }
46
47    fn close(&mut self) {
48        if let Status::Open = self {
49            *self = Status::Closed(None);
50        }
51    }
52}
53
54/// Internal state of a stream. See `stream()`.
55#[derive(Debug)]
56struct State {
57    /// The ring buffer itself.
58    deque: VecDeque<u8>,
59    /// How many bytes are readable. This is different from the length of the deque, as we may allow
60    /// bytes to be in the deque that are "initialized but unavailable." Mostly that just
61    /// accommodates a quirk of Rust's memory model; we have to initialize all bytes before we show
62    /// them to the user, even if they're there just to be overwritten, and if we pop the bytes out
63    /// of the deque Rust counts them as uninitialized again, so to avoid duplicating the
64    /// initialization process we just leave the initialized-but-unwritten bytes in the deque.
65    readable: usize,
66    /// If the reader needs to sleep, it puts a waker here so it can be woken up
67    /// again. It also lists how many bytes should be available before it should
68    /// be woken up.
69    notify_readable: Option<(Waker, usize)>,
70    /// Whether this stream is closed. I.e. whether either the `Reader` or `Writer` has been dropped.
71    closed: Status,
72}
73
74/// Read half of a stream. See `stream()`.
75pub struct Reader(Arc<SyncMutex<State>>);
76
77impl Reader {
78    /// Debug
79    pub fn inspect_shutdown(&self) -> String {
80        let lock = self.0.lock();
81        if lock.closed.is_closed() {
82            lock.closed.reason().unwrap_or_else(|| "No epitaph".to_owned())
83        } else {
84            "Not closed".to_owned()
85        }
86    }
87
88    /// Read bytes from the stream.
89    ///
90    /// The reader will wait until there are *at least* `size` bytes to read,
91    /// Then it will call the given callback with a slice containing all
92    /// available bytes to read.
93    ///
94    /// If the callback processes data successfully, it should return `Ok` with
95    /// a tuple containing a value of the user's choice, and the number of bytes
96    /// used. If the number of bytes returned from the callback is less than
97    /// what was available in the buffer, the unused bytes will appear at the
98    /// start of the buffer for subsequent read calls. It is allowable to `peek`
99    /// at the bytes in the buffer by returning a number of bytes read that is
100    /// smaller than the number of bytes actually used.
101    ///
102    /// If the callback returns `Error::BufferTooShort` and the expected buffer
103    /// value contained in the error is larger than the data that was provided,
104    /// we will wait again until there are enough bytes to satisfy the error and
105    /// then call the callback again. If the callback returns
106    /// `Error::BufferTooShort` but the buffer should have been long enough
107    /// according to the error, `Error::CallbackRejectedBuffer` is returned.
108    /// Other errors from the callback are returned as-is from `read` itself.
109    ///
110    /// If there are no bytes available to read and the `Writer` for this stream
111    /// has already been dropped, `read` returns `Error::ConnectionClosed`. If
112    /// there are *not enough* bytes available to be read and the `Writer` has
113    /// been dropped, `read` returns `Error::BufferTooSmall`. This is the only
114    /// time `read` should return `Error::BufferTooSmall`.
115    ///
116    /// Panics if the callback returns a number of bytes greater than the size
117    /// of the buffer.
118    pub async fn read<F, U>(&self, mut size: usize, mut f: F) -> Result<U>
119    where
120        F: FnMut(&[u8]) -> Result<(U, usize)>,
121    {
122        let mut f = move |_: &mut Context<'_>, b: &'_ [u8]| Poll::Ready(f(b));
123        futures::future::poll_fn(move |ctx| self.poll_read(ctx, &mut size, &mut f)).await
124    }
125
126    /// Like `read` but a poll function rather than a future. Passes the
127    /// `Context` to the callback so the callback can also poll.
128    ///
129    /// The `size` argument is a mutable reference so that if we determine
130    /// during reading that more data is needed, we can update the size. It
131    /// should always increase. The value of `size` is not meaningful after
132    /// returning `Poll::Ready`.
133    pub fn poll_read<F, U>(
134        &self,
135        ctx: &mut Context<'_>,
136        size: &mut usize,
137        mut f: F,
138    ) -> Poll<Result<U>>
139    where
140        F: FnMut(&mut Context<'_>, &[u8]) -> Poll<Result<(U, usize)>>,
141    {
142        let mut state = self.0.lock();
143
144        if let Status::Closed(reason) = &state.closed
145            && *size == 0
146        {
147            return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
148        }
149
150        if state.readable >= *size {
151            let (first, _) = state.deque.as_slices();
152
153            let first = if first.len() >= *size {
154                first
155            } else {
156                state.deque.make_contiguous();
157                state.deque.as_slices().0
158            };
159
160            debug_assert!(first.len() >= *size);
161
162            let first = &first[..std::cmp::min(first.len(), state.readable)];
163            let (ret, consumed) = match ready!(f(ctx, first)) {
164                Err(Error::BufferTooShort(s)) => {
165                    if s < first.len() {
166                        return Poll::Ready(Err(Error::CallbackRejectedBuffer(s, first.len())));
167                    }
168
169                    *size = s;
170                    ctx.waker().wake_by_ref();
171                    return Poll::Pending;
172                }
173                other => other?,
174            };
175
176            if consumed > first.len() {
177                panic!("Read claimed to consume more bytes than it was given!");
178            }
179
180            state.readable -= consumed;
181            state.deque.drain(..consumed);
182            let target_capacity = std::cmp::max(
183                state.deque.len().next_multiple_of(BUFFER_TRIM_GRANULARITY),
184                BUFFER_TRIM_GRANULARITY,
185            );
186
187            if target_capacity <= state.deque.capacity() / 2 {
188                state.deque.shrink_to(target_capacity);
189            }
190            return Poll::Ready(Ok(ret));
191        }
192
193        if let Status::Closed(reason) = &state.closed {
194            if state.readable > 0 {
195                return Poll::Ready(Err(Error::BufferTooShort(*size)));
196            } else {
197                return Poll::Ready(Err(Error::ConnectionClosed(reason.clone())));
198            }
199        }
200
201        state.notify_readable = Some((ctx.waker().clone(), *size));
202        Poll::Pending
203    }
204
205    /// Read a protocol message from the stream. This is just a quick way to wire
206    /// `ProtocolObject::try_from_bytes` in to `read`.
207    pub async fn read_protocol_message<P: protocol::ProtocolMessage>(&self) -> Result<P> {
208        self.read(P::MIN_SIZE, P::try_from_bytes).await
209    }
210
211    /// This writes the given protocol message to the stream at the *beginning* of the stream,
212    /// meaning that it will be the next thing read off the stream.
213    pub(crate) fn push_back_protocol_message<P: protocol::ProtocolMessage>(
214        &self,
215        message: &P,
216    ) -> Result<()> {
217        let size = message.byte_size();
218        let mut state = self.0.lock();
219        let readable = state.readable;
220        state.deque.resize(readable + size, 0);
221        state.deque.rotate_right(size);
222        let (first, _) = state.deque.as_mut_slices();
223
224        let mut first = if first.len() >= size {
225            first
226        } else {
227            state.deque.make_contiguous();
228            state.deque.as_mut_slices().0
229        };
230
231        let got = message.write_bytes(&mut first)?;
232        debug_assert!(got == size);
233        state.readable += size;
234
235        if let Some((waker, size)) = state.notify_readable.take() {
236            if size <= state.readable {
237                waker.wake();
238            } else {
239                state.notify_readable = Some((waker, size));
240            }
241        }
242
243        Ok(())
244    }
245
246    /// Whether this stream is closed. Returns false so long as there is unread
247    /// data in the buffer, even if the writer has hung up.
248    pub fn is_closed(&self) -> bool {
249        let state = self.0.lock();
250        state.closed.is_closed() && state.readable == 0
251    }
252
253    /// Get the reason this reader is closed. If the reader is not closed, or if
254    /// no reason was given, return `None`.
255    pub fn closed_reason(&self) -> Option<String> {
256        let state = self.0.lock();
257        state.closed.reason()
258    }
259
260    /// Close this stream, giving a reason for the closure.
261    pub fn close(self, reason: String) {
262        let mut state = self.0.lock();
263        match &state.closed {
264            Status::Closed(Some(_)) => (),
265            _ => state.closed = Status::Closed(Some(reason)),
266        }
267    }
268}
269
270impl std::fmt::Debug for Reader {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(f, "Reader({:?})", Arc::as_ptr(&self.0))
273    }
274}
275
276impl Drop for Reader {
277    fn drop(&mut self) {
278        let mut state = self.0.lock();
279        state.closed.close();
280    }
281}
282
283/// Write half of a stream. See `stream()`.
284pub struct Writer(Arc<SyncMutex<State>>);
285
286impl Writer {
287    /// Write data to this stream.
288    ///
289    /// Space for `size` bytes is allocated in the stream immediately, and then the callback is
290    /// invoked with a mutable slice to that region so that it may populate it. The slice given to
291    /// the callback *may* be larger than requested but will never be smaller.
292    ///
293    /// The callback should return `Ok` with the number of bytes actually written, which may be less
294    /// than `size`. If the callback returns an error, that error is returned from `write` as-is.
295    /// Note that we do not specially process `Error::BufferTooSmall` as with `Reader::read`.
296    ///
297    /// Panics if the callback returns a number of bytes greater than the size of the buffer.
298    pub fn write<F>(&self, size: usize, f: F) -> Result<()>
299    where
300        F: FnOnce(&mut [u8]) -> Result<usize>,
301    {
302        let mut state = self.0.lock();
303
304        if let Status::Closed(reason) = &state.closed {
305            return Err(Error::ConnectionClosed(reason.clone()));
306        }
307
308        let total_size = state.readable + size;
309
310        if state.deque.len() < total_size {
311            let total_size = std::cmp::max(total_size, state.deque.capacity());
312            state.deque.resize(total_size, 0);
313        }
314
315        let readable = state.readable;
316        let (first, second) = state.deque.as_mut_slices();
317
318        let slice = if first.len() > readable {
319            &mut first[readable..]
320        } else {
321            &mut second[(readable - first.len())..]
322        };
323
324        let slice = if slice.len() >= size {
325            slice
326        } else {
327            state.deque.make_contiguous();
328            &mut state.deque.as_mut_slices().0[readable..]
329        };
330
331        debug_assert!(slice.len() >= size);
332        let size = f(slice)?;
333
334        if size > slice.len() {
335            panic!("Write claimed to produce more bytes than buffer had space for!");
336        }
337
338        state.readable += size;
339
340        if let Some((waker, size)) = state.notify_readable.take() {
341            if size <= state.readable {
342                waker.wake();
343            } else {
344                state.notify_readable = Some((waker, size));
345            }
346        }
347
348        Ok(())
349    }
350
351    /// Write a protocol message to the stream. This is just a quick way to wire
352    /// `ProtocolObject::write_bytes` in to `write`.
353    pub fn write_protocol_message<P: protocol::ProtocolMessage>(&self, message: &P) -> Result<()> {
354        self.write(message.byte_size(), |mut buf| message.write_bytes(&mut buf))
355    }
356
357    /// Close this stream, giving a reason for the closure.
358    pub fn close(self, reason: String) {
359        self.0.lock().closed = Status::Closed(Some(reason))
360    }
361
362    /// Whether this stream is closed. Returns false so long as there is unread
363    /// data in the buffer, even if the writer has hung up.
364    pub fn is_closed(&self) -> bool {
365        let state = self.0.lock();
366        state.closed.is_closed() && state.readable == 0
367    }
368
369    /// Get the reason this writer is closed. If the writer is not closed, or if
370    /// no reason was given, return `None`.
371    pub fn closed_reason(&self) -> Option<String> {
372        let state = self.0.lock();
373        state.closed.reason()
374    }
375}
376
377impl std::fmt::Debug for Writer {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        write!(f, "Writer({:?})", Arc::as_ptr(&self.0))
380    }
381}
382
383impl Drop for Writer {
384    fn drop(&mut self) {
385        let mut state = self.0.lock();
386        state.closed.close();
387
388        if let Some((waker, _)) = state.notify_readable.take() {
389            waker.wake();
390        }
391    }
392}
393
394/// Creates a unidirectional stream of bytes.
395///
396/// The `Reader` and `Writer` share an expanding ring buffer. This allows
397/// sending bytes between tasks with minimal extra allocations or copies.
398pub fn stream() -> (Reader, Writer) {
399    let reader = Arc::new(SyncMutex::new(State {
400        deque: VecDeque::new(),
401        readable: 0,
402        notify_readable: None,
403        closed: Status::Open,
404    }));
405    let writer = Arc::clone(&reader);
406
407    (Reader(reader), Writer(writer))
408}
409
410#[cfg(test)]
411mod test {
412    use futures::FutureExt;
413    use futures::channel::oneshot;
414    use futures::task::noop_waker;
415    use std::future::Future;
416    use std::pin::pin;
417    use std::task::{Context, Poll};
418
419    use super::*;
420
421    impl protocol::ProtocolMessage for [u8; 4] {
422        const MIN_SIZE: usize = 4;
423        fn byte_size(&self) -> usize {
424            4
425        }
426
427        fn write_bytes<W: std::io::Write>(&self, out: &mut W) -> Result<usize> {
428            out.write_all(self)?;
429            Ok(4)
430        }
431
432        fn try_from_bytes(bytes: &[u8]) -> Result<(Self, usize)> {
433            if bytes.len() < 4 {
434                return Err(Error::BufferTooShort(4));
435            }
436
437            Ok((bytes[..4].try_into().unwrap(), 4))
438        }
439    }
440
441    #[fuchsia::test]
442    async fn stream_test() {
443        let (reader, writer) = stream();
444        writer
445            .write(8, |buf| {
446                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
447                Ok(8)
448            })
449            .unwrap();
450
451        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
452
453        assert_eq!(vec![1, 2, 3, 4], got);
454
455        writer
456            .write(2, |buf| {
457                buf[..2].copy_from_slice(&[9, 10]);
458                Ok(2)
459            })
460            .unwrap();
461
462        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
463
464        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
465    }
466
467    #[fuchsia::test]
468    async fn push_back_test() {
469        let (reader, writer) = stream();
470        writer
471            .write(8, |buf| {
472                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
473                Ok(8)
474            })
475            .unwrap();
476
477        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
478
479        assert_eq!(vec![1, 2, 3, 4], got);
480
481        reader.push_back_protocol_message(&[4, 3, 2, 1]).unwrap();
482
483        writer
484            .write(2, |buf| {
485                buf[..2].copy_from_slice(&[9, 10]);
486                Ok(2)
487            })
488            .unwrap();
489
490        let got = reader.read(10, |buf| Ok((buf[..10].to_vec(), 6))).await.unwrap();
491
492        assert_eq!(vec![4, 3, 2, 1, 5, 6, 7, 8, 9, 10], got);
493    }
494
495    #[fuchsia::test]
496    async fn writer_sees_close() {
497        let (reader, writer) = stream();
498        writer
499            .write(8, |buf| {
500                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
501                Ok(8)
502            })
503            .unwrap();
504
505        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
506
507        assert_eq!(vec![1, 2, 3, 4], got);
508
509        std::mem::drop(reader);
510
511        assert!(matches!(
512            writer.write(2, |buf| {
513                buf[..2].copy_from_slice(&[9, 10]);
514                Ok(2)
515            }),
516            Err(Error::ConnectionClosed(None))
517        ));
518    }
519
520    #[fuchsia::test]
521    async fn reader_sees_closed() {
522        let (reader, writer) = stream();
523        writer
524            .write(8, |buf| {
525                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
526                Ok(8)
527            })
528            .unwrap();
529
530        let got = reader.read(4, |buf| Ok((buf[..4].to_vec(), 4))).await.unwrap();
531
532        assert_eq!(vec![1, 2, 3, 4], got);
533
534        writer
535            .write(2, |buf| {
536                buf[..2].copy_from_slice(&[9, 10]);
537                Ok(2)
538            })
539            .unwrap();
540
541        std::mem::drop(writer);
542
543        assert!(matches!(reader.read(7, |_| Ok(((), 1))).await, Err(Error::BufferTooShort(7))));
544
545        let got = reader.read(6, |buf| Ok((buf[..6].to_vec(), 6))).await.unwrap();
546
547        assert_eq!(vec![5, 6, 7, 8, 9, 10], got);
548        assert!(matches!(
549            reader.read(1, |_| Ok(((), 1))).await,
550            Err(Error::ConnectionClosed(None))
551        ));
552    }
553
554    #[fuchsia::test]
555    async fn reader_sees_closed_when_polling() {
556        let (reader, writer) = stream();
557        writer
558            .write(8, |buf| {
559                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
560                Ok(8)
561            })
562            .unwrap();
563
564        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
565
566        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
567
568        let fut = reader
569            .read(1, |_| -> Result<((), usize)> { panic!("This read should never succeed!") });
570        let mut fut = std::pin::pin!(fut);
571
572        assert!(fut.poll_unpin(&mut Context::from_waker(&noop_waker())).is_pending());
573
574        std::mem::drop(writer);
575
576        assert!(matches!(
577            fut.poll_unpin(&mut Context::from_waker(&noop_waker())),
578            Poll::Ready(Err(Error::ConnectionClosed(None)))
579        ));
580    }
581
582    #[fuchsia::test]
583    async fn reader_sees_closed_separate_task() {
584        let (reader, writer) = stream();
585        writer
586            .write(8, |buf| {
587                buf[..8].copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
588                Ok(8)
589            })
590            .unwrap();
591
592        let got = reader.read(8, |buf| Ok((buf[..8].to_vec(), 8))).await.unwrap();
593
594        assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], got);
595
596        let (sender, receiver) = oneshot::channel();
597        let task = fuchsia_async::Task::spawn(async move {
598            let fut = reader.read(1, |_| Ok(((), 1)));
599            let mut fut = std::pin::pin!(fut);
600            let mut writer = Some(writer);
601            let fut = futures::future::poll_fn(move |cx| {
602                let ret = fut.as_mut().poll(cx);
603
604                if writer.take().is_some() {
605                    assert!(matches!(ret, Poll::Pending));
606                }
607
608                ret
609            });
610            assert!(matches!(fut.await, Err(Error::ConnectionClosed(None))));
611            sender.send(()).unwrap();
612        });
613
614        receiver.await.unwrap();
615        task.await;
616    }
617
618    #[fuchsia::test]
619    async fn reader_buffer_too_short() {
620        let (reader, writer) = stream();
621        let (sender, receiver) = oneshot::channel();
622        let mut sender = Some(sender);
623
624        let reader_task = async move {
625            let got = reader
626                .read(1, |buf| {
627                    if buf.len() != 4 {
628                        sender.take().unwrap().send(buf.len()).unwrap();
629                        Err(Error::BufferTooShort(4))
630                    } else {
631                        Ok((buf[..4].to_vec(), 4))
632                    }
633                })
634                .await
635                .unwrap();
636            assert_eq!(vec![1, 2, 3, 4], got);
637        };
638
639        let writer_task = async move {
640            writer
641                .write(2, |buf| {
642                    buf[..2].copy_from_slice(&[1, 2]);
643                    Ok(2)
644                })
645                .unwrap();
646
647            assert_eq!(2, receiver.await.unwrap());
648
649            writer
650                .write(2, |buf| {
651                    buf[..2].copy_from_slice(&[3, 4]);
652                    Ok(2)
653                })
654                .unwrap();
655        };
656
657        futures::future::join(pin!(reader_task), pin!(writer_task)).await;
658    }
659}