Skip to main content

trace_task/
trace_task.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24    /// Unique identifier for this task. The value of this id monotonicallly increases.
25    task_id: u64,
26    /// Tag used to identify this task in the log.
27    debug_tag: String,
28    /// Trace configuration.
29    config: trace::TraceConfig,
30    /// Requested categories. These are unexpanded from the user.
31    requested_categories: Vec<String>,
32    /// Duration to capture trace. None indicates capture until canceled.
33    duration: Option<Duration>,
34    /// Triggers for terminating the trace.
35    triggers: Vec<Trigger>,
36    /// True when the task is cleaning up.
37    terminating: Arc<AtomicBool>,
38    /// Start time of the task.
39    start_time: Instant,
40    /// Channel used to shutdown this task.
41    shutdown_sender: async_channel::Sender<()>,
42    /// The task.
43    task: Task<Option<trace::StopResult>>,
44    /// The socket to read the trace data from when tracing is completed.
45    read_socket: Option<AsyncSocket>,
46    /// The compression algorithm to use.
47    compression: trace::CompressionType,
48    /// True when the task was cancelled (aborted).
49    cancelled: Arc<AtomicBool>,
50}
51
52// This is just implemented for convenience so the wrapper is await-able.
53impl Future for TraceTask {
54    type Output = Option<trace::StopResult>;
55
56    fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
57        Pin::new(&mut self.task).poll(cx)
58    }
59}
60
61impl TraceTask {
62    pub async fn new(
63        debug_tag: String,
64        config: trace::TraceConfig,
65        duration: Option<Duration>,
66        triggers: Vec<Trigger>,
67        requested_categories: Option<Vec<String>>,
68        compression: trace::CompressionType,
69        provisioner: trace::ProvisionerProxy,
70    ) -> Result<Self, TracingError> {
71        // Start the tracing session immediately. Maybe we should consider separating the creating
72        // of the session and the actual starting of it. This seems like a side-effect.
73        log::info!("TraceTask::new called with compression: {:?}", compression);
74        let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
75        let (client, server) = provisioner.domain().create_stream_socket();
76        let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
77        provisioner.initialize_tracing(server_end, &config, server)?;
78
79        client_end
80            .start_tracing(&trace::StartOptions::default())
81            .await?
82            .map_err(Into::<TracingError>::into)?;
83
84        let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
85        let terminate_result = Arc::new(Mutex::new(None));
86        let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
87
88        let controller = client_end.clone();
89        let shutdown_controller = client_end.clone();
90        let triggers_watcher =
91            TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
92        let terminating = Arc::new(AtomicBool::new(false));
93        let cancelled = Arc::new(AtomicBool::new(false));
94        let terminating_clone = terminating.clone();
95        let cancelled_clone = cancelled.clone();
96        let terminate_result_clone = terminate_result.clone();
97
98        let shutdown_fut = {
99            let logging_prefix = logging_prefix_og.clone();
100            async move {
101                if terminating_clone
102                    .compare_exchange(
103                        false,
104                        true,
105                        std::sync::atomic::Ordering::SeqCst,
106                        std::sync::atomic::Ordering::Relaxed,
107                    )
108                    .is_ok()
109                {
110                    log::info!("{logging_prefix} Running shutdown future.");
111                    let is_cancelled = cancelled_clone.load(std::sync::atomic::Ordering::Relaxed);
112                    let result = trace_shutdown(&shutdown_controller, is_cancelled).await;
113
114                    let mut done = terminate_result_clone.lock().await;
115                    if done.is_none() {
116                        match result {
117                            Ok(stop) => {
118                                log::info!("{logging_prefix} call to trace_shutdown successful.");
119                                *done = Some(stop)
120                            }
121                            Err(e) => {
122                                log::error!(
123                                    "{logging_prefix} call to trace_shutdown failed: {e:?}"
124                                );
125                            }
126                        }
127                    }
128                } else {
129                    log::debug!("Shutdown already triggered");
130                }
131                "shutdown future completed"
132            }
133        };
134
135        let task = Self::make_task(
136            task_id,
137            debug_tag,
138            duration,
139            shutdown_fut,
140            triggers_watcher,
141            terminate_result,
142        );
143        Ok(Self {
144            task_id,
145            debug_tag: logging_prefix_og,
146            config,
147            duration,
148            triggers: triggers.clone(),
149            terminating,
150            requested_categories: requested_categories.unwrap_or_default(),
151            start_time: Instant::now(),
152            shutdown_sender,
153            read_socket: Some(socket_to_async(client)),
154            compression,
155            task,
156            cancelled,
157        })
158    }
159
160    /// Shutdown the tracing task.
161    async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
162        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
163            log::info!("{} Sending shutdown message.", self.debug_tag);
164            if self.shutdown_sender.send(()).await.is_err() {
165                log::warn!(
166                    "{} Shutdown channel was closed. Task may have already completed.",
167                    self.debug_tag
168                );
169            }
170        } else {
171            log::debug!("{} Shutdown already in progress.", self.debug_tag);
172        }
173
174        self.await
175            .map(|r| Ok(r))
176            .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
177    }
178
179    /// Abort the tracing task without writing results.
180    pub async fn abort(mut self) -> Result<trace::StopResult, TracingError> {
181        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
182            log::info!("{} Sending cancel message for task", self.debug_tag);
183            self.cancelled.store(true, std::sync::atomic::Ordering::SeqCst);
184            if self.shutdown_sender.send(()).await.is_err() {
185                log::warn!(
186                    "{} Shutdown channel was closed. Task may have already completed.",
187                    self.debug_tag
188                );
189            }
190        } else {
191            log::debug!("{} Shutdown already in progress.", self.debug_tag);
192        }
193
194        // Close the socket without reading.
195        if let Some(mut socket) = self.read_socket.take() {
196            let res = socket.close().await;
197            if res.is_err() {
198                log::warn!("{} Failed to close socket: {:?}", self.debug_tag, res);
199            }
200            drop(socket);
201        }
202        self.shutdown().await
203    }
204
205    fn make_task(
206        task_id: u64,
207        debug_tag: String,
208        duration: Option<Duration>,
209        shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
210        trigger_watcher: TriggersWatcher<'static>,
211        terminate_result: Arc<Mutex<Option<StopResult>>>,
212    ) -> Task<Option<trace::StopResult>> {
213        Task::local(async move {
214            let mut timeout_fut = Box::pin(async move {
215                if let Some(duration) = duration {
216                    fuchsia_async::Timer::new(duration).await;
217                } else {
218                    std::future::pending::<()>().await;
219                }
220            })
221            .fuse();
222            let mut trigger_fut = trigger_watcher.fuse();
223
224            futures::select! {
225                // Timeout, clean up and wait for copying to finish.
226                _ = timeout_fut => {
227                    log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
228                     duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
229
230                    shutdown_fut.await;
231                     log::debug!("done with timeout!");
232
233                }
234
235                // Trigger hit, shutdown and copy the trace.
236                action = trigger_fut => {
237                    if let Some(action) = action {
238                        match action {
239                            TriggerAction::Terminate => {
240                                log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
241                            }
242                        }
243                    } else {
244                        // This usually means the proxy was closed.
245                        log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
246                    }
247                    shutdown_fut.await;
248                     log::debug!("done with trigger future!");
249                }
250            };
251            log::debug!("end of task waiting for terminate_result lock");
252            let res = terminate_result.lock().await.clone();
253            log::debug!("got res in task is some: {}", res.is_some());
254            res
255        })
256    }
257
258    pub fn triggers(&self) -> Vec<Trigger> {
259        self.triggers.clone()
260    }
261    pub fn config(&self) -> TraceConfig {
262        self.config.clone()
263    }
264
265    pub fn start_time(&self) -> Instant {
266        self.start_time
267    }
268
269    pub fn duration(&self) -> Option<Duration> {
270        self.duration.clone()
271    }
272
273    pub fn requested_categories(&self) -> Vec<String> {
274        self.requested_categories.clone()
275    }
276
277    pub fn task_id(&self) -> u64 {
278        self.task_id
279    }
280
281    /// Signals the trace session to stop, copies all trace data to the
282    /// provided writer, and awaits task completion.
283    pub async fn stop_and_receive_data<W>(
284        mut self,
285        mut writer: W,
286    ) -> Result<trace::StopResult, TracingError>
287    where
288        W: AsyncWrite + Unpin + Send + 'static,
289    {
290        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
291            log::info!("{} Sending shutdown message for task", self.debug_tag);
292            if self.shutdown_sender.send(()).await.is_err() {
293                log::warn!(
294                    "{} Shutdown channel was closed. Task may have already completed.",
295                    self.debug_tag
296                );
297            }
298        } else {
299            log::debug!("{} Shutdown already in progress.", self.debug_tag);
300        }
301
302        let mut read_socket = self.read_socket.take().unwrap();
303        let res = match self.compression {
304            trace::CompressionType::Zstd => compress_zstd(&mut read_socket, &mut writer).await,
305            _ => futures::io::copy(&mut read_socket, &mut writer)
306                .await
307                .map(|_| ())
308                .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
309        };
310
311        if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
312    }
313
314    /// Waits for the tracing task to complete and copies the trace data to the writer.
315    /// If the tracing should be stopped vs. waiting, call |stop_and_receive_data|.
316    pub async fn await_completion_and_receive_data<W>(
317        mut self,
318        mut writer: W,
319    ) -> Result<StopResult, TracingError>
320    where
321        W: AsyncWrite + Unpin + Send + 'static,
322    {
323        let mut read_socket = self.read_socket.take().unwrap();
324        let res = match self.compression {
325            trace::CompressionType::Zstd => compress_zstd(&mut read_socket, &mut writer).await,
326            _ => futures::io::copy(&mut read_socket, &mut writer)
327                .await
328                .map(|_| ())
329                .map_err(|e| TracingError::RecordingStop(e.to_string())),
330        };
331
332        match res {
333            Ok(_) => match self.await {
334                Some(r) => Ok(r),
335                None => Err(TracingError::RecordingStop("could not await task".into())),
336            },
337            Err(e) => Err(e),
338        }
339    }
340}
341
342async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
343where
344    R: AsyncRead + Unpin,
345    W: AsyncWrite + Unpin,
346{
347    let mut encoder = zstd::stream::raw::Encoder::new(0)
348        .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
349    // 128KB is the recommended size for Zstd (ZSTD_CStreamInSize/ZSTD_CStreamOutSize)
350    let mut input_buf = vec![0u8; 128 * 1024];
351    let mut output_buf = vec![0u8; 128 * 1024];
352
353    // Read the stream until EOF, compressing and writing out fully compressed buffers as we go.
354    while let n = reader
355        .read(&mut input_buf)
356        .await
357        .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
358        && n > 0
359    {
360        let mut read_offset = 0;
361        while read_offset < n {
362            let status = encoder
363                .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
364                .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
365            read_offset += status.bytes_read;
366            if status.bytes_written > 0 {
367                writer
368                    .write_all(&output_buf[..status.bytes_written])
369                    .await
370                    .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
371            }
372        }
373    }
374
375    // Flush remaining of the last compressed buffer.
376    loop {
377        let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
378        let remaining = encoder
379            .finish(&mut out_wrapper, true)
380            .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
381        let bytes = out_wrapper.as_slice();
382        if !bytes.is_empty() {
383            writer
384                .write_all(bytes)
385                .await
386                .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
387        }
388        if remaining == 0 {
389            break;
390        }
391    }
392    Ok(())
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use flex_fuchsia_tracing_controller::StartError;
399
400    const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
401
402    fn setup_fake_provisioner_proxy(
403        start_error: Option<StartError>,
404        trigger_name: Option<&'static str>,
405        expected_write_results: bool,
406    ) -> trace::ProvisionerProxy {
407        let (proxy, mut stream) =
408            fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
409        fuchsia_async::Task::local(async move {
410            while let Ok(Some(req)) = stream.try_next().await {
411                match req {
412                    trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
413                        let mut stream = controller.into_stream();
414                        while let Ok(Some(req)) = stream.try_next().await {
415                            match req {
416                                trace::SessionRequest::StartTracing { responder, .. } => {
417                                    let response = match start_error {
418                                        Some(e) => Err(e),
419                                        None => Ok(()),
420                                    };
421                                    responder.send(response).expect("Failed to start")
422                                }
423                                trace::SessionRequest::StopTracing { responder, payload } => {
424                                    if start_error.is_some() {
425                                        responder
426                                            .send(Err(trace::StopError::NotStarted))
427                                            .expect("Failed to stop")
428                                    } else {
429                                        assert_eq!(
430                                            payload.write_results.unwrap(),
431                                            expected_write_results
432                                        );
433                                        let _ =
434                                            output.write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes());
435                                        let stop_result = trace::StopResult {
436                                            provider_stats: Some(vec![]),
437                                            ..Default::default()
438                                        };
439                                        responder.send(Ok(&stop_result)).expect("Failed to stop")
440                                    }
441                                    break;
442                                }
443                                trace::SessionRequest::WatchAlert { responder } => {
444                                    responder
445                                        .send(trigger_name.unwrap_or(""))
446                                        .expect("Unable to send alert");
447                                }
448                                r => panic!("unexpected request: {:#?}", r),
449                            }
450                        }
451                    }
452                    r => panic!("unexpected request: {:#?}", r),
453                }
454            }
455        })
456        .detach();
457        proxy
458    }
459
460    #[fuchsia::test]
461    async fn test_trace_task_start_stop_write_check_with_vec() {
462        let provisioner = setup_fake_provisioner_proxy(None, None, true);
463
464        let trace_task = TraceTask::new(
465            "test_trace_start_stop_write_check".into(),
466            trace::TraceConfig::default(),
467            None,
468            vec![],
469            None,
470            trace::CompressionType::None,
471            provisioner,
472        )
473        .await
474        .expect("tracing task started");
475
476        let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
477        assert_eq!(
478            shutdown_result,
479            trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
480        );
481    }
482
483    #[cfg(not(target_os = "fuchsia"))]
484    #[fuchsia::test]
485    async fn test_trace_task_start_stop_write_check_with_file() {
486        let temp_dir = tempfile::TempDir::new().unwrap();
487        let output = temp_dir.path().join("trace-test.fxt");
488
489        let provisioner = setup_fake_provisioner_proxy(None, None, true);
490        let writer = async_fs::File::create(&output).await.unwrap();
491
492        let trace_task = TraceTask::new(
493            "test_trace_start_stop_write_check".into(),
494            trace::TraceConfig::default(),
495            None,
496            vec![],
497            None,
498            trace::CompressionType::None,
499            provisioner,
500        )
501        .await
502        .expect("tracing task started");
503
504        let shutdown_result =
505            trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
506
507        let res = async_fs::read_to_string(&output).await.unwrap();
508        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
509        let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
510        assert_eq!(shutdown_result, expected);
511    }
512
513    #[fuchsia::test]
514    async fn test_trace_error_handling_already_started() {
515        let provisioner =
516            setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None, true);
517
518        let trace_task_result = TraceTask::new(
519            "test_trace_error_handling_already_started".into(),
520            trace::TraceConfig::default(),
521            None,
522            vec![],
523            None,
524            trace::CompressionType::None,
525            provisioner,
526        )
527        .await
528        .err();
529
530        assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
531    }
532
533    #[cfg(not(target_os = "fuchsia"))]
534    #[fuchsia::test]
535    async fn test_trace_task_start_with_duration() {
536        let temp_dir = tempfile::TempDir::new().unwrap();
537        let output = temp_dir.path().join("trace-test.fxt");
538
539        let provisioner = setup_fake_provisioner_proxy(None, None, true);
540        let writer = async_fs::File::create(&output).await.unwrap();
541
542        let trace_task = TraceTask::new(
543            "test_trace_task_start_with_duration".into(),
544            trace::TraceConfig::default(),
545            Some(Duration::from_millis(100)),
546            vec![],
547            None,
548            trace::CompressionType::None,
549            provisioner,
550        )
551        .await
552        .expect("tracing task started");
553
554        let res = trace_task.await_completion_and_receive_data(writer).await;
555        if let Some(ref stop_result) = res.as_ref().ok() {
556            assert!(stop_result.provider_stats.is_some());
557        } else {
558            panic!("Expected stop result from trace_task.await: {res:?}");
559        }
560
561        let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
562        let mut res = String::new();
563        f.read_to_string(&mut res).await.unwrap();
564        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
565    }
566
567    #[cfg(not(target_os = "fuchsia"))]
568    #[fuchsia::test]
569    async fn test_triggers_valid() {
570        let temp_dir = tempfile::TempDir::new().unwrap();
571        let output = temp_dir.path().join("trace-test.fxt");
572        let alert_name = "some_alert";
573        let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()), true);
574        let writer = async_fs::File::create(output.clone()).await.unwrap();
575
576        let trace_task = TraceTask::new(
577            "test_triggers_valid".into(),
578            trace::TraceConfig::default(),
579            None,
580            vec![Trigger {
581                alert: Some(alert_name.into()),
582                action: Some(TriggerAction::Terminate),
583            }],
584            None,
585            trace::CompressionType::None,
586            provisioner,
587        )
588        .await
589        .expect("tracing task started");
590
591        trace_task.await_completion_and_receive_data(writer).await.unwrap();
592        let res = async_fs::read_to_string(&output).await.unwrap();
593        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
594    }
595
596    #[fuchsia::test]
597    async fn test_trace_task_abort() {
598        let provisioner = setup_fake_provisioner_proxy(None, None, false);
599
600        let trace_task = TraceTask::new(
601            "test_trace_task_abort".into(),
602            trace::TraceConfig::default(),
603            None,
604            vec![],
605            None,
606            trace::CompressionType::None,
607            provisioner,
608        )
609        .await
610        .expect("tracing task started");
611
612        let shutdown_result = trace_task.abort().await.expect("tracing abort");
613        assert_eq!(
614            shutdown_result,
615            trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }
616        );
617    }
618}