1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14
15use zx::Status;
16
17use crate::shutdown_observer::ShutdownObserver;
18
19pub use fdf_sys::fdf_dispatcher_t;
20pub use libasync::{
21 AfterDeadline, AsAsyncDispatcherRef, AsyncDispatcher, AsyncDispatcherRef, DispatcherTimerExt,
22 JoinHandle, OnDispatcher, Task,
23};
24
25pub trait ShutdownObserverFn: FnOnce(DriverDispatcherRef<'_>) + Send + 'static {}
27impl<T> ShutdownObserverFn for T where T: FnOnce(DriverDispatcherRef<'_>) + Send + 'static {}
28
29#[derive(Default)]
31pub struct DispatcherBuilder {
32 #[doc(hidden)]
33 pub options: u32,
34 #[doc(hidden)]
35 pub name: String,
36 #[doc(hidden)]
37 pub scheduler_role: String,
38 #[doc(hidden)]
39 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
40}
41
42impl DispatcherBuilder {
43 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
45 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
47 pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
49
50 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn unsynchronized(mut self) -> Self {
63 assert!(
64 !self.allows_thread_blocking(),
65 "you may not create an unsynchronized dispatcher that allows synchronous calls"
66 );
67 self.options |= Self::UNSYNCHRONIZED;
68 self
69 }
70
71 pub fn is_unsynchronized(&self) -> bool {
73 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
74 }
75
76 pub fn allow_thread_blocking(mut self) -> Self {
82 assert!(
83 !self.is_unsynchronized(),
84 "you may not create an unsynchronized dispatcher that allows synchronous calls"
85 );
86 self.options |= Self::ALLOW_THREAD_BLOCKING;
87 self
88 }
89
90 pub fn allows_thread_blocking(&self) -> bool {
92 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
93 }
94
95 pub fn no_thread_migration(mut self) -> Self {
102 self.options |= Self::NO_THREAD_MIGRATION;
103 self
104 }
105
106 pub fn allows_thread_migration(&self) -> bool {
108 (self.options & Self::NO_THREAD_MIGRATION) == 0
109 }
110
111 pub fn name(mut self, name: &str) -> Self {
114 self.name = name.to_string();
115 self
116 }
117
118 pub fn scheduler_role(mut self, role: &str) -> Self {
122 self.scheduler_role = role.to_string();
123 self
124 }
125
126 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
128 self.shutdown_observer = Some(Box::new(shutdown_observer));
129 self
130 }
131
132 pub fn create(self) -> Result<Dispatcher, Status> {
137 let mut out_dispatcher = null_mut();
138 let options = self.options;
139 let name = self.name.as_ptr() as *mut ffi::c_char;
140 let name_len = self.name.len();
141 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
142 let scheduler_role_len = self.scheduler_role.len();
143 let observer =
144 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
145 .into_ptr();
146 Status::ok(unsafe {
150 fdf_dispatcher_create(
151 options,
152 name,
153 name_len,
154 scheduler_role,
155 scheduler_role_len,
156 observer,
157 &mut out_dispatcher,
158 )
159 })?;
160 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
163 }
164
165 pub fn create_released(self) -> Result<AutoReleaseDispatcher, Status> {
169 self.create().map(Dispatcher::release)
170 }
171}
172
173#[derive(Debug)]
175pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
176
177unsafe impl Send for Dispatcher {}
179unsafe impl Sync for Dispatcher {}
180thread_local! {
181 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
182}
183
184impl Dispatcher {
185 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
193 Self(handle)
194 }
195
196 fn get_raw_flags(&self) -> u32 {
197 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
199 }
200
201 pub fn is_unsynchronized(&self) -> bool {
203 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
204 }
205
206 pub fn allows_thread_blocking(&self) -> bool {
208 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
209 }
210
211 pub fn allows_thread_migration(&self) -> bool {
214 (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
215 }
216
217 pub fn is_current_dispatcher(&self) -> bool {
219 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
222 }
223
224 pub fn release(self) -> AutoReleaseDispatcher {
229 AutoReleaseDispatcher { dispatcher: ManuallyDrop::new(self) }
230 }
231
232 pub fn as_dispatcher_ref(&self) -> DriverDispatcherRef<'_> {
235 DriverDispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
236 }
237}
238
239impl AsAsyncDispatcherRef for Dispatcher {
240 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
241 let async_dispatcher =
242 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
243 .expect("No async dispatcher on driver dispatcher");
244 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
245 }
246}
247
248impl Drop for Dispatcher {
249 fn drop(&mut self) {
250 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
253 }
254}
255
256#[derive(Debug)]
266pub struct AutoReleaseDispatcher {
267 dispatcher: ManuallyDrop<Dispatcher>,
268}
269
270impl AutoReleaseDispatcher {
271 pub unsafe fn from_raw(dispatcher: NonNull<fdf_dispatcher_t>) -> Self {
279 let dispatcher = ManuallyDrop::new(Dispatcher(dispatcher));
280 Self { dispatcher }
281 }
282
283 pub fn as_async_dispatcher(&self) -> AsyncDispatcher {
287 AsyncDispatcher::new(self)
288 }
289
290 pub fn as_dispatcher_ref(&self) -> DriverDispatcherRef<'_> {
293 DriverDispatcherRef(ManuallyDrop::new(Dispatcher(self.dispatcher.0)), PhantomData)
294 }
295
296 pub fn always_on_dispatcher(&self) -> AutoReleaseDispatcher {
298 let dispatcher_ref = unsafe { DriverDispatcherRef::from_raw(self.dispatcher.0) };
301 let dispatcher = unsafe { Dispatcher::from_raw(dispatcher_ref.always_on_dispatcher().0.0) };
306 Self { dispatcher: ManuallyDrop::new(dispatcher) }
307 }
308}
309
310impl AsAsyncDispatcherRef for AutoReleaseDispatcher {
311 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
312 self.dispatcher.as_async_dispatcher_ref()
313 }
314}
315
316impl From<Dispatcher> for AutoReleaseDispatcher {
317 fn from(dispatcher: Dispatcher) -> Self {
318 Self { dispatcher: ManuallyDrop::new(dispatcher) }
319 }
320}
321
322#[derive(Debug)]
326pub struct DriverDispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
327
328impl<'a> DriverDispatcherRef<'a> {
329 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
336 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
338 }
339
340 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
347 let handle = NonNull::new(unsafe {
348 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
349 })
350 .unwrap();
351 unsafe { Self::from_raw(handle) }
352 }
353
354 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
360 unsafe { self.0.0.as_mut() }
361 }
362
363 pub fn always_on_dispatcher(&self) -> DriverDispatcherRef<'a> {
366 let ptr = unsafe { fdf_dispatcher_get_always_on_dispatcher(self.0.0.as_ptr()) };
368 DriverDispatcherRef(
369 ManuallyDrop::new(Dispatcher(NonNull::new(ptr).expect("Always-on dispatcher is NULL"))),
370 PhantomData,
371 )
372 }
373}
374
375struct AddSendFuture<T>(T);
383
384impl<T: Future> Future for AddSendFuture<T> {
385 type Output = T::Output;
386
387 fn poll(
388 self: std::pin::Pin<&mut Self>,
389 cx: &mut std::task::Context<'_>,
390 ) -> std::task::Poll<Self::Output> {
391 let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
393 fut.poll(cx)
394 }
395}
396
397unsafe impl<T> Send for AddSendFuture<T> {}
401
402pub trait OnDriverDispatcher: OnDispatcher {
405 fn spawn_local(
418 &self,
419 future: impl Future<Output = ()> + 'static,
420 ) -> Result<JoinHandle<()>, Status>
421 where
422 Self: 'static,
423 {
424 self.on_maybe_dispatcher(|dispatcher| {
425 let dispatcher = DriverDispatcherRef::from_async_dispatcher(dispatcher);
426 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
427 Ok(OnDispatcher::spawn(self, AddSendFuture(future)))
428 } else {
429 Err(Status::BAD_STATE)
430 }
431 })
432 }
433
434 fn compute_local<T: Send + 'static>(
449 &self,
450 future: impl Future<Output = T> + 'static,
451 ) -> Result<Task<T>, Status>
452 where
453 Self: 'static,
454 {
455 self.on_maybe_dispatcher(|dispatcher| {
456 let dispatcher = DriverDispatcherRef::from_async_dispatcher(dispatcher);
457 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
458 Ok(OnDispatcher::compute(self, AddSendFuture(future)))
459 } else {
460 Err(Status::BAD_STATE)
461 }
462 })
463 }
464}
465
466impl<'a> AsAsyncDispatcherRef for DriverDispatcherRef<'a> {
467 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
468 self.0.as_async_dispatcher_ref()
469 }
470}
471
472impl<'a> Clone for DriverDispatcherRef<'a> {
473 fn clone(&self) -> Self {
474 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
475 }
476}
477
478impl<'a> core::ops::Deref for DriverDispatcherRef<'a> {
479 type Target = Dispatcher;
480 fn deref(&self) -> &Self::Target {
481 &self.0
482 }
483}
484
485impl<'a> core::ops::DerefMut for DriverDispatcherRef<'a> {
486 fn deref_mut(&mut self) -> &mut Self::Target {
487 &mut self.0
488 }
489}
490
491impl<T> OnDriverDispatcher for T where T: AsAsyncDispatcherRef + Clone {}
494
495#[derive(Clone, Copy, Debug, Default, PartialEq)]
498pub struct CurrentDispatcher;
499
500impl OnDispatcher for CurrentDispatcher {
501 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
502 let dispatcher = OVERRIDE_DISPATCHER
503 .with(|global| *global.borrow())
504 .or_else(|| {
505 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
507 })
508 .map(|dispatcher| {
509 let async_dispatcher = NonNull::new(unsafe {
515 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
516 })
517 .expect("No async dispatcher on driver dispatcher");
518 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
519 });
520 f(dispatcher)
521 }
522}
523
524impl OnDriverDispatcher for CurrentDispatcher {}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 use std::sync::{Once, mpsc};
531
532 use futures::channel::mpsc as async_mpsc;
533 use futures::{SinkExt, StreamExt};
534 use zx::sys::ZX_OK;
535
536 use core::ffi::{c_char, c_void};
537 use core::ptr::null_mut;
538
539 static GLOBAL_DRIVER_ENV: Once = Once::new();
540 const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
541
542 pub fn ensure_driver_env() {
543 GLOBAL_DRIVER_ENV.call_once(|| {
544 unsafe {
547 assert_eq!(fdf_env_start(0), ZX_OK);
548 assert_eq!(
549 fdf_env_set_scheduler_role_opts(
550 NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
551 NO_SYNC_CALLS_ROLE.len(),
552 FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
553 ),
554 ZX_OK
555 );
556 }
557 });
558 }
559 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(AsyncDispatcher) -> T) -> T {
560 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
561 }
562
563 pub(crate) fn with_raw_dispatcher_flags<T>(
564 name: &str,
565 flags: u32,
566 scheduler_role: &str,
567 p: impl for<'a> FnOnce(AsyncDispatcher) -> T,
568 ) -> T {
569 ensure_driver_env();
570
571 let (shutdown_tx, shutdown_rx) = mpsc::channel();
572 let mut dispatcher = null_mut();
573 let mut observer = ShutdownObserver::new(move |dispatcher| {
574 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
577 shutdown_tx.send(()).unwrap();
578 })
579 .into_ptr();
580 let driver_ptr = &mut observer as *mut _ as *mut c_void;
581 let res = unsafe {
586 fdf_env_dispatcher_create_with_owner(
587 driver_ptr,
588 flags,
589 name.as_ptr() as *const c_char,
590 name.len(),
591 scheduler_role.as_ptr() as *const c_char,
592 scheduler_role.len(),
593 observer,
594 &mut dispatcher,
595 )
596 };
597 assert_eq!(res, ZX_OK);
598 let dispatcher = Dispatcher(NonNull::new(dispatcher).unwrap());
599
600 let res = p(AsyncDispatcher::new(&dispatcher));
601
602 drop(dispatcher);
603 shutdown_rx.recv().unwrap();
604
605 res
606 }
607
608 #[test]
609 fn start_test_dispatcher() {
610 with_raw_dispatcher("testing", |dispatcher| {
611 println!("hello {dispatcher:?}");
612 })
613 }
614
615 #[test]
616 fn post_task_on_dispatcher() {
617 with_raw_dispatcher("testing task", |dispatcher| {
618 let (tx, rx) = mpsc::channel();
619 dispatcher
620 .post_task_sync(move |status| {
621 assert_eq!(status, Status::from_raw(ZX_OK));
622 tx.send(status).unwrap();
623 })
624 .unwrap();
625 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
626 });
627 }
628
629 #[test]
630 fn post_task_on_subdispatcher() {
631 let (shutdown_tx, shutdown_rx) = mpsc::channel();
632 with_raw_dispatcher("testing task top level", move |dispatcher| {
633 let (tx, rx) = mpsc::channel();
634 let (inner_tx, inner_rx) = mpsc::channel();
635 dispatcher
636 .post_task_sync(move |status| {
637 assert_eq!(status, Status::from_raw(ZX_OK));
638 let inner = DispatcherBuilder::new()
639 .name("testing task second level")
640 .scheduler_role("")
641 .allow_thread_blocking()
642 .shutdown_observer(move |_dispatcher| {
643 println!("shutdown observer called");
644 shutdown_tx.send(1).unwrap();
645 })
646 .create()
647 .unwrap();
648 inner
649 .post_task_sync(move |status| {
650 assert_eq!(status, Status::from_raw(ZX_OK));
651 tx.send(status).unwrap();
652 })
653 .unwrap();
654 inner_tx.send(inner).unwrap();
658 })
659 .unwrap();
660 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
661 inner_rx.recv().unwrap();
662 });
663 assert_eq!(shutdown_rx.recv().unwrap(), 1);
664 }
665
666 #[test]
667 fn spawn_local_fails_on_normal_dispatcher() {
668 let (shutdown_tx, shutdown_rx) = mpsc::channel();
669 with_raw_dispatcher("spawn local failures", move |dispatcher| {
670 let inside_dispatcher = dispatcher.clone();
671 dispatcher.spawn(async move {
672 assert_eq!(
673 inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
674 Status::BAD_STATE
675 );
676 assert_eq!(
677 inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
678 Status::BAD_STATE
679 );
680 shutdown_tx.send(()).unwrap();
681 });
682 shutdown_rx.recv().unwrap();
683 });
684 }
685
686 #[test]
687 #[ignore = "Pending resolution of b/488397193"]
688 fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
689 let (tx, rx) = mpsc::channel();
690 with_raw_dispatcher_flags(
691 "spawn local success",
692 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
693 NO_SYNC_CALLS_ROLE,
694 move |dispatcher| {
695 let inside_dispatcher = dispatcher.clone();
696 dispatcher.spawn(async move {
697 let tx_clone = tx.clone();
698 inside_dispatcher
699 .spawn_local(async move {
700 tx_clone.send(()).unwrap();
701 })
702 .unwrap();
703 inside_dispatcher
704 .compute_local(async move {
705 tx.send(()).unwrap();
706 })
707 .unwrap()
708 .await
709 .unwrap();
710 });
711 rx.recv().unwrap();
713 rx.recv().unwrap();
714 },
715 );
716 }
717
718 #[test]
719 #[ignore = "Pending resolution of b/488397193"]
720 fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
721 with_raw_dispatcher_flags(
722 "spawn local success",
723 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
724 NO_SYNC_CALLS_ROLE,
725 move |dispatcher| {
726 assert_eq!(
729 dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
730 Status::BAD_STATE
731 );
732 assert_eq!(
733 dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
734 Status::BAD_STATE
735 );
736 },
737 );
738 }
739
740 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
741 println!("starting ping!");
742 tx.send(0).await.unwrap();
743 while let Some(next) = rx.next().await {
744 println!("ping! {next}");
745 tx.send(next + 1).await.unwrap();
746 }
747 }
748
749 async fn pong(
750 fin_tx: std::sync::mpsc::Sender<()>,
751 mut tx: async_mpsc::Sender<u8>,
752 mut rx: async_mpsc::Receiver<u8>,
753 ) {
754 println!("starting pong!");
755 while let Some(next) = rx.next().await {
756 println!("pong! {next}");
757 if next > 10 {
758 println!("bye!");
759 break;
760 }
761 tx.send(next + 1).await.unwrap();
762 }
763 fin_tx.send(()).unwrap();
764 }
765
766 #[test]
767 fn async_ping_pong() {
768 with_raw_dispatcher("async ping pong", |dispatcher| {
769 let (fin_tx, fin_rx) = mpsc::channel();
770 let (ping_tx, pong_rx) = async_mpsc::channel(10);
771 let (pong_tx, ping_rx) = async_mpsc::channel(10);
772 dispatcher.spawn(ping(ping_tx, ping_rx));
773 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx));
774
775 fin_rx.recv().expect("to receive final value");
776 });
777 }
778
779 async fn slow_pong(
780 fin_tx: std::sync::mpsc::Sender<()>,
781 mut tx: async_mpsc::Sender<u8>,
782 mut rx: async_mpsc::Receiver<u8>,
783 ) {
784 use zx::MonotonicDuration;
785 println!("starting pong!");
786 while let Some(next) = rx.next().await {
787 println!("pong! {next}");
788 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
789 MonotonicDuration::from_seconds(1),
790 ))
791 .await;
792 if next > 10 {
793 println!("bye!");
794 break;
795 }
796 tx.send(next + 1).await.unwrap();
797 }
798 fin_tx.send(()).unwrap();
799 }
800
801 #[test]
802 fn mixed_executor_async_ping_pong() {
803 with_raw_dispatcher("async ping pong", |dispatcher| {
804 let (fin_tx, fin_rx) = mpsc::channel();
805 let (ping_tx, pong_rx) = async_mpsc::channel(10);
806 let (pong_tx, ping_rx) = async_mpsc::channel(10);
807
808 dispatcher.spawn(ping(ping_tx, ping_rx));
810
811 let mut executor = fuchsia_async::LocalExecutor::default();
813 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
814
815 fin_rx.recv().expect("to receive final value");
816 });
817 }
818}