use crate::runtime::{BootInstant, EHandle, MonotonicInstant, WakeupTime};
use crate::PacketReceiver;
use fuchsia_sync::Mutex;
use futures::future::FusedFuture;
use futures::stream::FusedStream;
use futures::task::{AtomicWaker, Context};
use futures::{FutureExt, Stream};
use std::cell::UnsafeCell;
use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::task::{ready, Poll, Waker};
use zx::AsHandleRef as _;
pub trait TimeInterface:
Clone + Copy + fmt::Debug + PartialEq + PartialOrd + Ord + Send + Sync + 'static
{
type Timeline: zx::Timeline + Send + Sync + 'static;
fn from_nanos(nanos: i64) -> Self;
fn into_nanos(self) -> i64;
fn zx_instant(nanos: i64) -> zx::Instant<Self::Timeline>;
fn now() -> i64;
}
impl TimeInterface for MonotonicInstant {
type Timeline = zx::MonotonicTimeline;
fn from_nanos(nanos: i64) -> Self {
Self::from_nanos(nanos)
}
fn into_nanos(self) -> i64 {
self.into_nanos()
}
fn zx_instant(nanos: i64) -> zx::MonotonicInstant {
zx::MonotonicInstant::from_nanos(nanos)
}
fn now() -> i64 {
EHandle::local().inner().now().into_nanos()
}
}
impl TimeInterface for BootInstant {
type Timeline = zx::BootTimeline;
fn from_nanos(nanos: i64) -> Self {
Self::from_nanos(nanos)
}
fn into_nanos(self) -> i64 {
self.into_nanos()
}
fn zx_instant(nanos: i64) -> zx::BootInstant {
zx::BootInstant::from_nanos(nanos)
}
fn now() -> i64 {
EHandle::local().inner().boot_now().into_nanos()
}
}
impl WakeupTime for std::time::Instant {
fn into_timer(self) -> Timer {
let now_as_instant = std::time::Instant::now();
let now_as_time = MonotonicInstant::now();
EHandle::local()
.mono_timers()
.new_timer(now_as_time + self.saturating_duration_since(now_as_instant).into())
}
}
impl WakeupTime for MonotonicInstant {
fn into_timer(self) -> Timer {
EHandle::local().mono_timers().new_timer(self)
}
}
impl WakeupTime for BootInstant {
fn into_timer(self) -> Timer {
EHandle::local().boot_timers().new_timer(self)
}
}
impl WakeupTime for zx::MonotonicInstant {
fn into_timer(self) -> Timer {
EHandle::local().mono_timers().new_timer(self.into())
}
}
impl WakeupTime for zx::BootInstant {
fn into_timer(self) -> Timer {
EHandle::local().boot_timers().new_timer(self.into())
}
}
#[must_use = "futures do nothing unless polled"]
pub struct Timer(TimerState);
impl Timer {
pub fn new(time: impl WakeupTime) -> Self {
time.into_timer()
}
pub fn reset(self: Pin<&mut Self>, time: MonotonicInstant) {
let nanos = time.into_nanos();
if self.0.state.load(Ordering::Relaxed) != REGISTERED
|| !self.0.timers.try_reset_timer(&self.0, nanos)
{
unsafe { *self.0.nanos.get() = nanos };
self.0.state.store(0, Ordering::Relaxed);
}
}
}
impl fmt::Debug for Timer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("Timer").field("time", &self.0.nanos).finish()
}
}
impl Drop for Timer {
fn drop(&mut self) {
self.0.timers.unregister(&self.0);
}
}
impl Future for Timer {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe { self.0.timers.poll(self.as_ref(), cx) }
}
}
struct TimerState {
timers: Arc<dyn TimersInterface>,
nanos: UnsafeCell<i64>,
waker: AtomicWaker,
state: AtomicU8,
index: UnsafeCell<HeapIndex>,
_pinned: PhantomPinned,
}
unsafe impl Send for TimerState {}
unsafe impl Sync for TimerState {}
const REGISTERED: u8 = 1;
const FIRED: u8 = 2;
const TERMINATED: u8 = 3;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct HeapIndex(usize);
impl HeapIndex {
const NULL: HeapIndex = HeapIndex(usize::MAX);
fn get(&self) -> Option<usize> {
if *self == HeapIndex::NULL {
None
} else {
Some(self.0)
}
}
}
impl From<usize> for HeapIndex {
fn from(value: usize) -> Self {
Self(value)
}
}
impl FusedFuture for Timer {
fn is_terminated(&self) -> bool {
self.0.state.load(Ordering::Relaxed) == TERMINATED
}
}
#[derive(Copy, Clone, Debug)]
struct StateRef(*const TimerState);
unsafe impl Send for StateRef {}
unsafe impl Sync for StateRef {}
impl StateRef {
fn into_waker(self, _inner: &mut Inner) -> Option<Waker> {
unsafe {
let waker = (*self.0).waker.take();
(*self.0).state.store(FIRED, Ordering::Relaxed);
waker
}
}
unsafe fn nanos(&self) -> i64 {
*(*self.0).nanos.get()
}
unsafe fn nanos_mut(&mut self) -> &mut i64 {
&mut *(*self.0).nanos.get()
}
unsafe fn set_index(&mut self, index: HeapIndex) -> HeapIndex {
std::mem::replace(&mut *(*self.0).index.get(), index)
}
}
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct Interval {
timer: Pin<Box<Timer>>,
next: MonotonicInstant,
duration: zx::MonotonicDuration,
}
impl Interval {
pub fn new(duration: zx::MonotonicDuration) -> Self {
let next = MonotonicInstant::after(duration);
Interval { timer: Box::pin(Timer::new(next)), next, duration }
}
}
impl FusedStream for Interval {
fn is_terminated(&self) -> bool {
false
}
}
impl Stream for Interval {
type Item = ();
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
ready!(self.timer.poll_unpin(cx));
let next = self.next + self.duration;
self.timer.as_mut().reset(next);
self.next = next;
Poll::Ready(Some(()))
}
}
pub(crate) struct Timers<T: TimeInterface> {
inner: Mutex<Inner>,
port_key: u64,
fake: bool,
timer: zx::Timer<T::Timeline>,
}
struct Inner {
timers: Heap,
async_wait: bool,
}
impl Timers<MonotonicInstant> {
pub fn new(port_key: u64, fake: bool) -> Self {
Self {
inner: Mutex::new(Inner { timers: Heap::default(), async_wait: false }),
port_key,
fake,
timer: zx::MonotonicTimer::create(),
}
}
}
impl Timers<BootInstant> {
pub fn new(port_key: u64, fake: bool) -> Self {
Self {
inner: Mutex::new(Inner { timers: Heap::default(), async_wait: false }),
port_key,
fake,
timer: zx::BootTimer::create(),
}
}
}
impl<T: TimeInterface> Timers<T> {
pub fn new_timer(self: &Arc<Self>, time: T) -> Timer {
let nanos = time.into_nanos();
Timer(TimerState {
timers: self.clone(),
nanos: UnsafeCell::new(nanos),
waker: AtomicWaker::new(),
state: AtomicU8::new(0),
index: UnsafeCell::new(HeapIndex::NULL),
_pinned: PhantomPinned,
})
}
fn set_timer(&self, inner: &mut Inner, time: i64) {
self.timer.set(T::zx_instant(time), zx::Duration::ZERO).unwrap();
if !inner.async_wait {
if self.fake {
self.timer.signal_handle(zx::Signals::USER_0, zx::Signals::empty()).unwrap();
}
self.timer
.wait_async_handle(
EHandle::local().port(),
self.port_key,
if self.fake { zx::Signals::USER_0 } else { zx::Signals::TIMER_SIGNALED },
zx::WaitAsyncOpts::empty(),
)
.unwrap();
inner.async_wait = true;
}
}
pub fn port_key(&self) -> u64 {
self.port_key
}
pub fn wake_timers(&self) -> bool {
self.wake_timers_impl(false)
}
fn wake_timers_impl(&self, from_receive_packet: bool) -> bool {
let now = T::now();
let mut timers_woken = false;
loop {
let waker = {
let mut inner = self.inner.lock();
if from_receive_packet {
inner.async_wait = false;
}
match inner.timers.peek() {
Some(timer) => {
let nanos = unsafe { timer.nanos() };
if nanos <= now {
let timer = inner.timers.pop().unwrap();
timer.into_waker(&mut inner)
} else {
self.set_timer(&mut inner, nanos);
break;
}
}
_ => break,
}
};
if let Some(waker) = waker {
waker.wake()
}
timers_woken = true;
}
timers_woken
}
pub fn wake_next_timer(&self) -> Option<T> {
let (nanos, waker) = {
let mut inner = self.inner.lock();
let Some(timer) = inner.timers.pop() else { return None };
let nanos = unsafe { timer.nanos() };
(nanos, timer.into_waker(&mut inner))
};
if let Some(waker) = waker {
waker.wake();
}
Some(T::from_nanos(nanos))
}
pub fn next_timer(&self) -> Option<T> {
self.inner.lock().timers.peek().map(|state| T::from_nanos(unsafe { state.nanos() }))
}
pub fn maybe_notify(&self, now: T) {
assert!(self.fake, "calling this function requires using fake time.");
if self
.inner
.lock()
.timers
.peek()
.map_or(false, |state| unsafe { state.nanos() } <= now.into_nanos())
{
self.timer.signal_handle(zx::Signals::empty(), zx::Signals::USER_0).unwrap();
}
}
}
impl<T: TimeInterface> PacketReceiver for Timers<T> {
fn receive_packet(&self, _packet: zx::Packet) {
self.wake_timers_impl(true);
}
}
trait TimersInterface: Send + Sync + 'static {
unsafe fn poll(&self, timer: Pin<&Timer>, cx: &mut Context<'_>) -> Poll<()>;
fn unregister(&self, state: &TimerState);
fn try_reset_timer(&self, timer: &TimerState, nanos: i64) -> bool;
}
impl<T: TimeInterface> TimersInterface for Timers<T> {
unsafe fn poll(&self, timer: Pin<&Timer>, cx: &mut Context<'_>) -> Poll<()> {
let state = timer.0.state.load(Ordering::Relaxed);
if state == TERMINATED {
return Poll::Ready(());
}
if state == FIRED {
timer.0.state.store(TERMINATED, Ordering::Relaxed);
return Poll::Ready(());
}
if state == 0 {
let nanos = unsafe { *timer.0.nanos.get() };
if nanos <= T::now() {
timer.0.state.store(FIRED, Ordering::Relaxed);
return Poll::Ready(());
}
let mut inner = self.inner.lock();
if inner.timers.peek().map_or(true, |s| nanos < unsafe { s.nanos() }) {
self.set_timer(&mut inner, nanos);
}
inner.timers.push(StateRef(&timer.0));
timer.0.state.store(REGISTERED, Ordering::Relaxed);
}
timer.0.waker.register(cx.waker());
if timer.0.state.load(Ordering::Relaxed) == FIRED {
timer.0.state.store(TERMINATED, Ordering::Relaxed);
Poll::Ready(())
} else {
Poll::Pending
}
}
fn unregister(&self, timer: &TimerState) {
if timer.state.load(Ordering::Relaxed) != REGISTERED {
return;
}
let mut inner = self.inner.lock();
let index = unsafe { *timer.index.get() };
if let Some(index) = index.get() {
inner.timers.remove(index);
if index == 0 {
match inner.timers.peek() {
Some(next) => {
let nanos = unsafe { next.nanos() };
self.set_timer(&mut inner, nanos);
}
None => self.timer.cancel().unwrap(),
}
}
timer.state.store(0, Ordering::Relaxed);
} else {
assert_eq!(timer.state.load(Ordering::Relaxed), FIRED);
}
}
fn try_reset_timer(&self, timer: &TimerState, nanos: i64) -> bool {
let mut inner = self.inner.lock();
let index = unsafe { *timer.index.get() };
if let Some(old_index) = index.get() {
if inner.timers.reset(old_index, nanos) == 0 {
self.set_timer(&mut inner, nanos);
} else if old_index == 0 {
let nanos = unsafe { inner.timers.peek().unwrap().nanos() };
self.set_timer(&mut inner, nanos);
}
timer.state.store(REGISTERED, Ordering::Relaxed);
true
} else {
false
}
}
}
#[derive(Default)]
struct Heap(Vec<StateRef>);
impl Heap {
fn push(&mut self, mut timer: StateRef) {
let index = self.0.len();
self.0.push(timer);
unsafe {
timer.set_index(index.into());
}
self.fix_up(index);
}
fn peek(&self) -> Option<&StateRef> {
self.0.first()
}
fn pop(&mut self) -> Option<StateRef> {
if let Some(&first) = self.0.first() {
self.remove(0);
Some(first)
} else {
None
}
}
fn swap(&mut self, a: usize, b: usize) {
self.0.swap(a, b);
unsafe {
self.0[a].set_index(a.into());
self.0[b].set_index(b.into());
}
}
fn reset(&mut self, index: usize, nanos: i64) -> usize {
if nanos < std::mem::replace(unsafe { self.0[index].nanos_mut() }, nanos) {
self.fix_up(index)
} else {
self.fix_down(index)
}
}
fn remove(&mut self, index: usize) {
unsafe {
let old_index = self.0[index].set_index(HeapIndex::NULL);
debug_assert_eq!(old_index, index.into());
}
let last = self.0.len() - 1;
if index < last {
self.0[index] = self.0[last];
unsafe {
self.0[index].set_index(index.into());
}
}
self.0.truncate(last);
self.fix_down(index);
}
fn fix_up(&mut self, mut index: usize) -> usize {
while index > 0 {
let parent = (index - 1) / 2;
if unsafe { self.0[parent].nanos() <= self.0[index].nanos() } {
return index;
}
self.swap(parent, index);
index = parent;
}
index
}
fn fix_down(&mut self, mut index: usize) -> usize {
let len = self.0.len();
loop {
let left = index * 2 + 1;
if left >= len {
return index;
}
let mut swap_with = None;
unsafe {
let mut nanos = self.0[index].nanos();
let left_nanos = self.0[left].nanos();
if left_nanos < nanos {
swap_with = Some(left);
nanos = left_nanos;
}
let right = left + 1;
if right < len && self.0[right].nanos() < nanos {
swap_with = Some(right);
}
}
let Some(swap_with) = swap_with else { return index };
self.swap(index, swap_with);
index = swap_with;
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{LocalExecutor, SendExecutor, Task, TestExecutor};
use assert_matches::assert_matches;
use futures::channel::oneshot::channel;
use futures::future::Either;
use futures::prelude::*;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::future::poll_fn;
use std::pin::pin;
use zx::MonotonicDuration;
trait TestTimeInterface:
TimeInterface
+ WakeupTime
+ std::ops::Sub<zx::Duration<Self::Timeline>, Output = Self>
+ std::ops::Add<zx::Duration<Self::Timeline>, Output = Self>
{
fn after(duration: zx::Duration<Self::Timeline>) -> Self;
}
impl TestTimeInterface for MonotonicInstant {
fn after(duration: zx::MonotonicDuration) -> Self {
Self::after(duration)
}
}
impl TestTimeInterface for BootInstant {
fn after(duration: zx::BootDuration) -> Self {
Self::after(duration)
}
}
fn test_shorter_fires_first<T: TestTimeInterface>() {
let mut exec = LocalExecutor::new();
let shorter = pin!(Timer::new(T::after(zx::Duration::<T::Timeline>::from_millis(100))));
let longer = pin!(Timer::new(T::after(zx::Duration::<T::Timeline>::from_seconds(1))));
match exec.run_singlethreaded(future::select(shorter, longer)) {
Either::Left(_) => {}
Either::Right(_) => panic!("wrong timer fired"),
}
}
#[test]
fn shorter_fires_first() {
test_shorter_fires_first::<MonotonicInstant>();
test_shorter_fires_first::<BootInstant>();
}
fn test_shorter_fires_first_multithreaded<T: TestTimeInterface>() {
SendExecutor::new(4).run(async {
let shorter = pin!(Timer::new(T::after(zx::Duration::<T::Timeline>::from_millis(100))));
let longer = pin!(Timer::new(T::after(zx::Duration::<T::Timeline>::from_seconds(1))));
match future::select(shorter, longer).await {
Either::Left(_) => {}
Either::Right(_) => panic!("wrong timer fired"),
}
});
}
#[test]
fn shorter_fires_first_multithreaded() {
test_shorter_fires_first_multithreaded::<MonotonicInstant>();
test_shorter_fires_first_multithreaded::<BootInstant>();
}
fn test_timer_before_now_fires_immediately<T: TestTimeInterface>() {
let mut exec = TestExecutor::new();
let now = T::now();
let before = pin!(Timer::new(T::from_nanos(now - 1)));
let after = pin!(Timer::new(T::from_nanos(now + 1)));
assert_matches!(
exec.run_singlethreaded(futures::future::select(before, after)),
Either::Left(_),
"Timer in the past should fire first"
);
}
#[test]
fn timer_before_now_fires_immediately() {
test_timer_before_now_fires_immediately::<MonotonicInstant>();
test_timer_before_now_fires_immediately::<BootInstant>();
}
#[test]
fn fires_after_timeout() {
let mut exec = TestExecutor::new_with_fake_time();
exec.set_fake_time(MonotonicInstant::from_nanos(0));
let deadline = MonotonicInstant::after(MonotonicDuration::from_seconds(5));
let mut future = pin!(Timer::new(deadline));
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut future));
exec.set_fake_time(deadline);
assert_eq!(Poll::Ready(()), exec.run_until_stalled(&mut future));
}
#[test]
fn interval() {
let mut exec = TestExecutor::new_with_fake_time();
let start = MonotonicInstant::from_nanos(0);
exec.set_fake_time(start);
let counter = Arc::new(::std::sync::atomic::AtomicUsize::new(0));
let mut future = pin!({
let counter = counter.clone();
Interval::new(MonotonicDuration::from_seconds(5))
.map(move |()| {
counter.fetch_add(1, Ordering::SeqCst);
})
.collect::<()>()
});
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut future));
assert_eq!(0, counter.load(Ordering::SeqCst));
let first_deadline = TestExecutor::next_timer().expect("Expected a pending timeout (1)");
assert!(first_deadline >= MonotonicDuration::from_seconds(5) + start);
exec.set_fake_time(first_deadline);
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut future));
assert_eq!(1, counter.load(Ordering::SeqCst));
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut future));
assert_eq!(1, counter.load(Ordering::SeqCst));
let second_deadline = TestExecutor::next_timer().expect("Expected a pending timeout (2)");
exec.set_fake_time(second_deadline);
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut future));
assert_eq!(2, counter.load(Ordering::SeqCst));
assert_eq!(second_deadline, first_deadline + MonotonicDuration::from_seconds(5));
}
#[test]
fn timer_fake_time() {
let mut exec = TestExecutor::new_with_fake_time();
exec.set_fake_time(MonotonicInstant::from_nanos(0));
let mut timer =
pin!(Timer::new(MonotonicInstant::after(MonotonicDuration::from_seconds(1))));
assert_eq!(Poll::Pending, exec.run_until_stalled(&mut timer));
exec.set_fake_time(MonotonicInstant::after(MonotonicDuration::from_seconds(1)));
assert_eq!(Poll::Ready(()), exec.run_until_stalled(&mut timer));
}
fn create_timers(
timers: &Arc<Timers<MonotonicInstant>>,
nanos: &[i64],
timer_futures: &mut Vec<Pin<Box<Timer>>>,
) {
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
for &n in nanos {
let mut timer = Box::pin(timers.new_timer(MonotonicInstant::from_nanos(n)));
let _ = timer.poll_unpin(&mut cx);
timer_futures.push(timer);
}
}
#[test]
fn timer_heap() {
let _exec = TestExecutor::new_with_fake_time();
let timers = Arc::new(Timers::<MonotonicInstant>::new(0, true));
let mut timer_futures = Vec::new();
let mut nanos: Vec<_> = (0..1000).collect();
let mut rng = thread_rng();
nanos.shuffle(&mut rng);
create_timers(&timers, &nanos, &mut timer_futures);
for i in 0..1000 {
assert_eq!(timers.wake_next_timer(), Some(MonotonicInstant::from_nanos(i)));
}
timer_futures.clear();
create_timers(&timers, &nanos, &mut timer_futures);
timer_futures.shuffle(&mut rng);
for _timer_fut in timer_futures.drain(..) {}
assert_eq!(timers.wake_next_timer(), None);
create_timers(&timers, &nanos, &mut timer_futures);
timer_futures.shuffle(&mut rng);
let mut nanos: Vec<_> = (1000..2000).collect();
nanos.shuffle(&mut rng);
for (fut, n) in timer_futures.iter_mut().zip(nanos) {
fut.as_mut().reset(MonotonicInstant::from_nanos(n));
}
for i in 1000..2000 {
assert_eq!(timers.wake_next_timer(), Some(MonotonicInstant::from_nanos(i)));
}
}
#[test]
fn timer_heap_with_same_time() {
let _exec = TestExecutor::new_with_fake_time();
let timers = Arc::new(Timers::<MonotonicInstant>::new(0, true));
let mut timer_futures = Vec::new();
let mut nanos: Vec<_> = (1..100).collect();
let mut rng = thread_rng();
nanos.shuffle(&mut rng);
create_timers(&timers, &nanos, &mut timer_futures);
let time = rng.gen_range(0..101);
let same_time = [time; 100];
create_timers(&timers, &same_time, &mut timer_futures);
nanos.extend(&same_time);
nanos.sort();
for n in nanos {
assert_eq!(timers.wake_next_timer(), Some(MonotonicInstant::from_nanos(n)));
}
}
#[test]
fn timer_reset_to_earlier_time() {
let mut exec = LocalExecutor::new();
for _ in 0..100 {
let instant = MonotonicInstant::after(MonotonicDuration::from_millis(100));
let (sender, receiver) = channel();
let task = Task::spawn(async move {
let mut timer = pin!(Timer::new(instant));
let mut receiver = pin!(receiver.fuse());
poll_fn(|cx| loop {
if timer.as_mut().poll_unpin(cx).is_ready() {
return Poll::Ready(());
}
if !receiver.is_terminated() && receiver.poll_unpin(cx).is_ready() {
timer
.as_mut()
.reset(MonotonicInstant::after(MonotonicDuration::from_millis(1)));
} else {
return Poll::Pending;
}
})
.await;
});
sender.send(()).unwrap();
exec.run_singlethreaded(task);
if MonotonicInstant::after(MonotonicDuration::from_millis(1)) < instant {
return;
}
}
panic!("Timer fired late in all 100 attempts");
}
}