Skip to main content

driver_manager_shutdown/
node_removal_tracker.rs

1// Copyright 2026 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use driver_manager_driver_host::DriverHost;
6use driver_manager_types::{Collection, ShutdownState};
7use fuchsia_async as fasync;
8use futures::channel::oneshot;
9use log::{info, warn};
10use std::cell::RefCell;
11use std::collections::{HashMap, HashSet};
12use std::rc::{Rc, Weak};
13
14pub type NodeId = u32;
15
16#[derive(Clone)]
17pub struct NodeInfo {
18    pub name: String,
19    pub driver_url: String,
20    pub collection: Collection,
21    pub state: ShutdownState,
22    pub host: Option<Rc<dyn DriverHost>>,
23}
24
25const REMOVAL_TIMEOUT_DURATION: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(15);
26
27pub struct NodeRemovalTracker {
28    fully_enumerated: bool,
29    next_node_id: NodeId,
30    remaining_pkg_nodes: HashSet<NodeId>,
31    remaining_non_pkg_nodes: HashSet<NodeId>,
32    nodes: HashMap<NodeId, NodeInfo>,
33    pkg_callback: Option<oneshot::Sender<()>>,
34    all_callback: Option<oneshot::Sender<()>>,
35    on_removal_timeout_callback: Option<Box<dyn Fn()>>,
36    timeout_count: u32,
37    timeout_task: Option<fasync::Task<()>>,
38}
39
40impl NodeRemovalTracker {
41    pub fn new() -> Rc<RefCell<Self>> {
42        Rc::new(RefCell::new(Self {
43            fully_enumerated: false,
44            next_node_id: 0,
45            remaining_pkg_nodes: HashSet::new(),
46            remaining_non_pkg_nodes: HashSet::new(),
47            nodes: HashMap::new(),
48            pkg_callback: None,
49            all_callback: None,
50            on_removal_timeout_callback: None,
51            timeout_count: 0,
52            timeout_task: None,
53        }))
54    }
55
56    fn start_timeout_task(&mut self, weak_self: Weak<RefCell<Self>>) {
57        if let Some(task) = self.timeout_task.take() {
58            std::mem::drop(task.abort());
59        }
60
61        self.timeout_task = Some(fasync::Task::local(async move {
62            fasync::Timer::new(REMOVAL_TIMEOUT_DURATION).await;
63            if let Some(strong_self) = weak_self.upgrade() {
64                strong_self.borrow_mut().on_removal_timeout(weak_self);
65            }
66        }));
67    }
68
69    pub fn register_node(&mut self, info: NodeInfo) -> NodeId {
70        if info.state == ShutdownState::Destroyed {
71            return self.next_node_id;
72        }
73
74        if info.collection == Collection::Package {
75            self.remaining_pkg_nodes.insert(self.next_node_id);
76        } else {
77            self.remaining_non_pkg_nodes.insert(self.next_node_id);
78        }
79        self.nodes.insert(self.next_node_id, info);
80        let id = self.next_node_id;
81        self.next_node_id += 1;
82        id
83    }
84
85    pub fn notify(&mut self, id: NodeId, state: ShutdownState, weak_self: Weak<RefCell<Self>>) {
86        let collection = {
87            let node_info = self.nodes.get_mut(&id).expect("Tried to Notify without registering!");
88            node_info.state = state;
89            node_info.collection
90        };
91
92        if self.timeout_task.is_some() {
93            self.start_timeout_task(weak_self);
94        }
95
96        if state == ShutdownState::Destroyed {
97            if collection == Collection::Package {
98                self.remaining_pkg_nodes.remove(&id);
99            } else {
100                self.remaining_non_pkg_nodes.remove(&id);
101            }
102            self.check_removal_done();
103        }
104    }
105
106    pub fn finish_enumeration(&mut self, weak_self: Weak<RefCell<Self>>) {
107        self.fully_enumerated = true;
108        self.start_timeout_task(weak_self);
109        self.check_removal_done();
110    }
111
112    pub fn set_pkg_callback(&mut self, callback: oneshot::Sender<()>) {
113        self.pkg_callback = Some(callback);
114    }
115
116    pub fn set_all_callback(&mut self, callback: oneshot::Sender<()>) {
117        self.all_callback = Some(callback);
118    }
119
120    pub fn set_on_removal_timeout_callback(&mut self, callback: Box<dyn Fn()>) {
121        self.on_removal_timeout_callback = Some(callback);
122    }
123
124    fn on_removal_timeout(&mut self, weak_self: Weak<RefCell<Self>>) {
125        self.timeout_count += 1;
126        warn!(
127            "Removal hanging, nodes remaining: {} pkg, {} pkg+boot",
128            self.remaining_pkg_nodes.len(),
129            self.remaining_pkg_nodes.len() + self.remaining_non_pkg_nodes.len()
130        );
131        for node in self.nodes.values() {
132            if node.state != ShutdownState::Destroyed && node.state != ShutdownState::Prestop {
133                if node.state == ShutdownState::WaitingOnDriver
134                    && let Some(host) = &node.host
135                {
136                    host.trigger_stack_trace();
137                }
138                warn!("  '{}' ('{}'): {}", node.name, node.driver_url, node.state);
139            }
140        }
141        if self.timeout_count >= 3
142            && let Some(callback) = &self.on_removal_timeout_callback
143        {
144            callback();
145        }
146        self.start_timeout_task(weak_self);
147    }
148
149    fn check_removal_done(&mut self) {
150        if !self.fully_enumerated {
151            return;
152        }
153
154        if self.pkg_callback.is_some() && self.remaining_pkg_nodes.is_empty() {
155            info!("NodeRemovalTracker: package removal completed");
156            if let Some(sender) = self.pkg_callback.take() {
157                let _ = sender.send(());
158            }
159        }
160        if self.all_callback.is_some()
161            && self.remaining_pkg_nodes.is_empty()
162            && self.remaining_non_pkg_nodes.is_empty()
163        {
164            info!("NodeRemovalTracker: all nodes removed");
165            if let Some(sender) = self.all_callback.take() {
166                let _ = sender.send(());
167            }
168            // Cancel timeout task.
169            if let Some(timeout_task) = self.timeout_task.take() {
170                std::mem::drop(timeout_task.abort());
171            }
172            self.nodes.clear();
173        }
174    }
175}