hyper/proto/h1/
encode.rs

1use std::fmt;
2use std::io::IoSlice;
3
4use bytes::buf::{Chain, Take};
5use bytes::Buf;
6use tracing::trace;
7
8use super::io::WriteBuf;
9
10type StaticBuf = &'static [u8];
11
12/// Encoders to handle different Transfer-Encodings.
13#[derive(Debug, Clone, PartialEq)]
14pub(crate) struct Encoder {
15    kind: Kind,
16    is_last: bool,
17}
18
19#[derive(Debug)]
20pub(crate) struct EncodedBuf<B> {
21    kind: BufKind<B>,
22}
23
24#[derive(Debug)]
25pub(crate) struct NotEof(u64);
26
27#[derive(Debug, PartialEq, Clone)]
28enum Kind {
29    /// An Encoder for when Transfer-Encoding includes `chunked`.
30    Chunked,
31    /// An Encoder for when Content-Length is set.
32    ///
33    /// Enforces that the body is not longer than the Content-Length header.
34    Length(u64),
35    /// An Encoder for when neither Content-Length nor Chunked encoding is set.
36    ///
37    /// This is mostly only used with HTTP/1.0 with a length. This kind requires
38    /// the connection to be closed when the body is finished.
39    #[cfg(feature = "server")]
40    CloseDelimited,
41}
42
43#[derive(Debug)]
44enum BufKind<B> {
45    Exact(B),
46    Limited(Take<B>),
47    Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
48    ChunkedEnd(StaticBuf),
49}
50
51impl Encoder {
52    fn new(kind: Kind) -> Encoder {
53        Encoder {
54            kind,
55            is_last: false,
56        }
57    }
58    pub(crate) fn chunked() -> Encoder {
59        Encoder::new(Kind::Chunked)
60    }
61
62    pub(crate) fn length(len: u64) -> Encoder {
63        Encoder::new(Kind::Length(len))
64    }
65
66    #[cfg(feature = "server")]
67    pub(crate) fn close_delimited() -> Encoder {
68        Encoder::new(Kind::CloseDelimited)
69    }
70
71    pub(crate) fn is_eof(&self) -> bool {
72        matches!(self.kind, Kind::Length(0))
73    }
74
75    #[cfg(feature = "server")]
76    pub(crate) fn set_last(mut self, is_last: bool) -> Self {
77        self.is_last = is_last;
78        self
79    }
80
81    pub(crate) fn is_last(&self) -> bool {
82        self.is_last
83    }
84
85    pub(crate) fn is_close_delimited(&self) -> bool {
86        match self.kind {
87            #[cfg(feature = "server")]
88            Kind::CloseDelimited => true,
89            _ => false,
90        }
91    }
92
93    pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
94        match self.kind {
95            Kind::Length(0) => Ok(None),
96            Kind::Chunked => Ok(Some(EncodedBuf {
97                kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
98            })),
99            #[cfg(feature = "server")]
100            Kind::CloseDelimited => Ok(None),
101            Kind::Length(n) => Err(NotEof(n)),
102        }
103    }
104
105    pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
106    where
107        B: Buf,
108    {
109        let len = msg.remaining();
110        debug_assert!(len > 0, "encode() called with empty buf");
111
112        let kind = match self.kind {
113            Kind::Chunked => {
114                trace!("encoding chunked {}B", len);
115                let buf = ChunkSize::new(len)
116                    .chain(msg)
117                    .chain(b"\r\n" as &'static [u8]);
118                BufKind::Chunked(buf)
119            }
120            Kind::Length(ref mut remaining) => {
121                trace!("sized write, len = {}", len);
122                if len as u64 > *remaining {
123                    let limit = *remaining as usize;
124                    *remaining = 0;
125                    BufKind::Limited(msg.take(limit))
126                } else {
127                    *remaining -= len as u64;
128                    BufKind::Exact(msg)
129                }
130            }
131            #[cfg(feature = "server")]
132            Kind::CloseDelimited => {
133                trace!("close delimited write {}B", len);
134                BufKind::Exact(msg)
135            }
136        };
137        EncodedBuf { kind }
138    }
139
140    pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
141    where
142        B: Buf,
143    {
144        let len = msg.remaining();
145        debug_assert!(len > 0, "encode() called with empty buf");
146
147        match self.kind {
148            Kind::Chunked => {
149                trace!("encoding chunked {}B", len);
150                let buf = ChunkSize::new(len)
151                    .chain(msg)
152                    .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
153                dst.buffer(buf);
154                !self.is_last
155            }
156            Kind::Length(remaining) => {
157                use std::cmp::Ordering;
158
159                trace!("sized write, len = {}", len);
160                match (len as u64).cmp(&remaining) {
161                    Ordering::Equal => {
162                        dst.buffer(msg);
163                        !self.is_last
164                    }
165                    Ordering::Greater => {
166                        dst.buffer(msg.take(remaining as usize));
167                        !self.is_last
168                    }
169                    Ordering::Less => {
170                        dst.buffer(msg);
171                        false
172                    }
173                }
174            }
175            #[cfg(feature = "server")]
176            Kind::CloseDelimited => {
177                trace!("close delimited write {}B", len);
178                dst.buffer(msg);
179                false
180            }
181        }
182    }
183
184    /// Encodes the full body, without verifying the remaining length matches.
185    ///
186    /// This is used in conjunction with HttpBody::__hyper_full_data(), which
187    /// means we can trust that the buf has the correct size (the buf itself
188    /// was checked to make the headers).
189    pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>)
190    where
191        B: Buf,
192    {
193        debug_assert!(msg.remaining() > 0, "encode() called with empty buf");
194        debug_assert!(
195            match self.kind {
196                Kind::Length(len) => len == msg.remaining() as u64,
197                _ => true,
198            },
199            "danger_full_buf length mismatches"
200        );
201
202        match self.kind {
203            Kind::Chunked => {
204                let len = msg.remaining();
205                trace!("encoding chunked {}B", len);
206                let buf = ChunkSize::new(len)
207                    .chain(msg)
208                    .chain(b"\r\n0\r\n\r\n" as &'static [u8]);
209                dst.buffer(buf);
210            }
211            _ => {
212                dst.buffer(msg);
213            }
214        }
215    }
216}
217
218impl<B> Buf for EncodedBuf<B>
219where
220    B: Buf,
221{
222    #[inline]
223    fn remaining(&self) -> usize {
224        match self.kind {
225            BufKind::Exact(ref b) => b.remaining(),
226            BufKind::Limited(ref b) => b.remaining(),
227            BufKind::Chunked(ref b) => b.remaining(),
228            BufKind::ChunkedEnd(ref b) => b.remaining(),
229        }
230    }
231
232    #[inline]
233    fn chunk(&self) -> &[u8] {
234        match self.kind {
235            BufKind::Exact(ref b) => b.chunk(),
236            BufKind::Limited(ref b) => b.chunk(),
237            BufKind::Chunked(ref b) => b.chunk(),
238            BufKind::ChunkedEnd(ref b) => b.chunk(),
239        }
240    }
241
242    #[inline]
243    fn advance(&mut self, cnt: usize) {
244        match self.kind {
245            BufKind::Exact(ref mut b) => b.advance(cnt),
246            BufKind::Limited(ref mut b) => b.advance(cnt),
247            BufKind::Chunked(ref mut b) => b.advance(cnt),
248            BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
249        }
250    }
251
252    #[inline]
253    fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
254        match self.kind {
255            BufKind::Exact(ref b) => b.chunks_vectored(dst),
256            BufKind::Limited(ref b) => b.chunks_vectored(dst),
257            BufKind::Chunked(ref b) => b.chunks_vectored(dst),
258            BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
259        }
260    }
261}
262
263#[cfg(target_pointer_width = "32")]
264const USIZE_BYTES: usize = 4;
265
266#[cfg(target_pointer_width = "64")]
267const USIZE_BYTES: usize = 8;
268
269// each byte will become 2 hex
270const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
271
272#[derive(Clone, Copy)]
273struct ChunkSize {
274    bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
275    pos: u8,
276    len: u8,
277}
278
279impl ChunkSize {
280    fn new(len: usize) -> ChunkSize {
281        use std::fmt::Write;
282        let mut size = ChunkSize {
283            bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
284            pos: 0,
285            len: 0,
286        };
287        write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
288        size
289    }
290}
291
292impl Buf for ChunkSize {
293    #[inline]
294    fn remaining(&self) -> usize {
295        (self.len - self.pos).into()
296    }
297
298    #[inline]
299    fn chunk(&self) -> &[u8] {
300        &self.bytes[self.pos.into()..self.len.into()]
301    }
302
303    #[inline]
304    fn advance(&mut self, cnt: usize) {
305        assert!(cnt <= self.remaining());
306        self.pos += cnt as u8; // just asserted cnt fits in u8
307    }
308}
309
310impl fmt::Debug for ChunkSize {
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        f.debug_struct("ChunkSize")
313            .field("bytes", &&self.bytes[..self.len.into()])
314            .field("pos", &self.pos)
315            .finish()
316    }
317}
318
319impl fmt::Write for ChunkSize {
320    fn write_str(&mut self, num: &str) -> fmt::Result {
321        use std::io::Write;
322        (&mut self.bytes[self.len.into()..])
323            .write_all(num.as_bytes())
324            .expect("&mut [u8].write() cannot error");
325        self.len += num.len() as u8; // safe because bytes is never bigger than 256
326        Ok(())
327    }
328}
329
330impl<B: Buf> From<B> for EncodedBuf<B> {
331    fn from(buf: B) -> Self {
332        EncodedBuf {
333            kind: BufKind::Exact(buf),
334        }
335    }
336}
337
338impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
339    fn from(buf: Take<B>) -> Self {
340        EncodedBuf {
341            kind: BufKind::Limited(buf),
342        }
343    }
344}
345
346impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
347    fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
348        EncodedBuf {
349            kind: BufKind::Chunked(buf),
350        }
351    }
352}
353
354impl fmt::Display for NotEof {
355    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
356        write!(f, "early end, expected {} more bytes", self.0)
357    }
358}
359
360impl std::error::Error for NotEof {}
361
362#[cfg(test)]
363mod tests {
364    use bytes::BufMut;
365
366    use super::super::io::Cursor;
367    use super::Encoder;
368
369    #[test]
370    fn chunked() {
371        let mut encoder = Encoder::chunked();
372        let mut dst = Vec::new();
373
374        let msg1 = b"foo bar".as_ref();
375        let buf1 = encoder.encode(msg1);
376        dst.put(buf1);
377        assert_eq!(dst, b"7\r\nfoo bar\r\n");
378
379        let msg2 = b"baz quux herp".as_ref();
380        let buf2 = encoder.encode(msg2);
381        dst.put(buf2);
382
383        assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
384
385        let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
386        dst.put(end);
387
388        assert_eq!(
389            dst,
390            b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
391        );
392    }
393
394    #[test]
395    fn length() {
396        let max_len = 8;
397        let mut encoder = Encoder::length(max_len as u64);
398        let mut dst = Vec::new();
399
400        let msg1 = b"foo bar".as_ref();
401        let buf1 = encoder.encode(msg1);
402        dst.put(buf1);
403
404        assert_eq!(dst, b"foo bar");
405        assert!(!encoder.is_eof());
406        encoder.end::<()>().unwrap_err();
407
408        let msg2 = b"baz".as_ref();
409        let buf2 = encoder.encode(msg2);
410        dst.put(buf2);
411
412        assert_eq!(dst.len(), max_len);
413        assert_eq!(dst, b"foo barb");
414        assert!(encoder.is_eof());
415        assert!(encoder.end::<()>().unwrap().is_none());
416    }
417
418    #[test]
419    fn eof() {
420        let mut encoder = Encoder::close_delimited();
421        let mut dst = Vec::new();
422
423        let msg1 = b"foo bar".as_ref();
424        let buf1 = encoder.encode(msg1);
425        dst.put(buf1);
426
427        assert_eq!(dst, b"foo bar");
428        assert!(!encoder.is_eof());
429        encoder.end::<()>().unwrap();
430
431        let msg2 = b"baz".as_ref();
432        let buf2 = encoder.encode(msg2);
433        dst.put(buf2);
434
435        assert_eq!(dst, b"foo barbaz");
436        assert!(!encoder.is_eof());
437        encoder.end::<()>().unwrap();
438    }
439}