zstd/stream/
raw.rs
1use std::io;
8
9pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};
10
11use crate::dict::{DecoderDictionary, EncoderDictionary};
12use crate::map_error_code;
13
14pub trait Operation {
18 fn run<C: WriteBuf + ?Sized>(
25 &mut self,
26 input: &mut InBuffer<'_>,
27 output: &mut OutBuffer<'_, C>,
28 ) -> io::Result<usize>;
29
30 fn run_on_buffers(
35 &mut self,
36 input: &[u8],
37 output: &mut [u8],
38 ) -> io::Result<Status> {
39 let mut input = InBuffer::around(input);
40 let mut output = OutBuffer::around(output);
41
42 let remaining = self.run(&mut input, &mut output)?;
43
44 Ok(Status {
45 remaining,
46 bytes_read: input.pos(),
47 bytes_written: output.pos(),
48 })
49 }
50
51 fn flush<C: WriteBuf + ?Sized>(
56 &mut self,
57 output: &mut OutBuffer<'_, C>,
58 ) -> io::Result<usize> {
59 let _ = output;
60 Ok(0)
61 }
62
63 fn reinit(&mut self) -> io::Result<()> {
67 Ok(())
68 }
69
70 fn finish<C: WriteBuf + ?Sized>(
77 &mut self,
78 output: &mut OutBuffer<'_, C>,
79 finished_frame: bool,
80 ) -> io::Result<usize> {
81 let _ = output;
82 let _ = finished_frame;
83 Ok(0)
84 }
85}
86
87pub struct NoOp;
89
90impl Operation for NoOp {
91 fn run<C: WriteBuf + ?Sized>(
92 &mut self,
93 input: &mut InBuffer<'_>,
94 output: &mut OutBuffer<'_, C>,
95 ) -> io::Result<usize> {
96 let src = &input.src[input.pos..];
98 let dst = unsafe { output.dst.as_mut_ptr().add(output.pos()) };
100
101 let len = usize::min(src.len(), output.dst.capacity());
103 let src = &src[..len];
104
105 unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len) };
109 input.set_pos(input.pos() + len);
110 unsafe { output.set_pos(output.pos() + len) };
111
112 Ok(0)
113 }
114}
115
116pub struct Status {
118 pub remaining: usize,
122
123 pub bytes_read: usize,
125
126 pub bytes_written: usize,
128}
129
130pub struct Decoder<'a> {
132 context: zstd_safe::DCtx<'a>,
133}
134
135impl Decoder<'static> {
136 pub fn new() -> io::Result<Self> {
138 Self::with_dictionary(&[])
139 }
140
141 pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
143 let mut context = zstd_safe::DCtx::create();
144 context.init();
145 context
146 .load_dictionary(dictionary)
147 .map_err(map_error_code)?;
148 Ok(Decoder { context })
149 }
150}
151
152impl<'a> Decoder<'a> {
153 pub fn with_prepared_dictionary<'b>(
155 dictionary: &DecoderDictionary<'b>,
156 ) -> io::Result<Self>
157 where
158 'b: 'a,
159 {
160 let mut context = zstd_safe::DCtx::create();
161 context
162 .ref_ddict(dictionary.as_ddict())
163 .map_err(map_error_code)?;
164 Ok(Decoder { context })
165 }
166
167 pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
169 self.context
170 .set_parameter(parameter)
171 .map_err(map_error_code)?;
172 Ok(())
173 }
174}
175
176impl Operation for Decoder<'_> {
177 fn run<C: WriteBuf + ?Sized>(
178 &mut self,
179 input: &mut InBuffer<'_>,
180 output: &mut OutBuffer<'_, C>,
181 ) -> io::Result<usize> {
182 self.context
183 .decompress_stream(output, input)
184 .map_err(map_error_code)
185 }
186
187 fn reinit(&mut self) -> io::Result<()> {
188 self.context.reset().map_err(map_error_code)?;
189 Ok(())
190 }
191
192 fn finish<C: WriteBuf + ?Sized>(
193 &mut self,
194 _output: &mut OutBuffer<'_, C>,
195 finished_frame: bool,
196 ) -> io::Result<usize> {
197 if finished_frame {
198 Ok(0)
199 } else {
200 Err(io::Error::new(
201 io::ErrorKind::UnexpectedEof,
202 "incomplete frame",
203 ))
204 }
205 }
206}
207
208pub struct Encoder<'a> {
210 context: zstd_safe::CCtx<'a>,
211}
212
213impl Encoder<'static> {
214 pub fn new(level: i32) -> io::Result<Self> {
216 Self::with_dictionary(level, &[])
217 }
218
219 pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
221 let mut context = zstd_safe::CCtx::create();
222
223 context
224 .set_parameter(CParameter::CompressionLevel(level))
225 .map_err(map_error_code)?;
226
227 context
228 .load_dictionary(dictionary)
229 .map_err(map_error_code)?;
230
231 Ok(Encoder { context })
232 }
233}
234
235impl<'a> Encoder<'a> {
236 pub fn with_prepared_dictionary<'b>(
238 dictionary: &EncoderDictionary<'b>,
239 ) -> io::Result<Self>
240 where
241 'b: 'a,
242 {
243 let mut context = zstd_safe::CCtx::create();
244 context
245 .ref_cdict(dictionary.as_cdict())
246 .map_err(map_error_code)?;
247 Ok(Encoder { context })
248 }
249
250 pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
252 self.context
253 .set_parameter(parameter)
254 .map_err(map_error_code)?;
255 Ok(())
256 }
257
258 pub fn set_pledged_src_size(
265 &mut self,
266 pledged_src_size: u64,
267 ) -> io::Result<()> {
268 self.context
269 .set_pledged_src_size(pledged_src_size)
270 .map_err(map_error_code)?;
271 Ok(())
272 }
273}
274
275impl<'a> Operation for Encoder<'a> {
276 fn run<C: WriteBuf + ?Sized>(
277 &mut self,
278 input: &mut InBuffer<'_>,
279 output: &mut OutBuffer<'_, C>,
280 ) -> io::Result<usize> {
281 self.context
282 .compress_stream(output, input)
283 .map_err(map_error_code)
284 }
285
286 fn flush<C: WriteBuf + ?Sized>(
287 &mut self,
288 output: &mut OutBuffer<'_, C>,
289 ) -> io::Result<usize> {
290 self.context.flush_stream(output).map_err(map_error_code)
291 }
292
293 fn finish<C: WriteBuf + ?Sized>(
294 &mut self,
295 output: &mut OutBuffer<'_, C>,
296 _finished_frame: bool,
297 ) -> io::Result<usize> {
298 self.context.end_stream(output).map_err(map_error_code)
299 }
300
301 fn reinit(&mut self) -> io::Result<()> {
302 self.context
303 .reset(zstd_safe::ResetDirective::ZSTD_reset_session_only)
304 .map_err(map_error_code)?;
305 Ok(())
306 }
307}
308
309#[cfg(test)]
310mod tests {
311
312 #[cfg(feature = "arrays")]
314 #[test]
315 fn test_cycle() {
316 use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
317
318 let mut encoder = Encoder::new(1).unwrap();
319 let mut decoder = Decoder::new().unwrap();
320
321 let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
323
324 let mut output = [0u8; 128];
325 let mut output = OutBuffer::around(&mut output);
326
327 loop {
328 encoder.run(&mut input, &mut output).unwrap();
329
330 if input.pos == input.src.len() {
331 break;
332 }
333 }
334 encoder.finish(&mut output, true).unwrap();
335
336 let initial_data = input.src;
337
338 let mut input = InBuffer::around(output.as_slice());
340 let mut output = [0u8; 128];
341 let mut output = OutBuffer::around(&mut output);
342
343 loop {
344 decoder.run(&mut input, &mut output).unwrap();
345
346 if input.pos == input.src.len() {
347 break;
348 }
349 }
350
351 assert_eq!(initial_data, output.as_slice());
352 }
353}