fuchsia_async/
condition.rsuse std::future::poll_fn;
use std::marker::PhantomPinned;
use std::ops::{Deref, DerefMut};
use std::pin::{pin, Pin};
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, MutexGuard};
use std::task::{Poll, Waker};
pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
impl<T> Condition<T> {
pub fn new(data: T) -> Self {
Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
}
pub fn waker_count(&self) -> usize {
self.0.lock().unwrap().count
}
pub fn lock(&self) -> ConditionGuard<'_, T> {
ConditionGuard(&self.0, self.0.lock().unwrap())
}
pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
let mut entry = WakerEntry::new();
entry.list = Some(self.0.clone());
let mut entry = pin!(entry);
poll_fn(|cx| {
let mut guard = self.0.lock().unwrap();
let entry = unsafe { entry.as_mut().get_unchecked_mut() };
let result = poll(&mut guard.data);
if result.is_pending() {
unsafe {
entry.node.add(&mut *guard, cx.waker().clone());
}
}
result
})
.await
}
}
struct Inner<T> {
head: Option<NonNull<Node>>,
count: usize,
data: T,
}
unsafe impl<T: Send> Send for Inner<T> {}
pub struct ConditionGuard<'a, T>(&'a Arc<Mutex<Inner<T>>>, MutexGuard<'a, Inner<T>>);
impl<'a, T> ConditionGuard<'a, T> {
pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
waker_entry.list = Some(self.0.clone());
unsafe {
waker_entry.node.add(&mut *self.1, waker);
}
}
pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
Drainer(self)
}
pub fn waker_count(&self) -> usize {
self.1.count
}
}
impl<T> Deref for ConditionGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.1.data
}
}
impl<T> DerefMut for ConditionGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.1.data
}
}
pub struct WakerEntry<T> {
list: Option<Arc<Mutex<Inner<T>>>>,
node: Node,
}
impl<T> WakerEntry<T> {
pub fn new() -> Self {
Self {
list: None,
node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
}
}
}
impl<T> Drop for WakerEntry<T> {
fn drop(&mut self) {
if let Some(list) = &self.list {
self.node.remove(&mut *list.lock().unwrap());
}
}
}
struct Node {
next: Option<NonNull<Node>>,
prev: Option<NonNull<Node>>,
waker: Option<Waker>,
_pinned: PhantomPinned,
}
unsafe impl Send for Node {}
impl Node {
unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
if self.waker.is_none() {
self.prev = None;
self.next = inner.head;
inner.head = Some(self.into());
if let Some(mut next) = self.next {
unsafe {
next.as_mut().prev = Some(self.into());
}
}
inner.count += 1;
}
self.waker = Some(waker);
}
fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
if self.waker.is_none() {
debug_assert!(self.prev.is_none() && self.next.is_none());
return None;
}
if let Some(mut next) = self.next {
unsafe { next.as_mut().prev = self.prev };
}
if let Some(mut prev) = self.prev {
unsafe { prev.as_mut().next = self.next };
} else {
inner.head = self.next;
}
self.prev = None;
self.next = None;
inner.count -= 1;
self.waker.take()
}
}
pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
impl<T> Iterator for Drainer<'_, '_, T> {
type Item = Waker;
fn next(&mut self) -> Option<Self::Item> {
if let Some(mut head) = self.0 .1.head {
unsafe { head.as_mut().remove(&mut self.0 .1) }
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.0 .1.count, Some(self.0 .1.count))
}
}
impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
fn len(&self) -> usize {
self.0 .1.count
}
}
#[cfg(all(target_os = "fuchsia", test))]
mod tests {
use super::{Condition, WakerEntry};
use crate::TestExecutor;
use futures::stream::FuturesUnordered;
use futures::task::noop_waker;
use futures::StreamExt;
use std::pin::pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::Poll;
#[test]
fn test_condition_can_waker_multiple_wakers() {
let mut executor = TestExecutor::new();
let condition = Condition::new(());
static COUNT: u64 = 10;
let counter = AtomicU64::new(0);
let mut futures = FuturesUnordered::new();
for _ in 0..COUNT {
futures.push(condition.when(|()| {
if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
Poll::Ready(())
} else {
Poll::Pending
}
}));
}
assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
assert_eq!(counter.load(Ordering::Relaxed), COUNT);
assert_eq!(condition.waker_count(), COUNT as usize);
{
let mut guard = condition.lock();
let drainer = guard.drain_wakers();
assert_eq!(drainer.len(), COUNT as usize);
for waker in drainer {
waker.wake();
}
}
assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
}
#[test]
fn test_dropping_waker_entry_removes_from_list() {
let condition = Condition::new(());
let entry1 = pin!(WakerEntry::new());
condition.lock().add_waker(entry1, noop_waker());
{
let entry2 = pin!(WakerEntry::new());
condition.lock().add_waker(entry2, noop_waker());
assert_eq!(condition.waker_count(), 2);
}
assert_eq!(condition.waker_count(), 1);
{
let mut guard = condition.lock();
assert_eq!(guard.drain_wakers().count(), 1);
}
assert_eq!(condition.waker_count(), 0);
let entry3 = pin!(WakerEntry::new());
condition.lock().add_waker(entry3, noop_waker());
assert_eq!(condition.waker_count(), 1);
}
#[test]
fn test_waker_can_be_added_multiple_times() {
let condition = Condition::new(());
let mut entry1 = pin!(WakerEntry::new());
condition.lock().add_waker(entry1.as_mut(), noop_waker());
let mut entry2 = pin!(WakerEntry::new());
condition.lock().add_waker(entry2.as_mut(), noop_waker());
assert_eq!(condition.waker_count(), 2);
{
let mut guard = condition.lock();
assert_eq!(guard.drain_wakers().count(), 2);
}
assert_eq!(condition.waker_count(), 0);
condition.lock().add_waker(entry1, noop_waker());
condition.lock().add_waker(entry2, noop_waker());
assert_eq!(condition.waker_count(), 2);
{
let mut guard = condition.lock();
assert_eq!(guard.drain_wakers().count(), 2);
}
assert_eq!(condition.waker_count(), 0);
}
}