futures_test/
assert_unmoved.rs

1use futures_core::future::{FusedFuture, Future};
2use futures_core::stream::{FusedStream, Stream};
3use futures_core::task::{Context, Poll};
4use futures_io::{
5    self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom,
6};
7use futures_sink::Sink;
8use pin_project::{pin_project, pinned_drop};
9use std::pin::Pin;
10use std::thread::panicking;
11
12/// Combinator that asserts that the underlying type is not moved after being polled.
13///
14/// See the `assert_unmoved` methods on:
15/// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved)
16/// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved)
17/// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink)
18/// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved)
19/// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write)
20#[pin_project(PinnedDrop, !Unpin)]
21#[derive(Debug, Clone)]
22#[must_use = "futures do nothing unless you `.await` or poll them"]
23pub struct AssertUnmoved<T> {
24    #[pin]
25    inner: T,
26    this_addr: usize,
27}
28
29impl<T> AssertUnmoved<T> {
30    pub(crate) fn new(inner: T) -> Self {
31        Self { inner, this_addr: 0 }
32    }
33
34    fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U {
35        let cur_this = &*self as *const Self as usize;
36        if self.this_addr == 0 {
37            // First time being polled
38            *self.as_mut().project().this_addr = cur_this;
39        } else {
40            assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls");
41        }
42        f(self.project().inner)
43    }
44}
45
46impl<Fut: Future> Future for AssertUnmoved<Fut> {
47    type Output = Fut::Output;
48
49    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50        self.poll_with(|f| f.poll(cx))
51    }
52}
53
54impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> {
55    fn is_terminated(&self) -> bool {
56        self.inner.is_terminated()
57    }
58}
59
60impl<St: Stream> Stream for AssertUnmoved<St> {
61    type Item = St::Item;
62
63    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64        self.poll_with(|s| s.poll_next(cx))
65    }
66}
67
68impl<St: FusedStream> FusedStream for AssertUnmoved<St> {
69    fn is_terminated(&self) -> bool {
70        self.inner.is_terminated()
71    }
72}
73
74impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> {
75    type Error = Si::Error;
76
77    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        self.poll_with(|s| s.poll_ready(cx))
79    }
80
81    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
82        self.poll_with(|s| s.start_send(item))
83    }
84
85    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.poll_with(|s| s.poll_flush(cx))
87    }
88
89    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.poll_with(|s| s.poll_close(cx))
91    }
92}
93
94impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> {
95    fn poll_read(
96        self: Pin<&mut Self>,
97        cx: &mut Context<'_>,
98        buf: &mut [u8],
99    ) -> Poll<io::Result<usize>> {
100        self.poll_with(|r| r.poll_read(cx, buf))
101    }
102
103    fn poll_read_vectored(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        bufs: &mut [IoSliceMut<'_>],
107    ) -> Poll<io::Result<usize>> {
108        self.poll_with(|r| r.poll_read_vectored(cx, bufs))
109    }
110}
111
112impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> {
113    fn poll_write(
114        self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        buf: &[u8],
117    ) -> Poll<io::Result<usize>> {
118        self.poll_with(|w| w.poll_write(cx, buf))
119    }
120
121    fn poll_write_vectored(
122        self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124        bufs: &[IoSlice<'_>],
125    ) -> Poll<io::Result<usize>> {
126        self.poll_with(|w| w.poll_write_vectored(cx, bufs))
127    }
128
129    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130        self.poll_with(|w| w.poll_flush(cx))
131    }
132
133    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134        self.poll_with(|w| w.poll_close(cx))
135    }
136}
137
138impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> {
139    fn poll_seek(
140        self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        pos: SeekFrom,
143    ) -> Poll<io::Result<u64>> {
144        self.poll_with(|s| s.poll_seek(cx, pos))
145    }
146}
147
148impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> {
149    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
150        self.poll_with(|r| r.poll_fill_buf(cx))
151    }
152
153    fn consume(self: Pin<&mut Self>, amt: usize) {
154        self.poll_with(|r| r.consume(amt))
155    }
156}
157
158#[pinned_drop]
159impl<T> PinnedDrop for AssertUnmoved<T> {
160    fn drop(self: Pin<&mut Self>) {
161        // If the thread is panicking then we can't panic again as that will
162        // cause the process to be aborted.
163        if !panicking() && self.this_addr != 0 {
164            let cur_this = &*self as *const Self as usize;
165            assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop");
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use futures_core::future::Future;
173    use futures_core::task::{Context, Poll};
174    use futures_util::future::pending;
175    use futures_util::task::noop_waker;
176    use std::pin::Pin;
177
178    use super::AssertUnmoved;
179
180    #[test]
181    fn assert_send_sync() {
182        fn assert<T: Send + Sync>() {}
183        assert::<AssertUnmoved<()>>();
184    }
185
186    #[test]
187    fn dont_panic_when_not_polled() {
188        // This shouldn't panic.
189        let future = AssertUnmoved::new(pending::<()>());
190        drop(future);
191    }
192
193    #[test]
194    #[should_panic(expected = "AssertUnmoved moved between poll calls")]
195    fn dont_double_panic() {
196        // This test should only panic, not abort the process.
197        let waker = noop_waker();
198        let mut cx = Context::from_waker(&waker);
199
200        // First we allocate the future on the stack and poll it.
201        let mut future = AssertUnmoved::new(pending::<()>());
202        let pinned_future = unsafe { Pin::new_unchecked(&mut future) };
203        assert_eq!(pinned_future.poll(&mut cx), Poll::Pending);
204
205        // Next we move it back to the heap and poll it again. This second call
206        // should panic (as the future is moved), but we shouldn't panic again
207        // whilst dropping `AssertUnmoved`.
208        let mut future = Box::new(future);
209        let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) };
210        assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending);
211    }
212}