1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::atomic::{AtomicPtr, Ordering};
15use std::sync::{Arc, Weak};
16
17use zx::Status;
18
19use crate::shutdown_observer::ShutdownObserver;
20
21pub use fdf_sys::fdf_dispatcher_t;
22pub use libasync::{
23 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, JoinHandle, OnDispatcher, Task,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48
49 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn unsynchronized(mut self) -> Self {
62 assert!(
63 !self.allows_thread_blocking(),
64 "you may not create an unsynchronized dispatcher that allows synchronous calls"
65 );
66 self.options |= Self::UNSYNCHRONIZED;
67 self
68 }
69
70 pub fn is_unsynchronized(&self) -> bool {
72 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
73 }
74
75 pub fn allow_thread_blocking(mut self) -> Self {
81 assert!(
82 !self.is_unsynchronized(),
83 "you may not create an unsynchronized dispatcher that allows synchronous calls"
84 );
85 self.options |= Self::ALLOW_THREAD_BLOCKING;
86 self
87 }
88
89 pub fn allows_thread_blocking(&self) -> bool {
91 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
92 }
93
94 pub fn name(mut self, name: &str) -> Self {
97 self.name = name.to_string();
98 self
99 }
100
101 pub fn scheduler_role(mut self, role: &str) -> Self {
105 self.scheduler_role = role.to_string();
106 self
107 }
108
109 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
111 self.shutdown_observer = Some(Box::new(shutdown_observer));
112 self
113 }
114
115 pub fn create(self) -> Result<Dispatcher, Status> {
120 let mut out_dispatcher = null_mut();
121 let options = self.options;
122 let name = self.name.as_ptr() as *mut ffi::c_char;
123 let name_len = self.name.len();
124 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
125 let scheduler_role_len = self.scheduler_role.len();
126 let observer =
127 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
128 .into_ptr();
129 Status::ok(unsafe {
133 fdf_dispatcher_create(
134 options,
135 name,
136 name_len,
137 scheduler_role,
138 scheduler_role_len,
139 observer,
140 &mut out_dispatcher,
141 )
142 })?;
143 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
146 }
147
148 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
152 self.create().map(Dispatcher::release)
153 }
154}
155
156#[derive(Debug)]
158pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
159
160unsafe impl Send for Dispatcher {}
162unsafe impl Sync for Dispatcher {}
163thread_local! {
164 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
165}
166
167impl Dispatcher {
168 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
176 Self(handle)
177 }
178
179 fn get_raw_flags(&self) -> u32 {
180 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
182 }
183
184 pub fn is_unsynchronized(&self) -> bool {
186 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
187 }
188
189 pub fn allows_thread_blocking(&self) -> bool {
191 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
192 }
193
194 pub fn is_current_dispatcher(&self) -> bool {
196 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
199 }
200
201 pub fn release(self) -> DispatcherRef<'static> {
206 DispatcherRef(ManuallyDrop::new(self), PhantomData)
207 }
208
209 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
212 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
213 }
214}
215
216impl AsyncDispatcher for Dispatcher {
217 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
218 let async_dispatcher =
219 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
220 .expect("No async dispatcher on driver dispatcher");
221 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
222 }
223}
224
225impl Drop for Dispatcher {
226 fn drop(&mut self) {
227 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
230 }
231}
232
233#[derive(Debug)]
243pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
244
245impl AutoReleaseDispatcher {
246 pub fn downgrade(&self) -> WeakDispatcher {
250 WeakDispatcher::from(self)
251 }
252}
253
254impl From<Dispatcher> for AutoReleaseDispatcher {
255 fn from(dispatcher: Dispatcher) -> Self {
256 let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
257 Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
258 }
259}
260
261impl Drop for AutoReleaseDispatcher {
262 fn drop(&mut self) {
263 self.0.store(null_mut(), Ordering::Relaxed);
266 while Arc::strong_count(&self.0) > 1 {
270 std::thread::sleep(std::time::Duration::from_nanos(100))
273 }
274 }
275}
276
277#[derive(Clone, Debug)]
285pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
286
287impl From<&AutoReleaseDispatcher> for WeakDispatcher {
288 fn from(value: &AutoReleaseDispatcher) -> Self {
289 Self(Arc::downgrade(&value.0))
290 }
291}
292
293impl OnDispatcher for WeakDispatcher {
294 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
295 let Some(dispatcher_ptr) = self.0.upgrade() else {
296 return f(None);
297 };
298 let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
299 return f(None);
300 };
301 f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
305 }
306}
307
308#[derive(Debug)]
312pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
313
314impl<'a> DispatcherRef<'a> {
315 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
322 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
324 }
325
326 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
333 let handle = NonNull::new(unsafe {
334 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
335 })
336 .unwrap();
337 unsafe { Self::from_raw(handle) }
338 }
339}
340
341impl<'a> AsyncDispatcher for DispatcherRef<'a> {
342 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
343 self.0.as_async_dispatcher_ref()
344 }
345}
346
347impl<'a> Clone for DispatcherRef<'a> {
348 fn clone(&self) -> Self {
349 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
350 }
351}
352
353impl<'a> core::ops::Deref for DispatcherRef<'a> {
354 type Target = Dispatcher;
355 fn deref(&self) -> &Self::Target {
356 &self.0
357 }
358}
359
360impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
361 fn deref_mut(&mut self) -> &mut Self::Target {
362 &mut self.0
363 }
364}
365
366impl<'a> OnDispatcher for DispatcherRef<'a> {
367 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
368 f(Some(self.as_async_dispatcher_ref()))
369 }
370}
371
372#[derive(Clone, Copy, Debug, PartialEq)]
375pub struct CurrentDispatcher;
376
377impl OnDispatcher for CurrentDispatcher {
378 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
379 let dispatcher = OVERRIDE_DISPATCHER
380 .with(|global| *global.borrow())
381 .or_else(|| {
382 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
384 })
385 .map(|dispatcher| {
386 let async_dispatcher = NonNull::new(unsafe {
392 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
393 })
394 .expect("No async dispatcher on driver dispatcher");
395 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
396 });
397 f(dispatcher)
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 use std::sync::{Arc, Once, Weak, mpsc};
406
407 use futures::channel::mpsc as async_mpsc;
408 use futures::{SinkExt, StreamExt};
409 use zx::sys::ZX_OK;
410
411 use core::ffi::{c_char, c_void};
412 use core::ptr::null_mut;
413
414 static GLOBAL_DRIVER_ENV: Once = Once::new();
415
416 pub fn ensure_driver_env() {
417 GLOBAL_DRIVER_ENV.call_once(|| {
418 unsafe {
421 assert_eq!(fdf_env_start(0), ZX_OK);
422 }
423 });
424 }
425 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
426 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
427 }
428
429 pub(crate) fn with_raw_dispatcher_flags<T>(
430 name: &str,
431 flags: u32,
432 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
433 ) -> T {
434 ensure_driver_env();
435
436 let (shutdown_tx, shutdown_rx) = mpsc::channel();
437 let mut dispatcher = null_mut();
438 let mut observer = ShutdownObserver::new(move |dispatcher| {
439 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
442 shutdown_tx.send(()).unwrap();
443 })
444 .into_ptr();
445 let driver_ptr = &mut observer as *mut _ as *mut c_void;
446 let res = unsafe {
451 fdf_env_dispatcher_create_with_owner(
452 driver_ptr,
453 flags,
454 name.as_ptr() as *const c_char,
455 name.len(),
456 "".as_ptr() as *const c_char,
457 0_usize,
458 observer,
459 &mut dispatcher,
460 )
461 };
462 assert_eq!(res, ZX_OK);
463 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
464
465 let res = p(Arc::downgrade(&dispatcher));
466
467 let weak_dispatcher = Arc::downgrade(&dispatcher);
471 drop(dispatcher);
472 shutdown_rx.recv().unwrap();
473 assert_eq!(
474 0,
475 weak_dispatcher.strong_count(),
476 "a dispatcher reference escaped the test body"
477 );
478
479 res
480 }
481
482 #[test]
483 fn start_test_dispatcher() {
484 with_raw_dispatcher("testing", |dispatcher| {
485 println!("hello {dispatcher:?}");
486 })
487 }
488
489 #[test]
490 fn post_task_on_dispatcher() {
491 with_raw_dispatcher("testing task", |dispatcher| {
492 let (tx, rx) = mpsc::channel();
493 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
494 dispatcher
495 .post_task_sync(move |status| {
496 assert_eq!(status, Status::from_raw(ZX_OK));
497 tx.send(status).unwrap();
498 })
499 .unwrap();
500 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
501 });
502 }
503
504 #[test]
505 fn post_task_on_subdispatcher() {
506 let (shutdown_tx, shutdown_rx) = mpsc::channel();
507 with_raw_dispatcher("testing task top level", move |dispatcher| {
508 let (tx, rx) = mpsc::channel();
509 let (inner_tx, inner_rx) = mpsc::channel();
510 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
511 dispatcher
512 .post_task_sync(move |status| {
513 assert_eq!(status, Status::from_raw(ZX_OK));
514 let inner = DispatcherBuilder::new()
515 .name("testing task second level")
516 .scheduler_role("")
517 .allow_thread_blocking()
518 .shutdown_observer(move |_dispatcher| {
519 println!("shutdown observer called");
520 shutdown_tx.send(1).unwrap();
521 })
522 .create()
523 .unwrap();
524 inner
525 .post_task_sync(move |status| {
526 assert_eq!(status, Status::from_raw(ZX_OK));
527 tx.send(status).unwrap();
528 })
529 .unwrap();
530 inner_tx.send(inner).unwrap();
534 })
535 .unwrap();
536 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
537 inner_rx.recv().unwrap();
538 });
539 assert_eq!(shutdown_rx.recv().unwrap(), 1);
540 }
541
542 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
543 println!("starting ping!");
544 tx.send(0).await.unwrap();
545 while let Some(next) = rx.next().await {
546 println!("ping! {next}");
547 tx.send(next + 1).await.unwrap();
548 }
549 }
550
551 async fn pong(
552 fin_tx: std::sync::mpsc::Sender<()>,
553 mut tx: async_mpsc::Sender<u8>,
554 mut rx: async_mpsc::Receiver<u8>,
555 ) {
556 println!("starting pong!");
557 while let Some(next) = rx.next().await {
558 println!("pong! {next}");
559 if next > 10 {
560 println!("bye!");
561 break;
562 }
563 tx.send(next + 1).await.unwrap();
564 }
565 fin_tx.send(()).unwrap();
566 }
567
568 #[test]
569 fn async_ping_pong() {
570 with_raw_dispatcher("async ping pong", |dispatcher| {
571 let (fin_tx, fin_rx) = mpsc::channel();
572 let (ping_tx, pong_rx) = async_mpsc::channel(10);
573 let (pong_tx, ping_rx) = async_mpsc::channel(10);
574 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
575 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
576
577 fin_rx.recv().expect("to receive final value");
578 });
579 }
580
581 async fn slow_pong(
582 fin_tx: std::sync::mpsc::Sender<()>,
583 mut tx: async_mpsc::Sender<u8>,
584 mut rx: async_mpsc::Receiver<u8>,
585 ) {
586 use zx::MonotonicDuration;
587 println!("starting pong!");
588 while let Some(next) = rx.next().await {
589 println!("pong! {next}");
590 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
591 MonotonicDuration::from_seconds(1),
592 ))
593 .await;
594 if next > 10 {
595 println!("bye!");
596 break;
597 }
598 tx.send(next + 1).await.unwrap();
599 }
600 fin_tx.send(()).unwrap();
601 }
602
603 #[test]
604 fn mixed_executor_async_ping_pong() {
605 with_raw_dispatcher("async ping pong", |dispatcher| {
606 let (fin_tx, fin_rx) = mpsc::channel();
607 let (ping_tx, pong_rx) = async_mpsc::channel(10);
608 let (pong_tx, ping_rx) = async_mpsc::channel(10);
609
610 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
612
613 let mut executor = fuchsia_async::LocalExecutor::default();
615 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
616
617 fin_rx.recv().expect("to receive final value");
618 });
619 }
620}