1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::atomic::{AtomicPtr, Ordering};
15use std::sync::{Arc, Weak};
16
17use zx::Status;
18
19use crate::shutdown_observer::ShutdownObserver;
20
21pub use fdf_sys::fdf_dispatcher_t;
22pub use libasync::{
23 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, JoinHandle, OnDispatcher, Task,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48 pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
50
51 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn unsynchronized(mut self) -> Self {
64 assert!(
65 !self.allows_thread_blocking(),
66 "you may not create an unsynchronized dispatcher that allows synchronous calls"
67 );
68 self.options |= Self::UNSYNCHRONIZED;
69 self
70 }
71
72 pub fn is_unsynchronized(&self) -> bool {
74 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
75 }
76
77 pub fn allow_thread_blocking(mut self) -> Self {
83 assert!(
84 !self.is_unsynchronized(),
85 "you may not create an unsynchronized dispatcher that allows synchronous calls"
86 );
87 self.options |= Self::ALLOW_THREAD_BLOCKING;
88 self
89 }
90
91 pub fn allows_thread_blocking(&self) -> bool {
93 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
94 }
95
96 pub fn no_thread_migration(mut self) -> Self {
103 self.options |= Self::NO_THREAD_MIGRATION;
104 self
105 }
106
107 pub fn allows_thread_migration(&self) -> bool {
109 (self.options & Self::NO_THREAD_MIGRATION) == 0
110 }
111
112 pub fn name(mut self, name: &str) -> Self {
115 self.name = name.to_string();
116 self
117 }
118
119 pub fn scheduler_role(mut self, role: &str) -> Self {
123 self.scheduler_role = role.to_string();
124 self
125 }
126
127 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
129 self.shutdown_observer = Some(Box::new(shutdown_observer));
130 self
131 }
132
133 pub fn create(self) -> Result<Dispatcher, Status> {
138 let mut out_dispatcher = null_mut();
139 let options = self.options;
140 let name = self.name.as_ptr() as *mut ffi::c_char;
141 let name_len = self.name.len();
142 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
143 let scheduler_role_len = self.scheduler_role.len();
144 let observer =
145 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
146 .into_ptr();
147 Status::ok(unsafe {
151 fdf_dispatcher_create(
152 options,
153 name,
154 name_len,
155 scheduler_role,
156 scheduler_role_len,
157 observer,
158 &mut out_dispatcher,
159 )
160 })?;
161 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
164 }
165
166 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
170 self.create().map(Dispatcher::release)
171 }
172}
173
174#[derive(Debug)]
176pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
177
178unsafe impl Send for Dispatcher {}
180unsafe impl Sync for Dispatcher {}
181thread_local! {
182 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
183}
184
185impl Dispatcher {
186 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
194 Self(handle)
195 }
196
197 fn get_raw_flags(&self) -> u32 {
198 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
200 }
201
202 pub fn is_unsynchronized(&self) -> bool {
204 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
205 }
206
207 pub fn allows_thread_blocking(&self) -> bool {
209 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
210 }
211
212 pub fn allows_thread_migration(&self) -> bool {
215 (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
216 }
217
218 pub fn is_current_dispatcher(&self) -> bool {
220 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
223 }
224
225 pub fn release(self) -> DispatcherRef<'static> {
230 DispatcherRef(ManuallyDrop::new(self), PhantomData)
231 }
232
233 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
236 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
237 }
238}
239
240impl AsyncDispatcher for Dispatcher {
241 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
242 let async_dispatcher =
243 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
244 .expect("No async dispatcher on driver dispatcher");
245 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
246 }
247}
248
249impl Drop for Dispatcher {
250 fn drop(&mut self) {
251 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
254 }
255}
256
257#[derive(Debug)]
267pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
268
269impl AutoReleaseDispatcher {
270 pub fn downgrade(&self) -> WeakDispatcher {
274 WeakDispatcher::from(self)
275 }
276}
277
278impl AsyncDispatcher for AutoReleaseDispatcher {
279 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
280 let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
281 .expect("tried to obtain async dispatcher after drop");
282 unsafe {
285 AsyncDispatcherRef::from_raw(
286 NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
287 )
288 }
289 }
290}
291
292impl From<Dispatcher> for AutoReleaseDispatcher {
293 fn from(dispatcher: Dispatcher) -> Self {
294 let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
295 Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
296 }
297}
298
299impl Drop for AutoReleaseDispatcher {
300 fn drop(&mut self) {
301 self.0.store(null_mut(), Ordering::Relaxed);
304 while Arc::strong_count(&self.0) > 1 {
308 std::thread::sleep(std::time::Duration::from_nanos(100))
311 }
312 }
313}
314
315#[derive(Clone, Debug)]
323pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
324
325impl From<&AutoReleaseDispatcher> for WeakDispatcher {
326 fn from(value: &AutoReleaseDispatcher) -> Self {
327 Self(Arc::downgrade(&value.0))
328 }
329}
330
331impl OnDispatcher for WeakDispatcher {
332 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
333 let Some(dispatcher_ptr) = self.0.upgrade() else {
334 return f(None);
335 };
336 let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
337 return f(None);
338 };
339 f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
343 }
344}
345
346impl OnDriverDispatcher for WeakDispatcher {}
347
348#[derive(Debug)]
352pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
353
354impl<'a> DispatcherRef<'a> {
355 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
362 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
364 }
365
366 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
373 let handle = NonNull::new(unsafe {
374 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
375 })
376 .unwrap();
377 unsafe { Self::from_raw(handle) }
378 }
379
380 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
386 unsafe { self.0.0.as_mut() }
387 }
388}
389
390struct AddSendFuture<T>(T);
398
399impl<T: Future> Future for AddSendFuture<T> {
400 type Output = T::Output;
401
402 fn poll(
403 self: std::pin::Pin<&mut Self>,
404 cx: &mut std::task::Context<'_>,
405 ) -> std::task::Poll<Self::Output> {
406 let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
408 fut.poll(cx)
409 }
410}
411
412unsafe impl<T> Send for AddSendFuture<T> {}
416
417pub trait OnDriverDispatcher: OnDispatcher {
420 fn spawn_local(
433 &self,
434 future: impl Future<Output = ()> + 'static,
435 ) -> Result<JoinHandle<()>, Status>
436 where
437 Self: 'static,
438 {
439 self.on_maybe_dispatcher(|dispatcher| {
440 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
441 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
442 OnDispatcher::spawn(self, AddSendFuture(future))
443 } else {
444 Err(Status::BAD_STATE)
445 }
446 })
447 }
448
449 fn compute_local<T: Send + 'static>(
464 &self,
465 future: impl Future<Output = T> + 'static,
466 ) -> Result<Task<T>, Status>
467 where
468 Self: 'static,
469 {
470 self.on_maybe_dispatcher(|dispatcher| {
471 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
472 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
473 Ok(OnDispatcher::compute(self, AddSendFuture(future)))
474 } else {
475 Err(Status::BAD_STATE)
476 }
477 })
478 }
479}
480
481impl OnDriverDispatcher for Arc<Dispatcher> {}
482impl OnDriverDispatcher for Weak<Dispatcher> {}
483
484impl<'a> AsyncDispatcher for DispatcherRef<'a> {
485 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
486 self.0.as_async_dispatcher_ref()
487 }
488}
489
490impl<'a> Clone for DispatcherRef<'a> {
491 fn clone(&self) -> Self {
492 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
493 }
494}
495
496impl<'a> core::ops::Deref for DispatcherRef<'a> {
497 type Target = Dispatcher;
498 fn deref(&self) -> &Self::Target {
499 &self.0
500 }
501}
502
503impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
504 fn deref_mut(&mut self) -> &mut Self::Target {
505 &mut self.0
506 }
507}
508
509impl<'a> OnDispatcher for DispatcherRef<'a> {
510 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
511 f(Some(self.as_async_dispatcher_ref()))
512 }
513}
514
515impl<'a> OnDriverDispatcher for DispatcherRef<'a> {}
516
517#[derive(Clone, Copy, Debug, Default, PartialEq)]
520pub struct CurrentDispatcher;
521
522impl OnDispatcher for CurrentDispatcher {
523 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
524 let dispatcher = OVERRIDE_DISPATCHER
525 .with(|global| *global.borrow())
526 .or_else(|| {
527 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
529 })
530 .map(|dispatcher| {
531 let async_dispatcher = NonNull::new(unsafe {
537 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
538 })
539 .expect("No async dispatcher on driver dispatcher");
540 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
541 });
542 f(dispatcher)
543 }
544}
545
546impl OnDriverDispatcher for CurrentDispatcher {}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 use std::sync::{Arc, Once, Weak, mpsc};
553
554 use futures::channel::mpsc as async_mpsc;
555 use futures::{SinkExt, StreamExt};
556 use zx::sys::ZX_OK;
557
558 use core::ffi::{c_char, c_void};
559 use core::ptr::null_mut;
560
561 static GLOBAL_DRIVER_ENV: Once = Once::new();
562 const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
563
564 pub fn ensure_driver_env() {
565 GLOBAL_DRIVER_ENV.call_once(|| {
566 unsafe {
569 assert_eq!(fdf_env_start(0), ZX_OK);
570 assert_eq!(
571 fdf_env_set_scheduler_role_opts(
572 NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
573 NO_SYNC_CALLS_ROLE.len(),
574 FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
575 ),
576 ZX_OK
577 );
578 }
579 });
580 }
581 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
582 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
583 }
584
585 pub(crate) fn with_raw_dispatcher_flags<T>(
586 name: &str,
587 flags: u32,
588 scheduler_role: &str,
589 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
590 ) -> T {
591 ensure_driver_env();
592
593 let (shutdown_tx, shutdown_rx) = mpsc::channel();
594 let mut dispatcher = null_mut();
595 let mut observer = ShutdownObserver::new(move |dispatcher| {
596 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
599 shutdown_tx.send(()).unwrap();
600 })
601 .into_ptr();
602 let driver_ptr = &mut observer as *mut _ as *mut c_void;
603 let res = unsafe {
608 fdf_env_dispatcher_create_with_owner(
609 driver_ptr,
610 flags,
611 name.as_ptr() as *const c_char,
612 name.len(),
613 scheduler_role.as_ptr() as *const c_char,
614 scheduler_role.len(),
615 observer,
616 &mut dispatcher,
617 )
618 };
619 assert_eq!(res, ZX_OK);
620 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
621
622 let res = p(Arc::downgrade(&dispatcher));
623
624 let weak_dispatcher = Arc::downgrade(&dispatcher);
628 drop(dispatcher);
629 shutdown_rx.recv().unwrap();
630 assert_eq!(
631 0,
632 weak_dispatcher.strong_count(),
633 "a dispatcher reference escaped the test body"
634 );
635
636 res
637 }
638
639 #[test]
640 fn start_test_dispatcher() {
641 with_raw_dispatcher("testing", |dispatcher| {
642 println!("hello {dispatcher:?}");
643 })
644 }
645
646 #[test]
647 fn post_task_on_dispatcher() {
648 with_raw_dispatcher("testing task", |dispatcher| {
649 let (tx, rx) = mpsc::channel();
650 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
651 dispatcher
652 .post_task_sync(move |status| {
653 assert_eq!(status, Status::from_raw(ZX_OK));
654 tx.send(status).unwrap();
655 })
656 .unwrap();
657 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
658 });
659 }
660
661 #[test]
662 fn post_task_on_subdispatcher() {
663 let (shutdown_tx, shutdown_rx) = mpsc::channel();
664 with_raw_dispatcher("testing task top level", move |dispatcher| {
665 let (tx, rx) = mpsc::channel();
666 let (inner_tx, inner_rx) = mpsc::channel();
667 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
668 dispatcher
669 .post_task_sync(move |status| {
670 assert_eq!(status, Status::from_raw(ZX_OK));
671 let inner = DispatcherBuilder::new()
672 .name("testing task second level")
673 .scheduler_role("")
674 .allow_thread_blocking()
675 .shutdown_observer(move |_dispatcher| {
676 println!("shutdown observer called");
677 shutdown_tx.send(1).unwrap();
678 })
679 .create()
680 .unwrap();
681 inner
682 .post_task_sync(move |status| {
683 assert_eq!(status, Status::from_raw(ZX_OK));
684 tx.send(status).unwrap();
685 })
686 .unwrap();
687 inner_tx.send(inner).unwrap();
691 })
692 .unwrap();
693 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
694 inner_rx.recv().unwrap();
695 });
696 assert_eq!(shutdown_rx.recv().unwrap(), 1);
697 }
698
699 #[test]
700 fn spawn_local_fails_on_normal_dispatcher() {
701 let (shutdown_tx, shutdown_rx) = mpsc::channel();
702 with_raw_dispatcher("spawn local failures", move |dispatcher| {
703 let inside_dispatcher = dispatcher.clone();
704 dispatcher
705 .spawn(async move {
706 assert_eq!(
707 inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
708 Status::BAD_STATE
709 );
710 assert_eq!(
711 inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
712 Status::BAD_STATE
713 );
714 shutdown_tx.send(()).unwrap();
715 })
716 .unwrap();
717 shutdown_rx.recv().unwrap();
718 });
719 }
720
721 #[test]
722 #[ignore = "Pending resolution of b/488397193"]
723 fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
724 let (tx, rx) = mpsc::channel();
725 with_raw_dispatcher_flags(
726 "spawn local success",
727 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
728 NO_SYNC_CALLS_ROLE,
729 move |dispatcher| {
730 let inside_dispatcher = dispatcher.clone();
731 dispatcher
732 .spawn(async move {
733 let tx_clone = tx.clone();
734 inside_dispatcher
735 .spawn_local(async move {
736 tx_clone.send(()).unwrap();
737 })
738 .unwrap();
739 inside_dispatcher
740 .compute_local(async move {
741 tx.send(()).unwrap();
742 })
743 .unwrap()
744 .await
745 .unwrap();
746 })
747 .unwrap();
748 rx.recv().unwrap();
750 rx.recv().unwrap();
751 },
752 );
753 }
754
755 #[test]
756 #[ignore = "Pending resolution of b/488397193"]
757 fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
758 with_raw_dispatcher_flags(
759 "spawn local success",
760 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
761 NO_SYNC_CALLS_ROLE,
762 move |dispatcher| {
763 assert_eq!(
766 dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
767 Status::BAD_STATE
768 );
769 assert_eq!(
770 dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
771 Status::BAD_STATE
772 );
773 },
774 );
775 }
776
777 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
778 println!("starting ping!");
779 tx.send(0).await.unwrap();
780 while let Some(next) = rx.next().await {
781 println!("ping! {next}");
782 tx.send(next + 1).await.unwrap();
783 }
784 }
785
786 async fn pong(
787 fin_tx: std::sync::mpsc::Sender<()>,
788 mut tx: async_mpsc::Sender<u8>,
789 mut rx: async_mpsc::Receiver<u8>,
790 ) {
791 println!("starting pong!");
792 while let Some(next) = rx.next().await {
793 println!("pong! {next}");
794 if next > 10 {
795 println!("bye!");
796 break;
797 }
798 tx.send(next + 1).await.unwrap();
799 }
800 fin_tx.send(()).unwrap();
801 }
802
803 #[test]
804 fn async_ping_pong() {
805 with_raw_dispatcher("async ping pong", |dispatcher| {
806 let (fin_tx, fin_rx) = mpsc::channel();
807 let (ping_tx, pong_rx) = async_mpsc::channel(10);
808 let (pong_tx, ping_rx) = async_mpsc::channel(10);
809 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
810 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
811
812 fin_rx.recv().expect("to receive final value");
813 });
814 }
815
816 async fn slow_pong(
817 fin_tx: std::sync::mpsc::Sender<()>,
818 mut tx: async_mpsc::Sender<u8>,
819 mut rx: async_mpsc::Receiver<u8>,
820 ) {
821 use zx::MonotonicDuration;
822 println!("starting pong!");
823 while let Some(next) = rx.next().await {
824 println!("pong! {next}");
825 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
826 MonotonicDuration::from_seconds(1),
827 ))
828 .await;
829 if next > 10 {
830 println!("bye!");
831 break;
832 }
833 tx.send(next + 1).await.unwrap();
834 }
835 fin_tx.send(()).unwrap();
836 }
837
838 #[test]
839 fn mixed_executor_async_ping_pong() {
840 with_raw_dispatcher("async ping pong", |dispatcher| {
841 let (fin_tx, fin_rx) = mpsc::channel();
842 let (ping_tx, pong_rx) = async_mpsc::channel(10);
843 let (pong_tx, ping_rx) = async_mpsc::channel(10);
844
845 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
847
848 let mut executor = fuchsia_async::LocalExecutor::default();
850 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
851
852 fin_rx.recv().expect("to receive final value");
853 });
854 }
855}