1use std::sync::{mpsc, Arc};
8
9use fdf::{CurrentDispatcher, OnDispatcher};
10
11use super::*;
12
13pub fn run_in_driver<T: Send + 'static>(name: &str, p: impl FnOnce() -> T + Send + 'static) -> T {
16 run_in_driver_raw(name, true, false, |tx| tx.send(p()).unwrap())
17}
18
19pub fn run_in_driver_etc<T: Send + 'static>(
22 name: &str,
23 allow_thread_blocking: bool,
24 unsynchronized: bool,
25 p: impl FnOnce() -> T + Send + 'static,
26) -> T {
27 run_in_driver_raw(name, allow_thread_blocking, unsynchronized, |tx| tx.send(p()).unwrap())
28}
29
30pub fn spawn_in_driver<T: Send + 'static>(
33 name: &str,
34 p: impl Future<Output = T> + Send + 'static,
35) -> T {
36 run_in_driver_raw(name, true, false, |tx| {
37 CurrentDispatcher.spawn_task(async move { tx.send(p.await).unwrap() }).unwrap();
38 })
39}
40
41pub fn spawn_in_driver_etc<T: Send + 'static>(
44 name: &str,
45 allow_thread_blocking: bool,
46 unsynchronized: bool,
47 p: impl Future<Output = T> + Send + 'static,
48) -> T {
49 run_in_driver_raw(name, allow_thread_blocking, unsynchronized, |tx| {
50 CurrentDispatcher.spawn_task(async move { tx.send(p.await).unwrap() }).unwrap();
51 })
52}
53
54fn run_in_driver_raw<T: Send + 'static>(
55 name: &str,
56 allow_thread_blocking: bool,
57 unsynchronized: bool,
58 p: impl FnOnce(mpsc::Sender<T>) + Send + 'static,
59) -> T {
60 let env = Arc::new(Environment::start(0).unwrap());
61 let env_clone = env.clone();
62
63 let (shutdown_tx, shutdown_rx) = mpsc::channel();
64 let driver_value: u32 = 0x1337;
65 let driver_value_ptr = &driver_value as *const u32;
66 let driver = env.new_driver(driver_value_ptr);
67 let dispatcher = DispatcherBuilder::new().name(name);
68 let dispatcher =
69 if allow_thread_blocking { dispatcher.allow_thread_blocking() } else { dispatcher };
70 let dispatcher = if unsynchronized { dispatcher.unsynchronized() } else { dispatcher };
71 let dispatcher = dispatcher.shutdown_observer(move |dispatcher| {
72 assert!(!env_clone.dispatcher_has_queued_tasks(dispatcher.as_dispatcher_ref()));
75 });
76 let dispatcher = driver.new_dispatcher(dispatcher).unwrap();
77
78 let (finished_tx, finished_rx) = mpsc::channel();
79 dispatcher
80 .post_task_sync(move |_| {
81 p(finished_tx);
82 })
83 .unwrap();
84 let res = finished_rx.recv().unwrap();
85
86 driver.shutdown(move |driver| {
87 assert!(unsafe { *driver.0 } == 0x1337);
89 shutdown_tx.send(()).unwrap();
90 });
91
92 shutdown_rx.recv().unwrap();
93
94 res
95}