1use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24 task_id: u64,
26 debug_tag: String,
28 config: trace::TraceConfig,
30 requested_categories: Vec<String>,
32 duration: Option<Duration>,
34 triggers: Vec<Trigger>,
36 terminating: Arc<AtomicBool>,
38 start_time: Instant,
40 shutdown_sender: async_channel::Sender<()>,
42 task: Task<Option<trace::StopResult>>,
44 read_socket: AsyncSocket,
46 compression: trace::CompressionType,
48}
49
50impl Future for TraceTask {
52 type Output = Option<trace::StopResult>;
53
54 fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
55 Pin::new(&mut self.task).poll(cx)
56 }
57}
58
59impl TraceTask {
60 pub async fn new(
61 debug_tag: String,
62 config: trace::TraceConfig,
63 duration: Option<Duration>,
64 triggers: Vec<Trigger>,
65 requested_categories: Option<Vec<String>>,
66 compression: trace::CompressionType,
67 provisioner: trace::ProvisionerProxy,
68 ) -> Result<Self, TracingError> {
69 log::info!("TraceTask::new called with compression: {:?}", compression);
72 let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
73 let (client, server) = provisioner.domain().create_stream_socket();
74 let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
75 provisioner.initialize_tracing(server_end, &config, server)?;
76
77 client_end
78 .start_tracing(&trace::StartOptions::default())
79 .await?
80 .map_err(Into::<TracingError>::into)?;
81
82 let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
83 let terminate_result = Arc::new(Mutex::new(None));
84 let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
85
86 let controller = client_end.clone();
87 let shutdown_controller = client_end.clone();
88 let triggers_watcher =
89 TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
90 let terminating = Arc::new(AtomicBool::new(false));
91 let terminating_clone = terminating.clone();
92 let terminate_result_clone = terminate_result.clone();
93 let shutdown_fut = {
94 let logging_prefix = logging_prefix_og.clone();
95 async move {
96 if terminating_clone
97 .compare_exchange(
98 false,
99 true,
100 std::sync::atomic::Ordering::SeqCst,
101 std::sync::atomic::Ordering::Relaxed,
102 )
103 .is_ok()
104 {
105 log::info!("{logging_prefix} Running shutdown future.");
106 let result = trace_shutdown(&shutdown_controller).await;
107
108 let mut done = terminate_result_clone.lock().await;
109 if done.is_none() {
110 match result {
111 Ok(stop) => {
112 log::info!("{logging_prefix} call to trace_shutdown successful.");
113 *done = Some(stop)
114 }
115 Err(e) => {
116 log::error!(
117 "{logging_prefix} call to trace_shutdown failed: {e:?}"
118 );
119 }
120 }
121 }
122 } else {
123 log::debug!("Shutdown already triggered");
124 }
125 "shutdown future completed"
126 }
127 };
128
129 Ok(Self {
130 task_id,
131 debug_tag: logging_prefix_og,
132 config,
133 duration,
134 triggers: triggers.clone(),
135 terminating,
136 requested_categories: requested_categories.unwrap_or_default(),
137 start_time: Instant::now(),
138 shutdown_sender,
139 read_socket: socket_to_async(client),
140 compression,
141 task: Self::make_task(
142 task_id,
143 debug_tag,
144 duration,
145 shutdown_fut,
146 triggers_watcher,
147 terminate_result,
148 ),
149 })
150 }
151
152 async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
154 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
155 log::info!("{} Sending shutdown message.", self.debug_tag);
156 if self.shutdown_sender.send(()).await.is_err() {
157 log::warn!(
158 "{} Shutdown channel was closed. Task may have already completed.",
159 self.debug_tag
160 );
161 }
162 } else {
163 log::debug!("{} Shutdown already in progress.", self.debug_tag);
164 }
165
166 self.await
167 .map(|r| Ok(r))
168 .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
169 }
170
171 fn make_task(
172 task_id: u64,
173 debug_tag: String,
174 duration: Option<Duration>,
175 shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
176 trigger_watcher: TriggersWatcher<'static>,
177 terminate_result: Arc<Mutex<Option<StopResult>>>,
178 ) -> Task<Option<trace::StopResult>> {
179 Task::local(async move {
180 let mut timeout_fut = Box::pin(async move {
181 if let Some(duration) = duration {
182 fuchsia_async::Timer::new(duration).await;
183 } else {
184 std::future::pending::<()>().await;
185 }
186 })
187 .fuse();
188 let mut trigger_fut = trigger_watcher.fuse();
189
190 futures::select! {
191 _ = timeout_fut => {
193 log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
194 duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
195
196 shutdown_fut.await;
197 log::debug!("done with timeout!");
198
199 }
200
201 action = trigger_fut => {
203 if let Some(action) = action {
204 match action {
205 TriggerAction::Terminate => {
206 log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
207 }
208 }
209 } else {
210 log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
212 }
213 shutdown_fut.await;
214 log::debug!("done with trigger future!");
215 }
216 };
217 log::debug!("end of task waiting for terminate_result lock");
218 let res = terminate_result.lock().await.clone();
219 log::debug!("got res in task is some: {}", res.is_some());
220 res
221 })
222 }
223
224 pub fn triggers(&self) -> Vec<Trigger> {
225 self.triggers.clone()
226 }
227 pub fn config(&self) -> TraceConfig {
228 self.config.clone()
229 }
230
231 pub fn start_time(&self) -> Instant {
232 self.start_time
233 }
234
235 pub fn duration(&self) -> Option<Duration> {
236 self.duration.clone()
237 }
238
239 pub fn requested_categories(&self) -> Vec<String> {
240 self.requested_categories.clone()
241 }
242
243 pub fn task_id(&self) -> u64 {
244 self.task_id
245 }
246
247 pub async fn stop_and_receive_data<W>(
250 self,
251 mut writer: W,
252 ) -> Result<trace::StopResult, TracingError>
253 where
254 W: AsyncWrite + Unpin + Send + 'static,
255 {
256 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
257 log::info!("{} Sending shutdown message for task", self.debug_tag);
258 if self.shutdown_sender.send(()).await.is_err() {
259 log::warn!(
260 "{} Shutdown channel was closed. Task may have already completed.",
261 self.debug_tag
262 );
263 }
264 } else {
265 log::debug!("{} Shutdown already in progress.", self.debug_tag);
266 }
267
268 let res = match self.compression {
269 trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
270 _ => futures::io::copy(&self.read_socket, &mut writer)
271 .await
272 .map(|_| ())
273 .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
274 };
275
276 if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
277 }
278
279 pub async fn await_completion_and_receive_data<W>(
282 self,
283 mut writer: W,
284 ) -> Result<StopResult, TracingError>
285 where
286 W: AsyncWrite + Unpin + Send + 'static,
287 {
288 let res = match self.compression {
289 trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
290 _ => futures::io::copy(&self.read_socket, &mut writer)
291 .await
292 .map(|_| ())
293 .map_err(|e| TracingError::RecordingStop(e.to_string())),
294 };
295
296 match res {
297 Ok(_) => match self.await {
298 Some(r) => Ok(r),
299 None => Err(TracingError::RecordingStop("could not await task".into())),
300 },
301 Err(e) => Err(e),
302 }
303 }
304}
305
306async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
307where
308 R: AsyncRead + Unpin,
309 W: AsyncWrite + Unpin,
310{
311 let mut encoder = zstd::stream::raw::Encoder::new(0)
312 .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
313 let mut input_buf = vec![0u8; 128 * 1024];
315 let mut output_buf = vec![0u8; 128 * 1024];
316
317 while let n = reader
319 .read(&mut input_buf)
320 .await
321 .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
322 && n > 0
323 {
324 let mut read_offset = 0;
325 while read_offset < n {
326 let status = encoder
327 .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
328 .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
329 read_offset += status.bytes_read;
330 if status.bytes_written > 0 {
331 writer
332 .write_all(&output_buf[..status.bytes_written])
333 .await
334 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
335 }
336 }
337 }
338
339 loop {
341 let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
342 let remaining = encoder
343 .finish(&mut out_wrapper, true)
344 .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
345 let bytes = out_wrapper.as_slice();
346 if !bytes.is_empty() {
347 writer
348 .write_all(bytes)
349 .await
350 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
351 }
352 if remaining == 0 {
353 break;
354 }
355 }
356 Ok(())
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use flex_fuchsia_tracing_controller::StartError;
363
364 const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
365
366 fn setup_fake_provisioner_proxy(
367 start_error: Option<StartError>,
368 trigger_name: Option<&'static str>,
369 ) -> trace::ProvisionerProxy {
370 let (proxy, mut stream) =
371 fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
372 fuchsia_async::Task::local(async move {
373 while let Ok(Some(req)) = stream.try_next().await {
374 match req {
375 trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
376 let mut stream = controller.into_stream();
377 while let Ok(Some(req)) = stream.try_next().await {
378 match req {
379 trace::SessionRequest::StartTracing { responder, .. } => {
380 let response = match start_error {
381 Some(e) => Err(e),
382 None => Ok(()),
383 };
384 responder.send(response).expect("Failed to start")
385 }
386 trace::SessionRequest::StopTracing { responder, payload } => {
387 if start_error.is_some() {
388 responder
389 .send(Err(trace::StopError::NotStarted))
390 .expect("Failed to stop")
391 } else {
392 assert_eq!(payload.write_results.unwrap(), true);
393 assert_eq!(
394 FAKE_CONTROLLER_TRACE_OUTPUT.len(),
395 output
396 .write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes())
397 .unwrap()
398 );
399 let stop_result = trace::StopResult {
400 provider_stats: Some(vec![]),
401 ..Default::default()
402 };
403 responder.send(Ok(&stop_result)).expect("Failed to stop")
404 }
405 break;
406 }
407 trace::SessionRequest::WatchAlert { responder } => {
408 responder
409 .send(trigger_name.unwrap_or(""))
410 .expect("Unable to send alert");
411 }
412 r => panic!("unexpected request: {:#?}", r),
413 }
414 }
415 }
416 r => panic!("unexpected request: {:#?}", r),
417 }
418 }
419 })
420 .detach();
421 proxy
422 }
423
424 #[fuchsia::test]
425 async fn test_trace_task_start_stop_write_check_with_vec() {
426 let provisioner = setup_fake_provisioner_proxy(None, None);
427
428 let trace_task = TraceTask::new(
429 "test_trace_start_stop_write_check".into(),
430 trace::TraceConfig::default(),
431 None,
432 vec![],
433 None,
434 trace::CompressionType::None,
435 provisioner,
436 )
437 .await
438 .expect("tracing task started");
439
440 let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
441 assert_eq!(
442 shutdown_result,
443 trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
444 );
445 }
446
447 #[cfg(not(target_os = "fuchsia"))]
448 #[fuchsia::test]
449 async fn test_trace_task_start_stop_write_check_with_file() {
450 let temp_dir = tempfile::TempDir::new().unwrap();
451 let output = temp_dir.path().join("trace-test.fxt");
452
453 let provisioner = setup_fake_provisioner_proxy(None, None);
454 let writer = async_fs::File::create(&output).await.unwrap();
455
456 let trace_task = TraceTask::new(
457 "test_trace_start_stop_write_check".into(),
458 trace::TraceConfig::default(),
459 None,
460 vec![],
461 None,
462 trace::CompressionType::None,
463 provisioner,
464 )
465 .await
466 .expect("tracing task started");
467
468 let shutdown_result =
469 trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
470
471 let res = async_fs::read_to_string(&output).await.unwrap();
472 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
473 let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
474 assert_eq!(shutdown_result, expected);
475 }
476
477 #[fuchsia::test]
478 async fn test_trace_error_handling_already_started() {
479 let provisioner = setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None);
480
481 let trace_task_result = TraceTask::new(
482 "test_trace_error_handling_already_started".into(),
483 trace::TraceConfig::default(),
484 None,
485 vec![],
486 None,
487 trace::CompressionType::None,
488 provisioner,
489 )
490 .await
491 .err();
492
493 assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
494 }
495
496 #[cfg(not(target_os = "fuchsia"))]
497 #[fuchsia::test]
498 async fn test_trace_task_start_with_duration() {
499 let temp_dir = tempfile::TempDir::new().unwrap();
500 let output = temp_dir.path().join("trace-test.fxt");
501
502 let provisioner = setup_fake_provisioner_proxy(None, None);
503 let writer = async_fs::File::create(&output).await.unwrap();
504
505 let trace_task = TraceTask::new(
506 "test_trace_task_start_with_duration".into(),
507 trace::TraceConfig::default(),
508 Some(Duration::from_millis(100)),
509 vec![],
510 None,
511 trace::CompressionType::None,
512 provisioner,
513 )
514 .await
515 .expect("tracing task started");
516
517 let res = trace_task.await_completion_and_receive_data(writer).await;
518 if let Some(ref stop_result) = res.as_ref().ok() {
519 assert!(stop_result.provider_stats.is_some());
520 } else {
521 panic!("Expected stop result from trace_task.await: {res:?}");
522 }
523
524 let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
525 let mut res = String::new();
526 f.read_to_string(&mut res).await.unwrap();
527 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
528 }
529
530 #[cfg(not(target_os = "fuchsia"))]
531 #[fuchsia::test]
532 async fn test_triggers_valid() {
533 let temp_dir = tempfile::TempDir::new().unwrap();
534 let output = temp_dir.path().join("trace-test.fxt");
535 let alert_name = "some_alert";
536 let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()));
537 let writer = async_fs::File::create(output.clone()).await.unwrap();
538
539 let trace_task = TraceTask::new(
540 "test_triggers_valid".into(),
541 trace::TraceConfig::default(),
542 None,
543 vec![Trigger {
544 alert: Some(alert_name.into()),
545 action: Some(TriggerAction::Terminate),
546 }],
547 None,
548 trace::CompressionType::None,
549 provisioner,
550 )
551 .await
552 .expect("tracing task started");
553
554 trace_task.await_completion_and_receive_data(writer).await.unwrap();
555 let res = async_fs::read_to_string(&output).await.unwrap();
556 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
557 }
558}