futures_util/io/
take.rs

1use futures_core::ready;
2use futures_core::task::{Context, Poll};
3use futures_io::{AsyncBufRead, AsyncRead};
4use pin_project_lite::pin_project;
5use std::pin::Pin;
6use std::{cmp, io};
7
8pin_project! {
9    /// Reader for the [`take`](super::AsyncReadExt::take) method.
10    #[derive(Debug)]
11    #[must_use = "readers do nothing unless you `.await` or poll them"]
12    pub struct Take<R> {
13        #[pin]
14        inner: R,
15        limit: u64,
16    }
17}
18
19impl<R: AsyncRead> Take<R> {
20    pub(super) fn new(inner: R, limit: u64) -> Self {
21        Self { inner, limit }
22    }
23
24    /// Returns the remaining number of bytes that can be
25    /// read before this instance will return EOF.
26    ///
27    /// # Note
28    ///
29    /// This instance may reach `EOF` after reading fewer bytes than indicated by
30    /// this method if the underlying [`AsyncRead`] instance reaches EOF.
31    ///
32    /// # Examples
33    ///
34    /// ```
35    /// # futures::executor::block_on(async {
36    /// use futures::io::{AsyncReadExt, Cursor};
37    ///
38    /// let reader = Cursor::new(&b"12345678"[..]);
39    /// let mut buffer = [0; 2];
40    ///
41    /// let mut take = reader.take(4);
42    /// let n = take.read(&mut buffer).await?;
43    ///
44    /// assert_eq!(take.limit(), 2);
45    /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
46    /// ```
47    pub fn limit(&self) -> u64 {
48        self.limit
49    }
50
51    /// Sets the number of bytes that can be read before this instance will
52    /// return EOF. This is the same as constructing a new `Take` instance, so
53    /// the amount of bytes read and the previous limit value don't matter when
54    /// calling this method.
55    ///
56    /// # Examples
57    ///
58    /// ```
59    /// # futures::executor::block_on(async {
60    /// use futures::io::{AsyncReadExt, Cursor};
61    ///
62    /// let reader = Cursor::new(&b"12345678"[..]);
63    /// let mut buffer = [0; 4];
64    ///
65    /// let mut take = reader.take(4);
66    /// let n = take.read(&mut buffer).await?;
67    ///
68    /// assert_eq!(n, 4);
69    /// assert_eq!(take.limit(), 0);
70    ///
71    /// take.set_limit(10);
72    /// let n = take.read(&mut buffer).await?;
73    /// assert_eq!(n, 4);
74    ///
75    /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
76    /// ```
77    pub fn set_limit(&mut self, limit: u64) {
78        self.limit = limit
79    }
80
81    delegate_access_inner!(inner, R, ());
82}
83
84impl<R: AsyncRead> AsyncRead for Take<R> {
85    fn poll_read(
86        self: Pin<&mut Self>,
87        cx: &mut Context<'_>,
88        buf: &mut [u8],
89    ) -> Poll<Result<usize, io::Error>> {
90        let this = self.project();
91
92        if *this.limit == 0 {
93            return Poll::Ready(Ok(0));
94        }
95
96        let max = cmp::min(buf.len() as u64, *this.limit) as usize;
97        let n = ready!(this.inner.poll_read(cx, &mut buf[..max]))?;
98        *this.limit -= n as u64;
99        Poll::Ready(Ok(n))
100    }
101}
102
103impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
104    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
105        let this = self.project();
106
107        // Don't call into inner reader at all at EOF because it may still block
108        if *this.limit == 0 {
109            return Poll::Ready(Ok(&[]));
110        }
111
112        let buf = ready!(this.inner.poll_fill_buf(cx)?);
113        let cap = cmp::min(buf.len() as u64, *this.limit) as usize;
114        Poll::Ready(Ok(&buf[..cap]))
115    }
116
117    fn consume(self: Pin<&mut Self>, amt: usize) {
118        let this = self.project();
119
120        // Don't let callers reset the limit by passing an overlarge value
121        let amt = cmp::min(amt as u64, *this.limit) as usize;
122        *this.limit -= amt as u64;
123        this.inner.consume(amt);
124    }
125}