1use crate::FatalError;
6use anyhow::Error;
7use async_trait::async_trait;
8use fidl_fuchsia_media::{FormatDetails, StreamOutputFormat};
9use fidl_table_validation::*;
10use fuchsia_stream_processors::*;
11use hex::{decode, encode};
12use log::info;
13use mundane::hash::{Digest, Hasher, Sha256};
14use num_traits::PrimInt;
15use std::fmt;
16use std::io::Write;
17use std::rc::Rc;
18
19#[derive(ValidFidlTable, Debug, PartialEq)]
20#[fidl_table_src(StreamOutputFormat)]
21pub struct ValidStreamOutputFormat {
22 pub stream_lifetime_ordinal: u64,
23 pub format_details: FormatDetails,
24}
25
26#[derive(Debug, PartialEq)]
28pub struct OutputPacket {
29 pub data: Vec<u8>,
30 pub format: Rc<ValidStreamOutputFormat>,
31 pub packet: ValidPacket,
32}
33
34pub fn output_packets(output: &[Output]) -> impl Iterator<Item = &OutputPacket> {
36 output.iter().filter_map(|output| match output {
37 Output::Packet(packet) => Some(packet),
38 _ => None,
39 })
40}
41
42#[derive(Debug, PartialEq)]
47pub enum Output {
48 Packet(OutputPacket),
49 Eos { stream_lifetime_ordinal: u64 },
50 CodecChannelClose,
51}
52
53#[async_trait(?Send)]
59pub trait OutputValidator {
60 async fn validate(&self, output: &[Output]) -> Result<(), Error>;
61}
62
63pub struct OutputPacketCountValidator {
65 pub expected_output_packet_count: usize,
66}
67
68#[async_trait(?Send)]
69impl OutputValidator for OutputPacketCountValidator {
70 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
71 let actual_output_packet_count: usize = output
72 .iter()
73 .filter(|output| match output {
74 Output::Packet(_) => true,
75 _ => false,
76 })
77 .count();
78
79 if actual_output_packet_count != self.expected_output_packet_count {
80 return Err(FatalError(format!(
81 "actual output packet count: {}; expected output packet count: {}",
82 actual_output_packet_count, self.expected_output_packet_count
83 ))
84 .into());
85 }
86
87 Ok(())
88 }
89}
90
91pub struct OutputDataSizeValidator {
93 pub expected_output_data_size: usize,
94}
95
96#[async_trait(?Send)]
97impl OutputValidator for OutputDataSizeValidator {
98 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
99 let actual_output_data_size: usize = output
100 .iter()
101 .map(|output| match output {
102 Output::Packet(p) => p.data.len(),
103 _ => 0,
104 })
105 .sum();
106
107 if actual_output_data_size != self.expected_output_data_size {
108 return Err(FatalError(format!(
109 "actual output data size: {}; expected output data size: {}",
110 actual_output_data_size, self.expected_output_data_size
111 ))
112 .into());
113 }
114
115 Ok(())
116 }
117}
118
119pub struct TerminatesWithValidator {
121 pub expected_terminal_output: Output,
122}
123
124#[async_trait(?Send)]
125impl OutputValidator for TerminatesWithValidator {
126 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
127 let actual_terminal_output = output.last().ok_or_else(|| {
128 FatalError(format!(
129 "In terminal output: expected {:?}; found: None",
130 Some(&self.expected_terminal_output)
131 ))
132 })?;
133
134 if *actual_terminal_output == self.expected_terminal_output {
135 Ok(())
136 } else {
137 Err(FatalError(format!(
138 "In terminal output: expected {:?}; found: {:?}",
139 Some(&self.expected_terminal_output),
140 actual_terminal_output
141 ))
142 .into())
143 }
144 }
145}
146
147pub struct FormatValidator {
149 pub expected_format: FormatDetails,
150}
151
152#[async_trait(?Send)]
153impl OutputValidator for FormatValidator {
154 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
155 let packets: Vec<&OutputPacket> = output_packets(output).collect();
156 let format = &packets
157 .first()
158 .ok_or_else(|| FatalError(String::from("No packets in output")))?
159 .format
160 .format_details;
161
162 if self.expected_format != *format {
163 return Err(FatalError(format!(
164 "Expected {:?}; got {:?}",
165 self.expected_format, format
166 ))
167 .into());
168 }
169
170 Ok(())
171 }
172}
173
174pub struct BytesValidator {
176 pub output_file: Option<&'static str>,
177 pub expected_digests: Vec<ExpectedDigest>,
178}
179
180impl BytesValidator {
181 fn write_and_hash(
182 &self,
183 mut writer: impl Write,
184 oob: &[u8],
185 packets: &[&OutputPacket],
186 ) -> Result<(), Error> {
187 let mut hasher = Sha256::default();
188
189 hasher.update(oob);
190
191 for packet in packets {
192 writer.write_all(&packet.data)?;
193 hasher.update(&packet.data);
194 }
195 writer.flush()?;
196
197 let digest = hasher.finish().bytes();
198
199 if let None = self.expected_digests.iter().find(|e| e.bytes == digest) {
200 return Err(FatalError(format!(
201 "Expected one of {:?}; got {}",
202 self.expected_digests,
203 encode(digest)
204 ))
205 .into());
206 }
207
208 Ok(())
209 }
210}
211
212fn output_writer(output_file: Option<&'static str>) -> Result<impl Write, Error> {
213 Ok(if let Some(file) = output_file {
214 Box::new(std::fs::File::create(file)?) as Box<dyn Write>
215 } else {
216 Box::new(std::io::sink()) as Box<dyn Write>
217 })
218}
219
220#[async_trait(?Send)]
221impl OutputValidator for BytesValidator {
222 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
223 let packets: Vec<&OutputPacket> = output_packets(output).collect();
224 let oob = packets
225 .first()
226 .ok_or_else(|| FatalError(String::from("No packets in output")))?
227 .format
228 .format_details
229 .oob_bytes
230 .clone()
231 .unwrap_or(vec![]);
232
233 self.write_and_hash(output_writer(self.output_file)?, oob.as_slice(), &packets)
234 }
235}
236
237#[derive(Clone)]
238pub struct ExpectedDigest {
239 pub label: &'static str,
240 pub bytes: <<Sha256 as Hasher>::Digest as Digest>::Bytes,
241 pub per_frame_bytes: Option<Vec<<<Sha256 as Hasher>::Digest as Digest>::Bytes>>,
242}
243
244impl ExpectedDigest {
245 pub fn new(label: &'static str, hex: impl AsRef<[u8]>) -> Self {
246 Self {
247 label,
248 bytes: decode(hex)
249 .expect("Decoding static compile-time test hash as valid hex")
250 .as_slice()
251 .try_into()
252 .expect("Taking 32 bytes from compile-time test hash"),
253 per_frame_bytes: None,
254 }
255 }
256 pub fn new_with_per_frame_digest(
257 label: &'static str,
258 hex: impl AsRef<[u8]>,
259 per_frame_hexen: Vec<impl AsRef<[u8]>>,
260 ) -> Self {
261 Self {
262 per_frame_bytes: Some(
263 per_frame_hexen
264 .into_iter()
265 .map(|per_frame_hex| {
266 decode(per_frame_hex)
267 .expect("Decoding static compile-time test hash as valid hex")
268 .as_slice()
269 .try_into()
270 .expect("Taking 32 bytes from compile-time test hash")
271 })
272 .collect(),
273 ),
274 ..Self::new(label, hex)
275 }
276 }
277
278 pub fn new_from_raw(label: &'static str, raw_data: Vec<u8>) -> Self {
279 Self {
280 label,
281 bytes: <Sha256 as Hasher>::hash(raw_data.as_slice()).bytes(),
282 per_frame_bytes: None,
283 }
284 }
285}
286
287impl fmt::Display for ExpectedDigest {
288 fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
289 write!(w, "{:?}", self)
290 }
291}
292
293impl fmt::Debug for ExpectedDigest {
294 fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
295 write!(w, "ExpectedDigest {{\n")?;
296 write!(w, "\tlabel: {}", self.label)?;
297 write!(w, "\tbytes: {}", encode(self.bytes))?;
298 write!(w, "}}")
299 }
300}
301
302#[allow(unused)]
305pub struct RmseValidator<T> {
306 pub output_file: Option<&'static str>,
307 pub expected_data: Vec<T>,
308 pub expected_rmse: f64,
309 pub rmse_diff_tolerance: f64,
312 pub data_len_diff_tolerance: u32,
313 pub output_converter: fn(Vec<u8>) -> Vec<T>,
314}
315
316pub fn calculate_rmse<T: PrimInt + std::fmt::Debug>(
317 expected_data: &[T],
318 actual_data: &[T],
319 acceptable_len_diff: u32,
320) -> Result<f64, Error> {
321 let compare_len = std::cmp::min(expected_data.len(), actual_data.len());
328 if std::cmp::max(expected_data.len(), actual_data.len()) - compare_len
329 > acceptable_len_diff.try_into().unwrap()
330 {
331 return Err(FatalError(format!(
332 "Expected data (len {}) and the actual data (len {}) have significant length difference and cannot be compared.",
333 expected_data.len(), actual_data.len(),
334 )).into());
335 }
336 let expected_data = &expected_data[..compare_len];
337 let actual_data = &actual_data[..compare_len];
338
339 let mut rmse = 0.0;
340 let mut n = 0;
341 for data in std::iter::zip(actual_data.iter(), expected_data.iter()) {
342 let b1: f64 = num_traits::cast::cast(*data.0).unwrap();
343 let b2: f64 = num_traits::cast::cast(*data.1).unwrap();
344 rmse += (b1 - b2).powi(2);
345 n += 1;
346 }
347 Ok((rmse / n as f64).sqrt())
348}
349
350impl<T: PrimInt + std::fmt::Debug> RmseValidator<T> {
351 fn write_and_calc_rsme(
352 &self,
353 mut writer: impl Write,
354 packets: &[&OutputPacket],
355 ) -> Result<(), Error> {
356 let mut output_data: Vec<u8> = Vec::new();
357 for packet in packets {
358 writer.write_all(&packet.data)?;
359 packet.data.iter().for_each(|item| output_data.push(*item));
360 }
361
362 let actual_data = (self.output_converter)(output_data);
363
364 let rmse = calculate_rmse(
365 self.expected_data.as_slice(),
366 actual_data.as_slice(),
367 self.data_len_diff_tolerance,
368 )?;
369 info!("RMSE is {}", rmse);
370 if (rmse - self.expected_rmse).abs() > self.rmse_diff_tolerance {
371 return Err(FatalError(format!(
372 "expected rmse: {}; actual rmse: {}; rmse diff tolerance {}",
373 self.expected_rmse, rmse, self.rmse_diff_tolerance,
374 ))
375 .into());
376 }
377 Ok(())
378 }
379}
380
381#[async_trait(?Send)]
382impl<T: PrimInt + std::fmt::Debug> OutputValidator for RmseValidator<T> {
383 async fn validate(&self, output: &[Output]) -> Result<(), Error> {
384 let packets: Vec<&OutputPacket> = output_packets(output).collect();
385 self.write_and_calc_rsme(output_writer(self.output_file)?, &packets)
386 }
387}