1use fdf_sys::*;
8
9use core::cell::{RefCell, UnsafeCell};
10use core::ffi;
11use core::future::Future;
12use core::marker::PhantomData;
13use core::mem::ManuallyDrop;
14use core::ptr::{NonNull, addr_of_mut, null_mut};
15use core::task::Context;
16use std::sync::{Arc, Mutex, Weak};
17
18use zx::Status;
19
20use futures::future::{BoxFuture, FutureExt};
21use futures::task::{ArcWake, waker_ref};
22
23pub use fdf_sys::fdf_dispatcher_t;
24
25pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
27impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28
29#[derive(Default)]
31pub struct DispatcherBuilder {
32 #[doc(hidden)]
33 pub options: u32,
34 #[doc(hidden)]
35 pub name: String,
36 #[doc(hidden)]
37 pub scheduler_role: String,
38 #[doc(hidden)]
39 pub shutdown_observer: Option<ShutdownObserver>,
40}
41
42impl DispatcherBuilder {
43 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
45 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
47
48 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn unsynchronized(mut self) -> Self {
61 assert!(
62 !self.allows_thread_blocking(),
63 "you may not create an unsynchronized dispatcher that allows synchronous calls"
64 );
65 self.options |= Self::UNSYNCHRONIZED;
66 self
67 }
68
69 pub fn is_unsynchronized(&self) -> bool {
71 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
72 }
73
74 pub fn allow_thread_blocking(mut self) -> Self {
80 assert!(
81 !self.is_unsynchronized(),
82 "you may not create an unsynchronized dispatcher that allows synchronous calls"
83 );
84 self.options |= Self::ALLOW_THREAD_BLOCKING;
85 self
86 }
87
88 pub fn allows_thread_blocking(&self) -> bool {
90 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
91 }
92
93 pub fn name(mut self, name: &str) -> Self {
96 self.name = name.to_string();
97 self
98 }
99
100 pub fn scheduler_role(mut self, role: &str) -> Self {
104 self.scheduler_role = role.to_string();
105 self
106 }
107
108 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
110 self.shutdown_observer = Some(ShutdownObserver::new(shutdown_observer));
111 self
112 }
113
114 pub fn create(self) -> Result<Dispatcher, Status> {
119 let mut out_dispatcher = null_mut();
120 let options = self.options;
121 let name = self.name.as_ptr() as *mut ffi::c_char;
122 let name_len = self.name.len();
123 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
124 let scheduler_role_len = self.scheduler_role.len();
125 let observer =
126 self.shutdown_observer.unwrap_or_else(|| ShutdownObserver::new(|_| {})).into_ptr();
127 Status::ok(unsafe {
131 fdf_dispatcher_create(
132 options,
133 name,
134 name_len,
135 scheduler_role,
136 scheduler_role_len,
137 observer,
138 &mut out_dispatcher,
139 )
140 })?;
141 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
144 }
145
146 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
150 self.create().map(Dispatcher::release)
151 }
152}
153
154#[derive(Debug)]
156pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
157
158unsafe impl Send for Dispatcher {}
160unsafe impl Sync for Dispatcher {}
161thread_local! {
162 static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
163}
164
165impl Dispatcher {
166 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
174 Self(handle)
175 }
176
177 #[doc(hidden)]
178 pub fn inner(&self) -> &NonNull<fdf_dispatcher_t> {
179 &self.0
180 }
181
182 fn get_raw_flags(&self) -> u32 {
183 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
185 }
186
187 pub fn is_unsynchronized(&self) -> bool {
189 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
190 }
191
192 pub fn allows_thread_blocking(&self) -> bool {
194 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
195 }
196
197 pub fn is_current_dispatcher(&self) -> bool {
199 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
202 }
203
204 pub fn post_task_sync(&self, p: impl TaskCallback) -> Result<(), Status> {
206 let async_dispatcher = unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) };
208 #[expect(clippy::arc_with_non_send_sync)]
209 let task_arc = Arc::new(UnsafeCell::new(TaskFunc {
210 task: async_task { handler: Some(TaskFunc::call), ..Default::default() },
211 func: Box::new(p),
212 }));
213
214 let task_cell = Arc::into_raw(task_arc);
215 let res = unsafe {
222 let task_ptr = addr_of_mut!((*UnsafeCell::raw_get(task_cell)).task);
223 async_post_task(async_dispatcher, task_ptr)
224 };
225 if res != ZX_OK {
226 unsafe { Arc::decrement_strong_count(task_cell) }
229 Err(Status::from_raw(res))
230 } else {
231 Ok(())
232 }
233 }
234
235 pub fn release(self) -> DispatcherRef<'static> {
240 DispatcherRef(ManuallyDrop::new(self), PhantomData)
241 }
242
243 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
246 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
247 }
248
249 #[doc(hidden)]
252 pub fn override_current<R>(dispatcher: DispatcherRef<'_>, f: impl FnOnce() -> R) -> R {
253 OVERRIDE_DISPATCHER.with(|global| {
254 let previous = global.replace(Some(dispatcher.0.0));
255 let res = f();
256 global.replace(previous);
257 res
258 })
259 }
260}
261
262impl Drop for Dispatcher {
263 fn drop(&mut self) {
264 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
267 }
268}
269
270#[derive(Debug)]
274pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
275
276impl<'a> DispatcherRef<'a> {
277 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
284 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
286 }
287}
288
289impl<'a> Clone for DispatcherRef<'a> {
290 fn clone(&self) -> Self {
291 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
292 }
293}
294
295impl<'a> core::ops::Deref for DispatcherRef<'a> {
296 type Target = Dispatcher;
297 fn deref(&self) -> &Self::Target {
298 &self.0
299 }
300}
301
302impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut self.0
305 }
306}
307
308pub trait OnDispatcher: Clone + Send + Sync + Unpin {
310 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R;
313
314 fn on_maybe_dispatcher<R, E: From<Status>>(
317 &self,
318 f: impl FnOnce(DispatcherRef<'_>) -> Result<R, E>,
319 ) -> Result<R, E> {
320 self.on_dispatcher(|dispatcher| {
321 let dispatcher = dispatcher.ok_or(Status::BAD_STATE)?;
322 f(dispatcher)
323 })
324 }
325
326 fn spawn_task(&self, future: impl Future<Output = ()> + Send + 'static) -> Result<(), Status>
330 where
331 Self: 'static,
332 {
333 let task =
334 Arc::new(Task { future: Mutex::new(Some(future.boxed())), dispatcher: self.clone() });
335 task.queue()
336 }
337}
338
339impl<D: OnDispatcher> OnDispatcher for &D {
340 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
341 D::on_dispatcher(*self, f)
342 }
343}
344
345impl OnDispatcher for &Dispatcher {
346 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
347 f(Some(self.as_dispatcher_ref()))
348 }
349}
350
351impl<'a> OnDispatcher for DispatcherRef<'a> {
352 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
353 f(Some(self.as_dispatcher_ref()))
354 }
355}
356
357impl OnDispatcher for Arc<Dispatcher> {
358 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
359 f(Some(self.as_dispatcher_ref()))
360 }
361}
362
363impl OnDispatcher for Weak<Dispatcher> {
364 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
365 let dispatcher = Weak::upgrade(self);
366 match dispatcher {
367 Some(dispatcher) => f(Some(dispatcher.as_dispatcher_ref())),
368 None => f(None),
369 }
370 }
371}
372
373#[derive(Clone, Copy, Debug, PartialEq)]
376pub struct CurrentDispatcher;
377
378impl OnDispatcher for CurrentDispatcher {
379 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
380 let dispatcher = OVERRIDE_DISPATCHER
381 .with(|global| *global.borrow())
382 .or_else(|| {
383 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
385 })
386 .map(|dispatcher| {
387 DispatcherRef(
393 ManuallyDrop::new(unsafe { Dispatcher::from_raw(dispatcher) }),
394 Default::default(),
395 )
396 });
397 f(dispatcher)
398 }
399}
400
401pub trait TaskCallback: FnOnce(Status) + 'static + Send {}
403impl<T> TaskCallback for T where T: FnOnce(Status) + 'static + Send {}
404
405struct Task<D> {
406 future: Mutex<Option<BoxFuture<'static, ()>>>,
407 dispatcher: D,
408}
409
410impl<D: OnDispatcher + 'static> ArcWake for Task<D> {
411 fn wake_by_ref(arc_self: &Arc<Self>) {
412 match arc_self.queue() {
413 Err(e) if e == Status::from_raw(ZX_ERR_BAD_STATE) => {
414 let future_slot = arc_self.future.lock().unwrap().take();
417 core::mem::drop(future_slot);
418 }
419 res => res.expect("Unexpected error waking dispatcher task"),
420 }
421 }
422}
423
424impl<D: OnDispatcher + 'static> Task<D> {
425 fn queue(self: &Arc<Self>) -> Result<(), Status> {
429 let arc_self = self.clone();
430 self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
431 dispatcher
432 .post_task_sync(move |status| {
433 let mut future_slot = arc_self.future.lock().unwrap();
434 if status != Status::from_raw(ZX_OK) {
436 core::mem::drop(future_slot.take());
437 return;
438 }
439
440 let Some(mut future) = future_slot.take() else {
441 return;
442 };
443 let waker = waker_ref(&arc_self);
444 let context = &mut Context::from_waker(&waker);
445 if future.as_mut().poll(context).is_pending() {
446 *future_slot = Some(future);
447 }
448 })
449 .map(|_| ())
450 })
451 }
452}
453
454#[repr(C)]
455struct TaskFunc {
456 task: async_task,
457 func: Box<dyn TaskCallback>,
458}
459
460impl TaskFunc {
461 extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
462 let task = unsafe { Arc::from_raw(task as *const UnsafeCell<Self>) };
465 if let Ok(task) = Arc::try_unwrap(task) {
468 (task.into_inner().func)(Status::from_raw(status));
469 }
470 }
471}
472
473#[repr(C)]
482#[doc(hidden)]
483pub struct ShutdownObserver {
484 observer: fdf_dispatcher_shutdown_observer,
485 shutdown_fn: Box<dyn ShutdownObserverFn>,
486}
487
488impl ShutdownObserver {
489 pub fn new<F: ShutdownObserverFn>(f: F) -> Self {
492 let shutdown_fn = Box::new(f);
493 Self {
494 observer: fdf_dispatcher_shutdown_observer { handler: Some(Self::handler) },
495 shutdown_fn,
496 }
497 }
498
499 pub fn into_ptr(self) -> *mut fdf_dispatcher_shutdown_observer {
503 Box::leak(Box::new(self)) as *mut _ as *mut _
506 }
507
508 unsafe extern "C" fn handler(
518 dispatcher: *mut fdf_dispatcher_t,
519 observer: *mut fdf_dispatcher_shutdown_observer_t,
520 ) {
521 let observer = unsafe { Box::from_raw(observer as *mut ShutdownObserver) };
524 let dispatcher_ref = DispatcherRef(
526 ManuallyDrop::new(Dispatcher(unsafe { NonNull::new_unchecked(dispatcher) })),
527 PhantomData,
528 );
529 (observer.shutdown_fn)(dispatcher_ref);
530 unsafe { fdf_dispatcher_destroy(dispatcher) };
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 use std::sync::{Once, mpsc};
541
542 use futures::channel::mpsc as async_mpsc;
543 use futures::{SinkExt, StreamExt};
544
545 use core::ffi::{c_char, c_void};
546 use core::ptr::null_mut;
547
548 static GLOBAL_DRIVER_ENV: Once = Once::new();
549
550 pub fn ensure_driver_env() {
551 GLOBAL_DRIVER_ENV.call_once(|| {
552 unsafe {
555 assert_eq!(fdf_env_start(0), ZX_OK);
556 }
557 });
558 }
559 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
560 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
561 }
562
563 pub(crate) fn with_raw_dispatcher_flags<T>(
564 name: &str,
565 flags: u32,
566 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
567 ) -> T {
568 ensure_driver_env();
569
570 let (shutdown_tx, shutdown_rx) = mpsc::channel();
571 let mut dispatcher = null_mut();
572 let mut observer = ShutdownObserver::new(move |dispatcher| {
573 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
576 shutdown_tx.send(()).unwrap();
577 })
578 .into_ptr();
579 let driver_ptr = &mut observer as *mut _ as *mut c_void;
580 let res = unsafe {
585 fdf_env_dispatcher_create_with_owner(
586 driver_ptr,
587 flags,
588 name.as_ptr() as *const c_char,
589 name.len(),
590 "".as_ptr() as *const c_char,
591 0_usize,
592 observer,
593 &mut dispatcher,
594 )
595 };
596 assert_eq!(res, ZX_OK);
597 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
598
599 let res = p(Arc::downgrade(&dispatcher));
600
601 let weak_dispatcher = Arc::downgrade(&dispatcher);
605 drop(dispatcher);
606 shutdown_rx.recv().unwrap();
607 assert_eq!(
608 0,
609 weak_dispatcher.strong_count(),
610 "a dispatcher reference escaped the test body"
611 );
612
613 res
614 }
615
616 #[test]
617 fn start_test_dispatcher() {
618 with_raw_dispatcher("testing", |dispatcher| {
619 println!("hello {dispatcher:?}");
620 })
621 }
622
623 #[test]
624 fn post_task_on_dispatcher() {
625 with_raw_dispatcher("testing task", |dispatcher| {
626 let (tx, rx) = mpsc::channel();
627 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
628 dispatcher
629 .post_task_sync(move |status| {
630 assert_eq!(status, Status::from_raw(ZX_OK));
631 tx.send(status).unwrap();
632 })
633 .unwrap();
634 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
635 });
636 }
637
638 #[test]
639 fn post_task_on_subdispatcher() {
640 let (shutdown_tx, shutdown_rx) = mpsc::channel();
641 with_raw_dispatcher("testing task top level", move |dispatcher| {
642 let (tx, rx) = mpsc::channel();
643 let (inner_tx, inner_rx) = mpsc::channel();
644 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
645 dispatcher
646 .post_task_sync(move |status| {
647 assert_eq!(status, Status::from_raw(ZX_OK));
648 let inner = DispatcherBuilder::new()
649 .name("testing task second level")
650 .scheduler_role("")
651 .allow_thread_blocking()
652 .shutdown_observer(move |_dispatcher| {
653 println!("shutdown observer called");
654 shutdown_tx.send(1).unwrap();
655 })
656 .create()
657 .unwrap();
658 inner
659 .post_task_sync(move |status| {
660 assert_eq!(status, Status::from_raw(ZX_OK));
661 tx.send(status).unwrap();
662 })
663 .unwrap();
664 inner_tx.send(inner).unwrap();
668 })
669 .unwrap();
670 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
671 inner_rx.recv().unwrap();
672 });
673 assert_eq!(shutdown_rx.recv().unwrap(), 1);
674 }
675
676 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
677 println!("starting ping!");
678 tx.send(0).await.unwrap();
679 while let Some(next) = rx.next().await {
680 println!("ping! {next}");
681 tx.send(next + 1).await.unwrap();
682 }
683 }
684
685 async fn pong(
686 fin_tx: std::sync::mpsc::Sender<()>,
687 mut tx: async_mpsc::Sender<u8>,
688 mut rx: async_mpsc::Receiver<u8>,
689 ) {
690 println!("starting pong!");
691 while let Some(next) = rx.next().await {
692 println!("pong! {next}");
693 if next > 10 {
694 println!("bye!");
695 break;
696 }
697 tx.send(next + 1).await.unwrap();
698 }
699 fin_tx.send(()).unwrap();
700 }
701
702 #[test]
703 fn async_ping_pong() {
704 with_raw_dispatcher("async ping pong", |dispatcher| {
705 let (fin_tx, fin_rx) = mpsc::channel();
706 let (ping_tx, pong_rx) = async_mpsc::channel(10);
707 let (pong_tx, ping_rx) = async_mpsc::channel(10);
708 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
709 dispatcher.spawn_task(pong(fin_tx, pong_tx, pong_rx)).unwrap();
710
711 fin_rx.recv().expect("to receive final value");
712 });
713 }
714
715 async fn slow_pong(
716 fin_tx: std::sync::mpsc::Sender<()>,
717 mut tx: async_mpsc::Sender<u8>,
718 mut rx: async_mpsc::Receiver<u8>,
719 ) {
720 use zx::MonotonicDuration;
721 println!("starting pong!");
722 while let Some(next) = rx.next().await {
723 println!("pong! {next}");
724 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
725 MonotonicDuration::from_seconds(1),
726 ))
727 .await;
728 if next > 10 {
729 println!("bye!");
730 break;
731 }
732 tx.send(next + 1).await.unwrap();
733 }
734 fin_tx.send(()).unwrap();
735 }
736
737 #[test]
738 fn mixed_executor_async_ping_pong() {
739 with_raw_dispatcher("async ping pong", |dispatcher| {
740 let (fin_tx, fin_rx) = mpsc::channel();
741 let (ping_tx, pong_rx) = async_mpsc::channel(10);
742 let (pong_tx, ping_rx) = async_mpsc::channel(10);
743
744 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
746
747 let mut executor = fuchsia_async::LocalExecutor::default();
749 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
750
751 fin_rx.recv().expect("to receive final value");
752 });
753 }
754}