zstd/stream/zio/
writer.rs

1use std::io::{self, Write};
2
3use crate::stream::raw::{InBuffer, Operation, OutBuffer};
4
5// input -> [ zstd -> buffer -> writer ]
6
7/// Implements the [`Write`] API around an [`Operation`].
8///
9/// This can be used to wrap a raw in-memory operation in a write-focused API.
10///
11/// It can be used with either compression or decompression, and forwards the
12/// output to a wrapped `Write`.
13pub struct Writer<W, D> {
14    writer: W,
15    operation: D,
16
17    offset: usize,
18    buffer: Vec<u8>,
19
20    // When `true`, indicates that nothing should be added to the buffer.
21    // All that's left if to empty the buffer.
22    finished: bool,
23
24    finished_frame: bool,
25}
26
27impl<W, D> Writer<W, D>
28where
29    W: Write,
30    D: Operation,
31{
32    /// Creates a new `Writer`.
33    ///
34    /// All output from the given operation will be forwarded to `writer`.
35    pub fn new(writer: W, operation: D) -> Self {
36        Writer {
37            writer,
38            operation,
39
40            offset: 0,
41            // 32KB buffer? That's what flate2 uses
42            buffer: Vec::with_capacity(32 * 1024),
43
44            finished: false,
45            finished_frame: false,
46        }
47    }
48
49    /// Ends the stream.
50    ///
51    /// This *must* be called after all data has been written to finish the
52    /// stream.
53    ///
54    /// If you forget to call this and just drop the `Writer`, you *will* have
55    /// an incomplete output.
56    ///
57    /// Keep calling it until it returns `Ok(())`, then don't call it again.
58    pub fn finish(&mut self) -> io::Result<()> {
59        loop {
60            // Keep trying until we're really done.
61            self.write_from_offset()?;
62
63            // At this point the buffer has been fully written out.
64
65            if self.finished {
66                return Ok(());
67            }
68
69            // Let's fill this buffer again!
70
71            let finished_frame = self.finished_frame;
72            let hint =
73                self.with_buffer(|dst, op| op.finish(dst, finished_frame));
74            self.offset = 0;
75            // println!("Hint: {:?}\nOut:{:?}", hint, &self.buffer);
76
77            // We return here if zstd had a problem.
78            // Could happen with invalid data, ...
79            let hint = hint?;
80
81            if hint != 0 && self.buffer.is_empty() {
82                // This happens if we are decoding an incomplete frame.
83                return Err(io::Error::new(
84                    io::ErrorKind::UnexpectedEof,
85                    "incomplete frame",
86                ));
87            }
88
89            // println!("Finishing {}, {}", bytes_written, hint);
90
91            self.finished = hint == 0;
92        }
93    }
94
95    /// Run the given closure on `self.buffer`.
96    ///
97    /// The buffer will be cleared, and made available wrapped in an `OutBuffer`.
98    fn with_buffer<F, T>(&mut self, f: F) -> T
99    where
100        F: FnOnce(&mut OutBuffer<'_, Vec<u8>>, &mut D) -> T,
101    {
102        self.buffer.clear();
103        let mut output = OutBuffer::around(&mut self.buffer);
104        // eprintln!("Output: {:?}", output);
105        f(&mut output, &mut self.operation)
106    }
107
108    /// Attempt to write `self.buffer` to the wrapped writer.
109    ///
110    /// Returns `Ok(())` once all the buffer has been written.
111    fn write_from_offset(&mut self) -> io::Result<()> {
112        // The code looks a lot like `write_all`, but keeps track of what has
113        // been written in case we're interrupted.
114        while self.offset < self.buffer.len() {
115            match self.writer.write(&self.buffer[self.offset..]) {
116                Ok(0) => {
117                    return Err(io::Error::new(
118                        io::ErrorKind::WriteZero,
119                        "writer will not accept any more data",
120                    ))
121                }
122                Ok(n) => self.offset += n,
123                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => (),
124                Err(e) => return Err(e),
125            }
126        }
127        Ok(())
128    }
129
130    /// Return the wrapped `Writer` and `Operation`.
131    ///
132    /// Careful: if you call this before calling [`Writer::finish()`], the
133    /// output may be incomplete.
134    pub fn into_inner(self) -> (W, D) {
135        (self.writer, self.operation)
136    }
137
138    /// Gives a reference to the inner writer.
139    pub fn writer(&self) -> &W {
140        &self.writer
141    }
142
143    /// Gives a mutable reference to the inner writer.
144    pub fn writer_mut(&mut self) -> &mut W {
145        &mut self.writer
146    }
147
148    /// Gives a reference to the inner operation.
149    pub fn operation(&self) -> &D {
150        &self.operation
151    }
152
153    /// Gives a mutable reference to the inner operation.
154    pub fn operation_mut(&mut self) -> &mut D {
155        &mut self.operation
156    }
157
158    /// Returns the offset in the current buffer. Only useful for debugging.
159    #[cfg(test)]
160    pub fn offset(&self) -> usize {
161        self.offset
162    }
163
164    /// Returns the current buffer. Only useful for debugging.
165    #[cfg(test)]
166    pub fn buffer(&self) -> &[u8] {
167        &self.buffer
168    }
169}
170
171impl<W, D> Write for Writer<W, D>
172where
173    W: Write,
174    D: Operation,
175{
176    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
177        // Keep trying until _something_ has been consumed.
178        // As soon as some input has been taken, we cannot afford
179        // to take any chance: if an error occurs, the user couldn't know
180        // that some data _was_ successfully written.
181        loop {
182            // First, write any pending data from `self.buffer`.
183            self.write_from_offset()?;
184            // At this point `self.buffer` can safely be discarded.
185
186            // Support writing concatenated frames by re-initializing the
187            // context.
188            if self.finished_frame {
189                self.operation.reinit()?;
190                self.finished_frame = false;
191            }
192
193            let mut src = InBuffer::around(buf);
194            let hint = self.with_buffer(|dst, op| op.run(&mut src, dst));
195            let bytes_read = src.pos;
196
197            // eprintln!(
198            //     "Write Hint: {:?}\n src: {:?}\n dst: {:?}",
199            //     hint, src, self.buffer
200            // );
201
202            self.offset = 0;
203            let hint = hint?;
204
205            if hint == 0 {
206                self.finished_frame = true;
207            }
208
209            // As we said, as soon as we've consumed something, return.
210            if bytes_read > 0 || buf.is_empty() {
211                // println!("Returning {}", bytes_read);
212                return Ok(bytes_read);
213            }
214        }
215    }
216
217    fn flush(&mut self) -> io::Result<()> {
218        let mut finished = self.finished;
219        loop {
220            // If the output is blocked or has an error, return now.
221            self.write_from_offset()?;
222
223            if finished {
224                return Ok(());
225            }
226
227            let hint = self.with_buffer(|dst, op| op.flush(dst));
228
229            self.offset = 0;
230            let hint = hint?;
231
232            finished = hint == 0;
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::Writer;
240    use std::io::Write;
241
242    #[test]
243    fn test_noop() {
244        use crate::stream::raw::NoOp;
245
246        let input = b"AbcdefghAbcdefgh.";
247
248        // Test writer
249        let mut output = Vec::new();
250        {
251            let mut writer = Writer::new(&mut output, NoOp);
252            writer.write_all(input).unwrap();
253            writer.finish().unwrap();
254        }
255        assert_eq!(&output, input);
256    }
257
258    #[test]
259    fn test_compress() {
260        use crate::stream::raw::Encoder;
261
262        let input = b"AbcdefghAbcdefgh.";
263
264        // Test writer
265        let mut output = Vec::new();
266        {
267            let mut writer =
268                Writer::new(&mut output, Encoder::new(1).unwrap());
269            writer.write_all(input).unwrap();
270            writer.finish().unwrap();
271        }
272        // println!("Output: {:?}", output);
273        let decoded = crate::decode_all(&output[..]).unwrap();
274        assert_eq!(&decoded, input);
275    }
276
277    #[test]
278    fn test_decompress() {
279        use crate::stream::raw::Decoder;
280
281        let input = b"AbcdefghAbcdefgh.";
282        let compressed = crate::encode_all(&input[..], 1).unwrap();
283
284        // Test writer
285        let mut output = Vec::new();
286        {
287            let mut writer = Writer::new(&mut output, Decoder::new().unwrap());
288            writer.write_all(&compressed).unwrap();
289            writer.finish().unwrap();
290        }
291        // println!("Output: {:?}", output);
292        assert_eq!(&output, input);
293    }
294}