1use crate::suspend::WakeSources;
6use anyhow::{Error, anyhow};
7use fuchsia_async as fasync;
8use fuchsia_sync::Mutex;
9use futures::FutureExt;
10use log::warn;
11use std::cell::RefCell;
12use std::mem::MaybeUninit;
13use std::pin::pin;
14use std::rc::Rc;
15use std::sync::Arc;
16
17pub struct ChannelProxy {
23 pub container_channel: zx::Channel,
25
26 pub remote_channel: zx::Channel,
28
29 pub message_counter: zx::Counter,
32
33 pub name: String,
35}
36
37#[derive(Debug)]
39enum WaitReturn {
40 Container,
41 Remote,
42}
43
44const PROXY_ROLE_NAME: &str = "fuchsia.starnix.runner.proxy";
46
47pub fn run_proxy_thread(
49 new_proxies: async_channel::Receiver<(ChannelProxy, Arc<Mutex<WakeSources>>)>,
50) {
51 let _ = std::thread::Builder::new().name("proxy_thread".to_string()).spawn(move || {
52 if let Err(e) = fuchsia_scheduler::set_role_for_this_thread(PROXY_ROLE_NAME) {
53 warn!(e:%; "failed to set thread role");
54 }
55 let mut executor = fasync::LocalExecutor::default();
56 executor.run_singlethreaded(async move {
57 let mut tasks = fasync::TaskGroup::new();
58 let bounce_bytes = Rc::new(RefCell::new(
59 [MaybeUninit::uninit(); zx::sys::ZX_CHANNEL_MAX_MSG_BYTES as usize],
60 ));
61 let bounce_handles = Rc::new(RefCell::new(
62 [const { MaybeUninit::uninit() }; zx::sys::ZX_CHANNEL_MAX_MSG_HANDLES as usize],
63 ));
64 while let Ok((proxy, events)) = new_proxies.recv().await {
65 let bytes_clone = bounce_bytes.clone();
66 let handles_clone = bounce_handles.clone();
67 tasks.local(start_proxy(proxy, events, bytes_clone, handles_clone));
68 }
69 });
70 });
71}
72
73async fn start_proxy(
79 proxy: ChannelProxy,
80 wake_sources: Arc<Mutex<WakeSources>>,
81 bounce_bytes: Rc<RefCell<[MaybeUninit<u8>; zx::sys::ZX_CHANNEL_MAX_MSG_BYTES as usize]>>,
82 bounce_handles: Rc<
83 RefCell<[MaybeUninit<zx::NullableHandle>; zx::sys::ZX_CHANNEL_MAX_MSG_HANDLES as usize]>,
84 >,
85) {
86 let proxy_name = proxy.name.as_str();
87 trace_instant("starnix_runner:start_proxy:loop:enter", proxy_name);
88
89 'outer: loop {
90 let mut container_wait = pin!(
92 fasync::OnSignals::new(
93 proxy.container_channel.as_handle_ref(),
94 zx::Signals::CHANNEL_READABLE | zx::Signals::CHANNEL_PEER_CLOSED,
95 )
96 .fuse()
97 );
98 let mut remote_wait = pin!(
99 fasync::OnSignals::new(
100 proxy.remote_channel.as_handle_ref(),
101 zx::Signals::CHANNEL_READABLE | zx::Signals::CHANNEL_PEER_CLOSED,
102 )
103 .fuse()
104 );
105
106 let (signals, finished_wait) = {
107 trace_duration("starnix_runner:start_proxy:wait_for_messages", proxy_name);
108 let result = futures::select! {
109 res = container_wait => {
110 trace_instant("starnix_runner:start_proxy:container_readable", proxy_name);
111 res.map(|s| (s, WaitReturn::Container))
112 },
113 res = remote_wait => {
114 trace_instant("starnix_runner:start_proxy:remote_readable", proxy_name);
115 res.map(|s| (s, WaitReturn::Remote))
116 },
117 };
118
119 match result {
120 Ok(result) => result,
121 Err(e) => {
122 trace_instant("starnix_runner:start_proxy:result:error", proxy_name);
123 log::warn!("Failed to wait on proxied channels in runner: {:?}", e);
124 break 'outer;
125 }
126 }
127 };
128
129 let name = proxy.name.as_str();
133 let result = match finished_wait {
134 WaitReturn::Container => forward_message(
135 &signals,
136 &proxy.container_channel,
137 &proxy.remote_channel,
138 None,
139 &mut bounce_bytes.borrow_mut(),
140 &mut bounce_handles.borrow_mut(),
141 name,
142 ),
143 WaitReturn::Remote => forward_message(
144 &signals,
145 &proxy.remote_channel,
146 &proxy.container_channel,
147 Some(&proxy.message_counter),
148 &mut bounce_bytes.borrow_mut(),
149 &mut bounce_handles.borrow_mut(),
150 name,
151 ),
152 };
153
154 if result.is_err() {
155 log::warn!(
156 "Proxy failed to forward message {} kernel: {}; {:?}",
157 match finished_wait {
158 WaitReturn::Container => "from",
159 WaitReturn::Remote => "to",
160 },
161 name,
162 result,
163 );
164 break 'outer;
165 }
166 }
167
168 trace_instant("starnix_runner:start_proxy:loop:exit", proxy_name);
169 if let Ok(koid) = proxy.message_counter.koid() {
170 wake_sources.lock().remove(&koid);
171 }
172}
173
174fn forward_message(
181 signals: &zx::Signals,
182 read_channel: &zx::Channel,
183 write_channel: &zx::Channel,
184 message_counter: Option<&zx::Counter>,
185 bytes: &mut [MaybeUninit<u8>; zx::sys::ZX_CHANNEL_MAX_MSG_BYTES as usize],
186 handles: &mut [MaybeUninit<zx::NullableHandle>; zx::sys::ZX_CHANNEL_MAX_MSG_HANDLES as usize],
187 name: &str,
188) -> Result<(), Error> {
189 trace_duration("starnix_runner:forward_message", name);
190
191 if signals.contains(zx::Signals::CHANNEL_READABLE) {
192 let (actual_bytes, actual_handles) = {
193 match read_channel.read_uninit(bytes, handles) {
194 zx::ChannelReadResult::Ok(r) => r,
195 _ => return Err(anyhow!("Failed to read from channel")),
196 }
197 };
198
199 if let Some(counter) = message_counter {
200 counter.add(1).expect("Failed to add to the proxy's message counter");
201 trace_instant("starnix_runner:forward_message:counter_incremented", name);
202 }
203
204 write_channel.write(actual_bytes, actual_handles)?;
205 }
206
207 if signals.contains(zx::Signals::CHANNEL_PEER_CLOSED) {
210 Err(anyhow!("Proxy peer was closed"))
211 } else {
212 Ok(())
213 }
214}
215
216fn trace_duration(event: &'static str, name: &str) {
217 fuchsia_trace::duration!("power", event, "name" => name);
218}
219
220fn trace_instant(event: &'static str, name: &str) {
221 fuchsia_trace::instant!(
222 "power",
223 event,
224 fuchsia_trace::Scope::Process,
225 "name" => name
226 );
227}
228
229#[cfg(test)]
230mod test {
231 use super::{ChannelProxy, fasync, start_proxy};
232 use fidl::HandleBased;
233 use std::cell::RefCell;
234 use std::mem::MaybeUninit;
235 use std::rc::Rc;
236
237 fn run_proxy_for_test(proxy: ChannelProxy) -> fasync::Task<()> {
238 let bounce_bytes = Rc::new(RefCell::new(
239 [MaybeUninit::uninit(); zx::sys::ZX_CHANNEL_MAX_MSG_BYTES as usize],
240 ));
241 let bounce_handles = Rc::new(RefCell::new(
242 [const { MaybeUninit::uninit() }; zx::sys::ZX_CHANNEL_MAX_MSG_HANDLES as usize],
243 ));
244 fasync::Task::local(start_proxy(proxy, Default::default(), bounce_bytes, bounce_handles))
245 }
246
247 #[fuchsia::test]
248 async fn test_peer_closed_kernel() {
249 let (local_client, local_server) = zx::Channel::create();
250 let (remote_client, remote_server) = zx::Channel::create();
251 let message_counter = zx::Counter::create();
252
253 let channel_proxy = ChannelProxy {
254 container_channel: local_server,
255 remote_channel: remote_client,
256 message_counter,
257 name: "test".to_string(),
258 };
259 let _task = run_proxy_for_test(channel_proxy);
260
261 std::mem::drop(local_client);
262
263 fasync::OnSignals::new(remote_server, zx::Signals::CHANNEL_PEER_CLOSED).await.unwrap();
264 }
265
266 #[fuchsia::test]
267 async fn test_peer_closed_remote() {
268 let (local_client, local_server) = zx::Channel::create();
269 let (remote_client, remote_server) = zx::Channel::create();
270 let message_counter = zx::Counter::create();
271
272 let channel_proxy = ChannelProxy {
273 container_channel: local_server,
274 remote_channel: remote_client,
275 message_counter,
276 name: "test".to_string(),
277 };
278 let _task = run_proxy_for_test(channel_proxy);
279
280 std::mem::drop(remote_server);
281
282 fasync::OnSignals::new(local_client, zx::Signals::CHANNEL_PEER_CLOSED).await.unwrap();
283 }
284
285 #[fuchsia::test]
286 async fn test_counter_sequential() {
287 let (_local_client, local_server) = zx::Channel::create();
288 let (remote_client, remote_server) = zx::Channel::create();
289 let message_counter = zx::Counter::create();
290 let local_message_counter = message_counter
291 .duplicate_handle(zx::Rights::SAME_RIGHTS)
292 .expect("Failed to duplicate counter");
293
294 let channel_proxy = ChannelProxy {
295 container_channel: local_server,
296 remote_channel: remote_client,
297 message_counter,
298 name: "test".to_string(),
299 };
300 let _task = run_proxy_for_test(channel_proxy);
301
302 fasync::OnSignals::new(&local_message_counter, zx::Signals::COUNTER_NON_POSITIVE)
304 .await
305 .unwrap();
306 assert!(remote_server.write(&[0x0, 0x1, 0x2], &mut []).is_ok());
307 fasync::OnSignals::new(&local_message_counter, zx::Signals::COUNTER_POSITIVE)
308 .await
309 .unwrap();
310
311 local_message_counter.add(-1).expect("Failed add");
313 fasync::OnSignals::new(&local_message_counter, zx::Signals::COUNTER_NON_POSITIVE)
314 .await
315 .unwrap();
316 assert!(remote_server.write(&[0x0, 0x1, 0x2], &mut []).is_ok());
317 fasync::OnSignals::new(&local_message_counter, zx::Signals::COUNTER_POSITIVE)
318 .await
319 .unwrap();
320 }
321
322 #[fuchsia::test]
323 async fn test_counter_multiple() {
324 let (_local_client, local_server) = zx::Channel::create();
325 let (remote_client, remote_server) = zx::Channel::create();
326 let message_counter = zx::Counter::create();
327 let local_message_counter = message_counter
328 .duplicate_handle(zx::Rights::SAME_RIGHTS)
329 .expect("Failed to duplicate counter");
330
331 let channel_proxy = ChannelProxy {
332 container_channel: local_server,
333 remote_channel: remote_client,
334 message_counter,
335 name: "test".to_string(),
336 };
337 let _task = run_proxy_for_test(channel_proxy);
338
339 assert!(remote_server.write(&[0x0, 0x1, 0x2], &mut []).is_ok());
340 assert!(remote_server.write(&[0x0, 0x1, 0x2], &mut []).is_ok());
341 assert!(remote_server.write(&[0x0, 0x1, 0x2], &mut []).is_ok());
342 fasync::OnSignals::new(&local_message_counter, zx::Signals::COUNTER_POSITIVE)
343 .await
344 .unwrap();
345 assert_eq!(local_message_counter.read().expect("Failed to read counter"), 3);
346 }
347}