1use fdf_sys::*;
8
9use core::cell::UnsafeCell;
10use core::ffi;
11use core::future::Future;
12use core::marker::PhantomData;
13use core::mem::ManuallyDrop;
14use core::ptr::{addr_of_mut, null_mut, NonNull};
15use core::task::Context;
16use std::sync::{Arc, Mutex};
17
18use zx::Status;
19
20use futures::future::{BoxFuture, FutureExt};
21use futures::task::{waker_ref, ArcWake};
22
23pub use fdf_sys::fdf_dispatcher_t;
24
25pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + Sync + 'static {}
26impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + Sync + 'static {}
27
28#[derive(Default)]
30pub struct DispatcherBuilder {
31 #[doc(hidden)]
32 pub options: u32,
33 #[doc(hidden)]
34 pub name: String,
35 #[doc(hidden)]
36 pub scheduler_role: String,
37 #[doc(hidden)]
38 pub shutdown_observer: Option<ShutdownObserver>,
39}
40
41impl DispatcherBuilder {
42 pub(crate) const UNSYNCHRONIZED: u32 = 0b01;
44 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = 0b10;
46
47 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn unsynchronized(mut self) -> Self {
60 assert!(
61 !self.allows_thread_blocking(),
62 "you may not create an unsynchronized dispatcher that allows synchronous calls"
63 );
64 self.options = self.options | Self::UNSYNCHRONIZED;
65 self
66 }
67
68 pub fn is_unsynchronized(&self) -> bool {
70 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
71 }
72
73 pub fn allow_thread_blocking(mut self) -> Self {
79 assert!(
80 !self.is_unsynchronized(),
81 "you may not create an unsynchronized dispatcher that allows synchronous calls"
82 );
83 self.options = self.options | Self::ALLOW_THREAD_BLOCKING;
84 self
85 }
86
87 pub fn allows_thread_blocking(&self) -> bool {
89 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
90 }
91
92 pub fn name(mut self, name: &str) -> Self {
95 self.name = name.to_string();
96 self
97 }
98
99 pub fn scheduler_role(mut self, role: &str) -> Self {
103 self.scheduler_role = role.to_string();
104 self
105 }
106
107 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
109 self.shutdown_observer = Some(ShutdownObserver::new(shutdown_observer));
110 self
111 }
112
113 pub fn create(self) -> Result<Dispatcher, Status> {
118 let mut out_dispatcher = null_mut();
119 let options = self.options;
120 let name = self.name.as_ptr() as *mut ffi::c_char;
121 let name_len = self.name.len();
122 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
123 let scheduler_role_len = self.scheduler_role.len();
124 let observer =
125 self.shutdown_observer.unwrap_or_else(|| ShutdownObserver::new(|_| {})).into_ptr();
126 Status::ok(unsafe {
130 fdf_dispatcher_create(
131 options,
132 name,
133 name_len,
134 scheduler_role,
135 scheduler_role_len,
136 observer,
137 &mut out_dispatcher,
138 )
139 })?;
140 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
143 }
144
145 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
149 self.create().map(Dispatcher::release)
150 }
151}
152
153#[derive(Debug)]
154pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
155
156unsafe impl Send for Dispatcher {}
158unsafe impl Sync for Dispatcher {}
159
160impl Dispatcher {
161 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
169 Self(handle)
170 }
171
172 #[doc(hidden)]
173 pub fn inner<'a>(&'a self) -> &'a NonNull<fdf_dispatcher_t> {
174 &self.0
175 }
176
177 fn get_raw_flags(&self) -> u32 {
178 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
180 }
181
182 pub fn is_unsynchronized(&self) -> bool {
184 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
185 }
186
187 pub fn allows_thread_blocking(&self) -> bool {
189 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
190 }
191
192 pub fn post_task_sync(&self, p: impl TaskCallback) -> Result<(), Status> {
193 let async_dispatcher = unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) };
195 let task_arc = Arc::new(UnsafeCell::new(TaskFunc {
196 task: async_task { handler: Some(TaskFunc::call), ..Default::default() },
197 func: Box::new(p),
198 }));
199
200 let task_cell = Arc::into_raw(task_arc);
201 let res = unsafe {
208 let task_ptr = addr_of_mut!((*UnsafeCell::raw_get(task_cell)).task);
209 async_post_task(async_dispatcher, task_ptr)
210 };
211 if res != ZX_OK {
212 unsafe { Arc::decrement_strong_count(task_cell) }
215 Err(Status::from_raw(res))
216 } else {
217 Ok(())
218 }
219 }
220
221 pub fn spawn_task(
222 &self,
223 future: impl Future<Output = ()> + 'static + Send,
224 ) -> Result<(), Status> {
225 let task = Arc::new(Task {
226 future: Mutex::new(Some(future.boxed())),
227 dispatcher: ManuallyDrop::new(Dispatcher(self.0)),
228 });
229 task.queue()
230 }
231
232 pub fn release(self) -> DispatcherRef<'static> {
237 DispatcherRef(ManuallyDrop::new(self), PhantomData)
238 }
239
240 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
243 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
244 }
245}
246
247impl Drop for Dispatcher {
248 fn drop(&mut self) {
249 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
252 }
253}
254
255#[derive(Debug)]
259pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
260
261impl<'a> DispatcherRef<'a> {
262 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
269 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
271 }
272}
273
274impl<'a> Clone for DispatcherRef<'a> {
275 fn clone(&self) -> Self {
276 Self(ManuallyDrop::new(Dispatcher(self.0 .0)), PhantomData)
277 }
278}
279
280impl<'a> core::ops::Deref for DispatcherRef<'a> {
281 type Target = Dispatcher;
282 fn deref(&self) -> &Self::Target {
283 &self.0
284 }
285}
286
287impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
288 fn deref_mut(&mut self) -> &mut Self::Target {
289 &mut self.0
290 }
291}
292
293pub trait TaskCallback: FnOnce(Status) + 'static + Send + Sync {}
294impl<T> TaskCallback for T where T: FnOnce(Status) + 'static + Send + Sync {}
295
296struct Task {
297 future: Mutex<Option<BoxFuture<'static, ()>>>,
298 dispatcher: ManuallyDrop<Dispatcher>,
299}
300
301impl ArcWake for Task {
302 fn wake_by_ref(arc_self: &Arc<Self>) {
303 match arc_self.queue() {
304 Err(e) if e == Status::from_raw(ZX_ERR_BAD_STATE) => {
305 let mut future_slot = arc_self.future.lock().unwrap();
308 core::mem::drop(future_slot.take());
309 }
310 res => res.expect("Unexpected error waking dispatcher task"),
311 }
312 }
313}
314
315impl Task {
316 fn queue(self: &Arc<Self>) -> Result<(), Status> {
320 let arc_self = self.clone();
321 self.dispatcher
322 .post_task_sync(move |status| {
323 let mut future_slot = arc_self.future.lock().unwrap();
324 if status != Status::from_raw(ZX_OK) {
326 core::mem::drop(future_slot.take());
327 return;
328 }
329
330 let Some(mut future) = future_slot.take() else {
331 return;
332 };
333 let waker = waker_ref(&arc_self);
334 let context = &mut Context::from_waker(&waker);
335 if future.as_mut().poll(context).is_pending() {
336 *future_slot = Some(future);
337 }
338 })
339 .map(|_| ())
340 }
341}
342
343#[repr(C)]
344struct TaskFunc {
345 task: async_task,
346 func: Box<dyn TaskCallback>,
347}
348
349impl TaskFunc {
350 extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
351 let task = unsafe { Arc::from_raw(task as *const UnsafeCell<Self>) };
354 if let Some(task) = Arc::try_unwrap(task).ok() {
357 (task.into_inner().func)(Status::from_raw(status));
358 }
359 }
360}
361
362#[repr(C)]
371#[doc(hidden)]
372pub struct ShutdownObserver {
373 observer: fdf_dispatcher_shutdown_observer,
374 shutdown_fn: Box<dyn ShutdownObserverFn>,
375}
376
377impl ShutdownObserver {
378 pub fn new<F: ShutdownObserverFn>(f: F) -> Self {
381 let shutdown_fn = Box::new(f);
382 Self {
383 observer: fdf_dispatcher_shutdown_observer { handler: Some(Self::handler) },
384 shutdown_fn,
385 }
386 }
387
388 pub fn into_ptr(self) -> *mut fdf_dispatcher_shutdown_observer {
392 Box::leak(Box::new(self)) as *mut _ as *mut _
395 }
396
397 unsafe extern "C" fn handler(
407 dispatcher: *mut fdf_dispatcher_t,
408 observer: *mut fdf_dispatcher_shutdown_observer_t,
409 ) {
410 let observer = unsafe { Box::from_raw(observer as *mut ShutdownObserver) };
413 let dispatcher_ref = DispatcherRef(
415 ManuallyDrop::new(Dispatcher(unsafe { NonNull::new_unchecked(dispatcher) })),
416 PhantomData,
417 );
418 (observer.shutdown_fn)(dispatcher_ref);
419 unsafe { fdf_dispatcher_destroy(dispatcher) };
422 }
423}
424
425pub mod test {
426 use core::ffi::{c_char, c_void};
427 use core::ptr::null_mut;
428 use std::sync::{mpsc, Once};
429
430 use super::*;
431
432 static GLOBAL_DRIVER_ENV: Once = Once::new();
433
434 pub fn ensure_driver_env() {
435 GLOBAL_DRIVER_ENV.call_once(|| {
436 unsafe {
439 assert_eq!(fdf_env_start(0), ZX_OK);
440 }
441 });
442 }
443 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(&Arc<Dispatcher>) -> T) -> T {
444 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
445 }
446
447 pub(crate) fn with_raw_dispatcher_flags<T>(
448 name: &str,
449 flags: u32,
450 p: impl for<'a> FnOnce(&Arc<Dispatcher>) -> T,
451 ) -> T {
452 ensure_driver_env();
453
454 let (shutdown_tx, shutdown_rx) = mpsc::channel();
455 let mut dispatcher = null_mut();
456 let mut observer = ShutdownObserver::new(move |dispatcher| {
457 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0 .0.as_ptr()) });
460 shutdown_tx.send(()).unwrap();
461 })
462 .into_ptr();
463 let driver_ptr = &mut observer as *mut _ as *mut c_void;
464 let res = unsafe {
469 fdf_env_dispatcher_create_with_owner(
470 driver_ptr,
471 flags,
472 name.as_ptr() as *const c_char,
473 name.len(),
474 "".as_ptr() as *const c_char,
475 0 as usize,
476 observer,
477 &mut dispatcher,
478 )
479 };
480 assert_eq!(res, ZX_OK);
481 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
482
483 let res = p(&dispatcher);
484
485 let weak_dispatcher = Arc::downgrade(&dispatcher);
489 drop(dispatcher);
490 shutdown_rx.recv().unwrap();
491 assert_eq!(
492 0,
493 weak_dispatcher.strong_count(),
494 "a dispatcher reference escaped the test body"
495 );
496
497 res
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::test::*;
504 use super::*;
505
506 use std::sync::mpsc;
507
508 use futures::channel::mpsc as async_mpsc;
509 use futures::{SinkExt, StreamExt};
510
511 #[test]
512 fn start_test_dispatcher() {
513 with_raw_dispatcher("testing", |dispatcher| {
514 println!("hello {dispatcher:?}");
515 })
516 }
517
518 #[test]
519 fn post_task_on_dispatcher() {
520 with_raw_dispatcher("testing task", |dispatcher| {
521 let (tx, rx) = mpsc::channel();
522 dispatcher
523 .post_task_sync(move |status| {
524 assert_eq!(status, Status::from_raw(ZX_OK));
525 tx.send(status).unwrap();
526 })
527 .unwrap();
528 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
529 });
530 }
531
532 #[test]
533 fn post_task_on_subdispatcher() {
534 let (shutdown_tx, shutdown_rx) = mpsc::channel();
535 with_raw_dispatcher("testing task top level", move |dispatcher| {
536 let (tx, rx) = mpsc::channel();
537 let (inner_tx, inner_rx) = mpsc::channel();
538 dispatcher
539 .post_task_sync(move |status| {
540 assert_eq!(status, Status::from_raw(ZX_OK));
541 let inner = DispatcherBuilder::new()
542 .name("testing task second level")
543 .scheduler_role("")
544 .allow_thread_blocking()
545 .shutdown_observer(move |_dispatcher| {
546 println!("shutdown observer called");
547 shutdown_tx.send(1).unwrap();
548 })
549 .create()
550 .unwrap();
551 inner
552 .post_task_sync(move |status| {
553 assert_eq!(status, Status::from_raw(ZX_OK));
554 tx.send(status).unwrap();
555 })
556 .unwrap();
557 inner_tx.send(inner).unwrap();
561 })
562 .unwrap();
563 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
564 inner_rx.recv().unwrap();
565 });
566 assert_eq!(shutdown_rx.recv().unwrap(), 1);
567 }
568
569 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
570 println!("starting ping!");
571 tx.send(0).await.unwrap();
572 while let Some(next) = rx.next().await {
573 println!("ping! {next}");
574 tx.send(next + 1).await.unwrap();
575 }
576 }
577
578 async fn pong(
579 fin_tx: std::sync::mpsc::Sender<()>,
580 mut tx: async_mpsc::Sender<u8>,
581 mut rx: async_mpsc::Receiver<u8>,
582 ) {
583 println!("starting pong!");
584 while let Some(next) = rx.next().await {
585 println!("pong! {next}");
586 if next > 10 {
587 println!("bye!");
588 break;
589 }
590 tx.send(next + 1).await.unwrap();
591 }
592 fin_tx.send(()).unwrap();
593 }
594
595 #[test]
596 fn async_ping_pong() {
597 with_raw_dispatcher("async ping pong", |dispatcher| {
598 let (fin_tx, fin_rx) = mpsc::channel();
599 let (ping_tx, pong_rx) = async_mpsc::channel(10);
600 let (pong_tx, ping_rx) = async_mpsc::channel(10);
601 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
602 dispatcher.spawn_task(pong(fin_tx, pong_tx, pong_rx)).unwrap();
603
604 fin_rx.recv().expect("to receive final value");
605 });
606 }
607
608 async fn slow_pong(
609 fin_tx: std::sync::mpsc::Sender<()>,
610 mut tx: async_mpsc::Sender<u8>,
611 mut rx: async_mpsc::Receiver<u8>,
612 ) {
613 use zx::MonotonicDuration;
614 println!("starting pong!");
615 while let Some(next) = rx.next().await {
616 println!("pong! {next}");
617 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
618 MonotonicDuration::from_seconds(1),
619 ))
620 .await;
621 if next > 10 {
622 println!("bye!");
623 break;
624 }
625 tx.send(next + 1).await.unwrap();
626 }
627 fin_tx.send(()).unwrap();
628 }
629
630 #[test]
631 fn mixed_executor_async_ping_pong() {
632 with_raw_dispatcher("async ping pong", |dispatcher| {
633 let (fin_tx, fin_rx) = mpsc::channel();
634 let (ping_tx, pong_rx) = async_mpsc::channel(10);
635 let (pong_tx, ping_rx) = async_mpsc::channel(10);
636
637 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
639
640 let mut executor = fuchsia_async::LocalExecutor::new();
642 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
643
644 fin_rx.recv().expect("to receive final value");
645 });
646 }
647}