Skip to main content

delivery_blob/
compression.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implementation of chunked-compression library in Rust. Archives can be created by making a new
6//! [`ChunkedArchive`] and serializing/writing it. An archive's header can be verified and seek
7//! table decoded using [`decode_archive`].
8
9use itertools::Itertools;
10use rayon::prelude::*;
11use std::ops::Range;
12use thiserror::Error;
13use zerocopy::byteorder::{LE, U16, U32, U64};
14use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned};
15
16mod compression_algorithm;
17pub use compression_algorithm::{
18    CompressionAlgorithm, Compressor, Decompressor, ThreadLocalCompressor, ThreadLocalDecompressor,
19};
20
21/// Validated chunk information from an archive. Compressed ranges are relative to the start of
22/// compressed data (i.e. they start after the header and seek table).
23// *NOTE*: Use caution when using the `#[source]` attribute or naming fields `source`. Some callers
24// attempt to downcast library errors into the concrete type of the root cause.
25// See https://docs.rs/thiserror/latest/thiserror/ for more information.
26#[derive(Debug, Error)]
27pub enum ChunkedArchiveError {
28    #[error("Invalid or unsupported archive version.")]
29    InvalidVersion,
30
31    #[error("Archive header has incorrect magic.")]
32    BadMagic,
33
34    #[error("Integrity checks failed (e.g. incorrect CRC, inconsistent header fields).")]
35    IntegrityError,
36
37    #[error("Value is out of range or cannot be represented in specified type.")]
38    OutOfRange,
39
40    #[error("Error decompressing chunk {index}: `{error}`.")]
41    DecompressionError { index: usize, error: std::io::Error },
42
43    #[error("Error compressing chunk {index}: `{error}`.")]
44    CompressionError { index: usize, error: std::io::Error },
45}
46
47/// Options for constructing a chunked archive.
48#[derive(Copy, Clone, Debug, Eq, PartialEq)]
49pub enum ChunkedArchiveOptions {
50    /// A chunked-compression V2 archive will be created.
51    V2 {
52        /// Chunked-compression V2 has a limit of 1023 chunks. If splitting the data up into
53        /// `minimum_chunk_size`d chunks would exceed this limit then the chunk size increased by
54        /// `chunk_alignment` until fewer than 1024 are required. `minimum_chunk_size` must be a
55        /// multiple of `chunk_alignment`.
56        minimum_chunk_size: usize,
57        /// The chosen uncompressed chunk size must always be a multiple of this value.
58        chunk_alignment: usize,
59        /// The Zstd compression level to use when compressing chunks.
60        compression_level: i32,
61    },
62    /// A chunked-compression V3 archive will be created.
63    V3 {
64        /// The compression algorithm to use to compress the chunks.
65        compression_algorithm: CompressionAlgorithm,
66    },
67}
68
69impl ChunkedArchiveOptions {
70    const V2_VERSION: u16 = 2;
71    const V2_MAX_CHUNKS: usize = 1023;
72
73    const V3_VERSION: u16 = 3;
74    const V3_MAX_CHUNKS: usize = u32::MAX as usize;
75    const V3_CHUNK_SIZE: usize = 32 * 1024;
76    const V3_ZSTD_COMPRESSION_LEVEL: i32 = 22;
77
78    /// Which version of chunked-compression archive should be constructed.
79    fn version(&self) -> u16 {
80        match self {
81            Self::V2 { .. } => Self::V2_VERSION,
82            Self::V3 { .. } => Self::V3_VERSION,
83        }
84    }
85
86    /// The compression algorithm to use to compress the chunks.
87    fn compression_algorithm(&self) -> CompressionAlgorithm {
88        match self {
89            Self::V2 { .. } => CompressionAlgorithm::Zstd,
90            Self::V3 { compression_algorithm } => *compression_algorithm,
91        }
92    }
93
94    /// Calculate how large chunks must be for a given amount of data.
95    fn chunk_size_for(&self, data_size: usize) -> usize {
96        match self {
97            Self::V2 { chunk_alignment, minimum_chunk_size: target_chunk_size, .. } => {
98                if data_size <= (Self::V2_MAX_CHUNKS * target_chunk_size) {
99                    *target_chunk_size
100                } else {
101                    let chunk_size = data_size.div_ceil(Self::V2_MAX_CHUNKS);
102                    chunk_size.checked_next_multiple_of(*chunk_alignment).unwrap()
103                }
104            }
105            Self::V3 { .. } => {
106                assert!(
107                    data_size.div_ceil(Self::V3_CHUNK_SIZE) <= Self::V3_MAX_CHUNKS,
108                    "Chunked-compression V3 only supports data up to ~140TB"
109                );
110                Self::V3_CHUNK_SIZE
111            }
112        }
113    }
114
115    /// Constructs a compressor to compress chunks based on the specified options.
116    pub fn compressor(&self) -> Compressor {
117        match self {
118            Self::V2 { compression_level, .. } => {
119                let mut compressor = zstd::bulk::Compressor::default();
120                compressor
121                    .set_parameter(zstd::zstd_safe::CParameter::CompressionLevel(
122                        *compression_level,
123                    ))
124                    .expect("setting the compression level should never fail");
125                Compressor::Zstd(compressor)
126            }
127            Self::V3 { compression_algorithm: CompressionAlgorithm::Zstd } => {
128                let mut compressor = zstd::bulk::Compressor::default();
129                compressor
130                    .set_parameter(zstd::zstd_safe::CParameter::CompressionLevel(
131                        Self::V3_ZSTD_COMPRESSION_LEVEL,
132                    ))
133                    .expect("setting the compression level should never fail");
134                Compressor::Zstd(compressor)
135            }
136            Self::V3 { compression_algorithm: CompressionAlgorithm::Lz4 } => {
137                Compressor::Lz4 { compression_level: lz4::HcCompressionLevel::custom(12) }
138            }
139        }
140    }
141
142    /// Constructs a compressor object that uses a thread local compressor to compress chunks based
143    /// on the specified options.
144    pub fn thread_local_compressor(&self) -> ThreadLocalCompressor {
145        match self {
146            Self::V2 { compression_level, .. } => {
147                ThreadLocalCompressor::Zstd { compression_level: *compression_level }
148            }
149            Self::V3 { compression_algorithm: CompressionAlgorithm::Zstd } => {
150                ThreadLocalCompressor::Zstd { compression_level: Self::V3_ZSTD_COMPRESSION_LEVEL }
151            }
152            Self::V3 { compression_algorithm: CompressionAlgorithm::Lz4 } => {
153                ThreadLocalCompressor::Lz4 {
154                    compression_level: lz4::HcCompressionLevel::custom(12),
155                }
156            }
157        }
158    }
159
160    /// Returns true if `version` is a valid chunked-compression version.
161    fn is_valid_version(version: u16) -> bool {
162        match version {
163            Self::V2_VERSION => true,
164            Self::V3_VERSION => true,
165            _ => false,
166        }
167    }
168
169    /// Returns the maximum number of chunks supported by the chunked-compression format at the
170    /// specified version.
171    fn max_chunks_for_version(version: u16) -> Result<usize, ChunkedArchiveError> {
172        match version {
173            Self::V2_VERSION => Ok(Self::V2_MAX_CHUNKS),
174            Self::V3_VERSION => Ok(Self::V3_MAX_CHUNKS),
175            _ => Err(ChunkedArchiveError::InvalidVersion),
176        }
177    }
178}
179
180/// Validated chunk information from an archive. Compressed ranges are relative to the start of
181/// compressed data (i.e. they start after the header and seek table).
182#[derive(Clone, Debug, Eq, PartialEq)]
183pub struct ChunkInfo {
184    pub decompressed_range: Range<usize>,
185    pub compressed_range: Range<usize>,
186}
187
188impl ChunkInfo {
189    fn from_entry(
190        entry: &SeekTableEntry,
191        header_length: usize,
192    ) -> Result<Self, ChunkedArchiveError> {
193        let decompressed_start = entry.decompressed_offset.get() as usize;
194        let decompressed_size = entry.decompressed_size.get() as usize;
195        let decompressed_range = decompressed_start
196            ..decompressed_start
197                .checked_add(decompressed_size)
198                .ok_or(ChunkedArchiveError::OutOfRange)?;
199
200        let compressed_offset = entry.compressed_offset.get() as usize;
201        let compressed_start = compressed_offset
202            .checked_sub(header_length)
203            .ok_or(ChunkedArchiveError::IntegrityError)?;
204        let compressed_size = entry.compressed_size.get() as usize;
205        let compressed_range = compressed_start
206            ..compressed_start
207                .checked_add(compressed_size)
208                .ok_or(ChunkedArchiveError::OutOfRange)?;
209
210        Ok(Self { decompressed_range, compressed_range })
211    }
212}
213
214/// Validated information from decoding an archive.
215#[derive(Debug)]
216pub struct DecodedArchive {
217    compression_algorithm: CompressionAlgorithm,
218    seek_table: Vec<ChunkInfo>,
219}
220
221impl DecodedArchive {
222    /// The total size of decompressing all of the chunks in the archive.
223    pub fn decompressed_size(&self) -> usize {
224        self.seek_table.last().map_or(0, |entry| entry.decompressed_range.end)
225    }
226}
227
228/// Decodes a chunked archive header. Returns a `DecodedArchive` and any remaining bytes that are
229/// part of the chunk data. Returns `Ok(None)` if `data` is not large enough to decode the archive
230/// header & seek table.
231pub fn decode_archive(
232    data: &[u8],
233    archive_length: usize,
234) -> Result<Option<(DecodedArchive, /*archive_data*/ &[u8])>, ChunkedArchiveError> {
235    match Ref::<_, ChunkedArchiveHeader>::from_prefix(data).map_err(Into::into) {
236        Ok((header, data)) => header.decode_archive(data, archive_length as u64),
237        Err(zerocopy::SizeError { .. }) => Ok(None), // Not enough data.
238    }
239}
240
241/// Chunked archive header.
242#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
243#[repr(C)]
244struct ChunkedArchiveHeader {
245    magic: [u8; 8],
246    version: U16<LE>,
247    // This field was added in V3 and should not be used if `version` is 2. Technically, this field
248    // should be 0 in V2, Zstd has the value 0, and V2 always uses Zstd so accessing this field in
249    // V2 should give the correct result.
250    compression_algorithm: u8,
251    reserved_0: u8,
252    num_entries: U32<LE>,
253    checksum: U32<LE>,
254    reserved_1: U32<LE>,
255    reserved_2: U64<LE>,
256}
257
258/// Chunked archive seek table entry.
259#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
260#[repr(C)]
261struct SeekTableEntry {
262    decompressed_offset: U64<LE>,
263    decompressed_size: U64<LE>,
264    compressed_offset: U64<LE>,
265    compressed_size: U64<LE>,
266}
267
268impl ChunkedArchiveHeader {
269    const CHUNKED_ARCHIVE_MAGIC: [u8; 8] = [0x46, 0x9b, 0x78, 0xef, 0x0f, 0xd0, 0xb2, 0x03];
270    const CHUNKED_ARCHIVE_CHECKSUM_OFFSET: usize = 16;
271
272    fn new(
273        seek_table: &[SeekTableEntry],
274        options: ChunkedArchiveOptions,
275    ) -> Result<Self, ChunkedArchiveError> {
276        let header: ChunkedArchiveHeader = Self {
277            magic: Self::CHUNKED_ARCHIVE_MAGIC,
278            version: options.version().into(),
279            compression_algorithm: options.compression_algorithm().into(),
280            reserved_0: 0.into(),
281            num_entries: TryInto::<u32>::try_into(seek_table.len())
282                .or(Err(ChunkedArchiveError::OutOfRange))?
283                .into(),
284            checksum: 0.into(), // `checksum` is calculated below.
285            reserved_1: 0.into(),
286            reserved_2: 0.into(),
287        };
288        Ok(Self { checksum: header.checksum(seek_table).into(), ..header })
289    }
290
291    /// Calculate the checksum of the header + all seek table entries.
292    fn checksum(&self, entries: &[SeekTableEntry]) -> u32 {
293        let crc_algo = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
294        let mut digest = crc_algo.digest();
295        digest.update(&self.as_bytes()[..Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET]);
296        digest.update(
297            &self.as_bytes()
298                [Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET + self.checksum.as_bytes().len()..],
299        );
300        digest.update(entries.as_bytes());
301        digest.finalize()
302    }
303
304    /// Calculate the total header length of an archive *including* all seek table entries.
305    fn header_length(num_entries: usize) -> usize {
306        std::mem::size_of::<ChunkedArchiveHeader>()
307            + (std::mem::size_of::<SeekTableEntry>() * num_entries)
308    }
309
310    /// Validates the archive header and decodes the seek table.
311    fn decode_archive(
312        self,
313        data: &[u8],
314        archive_length: u64,
315    ) -> Result<Option<(DecodedArchive, /*chunk_data*/ &[u8])>, ChunkedArchiveError> {
316        // Deserialize seek table.
317        let num_entries = self.num_entries.get() as usize;
318        let Ok((entries, chunk_data)) =
319            Ref::<_, [SeekTableEntry]>::from_prefix_with_elems(data, num_entries)
320        else {
321            return Ok(None);
322        };
323        let entries: &[SeekTableEntry] = Ref::into_ref(entries);
324
325        // Validate archive header.
326        if self.magic != Self::CHUNKED_ARCHIVE_MAGIC {
327            return Err(ChunkedArchiveError::BadMagic);
328        }
329        let version = self.version.get();
330        if !ChunkedArchiveOptions::is_valid_version(version) {
331            return Err(ChunkedArchiveError::InvalidVersion);
332        }
333        if self.checksum.get() != self.checksum(entries) {
334            return Err(ChunkedArchiveError::IntegrityError);
335        }
336        if entries.len() > ChunkedArchiveOptions::max_chunks_for_version(version)? {
337            return Err(ChunkedArchiveError::IntegrityError);
338        }
339        let compression_algorithm = CompressionAlgorithm::try_from(self.compression_algorithm)?;
340
341        // Validate seek table using invariants I0 through I5.
342
343        // I0: The first seek table entry, if any, must have decompressed offset 0.
344        if !entries.is_empty() && entries[0].decompressed_offset.get() != 0 {
345            return Err(ChunkedArchiveError::IntegrityError);
346        }
347
348        // I1: The compressed offsets of all seek table entries must not overlap with the header.
349        let header_length = Self::header_length(entries.len());
350        if entries.iter().any(|entry| entry.compressed_offset.get() < header_length as u64) {
351            return Err(ChunkedArchiveError::IntegrityError);
352        }
353
354        // I2: Each entry's decompressed offset must be equal to the end of the previous frame
355        //     (i.e. to the previous frame's decompressed offset + length).
356        for (prev, curr) in entries.iter().tuple_windows() {
357            if (prev.decompressed_offset.get() + prev.decompressed_size.get())
358                != curr.decompressed_offset.get()
359            {
360                return Err(ChunkedArchiveError::IntegrityError);
361            }
362        }
363
364        // I3: Each entry's compressed offset must be greater than or equal to the end of the
365        //     previous frame (i.e. to the previous frame's compressed offset + length).
366        for (prev, curr) in entries.iter().tuple_windows() {
367            if (prev.compressed_offset.get() + prev.compressed_size.get())
368                > curr.compressed_offset.get()
369            {
370                return Err(ChunkedArchiveError::IntegrityError);
371            }
372        }
373
374        // I4: Each entry must have a non-zero decompressed and compressed length.
375        for entry in entries.iter() {
376            if entry.decompressed_size.get() == 0 || entry.compressed_size.get() == 0 {
377                return Err(ChunkedArchiveError::IntegrityError);
378            }
379        }
380
381        // I5: Data referenced by each entry must fit within the specified file size.
382        for entry in entries.iter() {
383            let compressed_end = entry.compressed_offset.get() + entry.compressed_size.get();
384            if compressed_end > archive_length {
385                return Err(ChunkedArchiveError::IntegrityError);
386            }
387        }
388
389        let seek_table = entries
390            .iter()
391            .map(|entry| ChunkInfo::from_entry(entry, header_length))
392            .try_collect()?;
393        Ok(Some((DecodedArchive { seek_table, compression_algorithm }, chunk_data)))
394    }
395}
396
397/// In-memory representation of a compressed chunk.
398pub struct CompressedChunk {
399    /// Compressed data for this chunk.
400    pub compressed_data: Vec<u8>,
401    /// Size of this chunk when decompressed.
402    pub decompressed_size: usize,
403}
404
405/// In-memory representation of a compressed chunked archive.
406pub struct ChunkedArchive {
407    /// Chunks this archive contains, in order. Right now we only allow creating archives with
408    /// contiguous compressed and decompressed space.
409    chunks: Vec<CompressedChunk>,
410    /// Size used to chunk input when creating this archive. Last chunk may be smaller than this
411    /// amount.
412    chunk_size: usize,
413    /// The options used to construct this archive.
414    options: ChunkedArchiveOptions,
415}
416
417impl ChunkedArchive {
418    /// Create a ChunkedArchive for `data` compressing each chunk in parallel. This function uses
419    /// the `rayon` crate for parallelism. By default compression happens in the global thread pool,
420    /// but this function can also be executed within a locally scoped pool.
421    pub fn new(data: &[u8], options: ChunkedArchiveOptions) -> Result<Self, ChunkedArchiveError> {
422        let chunk_size = options.chunk_size_for(data.len());
423        let mut chunks: Vec<Result<CompressedChunk, ChunkedArchiveError>> = vec![];
424        let compressor = options.thread_local_compressor();
425        data.par_chunks(chunk_size)
426            .enumerate()
427            .map(|(index, chunk)| {
428                let compressed_data = compressor.compress(chunk, index)?;
429                Ok(CompressedChunk { compressed_data, decompressed_size: chunk.len() })
430            })
431            .collect_into_vec(&mut chunks);
432        let chunks: Vec<_> = chunks.into_iter().try_collect()?;
433        Ok(ChunkedArchive { chunks, chunk_size, options })
434    }
435
436    /// Accessor for compressed chunk data.
437    pub fn chunks(&self) -> &Vec<CompressedChunk> {
438        &self.chunks
439    }
440
441    /// The chunk size calculated for this archive during compression. Represents how input data
442    /// was chunked for compression. Note that the final chunk may be smaller than this amount
443    /// when decompressed.
444    pub fn chunk_size(&self) -> usize {
445        self.chunk_size
446    }
447
448    /// Sum of sizes of all compressed chunks.
449    pub fn compressed_data_size(&self) -> usize {
450        self.chunks.iter().map(|chunk| chunk.compressed_data.len()).sum()
451    }
452
453    /// Total size of the archive in bytes.
454    pub fn serialized_size(&self) -> usize {
455        ChunkedArchiveHeader::header_length(self.chunks.len()) + self.compressed_data_size()
456    }
457
458    /// Write the archive to `writer`.
459    pub fn write(self, mut writer: impl std::io::Write) -> Result<(), std::io::Error> {
460        let seek_table = self.make_seek_table();
461        let header = ChunkedArchiveHeader::new(&seek_table, self.options).unwrap();
462        writer.write_all(header.as_bytes())?;
463        writer.write_all(seek_table.as_slice().as_bytes())?;
464        for chunk in self.chunks {
465            writer.write_all(&chunk.compressed_data)?;
466        }
467        Ok(())
468    }
469
470    /// Create the seek table for this archive.
471    fn make_seek_table(&self) -> Vec<SeekTableEntry> {
472        let header_length = ChunkedArchiveHeader::header_length(self.chunks.len());
473        let mut seek_table = vec![];
474        seek_table.reserve(self.chunks.len());
475        let mut compressed_size: usize = 0;
476        let mut decompressed_offset: usize = 0;
477        for chunk in &self.chunks {
478            seek_table.push(SeekTableEntry {
479                decompressed_offset: (decompressed_offset as u64).into(),
480                decompressed_size: (chunk.decompressed_size as u64).into(),
481                compressed_offset: ((header_length + compressed_size) as u64).into(),
482                compressed_size: (chunk.compressed_data.len() as u64).into(),
483            });
484            compressed_size += chunk.compressed_data.len();
485            decompressed_offset += chunk.decompressed_size;
486        }
487        seek_table
488    }
489}
490
491/// Streaming decompressor for chunked archives. Example:
492/// ```
493/// // Create a chunked archive:
494/// let data: Vec<u8> = vec![3; 1024];
495/// let compressed = ChunkedArchive::new(&data, /*block_size*/ 8192).serialize().unwrap();
496/// // Verify the header + decode the seek table:
497/// let (seek_table, archive_data) = decode_archive(&compressed, compressed.len())?.unwrap();
498/// let mut decompressed: Vec<u8> = vec![];
499/// let mut on_chunk = |data: &[u8]| { decompressed.extend_from_slice(data); };
500/// let mut decompressor = ChunkedDecompressor(seek_table);
501/// // `on_chunk` is invoked as each slice is made available. Archive can be provided as chunks.
502/// decompressor.update(archive_data, &mut on_chunk);
503/// assert_eq!(data.as_slice(), decompressed.as_slice());
504/// ```
505pub struct ChunkedDecompressor {
506    seek_table: Vec<ChunkInfo>,
507    buffer: Vec<u8>,
508    data_written: usize,
509    curr_chunk: usize,
510    total_compressed_size: usize,
511    decompressor: Decompressor,
512    decompressed_buffer: Vec<u8>,
513    error_handler: Option<ErrorHandler>,
514}
515
516type ErrorHandler = Box<dyn Fn(usize, ChunkInfo, &[u8]) -> () + Send + 'static>;
517
518impl ChunkedDecompressor {
519    /// Create a new decompressor to decode an archive from a validated seek table.
520    pub fn new(decoded_archive: DecodedArchive) -> Result<Self, ChunkedArchiveError> {
521        let DecodedArchive { compression_algorithm, seek_table } = decoded_archive;
522        let total_compressed_size =
523            seek_table.last().map_or(0, |last_chunk| last_chunk.compressed_range.end);
524        let decompressed_buffer =
525            vec![0u8; seek_table.first().map_or(0, |c| c.decompressed_range.len())];
526        Ok(Self {
527            seek_table,
528            buffer: vec![],
529            data_written: 0,
530            curr_chunk: 0,
531            total_compressed_size,
532            decompressor: compression_algorithm.decompressor(),
533            decompressed_buffer,
534            error_handler: None,
535        })
536    }
537
538    /// Creates a new decompressor with an additional error handler invoked when a chunk fails to be
539    /// decompressed.
540    pub fn new_with_error_handler(
541        decoded_archive: DecodedArchive,
542        error_handler: ErrorHandler,
543    ) -> Result<Self, ChunkedArchiveError> {
544        Ok(Self { error_handler: Some(error_handler), ..Self::new(decoded_archive)? })
545    }
546
547    pub fn seek_table(&self) -> &Vec<ChunkInfo> {
548        &self.seek_table
549    }
550
551    fn finish_chunk(
552        &mut self,
553        data: &[u8],
554        chunk_callback: &mut impl FnMut(&[u8]) -> (),
555    ) -> Result<(), ChunkedArchiveError> {
556        debug_assert_eq!(data.len(), self.seek_table[self.curr_chunk].compressed_range.len());
557        let chunk = &self.seek_table[self.curr_chunk];
558        let decompressed_size = self
559            .decompressor
560            .decompress_into(data, self.decompressed_buffer.as_mut_slice(), self.curr_chunk)
561            .inspect_err(|_| {
562                if let Some(error_handler) = &self.error_handler {
563                    error_handler(self.curr_chunk, chunk.clone(), data.as_bytes());
564                }
565            })?;
566        if decompressed_size != chunk.decompressed_range.len() {
567            return Err(ChunkedArchiveError::IntegrityError);
568        }
569        chunk_callback(&self.decompressed_buffer[..decompressed_size]);
570        self.curr_chunk += 1;
571        Ok(())
572    }
573
574    /// Update the decompressor with more data.
575    pub fn update(
576        &mut self,
577        mut data: &[u8],
578        chunk_callback: &mut impl FnMut(&[u8]) -> (),
579    ) -> Result<(), ChunkedArchiveError> {
580        // Caller must not provide too much data.
581        if self.data_written + data.len() > self.total_compressed_size {
582            return Err(ChunkedArchiveError::OutOfRange);
583        }
584        self.data_written += data.len();
585
586        // If we had leftover data from a previous read, append until we've filled a chunk.
587        if !self.buffer.is_empty() {
588            let to_read = std::cmp::min(
589                data.len(),
590                self.seek_table[self.curr_chunk]
591                    .compressed_range
592                    .len()
593                    .checked_sub(self.buffer.len())
594                    .unwrap(),
595            );
596            self.buffer.extend_from_slice(&data[..to_read]);
597            if self.buffer.len() == self.seek_table[self.curr_chunk].compressed_range.len() {
598                // Take self.buffer temporarily (so we don't have to split borrows).
599                // That way we don't have to re-commit the pages we've already used in the buffer
600                // for next time.
601                let full_chunk = std::mem::take(&mut self.buffer);
602                self.finish_chunk(&full_chunk[..], chunk_callback)?;
603                self.buffer = full_chunk;
604                // Draining the buffer will set the length to 0 but keep the capacity the same.
605                self.buffer.drain(..);
606            }
607            data = &data[to_read..];
608        }
609
610        // Decode as many full chunks as we can.
611        while !data.is_empty()
612            && self.curr_chunk < self.seek_table.len()
613            && self.seek_table[self.curr_chunk].compressed_range.len() <= data.len()
614        {
615            let len = self.seek_table[self.curr_chunk].compressed_range.len();
616            self.finish_chunk(&data[..len], chunk_callback)?;
617            data = &data[len..];
618        }
619
620        // Buffer the rest for the next call.
621        if !data.is_empty() {
622            debug_assert!(self.curr_chunk < self.seek_table.len());
623            debug_assert!(self.data_written < self.total_compressed_size);
624            self.buffer.extend_from_slice(data);
625        }
626
627        debug_assert!(
628            self.data_written < self.total_compressed_size
629                || self.curr_chunk == self.seek_table.len()
630        );
631
632        Ok(())
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use crate::Type1Blob;
639
640    use super::*;
641    use rand::Rng;
642    use std::matches;
643
644    /// Create a compressed archive and ensure we can decode it as a valid archive that passes all
645    /// required integrity checks.
646    #[test]
647    fn compress_simple() {
648        let data: Vec<u8> = vec![0; 32 * 1024 * 16];
649        let archive = ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS).unwrap();
650        // This data is highly compressible, so the result should be smaller than the original.
651        let mut compressed: Vec<u8> = vec![];
652        archive.write(&mut compressed).unwrap();
653        assert!(compressed.len() <= data.len());
654        // We should be able to decode and verify the archive's integrity in-place.
655        assert!(decode_archive(&compressed, compressed.len()).unwrap().is_some());
656    }
657
658    /// Generate a header + seek table for verifying invariants/integrity checks.
659    fn generate_archive(
660        num_entries: usize,
661        options: ChunkedArchiveOptions,
662    ) -> (ChunkedArchiveHeader, Vec<SeekTableEntry>, /*archive_length*/ u64) {
663        let mut seek_table = Vec::with_capacity(num_entries);
664        let header_length = ChunkedArchiveHeader::header_length(num_entries) as u64;
665        const COMPRESSED_CHUNK_SIZE: u64 = 1024;
666        const DECOMPRESSED_CHUNK_SIZE: u64 = 2048;
667        for n in 0..(num_entries as u64) {
668            seek_table.push(SeekTableEntry {
669                compressed_offset: (header_length + (n * COMPRESSED_CHUNK_SIZE)).into(),
670                compressed_size: COMPRESSED_CHUNK_SIZE.into(),
671                decompressed_offset: (n * DECOMPRESSED_CHUNK_SIZE).into(),
672                decompressed_size: DECOMPRESSED_CHUNK_SIZE.into(),
673            });
674        }
675        let header = ChunkedArchiveHeader::new(&seek_table, options).unwrap();
676        let archive_length: u64 = header_length + (num_entries as u64 * COMPRESSED_CHUNK_SIZE);
677        (header, seek_table, archive_length)
678    }
679
680    #[test]
681    fn should_validate_self() {
682        let (header, seek_table, archive_length) =
683            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
684        let serialized_table = seek_table.as_slice().as_bytes();
685        assert!(header.decode_archive(serialized_table, archive_length).unwrap().is_some());
686    }
687
688    #[test]
689    fn should_validate_empty() {
690        let (header, _, archive_length) = generate_archive(0, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
691        assert!(header.decode_archive(&[], archive_length).unwrap().is_some());
692    }
693
694    #[test]
695    fn should_detect_bad_magic() {
696        let (header, seek_table, archive_length) =
697            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
698        let mut corrupt_magic = ChunkedArchiveHeader::CHUNKED_ARCHIVE_MAGIC;
699        corrupt_magic[0] = !corrupt_magic[0];
700        let bad_magic = ChunkedArchiveHeader { magic: corrupt_magic, ..header };
701        let serialized_table = seek_table.as_slice().as_bytes();
702        assert!(matches!(
703            bad_magic.decode_archive(serialized_table, archive_length).unwrap_err(),
704            ChunkedArchiveError::BadMagic
705        ));
706    }
707    #[test]
708    fn should_detect_wrong_version() {
709        let (header, seek_table, archive_length) =
710            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
711        let invalid_version = ChunkedArchiveHeader { version: u16::MAX.into(), ..header };
712        let serialized_table = seek_table.as_slice().as_bytes();
713        assert!(matches!(
714            invalid_version.decode_archive(serialized_table, archive_length).unwrap_err(),
715            ChunkedArchiveError::InvalidVersion
716        ));
717    }
718
719    #[test]
720    fn should_detect_corrupt_checksum() {
721        let (header, seek_table, archive_length) =
722            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
723        let corrupt_checksum =
724            ChunkedArchiveHeader { checksum: (!header.checksum.get()).into(), ..header };
725        let serialized_table = seek_table.as_slice().as_bytes();
726        assert!(matches!(
727            corrupt_checksum.decode_archive(serialized_table, archive_length).unwrap_err(),
728            ChunkedArchiveError::IntegrityError
729        ));
730    }
731
732    #[test]
733    fn should_reject_too_many_entries_v2() {
734        let (too_many_entries, seek_table, archive_length) = generate_archive(
735            ChunkedArchiveOptions::V2_MAX_CHUNKS + 1,
736            Type1Blob::CHUNKED_ARCHIVE_OPTIONS,
737        );
738
739        let serialized_table = seek_table.as_slice().as_bytes();
740        assert!(matches!(
741            too_many_entries.decode_archive(serialized_table, archive_length).unwrap_err(),
742            ChunkedArchiveError::IntegrityError
743        ));
744    }
745
746    #[test]
747    fn invariant_i0_first_entry_zero() {
748        let (header, mut seek_table, archive_length) =
749            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
750        assert_eq!(seek_table[0].decompressed_offset.get(), 0);
751        seek_table[0].decompressed_offset = 1.into();
752
753        let serialized_table = seek_table.as_slice().as_bytes();
754        assert!(matches!(
755            header.decode_archive(serialized_table, archive_length).unwrap_err(),
756            ChunkedArchiveError::IntegrityError
757        ));
758    }
759
760    #[test]
761    fn invariant_i1_no_header_overlap() {
762        let (header, mut seek_table, archive_length) =
763            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
764        let header_end = ChunkedArchiveHeader::header_length(seek_table.len()) as u64;
765        assert!(seek_table[0].compressed_offset.get() >= header_end);
766        seek_table[0].compressed_offset = (header_end - 1).into();
767        let serialized_table = seek_table.as_slice().as_bytes();
768        assert!(matches!(
769            header.decode_archive(serialized_table, archive_length).unwrap_err(),
770            ChunkedArchiveError::IntegrityError
771        ));
772    }
773
774    #[test]
775    fn invariant_i2_decompressed_monotonic() {
776        let (header, mut seek_table, archive_length) =
777            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
778        assert_eq!(
779            seek_table[0].decompressed_offset.get() + seek_table[0].decompressed_size.get(),
780            seek_table[1].decompressed_offset.get()
781        );
782        seek_table[1].decompressed_offset = (seek_table[1].decompressed_offset.get() - 1).into();
783        let serialized_table = seek_table.as_slice().as_bytes();
784        assert!(matches!(
785            header.decode_archive(serialized_table, archive_length).unwrap_err(),
786            ChunkedArchiveError::IntegrityError
787        ));
788    }
789
790    #[test]
791    fn invariant_i3_compressed_monotonic() {
792        let (header, mut seek_table, archive_length) =
793            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
794        assert!(
795            (seek_table[0].compressed_offset.get() + seek_table[0].compressed_size.get())
796                <= seek_table[1].compressed_offset.get()
797        );
798        seek_table[1].compressed_offset = (seek_table[1].compressed_offset.get() - 1).into();
799        let serialized_table = seek_table.as_slice().as_bytes();
800        assert!(matches!(
801            header.decode_archive(serialized_table, archive_length).unwrap_err(),
802            ChunkedArchiveError::IntegrityError
803        ));
804    }
805
806    #[test]
807    fn invariant_i4_nonzero_compressed_size() {
808        let (header, mut seek_table, archive_length) =
809            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
810        assert!(seek_table[0].compressed_size.get() > 0);
811        seek_table[0].compressed_size = 0.into();
812        let serialized_table = seek_table.as_slice().as_bytes();
813        assert!(matches!(
814            header.decode_archive(serialized_table, archive_length).unwrap_err(),
815            ChunkedArchiveError::IntegrityError
816        ));
817    }
818
819    #[test]
820    fn invariant_i4_nonzero_decompressed_size() {
821        let (header, mut seek_table, archive_length) =
822            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
823        assert!(seek_table[0].decompressed_size.get() > 0);
824        seek_table[0].decompressed_size = 0.into();
825        let serialized_table = seek_table.as_slice().as_bytes();
826        assert!(matches!(
827            header.decode_archive(serialized_table, archive_length).unwrap_err(),
828            ChunkedArchiveError::IntegrityError
829        ));
830    }
831
832    #[test]
833    fn invariant_i5_within_archive() {
834        let (header, mut seek_table, archive_length) =
835            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
836        let last_entry = seek_table.last_mut().unwrap();
837        assert!(
838            (last_entry.compressed_offset.get() + last_entry.compressed_size.get())
839                <= archive_length
840        );
841        last_entry.compressed_offset = (archive_length + 1).into();
842        let serialized_table = seek_table.as_slice().as_bytes();
843        assert!(matches!(
844            header.decode_archive(serialized_table, archive_length).unwrap_err(),
845            ChunkedArchiveError::IntegrityError
846        ));
847    }
848
849    #[test]
850    fn max_chunks() {
851        let ChunkedArchiveOptions::V2 { minimum_chunk_size, chunk_alignment, .. } =
852            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
853        else {
854            panic!()
855        };
856        assert_eq!(
857            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
858                .chunk_size_for(minimum_chunk_size * ChunkedArchiveOptions::V2_MAX_CHUNKS),
859            minimum_chunk_size
860        );
861        assert_eq!(
862            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
863                .chunk_size_for(minimum_chunk_size * ChunkedArchiveOptions::V2_MAX_CHUNKS + 1),
864            minimum_chunk_size + chunk_alignment
865        );
866    }
867
868    #[test]
869    fn test_decompressor_empty_archive() {
870        let mut compressed: Vec<u8> = vec![];
871        ChunkedArchive::new(&[], Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
872            .expect("compress")
873            .write(&mut compressed)
874            .expect("write archive");
875        let (decoded_archive, chunk_data) =
876            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
877        assert!(decoded_archive.seek_table.is_empty());
878        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
879        let mut chunk_callback = |_chunk: &[u8]| panic!("Archive doesn't have any chunks.");
880        // Stream data into the decompressor in small chunks to exhaust more edge cases.
881        chunk_data
882            .chunks(4)
883            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
884    }
885
886    #[test]
887    fn test_decompressor() {
888        const UNCOMPRESSED_LENGTH: usize = 3_000_000;
889        let data: Vec<u8> = {
890            let range = rand::distr::Uniform::<u8>::new_inclusive(0, 255).unwrap();
891            rand::rng().sample_iter(&range).take(UNCOMPRESSED_LENGTH).collect()
892        };
893        let mut compressed: Vec<u8> = vec![];
894        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
895            .expect("compress")
896            .write(&mut compressed)
897            .expect("write archive");
898        let (decoded_archive, chunk_data) =
899            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
900
901        // Make sure we have multiple chunks for this test.
902        let num_chunks = decoded_archive.seek_table.len();
903        assert!(num_chunks > 1);
904
905        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
906
907        let mut decoded_chunks: usize = 0;
908        let mut decompressed_offset: usize = 0;
909        let mut chunk_callback = |decompressed_chunk: &[u8]| {
910            assert!(
911                decompressed_chunk
912                    == &data[decompressed_offset..decompressed_offset + decompressed_chunk.len()]
913            );
914            decompressed_offset += decompressed_chunk.len();
915            decoded_chunks += 1;
916        };
917
918        // Stream data into the decompressor in small chunks to exhaust more edge cases.
919        chunk_data
920            .chunks(4)
921            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
922        assert_eq!(decoded_chunks, num_chunks);
923    }
924
925    #[test]
926    fn test_decompressor_corrupt_decompressed_size() {
927        let data = vec![0; 3_000_000];
928        let mut compressed: Vec<u8> = vec![];
929        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
930            .expect("compress")
931            .write(&mut compressed)
932            .expect("write archive");
933        let (mut decoded_archive, chunk_data) =
934            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
935
936        // Corrupt the decompressed size of the chunk.
937        decoded_archive.seek_table[0].decompressed_range =
938            decoded_archive.seek_table[0].decompressed_range.start
939                ..decoded_archive.seek_table[0].decompressed_range.end + 1;
940
941        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
942        assert!(matches!(
943            decompressor.update(&chunk_data, &mut |_chunk| {}),
944            Err(ChunkedArchiveError::IntegrityError)
945        ));
946    }
947
948    #[test]
949    fn test_decompressor_corrupt_compressed_size() {
950        let data = vec![0; 3_000_000];
951        let mut compressed: Vec<u8> = vec![];
952        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
953            .expect("compress")
954            .write(&mut compressed)
955            .expect("write archive");
956        let (mut decoded_archive, chunk_data) =
957            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
958
959        // Corrupt the compressed size of the chunk.
960        decoded_archive.seek_table[0].compressed_range =
961            decoded_archive.seek_table[0].compressed_range.start
962                ..decoded_archive.seek_table[0].compressed_range.end - 1;
963        let first_chunk_info = decoded_archive.seek_table[0].clone();
964        let error_handler = move |chunk_index: usize, chunk_info: ChunkInfo, chunk_data: &[u8]| {
965            assert_eq!(chunk_index, 0);
966            assert_eq!(chunk_info, first_chunk_info);
967            assert_eq!(chunk_data.len(), chunk_info.compressed_range.len());
968        };
969
970        let mut decompressor =
971            ChunkedDecompressor::new_with_error_handler(decoded_archive, Box::new(error_handler))
972                .unwrap();
973        assert!(matches!(
974            decompressor.update(&chunk_data, &mut |_chunk| {}),
975            Err(ChunkedArchiveError::DecompressionError { index: 0, .. })
976        ));
977    }
978
979    #[test]
980    fn test_v3_zstd_roundtrip() {
981        let data = vec![0; 3_000_000];
982        let options =
983            ChunkedArchiveOptions::V3 { compression_algorithm: CompressionAlgorithm::Zstd };
984        let mut compressed = vec![];
985        ChunkedArchive::new(&data, options)
986            .expect("compress")
987            .write(&mut compressed)
988            .expect("write");
989
990        // Verify header.
991        let (header, _) =
992            Ref::<_, ChunkedArchiveHeader>::from_prefix(compressed.as_slice()).unwrap();
993        assert_eq!(header.version.get(), 3);
994        assert_eq!(header.compression_algorithm, CompressionAlgorithm::Zstd as u8);
995
996        let (decoded_archive, chunk_data) =
997            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
998
999        // Decompress.
1000        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
1001        let mut decompressed: Vec<u8> = vec![];
1002        let mut chunk_callback = |chunk: &[u8]| decompressed.extend_from_slice(chunk);
1003        decompressor.update(chunk_data, &mut chunk_callback).unwrap();
1004
1005        assert_eq!(decompressed, data);
1006    }
1007
1008    #[test]
1009    fn test_v3_lz4_roundtrip() {
1010        let data = vec![0; 3_000_000];
1011        let options =
1012            ChunkedArchiveOptions::V3 { compression_algorithm: CompressionAlgorithm::Lz4 };
1013        let mut compressed = vec![];
1014        ChunkedArchive::new(&data, options)
1015            .expect("compress")
1016            .write(&mut compressed)
1017            .expect("write");
1018
1019        // Verify header.
1020        let (header, _) =
1021            Ref::<_, ChunkedArchiveHeader>::from_prefix(compressed.as_slice()).unwrap();
1022        assert_eq!(header.version.get(), 3);
1023        assert_eq!(header.compression_algorithm, CompressionAlgorithm::Lz4 as u8);
1024
1025        let (decoded_archive, chunk_data) =
1026            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
1027
1028        // Decompress.
1029        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
1030        let mut decompressed: Vec<u8> = vec![];
1031        let mut chunk_callback = |chunk: &[u8]| decompressed.extend_from_slice(chunk);
1032        decompressor.update(chunk_data, &mut chunk_callback).unwrap();
1033
1034        assert_eq!(decompressed, data);
1035    }
1036}