fuchsia_async/runtime/fuchsia/executor/atomic_future/
hooks.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::instrument::Hooks;
6
7use super::{AtomicFutureHandle, Meta, VTable};
8use fuchsia_sync::Mutex;
9use std::collections::HashMap;
10use std::ptr::NonNull;
11use std::task::{Context, Poll};
12
13/// We don't want to pay a cost if there are no hooks, so we store a mapping from task ID to
14/// HooksWrapper in the executor.
15#[derive(Default)]
16pub struct HooksMap(Mutex<HashMap<usize, NonNull<()>>>);
17
18unsafe impl Send for HooksMap {}
19unsafe impl Sync for HooksMap {}
20
21struct HooksWrapper<H> {
22    orig_vtable: &'static VTable,
23    hooks: H,
24}
25
26impl<H: Hooks> HooksWrapper<H> {
27    // # Safety
28    //
29    // We rely on the fact that all these functions are called whilst we have exclusive
30    // access to the underlying future and the associated Hooks object.
31    const VTABLE: VTable = VTable {
32        drop: Self::drop,
33        drop_future: Self::drop_future,
34        poll: Self::poll,
35        get_result: Self::get_result,
36        drop_result: Self::drop_result,
37    };
38
39    // Returns a mutable reference to the wrapper. This will be safe from the functions below
40    // because they are all called when we have exclusive access.
41    unsafe fn wrapper<'a>(meta: NonNull<Meta>) -> &'a mut Self {
42        unsafe {
43            meta.as_ref()
44                .scope()
45                .executor()
46                .hooks_map
47                .0
48                .lock()
49                .get(&meta.as_ref().id)
50                .unwrap()
51                .cast::<Self>()
52                .as_mut()
53        }
54    }
55
56    unsafe fn drop(mut meta: NonNull<Meta>) {
57        let meta_ref = unsafe { meta.as_mut() };
58        // Remove the hooks entry from the map.
59        let hooks = unsafe {
60            Box::from_raw(
61                meta_ref
62                    .scope()
63                    .executor()
64                    .hooks_map
65                    .0
66                    .lock()
67                    .remove(&meta_ref.id)
68                    .unwrap()
69                    .cast::<Self>()
70                    .as_mut(),
71            )
72        };
73        // Restore the vtable because the drop implementation can call `drop_future` or
74        // `drop_result`, but we've removed the hooks from the map now.
75        meta_ref.vtable = hooks.orig_vtable;
76        unsafe { (hooks.orig_vtable.drop)(meta) };
77    }
78
79    unsafe fn poll(meta: NonNull<Meta>, cx: &mut Context<'_>) -> Poll<()> {
80        let wrapper = unsafe { Self::wrapper(meta) };
81        wrapper.hooks.task_poll_start();
82        let result = unsafe { (wrapper.orig_vtable.poll)(meta, cx) };
83        wrapper.hooks.task_poll_end();
84        if result.is_ready() {
85            wrapper.hooks.task_completed();
86        }
87        result
88    }
89
90    unsafe fn drop_future(meta: NonNull<Meta>) {
91        unsafe { (Self::wrapper(meta).orig_vtable.drop_future)(meta) };
92    }
93
94    unsafe fn get_result(meta: NonNull<Meta>) -> *const () {
95        unsafe { (Self::wrapper(meta).orig_vtable.get_result)(meta) }
96    }
97
98    unsafe fn drop_result(meta: NonNull<Meta>) {
99        unsafe { (Self::wrapper(meta).orig_vtable.drop_result)(meta) };
100    }
101}
102
103impl AtomicFutureHandle<'_> {
104    /// Adds hooks to the future.
105    pub fn add_hooks<H: Hooks>(&mut self, hooks: H) {
106        // SAFETY: This is safe because we have exclusive access.
107        let meta: &mut Meta = unsafe { self.0.as_mut() };
108        {
109            let mut hooks_map = meta.scope().executor().hooks_map.0.lock();
110            // SAFETY: Safe because `Box::into_raw` is guaranteed to give is a non-null pointer. We
111            // can use `Box::into_non_null` when it's stabilised.
112            assert!(
113                hooks_map
114                    .insert(meta.id, unsafe {
115                        NonNull::new_unchecked(Box::into_raw(Box::new(HooksWrapper {
116                            orig_vtable: meta.vtable,
117                            hooks,
118                        })))
119                        .cast::<()>()
120                    })
121                    .is_none()
122            );
123        }
124        // Inject our vtable.
125        meta.vtable = &HooksWrapper::<H>::VTABLE;
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::Hooks;
132    use crate::runtime::fuchsia::executor::scope::Spawnable;
133    use crate::{SpawnableFuture, TestExecutor, yield_now};
134    use std::sync::Arc;
135    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
136
137    #[test]
138    fn test_hooks() {
139        let mut executor = TestExecutor::new();
140        let scope = executor.global_scope();
141        let mut future = SpawnableFuture::new(async {
142            yield_now().await;
143        })
144        .into_task(scope.clone());
145        #[derive(Default)]
146        struct MyHooks {
147            poll_start: AtomicU32,
148            poll_end: AtomicU32,
149            completed: AtomicBool,
150        }
151        impl Hooks for Arc<MyHooks> {
152            fn task_completed(&mut self) {
153                assert!(!self.completed.load(Ordering::Relaxed));
154                self.completed.store(true, Ordering::Relaxed);
155            }
156            fn task_poll_start(&mut self) {
157                self.poll_start.fetch_add(1, Ordering::Relaxed);
158            }
159            fn task_poll_end(&mut self) {
160                self.poll_end.fetch_add(1, Ordering::Relaxed);
161            }
162        }
163        let my_hooks = Arc::new(MyHooks::default());
164        future.add_hooks(my_hooks.clone());
165        scope.insert_task(future, false);
166        assert!(executor.run_until_stalled(&mut std::future::pending::<()>()).is_pending());
167        assert_eq!(my_hooks.poll_start.load(Ordering::Relaxed), 2);
168        assert_eq!(my_hooks.poll_end.load(Ordering::Relaxed), 2);
169        assert!(my_hooks.completed.load(Ordering::Relaxed));
170    }
171}