zstd/stream/
raw.rs

1//! Raw in-memory stream compression/decompression.
2//!
3//! This module defines a `Decoder` and an `Encoder` to decode/encode streams
4//! of data using buffers.
5//!
6//! They are mostly thin wrappers around `zstd_safe::{DCtx, CCtx}`.
7use std::io;
8
9pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};
10
11use crate::dict::{DecoderDictionary, EncoderDictionary};
12use crate::map_error_code;
13
14/// Represents an abstract compression/decompression operation.
15///
16/// This trait covers both `Encoder` and `Decoder`.
17pub trait Operation {
18    /// Performs a single step of this operation.
19    ///
20    /// Should return a hint for the next input size.
21    ///
22    /// If the result is `Ok(0)`, it may indicate that a frame was just
23    /// finished.
24    fn run<C: WriteBuf + ?Sized>(
25        &mut self,
26        input: &mut InBuffer<'_>,
27        output: &mut OutBuffer<'_, C>,
28    ) -> io::Result<usize>;
29
30    /// Performs a single step of this operation.
31    ///
32    /// This is a comvenience wrapper around `Operation::run` if you don't
33    /// want to deal with `InBuffer`/`OutBuffer`.
34    fn run_on_buffers(
35        &mut self,
36        input: &[u8],
37        output: &mut [u8],
38    ) -> io::Result<Status> {
39        let mut input = InBuffer::around(input);
40        let mut output = OutBuffer::around(output);
41
42        let remaining = self.run(&mut input, &mut output)?;
43
44        Ok(Status {
45            remaining,
46            bytes_read: input.pos(),
47            bytes_written: output.pos(),
48        })
49    }
50
51    /// Flushes any internal buffer, if any.
52    ///
53    /// Returns the number of bytes still in the buffer.
54    /// To flush entirely, keep calling until it returns `Ok(0)`.
55    fn flush<C: WriteBuf + ?Sized>(
56        &mut self,
57        output: &mut OutBuffer<'_, C>,
58    ) -> io::Result<usize> {
59        let _ = output;
60        Ok(0)
61    }
62
63    /// Prepares the operation for a new frame.
64    ///
65    /// This is hopefully cheaper than creating a new operation.
66    fn reinit(&mut self) -> io::Result<()> {
67        Ok(())
68    }
69
70    /// Finishes the operation, writing any footer if necessary.
71    ///
72    /// Returns the number of bytes still to write.
73    ///
74    /// Keep calling this method until it returns `Ok(0)`,
75    /// and then don't ever call this method.
76    fn finish<C: WriteBuf + ?Sized>(
77        &mut self,
78        output: &mut OutBuffer<'_, C>,
79        finished_frame: bool,
80    ) -> io::Result<usize> {
81        let _ = output;
82        let _ = finished_frame;
83        Ok(0)
84    }
85}
86
87/// Dummy operation that just copies its input to the output.
88pub struct NoOp;
89
90impl Operation for NoOp {
91    fn run<C: WriteBuf + ?Sized>(
92        &mut self,
93        input: &mut InBuffer<'_>,
94        output: &mut OutBuffer<'_, C>,
95    ) -> io::Result<usize> {
96        // Skip the prelude
97        let src = &input.src[input.pos..];
98        // Safe because `output.pos() <= output.dst.capacity()`.
99        let dst = unsafe { output.dst.as_mut_ptr().add(output.pos()) };
100
101        // Ignore anything past the end
102        let len = usize::min(src.len(), output.dst.capacity());
103        let src = &src[..len];
104
105        // Safe because:
106        // * `len` is less than either of the two lengths
107        // * `src` and `dst` do not overlap because we have `&mut` to each.
108        unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len) };
109        input.set_pos(input.pos() + len);
110        unsafe { output.set_pos(output.pos() + len) };
111
112        Ok(0)
113    }
114}
115
116/// Describes the result of an operation.
117pub struct Status {
118    /// Number of bytes expected for next input.
119    ///
120    /// This is just a hint.
121    pub remaining: usize,
122
123    /// Number of bytes read from the input.
124    pub bytes_read: usize,
125
126    /// Number of bytes written to the output.
127    pub bytes_written: usize,
128}
129
130/// An in-memory decoder for streams of data.
131pub struct Decoder<'a> {
132    context: zstd_safe::DCtx<'a>,
133}
134
135impl Decoder<'static> {
136    /// Creates a new decoder.
137    pub fn new() -> io::Result<Self> {
138        Self::with_dictionary(&[])
139    }
140
141    /// Creates a new decoder initialized with the given dictionary.
142    pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
143        let mut context = zstd_safe::DCtx::create();
144        context.init();
145        context
146            .load_dictionary(dictionary)
147            .map_err(map_error_code)?;
148        Ok(Decoder { context })
149    }
150}
151
152impl<'a> Decoder<'a> {
153    /// Creates a new decoder, using an existing `DecoderDictionary`.
154    pub fn with_prepared_dictionary<'b>(
155        dictionary: &DecoderDictionary<'b>,
156    ) -> io::Result<Self>
157    where
158        'b: 'a,
159    {
160        let mut context = zstd_safe::DCtx::create();
161        context
162            .ref_ddict(dictionary.as_ddict())
163            .map_err(map_error_code)?;
164        Ok(Decoder { context })
165    }
166
167    /// Sets a decompression parameter for this decoder.
168    pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
169        self.context
170            .set_parameter(parameter)
171            .map_err(map_error_code)?;
172        Ok(())
173    }
174}
175
176impl Operation for Decoder<'_> {
177    fn run<C: WriteBuf + ?Sized>(
178        &mut self,
179        input: &mut InBuffer<'_>,
180        output: &mut OutBuffer<'_, C>,
181    ) -> io::Result<usize> {
182        self.context
183            .decompress_stream(output, input)
184            .map_err(map_error_code)
185    }
186
187    fn reinit(&mut self) -> io::Result<()> {
188        self.context.reset().map_err(map_error_code)?;
189        Ok(())
190    }
191
192    fn finish<C: WriteBuf + ?Sized>(
193        &mut self,
194        _output: &mut OutBuffer<'_, C>,
195        finished_frame: bool,
196    ) -> io::Result<usize> {
197        if finished_frame {
198            Ok(0)
199        } else {
200            Err(io::Error::new(
201                io::ErrorKind::UnexpectedEof,
202                "incomplete frame",
203            ))
204        }
205    }
206}
207
208/// An in-memory encoder for streams of data.
209pub struct Encoder<'a> {
210    context: zstd_safe::CCtx<'a>,
211}
212
213impl Encoder<'static> {
214    /// Creates a new encoder.
215    pub fn new(level: i32) -> io::Result<Self> {
216        Self::with_dictionary(level, &[])
217    }
218
219    /// Creates a new encoder initialized with the given dictionary.
220    pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
221        let mut context = zstd_safe::CCtx::create();
222
223        context
224            .set_parameter(CParameter::CompressionLevel(level))
225            .map_err(map_error_code)?;
226
227        context
228            .load_dictionary(dictionary)
229            .map_err(map_error_code)?;
230
231        Ok(Encoder { context })
232    }
233}
234
235impl<'a> Encoder<'a> {
236    /// Creates a new encoder using an existing `EncoderDictionary`.
237    pub fn with_prepared_dictionary<'b>(
238        dictionary: &EncoderDictionary<'b>,
239    ) -> io::Result<Self>
240    where
241        'b: 'a,
242    {
243        let mut context = zstd_safe::CCtx::create();
244        context
245            .ref_cdict(dictionary.as_cdict())
246            .map_err(map_error_code)?;
247        Ok(Encoder { context })
248    }
249
250    /// Sets a compression parameter for this encoder.
251    pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
252        self.context
253            .set_parameter(parameter)
254            .map_err(map_error_code)?;
255        Ok(())
256    }
257
258    /// Sets the size of the input expected by zstd.
259    ///
260    /// May affect compression ratio.
261    ///
262    /// It is an error to give an incorrect size (an error _will_ be returned when closing the
263    /// stream).
264    pub fn set_pledged_src_size(
265        &mut self,
266        pledged_src_size: u64,
267    ) -> io::Result<()> {
268        self.context
269            .set_pledged_src_size(pledged_src_size)
270            .map_err(map_error_code)?;
271        Ok(())
272    }
273}
274
275impl<'a> Operation for Encoder<'a> {
276    fn run<C: WriteBuf + ?Sized>(
277        &mut self,
278        input: &mut InBuffer<'_>,
279        output: &mut OutBuffer<'_, C>,
280    ) -> io::Result<usize> {
281        self.context
282            .compress_stream(output, input)
283            .map_err(map_error_code)
284    }
285
286    fn flush<C: WriteBuf + ?Sized>(
287        &mut self,
288        output: &mut OutBuffer<'_, C>,
289    ) -> io::Result<usize> {
290        self.context.flush_stream(output).map_err(map_error_code)
291    }
292
293    fn finish<C: WriteBuf + ?Sized>(
294        &mut self,
295        output: &mut OutBuffer<'_, C>,
296        _finished_frame: bool,
297    ) -> io::Result<usize> {
298        self.context.end_stream(output).map_err(map_error_code)
299    }
300
301    fn reinit(&mut self) -> io::Result<()> {
302        self.context
303            .reset(zstd_safe::ResetDirective::ZSTD_reset_session_only)
304            .map_err(map_error_code)?;
305        Ok(())
306    }
307}
308
309#[cfg(test)]
310mod tests {
311
312    // This requires impl for [u8; N] which is currently behind a feature.
313    #[cfg(feature = "arrays")]
314    #[test]
315    fn test_cycle() {
316        use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
317
318        let mut encoder = Encoder::new(1).unwrap();
319        let mut decoder = Decoder::new().unwrap();
320
321        // Step 1: compress
322        let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
323
324        let mut output = [0u8; 128];
325        let mut output = OutBuffer::around(&mut output);
326
327        loop {
328            encoder.run(&mut input, &mut output).unwrap();
329
330            if input.pos == input.src.len() {
331                break;
332            }
333        }
334        encoder.finish(&mut output, true).unwrap();
335
336        let initial_data = input.src;
337
338        // Step 2: decompress
339        let mut input = InBuffer::around(output.as_slice());
340        let mut output = [0u8; 128];
341        let mut output = OutBuffer::around(&mut output);
342
343        loop {
344            decoder.run(&mut input, &mut output).unwrap();
345
346            if input.pos == input.src.len() {
347                break;
348            }
349        }
350
351        assert_eq!(initial_data, output.as_slice());
352    }
353}