use bitflags::bitflags;
use crc::Hasher32 as _;
use static_assertions::const_assert_eq;
use zerocopy::byteorder::{LE, U32, U64};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
use crate::{DeliveryBlobError, DeliveryBlobHeader, DeliveryBlobType, Type1Blob};
const DELIVERY_BLOB_MAGIC: [u8; 4] = [0xfc, 0x1a, 0xb1, 0x0b];
const_assert_eq!(std::mem::size_of::<SerializedHeader>(), 12);
const_assert_eq!(
std::mem::size_of::<SerializedType1Blob>(),
std::mem::size_of::<SerializedHeader>() + 16
);
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct SerializedType1Flags : u32 {
const IS_COMPRESSED = 0x00000001;
const VALID_FLAGS_MASK = Self::IS_COMPRESSED.bits();
}
}
impl From<&Type1Blob> for SerializedType1Flags {
fn from(value: &Type1Blob) -> Self {
if value.is_compressed {
SerializedType1Flags::IS_COMPRESSED
} else {
SerializedType1Flags::empty()
}
}
}
#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
#[repr(C)]
pub(crate) struct SerializedHeader {
magic: [u8; 4],
delivery_type: U32<LE>,
header_length: U32<LE>,
}
impl SerializedHeader {
pub fn decode(&self) -> Result<DeliveryBlobHeader, DeliveryBlobError> {
if self.magic != DELIVERY_BLOB_MAGIC {
return Err(DeliveryBlobError::BadMagic);
}
Ok(DeliveryBlobHeader {
delivery_type: self.delivery_type.get().try_into()?,
header_length: self.header_length.get(),
})
}
}
impl From<&DeliveryBlobHeader> for SerializedHeader {
fn from(value: &DeliveryBlobHeader) -> Self {
Self {
magic: DELIVERY_BLOB_MAGIC,
delivery_type: Into::<u32>::into(value.delivery_type).into(),
header_length: value.header_length.into(),
}
}
}
#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
#[repr(C)]
pub(crate) struct SerializedType1Blob {
header: SerializedHeader,
payload_length: U64<LE>,
checksum: U32<LE>,
flags: U32<LE>,
}
impl From<Type1Blob> for SerializedType1Blob {
fn from(value: Type1Blob) -> Self {
let serialized = Self {
header: (&value.header).into(),
payload_length: (value.payload_length as u64).into(),
checksum: Default::default(), flags: SerializedType1Flags::from(&value).bits().into(),
};
Self { checksum: serialized.checksum().into(), ..serialized }
}
}
impl SerializedType1Blob {
pub fn checksum(&self) -> u32 {
let header = Self { checksum: 0.into(), ..*self };
let mut digest = crc::crc32::Digest::new(crc::crc32::IEEE);
digest.write(header.as_bytes());
digest.sum32()
}
pub fn decode(&self) -> Result<Type1Blob, DeliveryBlobError> {
if self.checksum.get() != self.checksum() {
return Err(DeliveryBlobError::IntegrityError);
}
let header: DeliveryBlobHeader = self.header.decode()?;
if header.delivery_type != DeliveryBlobType::Type1 {
return Err(DeliveryBlobError::InvalidType);
}
if header.header_length != Type1Blob::HEADER.header_length {
return Err(DeliveryBlobError::IntegrityError);
}
let payload_length = self.payload_length.get() as usize;
let flags = SerializedType1Flags::from_bits(self.flags.get())
.ok_or(DeliveryBlobError::IntegrityError)?;
Ok(Type1Blob {
header,
payload_length,
is_compressed: flags.contains(SerializedType1Flags::IS_COMPRESSED),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CompressionMode;
const TEST_DATA: &[u8] = &[1, 2, 3, 4];
#[test]
fn type_1_round_trip_uncompressed() {
let delivery_blob = Type1Blob::generate(TEST_DATA, CompressionMode::Never);
assert_eq!(
delivery_blob.len(),
std::mem::size_of::<SerializedType1Blob>() + TEST_DATA.len()
);
let (metadata, payload) = Type1Blob::parse(&delivery_blob).unwrap().unwrap();
assert_eq!(metadata.header, Type1Blob::HEADER);
assert_eq!(metadata.payload_length, TEST_DATA.len());
assert_eq!(metadata.is_compressed, false);
assert_eq!(payload, TEST_DATA);
}
#[test]
fn type_1_round_trip_empty() {
let delivery_blob = Type1Blob::generate(&[], CompressionMode::Never);
assert_eq!(delivery_blob.len(), std::mem::size_of::<SerializedType1Blob>());
let (metadata, payload) = Type1Blob::parse(&delivery_blob).unwrap().unwrap();
assert_eq!(metadata.header, Type1Blob::HEADER);
assert_eq!(metadata.payload_length, 0);
assert_eq!(metadata.is_compressed, false);
assert!(payload.is_empty());
}
#[test]
fn type_1_not_enough_data() {
let delivery_blob = Type1Blob::generate(TEST_DATA, CompressionMode::Never);
let not_enough_data = &delivery_blob[..std::mem::size_of::<SerializedType1Blob>() - 1];
assert!(Type1Blob::parse(not_enough_data).unwrap().is_none());
}
#[test]
fn type_1_bad_magic() {
let valid: SerializedType1Blob =
Type1Blob { header: Type1Blob::HEADER, payload_length: 0, is_compressed: false }
.try_into()
.unwrap();
assert!(Type1Blob::parse(valid.as_bytes()).is_ok());
let mut has_corrupt_magic = SerializedType1Blob {
header: SerializedHeader { magic: [0, 0, 0, 0], ..valid.header },
..valid
};
has_corrupt_magic.checksum = has_corrupt_magic.checksum().into();
assert_eq!(
Type1Blob::parse(has_corrupt_magic.as_bytes()).unwrap_err(),
DeliveryBlobError::BadMagic
);
}
#[test]
fn type_1_invalid_type() {
let has_invalid_type: SerializedType1Blob = Type1Blob {
header: DeliveryBlobHeader {
delivery_type: DeliveryBlobType::Reserved,
header_length: 0,
},
payload_length: 0,
is_compressed: false,
}
.try_into()
.unwrap();
assert_eq!(
Type1Blob::parse(has_invalid_type.as_bytes()).unwrap_err(),
DeliveryBlobError::InvalidType
);
}
#[test]
fn type_1_invalid_header_length() {
let has_invalid_header_length: SerializedType1Blob = Type1Blob {
header: DeliveryBlobHeader {
delivery_type: DeliveryBlobType::Type1,
header_length: Type1Blob::HEADER.header_length + 1,
},
payload_length: 0,
is_compressed: false,
}
.try_into()
.unwrap();
assert_eq!(
Type1Blob::parse(has_invalid_header_length.as_bytes()).unwrap_err(),
DeliveryBlobError::IntegrityError
);
}
#[test]
fn type_1_verify_checksum() {
let serialized: SerializedType1Blob = Type1Blob {
header: Type1Blob::HEADER,
payload_length: TEST_DATA.len(),
is_compressed: false,
}
.try_into()
.unwrap();
assert!(serialized.decode().is_ok());
assert_eq!(serialized.checksum.get(), serialized.checksum());
let corrupted_checksum: u32 = !(serialized.checksum.get());
let corrupted_blob =
SerializedType1Blob { checksum: corrupted_checksum.into(), ..serialized };
assert_eq!(
Type1Blob::parse(corrupted_blob.as_bytes()).unwrap_err(),
DeliveryBlobError::IntegrityError
);
}
}