async_utils/
mutex_ticket.rs1use futures::lock::{Mutex, MutexGuard, MutexLockFuture};
6use futures::prelude::*;
7use std::task::{Context, Poll};
8
9#[derive(Debug)]
31pub struct MutexTicket<'a, T> {
32 mutex: &'a Mutex<T>,
33 lock: Option<MutexLockFuture<'a, T>>,
34}
35
36impl<'a, T> MutexTicket<'a, T> {
37 pub fn new(mutex: &'a Mutex<T>) -> MutexTicket<'a, T> {
39 MutexTicket { mutex, lock: None }
40 }
41
42 pub fn poll(&mut self, ctx: &mut Context<'_>) -> Poll<MutexGuard<'a, T>> {
47 let mut lock_fut = match self.lock.take() {
48 None => self.mutex.lock(),
49 Some(lock_fut) => lock_fut,
50 };
51 match lock_fut.poll_unpin(ctx) {
52 Poll::Pending => {
53 self.lock = Some(lock_fut);
54 Poll::Pending
55 }
56 Poll::Ready(g) => Poll::Ready(g),
57 }
58 }
59
60 pub async fn lock(&mut self) -> MutexGuard<'a, T> {
64 match self.lock.take() {
65 None => self.mutex.lock(),
66 Some(lock_fut) => lock_fut,
67 }
68 .await
69 }
70}
71
72#[cfg(test)]
73mod tests {
74
75 use super::MutexTicket;
76 use anyhow::{Error, format_err};
77 use assert_matches::assert_matches;
78 use fuchsia_async::Timer;
79 use futures::channel::oneshot;
80 use futures::future::{poll_fn, try_join};
81 use futures::lock::Mutex;
82 use std::task::{Context, Poll, Waker};
83 use std::time::Duration;
84
85 #[fuchsia_async::run_singlethreaded(test)]
86 async fn basics(run: usize) {
87 let mutex = Mutex::new(run);
88 let mut ctx = Context::from_waker(Waker::noop());
89 let mut poll_mutex = MutexTicket::new(&mutex);
90 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
91 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
92 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
93 let mutex_guard = mutex.lock().await;
94 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Pending);
95 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Pending);
96 drop(mutex_guard);
97 assert_matches!(poll_mutex.poll(&mut ctx), Poll::Ready(_));
98 }
99
100 #[fuchsia_async::run_singlethreaded(test)]
101 async fn wakes_up(run: usize) -> Result<(), Error> {
102 let mutex = Mutex::new(run);
103 let (tx_saw_first_pending, rx_saw_first_pending) = oneshot::channel();
104 let mut poll_mutex = MutexTicket::new(&mutex);
105 let mutex_guard = mutex.lock().await;
106 try_join(
107 async move {
108 assert_matches!(
109 poll_mutex.poll(&mut Context::from_waker(Waker::noop())),
110 Poll::Pending
111 );
112 tx_saw_first_pending.send(()).map_err(|_| format_err!("cancelled"))?;
113 assert_eq!(*poll_fn(|ctx| poll_mutex.poll(ctx)).await, run);
114 Ok(())
115 },
116 async move {
117 rx_saw_first_pending.await?;
118 Timer::new(Duration::from_millis(300)).await;
119 drop(mutex_guard);
120 Ok(())
121 },
122 )
123 .await
124 .map(|_| ())
125 }
126}