trace_task/
trace_task.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24    /// Unique identifier for this task. The value of this id monotonicallly increases.
25    task_id: u64,
26    /// Tag used to identify this task in the log.
27    debug_tag: String,
28    /// Trace configuration.
29    config: trace::TraceConfig,
30    /// Requested categories. These are unexpanded from the user.
31    requested_categories: Vec<String>,
32    /// Duration to capture trace. None indicates capture until canceled.
33    duration: Option<Duration>,
34    /// Triggers for terminating the trace.
35    triggers: Vec<Trigger>,
36    /// True when the task is cleaning up.
37    terminating: Arc<AtomicBool>,
38    /// Start time of the task.
39    start_time: Instant,
40    /// Channel used to shutdown this task.
41    shutdown_sender: async_channel::Sender<()>,
42    /// The task.
43    task: Task<Option<trace::StopResult>>,
44    /// The socket to read the trace data from when tracing is completed.
45    read_socket: AsyncSocket,
46    /// The compression algorithm to use.
47    compression: trace::CompressionType,
48}
49
50// This is just implemented for convenience so the wrapper is await-able.
51impl Future for TraceTask {
52    type Output = Option<trace::StopResult>;
53
54    fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
55        Pin::new(&mut self.task).poll(cx)
56    }
57}
58
59impl TraceTask {
60    pub async fn new(
61        debug_tag: String,
62        config: trace::TraceConfig,
63        duration: Option<Duration>,
64        triggers: Vec<Trigger>,
65        requested_categories: Option<Vec<String>>,
66        compression: trace::CompressionType,
67        provisioner: trace::ProvisionerProxy,
68    ) -> Result<Self, TracingError> {
69        // Start the tracing session immediately. Maybe we should consider separating the creating
70        // of the session and the actual starting of it. This seems like a side-effect.
71        log::info!("TraceTask::new called with compression: {:?}", compression);
72        let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
73        let (client, server) = provisioner.domain().create_stream_socket();
74        let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
75        provisioner.initialize_tracing(server_end, &config, server)?;
76
77        client_end
78            .start_tracing(&trace::StartOptions::default())
79            .await?
80            .map_err(Into::<TracingError>::into)?;
81
82        let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
83        let terminate_result = Arc::new(Mutex::new(None));
84        let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
85
86        let controller = client_end.clone();
87        let shutdown_controller = client_end.clone();
88        let triggers_watcher =
89            TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
90        let terminating = Arc::new(AtomicBool::new(false));
91        let terminating_clone = terminating.clone();
92        let terminate_result_clone = terminate_result.clone();
93        let shutdown_fut = {
94            let logging_prefix = logging_prefix_og.clone();
95            async move {
96                if terminating_clone
97                    .compare_exchange(
98                        false,
99                        true,
100                        std::sync::atomic::Ordering::SeqCst,
101                        std::sync::atomic::Ordering::Relaxed,
102                    )
103                    .is_ok()
104                {
105                    log::info!("{logging_prefix} Running shutdown future.");
106                    let result = trace_shutdown(&shutdown_controller).await;
107
108                    let mut done = terminate_result_clone.lock().await;
109                    if done.is_none() {
110                        match result {
111                            Ok(stop) => {
112                                log::info!("{logging_prefix} call to trace_shutdown successful.");
113                                *done = Some(stop)
114                            }
115                            Err(e) => {
116                                log::error!(
117                                    "{logging_prefix} call to trace_shutdown failed: {e:?}"
118                                );
119                            }
120                        }
121                    }
122                } else {
123                    log::debug!("Shutdown already triggered");
124                }
125                "shutdown future completed"
126            }
127        };
128
129        Ok(Self {
130            task_id,
131            debug_tag: logging_prefix_og,
132            config,
133            duration,
134            triggers: triggers.clone(),
135            terminating,
136            requested_categories: requested_categories.unwrap_or_default(),
137            start_time: Instant::now(),
138            shutdown_sender,
139            read_socket: socket_to_async(client),
140            compression,
141            task: Self::make_task(
142                task_id,
143                debug_tag,
144                duration,
145                shutdown_fut,
146                triggers_watcher,
147                terminate_result,
148            ),
149        })
150    }
151
152    /// Shutdown the tracing task.
153    async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
154        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
155            log::info!("{} Sending shutdown message.", self.debug_tag);
156            if self.shutdown_sender.send(()).await.is_err() {
157                log::warn!(
158                    "{} Shutdown channel was closed. Task may have already completed.",
159                    self.debug_tag
160                );
161            }
162        } else {
163            log::debug!("{} Shutdown already in progress.", self.debug_tag);
164        }
165
166        self.await
167            .map(|r| Ok(r))
168            .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
169    }
170
171    fn make_task(
172        task_id: u64,
173        debug_tag: String,
174        duration: Option<Duration>,
175        shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
176        trigger_watcher: TriggersWatcher<'static>,
177        terminate_result: Arc<Mutex<Option<StopResult>>>,
178    ) -> Task<Option<trace::StopResult>> {
179        Task::local(async move {
180            let mut timeout_fut = Box::pin(async move {
181                if let Some(duration) = duration {
182                    fuchsia_async::Timer::new(duration).await;
183                } else {
184                    std::future::pending::<()>().await;
185                }
186            })
187            .fuse();
188            let mut trigger_fut = trigger_watcher.fuse();
189
190            futures::select! {
191                // Timeout, clean up and wait for copying to finish.
192                _ = timeout_fut => {
193                    log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
194                     duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
195
196                    shutdown_fut.await;
197                     log::debug!("done with timeout!");
198
199                }
200
201                // Trigger hit, shutdown and copy the trace.
202                action = trigger_fut => {
203                    if let Some(action) = action {
204                        match action {
205                            TriggerAction::Terminate => {
206                                log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
207                            }
208                        }
209                    } else {
210                        // This usually means the proxy was closed.
211                        log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
212                    }
213                    shutdown_fut.await;
214                     log::debug!("done with trigger future!");
215                }
216            };
217            log::debug!("end of task waiting for terminate_result lock");
218            let res = terminate_result.lock().await.clone();
219            log::debug!("got res in task is some: {}", res.is_some());
220            res
221        })
222    }
223
224    pub fn triggers(&self) -> Vec<Trigger> {
225        self.triggers.clone()
226    }
227    pub fn config(&self) -> TraceConfig {
228        self.config.clone()
229    }
230
231    pub fn start_time(&self) -> Instant {
232        self.start_time
233    }
234
235    pub fn duration(&self) -> Option<Duration> {
236        self.duration.clone()
237    }
238
239    pub fn requested_categories(&self) -> Vec<String> {
240        self.requested_categories.clone()
241    }
242
243    pub fn task_id(&self) -> u64 {
244        self.task_id
245    }
246
247    /// Signals the trace session to stop, copies all trace data to the
248    /// provided writer, and awaits task completion.
249    pub async fn stop_and_receive_data<W>(
250        self,
251        mut writer: W,
252    ) -> Result<trace::StopResult, TracingError>
253    where
254        W: AsyncWrite + Unpin + Send + 'static,
255    {
256        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
257            log::info!("{} Sending shutdown message for task", self.debug_tag);
258            if self.shutdown_sender.send(()).await.is_err() {
259                log::warn!(
260                    "{} Shutdown channel was closed. Task may have already completed.",
261                    self.debug_tag
262                );
263            }
264        } else {
265            log::debug!("{} Shutdown already in progress.", self.debug_tag);
266        }
267
268        let res = match self.compression {
269            trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
270            _ => futures::io::copy(&self.read_socket, &mut writer)
271                .await
272                .map(|_| ())
273                .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
274        };
275
276        if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
277    }
278
279    /// Waits for the tracing task to complete and copies the trace data to the writer.
280    /// If the tracing should be stopped vs. waiting, call |stop_and_receive_data|.
281    pub async fn await_completion_and_receive_data<W>(
282        self,
283        mut writer: W,
284    ) -> Result<StopResult, TracingError>
285    where
286        W: AsyncWrite + Unpin + Send + 'static,
287    {
288        let res = match self.compression {
289            trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
290            _ => futures::io::copy(&self.read_socket, &mut writer)
291                .await
292                .map(|_| ())
293                .map_err(|e| TracingError::RecordingStop(e.to_string())),
294        };
295
296        match res {
297            Ok(_) => match self.await {
298                Some(r) => Ok(r),
299                None => Err(TracingError::RecordingStop("could not await task".into())),
300            },
301            Err(e) => Err(e),
302        }
303    }
304}
305
306async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
307where
308    R: AsyncRead + Unpin,
309    W: AsyncWrite + Unpin,
310{
311    let mut encoder = zstd::stream::raw::Encoder::new(0)
312        .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
313    // 128KB is the recommended size for Zstd (ZSTD_CStreamInSize/ZSTD_CStreamOutSize)
314    let mut input_buf = vec![0u8; 128 * 1024];
315    let mut output_buf = vec![0u8; 128 * 1024];
316
317    // Read the stream until EOF, compressing and writing out fully compressed buffers as we go.
318    while let n = reader
319        .read(&mut input_buf)
320        .await
321        .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
322        && n > 0
323    {
324        let mut read_offset = 0;
325        while read_offset < n {
326            let status = encoder
327                .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
328                .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
329            read_offset += status.bytes_read;
330            if status.bytes_written > 0 {
331                writer
332                    .write_all(&output_buf[..status.bytes_written])
333                    .await
334                    .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
335            }
336        }
337    }
338
339    // Flush remaining of the last compressed buffer.
340    loop {
341        let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
342        let remaining = encoder
343            .finish(&mut out_wrapper, true)
344            .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
345        let bytes = out_wrapper.as_slice();
346        if !bytes.is_empty() {
347            writer
348                .write_all(bytes)
349                .await
350                .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
351        }
352        if remaining == 0 {
353            break;
354        }
355    }
356    Ok(())
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use flex_fuchsia_tracing_controller::StartError;
363
364    const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
365
366    fn setup_fake_provisioner_proxy(
367        start_error: Option<StartError>,
368        trigger_name: Option<&'static str>,
369    ) -> trace::ProvisionerProxy {
370        let (proxy, mut stream) =
371            fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
372        fuchsia_async::Task::local(async move {
373            while let Ok(Some(req)) = stream.try_next().await {
374                match req {
375                    trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
376                        let mut stream = controller.into_stream();
377                        while let Ok(Some(req)) = stream.try_next().await {
378                            match req {
379                                trace::SessionRequest::StartTracing { responder, .. } => {
380                                    let response = match start_error {
381                                        Some(e) => Err(e),
382                                        None => Ok(()),
383                                    };
384                                    responder.send(response).expect("Failed to start")
385                                }
386                                trace::SessionRequest::StopTracing { responder, payload } => {
387                                    if start_error.is_some() {
388                                        responder
389                                            .send(Err(trace::StopError::NotStarted))
390                                            .expect("Failed to stop")
391                                    } else {
392                                        assert_eq!(payload.write_results.unwrap(), true);
393                                        assert_eq!(
394                                            FAKE_CONTROLLER_TRACE_OUTPUT.len(),
395                                            output
396                                                .write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes())
397                                                .unwrap()
398                                        );
399                                        let stop_result = trace::StopResult {
400                                            provider_stats: Some(vec![]),
401                                            ..Default::default()
402                                        };
403                                        responder.send(Ok(&stop_result)).expect("Failed to stop")
404                                    }
405                                    break;
406                                }
407                                trace::SessionRequest::WatchAlert { responder } => {
408                                    responder
409                                        .send(trigger_name.unwrap_or(""))
410                                        .expect("Unable to send alert");
411                                }
412                                r => panic!("unexpected request: {:#?}", r),
413                            }
414                        }
415                    }
416                    r => panic!("unexpected request: {:#?}", r),
417                }
418            }
419        })
420        .detach();
421        proxy
422    }
423
424    #[fuchsia::test]
425    async fn test_trace_task_start_stop_write_check_with_vec() {
426        let provisioner = setup_fake_provisioner_proxy(None, None);
427
428        let trace_task = TraceTask::new(
429            "test_trace_start_stop_write_check".into(),
430            trace::TraceConfig::default(),
431            None,
432            vec![],
433            None,
434            trace::CompressionType::None,
435            provisioner,
436        )
437        .await
438        .expect("tracing task started");
439
440        let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
441        assert_eq!(
442            shutdown_result,
443            trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
444        );
445    }
446
447    #[cfg(not(target_os = "fuchsia"))]
448    #[fuchsia::test]
449    async fn test_trace_task_start_stop_write_check_with_file() {
450        let temp_dir = tempfile::TempDir::new().unwrap();
451        let output = temp_dir.path().join("trace-test.fxt");
452
453        let provisioner = setup_fake_provisioner_proxy(None, None);
454        let writer = async_fs::File::create(&output).await.unwrap();
455
456        let trace_task = TraceTask::new(
457            "test_trace_start_stop_write_check".into(),
458            trace::TraceConfig::default(),
459            None,
460            vec![],
461            None,
462            trace::CompressionType::None,
463            provisioner,
464        )
465        .await
466        .expect("tracing task started");
467
468        let shutdown_result =
469            trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
470
471        let res = async_fs::read_to_string(&output).await.unwrap();
472        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
473        let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
474        assert_eq!(shutdown_result, expected);
475    }
476
477    #[fuchsia::test]
478    async fn test_trace_error_handling_already_started() {
479        let provisioner = setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None);
480
481        let trace_task_result = TraceTask::new(
482            "test_trace_error_handling_already_started".into(),
483            trace::TraceConfig::default(),
484            None,
485            vec![],
486            None,
487            trace::CompressionType::None,
488            provisioner,
489        )
490        .await
491        .err();
492
493        assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
494    }
495
496    #[cfg(not(target_os = "fuchsia"))]
497    #[fuchsia::test]
498    async fn test_trace_task_start_with_duration() {
499        let temp_dir = tempfile::TempDir::new().unwrap();
500        let output = temp_dir.path().join("trace-test.fxt");
501
502        let provisioner = setup_fake_provisioner_proxy(None, None);
503        let writer = async_fs::File::create(&output).await.unwrap();
504
505        let trace_task = TraceTask::new(
506            "test_trace_task_start_with_duration".into(),
507            trace::TraceConfig::default(),
508            Some(Duration::from_millis(100)),
509            vec![],
510            None,
511            trace::CompressionType::None,
512            provisioner,
513        )
514        .await
515        .expect("tracing task started");
516
517        let res = trace_task.await_completion_and_receive_data(writer).await;
518        if let Some(ref stop_result) = res.as_ref().ok() {
519            assert!(stop_result.provider_stats.is_some());
520        } else {
521            panic!("Expected stop result from trace_task.await: {res:?}");
522        }
523
524        let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
525        let mut res = String::new();
526        f.read_to_string(&mut res).await.unwrap();
527        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
528    }
529
530    #[cfg(not(target_os = "fuchsia"))]
531    #[fuchsia::test]
532    async fn test_triggers_valid() {
533        let temp_dir = tempfile::TempDir::new().unwrap();
534        let output = temp_dir.path().join("trace-test.fxt");
535        let alert_name = "some_alert";
536        let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()));
537        let writer = async_fs::File::create(output.clone()).await.unwrap();
538
539        let trace_task = TraceTask::new(
540            "test_triggers_valid".into(),
541            trace::TraceConfig::default(),
542            None,
543            vec![Trigger {
544                alert: Some(alert_name.into()),
545                action: Some(TriggerAction::Terminate),
546            }],
547            None,
548            trace::CompressionType::None,
549            provisioner,
550        )
551        .await
552        .expect("tracing task started");
553
554        trace_task.await_completion_and_receive_data(writer).await.unwrap();
555        let res = async_fs::read_to_string(&output).await.unwrap();
556        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
557    }
558}