1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::atomic::{AtomicPtr, Ordering};
15use std::sync::{Arc, Weak};
16
17use zx::Status;
18
19use crate::shutdown_observer::ShutdownObserver;
20
21pub use fdf_sys::fdf_dispatcher_t;
22pub use libasync::{
23 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, JoinHandle, OnDispatcher, Task,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48
49 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn unsynchronized(mut self) -> Self {
62 assert!(
63 !self.allows_thread_blocking(),
64 "you may not create an unsynchronized dispatcher that allows synchronous calls"
65 );
66 self.options |= Self::UNSYNCHRONIZED;
67 self
68 }
69
70 pub fn is_unsynchronized(&self) -> bool {
72 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
73 }
74
75 pub fn allow_thread_blocking(mut self) -> Self {
81 assert!(
82 !self.is_unsynchronized(),
83 "you may not create an unsynchronized dispatcher that allows synchronous calls"
84 );
85 self.options |= Self::ALLOW_THREAD_BLOCKING;
86 self
87 }
88
89 pub fn allows_thread_blocking(&self) -> bool {
91 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
92 }
93
94 pub fn name(mut self, name: &str) -> Self {
97 self.name = name.to_string();
98 self
99 }
100
101 pub fn scheduler_role(mut self, role: &str) -> Self {
105 self.scheduler_role = role.to_string();
106 self
107 }
108
109 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
111 self.shutdown_observer = Some(Box::new(shutdown_observer));
112 self
113 }
114
115 pub fn create(self) -> Result<Dispatcher, Status> {
120 let mut out_dispatcher = null_mut();
121 let options = self.options;
122 let name = self.name.as_ptr() as *mut ffi::c_char;
123 let name_len = self.name.len();
124 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
125 let scheduler_role_len = self.scheduler_role.len();
126 let observer =
127 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
128 .into_ptr();
129 Status::ok(unsafe {
133 fdf_dispatcher_create(
134 options,
135 name,
136 name_len,
137 scheduler_role,
138 scheduler_role_len,
139 observer,
140 &mut out_dispatcher,
141 )
142 })?;
143 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
146 }
147
148 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
152 self.create().map(Dispatcher::release)
153 }
154}
155
156#[derive(Debug)]
158pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
159
160unsafe impl Send for Dispatcher {}
162unsafe impl Sync for Dispatcher {}
163thread_local! {
164 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
165}
166
167impl Dispatcher {
168 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
176 Self(handle)
177 }
178
179 fn get_raw_flags(&self) -> u32 {
180 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
182 }
183
184 pub fn is_unsynchronized(&self) -> bool {
186 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
187 }
188
189 pub fn allows_thread_blocking(&self) -> bool {
191 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
192 }
193
194 pub fn is_current_dispatcher(&self) -> bool {
196 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
199 }
200
201 pub fn release(self) -> DispatcherRef<'static> {
206 DispatcherRef(ManuallyDrop::new(self), PhantomData)
207 }
208
209 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
212 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
213 }
214}
215
216impl AsyncDispatcher for Dispatcher {
217 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
218 let async_dispatcher =
219 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
220 .expect("No async dispatcher on driver dispatcher");
221 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
222 }
223}
224
225impl Drop for Dispatcher {
226 fn drop(&mut self) {
227 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
230 }
231}
232
233#[derive(Debug)]
243pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
244
245impl AutoReleaseDispatcher {
246 pub fn downgrade(&self) -> WeakDispatcher {
250 WeakDispatcher::from(self)
251 }
252}
253
254impl AsyncDispatcher for AutoReleaseDispatcher {
255 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
256 let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
257 .expect("tried to obtain async dispatcher after drop");
258 unsafe {
261 AsyncDispatcherRef::from_raw(
262 NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
263 )
264 }
265 }
266}
267
268impl From<Dispatcher> for AutoReleaseDispatcher {
269 fn from(dispatcher: Dispatcher) -> Self {
270 let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
271 Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
272 }
273}
274
275impl Drop for AutoReleaseDispatcher {
276 fn drop(&mut self) {
277 self.0.store(null_mut(), Ordering::Relaxed);
280 while Arc::strong_count(&self.0) > 1 {
284 std::thread::sleep(std::time::Duration::from_nanos(100))
287 }
288 }
289}
290
291#[derive(Clone, Debug)]
299pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
300
301impl From<&AutoReleaseDispatcher> for WeakDispatcher {
302 fn from(value: &AutoReleaseDispatcher) -> Self {
303 Self(Arc::downgrade(&value.0))
304 }
305}
306
307impl OnDispatcher for WeakDispatcher {
308 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
309 let Some(dispatcher_ptr) = self.0.upgrade() else {
310 return f(None);
311 };
312 let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
313 return f(None);
314 };
315 f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
319 }
320}
321
322#[derive(Debug)]
326pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
327
328impl<'a> DispatcherRef<'a> {
329 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
336 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
338 }
339
340 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
347 let handle = NonNull::new(unsafe {
348 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
349 })
350 .unwrap();
351 unsafe { Self::from_raw(handle) }
352 }
353
354 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
360 unsafe { self.0.0.as_mut() }
361 }
362}
363
364impl<'a> AsyncDispatcher for DispatcherRef<'a> {
365 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
366 self.0.as_async_dispatcher_ref()
367 }
368}
369
370impl<'a> Clone for DispatcherRef<'a> {
371 fn clone(&self) -> Self {
372 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
373 }
374}
375
376impl<'a> core::ops::Deref for DispatcherRef<'a> {
377 type Target = Dispatcher;
378 fn deref(&self) -> &Self::Target {
379 &self.0
380 }
381}
382
383impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
384 fn deref_mut(&mut self) -> &mut Self::Target {
385 &mut self.0
386 }
387}
388
389impl<'a> OnDispatcher for DispatcherRef<'a> {
390 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
391 f(Some(self.as_async_dispatcher_ref()))
392 }
393}
394
395#[derive(Clone, Copy, Debug, Default, PartialEq)]
398pub struct CurrentDispatcher;
399
400impl OnDispatcher for CurrentDispatcher {
401 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
402 let dispatcher = OVERRIDE_DISPATCHER
403 .with(|global| *global.borrow())
404 .or_else(|| {
405 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
407 })
408 .map(|dispatcher| {
409 let async_dispatcher = NonNull::new(unsafe {
415 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
416 })
417 .expect("No async dispatcher on driver dispatcher");
418 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
419 });
420 f(dispatcher)
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 use std::sync::{Arc, Once, Weak, mpsc};
429
430 use futures::channel::mpsc as async_mpsc;
431 use futures::{SinkExt, StreamExt};
432 use zx::sys::ZX_OK;
433
434 use core::ffi::{c_char, c_void};
435 use core::ptr::null_mut;
436
437 static GLOBAL_DRIVER_ENV: Once = Once::new();
438
439 pub fn ensure_driver_env() {
440 GLOBAL_DRIVER_ENV.call_once(|| {
441 unsafe {
444 assert_eq!(fdf_env_start(0), ZX_OK);
445 }
446 });
447 }
448 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
449 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
450 }
451
452 pub(crate) fn with_raw_dispatcher_flags<T>(
453 name: &str,
454 flags: u32,
455 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
456 ) -> T {
457 ensure_driver_env();
458
459 let (shutdown_tx, shutdown_rx) = mpsc::channel();
460 let mut dispatcher = null_mut();
461 let mut observer = ShutdownObserver::new(move |dispatcher| {
462 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
465 shutdown_tx.send(()).unwrap();
466 })
467 .into_ptr();
468 let driver_ptr = &mut observer as *mut _ as *mut c_void;
469 let res = unsafe {
474 fdf_env_dispatcher_create_with_owner(
475 driver_ptr,
476 flags,
477 name.as_ptr() as *const c_char,
478 name.len(),
479 "".as_ptr() as *const c_char,
480 0_usize,
481 observer,
482 &mut dispatcher,
483 )
484 };
485 assert_eq!(res, ZX_OK);
486 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
487
488 let res = p(Arc::downgrade(&dispatcher));
489
490 let weak_dispatcher = Arc::downgrade(&dispatcher);
494 drop(dispatcher);
495 shutdown_rx.recv().unwrap();
496 assert_eq!(
497 0,
498 weak_dispatcher.strong_count(),
499 "a dispatcher reference escaped the test body"
500 );
501
502 res
503 }
504
505 #[test]
506 fn start_test_dispatcher() {
507 with_raw_dispatcher("testing", |dispatcher| {
508 println!("hello {dispatcher:?}");
509 })
510 }
511
512 #[test]
513 fn post_task_on_dispatcher() {
514 with_raw_dispatcher("testing task", |dispatcher| {
515 let (tx, rx) = mpsc::channel();
516 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
517 dispatcher
518 .post_task_sync(move |status| {
519 assert_eq!(status, Status::from_raw(ZX_OK));
520 tx.send(status).unwrap();
521 })
522 .unwrap();
523 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
524 });
525 }
526
527 #[test]
528 fn post_task_on_subdispatcher() {
529 let (shutdown_tx, shutdown_rx) = mpsc::channel();
530 with_raw_dispatcher("testing task top level", move |dispatcher| {
531 let (tx, rx) = mpsc::channel();
532 let (inner_tx, inner_rx) = mpsc::channel();
533 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
534 dispatcher
535 .post_task_sync(move |status| {
536 assert_eq!(status, Status::from_raw(ZX_OK));
537 let inner = DispatcherBuilder::new()
538 .name("testing task second level")
539 .scheduler_role("")
540 .allow_thread_blocking()
541 .shutdown_observer(move |_dispatcher| {
542 println!("shutdown observer called");
543 shutdown_tx.send(1).unwrap();
544 })
545 .create()
546 .unwrap();
547 inner
548 .post_task_sync(move |status| {
549 assert_eq!(status, Status::from_raw(ZX_OK));
550 tx.send(status).unwrap();
551 })
552 .unwrap();
553 inner_tx.send(inner).unwrap();
557 })
558 .unwrap();
559 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
560 inner_rx.recv().unwrap();
561 });
562 assert_eq!(shutdown_rx.recv().unwrap(), 1);
563 }
564
565 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
566 println!("starting ping!");
567 tx.send(0).await.unwrap();
568 while let Some(next) = rx.next().await {
569 println!("ping! {next}");
570 tx.send(next + 1).await.unwrap();
571 }
572 }
573
574 async fn pong(
575 fin_tx: std::sync::mpsc::Sender<()>,
576 mut tx: async_mpsc::Sender<u8>,
577 mut rx: async_mpsc::Receiver<u8>,
578 ) {
579 println!("starting pong!");
580 while let Some(next) = rx.next().await {
581 println!("pong! {next}");
582 if next > 10 {
583 println!("bye!");
584 break;
585 }
586 tx.send(next + 1).await.unwrap();
587 }
588 fin_tx.send(()).unwrap();
589 }
590
591 #[test]
592 fn async_ping_pong() {
593 with_raw_dispatcher("async ping pong", |dispatcher| {
594 let (fin_tx, fin_rx) = mpsc::channel();
595 let (ping_tx, pong_rx) = async_mpsc::channel(10);
596 let (pong_tx, ping_rx) = async_mpsc::channel(10);
597 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
598 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
599
600 fin_rx.recv().expect("to receive final value");
601 });
602 }
603
604 async fn slow_pong(
605 fin_tx: std::sync::mpsc::Sender<()>,
606 mut tx: async_mpsc::Sender<u8>,
607 mut rx: async_mpsc::Receiver<u8>,
608 ) {
609 use zx::MonotonicDuration;
610 println!("starting pong!");
611 while let Some(next) = rx.next().await {
612 println!("pong! {next}");
613 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
614 MonotonicDuration::from_seconds(1),
615 ))
616 .await;
617 if next > 10 {
618 println!("bye!");
619 break;
620 }
621 tx.send(next + 1).await.unwrap();
622 }
623 fin_tx.send(()).unwrap();
624 }
625
626 #[test]
627 fn mixed_executor_async_ping_pong() {
628 with_raw_dispatcher("async ping pong", |dispatcher| {
629 let (fin_tx, fin_rx) = mpsc::channel();
630 let (ping_tx, pong_rx) = async_mpsc::channel(10);
631 let (pong_tx, ping_rx) = async_mpsc::channel(10);
632
633 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
635
636 let mut executor = fuchsia_async::LocalExecutor::default();
638 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
639
640 fin_rx.recv().expect("to receive final value");
641 });
642 }
643}