Skip to main content

delivery_blob/
compression.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implementation of chunked-compression library in Rust. Archives can be created by making a new
6//! [`ChunkedArchive`] and serializing/writing it. An archive's header can be verified and seek
7//! table decoded using [`decode_archive`].
8
9use crc::Hasher32;
10use itertools::Itertools;
11use rayon::prelude::*;
12use std::ops::Range;
13use thiserror::Error;
14use zerocopy::byteorder::{LE, U16, U32, U64};
15use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned};
16
17mod compression_algorithm;
18pub use compression_algorithm::{
19    CompressionAlgorithm, Compressor, Decompressor, ThreadLocalCompressor, ThreadLocalDecompressor,
20};
21
22/// Validated chunk information from an archive. Compressed ranges are relative to the start of
23/// compressed data (i.e. they start after the header and seek table).
24// *NOTE*: Use caution when using the `#[source]` attribute or naming fields `source`. Some callers
25// attempt to downcast library errors into the concrete type of the root cause.
26// See https://docs.rs/thiserror/latest/thiserror/ for more information.
27#[derive(Debug, Error)]
28pub enum ChunkedArchiveError {
29    #[error("Invalid or unsupported archive version.")]
30    InvalidVersion,
31
32    #[error("Archive header has incorrect magic.")]
33    BadMagic,
34
35    #[error("Integrity checks failed (e.g. incorrect CRC, inconsistent header fields).")]
36    IntegrityError,
37
38    #[error("Value is out of range or cannot be represented in specified type.")]
39    OutOfRange,
40
41    #[error("Error decompressing chunk {index}: `{error}`.")]
42    DecompressionError { index: usize, error: std::io::Error },
43
44    #[error("Error compressing chunk {index}: `{error}`.")]
45    CompressionError { index: usize, error: std::io::Error },
46}
47
48/// Options for constructing a chunked archive.
49#[derive(Copy, Clone, Debug, Eq, PartialEq)]
50pub enum ChunkedArchiveOptions {
51    /// A chunked-compression V2 archive will be created.
52    V2 {
53        /// Chunked-compression V2 has a limit of 1023 chunks. If splitting the data up into
54        /// `minimum_chunk_size`d chunks would exceed this limit then the chunk size increased by
55        /// `chunk_alignment` until fewer than 1024 are required. `minimum_chunk_size` must be a
56        /// multiple of `chunk_alignment`.
57        minimum_chunk_size: usize,
58        /// The chosen uncompressed chunk size must always be a multiple of this value.
59        chunk_alignment: usize,
60        /// The Zstd compression level to use when compressing chunks.
61        compression_level: i32,
62    },
63    /// A chunked-compression V3 archive will be created.
64    V3 {
65        /// The compression algorithm to use to compress the chunks.
66        compression_algorithm: CompressionAlgorithm,
67    },
68}
69
70impl ChunkedArchiveOptions {
71    const V2_VERSION: u16 = 2;
72    const V2_MAX_CHUNKS: usize = 1023;
73
74    const V3_VERSION: u16 = 3;
75    const V3_MAX_CHUNKS: usize = u32::MAX as usize;
76    const V3_CHUNK_SIZE: usize = 32 * 1024;
77    const V3_ZSTD_COMPRESSION_LEVEL: i32 = 22;
78
79    /// Which version of chunked-compression archive should be constructed.
80    fn version(&self) -> u16 {
81        match self {
82            Self::V2 { .. } => Self::V2_VERSION,
83            Self::V3 { .. } => Self::V3_VERSION,
84        }
85    }
86
87    /// The compression algorithm to use to compress the chunks.
88    fn compression_algorithm(&self) -> CompressionAlgorithm {
89        match self {
90            Self::V2 { .. } => CompressionAlgorithm::Zstd,
91            Self::V3 { compression_algorithm } => *compression_algorithm,
92        }
93    }
94
95    /// Calculate how large chunks must be for a given amount of data.
96    fn chunk_size_for(&self, data_size: usize) -> usize {
97        match self {
98            Self::V2 { chunk_alignment, minimum_chunk_size: target_chunk_size, .. } => {
99                if data_size <= (Self::V2_MAX_CHUNKS * target_chunk_size) {
100                    *target_chunk_size
101                } else {
102                    let chunk_size = data_size.div_ceil(Self::V2_MAX_CHUNKS);
103                    chunk_size.checked_next_multiple_of(*chunk_alignment).unwrap()
104                }
105            }
106            Self::V3 { .. } => {
107                assert!(
108                    data_size.div_ceil(Self::V3_CHUNK_SIZE) <= Self::V3_MAX_CHUNKS,
109                    "Chunked-compression V3 only supports data up to ~140TB"
110                );
111                Self::V3_CHUNK_SIZE
112            }
113        }
114    }
115
116    /// Constructs a compressor to compress chunks based on the specified options.
117    pub fn compressor(&self) -> Compressor {
118        match self {
119            Self::V2 { compression_level, .. } => {
120                let mut compressor = zstd::bulk::Compressor::default();
121                compressor
122                    .set_parameter(zstd::zstd_safe::CParameter::CompressionLevel(
123                        *compression_level,
124                    ))
125                    .expect("setting the compression level should never fail");
126                Compressor::Zstd(compressor)
127            }
128            Self::V3 { compression_algorithm: CompressionAlgorithm::Zstd } => {
129                let mut compressor = zstd::bulk::Compressor::default();
130                compressor
131                    .set_parameter(zstd::zstd_safe::CParameter::CompressionLevel(
132                        Self::V3_ZSTD_COMPRESSION_LEVEL,
133                    ))
134                    .expect("setting the compression level should never fail");
135                Compressor::Zstd(compressor)
136            }
137            Self::V3 { compression_algorithm: CompressionAlgorithm::Lz4 } => {
138                Compressor::Lz4 { compression_level: lz4::HcCompressionLevel::custom(12) }
139            }
140        }
141    }
142
143    /// Constructs a compressor object that uses a thread local compressor to compress chunks based
144    /// on the specified options.
145    pub fn thread_local_compressor(&self) -> ThreadLocalCompressor {
146        match self {
147            Self::V2 { compression_level, .. } => {
148                ThreadLocalCompressor::Zstd { compression_level: *compression_level }
149            }
150            Self::V3 { compression_algorithm: CompressionAlgorithm::Zstd } => {
151                ThreadLocalCompressor::Zstd { compression_level: Self::V3_ZSTD_COMPRESSION_LEVEL }
152            }
153            Self::V3 { compression_algorithm: CompressionAlgorithm::Lz4 } => {
154                ThreadLocalCompressor::Lz4 {
155                    compression_level: lz4::HcCompressionLevel::custom(12),
156                }
157            }
158        }
159    }
160
161    /// Returns true if `version` is a valid chunked-compression version.
162    fn is_valid_version(version: u16) -> bool {
163        match version {
164            Self::V2_VERSION => true,
165            Self::V3_VERSION => true,
166            _ => false,
167        }
168    }
169
170    /// Returns the maximum number of chunks supported by the chunked-compression format at the
171    /// specified version.
172    fn max_chunks_for_version(version: u16) -> Result<usize, ChunkedArchiveError> {
173        match version {
174            Self::V2_VERSION => Ok(Self::V2_MAX_CHUNKS),
175            Self::V3_VERSION => Ok(Self::V3_MAX_CHUNKS),
176            _ => Err(ChunkedArchiveError::InvalidVersion),
177        }
178    }
179}
180
181/// Validated chunk information from an archive. Compressed ranges are relative to the start of
182/// compressed data (i.e. they start after the header and seek table).
183#[derive(Clone, Debug, Eq, PartialEq)]
184pub struct ChunkInfo {
185    pub decompressed_range: Range<usize>,
186    pub compressed_range: Range<usize>,
187}
188
189impl ChunkInfo {
190    fn from_entry(
191        entry: &SeekTableEntry,
192        header_length: usize,
193    ) -> Result<Self, ChunkedArchiveError> {
194        let decompressed_start = entry.decompressed_offset.get() as usize;
195        let decompressed_size = entry.decompressed_size.get() as usize;
196        let decompressed_range = decompressed_start
197            ..decompressed_start
198                .checked_add(decompressed_size)
199                .ok_or(ChunkedArchiveError::OutOfRange)?;
200
201        let compressed_offset = entry.compressed_offset.get() as usize;
202        let compressed_start = compressed_offset
203            .checked_sub(header_length)
204            .ok_or(ChunkedArchiveError::IntegrityError)?;
205        let compressed_size = entry.compressed_size.get() as usize;
206        let compressed_range = compressed_start
207            ..compressed_start
208                .checked_add(compressed_size)
209                .ok_or(ChunkedArchiveError::OutOfRange)?;
210
211        Ok(Self { decompressed_range, compressed_range })
212    }
213}
214
215/// Validated information from decoding an archive.
216#[derive(Debug)]
217pub struct DecodedArchive {
218    compression_algorithm: CompressionAlgorithm,
219    seek_table: Vec<ChunkInfo>,
220}
221
222impl DecodedArchive {
223    /// The total size of decompressing all of the chunks in the archive.
224    pub fn decompressed_size(&self) -> usize {
225        self.seek_table.last().map_or(0, |entry| entry.decompressed_range.end)
226    }
227}
228
229/// Decodes a chunked archive header. Returns a `DecodedArchive` and any remaining bytes that are
230/// part of the chunk data. Returns `Ok(None)` if `data` is not large enough to decode the archive
231/// header & seek table.
232pub fn decode_archive(
233    data: &[u8],
234    archive_length: usize,
235) -> Result<Option<(DecodedArchive, /*archive_data*/ &[u8])>, ChunkedArchiveError> {
236    match Ref::<_, ChunkedArchiveHeader>::from_prefix(data).map_err(Into::into) {
237        Ok((header, data)) => header.decode_archive(data, archive_length as u64),
238        Err(zerocopy::SizeError { .. }) => Ok(None), // Not enough data.
239    }
240}
241
242/// Chunked archive header.
243#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
244#[repr(C)]
245struct ChunkedArchiveHeader {
246    magic: [u8; 8],
247    version: U16<LE>,
248    // This field was added in V3 and should not be used if `version` is 2. Technically, this field
249    // should be 0 in V2, Zstd has the value 0, and V2 always uses Zstd so accessing this field in
250    // V2 should give the correct result.
251    compression_algorithm: u8,
252    reserved_0: u8,
253    num_entries: U32<LE>,
254    checksum: U32<LE>,
255    reserved_1: U32<LE>,
256    reserved_2: U64<LE>,
257}
258
259/// Chunked archive seek table entry.
260#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
261#[repr(C)]
262struct SeekTableEntry {
263    decompressed_offset: U64<LE>,
264    decompressed_size: U64<LE>,
265    compressed_offset: U64<LE>,
266    compressed_size: U64<LE>,
267}
268
269impl ChunkedArchiveHeader {
270    const CHUNKED_ARCHIVE_MAGIC: [u8; 8] = [0x46, 0x9b, 0x78, 0xef, 0x0f, 0xd0, 0xb2, 0x03];
271    const CHUNKED_ARCHIVE_CHECKSUM_OFFSET: usize = 16;
272
273    fn new(
274        seek_table: &[SeekTableEntry],
275        options: ChunkedArchiveOptions,
276    ) -> Result<Self, ChunkedArchiveError> {
277        let header: ChunkedArchiveHeader = Self {
278            magic: Self::CHUNKED_ARCHIVE_MAGIC,
279            version: options.version().into(),
280            compression_algorithm: options.compression_algorithm().into(),
281            reserved_0: 0.into(),
282            num_entries: TryInto::<u32>::try_into(seek_table.len())
283                .or(Err(ChunkedArchiveError::OutOfRange))?
284                .into(),
285            checksum: 0.into(), // `checksum` is calculated below.
286            reserved_1: 0.into(),
287            reserved_2: 0.into(),
288        };
289        Ok(Self { checksum: header.checksum(seek_table).into(), ..header })
290    }
291
292    /// Calculate the checksum of the header + all seek table entries.
293    fn checksum(&self, entries: &[SeekTableEntry]) -> u32 {
294        let mut first_crc = crc::crc32::Digest::new(crc::crc32::IEEE);
295        first_crc.write(&self.as_bytes()[..Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET]);
296        let mut crc = crc::crc32::Digest::new_with_initial(crc::crc32::IEEE, first_crc.sum32());
297        crc.write(
298            &self.as_bytes()
299                [Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET + self.checksum.as_bytes().len()..],
300        );
301        crc.write(entries.as_bytes());
302        crc.sum32()
303    }
304
305    /// Calculate the total header length of an archive *including* all seek table entries.
306    fn header_length(num_entries: usize) -> usize {
307        std::mem::size_of::<ChunkedArchiveHeader>()
308            + (std::mem::size_of::<SeekTableEntry>() * num_entries)
309    }
310
311    /// Validates the archive header and decodes the seek table.
312    fn decode_archive(
313        self,
314        data: &[u8],
315        archive_length: u64,
316    ) -> Result<Option<(DecodedArchive, /*chunk_data*/ &[u8])>, ChunkedArchiveError> {
317        // Deserialize seek table.
318        let num_entries = self.num_entries.get() as usize;
319        let Ok((entries, chunk_data)) =
320            Ref::<_, [SeekTableEntry]>::from_prefix_with_elems(data, num_entries)
321        else {
322            return Ok(None);
323        };
324        let entries: &[SeekTableEntry] = Ref::into_ref(entries);
325
326        // Validate archive header.
327        if self.magic != Self::CHUNKED_ARCHIVE_MAGIC {
328            return Err(ChunkedArchiveError::BadMagic);
329        }
330        let version = self.version.get();
331        if !ChunkedArchiveOptions::is_valid_version(version) {
332            return Err(ChunkedArchiveError::InvalidVersion);
333        }
334        if self.checksum.get() != self.checksum(entries) {
335            return Err(ChunkedArchiveError::IntegrityError);
336        }
337        if entries.len() > ChunkedArchiveOptions::max_chunks_for_version(version)? {
338            return Err(ChunkedArchiveError::IntegrityError);
339        }
340        let compression_algorithm = CompressionAlgorithm::try_from(self.compression_algorithm)?;
341
342        // Validate seek table using invariants I0 through I5.
343
344        // I0: The first seek table entry, if any, must have decompressed offset 0.
345        if !entries.is_empty() && entries[0].decompressed_offset.get() != 0 {
346            return Err(ChunkedArchiveError::IntegrityError);
347        }
348
349        // I1: The compressed offsets of all seek table entries must not overlap with the header.
350        let header_length = Self::header_length(entries.len());
351        if entries.iter().any(|entry| entry.compressed_offset.get() < header_length as u64) {
352            return Err(ChunkedArchiveError::IntegrityError);
353        }
354
355        // I2: Each entry's decompressed offset must be equal to the end of the previous frame
356        //     (i.e. to the previous frame's decompressed offset + length).
357        for (prev, curr) in entries.iter().tuple_windows() {
358            if (prev.decompressed_offset.get() + prev.decompressed_size.get())
359                != curr.decompressed_offset.get()
360            {
361                return Err(ChunkedArchiveError::IntegrityError);
362            }
363        }
364
365        // I3: Each entry's compressed offset must be greater than or equal to the end of the
366        //     previous frame (i.e. to the previous frame's compressed offset + length).
367        for (prev, curr) in entries.iter().tuple_windows() {
368            if (prev.compressed_offset.get() + prev.compressed_size.get())
369                > curr.compressed_offset.get()
370            {
371                return Err(ChunkedArchiveError::IntegrityError);
372            }
373        }
374
375        // I4: Each entry must have a non-zero decompressed and compressed length.
376        for entry in entries.iter() {
377            if entry.decompressed_size.get() == 0 || entry.compressed_size.get() == 0 {
378                return Err(ChunkedArchiveError::IntegrityError);
379            }
380        }
381
382        // I5: Data referenced by each entry must fit within the specified file size.
383        for entry in entries.iter() {
384            let compressed_end = entry.compressed_offset.get() + entry.compressed_size.get();
385            if compressed_end > archive_length {
386                return Err(ChunkedArchiveError::IntegrityError);
387            }
388        }
389
390        let seek_table = entries
391            .into_iter()
392            .map(|entry| ChunkInfo::from_entry(entry, header_length))
393            .try_collect()?;
394        Ok(Some((DecodedArchive { seek_table, compression_algorithm }, chunk_data)))
395    }
396}
397
398/// In-memory representation of a compressed chunk.
399pub struct CompressedChunk {
400    /// Compressed data for this chunk.
401    pub compressed_data: Vec<u8>,
402    /// Size of this chunk when decompressed.
403    pub decompressed_size: usize,
404}
405
406/// In-memory representation of a compressed chunked archive.
407pub struct ChunkedArchive {
408    /// Chunks this archive contains, in order. Right now we only allow creating archives with
409    /// contiguous compressed and decompressed space.
410    chunks: Vec<CompressedChunk>,
411    /// Size used to chunk input when creating this archive. Last chunk may be smaller than this
412    /// amount.
413    chunk_size: usize,
414    /// The options used to construct this archive.
415    options: ChunkedArchiveOptions,
416}
417
418impl ChunkedArchive {
419    /// Create a ChunkedArchive for `data` compressing each chunk in parallel. This function uses
420    /// the `rayon` crate for parallelism. By default compression happens in the global thread pool,
421    /// but this function can also be executed within a locally scoped pool.
422    pub fn new(data: &[u8], options: ChunkedArchiveOptions) -> Result<Self, ChunkedArchiveError> {
423        let chunk_size = options.chunk_size_for(data.len());
424        let mut chunks: Vec<Result<CompressedChunk, ChunkedArchiveError>> = vec![];
425        let compressor = options.thread_local_compressor();
426        data.par_chunks(chunk_size)
427            .enumerate()
428            .map(|(index, chunk)| {
429                let compressed_data = compressor.compress(chunk, index)?;
430                Ok(CompressedChunk { compressed_data, decompressed_size: chunk.len() })
431            })
432            .collect_into_vec(&mut chunks);
433        let chunks: Vec<_> = chunks.into_iter().try_collect()?;
434        Ok(ChunkedArchive { chunks, chunk_size, options })
435    }
436
437    /// Accessor for compressed chunk data.
438    pub fn chunks(&self) -> &Vec<CompressedChunk> {
439        &self.chunks
440    }
441
442    /// The chunk size calculated for this archive during compression. Represents how input data
443    /// was chunked for compression. Note that the final chunk may be smaller than this amount
444    /// when decompressed.
445    pub fn chunk_size(&self) -> usize {
446        self.chunk_size
447    }
448
449    /// Sum of sizes of all compressed chunks.
450    pub fn compressed_data_size(&self) -> usize {
451        self.chunks.iter().map(|chunk| chunk.compressed_data.len()).sum()
452    }
453
454    /// Total size of the archive in bytes.
455    pub fn serialized_size(&self) -> usize {
456        ChunkedArchiveHeader::header_length(self.chunks.len()) + self.compressed_data_size()
457    }
458
459    /// Write the archive to `writer`.
460    pub fn write(self, mut writer: impl std::io::Write) -> Result<(), std::io::Error> {
461        let seek_table = self.make_seek_table();
462        let header = ChunkedArchiveHeader::new(&seek_table, self.options).unwrap();
463        writer.write_all(header.as_bytes())?;
464        writer.write_all(seek_table.as_slice().as_bytes())?;
465        for chunk in self.chunks {
466            writer.write_all(&chunk.compressed_data)?;
467        }
468        Ok(())
469    }
470
471    /// Create the seek table for this archive.
472    fn make_seek_table(&self) -> Vec<SeekTableEntry> {
473        let header_length = ChunkedArchiveHeader::header_length(self.chunks.len());
474        let mut seek_table = vec![];
475        seek_table.reserve(self.chunks.len());
476        let mut compressed_size: usize = 0;
477        let mut decompressed_offset: usize = 0;
478        for chunk in &self.chunks {
479            seek_table.push(SeekTableEntry {
480                decompressed_offset: (decompressed_offset as u64).into(),
481                decompressed_size: (chunk.decompressed_size as u64).into(),
482                compressed_offset: ((header_length + compressed_size) as u64).into(),
483                compressed_size: (chunk.compressed_data.len() as u64).into(),
484            });
485            compressed_size += chunk.compressed_data.len();
486            decompressed_offset += chunk.decompressed_size;
487        }
488        seek_table
489    }
490}
491
492/// Streaming decompressor for chunked archives. Example:
493/// ```
494/// // Create a chunked archive:
495/// let data: Vec<u8> = vec![3; 1024];
496/// let compressed = ChunkedArchive::new(&data, /*block_size*/ 8192).serialize().unwrap();
497/// // Verify the header + decode the seek table:
498/// let (seek_table, archive_data) = decode_archive(&compressed, compressed.len())?.unwrap();
499/// let mut decompressed: Vec<u8> = vec![];
500/// let mut on_chunk = |data: &[u8]| { decompressed.extend_from_slice(data); };
501/// let mut decompressor = ChunkedDecompressor(seek_table);
502/// // `on_chunk` is invoked as each slice is made available. Archive can be provided as chunks.
503/// decompressor.update(archive_data, &mut on_chunk);
504/// assert_eq!(data.as_slice(), decompressed.as_slice());
505/// ```
506pub struct ChunkedDecompressor {
507    seek_table: Vec<ChunkInfo>,
508    buffer: Vec<u8>,
509    data_written: usize,
510    curr_chunk: usize,
511    total_compressed_size: usize,
512    decompressor: Decompressor,
513    decompressed_buffer: Vec<u8>,
514    error_handler: Option<ErrorHandler>,
515}
516
517type ErrorHandler = Box<dyn Fn(usize, ChunkInfo, &[u8]) -> () + Send + 'static>;
518
519impl ChunkedDecompressor {
520    /// Create a new decompressor to decode an archive from a validated seek table.
521    pub fn new(decoded_archive: DecodedArchive) -> Result<Self, ChunkedArchiveError> {
522        let DecodedArchive { compression_algorithm, seek_table } = decoded_archive;
523        let total_compressed_size =
524            seek_table.last().map_or(0, |last_chunk| last_chunk.compressed_range.end);
525        let decompressed_buffer =
526            vec![0u8; seek_table.first().map_or(0, |c| c.decompressed_range.len())];
527        Ok(Self {
528            seek_table,
529            buffer: vec![],
530            data_written: 0,
531            curr_chunk: 0,
532            total_compressed_size,
533            decompressor: compression_algorithm.decompressor(),
534            decompressed_buffer,
535            error_handler: None,
536        })
537    }
538
539    /// Creates a new decompressor with an additional error handler invoked when a chunk fails to be
540    /// decompressed.
541    pub fn new_with_error_handler(
542        decoded_archive: DecodedArchive,
543        error_handler: ErrorHandler,
544    ) -> Result<Self, ChunkedArchiveError> {
545        Ok(Self { error_handler: Some(error_handler), ..Self::new(decoded_archive)? })
546    }
547
548    pub fn seek_table(&self) -> &Vec<ChunkInfo> {
549        &self.seek_table
550    }
551
552    fn finish_chunk(
553        &mut self,
554        data: &[u8],
555        chunk_callback: &mut impl FnMut(&[u8]) -> (),
556    ) -> Result<(), ChunkedArchiveError> {
557        debug_assert_eq!(data.len(), self.seek_table[self.curr_chunk].compressed_range.len());
558        let chunk = &self.seek_table[self.curr_chunk];
559        let decompressed_size = self
560            .decompressor
561            .decompress_into(data, self.decompressed_buffer.as_mut_slice(), self.curr_chunk)
562            .inspect_err(|_| {
563                if let Some(error_handler) = &self.error_handler {
564                    error_handler(self.curr_chunk, chunk.clone(), data.as_bytes());
565                }
566            })?;
567        if decompressed_size != chunk.decompressed_range.len() {
568            return Err(ChunkedArchiveError::IntegrityError);
569        }
570        chunk_callback(&self.decompressed_buffer[..decompressed_size]);
571        self.curr_chunk += 1;
572        Ok(())
573    }
574
575    /// Update the decompressor with more data.
576    pub fn update(
577        &mut self,
578        mut data: &[u8],
579        chunk_callback: &mut impl FnMut(&[u8]) -> (),
580    ) -> Result<(), ChunkedArchiveError> {
581        // Caller must not provide too much data.
582        if self.data_written + data.len() > self.total_compressed_size {
583            return Err(ChunkedArchiveError::OutOfRange);
584        }
585        self.data_written += data.len();
586
587        // If we had leftover data from a previous read, append until we've filled a chunk.
588        if !self.buffer.is_empty() {
589            let to_read = std::cmp::min(
590                data.len(),
591                self.seek_table[self.curr_chunk]
592                    .compressed_range
593                    .len()
594                    .checked_sub(self.buffer.len())
595                    .unwrap(),
596            );
597            self.buffer.extend_from_slice(&data[..to_read]);
598            if self.buffer.len() == self.seek_table[self.curr_chunk].compressed_range.len() {
599                // Take self.buffer temporarily (so we don't have to split borrows).
600                // That way we don't have to re-commit the pages we've already used in the buffer
601                // for next time.
602                let full_chunk = std::mem::take(&mut self.buffer);
603                self.finish_chunk(&full_chunk[..], chunk_callback)?;
604                self.buffer = full_chunk;
605                // Draining the buffer will set the length to 0 but keep the capacity the same.
606                self.buffer.drain(..);
607            }
608            data = &data[to_read..];
609        }
610
611        // Decode as many full chunks as we can.
612        while !data.is_empty()
613            && self.curr_chunk < self.seek_table.len()
614            && self.seek_table[self.curr_chunk].compressed_range.len() <= data.len()
615        {
616            let len = self.seek_table[self.curr_chunk].compressed_range.len();
617            self.finish_chunk(&data[..len], chunk_callback)?;
618            data = &data[len..];
619        }
620
621        // Buffer the rest for the next call.
622        if !data.is_empty() {
623            debug_assert!(self.curr_chunk < self.seek_table.len());
624            debug_assert!(self.data_written < self.total_compressed_size);
625            self.buffer.extend_from_slice(data);
626        }
627
628        debug_assert!(
629            self.data_written < self.total_compressed_size
630                || self.curr_chunk == self.seek_table.len()
631        );
632
633        Ok(())
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use crate::Type1Blob;
640
641    use super::*;
642    use rand::Rng;
643    use std::matches;
644
645    /// Create a compressed archive and ensure we can decode it as a valid archive that passes all
646    /// required integrity checks.
647    #[test]
648    fn compress_simple() {
649        let data: Vec<u8> = vec![0; 32 * 1024 * 16];
650        let archive = ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS).unwrap();
651        // This data is highly compressible, so the result should be smaller than the original.
652        let mut compressed: Vec<u8> = vec![];
653        archive.write(&mut compressed).unwrap();
654        assert!(compressed.len() <= data.len());
655        // We should be able to decode and verify the archive's integrity in-place.
656        assert!(decode_archive(&compressed, compressed.len()).unwrap().is_some());
657    }
658
659    /// Generate a header + seek table for verifying invariants/integrity checks.
660    fn generate_archive(
661        num_entries: usize,
662        options: ChunkedArchiveOptions,
663    ) -> (ChunkedArchiveHeader, Vec<SeekTableEntry>, /*archive_length*/ u64) {
664        let mut seek_table = Vec::with_capacity(num_entries);
665        let header_length = ChunkedArchiveHeader::header_length(num_entries) as u64;
666        const COMPRESSED_CHUNK_SIZE: u64 = 1024;
667        const DECOMPRESSED_CHUNK_SIZE: u64 = 2048;
668        for n in 0..(num_entries as u64) {
669            seek_table.push(SeekTableEntry {
670                compressed_offset: (header_length + (n * COMPRESSED_CHUNK_SIZE)).into(),
671                compressed_size: COMPRESSED_CHUNK_SIZE.into(),
672                decompressed_offset: (n * DECOMPRESSED_CHUNK_SIZE).into(),
673                decompressed_size: DECOMPRESSED_CHUNK_SIZE.into(),
674            });
675        }
676        let header = ChunkedArchiveHeader::new(&seek_table, options).unwrap();
677        let archive_length: u64 = header_length + (num_entries as u64 * COMPRESSED_CHUNK_SIZE);
678        (header, seek_table, archive_length)
679    }
680
681    #[test]
682    fn should_validate_self() {
683        let (header, seek_table, archive_length) =
684            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
685        let serialized_table = seek_table.as_slice().as_bytes();
686        assert!(header.decode_archive(serialized_table, archive_length).unwrap().is_some());
687    }
688
689    #[test]
690    fn should_validate_empty() {
691        let (header, _, archive_length) = generate_archive(0, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
692        assert!(header.decode_archive(&[], archive_length).unwrap().is_some());
693    }
694
695    #[test]
696    fn should_detect_bad_magic() {
697        let (header, seek_table, archive_length) =
698            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
699        let mut corrupt_magic = ChunkedArchiveHeader::CHUNKED_ARCHIVE_MAGIC;
700        corrupt_magic[0] = !corrupt_magic[0];
701        let bad_magic = ChunkedArchiveHeader { magic: corrupt_magic, ..header };
702        let serialized_table = seek_table.as_slice().as_bytes();
703        assert!(matches!(
704            bad_magic.decode_archive(serialized_table, archive_length).unwrap_err(),
705            ChunkedArchiveError::BadMagic
706        ));
707    }
708    #[test]
709    fn should_detect_wrong_version() {
710        let (header, seek_table, archive_length) =
711            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
712        let invalid_version = ChunkedArchiveHeader { version: u16::MAX.into(), ..header };
713        let serialized_table = seek_table.as_slice().as_bytes();
714        assert!(matches!(
715            invalid_version.decode_archive(serialized_table, archive_length).unwrap_err(),
716            ChunkedArchiveError::InvalidVersion
717        ));
718    }
719
720    #[test]
721    fn should_detect_corrupt_checksum() {
722        let (header, seek_table, archive_length) =
723            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
724        let corrupt_checksum =
725            ChunkedArchiveHeader { checksum: (!header.checksum.get()).into(), ..header };
726        let serialized_table = seek_table.as_slice().as_bytes();
727        assert!(matches!(
728            corrupt_checksum.decode_archive(serialized_table, archive_length).unwrap_err(),
729            ChunkedArchiveError::IntegrityError
730        ));
731    }
732
733    #[test]
734    fn should_reject_too_many_entries_v2() {
735        let (too_many_entries, seek_table, archive_length) = generate_archive(
736            ChunkedArchiveOptions::V2_MAX_CHUNKS + 1,
737            Type1Blob::CHUNKED_ARCHIVE_OPTIONS,
738        );
739
740        let serialized_table = seek_table.as_slice().as_bytes();
741        assert!(matches!(
742            too_many_entries.decode_archive(serialized_table, archive_length).unwrap_err(),
743            ChunkedArchiveError::IntegrityError
744        ));
745    }
746
747    #[test]
748    fn invariant_i0_first_entry_zero() {
749        let (header, mut seek_table, archive_length) =
750            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
751        assert_eq!(seek_table[0].decompressed_offset.get(), 0);
752        seek_table[0].decompressed_offset = 1.into();
753
754        let serialized_table = seek_table.as_slice().as_bytes();
755        assert!(matches!(
756            header.decode_archive(serialized_table, archive_length).unwrap_err(),
757            ChunkedArchiveError::IntegrityError
758        ));
759    }
760
761    #[test]
762    fn invariant_i1_no_header_overlap() {
763        let (header, mut seek_table, archive_length) =
764            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
765        let header_end = ChunkedArchiveHeader::header_length(seek_table.len()) as u64;
766        assert!(seek_table[0].compressed_offset.get() >= header_end);
767        seek_table[0].compressed_offset = (header_end - 1).into();
768        let serialized_table = seek_table.as_slice().as_bytes();
769        assert!(matches!(
770            header.decode_archive(serialized_table, archive_length).unwrap_err(),
771            ChunkedArchiveError::IntegrityError
772        ));
773    }
774
775    #[test]
776    fn invariant_i2_decompressed_monotonic() {
777        let (header, mut seek_table, archive_length) =
778            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
779        assert_eq!(
780            seek_table[0].decompressed_offset.get() + seek_table[0].decompressed_size.get(),
781            seek_table[1].decompressed_offset.get()
782        );
783        seek_table[1].decompressed_offset = (seek_table[1].decompressed_offset.get() - 1).into();
784        let serialized_table = seek_table.as_slice().as_bytes();
785        assert!(matches!(
786            header.decode_archive(serialized_table, archive_length).unwrap_err(),
787            ChunkedArchiveError::IntegrityError
788        ));
789    }
790
791    #[test]
792    fn invariant_i3_compressed_monotonic() {
793        let (header, mut seek_table, archive_length) =
794            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
795        assert!(
796            (seek_table[0].compressed_offset.get() + seek_table[0].compressed_size.get())
797                <= seek_table[1].compressed_offset.get()
798        );
799        seek_table[1].compressed_offset = (seek_table[1].compressed_offset.get() - 1).into();
800        let serialized_table = seek_table.as_slice().as_bytes();
801        assert!(matches!(
802            header.decode_archive(serialized_table, archive_length).unwrap_err(),
803            ChunkedArchiveError::IntegrityError
804        ));
805    }
806
807    #[test]
808    fn invariant_i4_nonzero_compressed_size() {
809        let (header, mut seek_table, archive_length) =
810            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
811        assert!(seek_table[0].compressed_size.get() > 0);
812        seek_table[0].compressed_size = 0.into();
813        let serialized_table = seek_table.as_slice().as_bytes();
814        assert!(matches!(
815            header.decode_archive(serialized_table, archive_length).unwrap_err(),
816            ChunkedArchiveError::IntegrityError
817        ));
818    }
819
820    #[test]
821    fn invariant_i4_nonzero_decompressed_size() {
822        let (header, mut seek_table, archive_length) =
823            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
824        assert!(seek_table[0].decompressed_size.get() > 0);
825        seek_table[0].decompressed_size = 0.into();
826        let serialized_table = seek_table.as_slice().as_bytes();
827        assert!(matches!(
828            header.decode_archive(serialized_table, archive_length).unwrap_err(),
829            ChunkedArchiveError::IntegrityError
830        ));
831    }
832
833    #[test]
834    fn invariant_i5_within_archive() {
835        let (header, mut seek_table, archive_length) =
836            generate_archive(4, Type1Blob::CHUNKED_ARCHIVE_OPTIONS);
837        let last_entry = seek_table.last_mut().unwrap();
838        assert!(
839            (last_entry.compressed_offset.get() + last_entry.compressed_size.get())
840                <= archive_length
841        );
842        last_entry.compressed_offset = (archive_length + 1).into();
843        let serialized_table = seek_table.as_slice().as_bytes();
844        assert!(matches!(
845            header.decode_archive(serialized_table, archive_length).unwrap_err(),
846            ChunkedArchiveError::IntegrityError
847        ));
848    }
849
850    #[test]
851    fn max_chunks() {
852        let ChunkedArchiveOptions::V2 { minimum_chunk_size, chunk_alignment, .. } =
853            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
854        else {
855            panic!()
856        };
857        assert_eq!(
858            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
859                .chunk_size_for(minimum_chunk_size * ChunkedArchiveOptions::V2_MAX_CHUNKS),
860            minimum_chunk_size
861        );
862        assert_eq!(
863            Type1Blob::CHUNKED_ARCHIVE_OPTIONS
864                .chunk_size_for(minimum_chunk_size * ChunkedArchiveOptions::V2_MAX_CHUNKS + 1),
865            minimum_chunk_size + chunk_alignment
866        );
867    }
868
869    #[test]
870    fn test_decompressor_empty_archive() {
871        let mut compressed: Vec<u8> = vec![];
872        ChunkedArchive::new(&[], Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
873            .expect("compress")
874            .write(&mut compressed)
875            .expect("write archive");
876        let (decoded_archive, chunk_data) =
877            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
878        assert!(decoded_archive.seek_table.is_empty());
879        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
880        let mut chunk_callback = |_chunk: &[u8]| panic!("Archive doesn't have any chunks.");
881        // Stream data into the decompressor in small chunks to exhaust more edge cases.
882        chunk_data
883            .chunks(4)
884            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
885    }
886
887    #[test]
888    fn test_decompressor() {
889        const UNCOMPRESSED_LENGTH: usize = 3_000_000;
890        let data: Vec<u8> = {
891            let range = rand::distr::Uniform::<u8>::new_inclusive(0, 255).unwrap();
892            rand::rng().sample_iter(&range).take(UNCOMPRESSED_LENGTH).collect()
893        };
894        let mut compressed: Vec<u8> = vec![];
895        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
896            .expect("compress")
897            .write(&mut compressed)
898            .expect("write archive");
899        let (decoded_archive, chunk_data) =
900            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
901
902        // Make sure we have multiple chunks for this test.
903        let num_chunks = decoded_archive.seek_table.len();
904        assert!(num_chunks > 1);
905
906        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
907
908        let mut decoded_chunks: usize = 0;
909        let mut decompressed_offset: usize = 0;
910        let mut chunk_callback = |decompressed_chunk: &[u8]| {
911            assert!(
912                decompressed_chunk
913                    == &data[decompressed_offset..decompressed_offset + decompressed_chunk.len()]
914            );
915            decompressed_offset += decompressed_chunk.len();
916            decoded_chunks += 1;
917        };
918
919        // Stream data into the decompressor in small chunks to exhaust more edge cases.
920        chunk_data
921            .chunks(4)
922            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
923        assert_eq!(decoded_chunks, num_chunks);
924    }
925
926    #[test]
927    fn test_decompressor_corrupt_decompressed_size() {
928        let data = vec![0; 3_000_000];
929        let mut compressed: Vec<u8> = vec![];
930        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
931            .expect("compress")
932            .write(&mut compressed)
933            .expect("write archive");
934        let (mut decoded_archive, chunk_data) =
935            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
936
937        // Corrupt the decompressed size of the chunk.
938        decoded_archive.seek_table[0].decompressed_range =
939            decoded_archive.seek_table[0].decompressed_range.start
940                ..decoded_archive.seek_table[0].decompressed_range.end + 1;
941
942        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
943        assert!(matches!(
944            decompressor.update(&chunk_data, &mut |_chunk| {}),
945            Err(ChunkedArchiveError::IntegrityError)
946        ));
947    }
948
949    #[test]
950    fn test_decompressor_corrupt_compressed_size() {
951        let data = vec![0; 3_000_000];
952        let mut compressed: Vec<u8> = vec![];
953        ChunkedArchive::new(&data, Type1Blob::CHUNKED_ARCHIVE_OPTIONS)
954            .expect("compress")
955            .write(&mut compressed)
956            .expect("write archive");
957        let (mut decoded_archive, chunk_data) =
958            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
959
960        // Corrupt the compressed size of the chunk.
961        decoded_archive.seek_table[0].compressed_range =
962            decoded_archive.seek_table[0].compressed_range.start
963                ..decoded_archive.seek_table[0].compressed_range.end - 1;
964        let first_chunk_info = decoded_archive.seek_table[0].clone();
965        let error_handler = move |chunk_index: usize, chunk_info: ChunkInfo, chunk_data: &[u8]| {
966            assert_eq!(chunk_index, 0);
967            assert_eq!(chunk_info, first_chunk_info);
968            assert_eq!(chunk_data.len(), chunk_info.compressed_range.len());
969        };
970
971        let mut decompressor =
972            ChunkedDecompressor::new_with_error_handler(decoded_archive, Box::new(error_handler))
973                .unwrap();
974        assert!(matches!(
975            decompressor.update(&chunk_data, &mut |_chunk| {}),
976            Err(ChunkedArchiveError::DecompressionError { index: 0, .. })
977        ));
978    }
979
980    #[test]
981    fn test_v3_zstd_roundtrip() {
982        let data = vec![0; 3_000_000];
983        let options =
984            ChunkedArchiveOptions::V3 { compression_algorithm: CompressionAlgorithm::Zstd };
985        let mut compressed = vec![];
986        ChunkedArchive::new(&data, options)
987            .expect("compress")
988            .write(&mut compressed)
989            .expect("write");
990
991        // Verify header.
992        let (header, _) =
993            Ref::<_, ChunkedArchiveHeader>::from_prefix(compressed.as_slice()).unwrap();
994        assert_eq!(header.version.get(), 3);
995        assert_eq!(header.compression_algorithm, CompressionAlgorithm::Zstd as u8);
996
997        let (decoded_archive, chunk_data) =
998            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
999
1000        // Decompress.
1001        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
1002        let mut decompressed: Vec<u8> = vec![];
1003        let mut chunk_callback = |chunk: &[u8]| decompressed.extend_from_slice(chunk);
1004        decompressor.update(chunk_data, &mut chunk_callback).unwrap();
1005
1006        assert_eq!(decompressed, data);
1007    }
1008
1009    #[test]
1010    fn test_v3_lz4_roundtrip() {
1011        let data = vec![0; 3_000_000];
1012        let options =
1013            ChunkedArchiveOptions::V3 { compression_algorithm: CompressionAlgorithm::Lz4 };
1014        let mut compressed = vec![];
1015        ChunkedArchive::new(&data, options)
1016            .expect("compress")
1017            .write(&mut compressed)
1018            .expect("write");
1019
1020        // Verify header.
1021        let (header, _) =
1022            Ref::<_, ChunkedArchiveHeader>::from_prefix(compressed.as_slice()).unwrap();
1023        assert_eq!(header.version.get(), 3);
1024        assert_eq!(header.compression_algorithm, CompressionAlgorithm::Lz4 as u8);
1025
1026        let (decoded_archive, chunk_data) =
1027            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
1028
1029        // Decompress.
1030        let mut decompressor = ChunkedDecompressor::new(decoded_archive).unwrap();
1031        let mut decompressed: Vec<u8> = vec![];
1032        let mut chunk_callback = |chunk: &[u8]| decompressed.extend_from_slice(chunk);
1033        decompressor.update(chunk_data, &mut chunk_callback).unwrap();
1034
1035        assert_eq!(decompressed, data);
1036    }
1037}