1use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24 task_id: u64,
26 debug_tag: String,
28 config: trace::TraceConfig,
30 requested_categories: Vec<String>,
32 duration: Option<Duration>,
34 triggers: Vec<Trigger>,
36 terminating: Arc<AtomicBool>,
38 start_time: Instant,
40 shutdown_sender: async_channel::Sender<()>,
42 task: Task<Option<trace::StopResult>>,
44 read_socket: Option<AsyncSocket>,
46 compression: trace::CompressionType,
48 cancelled: Arc<AtomicBool>,
50}
51
52impl Future for TraceTask {
54 type Output = Option<trace::StopResult>;
55
56 fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
57 Pin::new(&mut self.task).poll(cx)
58 }
59}
60
61impl TraceTask {
62 pub async fn new(
63 debug_tag: String,
64 config: trace::TraceConfig,
65 duration: Option<Duration>,
66 triggers: Vec<Trigger>,
67 requested_categories: Option<Vec<String>>,
68 compression: trace::CompressionType,
69 provisioner: trace::ProvisionerProxy,
70 ) -> Result<Self, TracingError> {
71 log::info!("TraceTask::new called with compression: {:?}", compression);
74 let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
75 let (client, server) = provisioner.domain().create_stream_socket();
76 let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
77 provisioner.initialize_tracing(server_end, &config, server)?;
78
79 client_end
80 .start_tracing(&trace::StartOptions::default())
81 .await?
82 .map_err(Into::<TracingError>::into)?;
83
84 let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
85 let terminate_result = Arc::new(Mutex::new(None));
86 let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
87
88 let controller = client_end.clone();
89 let shutdown_controller = client_end.clone();
90 let triggers_watcher =
91 TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
92 let terminating = Arc::new(AtomicBool::new(false));
93 let cancelled = Arc::new(AtomicBool::new(false));
94 let terminating_clone = terminating.clone();
95 let cancelled_clone = cancelled.clone();
96 let terminate_result_clone = terminate_result.clone();
97
98 let shutdown_fut = {
99 let logging_prefix = logging_prefix_og.clone();
100 async move {
101 if terminating_clone
102 .compare_exchange(
103 false,
104 true,
105 std::sync::atomic::Ordering::SeqCst,
106 std::sync::atomic::Ordering::Relaxed,
107 )
108 .is_ok()
109 {
110 log::info!("{logging_prefix} Running shutdown future.");
111 let is_cancelled = cancelled_clone.load(std::sync::atomic::Ordering::Relaxed);
112 let result = trace_shutdown(&shutdown_controller, is_cancelled).await;
113
114 let mut done = terminate_result_clone.lock().await;
115 if done.is_none() {
116 match result {
117 Ok(stop) => {
118 log::info!("{logging_prefix} call to trace_shutdown successful.");
119 *done = Some(stop)
120 }
121 Err(e) => {
122 log::error!(
123 "{logging_prefix} call to trace_shutdown failed: {e:?}"
124 );
125 }
126 }
127 }
128 } else {
129 log::debug!("Shutdown already triggered");
130 }
131 "shutdown future completed"
132 }
133 };
134
135 let task = Self::make_task(
136 task_id,
137 debug_tag,
138 duration,
139 shutdown_fut,
140 triggers_watcher,
141 terminate_result,
142 );
143 Ok(Self {
144 task_id,
145 debug_tag: logging_prefix_og,
146 config,
147 duration,
148 triggers: triggers.clone(),
149 terminating,
150 requested_categories: requested_categories.unwrap_or_default(),
151 start_time: Instant::now(),
152 shutdown_sender,
153 read_socket: Some(socket_to_async(client)),
154 compression,
155 task,
156 cancelled,
157 })
158 }
159
160 async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
162 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
163 log::info!("{} Sending shutdown message.", self.debug_tag);
164 if self.shutdown_sender.send(()).await.is_err() {
165 log::warn!(
166 "{} Shutdown channel was closed. Task may have already completed.",
167 self.debug_tag
168 );
169 }
170 } else {
171 log::debug!("{} Shutdown already in progress.", self.debug_tag);
172 }
173
174 self.await
175 .map(|r| Ok(r))
176 .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
177 }
178
179 pub async fn abort(mut self) -> Result<trace::StopResult, TracingError> {
181 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
182 log::info!("{} Sending cancel message for task", self.debug_tag);
183 self.cancelled.store(true, std::sync::atomic::Ordering::SeqCst);
184 if self.shutdown_sender.send(()).await.is_err() {
185 log::warn!(
186 "{} Shutdown channel was closed. Task may have already completed.",
187 self.debug_tag
188 );
189 }
190 } else {
191 log::debug!("{} Shutdown already in progress.", self.debug_tag);
192 }
193
194 if let Some(mut socket) = self.read_socket.take() {
196 let res = socket.close().await;
197 if res.is_err() {
198 log::warn!("{} Failed to close socket: {:?}", self.debug_tag, res);
199 }
200 drop(socket);
201 }
202 self.shutdown().await
203 }
204
205 fn make_task(
206 task_id: u64,
207 debug_tag: String,
208 duration: Option<Duration>,
209 shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
210 trigger_watcher: TriggersWatcher<'static>,
211 terminate_result: Arc<Mutex<Option<StopResult>>>,
212 ) -> Task<Option<trace::StopResult>> {
213 Task::local(async move {
214 let mut timeout_fut = Box::pin(async move {
215 if let Some(duration) = duration {
216 fuchsia_async::Timer::new(duration).await;
217 } else {
218 std::future::pending::<()>().await;
219 }
220 })
221 .fuse();
222 let mut trigger_fut = trigger_watcher.fuse();
223
224 futures::select! {
225 _ = timeout_fut => {
227 log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
228 duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
229
230 shutdown_fut.await;
231 log::debug!("done with timeout!");
232
233 }
234
235 action = trigger_fut => {
237 if let Some(action) = action {
238 match action {
239 TriggerAction::Terminate => {
240 log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
241 }
242 }
243 } else {
244 log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
246 }
247 shutdown_fut.await;
248 log::debug!("done with trigger future!");
249 }
250 };
251 log::debug!("end of task waiting for terminate_result lock");
252 let res = terminate_result.lock().await.clone();
253 log::debug!("got res in task is some: {}", res.is_some());
254 res
255 })
256 }
257
258 pub fn triggers(&self) -> Vec<Trigger> {
259 self.triggers.clone()
260 }
261 pub fn config(&self) -> TraceConfig {
262 self.config.clone()
263 }
264
265 pub fn start_time(&self) -> Instant {
266 self.start_time
267 }
268
269 pub fn duration(&self) -> Option<Duration> {
270 self.duration.clone()
271 }
272
273 pub fn requested_categories(&self) -> Vec<String> {
274 self.requested_categories.clone()
275 }
276
277 pub fn task_id(&self) -> u64 {
278 self.task_id
279 }
280
281 pub async fn stop_and_receive_data<W>(
284 mut self,
285 mut writer: W,
286 ) -> Result<trace::StopResult, TracingError>
287 where
288 W: AsyncWrite + Unpin + Send + 'static,
289 {
290 if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
291 log::info!("{} Sending shutdown message for task", self.debug_tag);
292 if self.shutdown_sender.send(()).await.is_err() {
293 log::warn!(
294 "{} Shutdown channel was closed. Task may have already completed.",
295 self.debug_tag
296 );
297 }
298 } else {
299 log::debug!("{} Shutdown already in progress.", self.debug_tag);
300 }
301
302 let mut read_socket = self.read_socket.take().unwrap();
303 let res = match self.compression {
304 trace::CompressionType::Zstd => compress_zstd(&mut read_socket, &mut writer).await,
305 _ => futures::io::copy(&mut read_socket, &mut writer)
306 .await
307 .map(|_| ())
308 .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
309 };
310
311 if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
312 }
313
314 pub async fn await_completion_and_receive_data<W>(
317 mut self,
318 mut writer: W,
319 ) -> Result<StopResult, TracingError>
320 where
321 W: AsyncWrite + Unpin + Send + 'static,
322 {
323 let mut read_socket = self.read_socket.take().unwrap();
324 let res = match self.compression {
325 trace::CompressionType::Zstd => compress_zstd(&mut read_socket, &mut writer).await,
326 _ => futures::io::copy(&mut read_socket, &mut writer)
327 .await
328 .map(|_| ())
329 .map_err(|e| TracingError::RecordingStop(e.to_string())),
330 };
331
332 match res {
333 Ok(_) => match self.await {
334 Some(r) => Ok(r),
335 None => Err(TracingError::RecordingStop("could not await task".into())),
336 },
337 Err(e) => Err(e),
338 }
339 }
340}
341
342async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
343where
344 R: AsyncRead + Unpin,
345 W: AsyncWrite + Unpin,
346{
347 let mut encoder = zstd::stream::raw::Encoder::new(0)
348 .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
349 let mut input_buf = vec![0u8; 128 * 1024];
351 let mut output_buf = vec![0u8; 128 * 1024];
352
353 while let n = reader
355 .read(&mut input_buf)
356 .await
357 .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
358 && n > 0
359 {
360 let mut read_offset = 0;
361 while read_offset < n {
362 let status = encoder
363 .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
364 .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
365 read_offset += status.bytes_read;
366 if status.bytes_written > 0 {
367 writer
368 .write_all(&output_buf[..status.bytes_written])
369 .await
370 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
371 }
372 }
373 }
374
375 loop {
377 let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
378 let remaining = encoder
379 .finish(&mut out_wrapper, true)
380 .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
381 let bytes = out_wrapper.as_slice();
382 if !bytes.is_empty() {
383 writer
384 .write_all(bytes)
385 .await
386 .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
387 }
388 if remaining == 0 {
389 break;
390 }
391 }
392 Ok(())
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use flex_fuchsia_tracing_controller::StartError;
399
400 const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
401
402 fn setup_fake_provisioner_proxy(
403 start_error: Option<StartError>,
404 trigger_name: Option<&'static str>,
405 expected_write_results: bool,
406 ) -> trace::ProvisionerProxy {
407 let (proxy, mut stream) =
408 fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
409 fuchsia_async::Task::local(async move {
410 while let Ok(Some(req)) = stream.try_next().await {
411 match req {
412 trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
413 let mut stream = controller.into_stream();
414 while let Ok(Some(req)) = stream.try_next().await {
415 match req {
416 trace::SessionRequest::StartTracing { responder, .. } => {
417 let response = match start_error {
418 Some(e) => Err(e),
419 None => Ok(()),
420 };
421 responder.send(response).expect("Failed to start")
422 }
423 trace::SessionRequest::StopTracing { responder, payload } => {
424 if start_error.is_some() {
425 responder
426 .send(Err(trace::StopError::NotStarted))
427 .expect("Failed to stop")
428 } else {
429 assert_eq!(
430 payload.write_results.unwrap(),
431 expected_write_results
432 );
433 let _ =
434 output.write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes());
435 let stop_result = trace::StopResult {
436 provider_stats: Some(vec![]),
437 ..Default::default()
438 };
439 responder.send(Ok(&stop_result)).expect("Failed to stop")
440 }
441 break;
442 }
443 trace::SessionRequest::WatchAlert { responder } => {
444 responder
445 .send(trigger_name.unwrap_or(""))
446 .expect("Unable to send alert");
447 }
448 r => panic!("unexpected request: {:#?}", r),
449 }
450 }
451 }
452 r => panic!("unexpected request: {:#?}", r),
453 }
454 }
455 })
456 .detach();
457 proxy
458 }
459
460 #[fuchsia::test]
461 async fn test_trace_task_start_stop_write_check_with_vec() {
462 let provisioner = setup_fake_provisioner_proxy(None, None, true);
463
464 let trace_task = TraceTask::new(
465 "test_trace_start_stop_write_check".into(),
466 trace::TraceConfig::default(),
467 None,
468 vec![],
469 None,
470 trace::CompressionType::None,
471 provisioner,
472 )
473 .await
474 .expect("tracing task started");
475
476 let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
477 assert_eq!(
478 shutdown_result,
479 trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
480 );
481 }
482
483 #[cfg(not(target_os = "fuchsia"))]
484 #[fuchsia::test]
485 async fn test_trace_task_start_stop_write_check_with_file() {
486 let temp_dir = tempfile::TempDir::new().unwrap();
487 let output = temp_dir.path().join("trace-test.fxt");
488
489 let provisioner = setup_fake_provisioner_proxy(None, None, true);
490 let writer = async_fs::File::create(&output).await.unwrap();
491
492 let trace_task = TraceTask::new(
493 "test_trace_start_stop_write_check".into(),
494 trace::TraceConfig::default(),
495 None,
496 vec![],
497 None,
498 trace::CompressionType::None,
499 provisioner,
500 )
501 .await
502 .expect("tracing task started");
503
504 let shutdown_result =
505 trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
506
507 let res = async_fs::read_to_string(&output).await.unwrap();
508 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
509 let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
510 assert_eq!(shutdown_result, expected);
511 }
512
513 #[fuchsia::test]
514 async fn test_trace_error_handling_already_started() {
515 let provisioner =
516 setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None, true);
517
518 let trace_task_result = TraceTask::new(
519 "test_trace_error_handling_already_started".into(),
520 trace::TraceConfig::default(),
521 None,
522 vec![],
523 None,
524 trace::CompressionType::None,
525 provisioner,
526 )
527 .await
528 .err();
529
530 assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
531 }
532
533 #[cfg(not(target_os = "fuchsia"))]
534 #[fuchsia::test]
535 async fn test_trace_task_start_with_duration() {
536 let temp_dir = tempfile::TempDir::new().unwrap();
537 let output = temp_dir.path().join("trace-test.fxt");
538
539 let provisioner = setup_fake_provisioner_proxy(None, None, true);
540 let writer = async_fs::File::create(&output).await.unwrap();
541
542 let trace_task = TraceTask::new(
543 "test_trace_task_start_with_duration".into(),
544 trace::TraceConfig::default(),
545 Some(Duration::from_millis(100)),
546 vec![],
547 None,
548 trace::CompressionType::None,
549 provisioner,
550 )
551 .await
552 .expect("tracing task started");
553
554 let res = trace_task.await_completion_and_receive_data(writer).await;
555 if let Some(ref stop_result) = res.as_ref().ok() {
556 assert!(stop_result.provider_stats.is_some());
557 } else {
558 panic!("Expected stop result from trace_task.await: {res:?}");
559 }
560
561 let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
562 let mut res = String::new();
563 f.read_to_string(&mut res).await.unwrap();
564 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
565 }
566
567 #[cfg(not(target_os = "fuchsia"))]
568 #[fuchsia::test]
569 async fn test_triggers_valid() {
570 let temp_dir = tempfile::TempDir::new().unwrap();
571 let output = temp_dir.path().join("trace-test.fxt");
572 let alert_name = "some_alert";
573 let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()), true);
574 let writer = async_fs::File::create(output.clone()).await.unwrap();
575
576 let trace_task = TraceTask::new(
577 "test_triggers_valid".into(),
578 trace::TraceConfig::default(),
579 None,
580 vec![Trigger {
581 alert: Some(alert_name.into()),
582 action: Some(TriggerAction::Terminate),
583 }],
584 None,
585 trace::CompressionType::None,
586 provisioner,
587 )
588 .await
589 .expect("tracing task started");
590
591 trace_task.await_completion_and_receive_data(writer).await.unwrap();
592 let res = async_fs::read_to_string(&output).await.unwrap();
593 assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
594 }
595
596 #[fuchsia::test]
597 async fn test_trace_task_abort() {
598 let provisioner = setup_fake_provisioner_proxy(None, None, false);
599
600 let trace_task = TraceTask::new(
601 "test_trace_task_abort".into(),
602 trace::TraceConfig::default(),
603 None,
604 vec![],
605 None,
606 trace::CompressionType::None,
607 provisioner,
608 )
609 .await
610 .expect("tracing task started");
611
612 let shutdown_result = trace_task.abort().await.expect("tracing abort");
613 assert_eq!(
614 shutdown_result,
615 trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }
616 );
617 }
618}