stream_processor_test/
output_validator.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::FatalError;
6use anyhow::Error;
7use async_trait::async_trait;
8use fidl_fuchsia_media::{FormatDetails, StreamOutputFormat};
9use fidl_table_validation::*;
10use fuchsia_stream_processors::*;
11use hex::{decode, encode};
12use log::info;
13use mundane::hash::{Digest, Hasher, Sha256};
14use num_traits::PrimInt;
15use std::fmt;
16use std::io::Write;
17use std::rc::Rc;
18
19#[derive(ValidFidlTable, Debug, PartialEq)]
20#[fidl_table_src(StreamOutputFormat)]
21pub struct ValidStreamOutputFormat {
22    pub stream_lifetime_ordinal: u64,
23    pub format_details: FormatDetails,
24}
25
26/// An output packet from the stream.
27#[derive(Debug, PartialEq)]
28pub struct OutputPacket {
29    pub data: Vec<u8>,
30    pub format: Rc<ValidStreamOutputFormat>,
31    pub packet: ValidPacket,
32}
33
34/// Returns all the packets in the output with preserved order.
35pub fn output_packets(output: &[Output]) -> impl Iterator<Item = &OutputPacket> {
36    output.iter().filter_map(|output| match output {
37        Output::Packet(packet) => Some(packet),
38        _ => None,
39    })
40}
41
42/// Output represents any output from a stream we might want to validate programmatically.
43///
44/// This may extend to contain not just explicit events but certain stream control behaviors or
45/// even errors.
46#[derive(Debug, PartialEq)]
47pub enum Output {
48    Packet(OutputPacket),
49    Eos { stream_lifetime_ordinal: u64 },
50    CodecChannelClose,
51}
52
53/// Checks all output packets, which are provided to the validator in the order in which they
54/// were received from the stream processor.
55///
56/// Failure should be indicated by returning an error, not by panic, so that the full context of
57/// the error will be available in the failure output.
58#[async_trait(?Send)]
59pub trait OutputValidator {
60    async fn validate(&self, output: &[Output]) -> Result<(), Error>;
61}
62
63/// Validates that the output contains the expected number of packets.
64pub struct OutputPacketCountValidator {
65    pub expected_output_packet_count: usize,
66}
67
68#[async_trait(?Send)]
69impl OutputValidator for OutputPacketCountValidator {
70    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
71        let actual_output_packet_count: usize = output
72            .iter()
73            .filter(|output| match output {
74                Output::Packet(_) => true,
75                _ => false,
76            })
77            .count();
78
79        if actual_output_packet_count != self.expected_output_packet_count {
80            return Err(FatalError(format!(
81                "actual output packet count: {}; expected output packet count: {}",
82                actual_output_packet_count, self.expected_output_packet_count
83            ))
84            .into());
85        }
86
87        Ok(())
88    }
89}
90
91/// Validates that the output contains the expected number of bytes.
92pub struct OutputDataSizeValidator {
93    pub expected_output_data_size: usize,
94}
95
96#[async_trait(?Send)]
97impl OutputValidator for OutputDataSizeValidator {
98    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
99        let actual_output_data_size: usize = output
100            .iter()
101            .map(|output| match output {
102                Output::Packet(p) => p.data.len(),
103                _ => 0,
104            })
105            .sum();
106
107        if actual_output_data_size != self.expected_output_data_size {
108            return Err(FatalError(format!(
109                "actual output data size: {}; expected output data size: {}",
110                actual_output_data_size, self.expected_output_data_size
111            ))
112            .into());
113        }
114
115        Ok(())
116    }
117}
118
119/// Validates that a stream terminates with Eos.
120pub struct TerminatesWithValidator {
121    pub expected_terminal_output: Output,
122}
123
124#[async_trait(?Send)]
125impl OutputValidator for TerminatesWithValidator {
126    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
127        let actual_terminal_output = output.last().ok_or_else(|| {
128            FatalError(format!(
129                "In terminal output: expected {:?}; found: None",
130                Some(&self.expected_terminal_output)
131            ))
132        })?;
133
134        if *actual_terminal_output == self.expected_terminal_output {
135            Ok(())
136        } else {
137            Err(FatalError(format!(
138                "In terminal output: expected {:?}; found: {:?}",
139                Some(&self.expected_terminal_output),
140                actual_terminal_output
141            ))
142            .into())
143        }
144    }
145}
146
147/// Validates that an output's format matches expected
148pub struct FormatValidator {
149    pub expected_format: FormatDetails,
150}
151
152#[async_trait(?Send)]
153impl OutputValidator for FormatValidator {
154    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
155        let packets: Vec<&OutputPacket> = output_packets(output).collect();
156        let format = &packets
157            .first()
158            .ok_or_else(|| FatalError(String::from("No packets in output")))?
159            .format
160            .format_details;
161
162        if self.expected_format != *format {
163            return Err(FatalError(format!(
164                "Expected {:?}; got {:?}",
165                self.expected_format, format
166            ))
167            .into());
168        }
169
170        Ok(())
171    }
172}
173
174/// Validates that an output's data exactly matches an expected hash, including oob_bytes
175pub struct BytesValidator {
176    pub output_file: Option<&'static str>,
177    pub expected_digests: Vec<ExpectedDigest>,
178}
179
180impl BytesValidator {
181    fn write_and_hash(
182        &self,
183        mut writer: impl Write,
184        oob: &[u8],
185        packets: &[&OutputPacket],
186    ) -> Result<(), Error> {
187        let mut hasher = Sha256::default();
188
189        hasher.update(oob);
190
191        for packet in packets {
192            writer.write_all(&packet.data)?;
193            hasher.update(&packet.data);
194        }
195        writer.flush()?;
196
197        let digest = hasher.finish().bytes();
198
199        if let None = self.expected_digests.iter().find(|e| e.bytes == digest) {
200            return Err(FatalError(format!(
201                "Expected one of {:?}; got {}",
202                self.expected_digests,
203                encode(digest)
204            ))
205            .into());
206        }
207
208        Ok(())
209    }
210}
211
212fn output_writer(output_file: Option<&'static str>) -> Result<impl Write, Error> {
213    Ok(if let Some(file) = output_file {
214        Box::new(std::fs::File::create(file)?) as Box<dyn Write>
215    } else {
216        Box::new(std::io::sink()) as Box<dyn Write>
217    })
218}
219
220#[async_trait(?Send)]
221impl OutputValidator for BytesValidator {
222    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
223        let packets: Vec<&OutputPacket> = output_packets(output).collect();
224        let oob = packets
225            .first()
226            .ok_or_else(|| FatalError(String::from("No packets in output")))?
227            .format
228            .format_details
229            .oob_bytes
230            .clone()
231            .unwrap_or(vec![]);
232
233        self.write_and_hash(output_writer(self.output_file)?, oob.as_slice(), &packets)
234    }
235}
236
237#[derive(Clone)]
238pub struct ExpectedDigest {
239    pub label: &'static str,
240    pub bytes: <<Sha256 as Hasher>::Digest as Digest>::Bytes,
241    pub per_frame_bytes: Option<Vec<<<Sha256 as Hasher>::Digest as Digest>::Bytes>>,
242}
243
244impl ExpectedDigest {
245    pub fn new(label: &'static str, hex: impl AsRef<[u8]>) -> Self {
246        Self {
247            label,
248            bytes: decode(hex)
249                .expect("Decoding static compile-time test hash as valid hex")
250                .as_slice()
251                .try_into()
252                .expect("Taking 32 bytes from compile-time test hash"),
253            per_frame_bytes: None,
254        }
255    }
256    pub fn new_with_per_frame_digest(
257        label: &'static str,
258        hex: impl AsRef<[u8]>,
259        per_frame_hexen: Vec<impl AsRef<[u8]>>,
260    ) -> Self {
261        Self {
262            per_frame_bytes: Some(
263                per_frame_hexen
264                    .into_iter()
265                    .map(|per_frame_hex| {
266                        decode(per_frame_hex)
267                            .expect("Decoding static compile-time test hash as valid hex")
268                            .as_slice()
269                            .try_into()
270                            .expect("Taking 32 bytes from compile-time test hash")
271                    })
272                    .collect(),
273            ),
274            ..Self::new(label, hex)
275        }
276    }
277
278    pub fn new_from_raw(label: &'static str, raw_data: Vec<u8>) -> Self {
279        Self {
280            label,
281            bytes: <Sha256 as Hasher>::hash(raw_data.as_slice()).bytes(),
282            per_frame_bytes: None,
283        }
284    }
285}
286
287impl fmt::Display for ExpectedDigest {
288    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
289        write!(w, "{:?}", self)
290    }
291}
292
293impl fmt::Debug for ExpectedDigest {
294    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
295        write!(w, "ExpectedDigest {{\n")?;
296        write!(w, "\tlabel: {}", self.label)?;
297        write!(w, "\tbytes: {}", encode(self.bytes))?;
298        write!(w, "}}")
299    }
300}
301
302/// Validates that the RMSE of output data and the expected data
303/// falls within an acceptable range.
304#[allow(unused)]
305pub struct RmseValidator<T> {
306    pub output_file: Option<&'static str>,
307    pub expected_data: Vec<T>,
308    pub expected_rmse: f64,
309    // By how much percentage should we allow the calculated RMSE value to
310    // differ from the expected RMSE.
311    pub rmse_diff_tolerance: f64,
312    pub data_len_diff_tolerance: u32,
313    pub output_converter: fn(Vec<u8>) -> Vec<T>,
314}
315
316pub fn calculate_rmse<T: PrimInt + std::fmt::Debug>(
317    expected_data: &[T],
318    actual_data: &[T],
319    acceptable_len_diff: u32,
320) -> Result<f64, Error> {
321    // There could be a slight difference to the length of the expected data
322    // and the actual data due to the way some codecs deal with left over data
323    // at the end of the stream. This can be caused by minimum block size and
324    // how some codecs may choose to pad out the last block or insert a silence
325    // data at the start. Ensure the difference in length between expected and
326    // actual data is not too much.
327    let compare_len = std::cmp::min(expected_data.len(), actual_data.len());
328    if std::cmp::max(expected_data.len(), actual_data.len()) - compare_len
329        > acceptable_len_diff.try_into().unwrap()
330    {
331        return Err(FatalError(format!(
332            "Expected data (len {}) and the actual data (len {}) have significant length difference and cannot be compared.",
333            expected_data.len(), actual_data.len(),
334        )).into());
335    }
336    let expected_data = &expected_data[..compare_len];
337    let actual_data = &actual_data[..compare_len];
338
339    let mut rmse = 0.0;
340    let mut n = 0;
341    for data in std::iter::zip(actual_data.iter(), expected_data.iter()) {
342        let b1: f64 = num_traits::cast::cast(*data.0).unwrap();
343        let b2: f64 = num_traits::cast::cast(*data.1).unwrap();
344        rmse += (b1 - b2).powi(2);
345        n += 1;
346    }
347    Ok((rmse / n as f64).sqrt())
348}
349
350impl<T: PrimInt + std::fmt::Debug> RmseValidator<T> {
351    fn write_and_calc_rsme(
352        &self,
353        mut writer: impl Write,
354        packets: &[&OutputPacket],
355    ) -> Result<(), Error> {
356        let mut output_data: Vec<u8> = Vec::new();
357        for packet in packets {
358            writer.write_all(&packet.data)?;
359            packet.data.iter().for_each(|item| output_data.push(*item));
360        }
361
362        let actual_data = (self.output_converter)(output_data);
363
364        let rmse = calculate_rmse(
365            self.expected_data.as_slice(),
366            actual_data.as_slice(),
367            self.data_len_diff_tolerance,
368        )?;
369        info!("RMSE is {}", rmse);
370        if (rmse - self.expected_rmse).abs() > self.rmse_diff_tolerance {
371            return Err(FatalError(format!(
372                "expected rmse: {}; actual rmse: {}; rmse diff tolerance {}",
373                self.expected_rmse, rmse, self.rmse_diff_tolerance,
374            ))
375            .into());
376        }
377        Ok(())
378    }
379}
380
381#[async_trait(?Send)]
382impl<T: PrimInt + std::fmt::Debug> OutputValidator for RmseValidator<T> {
383    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
384        let packets: Vec<&OutputPacket> = output_packets(output).collect();
385        self.write_and_calc_rsme(output_writer(self.output_file)?, &packets)
386    }
387}