Skip to main content

trace_task/
trace_task.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::triggers::{Trigger, TriggerAction, TriggersWatcher};
6use crate::{TracingError, trace_shutdown};
7use async_lock::Mutex;
8use flex_client::{AsyncSocket, ProxyHasDomain, socket_to_async};
9use flex_fuchsia_tracing_controller::{self as trace, StopResult, TraceConfig};
10use fuchsia_async::Task;
11use futures::io::AsyncWrite;
12use futures::prelude::*;
13use futures::task::{Context as FutContext, Poll};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64};
17use std::time::{Duration, Instant};
18use zstd::stream::raw::Operation;
19
20static SERIAL: AtomicU64 = AtomicU64::new(100);
21
22#[derive(Debug)]
23pub struct TraceTask {
24    /// Unique identifier for this task. The value of this id monotonicallly increases.
25    task_id: u64,
26    /// Tag used to identify this task in the log.
27    debug_tag: String,
28    /// Trace configuration.
29    config: trace::TraceConfig,
30    /// Requested categories. These are unexpanded from the user.
31    requested_categories: Vec<String>,
32    /// Duration to capture trace. None indicates capture until canceled.
33    duration: Option<Duration>,
34    /// Triggers for terminating the trace.
35    triggers: Vec<Trigger>,
36    /// True when the task is cleaning up.
37    terminating: Arc<AtomicBool>,
38    /// Start time of the task.
39    start_time: Instant,
40    /// Channel used to shutdown this task.
41    shutdown_sender: async_channel::Sender<()>,
42    /// The task.
43    task: Task<Option<trace::StopResult>>,
44    /// The socket to read the trace data from when tracing is completed.
45    read_socket: AsyncSocket,
46    /// The compression algorithm to use.
47    compression: trace::CompressionType,
48    /// True when the task was cancelled (aborted).
49    cancelled: Arc<AtomicBool>,
50}
51
52// This is just implemented for convenience so the wrapper is await-able.
53impl Future for TraceTask {
54    type Output = Option<trace::StopResult>;
55
56    fn poll(mut self: Pin<&mut Self>, cx: &mut FutContext<'_>) -> Poll<Self::Output> {
57        Pin::new(&mut self.task).poll(cx)
58    }
59}
60
61impl TraceTask {
62    pub async fn new(
63        debug_tag: String,
64        config: trace::TraceConfig,
65        duration: Option<Duration>,
66        triggers: Vec<Trigger>,
67        requested_categories: Option<Vec<String>>,
68        compression: trace::CompressionType,
69        provisioner: trace::ProvisionerProxy,
70    ) -> Result<Self, TracingError> {
71        // Start the tracing session immediately. Maybe we should consider separating the creating
72        // of the session and the actual starting of it. This seems like a side-effect.
73        log::info!("TraceTask::new called with compression: {:?}", compression);
74        let task_id = SERIAL.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
75        let (client, server) = provisioner.domain().create_stream_socket();
76        let (client_end, server_end) = provisioner.domain().create_proxy::<trace::SessionMarker>();
77        provisioner.initialize_tracing(server_end, &config, server)?;
78
79        client_end
80            .start_tracing(&trace::StartOptions::default())
81            .await?
82            .map_err(Into::<TracingError>::into)?;
83
84        let logging_prefix_og = format!("Task {task_id} ({debug_tag})");
85        let terminate_result = Arc::new(Mutex::new(None));
86        let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
87
88        let controller = client_end.clone();
89        let shutdown_controller = client_end.clone();
90        let triggers_watcher =
91            TriggersWatcher::new(controller, triggers.clone(), shutdown_receiver);
92        let terminating = Arc::new(AtomicBool::new(false));
93        let cancelled = Arc::new(AtomicBool::new(false));
94        let terminating_clone = terminating.clone();
95        let cancelled_clone = cancelled.clone();
96        let terminate_result_clone = terminate_result.clone();
97
98        let shutdown_fut = {
99            let logging_prefix = logging_prefix_og.clone();
100            async move {
101                if terminating_clone
102                    .compare_exchange(
103                        false,
104                        true,
105                        std::sync::atomic::Ordering::SeqCst,
106                        std::sync::atomic::Ordering::Relaxed,
107                    )
108                    .is_ok()
109                {
110                    log::info!("{logging_prefix} Running shutdown future.");
111                    let is_cancelled = cancelled_clone.load(std::sync::atomic::Ordering::Relaxed);
112                    let result = trace_shutdown(&shutdown_controller, is_cancelled).await;
113
114                    let mut done = terminate_result_clone.lock().await;
115                    if done.is_none() {
116                        match result {
117                            Ok(stop) => {
118                                log::info!("{logging_prefix} call to trace_shutdown successful.");
119                                *done = Some(stop)
120                            }
121                            Err(e) => {
122                                log::error!(
123                                    "{logging_prefix} call to trace_shutdown failed: {e:?}"
124                                );
125                            }
126                        }
127                    }
128                } else {
129                    log::debug!("Shutdown already triggered");
130                }
131                "shutdown future completed"
132            }
133        };
134
135        let task = Self::make_task(
136            task_id,
137            debug_tag,
138            duration,
139            shutdown_fut,
140            triggers_watcher,
141            terminate_result,
142        );
143        Ok(Self {
144            task_id,
145            debug_tag: logging_prefix_og,
146            config,
147            duration,
148            triggers: triggers.clone(),
149            terminating,
150            requested_categories: requested_categories.unwrap_or_default(),
151            start_time: Instant::now(),
152            shutdown_sender,
153            read_socket: socket_to_async(client),
154            compression,
155            task,
156            cancelled,
157        })
158    }
159
160    /// Shutdown the tracing task.
161    async fn shutdown(self) -> Result<trace::StopResult, TracingError> {
162        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
163            log::info!("{} Sending shutdown message.", self.debug_tag);
164            if self.shutdown_sender.send(()).await.is_err() {
165                log::warn!(
166                    "{} Shutdown channel was closed. Task may have already completed.",
167                    self.debug_tag
168                );
169            }
170        } else {
171            log::debug!("{} Shutdown already in progress.", self.debug_tag);
172        }
173
174        self.await
175            .map(|r| Ok(r))
176            .unwrap_or_else(|| Err(TracingError::RecordingStop("Error awaiting".into())))
177    }
178
179    /// Abort the tracing task without writing results.
180    pub async fn abort(mut self) -> Result<trace::StopResult, TracingError> {
181        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
182            log::info!("{} Sending cancel message for task", self.debug_tag);
183            self.cancelled.store(true, std::sync::atomic::Ordering::SeqCst);
184            if self.shutdown_sender.send(()).await.is_err() {
185                log::warn!(
186                    "{} Shutdown channel was closed. Task may have already completed.",
187                    self.debug_tag
188                );
189            }
190        } else {
191            log::debug!("{} Shutdown already in progress.", self.debug_tag);
192        }
193
194        // Close the socket without reading.
195        let res = self.read_socket.close().await;
196        if res.is_err() {
197            log::warn!("{} Failed to close socket: {:?}", self.debug_tag, res);
198        }
199        self.shutdown().await
200    }
201
202    fn make_task(
203        task_id: u64,
204        debug_tag: String,
205        duration: Option<Duration>,
206        shutdown_fut: impl Future<Output = &'static str> + 'static + std::marker::Send,
207        trigger_watcher: TriggersWatcher<'static>,
208        terminate_result: Arc<Mutex<Option<StopResult>>>,
209    ) -> Task<Option<trace::StopResult>> {
210        Task::local(async move {
211            let mut timeout_fut = Box::pin(async move {
212                if let Some(duration) = duration {
213                    fuchsia_async::Timer::new(duration).await;
214                } else {
215                    std::future::pending::<()>().await;
216                }
217            })
218            .fuse();
219            let mut trigger_fut = trigger_watcher.fuse();
220
221            futures::select! {
222                // Timeout, clean up and wait for copying to finish.
223                _ = timeout_fut => {
224                    log::info!("Trace {task_id} (debug_tag): timeout of {} successfully completed. Stopping and cleaning up.",
225                     duration.map(|d| format!("{} secs", d.as_secs())).unwrap_or_else(|| "infinite?".into()));
226
227                    shutdown_fut.await;
228                     log::debug!("done with timeout!");
229
230                }
231
232                // Trigger hit, shutdown and copy the trace.
233                action = trigger_fut => {
234                    if let Some(action) = action {
235                        match action {
236                            TriggerAction::Terminate => {
237                                log::info!("Task {task_id} ({debug_tag}): received terminate trigger");
238                            }
239                        }
240                    } else {
241                        // This usually means the proxy was closed.
242                        log::debug!("Task {task_id} ({debug_tag}): Trigger future completed without an action!");
243                    }
244                    shutdown_fut.await;
245                     log::debug!("done with trigger future!");
246                }
247            };
248            log::debug!("end of task waiting for terminate_result lock");
249            let res = terminate_result.lock().await.clone();
250            log::debug!("got res in task is some: {}", res.is_some());
251            res
252        })
253    }
254
255    pub fn triggers(&self) -> Vec<Trigger> {
256        self.triggers.clone()
257    }
258    pub fn config(&self) -> TraceConfig {
259        self.config.clone()
260    }
261
262    pub fn start_time(&self) -> Instant {
263        self.start_time
264    }
265
266    pub fn duration(&self) -> Option<Duration> {
267        self.duration.clone()
268    }
269
270    pub fn requested_categories(&self) -> Vec<String> {
271        self.requested_categories.clone()
272    }
273
274    pub fn task_id(&self) -> u64 {
275        self.task_id
276    }
277
278    /// Signals the trace session to stop, copies all trace data to the
279    /// provided writer, and awaits task completion.
280    pub async fn stop_and_receive_data<W>(
281        self,
282        mut writer: W,
283    ) -> Result<trace::StopResult, TracingError>
284    where
285        W: AsyncWrite + Unpin + Send + 'static,
286    {
287        if !self.terminating.load(std::sync::atomic::Ordering::SeqCst) {
288            log::info!("{} Sending shutdown message for task", self.debug_tag);
289            if self.shutdown_sender.send(()).await.is_err() {
290                log::warn!(
291                    "{} Shutdown channel was closed. Task may have already completed.",
292                    self.debug_tag
293                );
294            }
295        } else {
296            log::debug!("{} Shutdown already in progress.", self.debug_tag);
297        }
298
299        let res = match self.compression {
300            trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
301            _ => futures::io::copy(&self.read_socket, &mut writer)
302                .await
303                .map(|_| ())
304                .map_err(|e| TracingError::GeneralError(format!("{e:?}"))),
305        };
306
307        if res.is_ok() { self.shutdown().await } else { Err(res.err().unwrap()) }
308    }
309
310    /// Waits for the tracing task to complete and copies the trace data to the writer.
311    /// If the tracing should be stopped vs. waiting, call |stop_and_receive_data|.
312    pub async fn await_completion_and_receive_data<W>(
313        self,
314        mut writer: W,
315    ) -> Result<StopResult, TracingError>
316    where
317        W: AsyncWrite + Unpin + Send + 'static,
318    {
319        let res = match self.compression {
320            trace::CompressionType::Zstd => compress_zstd(&self.read_socket, &mut writer).await,
321            _ => futures::io::copy(&self.read_socket, &mut writer)
322                .await
323                .map(|_| ())
324                .map_err(|e| TracingError::RecordingStop(e.to_string())),
325        };
326
327        match res {
328            Ok(_) => match self.await {
329                Some(r) => Ok(r),
330                None => Err(TracingError::RecordingStop("could not await task".into())),
331            },
332            Err(e) => Err(e),
333        }
334    }
335}
336
337async fn compress_zstd<R, W>(mut reader: R, mut writer: W) -> Result<(), TracingError>
338where
339    R: AsyncRead + Unpin,
340    W: AsyncWrite + Unpin,
341{
342    let mut encoder = zstd::stream::raw::Encoder::new(0)
343        .map_err(|e| TracingError::GeneralError(format!("zstd init: {e:?}")))?;
344    // 128KB is the recommended size for Zstd (ZSTD_CStreamInSize/ZSTD_CStreamOutSize)
345    let mut input_buf = vec![0u8; 128 * 1024];
346    let mut output_buf = vec![0u8; 128 * 1024];
347
348    // Read the stream until EOF, compressing and writing out fully compressed buffers as we go.
349    while let n = reader
350        .read(&mut input_buf)
351        .await
352        .map_err(|e| TracingError::GeneralError(format!("read: {e:?}")))?
353        && n > 0
354    {
355        let mut read_offset = 0;
356        while read_offset < n {
357            let status = encoder
358                .run_on_buffers(&input_buf[read_offset..n], &mut output_buf)
359                .map_err(|e| TracingError::GeneralError(format!("zstd run: {e:?}")))?;
360            read_offset += status.bytes_read;
361            if status.bytes_written > 0 {
362                writer
363                    .write_all(&output_buf[..status.bytes_written])
364                    .await
365                    .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
366            }
367        }
368    }
369
370    // Flush remaining of the last compressed buffer.
371    loop {
372        let mut out_wrapper = zstd::stream::raw::OutBuffer::around(&mut output_buf);
373        let remaining = encoder
374            .finish(&mut out_wrapper, true)
375            .map_err(|e| TracingError::GeneralError(format!("zstd finish: {e:?}")))?;
376        let bytes = out_wrapper.as_slice();
377        if !bytes.is_empty() {
378            writer
379                .write_all(bytes)
380                .await
381                .map_err(|e| TracingError::GeneralError(format!("write: {e:?}")))?;
382        }
383        if remaining == 0 {
384            break;
385        }
386    }
387    Ok(())
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use flex_fuchsia_tracing_controller::StartError;
394
395    const FAKE_CONTROLLER_TRACE_OUTPUT: &'static str = "HOWDY HOWDY HOWDY";
396
397    fn setup_fake_provisioner_proxy(
398        start_error: Option<StartError>,
399        trigger_name: Option<&'static str>,
400        expected_write_results: bool,
401    ) -> trace::ProvisionerProxy {
402        let (proxy, mut stream) =
403            fidl::endpoints::create_proxy_and_stream::<trace::ProvisionerMarker>();
404        fuchsia_async::Task::local(async move {
405            while let Ok(Some(req)) = stream.try_next().await {
406                match req {
407                    trace::ProvisionerRequest::InitializeTracing { controller, output, .. } => {
408                        let mut stream = controller.into_stream();
409                        while let Ok(Some(req)) = stream.try_next().await {
410                            match req {
411                                trace::SessionRequest::StartTracing { responder, .. } => {
412                                    let response = match start_error {
413                                        Some(e) => Err(e),
414                                        None => Ok(()),
415                                    };
416                                    responder.send(response).expect("Failed to start")
417                                }
418                                trace::SessionRequest::StopTracing { responder, payload } => {
419                                    if start_error.is_some() {
420                                        responder
421                                            .send(Err(trace::StopError::NotStarted))
422                                            .expect("Failed to stop")
423                                    } else {
424                                        assert_eq!(
425                                            payload.write_results.unwrap(),
426                                            expected_write_results
427                                        );
428                                        assert_eq!(
429                                            FAKE_CONTROLLER_TRACE_OUTPUT.len(),
430                                            output
431                                                .write(FAKE_CONTROLLER_TRACE_OUTPUT.as_bytes())
432                                                .unwrap()
433                                        );
434                                        let stop_result = trace::StopResult {
435                                            provider_stats: Some(vec![]),
436                                            ..Default::default()
437                                        };
438                                        responder.send(Ok(&stop_result)).expect("Failed to stop")
439                                    }
440                                    break;
441                                }
442                                trace::SessionRequest::WatchAlert { responder } => {
443                                    responder
444                                        .send(trigger_name.unwrap_or(""))
445                                        .expect("Unable to send alert");
446                                }
447                                r => panic!("unexpected request: {:#?}", r),
448                            }
449                        }
450                    }
451                    r => panic!("unexpected request: {:#?}", r),
452                }
453            }
454        })
455        .detach();
456        proxy
457    }
458
459    #[fuchsia::test]
460    async fn test_trace_task_start_stop_write_check_with_vec() {
461        let provisioner = setup_fake_provisioner_proxy(None, None, true);
462
463        let trace_task = TraceTask::new(
464            "test_trace_start_stop_write_check".into(),
465            trace::TraceConfig::default(),
466            None,
467            vec![],
468            None,
469            trace::CompressionType::None,
470            provisioner,
471        )
472        .await
473        .expect("tracing task started");
474
475        let shutdown_result = trace_task.shutdown().await.expect("tracing shutdown");
476        assert_eq!(
477            shutdown_result,
478            trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }.into()
479        );
480    }
481
482    #[cfg(not(target_os = "fuchsia"))]
483    #[fuchsia::test]
484    async fn test_trace_task_start_stop_write_check_with_file() {
485        let temp_dir = tempfile::TempDir::new().unwrap();
486        let output = temp_dir.path().join("trace-test.fxt");
487
488        let provisioner = setup_fake_provisioner_proxy(None, None, true);
489        let writer = async_fs::File::create(&output).await.unwrap();
490
491        let trace_task = TraceTask::new(
492            "test_trace_start_stop_write_check".into(),
493            trace::TraceConfig::default(),
494            None,
495            vec![],
496            None,
497            trace::CompressionType::None,
498            provisioner,
499        )
500        .await
501        .expect("tracing task started");
502
503        let shutdown_result =
504            trace_task.stop_and_receive_data(writer).await.expect("tracing shutdown");
505
506        let res = async_fs::read_to_string(&output).await.unwrap();
507        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
508        let expected = trace::StopResult { provider_stats: Some(vec![]), ..Default::default() };
509        assert_eq!(shutdown_result, expected);
510    }
511
512    #[fuchsia::test]
513    async fn test_trace_error_handling_already_started() {
514        let provisioner =
515            setup_fake_provisioner_proxy(Some(StartError::AlreadyStarted), None, true);
516
517        let trace_task_result = TraceTask::new(
518            "test_trace_error_handling_already_started".into(),
519            trace::TraceConfig::default(),
520            None,
521            vec![],
522            None,
523            trace::CompressionType::None,
524            provisioner,
525        )
526        .await
527        .err();
528
529        assert_eq!(trace_task_result, Some(TracingError::RecordingAlreadyStarted));
530    }
531
532    #[cfg(not(target_os = "fuchsia"))]
533    #[fuchsia::test]
534    async fn test_trace_task_start_with_duration() {
535        let temp_dir = tempfile::TempDir::new().unwrap();
536        let output = temp_dir.path().join("trace-test.fxt");
537
538        let provisioner = setup_fake_provisioner_proxy(None, None, true);
539        let writer = async_fs::File::create(&output).await.unwrap();
540
541        let trace_task = TraceTask::new(
542            "test_trace_task_start_with_duration".into(),
543            trace::TraceConfig::default(),
544            Some(Duration::from_millis(100)),
545            vec![],
546            None,
547            trace::CompressionType::None,
548            provisioner,
549        )
550        .await
551        .expect("tracing task started");
552
553        let res = trace_task.await_completion_and_receive_data(writer).await;
554        if let Some(ref stop_result) = res.as_ref().ok() {
555            assert!(stop_result.provider_stats.is_some());
556        } else {
557            panic!("Expected stop result from trace_task.await: {res:?}");
558        }
559
560        let mut f = async_fs::File::open(std::path::PathBuf::from(output)).await.unwrap();
561        let mut res = String::new();
562        f.read_to_string(&mut res).await.unwrap();
563        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
564    }
565
566    #[cfg(not(target_os = "fuchsia"))]
567    #[fuchsia::test]
568    async fn test_triggers_valid() {
569        let temp_dir = tempfile::TempDir::new().unwrap();
570        let output = temp_dir.path().join("trace-test.fxt");
571        let alert_name = "some_alert";
572        let provisioner = setup_fake_provisioner_proxy(None, Some(alert_name.into()), true);
573        let writer = async_fs::File::create(output.clone()).await.unwrap();
574
575        let trace_task = TraceTask::new(
576            "test_triggers_valid".into(),
577            trace::TraceConfig::default(),
578            None,
579            vec![Trigger {
580                alert: Some(alert_name.into()),
581                action: Some(TriggerAction::Terminate),
582            }],
583            None,
584            trace::CompressionType::None,
585            provisioner,
586        )
587        .await
588        .expect("tracing task started");
589
590        trace_task.await_completion_and_receive_data(writer).await.unwrap();
591        let res = async_fs::read_to_string(&output).await.unwrap();
592        assert_eq!(res, FAKE_CONTROLLER_TRACE_OUTPUT.to_string());
593    }
594
595    #[fuchsia::test]
596    async fn test_trace_task_abort() {
597        let provisioner = setup_fake_provisioner_proxy(None, None, false);
598
599        let trace_task = TraceTask::new(
600            "test_trace_task_abort".into(),
601            trace::TraceConfig::default(),
602            None,
603            vec![],
604            None,
605            trace::CompressionType::None,
606            provisioner,
607        )
608        .await
609        .expect("tracing task started");
610
611        let shutdown_result = trace_task.abort().await.expect("tracing abort");
612        assert_eq!(
613            shutdown_result,
614            trace::StopResult { provider_stats: Some(vec![]), ..Default::default() }
615        );
616    }
617}