use fidl::endpoints::RequestStream;
use fuchsia_async as fasync;
use fuchsia_zircon as zx;
use futures::{
channel::oneshot::{self, Receiver},
ready, FutureExt, Stream, StreamExt,
};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use zx::Duration;
pub fn until_stalled<RS: RequestStream>(
request_stream: RS,
debounce_interval: Duration,
) -> (impl Stream<Item = <RS as Stream>::Item>, Receiver<Option<zx::Channel>>) {
let (sender, receiver) = oneshot::channel();
let stream = StallableRequestStream::new(request_stream, debounce_interval, move |channel| {
let _ = sender.send(channel);
});
(stream, receiver)
}
pub struct StallableRequestStream<RS, F> {
stream: Option<RS>,
debounce_interval: Duration,
unbind_callback: Option<F>,
timer: Option<fasync::Timer>,
}
impl<RS, F> StallableRequestStream<RS, F> {
pub fn new(stream: RS, debounce_interval: Duration, unbind_callback: F) -> Self {
Self {
stream: Some(stream),
debounce_interval,
unbind_callback: Some(unbind_callback),
timer: None,
}
}
}
impl<RS: RequestStream + Unpin, F: FnOnce(Option<zx::Channel>) + Unpin> Stream
for StallableRequestStream<RS, F>
{
type Item = <RS as Stream>::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.stream.as_mut().expect("Stream already resolved").poll_next_unpin(cx) {
Poll::Ready(message) => {
self.timer = None;
if message.is_none() {
self.unbind_callback.take().unwrap()(None);
}
Poll::Ready(message)
}
Poll::Pending => {
let debounce_interval = self.debounce_interval;
loop {
let timer =
self.timer.get_or_insert_with(|| fasync::Timer::new(debounce_interval));
ready!(timer.poll_unpin(cx));
self.timer = None;
let (inner, is_terminated) = self.stream.take().unwrap().into_inner();
match Arc::try_unwrap(inner) {
Ok(inner) => {
self.unbind_callback.take().unwrap()(Some(
inner.into_channel().into_zx_channel(),
));
return Poll::Ready(None);
}
Err(inner) => {
self.stream = Some(RS::from_inner(inner, is_terminated));
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use fasync::TestExecutor;
use fidl::endpoints::Proxy;
use fidl::AsHandleRef;
use fidl_fuchsia_io as fio;
use fuchsia_async as fasync;
use fuchsia_zircon as zx;
use futures::{pin_mut, TryStreamExt};
#[fuchsia::test(allow_stalls = false)]
async fn no_message() {
let initial = fasync::Time::from_nanos(0);
TestExecutor::advance_to(initial).await;
const DURATION_NANOS: i64 = 1_000_000;
let idle_duration = Duration::from_nanos(DURATION_NANOS);
let (_proxy, stream) =
fidl::endpoints::create_proxy_and_stream::<fio::DirectoryMarker>().unwrap();
let (mut stream, stalled) = until_stalled(stream, idle_duration);
assert_matches!(
futures::join!(
stream.next(),
TestExecutor::advance_to(initial + idle_duration).then(|()| stalled)
),
(None, Ok(Some(_)))
);
}
#[fuchsia::test(allow_stalls = false)]
async fn one_message() {
let initial = fasync::Time::from_nanos(0);
TestExecutor::advance_to(initial).await;
const DURATION_NANOS: i64 = 1_000_000;
let idle_duration = Duration::from_nanos(DURATION_NANOS);
let (proxy, stream) =
fidl::endpoints::create_proxy_and_stream::<fio::DirectoryMarker>().unwrap();
let (mut stream, stalled) = until_stalled(stream, idle_duration);
pin_mut!(stalled);
assert_matches!(TestExecutor::poll_until_stalled(&mut stalled).await, Poll::Pending);
let _ = proxy.get_flags();
let message = stream.next();
pin_mut!(message);
let message = TestExecutor::poll_until_stalled(&mut message).await;
let Poll::Ready(Some(Ok(fio::DirectoryRequest::GetFlags { responder }))) = message else {
panic!("Unexpected {message:?}");
};
responder.send(zx::Status::OK.into_raw(), fio::OpenFlags::empty()).unwrap();
TestExecutor::advance_to(initial + idle_duration * 2).await;
assert!(TestExecutor::poll_until_stalled(&mut stalled).await.is_pending());
let message = stream.next();
pin_mut!(message);
assert_matches!(TestExecutor::poll_until_stalled(&mut message).await, Poll::Pending);
assert_matches!(TestExecutor::poll_until_stalled(&mut stalled).await, Poll::Pending);
TestExecutor::advance_to(initial + idle_duration * 3).await;
assert_matches!(message.await, None);
assert_matches!(stalled.await, Ok(Some(_)));
}
#[fuchsia::test(allow_stalls = false)]
async fn pending_reply_blocks_stalling() {
let initial = fasync::Time::from_nanos(0);
TestExecutor::advance_to(initial).await;
const DURATION_NANOS: i64 = 1_000_000;
let idle_duration = Duration::from_nanos(DURATION_NANOS);
let (proxy, stream) =
fidl::endpoints::create_proxy_and_stream::<fio::DirectoryMarker>().unwrap();
let (stream, mut stalled) = until_stalled(stream, idle_duration);
let mut stream = stream.fuse();
let _ = proxy.get_flags();
let message_with_pending_reply = stream.next().await.unwrap();
let Ok(fio::DirectoryRequest::GetFlags { responder, .. }) = message_with_pending_reply
else {
panic!("Unexpected {message_with_pending_reply:?}");
};
TestExecutor::advance_to(initial + idle_duration * 2).await;
futures::select! {
_ = stream.next() => unreachable!(),
_ = stalled => unreachable!(),
default => {},
}
responder.send(zx::Status::OK.into_raw(), fio::OpenFlags::empty()).unwrap();
assert_matches!(
futures::join!(
stream.next(),
TestExecutor::advance_to(initial + idle_duration * 3).then(|()| stalled)
),
(None, Ok(Some(_)))
);
}
#[fuchsia::test(allow_stalls = false)]
async fn completed_stream() {
let initial = fasync::Time::from_nanos(0);
TestExecutor::advance_to(initial).await;
const DURATION_NANOS: i64 = 1_000_000;
let idle_duration = Duration::from_nanos(DURATION_NANOS);
let (proxy, stream) =
fidl::endpoints::create_proxy_and_stream::<fio::DirectoryMarker>().unwrap();
let (mut stream, stalled) = until_stalled(stream, idle_duration);
pin_mut!(stalled);
assert_matches!(TestExecutor::poll_until_stalled(&mut stalled).await, Poll::Pending);
drop(proxy);
{
assert_matches!(stream.next().await, None);
drop(stream);
}
assert_matches!(stalled.await, Ok(None));
}
#[fuchsia::test(allow_stalls = false)]
async fn end_to_end() {
let initial = fasync::Time::from_nanos(0);
TestExecutor::advance_to(initial).await;
use fidl_fuchsia_component_client_test::{ServiceAMarker, ServiceARequest};
const DURATION_NANOS: i64 = 40_000_000;
let idle_duration = Duration::from_nanos(DURATION_NANOS);
let (proxy, stream) = fidl::endpoints::create_proxy_and_stream::<ServiceAMarker>().unwrap();
let (mut stream, stalled) = until_stalled(stream, idle_duration);
let task = fasync::Task::spawn(async move {
while let Some(request) = stream.try_next().await.unwrap() {
match request {
ServiceARequest::Foo { responder } => responder.send().unwrap(),
}
}
});
let stalled = fasync::Task::spawn(stalled).map(Arc::new).shared();
let request_duration = Duration::from_nanos(DURATION_NANOS / 2);
const NUM_REQUESTS: usize = 5;
let mut deadline = initial;
for _ in 0..NUM_REQUESTS {
proxy.foo().await.unwrap();
deadline += request_duration;
TestExecutor::advance_to(deadline).await;
assert!(stalled.clone().now_or_never().is_none());
}
deadline += idle_duration;
TestExecutor::advance_to(deadline).await;
let server_end = stalled.await;
task.await;
let client = proxy.into_channel().unwrap().into_zx_channel();
assert_eq!(
client.basic_info().unwrap().koid,
(*server_end).as_ref().unwrap().as_ref().unwrap().basic_info().unwrap().related_koid
);
}
}