async_utils/
mutex_ticket.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::lock::{Mutex, MutexGuard, MutexLockFuture};
6use futures::prelude::*;
7use std::task::{Context, Poll};
8
9/// Helper to poll a mutex.
10///
11/// Since Mutex::lock futures keep track of where in queue the lock request is,
12/// this is different to `mutex.lock().poll(ctx)` as that construction will create
13/// a new lock request at each poll.
14/// This can often be useful when we need to poll something that is contained under
15/// this mutex.
16///
17/// Typical usage:
18///   let mut ticket = MutexTicket::new();
19///   poll_fn(|ctx| {
20///     let mutex_guard = ready!(ticket.poll(ctx));
21///     mutex_guard.some_child_future.poll(ctx)
22///   }).await;
23///
24/// What this means:
25///   Attempt to acquire the mutex. If it's not available, wait until it's available.
26///   With the mutex acquired, check some_child_future.
27///   If it's completed, complete the poll_fn.
28///   *If it's not completed* drop the mutex guard (unblock other tasks) and wait for
29///   some_child_future to be awoken.
30#[derive(Debug)]
31pub struct MutexTicket<'a, T> {
32    mutex: &'a Mutex<T>,
33    lock: Option<MutexLockFuture<'a, T>>,
34}
35
36impl<'a, T> MutexTicket<'a, T> {
37    /// Create a new `MutexTicket`
38    pub fn new(mutex: &'a Mutex<T>) -> MutexTicket<'a, T> {
39        MutexTicket { mutex, lock: None }
40    }
41
42    /// Poll once to see if the lock has been acquired.
43    /// This is not Future::poll because it's intended to be a helper used during a Future::poll
44    /// implementation, but never as a Future itself -- one can simply call Mutex::lock.await in that
45    /// case!
46    pub fn poll(&mut self, ctx: &mut Context<'_>) -> Poll<MutexGuard<'a, T>> {
47        let mut lock_fut = match self.lock.take() {
48            None => self.mutex.lock(),
49            Some(lock_fut) => lock_fut,
50        };
51        match lock_fut.poll_unpin(ctx) {
52            Poll::Pending => {
53                self.lock = Some(lock_fut);
54                Poll::Pending
55            }
56            Poll::Ready(g) => Poll::Ready(g),
57        }
58    }
59
60    /// Finish locking. This should be used instead of the Mutex.lock function *if* there
61    /// is a `MutexTicket` constructed already - it may be that said `MutexTicket` has already been
62    /// granted ownership of the Mutex - if this is the case, the Mutex.lock call will never succeed.
63    pub async fn lock(&mut self) -> MutexGuard<'a, T> {
64        match self.lock.take() {
65            None => self.mutex.lock(),
66            Some(lock_fut) => lock_fut,
67        }
68        .await
69    }
70}
71
72#[cfg(test)]
73mod tests {
74
75    use super::MutexTicket;
76    use anyhow::{format_err, Error};
77    use assert_matches::assert_matches;
78    use fuchsia_async::Timer;
79    use futures::channel::oneshot;
80    use futures::future::{poll_fn, try_join};
81    use futures::lock::Mutex;
82    use futures::task::noop_waker_ref;
83    use std::task::{Context, Poll};
84    use std::time::Duration;
85
86    #[fuchsia_async::run_singlethreaded(test)]
87    async fn basics(run: usize) {
88        let mutex = Mutex::new(run);
89        let mut ctx = Context::from_waker(noop_waker_ref());
90        let mut poll_mutex = MutexTicket::new(&mutex);
91        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
92        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
93        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
94        let mutex_guard = mutex.lock().await;
95        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Pending);
96        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Pending);
97        drop(mutex_guard);
98        assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
99    }
100
101    #[fuchsia_async::run_singlethreaded(test)]
102    async fn wakes_up(run: usize) -> Result<(), Error> {
103        let mutex = Mutex::new(run);
104        let (tx_saw_first_pending, rx_saw_first_pending) = oneshot::channel();
105        let mut poll_mutex = MutexTicket::new(&mutex);
106        let mutex_guard = mutex.lock().await;
107        try_join(
108            async move {
109                assert_matches!(
110                    poll_mutex.poll(&mut Context::from_waker(noop_waker_ref())),
111                    Poll::Pending
112                );
113                tx_saw_first_pending.send(()).map_err(|_| format_err!("cancelled"))?;
114                assert_eq!(*poll_fn(|ctx| poll_mutex.poll(ctx)).await, run);
115                Ok(())
116            },
117            async move {
118                rx_saw_first_pending.await?;
119                Timer::new(Duration::from_millis(300)).await;
120                drop(mutex_guard);
121                Ok(())
122            },
123        )
124        .await
125        .map(|_| ())
126    }
127}