1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::{Arc, Weak};
15
16use zx::Status;
17
18pub use fdf_sys::fdf_dispatcher_t;
19pub use libasync::{AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, OnDispatcher};
20
21pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
23impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
24
25#[derive(Default)]
27pub struct DispatcherBuilder {
28 #[doc(hidden)]
29 pub options: u32,
30 #[doc(hidden)]
31 pub name: String,
32 #[doc(hidden)]
33 pub scheduler_role: String,
34 #[doc(hidden)]
35 pub shutdown_observer: Option<ShutdownObserver>,
36}
37
38impl DispatcherBuilder {
39 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
41 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
43
44 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn unsynchronized(mut self) -> Self {
57 assert!(
58 !self.allows_thread_blocking(),
59 "you may not create an unsynchronized dispatcher that allows synchronous calls"
60 );
61 self.options |= Self::UNSYNCHRONIZED;
62 self
63 }
64
65 pub fn is_unsynchronized(&self) -> bool {
67 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
68 }
69
70 pub fn allow_thread_blocking(mut self) -> Self {
76 assert!(
77 !self.is_unsynchronized(),
78 "you may not create an unsynchronized dispatcher that allows synchronous calls"
79 );
80 self.options |= Self::ALLOW_THREAD_BLOCKING;
81 self
82 }
83
84 pub fn allows_thread_blocking(&self) -> bool {
86 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
87 }
88
89 pub fn name(mut self, name: &str) -> Self {
92 self.name = name.to_string();
93 self
94 }
95
96 pub fn scheduler_role(mut self, role: &str) -> Self {
100 self.scheduler_role = role.to_string();
101 self
102 }
103
104 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
106 self.shutdown_observer = Some(ShutdownObserver::new(shutdown_observer));
107 self
108 }
109
110 pub fn create(self) -> Result<Dispatcher, Status> {
115 let mut out_dispatcher = null_mut();
116 let options = self.options;
117 let name = self.name.as_ptr() as *mut ffi::c_char;
118 let name_len = self.name.len();
119 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
120 let scheduler_role_len = self.scheduler_role.len();
121 let observer =
122 self.shutdown_observer.unwrap_or_else(|| ShutdownObserver::new(|_| {})).into_ptr();
123 Status::ok(unsafe {
127 fdf_dispatcher_create(
128 options,
129 name,
130 name_len,
131 scheduler_role,
132 scheduler_role_len,
133 observer,
134 &mut out_dispatcher,
135 )
136 })?;
137 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
140 }
141
142 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
146 self.create().map(Dispatcher::release)
147 }
148}
149
150#[derive(Debug)]
152pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
153
154unsafe impl Send for Dispatcher {}
156unsafe impl Sync for Dispatcher {}
157thread_local! {
158 static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
159}
160
161impl Dispatcher {
162 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
170 Self(handle)
171 }
172
173 #[doc(hidden)]
174 pub fn inner(&self) -> &NonNull<fdf_dispatcher_t> {
175 &self.0
176 }
177
178 fn get_raw_flags(&self) -> u32 {
179 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
181 }
182
183 pub fn is_unsynchronized(&self) -> bool {
185 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
186 }
187
188 pub fn allows_thread_blocking(&self) -> bool {
190 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
191 }
192
193 pub fn is_current_dispatcher(&self) -> bool {
195 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
198 }
199
200 pub fn release(self) -> DispatcherRef<'static> {
205 DispatcherRef(ManuallyDrop::new(self), PhantomData)
206 }
207
208 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
211 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
212 }
213
214 #[doc(hidden)]
217 pub fn override_current<R>(dispatcher: DispatcherRef<'_>, f: impl FnOnce() -> R) -> R {
218 OVERRIDE_DISPATCHER.with(|global| {
219 let previous = global.replace(Some(dispatcher.0.0));
220 let res = f();
221 global.replace(previous);
222 res
223 })
224 }
225}
226
227impl AsyncDispatcher for Dispatcher {
228 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
229 let async_dispatcher =
230 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
231 .expect("No async dispatcher on driver dispatcher");
232 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
233 }
234}
235
236impl Drop for Dispatcher {
237 fn drop(&mut self) {
238 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
241 }
242}
243
244#[derive(Debug)]
254pub struct AutoReleaseDispatcher(*const Dispatcher, Weak<Dispatcher>);
255
256unsafe impl Send for AutoReleaseDispatcher {}
259unsafe impl Sync for AutoReleaseDispatcher {}
260
261impl From<Dispatcher> for AutoReleaseDispatcher {
262 fn from(value: Dispatcher) -> Self {
263 let dispatcher = Arc::new(value);
264 let weak = Arc::downgrade(&dispatcher);
265 Self(Arc::into_raw(dispatcher), weak)
266 }
267}
268
269impl Drop for AutoReleaseDispatcher {
270 fn drop(&mut self) {
271 let dispatcher = unsafe { Arc::from_raw(self.0) };
274 Arc::try_unwrap(dispatcher)
275 .expect("Outstanding strong reference to `AutoReleaseDispatcher` at drop time")
276 .release();
277 }
278}
279
280#[derive(Clone, Debug)]
288pub struct WeakDispatcher(Weak<Dispatcher>);
289
290impl From<&Arc<Dispatcher>> for WeakDispatcher {
291 fn from(value: &Arc<Dispatcher>) -> Self {
292 Self(Arc::downgrade(value))
293 }
294}
295
296impl From<Weak<Dispatcher>> for WeakDispatcher {
297 fn from(value: Weak<Dispatcher>) -> Self {
298 Self(value)
299 }
300}
301
302impl From<&AutoReleaseDispatcher> for WeakDispatcher {
303 fn from(value: &AutoReleaseDispatcher) -> Self {
304 Self(value.1.clone())
305 }
306}
307
308impl OnDispatcher for WeakDispatcher {
309 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
310 self.0.on_dispatcher(f)
311 }
312}
313
314#[derive(Debug)]
318pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
319
320impl<'a> DispatcherRef<'a> {
321 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
328 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
330 }
331
332 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
339 let handle = NonNull::new(unsafe {
340 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
341 })
342 .unwrap();
343 unsafe { Self::from_raw(handle) }
344 }
345}
346
347impl<'a> AsyncDispatcher for DispatcherRef<'a> {
348 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
349 self.0.as_async_dispatcher_ref()
350 }
351}
352
353impl<'a> Clone for DispatcherRef<'a> {
354 fn clone(&self) -> Self {
355 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
356 }
357}
358
359impl<'a> core::ops::Deref for DispatcherRef<'a> {
360 type Target = Dispatcher;
361 fn deref(&self) -> &Self::Target {
362 &self.0
363 }
364}
365
366impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
367 fn deref_mut(&mut self) -> &mut Self::Target {
368 &mut self.0
369 }
370}
371
372impl<'a> OnDispatcher for DispatcherRef<'a> {
373 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
374 f(Some(self.as_async_dispatcher_ref()))
375 }
376}
377
378#[derive(Clone, Copy, Debug, PartialEq)]
381pub struct CurrentDispatcher;
382
383impl OnDispatcher for CurrentDispatcher {
384 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
385 let dispatcher = OVERRIDE_DISPATCHER
386 .with(|global| *global.borrow())
387 .or_else(|| {
388 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
390 })
391 .map(|dispatcher| {
392 let async_dispatcher = NonNull::new(unsafe {
398 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
399 })
400 .expect("No async dispatcher on driver dispatcher");
401 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
402 });
403 f(dispatcher)
404 }
405}
406
407#[repr(C)]
416#[doc(hidden)]
417pub struct ShutdownObserver {
418 observer: fdf_dispatcher_shutdown_observer,
419 shutdown_fn: Box<dyn ShutdownObserverFn>,
420}
421
422impl ShutdownObserver {
423 pub fn new<F: ShutdownObserverFn>(f: F) -> Self {
426 let shutdown_fn = Box::new(f);
427 Self {
428 observer: fdf_dispatcher_shutdown_observer { handler: Some(Self::handler) },
429 shutdown_fn,
430 }
431 }
432
433 pub fn into_ptr(self) -> *mut fdf_dispatcher_shutdown_observer {
437 Box::leak(Box::new(self)) as *mut _ as *mut _
440 }
441
442 unsafe extern "C" fn handler(
452 dispatcher: *mut fdf_dispatcher_t,
453 observer: *mut fdf_dispatcher_shutdown_observer_t,
454 ) {
455 let observer = unsafe { Box::from_raw(observer as *mut ShutdownObserver) };
458 let dispatcher_ref = DispatcherRef(
460 ManuallyDrop::new(Dispatcher(unsafe { NonNull::new_unchecked(dispatcher) })),
461 PhantomData,
462 );
463 (observer.shutdown_fn)(dispatcher_ref);
464 unsafe { fdf_dispatcher_destroy(dispatcher) };
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 use std::sync::{Arc, Once, Weak, mpsc};
475
476 use futures::channel::mpsc as async_mpsc;
477 use futures::{SinkExt, StreamExt};
478 use zx::sys::ZX_OK;
479
480 use core::ffi::{c_char, c_void};
481 use core::ptr::null_mut;
482
483 static GLOBAL_DRIVER_ENV: Once = Once::new();
484
485 pub fn ensure_driver_env() {
486 GLOBAL_DRIVER_ENV.call_once(|| {
487 unsafe {
490 assert_eq!(fdf_env_start(0), ZX_OK);
491 }
492 });
493 }
494 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
495 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
496 }
497
498 pub(crate) fn with_raw_dispatcher_flags<T>(
499 name: &str,
500 flags: u32,
501 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
502 ) -> T {
503 ensure_driver_env();
504
505 let (shutdown_tx, shutdown_rx) = mpsc::channel();
506 let mut dispatcher = null_mut();
507 let mut observer = ShutdownObserver::new(move |dispatcher| {
508 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
511 shutdown_tx.send(()).unwrap();
512 })
513 .into_ptr();
514 let driver_ptr = &mut observer as *mut _ as *mut c_void;
515 let res = unsafe {
520 fdf_env_dispatcher_create_with_owner(
521 driver_ptr,
522 flags,
523 name.as_ptr() as *const c_char,
524 name.len(),
525 "".as_ptr() as *const c_char,
526 0_usize,
527 observer,
528 &mut dispatcher,
529 )
530 };
531 assert_eq!(res, ZX_OK);
532 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
533
534 let res = p(Arc::downgrade(&dispatcher));
535
536 let weak_dispatcher = Arc::downgrade(&dispatcher);
540 drop(dispatcher);
541 shutdown_rx.recv().unwrap();
542 assert_eq!(
543 0,
544 weak_dispatcher.strong_count(),
545 "a dispatcher reference escaped the test body"
546 );
547
548 res
549 }
550
551 #[test]
552 fn start_test_dispatcher() {
553 with_raw_dispatcher("testing", |dispatcher| {
554 println!("hello {dispatcher:?}");
555 })
556 }
557
558 #[test]
559 fn post_task_on_dispatcher() {
560 with_raw_dispatcher("testing task", |dispatcher| {
561 let (tx, rx) = mpsc::channel();
562 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
563 dispatcher
564 .post_task_sync(move |status| {
565 assert_eq!(status, Status::from_raw(ZX_OK));
566 tx.send(status).unwrap();
567 })
568 .unwrap();
569 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
570 });
571 }
572
573 #[test]
574 fn post_task_on_subdispatcher() {
575 let (shutdown_tx, shutdown_rx) = mpsc::channel();
576 with_raw_dispatcher("testing task top level", move |dispatcher| {
577 let (tx, rx) = mpsc::channel();
578 let (inner_tx, inner_rx) = mpsc::channel();
579 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
580 dispatcher
581 .post_task_sync(move |status| {
582 assert_eq!(status, Status::from_raw(ZX_OK));
583 let inner = DispatcherBuilder::new()
584 .name("testing task second level")
585 .scheduler_role("")
586 .allow_thread_blocking()
587 .shutdown_observer(move |_dispatcher| {
588 println!("shutdown observer called");
589 shutdown_tx.send(1).unwrap();
590 })
591 .create()
592 .unwrap();
593 inner
594 .post_task_sync(move |status| {
595 assert_eq!(status, Status::from_raw(ZX_OK));
596 tx.send(status).unwrap();
597 })
598 .unwrap();
599 inner_tx.send(inner).unwrap();
603 })
604 .unwrap();
605 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
606 inner_rx.recv().unwrap();
607 });
608 assert_eq!(shutdown_rx.recv().unwrap(), 1);
609 }
610
611 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
612 println!("starting ping!");
613 tx.send(0).await.unwrap();
614 while let Some(next) = rx.next().await {
615 println!("ping! {next}");
616 tx.send(next + 1).await.unwrap();
617 }
618 }
619
620 async fn pong(
621 fin_tx: std::sync::mpsc::Sender<()>,
622 mut tx: async_mpsc::Sender<u8>,
623 mut rx: async_mpsc::Receiver<u8>,
624 ) {
625 println!("starting pong!");
626 while let Some(next) = rx.next().await {
627 println!("pong! {next}");
628 if next > 10 {
629 println!("bye!");
630 break;
631 }
632 tx.send(next + 1).await.unwrap();
633 }
634 fin_tx.send(()).unwrap();
635 }
636
637 #[test]
638 fn async_ping_pong() {
639 with_raw_dispatcher("async ping pong", |dispatcher| {
640 let (fin_tx, fin_rx) = mpsc::channel();
641 let (ping_tx, pong_rx) = async_mpsc::channel(10);
642 let (pong_tx, ping_rx) = async_mpsc::channel(10);
643 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
644 dispatcher.spawn_task(pong(fin_tx, pong_tx, pong_rx)).unwrap();
645
646 fin_rx.recv().expect("to receive final value");
647 });
648 }
649
650 async fn slow_pong(
651 fin_tx: std::sync::mpsc::Sender<()>,
652 mut tx: async_mpsc::Sender<u8>,
653 mut rx: async_mpsc::Receiver<u8>,
654 ) {
655 use zx::MonotonicDuration;
656 println!("starting pong!");
657 while let Some(next) = rx.next().await {
658 println!("pong! {next}");
659 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
660 MonotonicDuration::from_seconds(1),
661 ))
662 .await;
663 if next > 10 {
664 println!("bye!");
665 break;
666 }
667 tx.send(next + 1).await.unwrap();
668 }
669 fin_tx.send(()).unwrap();
670 }
671
672 #[test]
673 fn mixed_executor_async_ping_pong() {
674 with_raw_dispatcher("async ping pong", |dispatcher| {
675 let (fin_tx, fin_rx) = mpsc::channel();
676 let (ping_tx, pong_rx) = async_mpsc::channel(10);
677 let (pong_tx, ping_rx) = async_mpsc::channel(10);
678
679 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
681
682 let mut executor = fuchsia_async::LocalExecutor::default();
684 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
685
686 fin_rx.recv().expect("to receive final value");
687 });
688 }
689}