usb_vsock/connection/
pause_state.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_sync::Mutex;
6use std::future::Future;
7use std::pin::pin;
8use std::sync::Arc;
9use std::task::{Poll, Waker};
10
11/// Maintains whether a connection is paused. A paused connection should not
12/// send any more data to the peer.
13pub struct PauseState(Mutex<PauseStateInner>);
14
15/// Mutex-protected interior of [`PauseState`]
16struct PauseStateInner {
17    paused: bool,
18    wakers: Vec<Waker>,
19}
20
21impl PauseState {
22    /// Create a new [`PauseState`]. The initial state is un-paused.
23    pub fn new() -> Arc<Self> {
24        Arc::new(PauseState(Mutex::new(PauseStateInner { paused: false, wakers: Vec::new() })))
25    }
26
27    /// Polls the given future, but pauses polling when we are in the paused
28    /// state.
29    pub async fn while_unpaused<T>(&self, f: impl Future<Output = T>) -> T {
30        let mut f = pin!(f);
31        futures::future::poll_fn(move |ctx| {
32            {
33                let mut this = self.0.lock();
34
35                if this.wakers.iter().all(|x| !x.will_wake(ctx.waker())) {
36                    this.wakers.push(ctx.waker().clone());
37                }
38
39                if this.paused {
40                    return Poll::Pending;
41                }
42            }
43
44            f.as_mut().poll(ctx)
45        })
46        .await
47    }
48
49    /// Set the paused state.
50    pub fn set_paused(&self, paused: bool) {
51        let mut this = self.0.lock();
52
53        this.paused = paused;
54        this.wakers.drain(..).for_each(Waker::wake);
55    }
56}
57
58#[cfg(test)]
59mod test {
60    use super::*;
61    use futures::{Stream, StreamExt};
62    use std::task::Context;
63
64    #[fuchsia::test]
65    async fn test_pause() {
66        let pause_state = PauseState::new();
67        let pause_state_clone = Arc::clone(&pause_state);
68        let stream = futures::stream::iter(1..)
69            .then(|x| pause_state_clone.while_unpaused(futures::future::ready(x)));
70        let mut stream = pin!(stream);
71        let mut ctx = Context::from_waker(&Waker::noop());
72
73        assert_eq!(Poll::Ready(Some(1)), stream.as_mut().poll_next(&mut ctx));
74        assert_eq!(Poll::Ready(Some(2)), stream.as_mut().poll_next(&mut ctx));
75        assert_eq!(Poll::Ready(Some(3)), stream.as_mut().poll_next(&mut ctx));
76        assert_eq!(Poll::Ready(Some(4)), stream.as_mut().poll_next(&mut ctx));
77        assert_eq!(Poll::Ready(Some(5)), stream.as_mut().poll_next(&mut ctx));
78
79        pause_state.set_paused(true);
80
81        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
82        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
83        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
84        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
85        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
86
87        pause_state.set_paused(true);
88
89        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
90        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
91        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
92        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
93        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
94
95        pause_state.set_paused(false);
96
97        assert_eq!(Poll::Ready(Some(6)), stream.as_mut().poll_next(&mut ctx));
98        assert_eq!(Poll::Ready(Some(7)), stream.as_mut().poll_next(&mut ctx));
99        assert_eq!(Poll::Ready(Some(8)), stream.as_mut().poll_next(&mut ctx));
100        assert_eq!(Poll::Ready(Some(9)), stream.as_mut().poll_next(&mut ctx));
101        assert_eq!(Poll::Ready(Some(10)), stream.as_mut().poll_next(&mut ctx));
102    }
103}