1use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24 task_id: u64,
26 debug_tag: String,
28 config: trace::TraceConfig,
30 requested_categories: Vec<String>,
32 duration: Option<Duration>,
34 triggers: Vec<Trigger>,
36 terminating: Arc<AtomicBool>,
38 start_time: Instant,
40 shutdown_sender: async_channel::Sender<()>,
42 task: Task<Option<trace::StopResult>>,
44 read_socket: AsyncSocket,
46 compression: trace::CompressionType,
48 cancelled: Arc<AtomicBool>,
50}
51
52impl Future for TraceTask {
54 type Output = Option<trace::StopResult>;
55
56 fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
57 Pin::new(&mut self.task).poll(cx)
58 }
59}
60
61impl TraceTask {
62 pub async fn new(
63 debug_tag: String,
64 config: trace::TraceConfig,
65 duration: Option<Duration>,
66 triggers: Vec<Trigger>,
67 requested_categories: Option<Vec<String>>,
68 compression: trace::CompressionType,
69 provisioner: trace::ProvisionerProxy,
70 ) -> Result<Self, TracingError> {
71 log::info!("TraceTask::new called with compression: {:?}", compression);
74 let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
75 let (client, server) = provisioner.domain().create_stream_socket();
76 let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
77 provisioner.initialize_tracing(server_end, &config, server)?;
78
79 client_end
80 .start_tracing(&trace::StartOptions::default())
81 .await?
82 .map_err(Into::<TracingError>::into)?;
83
84 let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
85 let terminate_result = Arc::new(Mutex::new(None));
86 let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
87
88 let controller = client_end.clone();
89 let shutdown_controller = client_end.clone();
90 let triggers_watcher =
91 TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
92 let terminating = Arc::new(AtomicBool::new(false));
93 let cancelled = Arc::new(AtomicBool::new(false));
94 let terminating_clone = terminating.clone();
95 let cancelled_clone = cancelled.clone();
96 let terminate_result_clone = terminate_result.clone();
97
98 let shutdown_fut = {
99 let logging_prefix = logging_prefix_og.clone();
100 async move {
101 if terminating_clone
102 .compare_exchange(
103 false,
104 true,
105 std::sync::atomic::Ordering::SeqCst,
106 std::sync::atomic::Ordering::Relaxed,
107 )
108 .is_ok()
109 {
110 log::info!("{logging_prefix} Running shutdown future.");
111 let is_cancelled = cancelled_clone.load(std::sync::atomic::Ordering::Relaxed);
112 let result = trace_shutdown(&shutdown_controller, is_cancelled).await;
113
114 let mut done = terminate_result_clone.lock().await;
115 if done.is_none() {
116 match result {
117 Ok(stop) => {
118 log::info!("{logging_prefix} call to trace_shutdown successful.");
119 *done = Some(stop)
120 }
121 Err(e) => {
122 log::error!(
123 "{logging_prefix} call to trace_shutdown failed: {e:?}"
124 );
125 }
126 }
127 }
128 } else {
129 log::debug!("Shutdown already triggered");
130 }
131 "shutdown future completed"
132 }
133 };
134
135 let task = Self::make_task(
136 task_id,
137 debug_tag,
138 duration,
139 shutdown_fut,
140 triggers_watcher,
141 terminate_result,
142 );
143 Ok(Self {
144 task_id,
145 debug_tag: logging_prefix_og,
146 config,
147 duration,
148 triggers: triggers.clone(),
149 terminating,
150 requested_categories: requested_categories.unwrap_or_default(),
151 start_time: Instant::now(),
152 shutdown_sender,
153 read_socket: socket_to_async(client),
154 compression,
155 task,
156 cancelled,
157 })
158 }
159
160 async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
162 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
163 log::info!("{} Sending shutdown message.", self.debug_tag);
164 if self.shutdown_sender.send(()).await.is_err() {
165 log::warn!(
166 "{} Shutdown channel was closed. Task may have already completed.",
167 self.debug_tag
168 );
169 }
170 } else {
171 log::debug!("{} Shutdown already in progress.", self.debug_tag);
172 }
173
174 self.await
175 .map(|r| Ok(r))
176 .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
177 }
178
179 pub async fn abort(mut self) -> Result<trace::StopResult, TracingError> {
181 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
182 log::info!("{} Sending cancel message for task", self.debug_tag);
183 self.cancelled.store(true, std::sync::atomic::Ordering::SeqCst);
184 if self.shutdown_sender.send(()).await.is_err() {
185 log::warn!(
186 "{} Shutdown channel was closed. Task may have already completed.",
187 self.debug_tag
188 );
189 }
190 } else {
191 log::debug!("{} Shutdown already in progress.", self.debug_tag);
192 }
193
194 let res = self.read_socket.close().await;
196 if res.is_err() {
197 log::warn!("{} Failed to close socket: {:?}", self.debug_tag, res);
198 }
199 self.shutdown().await
200 }
201
202 fn make_task(
203 task_id: u64,
204 debug_tag: String,
205 duration: Option<Duration>,
206 shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
207 trigger_watcher: TriggersWatcher<'static>,
208 terminate_result: Arc<Mutex<Option<StopResult>>>,
209 ) -> Task<Option<trace::StopResult>> {
210 Task::local(async move {
211 let mut timeout_fut = Box::pin(async move {
212 if let Some(duration) = duration {
213 fuchsia_async::Timer::new(duration).await;
214 } else {
215 std::future::pending::<()>().await;
216 }
217 })
218 .fuse();
219 let mut trigger_fut = trigger_watcher.fuse();
220
221 futures::select! {
222 _ = timeout_fut => {
224 log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
225 duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
226
227 shutdown_fut.await;
228 log::debug!("done with timeout!");
229
230 }
231
232 action = trigger_fut => {
234 if let Some(action) = action {
235 match action {
236 TriggerAction::Terminate => {
237 log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
238 }
239 }
240 } else {
241 log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
243 }
244 shutdown_fut.await;
245 log::debug!("done with trigger future!");
246 }
247 };
248 log::debug!("end of task waiting for terminate_result lock");
249 let res = terminate_result.lock().await.clone();
250 log::debug!("got res in task is some: {}", res.is_some());
251 res
252 })
253 }
254
255 pub fn triggers(&self) -> Vec<Trigger> {
256 self.triggers.clone()
257 }
258 pub fn config(&self) -> TraceConfig {
259 self.config.clone()
260 }
261
262 pub fn start_time(&self) -> Instant {
263 self.start_time
264 }
265
266 pub fn duration(&self) -> Option<Duration> {
267 self.duration.clone()
268 }
269
270 pub fn requested_categories(&self) -> Vec<String> {
271 self.requested_categories.clone()
272 }
273
274 pub fn task_id(&self) -> u64 {
275 self.task_id
276 }
277
278 pub async fn stop_and_receive_data<W>(
281 self,
282 mut writer: W,
283 ) -> Result<trace::StopResult, TracingError>
284 where
285 W: AsyncWrite + Unpin + Send + 'static,
286 {
287 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
288 log::info!("{} Sending shutdown message for task", self.debug_tag);
289 if self.shutdown_sender.send(()).await.is_err() {
290 log::warn!(
291 "{} Shutdown channel was closed. Task may have already completed.",
292 self.debug_tag
293 );
294 }
295 } else {
296 log::debug!("{} Shutdown already in progress.", self.debug_tag);
297 }
298
299 let res = match self.compression {
300 trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
301 _ => futures::io::copy(&self.read_socket, &mut writer)
302 .await
303 .map(|_| ())
304 .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
305 };
306
307 if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
308 }
309
310 pub async fn await_completion_and_receive_data<W>(
313 self,
314 mut writer: W,
315 ) -> Result<StopResult, TracingError>
316 where
317 W: AsyncWrite + Unpin + Send + 'static,
318 {
319 let res = match self.compression {
320 trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
321 _ => futures::io::copy(&self.read_socket, &mut writer)
322 .await
323 .map(|_| ())
324 .map_err(|e| TracingError::RecordingStop(e.to_string())),
325 };
326
327 match res {
328 Ok(_) => match self.await {
329 Some(r) => Ok(r),
330 None => Err(TracingError::RecordingStop("could not await task".into())),
331 },
332 Err(e) => Err(e),
333 }
334 }
335}
336
337async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
338where
339 R: AsyncRead + Unpin,
340 W: AsyncWrite + Unpin,
341{
342 let mut encoder = zstd::stream::raw::Encoder::new(0)
343 .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
344 let mut input_buf = vec![0u8; 128 * 1024];
346 let mut output_buf = vec![0u8; 128 * 1024];
347
348 while let n = reader
350 .read(&mut input_buf)
351 .await
352 .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
353 && n > 0
354 {
355 let mut read_offset = 0;
356 while read_offset < n {
357 let status = encoder
358 .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
359 .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
360 read_offset += status.bytes_read;
361 if status.bytes_written > 0 {
362 writer
363 .write_all(&output_buf[..status.bytes_written])
364 .await
365 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
366 }
367 }
368 }
369
370 loop {
372 let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
373 let remaining = encoder
374 .finish(&mut out_wrapper, true)
375 .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
376 let bytes = out_wrapper.as_slice();
377 if !bytes.is_empty() {
378 writer
379 .write_all(bytes)
380 .await
381 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
382 }
383 if remaining == 0 {
384 break;
385 }
386 }
387 Ok(())
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use flex_fuchsia_tracing_controller::StartError;
394
395 const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
396
397 fn setup_fake_provisioner_proxy(
398 start_error: Option<StartError>,
399 trigger_name: Option<&'static str>,
400 expected_write_results: bool,
401 ) -> trace::ProvisionerProxy {
402 let (proxy, mut stream) =
403 fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
404 fuchsia_async::Task::local(async move {
405 while let Ok(Some(req)) = stream.try_next().await {
406 match req {
407 trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
408 let mut stream = controller.into_stream();
409 while let Ok(Some(req)) = stream.try_next().await {
410 match req {
411 trace::SessionRequest::StartTracing { responder, .. } => {
412 let response = match start_error {
413 Some(e) => Err(e),
414 None => Ok(()),
415 };
416 responder.send(response).expect("Failed to start")
417 }
418 trace::SessionRequest::StopTracing { responder, payload } => {
419 if start_error.is_some() {
420 responder
421 .send(Err(trace::StopError::NotStarted))
422 .expect("Failed to stop")
423 } else {
424 assert_eq!(
425 payload.write_results.unwrap(),
426 expected_write_results
427 );
428 assert_eq!(
429 FAKE_CONTROLLER_TRACE_OUTPUT.len(),
430 output
431 .write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes())
432 .unwrap()
433 );
434 let stop_result = trace::StopResult {
435 provider_stats: Some(vec![]),
436 ..Default::default()
437 };
438 responder.send(Ok(&stop_result)).expect("Failed to stop")
439 }
440 break;
441 }
442 trace::SessionRequest::WatchAlert { responder } => {
443 responder
444 .send(trigger_name.unwrap_or(""))
445 .expect("Unable to send alert");
446 }
447 r => panic!("unexpected request: {:#?}", r),
448 }
449 }
450 }
451 r => panic!("unexpected request: {:#?}", r),
452 }
453 }
454 })
455 .detach();
456 proxy
457 }
458
459 #[fuchsia::test]
460 async fn test_trace_task_start_stop_write_check_with_vec() {
461 let provisioner = setup_fake_provisioner_proxy(None, None, true);
462
463 let trace_task = TraceTask::new(
464 "test_trace_start_stop_write_check".into(),
465 trace::TraceConfig::default(),
466 None,
467 vec![],
468 None,
469 trace::CompressionType::None,
470 provisioner,
471 )
472 .await
473 .expect("tracing task started");
474
475 let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
476 assert_eq!(
477 shutdown_result,
478 trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
479 );
480 }
481
482 #[cfg(not(target_os = "fuchsia"))]
483 #[fuchsia::test]
484 async fn test_trace_task_start_stop_write_check_with_file() {
485 let temp_dir = tempfile::TempDir::new().unwrap();
486 let output = temp_dir.path().join("trace-test.fxt");
487
488 let provisioner = setup_fake_provisioner_proxy(None, None, true);
489 let writer = async_fs::File::create(&output).await.unwrap();
490
491 let trace_task = TraceTask::new(
492 "test_trace_start_stop_write_check".into(),
493 trace::TraceConfig::default(),
494 None,
495 vec![],
496 None,
497 trace::CompressionType::None,
498 provisioner,
499 )
500 .await
501 .expect("tracing task started");
502
503 let shutdown_result =
504 trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
505
506 let res = async_fs::read_to_string(&output).await.unwrap();
507 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
508 let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
509 assert_eq!(shutdown_result, expected);
510 }
511
512 #[fuchsia::test]
513 async fn test_trace_error_handling_already_started() {
514 let provisioner =
515 setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None, true);
516
517 let trace_task_result = TraceTask::new(
518 "test_trace_error_handling_already_started".into(),
519 trace::TraceConfig::default(),
520 None,
521 vec![],
522 None,
523 trace::CompressionType::None,
524 provisioner,
525 )
526 .await
527 .err();
528
529 assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
530 }
531
532 #[cfg(not(target_os = "fuchsia"))]
533 #[fuchsia::test]
534 async fn test_trace_task_start_with_duration() {
535 let temp_dir = tempfile::TempDir::new().unwrap();
536 let output = temp_dir.path().join("trace-test.fxt");
537
538 let provisioner = setup_fake_provisioner_proxy(None, None, true);
539 let writer = async_fs::File::create(&output).await.unwrap();
540
541 let trace_task = TraceTask::new(
542 "test_trace_task_start_with_duration".into(),
543 trace::TraceConfig::default(),
544 Some(Duration::from_millis(100)),
545 vec![],
546 None,
547 trace::CompressionType::None,
548 provisioner,
549 )
550 .await
551 .expect("tracing task started");
552
553 let res = trace_task.await_completion_and_receive_data(writer).await;
554 if let Some(ref stop_result) = res.as_ref().ok() {
555 assert!(stop_result.provider_stats.is_some());
556 } else {
557 panic!("Expected stop result from trace_task.await: {res:?}");
558 }
559
560 let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
561 let mut res = String::new();
562 f.read_to_string(&mut res).await.unwrap();
563 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
564 }
565
566 #[cfg(not(target_os = "fuchsia"))]
567 #[fuchsia::test]
568 async fn test_triggers_valid() {
569 let temp_dir = tempfile::TempDir::new().unwrap();
570 let output = temp_dir.path().join("trace-test.fxt");
571 let alert_name = "some_alert";
572 let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()), true);
573 let writer = async_fs::File::create(output.clone()).await.unwrap();
574
575 let trace_task = TraceTask::new(
576 "test_triggers_valid".into(),
577 trace::TraceConfig::default(),
578 None,
579 vec![Trigger {
580 alert: Some(alert_name.into()),
581 action: Some(TriggerAction::Terminate),
582 }],
583 None,
584 trace::CompressionType::None,
585 provisioner,
586 )
587 .await
588 .expect("tracing task started");
589
590 trace_task.await_completion_and_receive_data(writer).await.unwrap();
591 let res = async_fs::read_to_string(&output).await.unwrap();
592 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
593 }
594
595 #[fuchsia::test]
596 async fn test_trace_task_abort() {
597 let provisioner = setup_fake_provisioner_proxy(None, None, false);
598
599 let trace_task = TraceTask::new(
600 "test_trace_task_abort".into(),
601 trace::TraceConfig::default(),
602 None,
603 vec![],
604 None,
605 trace::CompressionType::None,
606 provisioner,
607 )
608 .await
609 .expect("tracing task started");
610
611 let shutdown_result = trace_task.abort().await.expect("tracing abort");
612 assert_eq!(
613 shutdown_result,
614 trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }
615 );
616 }
617}