Skip to main content

fdf_core/
dispatcher.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the driver runtime dispatcher stable ABI
6
7use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::atomic::{AtomicPtr, Ordering};
15use std::sync::{Arc, Weak};
16
17use zx::Status;
18
19use crate::shutdown_observer::ShutdownObserver;
20
21pub use fdf_sys::fdf_dispatcher_t;
22pub use libasync::{
23    AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, JoinHandle, OnDispatcher, Task,
24};
25
26/// A marker trait for a function type that can be used as a shutdown observer for [`Dispatcher`].
27pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30/// A builder for [`Dispatcher`]s
31#[derive(Default)]
32pub struct DispatcherBuilder {
33    #[doc(hidden)]
34    pub options: u32,
35    #[doc(hidden)]
36    pub name: String,
37    #[doc(hidden)]
38    pub scheduler_role: String,
39    #[doc(hidden)]
40    pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44    /// See `FDF_DISPATCHER_OPTION_UNSYNCHRONIZED` in the C API
45    pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46    /// See `FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS` in the C API
47    pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48    /// See `FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION` in the C API
49    pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
50
51    /// Creates a new [`DispatcherBuilder`] that can be used to configure a new dispatcher.
52    /// For more information on the threading-related flags for the dispatcher, see
53    /// https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Sets whether parallel callbacks in the callbacks set in the dispatcher are allowed. May
59    /// not be set with [`Self::allow_thread_blocking`].
60    ///
61    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
62    /// for more information on the threading model of driver dispatchers.
63    pub fn unsynchronized(mut self) -> Self {
64        assert!(
65            !self.allows_thread_blocking(),
66            "you may not create an unsynchronized dispatcher that allows synchronous calls"
67        );
68        self.options |= Self::UNSYNCHRONIZED;
69        self
70    }
71
72    /// Whether or not this is an unsynchronized dispatcher
73    pub fn is_unsynchronized(&self) -> bool {
74        (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
75    }
76
77    /// This dispatcher may not share zircon threads with other drivers. May not be set with
78    /// [`Self::unsynchronized`].
79    ///
80    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
81    /// for more information on the threading model of driver dispatchers.
82    pub fn allow_thread_blocking(mut self) -> Self {
83        assert!(
84            !self.is_unsynchronized(),
85            "you may not create an unsynchronized dispatcher that allows synchronous calls"
86        );
87        self.options |= Self::ALLOW_THREAD_BLOCKING;
88        self
89    }
90
91    /// Whether or not this dispatcher allows synchronous calls
92    pub fn allows_thread_blocking(&self) -> bool {
93        (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
94    }
95
96    /// This dispatcher may not run on more than one thread. This can only be set if the
97    /// dispatcher is being run on a scheduler role that does not allow sync calls on
98    /// any of its dispatchers.
99    ///
100    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
101    /// for more information on the threading model of driver dispatchers.
102    pub fn no_thread_migration(mut self) -> Self {
103        self.options |= Self::NO_THREAD_MIGRATION;
104        self
105    }
106
107    /// Whether or not this dispatcher is allowed to run on multiple threads
108    pub fn allows_thread_migration(&self) -> bool {
109        (self.options & Self::NO_THREAD_MIGRATION) == 0
110    }
111
112    /// A descriptive name for this dispatcher that is used in debug output and process
113    /// lists.
114    pub fn name(mut self, name: &str) -> Self {
115        self.name = name.to_string();
116        self
117    }
118
119    /// A hint string for the runtime that may or may not impact the priority the work scheduled
120    /// by this dispatcher is handled at. It may or may not impact the ability for other drivers
121    /// to share zircon threads with the dispatcher.
122    pub fn scheduler_role(mut self, role: &str) -> Self {
123        self.scheduler_role = role.to_string();
124        self
125    }
126
127    /// A callback to be called before after the dispatcher has completed asynchronous shutdown.
128    pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
129        self.shutdown_observer = Some(Box::new(shutdown_observer));
130        self
131    }
132
133    /// Create the dispatcher as configured by this object. This must be called from a
134    /// thread managed by the driver runtime. The dispatcher returned is owned by the caller,
135    /// and will initiate asynchronous shutdown when the object is dropped unless
136    /// [`Dispatcher::release`] is called on it to convert it into an unowned [`DispatcherRef`].
137    pub fn create(self) -> Result<Dispatcher, Status> {
138        let mut out_dispatcher = null_mut();
139        let options = self.options;
140        let name = self.name.as_ptr() as *mut ffi::c_char;
141        let name_len = self.name.len();
142        let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
143        let scheduler_role_len = self.scheduler_role.len();
144        let observer =
145            ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
146                .into_ptr();
147        // SAFETY: all arguments point to memory that will be available for the duration
148        // of the call, except `observer`, which will be available until it is unallocated
149        // by the dispatcher exit handler.
150        Status::ok(unsafe {
151            fdf_dispatcher_create(
152                options,
153                name,
154                name_len,
155                scheduler_role,
156                scheduler_role_len,
157                observer,
158                &mut out_dispatcher,
159            )
160        })?;
161        // SAFETY: `out_dispatcher` is valid by construction if `fdf_dispatcher_create` returns
162        // ZX_OK.
163        Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
164    }
165
166    /// As with [`Self::create`], this creates a new dispatcher as configured by this object, but
167    /// instead of returning an owned reference it immediately releases the reference to be
168    /// managed by the driver runtime.
169    pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
170        self.create().map(Dispatcher::release)
171    }
172}
173
174/// An owned handle for a dispatcher managed by the driver runtime.
175#[derive(Debug)]
176pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
177
178// SAFETY: The api of fdf_dispatcher_t is thread safe.
179unsafe impl Send for Dispatcher {}
180unsafe impl Sync for Dispatcher {}
181thread_local! {
182    pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
183}
184
185impl Dispatcher {
186    /// Creates a dispatcher ref from a raw handle.
187    ///
188    /// # Safety
189    ///
190    /// Caller is responsible for ensuring that the given handle is valid and
191    /// not owned by any other wrapper that will free it at an arbitrary
192    /// time.
193    pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
194        Self(handle)
195    }
196
197    fn get_raw_flags(&self) -> u32 {
198        // SAFETY: the inner fdf_dispatcher_t is valid by construction
199        unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
200    }
201
202    /// Whether this dispatcher's tasks and futures can run on multiple threads at the same time.
203    pub fn is_unsynchronized(&self) -> bool {
204        (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
205    }
206
207    /// Whether this dispatcher is allowed to call blocking functions or not
208    pub fn allows_thread_blocking(&self) -> bool {
209        (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
210    }
211
212    /// Whether this dispatcher is allowed to migrate threads, in which case it can't
213    /// be used for non-[`Send`] tasks.
214    pub fn allows_thread_migration(&self) -> bool {
215        (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
216    }
217
218    /// Whether this is the dispatcher the current thread is running on
219    pub fn is_current_dispatcher(&self) -> bool {
220        // SAFETY: we don't do anything with the dispatcher pointer, and NULL is returned if this
221        // isn't a dispatcher-managed thread.
222        self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
223    }
224
225    /// Releases ownership over this dispatcher and returns a [`DispatcherRef`]
226    /// that can be used to access it. The lifetime of this reference is static because it will
227    /// exist so long as this current driver is loaded, but the driver runtime will shut it down
228    /// when the driver is unloaded.
229    pub fn release(self) -> DispatcherRef<'static> {
230        DispatcherRef(ManuallyDrop::new(self), PhantomData)
231    }
232
233    /// Returns a [`DispatcherRef`] that references this dispatcher with a lifetime constrained by
234    /// `self`.
235    pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
236        DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
237    }
238}
239
240impl AsyncDispatcher for Dispatcher {
241    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
242        let async_dispatcher =
243            NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
244                .expect("No async dispatcher on driver dispatcher");
245        unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
246    }
247}
248
249impl Drop for Dispatcher {
250    fn drop(&mut self) {
251        // SAFETY: we only ever provide an owned `Dispatcher` to one owner, so when
252        // that one is dropped we can invoke the shutdown of the dispatcher
253        unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
254    }
255}
256
257/// An owned reference to a driver runtime dispatcher that auto-releases when dropped. This gives
258/// you the best of both worlds of having an `Arc<Dispatcher>` and a `DispatcherRef<'static>`
259/// created by [`Dispatcher::release`]:
260///
261/// - You can vend [`Weak`]-like pointers to it that will not cause memory access errors if used
262///   after the dispatcher has shut down, like an [`Arc`].
263/// - You can tie its terminal lifetime to that of the driver itself.
264///
265/// This is particularly useful in tests.
266#[derive(Debug)]
267pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
268
269impl AutoReleaseDispatcher {
270    /// Returns a weakened reference to this dispatcher. This weak reference will only be valid so
271    /// long as the [`AutoReleaseDispatcher`] object that spawned it is alive, after which it will
272    /// no longer be usable to spawn tasks on.
273    pub fn downgrade(&self) -> WeakDispatcher {
274        WeakDispatcher::from(self)
275    }
276}
277
278impl AsyncDispatcher for AutoReleaseDispatcher {
279    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
280        let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
281            .expect("tried to obtain async dispatcher after drop");
282        // SAFETY: the validity of this dispatcher is ensured by use of NonNull above and this
283        // object's exclusive ownership over the dispatcher while it's alive.
284        unsafe {
285            AsyncDispatcherRef::from_raw(
286                NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
287            )
288        }
289    }
290}
291
292impl From<Dispatcher> for AutoReleaseDispatcher {
293    fn from(dispatcher: Dispatcher) -> Self {
294        let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
295        Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
296    }
297}
298
299impl Drop for AutoReleaseDispatcher {
300    fn drop(&mut self) {
301        // Store nullptr into the atomic so that any future attempts to obtain a strong reference
302        // through a WeakDispatcher will not successfully upgrade.
303        self.0.store(null_mut(), Ordering::Relaxed);
304        // We want to allow for any outstanding `on_dispatcher` calls to finish before returning
305        // from drop, so we're going to loop until the strong reference count goes down to zero,
306        // after which any future attempts to call `on_dispatcher` on a `WeakDispatcher` will fail.
307        while Arc::strong_count(&self.0) > 1 {
308            // This sleep is kind of gross, but it should happen extremely rarely and
309            // `on_dispatcher` calls should not be performing any blocking work.
310            std::thread::sleep(std::time::Duration::from_nanos(100))
311        }
312    }
313}
314
315/// An unowned but reference counted reference to a dispatcher. This would usually come from
316/// an [`AutoReleaseDispatcher`] reference to a dispatcher.
317///
318/// The advantage to using this instead of using [`Weak`] directly is that it controls the lifetime
319/// of any given strong reference to the dispatcher, since the only way to access that strong
320/// reference is through [`OnDispatcher::on_dispatcher`]. This makes it much easier to be sure
321/// that you aren't leaving any dangling strong references to the dispatcher object around.
322#[derive(Clone, Debug)]
323pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
324
325impl From<&AutoReleaseDispatcher> for WeakDispatcher {
326    fn from(value: &AutoReleaseDispatcher) -> Self {
327        Self(Arc::downgrade(&value.0))
328    }
329}
330
331impl OnDispatcher for WeakDispatcher {
332    fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
333        let Some(dispatcher_ptr) = self.0.upgrade() else {
334            return f(None);
335        };
336        let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
337            return f(None);
338        };
339        // SAFETY: As long as we hold the strong reference in dispatcher_ptr, the
340        // AutoReleaseDispatcher will not allow its drop to finish and the dispatcher should still
341        // be valid.
342        f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
343    }
344}
345
346impl OnDriverDispatcher for WeakDispatcher {}
347
348/// An unowned reference to a driver runtime dispatcher such as is produced by calling
349/// [`Dispatcher::release`]. When this object goes out of scope it won't shut down the dispatcher,
350/// leaving that up to the driver runtime or another owner.
351#[derive(Debug)]
352pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
353
354impl<'a> DispatcherRef<'a> {
355    /// Creates a dispatcher ref from a raw handle.
356    ///
357    /// # Safety
358    ///
359    /// Caller is responsible for ensuring that the given handle is valid for
360    /// the lifetime `'a`.
361    pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
362        // SAFETY: Caller promises the handle is valid.
363        Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
364    }
365
366    /// Creates a dispatcher ref from an [`AsyncDispatcherRef`].
367    ///
368    /// # Panics
369    ///
370    /// Note that this will cause an assert if the [`AsyncDispatcherRef`] was not created from a
371    /// driver dispatcher in the first place.
372    pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
373        let handle = NonNull::new(unsafe {
374            fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
375        })
376        .unwrap();
377        unsafe { Self::from_raw(handle) }
378    }
379
380    /// Gets the raw handle from this dispatcher ref.
381    ///
382    /// # Safety
383    ///
384    /// Caller is responsible for ensuring that the dispatcher handle is used safely.
385    pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
386        unsafe { self.0.0.as_mut() }
387    }
388}
389
390/// Used to wrap a non-send future as send when we've dynamically checked that the dispatcher
391/// we're going to spawn it on is non-[`Send`]-safe.
392///
393/// This should only ever be used after validating that the dispatcher is the currently running
394/// one and that the dispatcher does not migrate threads.
395///
396/// This is an internal implementation detail and should never be made public.
397struct AddSendFuture<T>(T);
398
399impl<T: Future> Future for AddSendFuture<T> {
400    type Output = T::Output;
401
402    fn poll(
403        self: std::pin::Pin<&mut Self>,
404        cx: &mut std::task::Context<'_>,
405    ) -> std::task::Poll<Self::Output> {
406        // SAFETY: self.0 is pinned if self is.
407        let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
408        fut.poll(cx)
409    }
410}
411
412// SAFETY: We are forcing this future to be [`Send`] even though the inner future is not because
413// we validate at runtime before spawning the task that the dispatcher is correctly configured to
414// do the right thing with it.
415unsafe impl<T> Send for AddSendFuture<T> {}
416
417/// Makes available additional functionality available on driver dispatchers on top of what's
418/// available on [`OnDispatcher`].
419pub trait OnDriverDispatcher: OnDispatcher {
420    /// Spawn an asynchronous local task on this dispatcher. If this returns [`Ok`] then the task
421    /// has successfully been scheduled and will run or be cancelled and dropped when the dispatcher
422    /// shuts down. The returned future's result will be [`Ok`] if the future completed
423    /// successfully, or an [`Err`] if the task did not complete for some reason (like the
424    /// dispatcher shut down).
425    ///
426    /// Unlike [`OnDispatcher::spawn`], this will accept a future that does not implement [`Send`]. If
427    /// called from a thread other than the one the dispatcher is running on or the dispatcher
428    /// is not guaranteed to always poll from the same thread, this will return
429    /// [`Status::BAD_STATE`].
430    ///
431    /// Returns a [`JoinHandle`] that will detach the future when dropped.
432    fn spawn_local(
433        &self,
434        future: impl Future<Output = ()> + 'static,
435    ) -> Result<JoinHandle<()>, Status>
436    where
437        Self: 'static,
438    {
439        self.on_maybe_dispatcher(|dispatcher| {
440            let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
441            if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
442                OnDispatcher::spawn(self, AddSendFuture(future))
443            } else {
444                Err(Status::BAD_STATE)
445            }
446        })
447    }
448
449    /// Spawn a local asynchronous task that outputs type 'T' on this dispatcher. The returned future's
450    /// result will be [`Ok`] if the task was started and completed successfully, or an [`Err`] if
451    /// the task couldn't be started or failed to complete (for example because the dispatcher was
452    /// shutting down).
453    ///
454    /// Returns a [`Task`] that will cancel the future when dropped.
455    ///
456    /// Unlike [`OnDispatcher::compute`], this will accept a future that does not implement [`Send`]. If
457    /// called from a thread other than the one the dispatcher is running on or the dispatcher
458    /// is not guaranteed to always poll from the same thread, this will return
459    /// [`Status::BAD_STATE`].
460    ///
461    /// TODO(470088116): This may be the cause of some flakes, so care should be used with it
462    /// in critical paths for now.
463    fn compute_local<T: Send + 'static>(
464        &self,
465        future: impl Future<Output = T> + 'static,
466    ) -> Result<Task<T>, Status>
467    where
468        Self: 'static,
469    {
470        self.on_maybe_dispatcher(|dispatcher| {
471            let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
472            if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
473                Ok(OnDispatcher::compute(self, AddSendFuture(future)))
474            } else {
475                Err(Status::BAD_STATE)
476            }
477        })
478    }
479}
480
481impl OnDriverDispatcher for Arc<Dispatcher> {}
482impl OnDriverDispatcher for Weak<Dispatcher> {}
483
484impl<'a> AsyncDispatcher for DispatcherRef<'a> {
485    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
486        self.0.as_async_dispatcher_ref()
487    }
488}
489
490impl<'a> Clone for DispatcherRef<'a> {
491    fn clone(&self) -> Self {
492        Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
493    }
494}
495
496impl<'a> core::ops::Deref for DispatcherRef<'a> {
497    type Target = Dispatcher;
498    fn deref(&self) -> &Self::Target {
499        &self.0
500    }
501}
502
503impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
504    fn deref_mut(&mut self) -> &mut Self::Target {
505        &mut self.0
506    }
507}
508
509impl<'a> OnDispatcher for DispatcherRef<'a> {
510    fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
511        f(Some(self.as_async_dispatcher_ref()))
512    }
513}
514
515impl<'a> OnDriverDispatcher for DispatcherRef<'a> {}
516
517/// A placeholder for the currently active dispatcher. Use [`OnDispatcher::on_dispatcher`] to
518/// access it when needed.
519#[derive(Clone, Copy, Debug, Default, PartialEq)]
520pub struct CurrentDispatcher;
521
522impl OnDispatcher for CurrentDispatcher {
523    fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
524        let dispatcher = OVERRIDE_DISPATCHER
525            .with(|global| *global.borrow())
526            .or_else(|| {
527                // SAFETY: NonNull::new will null-check that we have a current dispatcher.
528                NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
529            })
530            .map(|dispatcher| {
531                // SAFETY: We constrain the lifetime of the `DispatcherRef` we provide to the
532                // function below to the span of the current function. Since we are running on
533                // the dispatcher, or another dispatcher that is bound to the same lifetime (through
534                // override_dispatcher), we can be sure that the dispatcher will not be shut
535                // down before that function completes.
536                let async_dispatcher = NonNull::new(unsafe {
537                    fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
538                })
539                .expect("No async dispatcher on driver dispatcher");
540                unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
541            });
542        f(dispatcher)
543    }
544}
545
546impl OnDriverDispatcher for CurrentDispatcher {}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551
552    use std::sync::{Arc, Once, Weak, mpsc};
553
554    use futures::channel::mpsc as async_mpsc;
555    use futures::{SinkExt, StreamExt};
556    use zx::sys::ZX_OK;
557
558    use core::ffi::{c_char, c_void};
559    use core::ptr::null_mut;
560
561    static GLOBAL_DRIVER_ENV: Once = Once::new();
562    const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
563
564    pub fn ensure_driver_env() {
565        GLOBAL_DRIVER_ENV.call_once(|| {
566            // SAFETY: calling fdf_env_start, which does not have any soundness
567            // concerns for rust code, and this is only used in tests.
568            unsafe {
569                assert_eq!(fdf_env_start(0), ZX_OK);
570                assert_eq!(
571                    fdf_env_set_scheduler_role_opts(
572                        NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
573                        NO_SYNC_CALLS_ROLE.len(),
574                        FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
575                    ),
576                    ZX_OK
577                );
578            }
579        });
580    }
581    pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
582        with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
583    }
584
585    pub(crate) fn with_raw_dispatcher_flags<T>(
586        name: &str,
587        flags: u32,
588        scheduler_role: &str,
589        p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
590    ) -> T {
591        ensure_driver_env();
592
593        let (shutdown_tx, shutdown_rx) = mpsc::channel();
594        let mut dispatcher = null_mut();
595        let mut observer = ShutdownObserver::new(move |dispatcher| {
596            // SAFETY: we verify that the dispatcher has no tasks left queued in it,
597            // just because this is testing code.
598            assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
599            shutdown_tx.send(()).unwrap();
600        })
601        .into_ptr();
602        let driver_ptr = &mut observer as *mut _ as *mut c_void;
603        // SAFETY: The pointers we pass to this function are all stable for the
604        // duration of this function, and are not available to copy or clone to
605        // client code (only through a ref to the non-`Clone`` `Dispatcher`
606        // wrapper).
607        let res = unsafe {
608            fdf_env_dispatcher_create_with_owner(
609                driver_ptr,
610                flags,
611                name.as_ptr() as *const c_char,
612                name.len(),
613                scheduler_role.as_ptr() as *const c_char,
614                scheduler_role.len(),
615                observer,
616                &mut dispatcher,
617            )
618        };
619        assert_eq!(res, ZX_OK);
620        let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
621
622        let res = p(Arc::downgrade(&dispatcher));
623
624        // this initiates the dispatcher shutdown on a driver runtime
625        // thread. When all tasks on the dispatcher have completed, the wait
626        // on the shutdown_rx below will end and we can tear it down.
627        let weak_dispatcher = Arc::downgrade(&dispatcher);
628        drop(dispatcher);
629        shutdown_rx.recv().unwrap();
630        assert_eq!(
631            0,
632            weak_dispatcher.strong_count(),
633            "a dispatcher reference escaped the test body"
634        );
635
636        res
637    }
638
639    #[test]
640    fn start_test_dispatcher() {
641        with_raw_dispatcher("testing", |dispatcher| {
642            println!("hello {dispatcher:?}");
643        })
644    }
645
646    #[test]
647    fn post_task_on_dispatcher() {
648        with_raw_dispatcher("testing task", |dispatcher| {
649            let (tx, rx) = mpsc::channel();
650            let dispatcher = Weak::upgrade(&dispatcher).unwrap();
651            dispatcher
652                .post_task_sync(move |status| {
653                    assert_eq!(status, Status::from_raw(ZX_OK));
654                    tx.send(status).unwrap();
655                })
656                .unwrap();
657            assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
658        });
659    }
660
661    #[test]
662    fn post_task_on_subdispatcher() {
663        let (shutdown_tx, shutdown_rx) = mpsc::channel();
664        with_raw_dispatcher("testing task top level", move |dispatcher| {
665            let (tx, rx) = mpsc::channel();
666            let (inner_tx, inner_rx) = mpsc::channel();
667            let dispatcher = Weak::upgrade(&dispatcher).unwrap();
668            dispatcher
669                .post_task_sync(move |status| {
670                    assert_eq!(status, Status::from_raw(ZX_OK));
671                    let inner = DispatcherBuilder::new()
672                        .name("testing task second level")
673                        .scheduler_role("")
674                        .allow_thread_blocking()
675                        .shutdown_observer(move |_dispatcher| {
676                            println!("shutdown observer called");
677                            shutdown_tx.send(1).unwrap();
678                        })
679                        .create()
680                        .unwrap();
681                    inner
682                        .post_task_sync(move |status| {
683                            assert_eq!(status, Status::from_raw(ZX_OK));
684                            tx.send(status).unwrap();
685                        })
686                        .unwrap();
687                    // we want to make sure the inner dispatcher lives long
688                    // enough to run the task, so we sent it out to the outer
689                    // closure.
690                    inner_tx.send(inner).unwrap();
691                })
692                .unwrap();
693            assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
694            inner_rx.recv().unwrap();
695        });
696        assert_eq!(shutdown_rx.recv().unwrap(), 1);
697    }
698
699    #[test]
700    fn spawn_local_fails_on_normal_dispatcher() {
701        let (shutdown_tx, shutdown_rx) = mpsc::channel();
702        with_raw_dispatcher("spawn local failures", move |dispatcher| {
703            let inside_dispatcher = dispatcher.clone();
704            dispatcher
705                .spawn(async move {
706                    assert_eq!(
707                        inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
708                        Status::BAD_STATE
709                    );
710                    assert_eq!(
711                        inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
712                        Status::BAD_STATE
713                    );
714                    shutdown_tx.send(()).unwrap();
715                })
716                .unwrap();
717            shutdown_rx.recv().unwrap();
718        });
719    }
720
721    #[test]
722    #[ignore = "Pending resolution of b/488397193"]
723    fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
724        let (tx, rx) = mpsc::channel();
725        with_raw_dispatcher_flags(
726            "spawn local success",
727            FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
728            NO_SYNC_CALLS_ROLE,
729            move |dispatcher| {
730                let inside_dispatcher = dispatcher.clone();
731                dispatcher
732                    .spawn(async move {
733                        let tx_clone = tx.clone();
734                        inside_dispatcher
735                            .spawn_local(async move {
736                                tx_clone.send(()).unwrap();
737                            })
738                            .unwrap();
739                        inside_dispatcher
740                            .compute_local(async move {
741                                tx.send(()).unwrap();
742                            })
743                            .unwrap()
744                            .await
745                            .unwrap();
746                    })
747                    .unwrap();
748                // one empty object received each for spawn and compute _local.
749                rx.recv().unwrap();
750                rx.recv().unwrap();
751            },
752        );
753    }
754
755    #[test]
756    #[ignore = "Pending resolution of b/488397193"]
757    fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
758        with_raw_dispatcher_flags(
759            "spawn local success",
760            FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
761            NO_SYNC_CALLS_ROLE,
762            move |dispatcher| {
763                // we are not currently running in any dispatcher here, so this is a context
764                // where the 'current dispatcher' is definitely not the one in question.
765                assert_eq!(
766                    dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
767                    Status::BAD_STATE
768                );
769                assert_eq!(
770                    dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
771                    Status::BAD_STATE
772                );
773            },
774        );
775    }
776
777    async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
778        println!("starting ping!");
779        tx.send(0).await.unwrap();
780        while let Some(next) = rx.next().await {
781            println!("ping! {next}");
782            tx.send(next + 1).await.unwrap();
783        }
784    }
785
786    async fn pong(
787        fin_tx: std::sync::mpsc::Sender<()>,
788        mut tx: async_mpsc::Sender<u8>,
789        mut rx: async_mpsc::Receiver<u8>,
790    ) {
791        println!("starting pong!");
792        while let Some(next) = rx.next().await {
793            println!("pong! {next}");
794            if next > 10 {
795                println!("bye!");
796                break;
797            }
798            tx.send(next + 1).await.unwrap();
799        }
800        fin_tx.send(()).unwrap();
801    }
802
803    #[test]
804    fn async_ping_pong() {
805        with_raw_dispatcher("async ping pong", |dispatcher| {
806            let (fin_tx, fin_rx) = mpsc::channel();
807            let (ping_tx, pong_rx) = async_mpsc::channel(10);
808            let (pong_tx, ping_rx) = async_mpsc::channel(10);
809            dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
810            dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
811
812            fin_rx.recv().expect("to receive final value");
813        });
814    }
815
816    async fn slow_pong(
817        fin_tx: std::sync::mpsc::Sender<()>,
818        mut tx: async_mpsc::Sender<u8>,
819        mut rx: async_mpsc::Receiver<u8>,
820    ) {
821        use zx::MonotonicDuration;
822        println!("starting pong!");
823        while let Some(next) = rx.next().await {
824            println!("pong! {next}");
825            fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
826                MonotonicDuration::from_seconds(1),
827            ))
828            .await;
829            if next > 10 {
830                println!("bye!");
831                break;
832            }
833            tx.send(next + 1).await.unwrap();
834        }
835        fin_tx.send(()).unwrap();
836    }
837
838    #[test]
839    fn mixed_executor_async_ping_pong() {
840        with_raw_dispatcher("async ping pong", |dispatcher| {
841            let (fin_tx, fin_rx) = mpsc::channel();
842            let (ping_tx, pong_rx) = async_mpsc::channel(10);
843            let (pong_tx, ping_rx) = async_mpsc::channel(10);
844
845            // spawn ping on the driver dispatcher
846            dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
847
848            // and run pong on the fuchsia_async executor
849            let mut executor = fuchsia_async::LocalExecutor::default();
850            executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
851
852            fin_rx.recv().expect("to receive final value");
853        });
854    }
855}