flate2/gz/
write.rs

1use std::cmp;
2use std::io;
3use std::io::prelude::*;
4
5#[cfg(feature = "tokio")]
6use futures::Poll;
7#[cfg(feature = "tokio")]
8use tokio_io::{AsyncRead, AsyncWrite};
9
10use super::bufread::{corrupt, read_gz_header};
11use super::{GzBuilder, GzHeader};
12use crate::crc::{Crc, CrcWriter};
13use crate::zio;
14use crate::{Compress, Compression, Decompress, Status};
15
16/// A gzip streaming encoder
17///
18/// This structure exposes a [`Write`] interface that will emit compressed data
19/// to the underlying writer `W`.
20///
21/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
22///
23/// # Examples
24///
25/// ```
26/// use std::io::prelude::*;
27/// use flate2::Compression;
28/// use flate2::write::GzEncoder;
29///
30/// // Vec<u8> implements Write to print the compressed bytes of sample string
31/// # fn main() {
32///
33/// let mut e = GzEncoder::new(Vec::new(), Compression::default());
34/// e.write_all(b"Hello World").unwrap();
35/// println!("{:?}", e.finish().unwrap());
36/// # }
37/// ```
38#[derive(Debug)]
39pub struct GzEncoder<W: Write> {
40    inner: zio::Writer<W, Compress>,
41    crc: Crc,
42    crc_bytes_written: usize,
43    header: Vec<u8>,
44}
45
46pub fn gz_encoder<W: Write>(header: Vec<u8>, w: W, lvl: Compression) -> GzEncoder<W> {
47    GzEncoder {
48        inner: zio::Writer::new(w, Compress::new(lvl, false)),
49        crc: Crc::new(),
50        header: header,
51        crc_bytes_written: 0,
52    }
53}
54
55impl<W: Write> GzEncoder<W> {
56    /// Creates a new encoder which will use the given compression level.
57    ///
58    /// The encoder is not configured specially for the emitted header. For
59    /// header configuration, see the `GzBuilder` type.
60    ///
61    /// The data written to the returned encoder will be compressed and then
62    /// written to the stream `w`.
63    pub fn new(w: W, level: Compression) -> GzEncoder<W> {
64        GzBuilder::new().write(w, level)
65    }
66
67    /// Acquires a reference to the underlying writer.
68    pub fn get_ref(&self) -> &W {
69        self.inner.get_ref()
70    }
71
72    /// Acquires a mutable reference to the underlying writer.
73    ///
74    /// Note that mutation of the writer may result in surprising results if
75    /// this encoder is continued to be used.
76    pub fn get_mut(&mut self) -> &mut W {
77        self.inner.get_mut()
78    }
79
80    /// Attempt to finish this output stream, writing out final chunks of data.
81    ///
82    /// Note that this function can only be used once data has finished being
83    /// written to the output stream. After this function is called then further
84    /// calls to `write` may result in a panic.
85    ///
86    /// # Panics
87    ///
88    /// Attempts to write data to this stream may result in a panic after this
89    /// function is called.
90    ///
91    /// # Errors
92    ///
93    /// This function will perform I/O to complete this stream, and any I/O
94    /// errors which occur will be returned from this function.
95    pub fn try_finish(&mut self) -> io::Result<()> {
96        self.write_header()?;
97        self.inner.finish()?;
98
99        while self.crc_bytes_written < 8 {
100            let (sum, amt) = (self.crc.sum() as u32, self.crc.amount());
101            let buf = [
102                (sum >> 0) as u8,
103                (sum >> 8) as u8,
104                (sum >> 16) as u8,
105                (sum >> 24) as u8,
106                (amt >> 0) as u8,
107                (amt >> 8) as u8,
108                (amt >> 16) as u8,
109                (amt >> 24) as u8,
110            ];
111            let inner = self.inner.get_mut();
112            let n = inner.write(&buf[self.crc_bytes_written..])?;
113            self.crc_bytes_written += n;
114        }
115        Ok(())
116    }
117
118    /// Finish encoding this stream, returning the underlying writer once the
119    /// encoding is done.
120    ///
121    /// Note that this function may not be suitable to call in a situation where
122    /// the underlying stream is an asynchronous I/O stream. To finish a stream
123    /// the `try_finish` (or `shutdown`) method should be used instead. To
124    /// re-acquire ownership of a stream it is safe to call this method after
125    /// `try_finish` or `shutdown` has returned `Ok`.
126    ///
127    /// # Errors
128    ///
129    /// This function will perform I/O to complete this stream, and any I/O
130    /// errors which occur will be returned from this function.
131    pub fn finish(mut self) -> io::Result<W> {
132        self.try_finish()?;
133        Ok(self.inner.take_inner())
134    }
135
136    fn write_header(&mut self) -> io::Result<()> {
137        while self.header.len() > 0 {
138            let n = self.inner.get_mut().write(&self.header)?;
139            self.header.drain(..n);
140        }
141        Ok(())
142    }
143}
144
145impl<W: Write> Write for GzEncoder<W> {
146    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
147        assert_eq!(self.crc_bytes_written, 0);
148        self.write_header()?;
149        let n = self.inner.write(buf)?;
150        self.crc.update(&buf[..n]);
151        Ok(n)
152    }
153
154    fn flush(&mut self) -> io::Result<()> {
155        assert_eq!(self.crc_bytes_written, 0);
156        self.write_header()?;
157        self.inner.flush()
158    }
159}
160
161#[cfg(feature = "tokio")]
162impl<W: AsyncWrite> AsyncWrite for GzEncoder<W> {
163    fn shutdown(&mut self) -> Poll<(), io::Error> {
164        self.try_finish()?;
165        self.get_mut().shutdown()
166    }
167}
168
169impl<R: Read + Write> Read for GzEncoder<R> {
170    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
171        self.get_mut().read(buf)
172    }
173}
174
175#[cfg(feature = "tokio")]
176impl<R: AsyncRead + AsyncWrite> AsyncRead for GzEncoder<R> {}
177
178impl<W: Write> Drop for GzEncoder<W> {
179    fn drop(&mut self) {
180        if self.inner.is_present() {
181            let _ = self.try_finish();
182        }
183    }
184}
185
186/// A gzip streaming decoder
187///
188/// This structure exposes a [`Write`] interface that will emit compressed data
189/// to the underlying writer `W`.
190///
191/// [`Write`]: https://doc.rust-lang.org/std/io/trait.Write.html
192///
193/// # Examples
194///
195/// ```
196/// use std::io::prelude::*;
197/// use std::io;
198/// use flate2::Compression;
199/// use flate2::write::{GzEncoder, GzDecoder};
200///
201/// # fn main() {
202/// #    let mut e = GzEncoder::new(Vec::new(), Compression::default());
203/// #    e.write(b"Hello World").unwrap();
204/// #    let bytes = e.finish().unwrap();
205/// #    assert_eq!("Hello World", decode_writer(bytes).unwrap());
206/// # }
207/// // Uncompresses a gzip encoded vector of bytes and returns a string or error
208/// // Here Vec<u8> implements Write
209/// fn decode_writer(bytes: Vec<u8>) -> io::Result<String> {
210///    let mut writer = Vec::new();
211///    let mut decoder = GzDecoder::new(writer);
212///    decoder.write_all(&bytes[..])?;
213///    writer = decoder.finish()?;
214///    let return_string = String::from_utf8(writer).expect("String parsing error");
215///    Ok(return_string)
216/// }
217/// ```
218#[derive(Debug)]
219pub struct GzDecoder<W: Write> {
220    inner: zio::Writer<CrcWriter<W>, Decompress>,
221    crc_bytes: Vec<u8>,
222    header: Option<GzHeader>,
223    header_buf: Vec<u8>,
224}
225
226const CRC_BYTES_LEN: usize = 8;
227
228impl<W: Write> GzDecoder<W> {
229    /// Creates a new decoder which will write uncompressed data to the stream.
230    ///
231    /// When this encoder is dropped or unwrapped the final pieces of data will
232    /// be flushed.
233    pub fn new(w: W) -> GzDecoder<W> {
234        GzDecoder {
235            inner: zio::Writer::new(CrcWriter::new(w), Decompress::new(false)),
236            crc_bytes: Vec::with_capacity(CRC_BYTES_LEN),
237            header: None,
238            header_buf: Vec::new(),
239        }
240    }
241
242    /// Returns the header associated with this stream.
243    pub fn header(&self) -> Option<&GzHeader> {
244        self.header.as_ref()
245    }
246
247    /// Acquires a reference to the underlying writer.
248    pub fn get_ref(&self) -> &W {
249        self.inner.get_ref().get_ref()
250    }
251
252    /// Acquires a mutable reference to the underlying writer.
253    ///
254    /// Note that mutating the output/input state of the stream may corrupt this
255    /// object, so care must be taken when using this method.
256    pub fn get_mut(&mut self) -> &mut W {
257        self.inner.get_mut().get_mut()
258    }
259
260    /// Attempt to finish this output stream, writing out final chunks of data.
261    ///
262    /// Note that this function can only be used once data has finished being
263    /// written to the output stream. After this function is called then further
264    /// calls to `write` may result in a panic.
265    ///
266    /// # Panics
267    ///
268    /// Attempts to write data to this stream may result in a panic after this
269    /// function is called.
270    ///
271    /// # Errors
272    ///
273    /// This function will perform I/O to finish the stream, returning any
274    /// errors which happen.
275    pub fn try_finish(&mut self) -> io::Result<()> {
276        self.finish_and_check_crc()?;
277        Ok(())
278    }
279
280    /// Consumes this decoder, flushing the output stream.
281    ///
282    /// This will flush the underlying data stream and then return the contained
283    /// writer if the flush succeeded.
284    ///
285    /// Note that this function may not be suitable to call in a situation where
286    /// the underlying stream is an asynchronous I/O stream. To finish a stream
287    /// the `try_finish` (or `shutdown`) method should be used instead. To
288    /// re-acquire ownership of a stream it is safe to call this method after
289    /// `try_finish` or `shutdown` has returned `Ok`.
290    ///
291    /// # Errors
292    ///
293    /// This function will perform I/O to complete this stream, and any I/O
294    /// errors which occur will be returned from this function.
295    pub fn finish(mut self) -> io::Result<W> {
296        self.finish_and_check_crc()?;
297        Ok(self.inner.take_inner().into_inner())
298    }
299
300    fn finish_and_check_crc(&mut self) -> io::Result<()> {
301        self.inner.finish()?;
302
303        if self.crc_bytes.len() != 8 {
304            return Err(corrupt());
305        }
306
307        let crc = ((self.crc_bytes[0] as u32) << 0)
308            | ((self.crc_bytes[1] as u32) << 8)
309            | ((self.crc_bytes[2] as u32) << 16)
310            | ((self.crc_bytes[3] as u32) << 24);
311        let amt = ((self.crc_bytes[4] as u32) << 0)
312            | ((self.crc_bytes[5] as u32) << 8)
313            | ((self.crc_bytes[6] as u32) << 16)
314            | ((self.crc_bytes[7] as u32) << 24);
315        if crc != self.inner.get_ref().crc().sum() as u32 {
316            return Err(corrupt());
317        }
318        if amt != self.inner.get_ref().crc().amount() {
319            return Err(corrupt());
320        }
321        Ok(())
322    }
323}
324
325struct Counter<T: Read> {
326    inner: T,
327    pos: usize,
328}
329
330impl<T: Read> Read for Counter<T> {
331    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
332        let pos = self.inner.read(buf)?;
333        self.pos += pos;
334        Ok(pos)
335    }
336}
337
338impl<W: Write> Write for GzDecoder<W> {
339    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
340        if self.header.is_none() {
341            // trying to avoid buffer usage
342            let (res, pos) = {
343                let mut counter = Counter {
344                    inner: self.header_buf.chain(buf),
345                    pos: 0,
346                };
347                let res = read_gz_header(&mut counter);
348                (res, counter.pos)
349            };
350
351            match res {
352                Err(err) => {
353                    if err.kind() == io::ErrorKind::UnexpectedEof {
354                        // not enough data for header, save to the buffer
355                        self.header_buf.extend(buf);
356                        Ok(buf.len())
357                    } else {
358                        Err(err)
359                    }
360                }
361                Ok(header) => {
362                    self.header = Some(header);
363                    let pos = pos - self.header_buf.len();
364                    self.header_buf.truncate(0);
365                    Ok(pos)
366                }
367            }
368        } else {
369            let (n, status) = self.inner.write_with_status(buf)?;
370
371            if status == Status::StreamEnd {
372                if n < buf.len() && self.crc_bytes.len() < 8 {
373                    let remaining = buf.len() - n;
374                    let crc_bytes = cmp::min(remaining, CRC_BYTES_LEN - self.crc_bytes.len());
375                    self.crc_bytes.extend(&buf[n..n + crc_bytes]);
376                    return Ok(n + crc_bytes);
377                }
378            }
379            Ok(n)
380        }
381    }
382
383    fn flush(&mut self) -> io::Result<()> {
384        self.inner.flush()
385    }
386}
387
388#[cfg(feature = "tokio")]
389impl<W: AsyncWrite> AsyncWrite for GzDecoder<W> {
390    fn shutdown(&mut self) -> Poll<(), io::Error> {
391        self.try_finish()?;
392        self.inner.get_mut().get_mut().shutdown()
393    }
394}
395
396impl<W: Read + Write> Read for GzDecoder<W> {
397    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
398        self.inner.get_mut().get_mut().read(buf)
399    }
400}
401
402#[cfg(feature = "tokio")]
403impl<W: AsyncRead + AsyncWrite> AsyncRead for GzDecoder<W> {}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    const STR: &'static str = "Hello World Hello World Hello World Hello World Hello World \
410                               Hello World Hello World Hello World Hello World Hello World \
411                               Hello World Hello World Hello World Hello World Hello World \
412                               Hello World Hello World Hello World Hello World Hello World \
413                               Hello World Hello World Hello World Hello World Hello World";
414
415    #[test]
416    fn decode_writer_one_chunk() {
417        let mut e = GzEncoder::new(Vec::new(), Compression::default());
418        e.write(STR.as_ref()).unwrap();
419        let bytes = e.finish().unwrap();
420
421        let mut writer = Vec::new();
422        let mut decoder = GzDecoder::new(writer);
423        let n = decoder.write(&bytes[..]).unwrap();
424        decoder.write(&bytes[n..]).unwrap();
425        decoder.try_finish().unwrap();
426        writer = decoder.finish().unwrap();
427        let return_string = String::from_utf8(writer).expect("String parsing error");
428        assert_eq!(return_string, STR);
429    }
430
431    #[test]
432    fn decode_writer_partial_header() {
433        let mut e = GzEncoder::new(Vec::new(), Compression::default());
434        e.write(STR.as_ref()).unwrap();
435        let bytes = e.finish().unwrap();
436
437        let mut writer = Vec::new();
438        let mut decoder = GzDecoder::new(writer);
439        assert_eq!(decoder.write(&bytes[..5]).unwrap(), 5);
440        let n = decoder.write(&bytes[5..]).unwrap();
441        if n < bytes.len() - 5 {
442            decoder.write(&bytes[n + 5..]).unwrap();
443        }
444        writer = decoder.finish().unwrap();
445        let return_string = String::from_utf8(writer).expect("String parsing error");
446        assert_eq!(return_string, STR);
447    }
448
449    #[test]
450    fn decode_writer_exact_header() {
451        let mut e = GzEncoder::new(Vec::new(), Compression::default());
452        e.write(STR.as_ref()).unwrap();
453        let bytes = e.finish().unwrap();
454
455        let mut writer = Vec::new();
456        let mut decoder = GzDecoder::new(writer);
457        assert_eq!(decoder.write(&bytes[..10]).unwrap(), 10);
458        decoder.write(&bytes[10..]).unwrap();
459        writer = decoder.finish().unwrap();
460        let return_string = String::from_utf8(writer).expect("String parsing error");
461        assert_eq!(return_string, STR);
462    }
463
464    #[test]
465    fn decode_writer_partial_crc() {
466        let mut e = GzEncoder::new(Vec::new(), Compression::default());
467        e.write(STR.as_ref()).unwrap();
468        let bytes = e.finish().unwrap();
469
470        let mut writer = Vec::new();
471        let mut decoder = GzDecoder::new(writer);
472        let l = bytes.len() - 5;
473        let n = decoder.write(&bytes[..l]).unwrap();
474        decoder.write(&bytes[n..]).unwrap();
475        writer = decoder.finish().unwrap();
476        let return_string = String::from_utf8(writer).expect("String parsing error");
477        assert_eq!(return_string, STR);
478    }
479}