futures_test/
track_closed.rs

1use futures_io::AsyncWrite;
2use futures_sink::Sink;
3use std::{
4    io::{self, IoSlice},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9/// Async wrapper that tracks whether it has been closed.
10///
11/// See the `track_closed` methods on:
12/// * [`SinkTestExt`](crate::sink::SinkTestExt::track_closed)
13/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::track_closed)
14#[pin_project::pin_project]
15#[derive(Debug)]
16pub struct TrackClosed<T> {
17    #[pin]
18    inner: T,
19    closed: bool,
20}
21
22impl<T> TrackClosed<T> {
23    pub(crate) fn new(inner: T) -> Self {
24        Self { inner, closed: false }
25    }
26
27    /// Check whether this object has been closed.
28    pub fn is_closed(&self) -> bool {
29        self.closed
30    }
31
32    /// Acquires a reference to the underlying object that this adaptor is
33    /// wrapping.
34    pub fn get_ref(&self) -> &T {
35        &self.inner
36    }
37
38    /// Acquires a mutable reference to the underlying object that this
39    /// adaptor is wrapping.
40    pub fn get_mut(&mut self) -> &mut T {
41        &mut self.inner
42    }
43
44    /// Acquires a pinned mutable reference to the underlying object that
45    /// this adaptor is wrapping.
46    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
47        self.project().inner
48    }
49
50    /// Consumes this adaptor returning the underlying object.
51    pub fn into_inner(self) -> T {
52        self.inner
53    }
54}
55
56impl<T: AsyncWrite> AsyncWrite for TrackClosed<T> {
57    fn poll_write(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &[u8],
61    ) -> Poll<io::Result<usize>> {
62        if self.is_closed() {
63            return Poll::Ready(Err(io::Error::new(
64                io::ErrorKind::Other,
65                "Attempted to write after stream was closed",
66            )));
67        }
68        self.project().inner.poll_write(cx, buf)
69    }
70
71    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72        if self.is_closed() {
73            return Poll::Ready(Err(io::Error::new(
74                io::ErrorKind::Other,
75                "Attempted to flush after stream was closed",
76            )));
77        }
78        assert!(!self.is_closed());
79        self.project().inner.poll_flush(cx)
80    }
81
82    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83        if self.is_closed() {
84            return Poll::Ready(Err(io::Error::new(
85                io::ErrorKind::Other,
86                "Attempted to close after stream was closed",
87            )));
88        }
89        let this = self.project();
90        match this.inner.poll_close(cx) {
91            Poll::Ready(Ok(())) => {
92                *this.closed = true;
93                Poll::Ready(Ok(()))
94            }
95            other => other,
96        }
97    }
98
99    fn poll_write_vectored(
100        self: Pin<&mut Self>,
101        cx: &mut Context<'_>,
102        bufs: &[IoSlice<'_>],
103    ) -> Poll<io::Result<usize>> {
104        if self.is_closed() {
105            return Poll::Ready(Err(io::Error::new(
106                io::ErrorKind::Other,
107                "Attempted to write after stream was closed",
108            )));
109        }
110        self.project().inner.poll_write_vectored(cx, bufs)
111    }
112}
113
114impl<Item, T: Sink<Item>> Sink<Item> for TrackClosed<T> {
115    type Error = T::Error;
116
117    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        assert!(!self.is_closed());
119        self.project().inner.poll_ready(cx)
120    }
121
122    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
123        assert!(!self.is_closed());
124        self.project().inner.start_send(item)
125    }
126
127    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128        assert!(!self.is_closed());
129        self.project().inner.poll_flush(cx)
130    }
131
132    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
133        assert!(!self.is_closed());
134        let this = self.project();
135        match this.inner.poll_close(cx) {
136            Poll::Ready(Ok(())) => {
137                *this.closed = true;
138                Poll::Ready(Ok(()))
139            }
140            other => other,
141        }
142    }
143}