prost/
encoding.rs

1//! Utility functions and types for encoding and decoding Protobuf types.
2//!
3//! Meant to be used only from `Message` implementations.
4
5#![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7use alloc::collections::BTreeMap;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::cmp::min;
12use core::convert::TryFrom;
13use core::mem;
14use core::str;
15use core::u32;
16use core::usize;
17
18use ::bytes::{Buf, BufMut, Bytes};
19
20use crate::DecodeError;
21use crate::Message;
22
23/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
24/// The buffer must have enough remaining space (maximum 10 bytes).
25#[inline]
26pub fn encode_varint<B>(mut value: u64, buf: &mut B)
27where
28    B: BufMut,
29{
30    loop {
31        if value < 0x80 {
32            buf.put_u8(value as u8);
33            break;
34        } else {
35            buf.put_u8(((value & 0x7F) | 0x80) as u8);
36            value >>= 7;
37        }
38    }
39}
40
41/// Decodes a LEB128-encoded variable length integer from the buffer.
42#[inline]
43pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError>
44where
45    B: Buf,
46{
47    let bytes = buf.chunk();
48    let len = bytes.len();
49    if len == 0 {
50        return Err(DecodeError::new("invalid varint"));
51    }
52
53    let byte = bytes[0];
54    if byte < 0x80 {
55        buf.advance(1);
56        Ok(u64::from(byte))
57    } else if len > 10 || bytes[len - 1] < 0x80 {
58        let (value, advance) = decode_varint_slice(bytes)?;
59        buf.advance(advance);
60        Ok(value)
61    } else {
62        decode_varint_slow(buf)
63    }
64}
65
66/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
67/// number of bytes read.
68///
69/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
70/// [`ConsumeVarint`][2].
71///
72/// ## Safety
73///
74/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
75/// element in bytes is < `0x80`.
76///
77/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
78/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
79#[inline]
80fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
81    // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
82
83    // Use assertions to ensure memory safety, but it should always be optimized after inline.
84    assert!(!bytes.is_empty());
85    assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
86
87    let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
88    let mut part0: u32 = u32::from(b);
89    if b < 0x80 {
90        return Ok((u64::from(part0), 1));
91    };
92    part0 -= 0x80;
93    b = unsafe { *bytes.get_unchecked(1) };
94    part0 += u32::from(b) << 7;
95    if b < 0x80 {
96        return Ok((u64::from(part0), 2));
97    };
98    part0 -= 0x80 << 7;
99    b = unsafe { *bytes.get_unchecked(2) };
100    part0 += u32::from(b) << 14;
101    if b < 0x80 {
102        return Ok((u64::from(part0), 3));
103    };
104    part0 -= 0x80 << 14;
105    b = unsafe { *bytes.get_unchecked(3) };
106    part0 += u32::from(b) << 21;
107    if b < 0x80 {
108        return Ok((u64::from(part0), 4));
109    };
110    part0 -= 0x80 << 21;
111    let value = u64::from(part0);
112
113    b = unsafe { *bytes.get_unchecked(4) };
114    let mut part1: u32 = u32::from(b);
115    if b < 0x80 {
116        return Ok((value + (u64::from(part1) << 28), 5));
117    };
118    part1 -= 0x80;
119    b = unsafe { *bytes.get_unchecked(5) };
120    part1 += u32::from(b) << 7;
121    if b < 0x80 {
122        return Ok((value + (u64::from(part1) << 28), 6));
123    };
124    part1 -= 0x80 << 7;
125    b = unsafe { *bytes.get_unchecked(6) };
126    part1 += u32::from(b) << 14;
127    if b < 0x80 {
128        return Ok((value + (u64::from(part1) << 28), 7));
129    };
130    part1 -= 0x80 << 14;
131    b = unsafe { *bytes.get_unchecked(7) };
132    part1 += u32::from(b) << 21;
133    if b < 0x80 {
134        return Ok((value + (u64::from(part1) << 28), 8));
135    };
136    part1 -= 0x80 << 21;
137    let value = value + ((u64::from(part1)) << 28);
138
139    b = unsafe { *bytes.get_unchecked(8) };
140    let mut part2: u32 = u32::from(b);
141    if b < 0x80 {
142        return Ok((value + (u64::from(part2) << 56), 9));
143    };
144    part2 -= 0x80;
145    b = unsafe { *bytes.get_unchecked(9) };
146    part2 += u32::from(b) << 7;
147    // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
148    // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
149    if b < 0x02 {
150        return Ok((value + (u64::from(part2) << 56), 10));
151    };
152
153    // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
154    // Assume the data is corrupt.
155    Err(DecodeError::new("invalid varint"))
156}
157
158/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
159/// necessary.
160///
161/// Contains a varint overflow check from [`ConsumeVarint`][1].
162///
163/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
164#[inline(never)]
165#[cold]
166fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
167where
168    B: Buf,
169{
170    let mut value = 0;
171    for count in 0..min(10, buf.remaining()) {
172        let byte = buf.get_u8();
173        value |= u64::from(byte & 0x7F) << (count * 7);
174        if byte <= 0x7F {
175            // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
176            // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
177            if count == 9 && byte >= 0x02 {
178                return Err(DecodeError::new("invalid varint"));
179            } else {
180                return Ok(value);
181            }
182        }
183    }
184
185    Err(DecodeError::new("invalid varint"))
186}
187
188/// Additional information passed to every decode/merge function.
189///
190/// The context should be passed by value and can be freely cloned. When passing
191/// to a function which is decoding a nested object, then use `enter_recursion`.
192#[derive(Clone, Debug)]
193#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
194pub struct DecodeContext {
195    /// How many times we can recurse in the current decode stack before we hit
196    /// the recursion limit.
197    ///
198    /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
199    /// customized. The recursion limit can be ignored by building the Prost
200    /// crate with the `no-recursion-limit` feature.
201    #[cfg(not(feature = "no-recursion-limit"))]
202    recurse_count: u32,
203}
204
205#[cfg(not(feature = "no-recursion-limit"))]
206impl Default for DecodeContext {
207    #[inline]
208    fn default() -> DecodeContext {
209        DecodeContext {
210            recurse_count: crate::RECURSION_LIMIT,
211        }
212    }
213}
214
215impl DecodeContext {
216    /// Call this function before recursively decoding.
217    ///
218    /// There is no `exit` function since this function creates a new `DecodeContext`
219    /// to be used at the next level of recursion. Continue to use the old context
220    // at the previous level of recursion.
221    #[cfg(not(feature = "no-recursion-limit"))]
222    #[inline]
223    pub(crate) fn enter_recursion(&self) -> DecodeContext {
224        DecodeContext {
225            recurse_count: self.recurse_count - 1,
226        }
227    }
228
229    #[cfg(feature = "no-recursion-limit")]
230    #[inline]
231    pub(crate) fn enter_recursion(&self) -> DecodeContext {
232        DecodeContext {}
233    }
234
235    /// Checks whether the recursion limit has been reached in the stack of
236    /// decodes described by the `DecodeContext` at `self.ctx`.
237    ///
238    /// Returns `Ok<()>` if it is ok to continue recursing.
239    /// Returns `Err<DecodeError>` if the recursion limit has been reached.
240    #[cfg(not(feature = "no-recursion-limit"))]
241    #[inline]
242    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
243        if self.recurse_count == 0 {
244            Err(DecodeError::new("recursion limit reached"))
245        } else {
246            Ok(())
247        }
248    }
249
250    #[cfg(feature = "no-recursion-limit")]
251    #[inline]
252    #[allow(clippy::unnecessary_wraps)] // needed in other features
253    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
254        Ok(())
255    }
256}
257
258/// Returns the encoded length of the value in LEB128 variable length format.
259/// The returned value will be between 1 and 10, inclusive.
260#[inline]
261pub fn encoded_len_varint(value: u64) -> usize {
262    // Based on [VarintSize64][1].
263    // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
264    ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
265}
266
267#[derive(Clone, Copy, Debug, PartialEq)]
268#[repr(u8)]
269pub enum WireType {
270    Varint = 0,
271    SixtyFourBit = 1,
272    LengthDelimited = 2,
273    StartGroup = 3,
274    EndGroup = 4,
275    ThirtyTwoBit = 5,
276}
277
278pub const MIN_TAG: u32 = 1;
279pub const MAX_TAG: u32 = (1 << 29) - 1;
280
281impl TryFrom<u64> for WireType {
282    type Error = DecodeError;
283
284    #[inline]
285    fn try_from(value: u64) -> Result<Self, Self::Error> {
286        match value {
287            0 => Ok(WireType::Varint),
288            1 => Ok(WireType::SixtyFourBit),
289            2 => Ok(WireType::LengthDelimited),
290            3 => Ok(WireType::StartGroup),
291            4 => Ok(WireType::EndGroup),
292            5 => Ok(WireType::ThirtyTwoBit),
293            _ => Err(DecodeError::new(format!(
294                "invalid wire type value: {}",
295                value
296            ))),
297        }
298    }
299}
300
301/// Encodes a Protobuf field key, which consists of a wire type designator and
302/// the field tag.
303#[inline]
304pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B)
305where
306    B: BufMut,
307{
308    debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
309    let key = (tag << 3) | wire_type as u32;
310    encode_varint(u64::from(key), buf);
311}
312
313/// Decodes a Protobuf field key, which consists of a wire type designator and
314/// the field tag.
315#[inline(always)]
316pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError>
317where
318    B: Buf,
319{
320    let key = decode_varint(buf)?;
321    if key > u64::from(u32::MAX) {
322        return Err(DecodeError::new(format!("invalid key value: {}", key)));
323    }
324    let wire_type = WireType::try_from(key & 0x07)?;
325    let tag = key as u32 >> 3;
326
327    if tag < MIN_TAG {
328        return Err(DecodeError::new("invalid tag value: 0"));
329    }
330
331    Ok((tag, wire_type))
332}
333
334/// Returns the width of an encoded Protobuf field key with the given tag.
335/// The returned width will be between 1 and 5 bytes (inclusive).
336#[inline]
337pub fn key_len(tag: u32) -> usize {
338    encoded_len_varint(u64::from(tag << 3))
339}
340
341/// Checks that the expected wire type matches the actual wire type,
342/// or returns an error result.
343#[inline]
344pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
345    if expected != actual {
346        return Err(DecodeError::new(format!(
347            "invalid wire type: {:?} (expected {:?})",
348            actual, expected
349        )));
350    }
351    Ok(())
352}
353
354/// Helper function which abstracts reading a length delimiter prefix followed
355/// by decoding values until the length of bytes is exhausted.
356pub fn merge_loop<T, M, B>(
357    value: &mut T,
358    buf: &mut B,
359    ctx: DecodeContext,
360    mut merge: M,
361) -> Result<(), DecodeError>
362where
363    M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
364    B: Buf,
365{
366    let len = decode_varint(buf)?;
367    let remaining = buf.remaining();
368    if len > remaining as u64 {
369        return Err(DecodeError::new("buffer underflow"));
370    }
371
372    let limit = remaining - len as usize;
373    while buf.remaining() > limit {
374        merge(value, buf, ctx.clone())?;
375    }
376
377    if buf.remaining() != limit {
378        return Err(DecodeError::new("delimited length exceeded"));
379    }
380    Ok(())
381}
382
383pub fn skip_field<B>(
384    wire_type: WireType,
385    tag: u32,
386    buf: &mut B,
387    ctx: DecodeContext,
388) -> Result<(), DecodeError>
389where
390    B: Buf,
391{
392    ctx.limit_reached()?;
393    let len = match wire_type {
394        WireType::Varint => decode_varint(buf).map(|_| 0)?,
395        WireType::ThirtyTwoBit => 4,
396        WireType::SixtyFourBit => 8,
397        WireType::LengthDelimited => decode_varint(buf)?,
398        WireType::StartGroup => loop {
399            let (inner_tag, inner_wire_type) = decode_key(buf)?;
400            match inner_wire_type {
401                WireType::EndGroup => {
402                    if inner_tag != tag {
403                        return Err(DecodeError::new("unexpected end group tag"));
404                    }
405                    break 0;
406                }
407                _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
408            }
409        },
410        WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
411    };
412
413    if len > buf.remaining() as u64 {
414        return Err(DecodeError::new("buffer underflow"));
415    }
416
417    buf.advance(len as usize);
418    Ok(())
419}
420
421/// Helper macro which emits an `encode_repeated` function for the type.
422macro_rules! encode_repeated {
423    ($ty:ty) => {
424        pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B)
425        where
426            B: BufMut,
427        {
428            for value in values {
429                encode(tag, value, buf);
430            }
431        }
432    };
433}
434
435/// Helper macro which emits a `merge_repeated` function for the numeric type.
436macro_rules! merge_repeated_numeric {
437    ($ty:ty,
438     $wire_type:expr,
439     $merge:ident,
440     $merge_repeated:ident) => {
441        pub fn $merge_repeated<B>(
442            wire_type: WireType,
443            values: &mut Vec<$ty>,
444            buf: &mut B,
445            ctx: DecodeContext,
446        ) -> Result<(), DecodeError>
447        where
448            B: Buf,
449        {
450            if wire_type == WireType::LengthDelimited {
451                // Packed.
452                merge_loop(values, buf, ctx, |values, buf, ctx| {
453                    let mut value = Default::default();
454                    $merge($wire_type, &mut value, buf, ctx)?;
455                    values.push(value);
456                    Ok(())
457                })
458            } else {
459                // Unpacked.
460                check_wire_type($wire_type, wire_type)?;
461                let mut value = Default::default();
462                $merge(wire_type, &mut value, buf, ctx)?;
463                values.push(value);
464                Ok(())
465            }
466        }
467    };
468}
469
470/// Macro which emits a module containing a set of encoding functions for a
471/// variable width numeric type.
472macro_rules! varint {
473    ($ty:ty,
474     $proto_ty:ident) => (
475        varint!($ty,
476                $proto_ty,
477                to_uint64(value) { *value as u64 },
478                from_uint64(value) { value as $ty });
479    );
480
481    ($ty:ty,
482     $proto_ty:ident,
483     to_uint64($to_uint64_value:ident) $to_uint64:expr,
484     from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
485
486         pub mod $proto_ty {
487            use crate::encoding::*;
488
489            pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut {
490                encode_key(tag, WireType::Varint, buf);
491                encode_varint($to_uint64, buf);
492            }
493
494            pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf {
495                check_wire_type(WireType::Varint, wire_type)?;
496                let $from_uint64_value = decode_varint(buf)?;
497                *value = $from_uint64;
498                Ok(())
499            }
500
501            encode_repeated!($ty);
502
503            pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut {
504                if values.is_empty() { return; }
505
506                encode_key(tag, WireType::LengthDelimited, buf);
507                let len: usize = values.iter().map(|$to_uint64_value| {
508                    encoded_len_varint($to_uint64)
509                }).sum();
510                encode_varint(len as u64, buf);
511
512                for $to_uint64_value in values {
513                    encode_varint($to_uint64, buf);
514                }
515            }
516
517            merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
518
519            #[inline]
520            pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
521                key_len(tag) + encoded_len_varint($to_uint64)
522            }
523
524            #[inline]
525            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
526                key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
527                    encoded_len_varint($to_uint64)
528                }).sum::<usize>()
529            }
530
531            #[inline]
532            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
533                if values.is_empty() {
534                    0
535                } else {
536                    let len = values.iter()
537                                    .map(|$to_uint64_value| encoded_len_varint($to_uint64))
538                                    .sum::<usize>();
539                    key_len(tag) + encoded_len_varint(len as u64) + len
540                }
541            }
542
543            #[cfg(test)]
544            mod test {
545                use proptest::prelude::*;
546
547                use crate::encoding::$proto_ty::*;
548                use crate::encoding::test::{
549                    check_collection_type,
550                    check_type,
551                };
552
553                proptest! {
554                    #[test]
555                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
556                        check_type(value, tag, WireType::Varint,
557                                   encode, merge, encoded_len)?;
558                    }
559                    #[test]
560                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
561                        check_collection_type(value, tag, WireType::Varint,
562                                              encode_repeated, merge_repeated,
563                                              encoded_len_repeated)?;
564                    }
565                    #[test]
566                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
567                        check_type(value, tag, WireType::LengthDelimited,
568                                   encode_packed, merge_repeated,
569                                   encoded_len_packed)?;
570                    }
571                }
572            }
573         }
574
575    );
576}
577varint!(bool, bool,
578        to_uint64(value) if *value { 1u64 } else { 0u64 },
579        from_uint64(value) value != 0);
580varint!(i32, int32);
581varint!(i64, int64);
582varint!(u32, uint32);
583varint!(u64, uint64);
584varint!(i32, sint32,
585to_uint64(value) {
586    ((value << 1) ^ (value >> 31)) as u32 as u64
587},
588from_uint64(value) {
589    let value = value as u32;
590    ((value >> 1) as i32) ^ (-((value & 1) as i32))
591});
592varint!(i64, sint64,
593to_uint64(value) {
594    ((value << 1) ^ (value >> 63)) as u64
595},
596from_uint64(value) {
597    ((value >> 1) as i64) ^ (-((value & 1) as i64))
598});
599
600/// Macro which emits a module containing a set of encoding functions for a
601/// fixed width numeric type.
602macro_rules! fixed_width {
603    ($ty:ty,
604     $width:expr,
605     $wire_type:expr,
606     $proto_ty:ident,
607     $put:ident,
608     $get:ident) => {
609        pub mod $proto_ty {
610            use crate::encoding::*;
611
612            pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B)
613            where
614                B: BufMut,
615            {
616                encode_key(tag, $wire_type, buf);
617                buf.$put(*value);
618            }
619
620            pub fn merge<B>(
621                wire_type: WireType,
622                value: &mut $ty,
623                buf: &mut B,
624                _ctx: DecodeContext,
625            ) -> Result<(), DecodeError>
626            where
627                B: Buf,
628            {
629                check_wire_type($wire_type, wire_type)?;
630                if buf.remaining() < $width {
631                    return Err(DecodeError::new("buffer underflow"));
632                }
633                *value = buf.$get();
634                Ok(())
635            }
636
637            encode_repeated!($ty);
638
639            pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B)
640            where
641                B: BufMut,
642            {
643                if values.is_empty() {
644                    return;
645                }
646
647                encode_key(tag, WireType::LengthDelimited, buf);
648                let len = values.len() as u64 * $width;
649                encode_varint(len as u64, buf);
650
651                for value in values {
652                    buf.$put(*value);
653                }
654            }
655
656            merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
657
658            #[inline]
659            pub fn encoded_len(tag: u32, _: &$ty) -> usize {
660                key_len(tag) + $width
661            }
662
663            #[inline]
664            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
665                (key_len(tag) + $width) * values.len()
666            }
667
668            #[inline]
669            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
670                if values.is_empty() {
671                    0
672                } else {
673                    let len = $width * values.len();
674                    key_len(tag) + encoded_len_varint(len as u64) + len
675                }
676            }
677
678            #[cfg(test)]
679            mod test {
680                use proptest::prelude::*;
681
682                use super::super::test::{check_collection_type, check_type};
683                use super::*;
684
685                proptest! {
686                    #[test]
687                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
688                        check_type(value, tag, $wire_type,
689                                   encode, merge, encoded_len)?;
690                    }
691                    #[test]
692                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
693                        check_collection_type(value, tag, $wire_type,
694                                              encode_repeated, merge_repeated,
695                                              encoded_len_repeated)?;
696                    }
697                    #[test]
698                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
699                        check_type(value, tag, WireType::LengthDelimited,
700                                   encode_packed, merge_repeated,
701                                   encoded_len_packed)?;
702                    }
703                }
704            }
705        }
706    };
707}
708fixed_width!(
709    f32,
710    4,
711    WireType::ThirtyTwoBit,
712    float,
713    put_f32_le,
714    get_f32_le
715);
716fixed_width!(
717    f64,
718    8,
719    WireType::SixtyFourBit,
720    double,
721    put_f64_le,
722    get_f64_le
723);
724fixed_width!(
725    u32,
726    4,
727    WireType::ThirtyTwoBit,
728    fixed32,
729    put_u32_le,
730    get_u32_le
731);
732fixed_width!(
733    u64,
734    8,
735    WireType::SixtyFourBit,
736    fixed64,
737    put_u64_le,
738    get_u64_le
739);
740fixed_width!(
741    i32,
742    4,
743    WireType::ThirtyTwoBit,
744    sfixed32,
745    put_i32_le,
746    get_i32_le
747);
748fixed_width!(
749    i64,
750    8,
751    WireType::SixtyFourBit,
752    sfixed64,
753    put_i64_le,
754    get_i64_le
755);
756
757/// Macro which emits encoding functions for a length-delimited type.
758macro_rules! length_delimited {
759    ($ty:ty) => {
760        encode_repeated!($ty);
761
762        pub fn merge_repeated<B>(
763            wire_type: WireType,
764            values: &mut Vec<$ty>,
765            buf: &mut B,
766            ctx: DecodeContext,
767        ) -> Result<(), DecodeError>
768        where
769            B: Buf,
770        {
771            check_wire_type(WireType::LengthDelimited, wire_type)?;
772            let mut value = Default::default();
773            merge(wire_type, &mut value, buf, ctx)?;
774            values.push(value);
775            Ok(())
776        }
777
778        #[inline]
779        pub fn encoded_len(tag: u32, value: &$ty) -> usize {
780            key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
781        }
782
783        #[inline]
784        pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
785            key_len(tag) * values.len()
786                + values
787                    .iter()
788                    .map(|value| encoded_len_varint(value.len() as u64) + value.len())
789                    .sum::<usize>()
790        }
791    };
792}
793
794pub mod string {
795    use super::*;
796
797    pub fn encode<B>(tag: u32, value: &String, buf: &mut B)
798    where
799        B: BufMut,
800    {
801        encode_key(tag, WireType::LengthDelimited, buf);
802        encode_varint(value.len() as u64, buf);
803        buf.put_slice(value.as_bytes());
804    }
805    pub fn merge<B>(
806        wire_type: WireType,
807        value: &mut String,
808        buf: &mut B,
809        ctx: DecodeContext,
810    ) -> Result<(), DecodeError>
811    where
812        B: Buf,
813    {
814        // ## Unsafety
815        //
816        // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
817        // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
818        // string is cleared, so as to avoid leaking a string field with invalid data.
819        //
820        // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
821        // alternative of temporarily swapping an empty `String` into the field, because it results
822        // in up to 10% better performance on the protobuf message decoding benchmarks.
823        //
824        // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
825        // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
826        // in the buf implementation, a drop guard is used.
827        unsafe {
828            struct DropGuard<'a>(&'a mut Vec<u8>);
829            impl<'a> Drop for DropGuard<'a> {
830                #[inline]
831                fn drop(&mut self) {
832                    self.0.clear();
833                }
834            }
835
836            let drop_guard = DropGuard(value.as_mut_vec());
837            bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
838            match str::from_utf8(drop_guard.0) {
839                Ok(_) => {
840                    // Success; do not clear the bytes.
841                    mem::forget(drop_guard);
842                    Ok(())
843                }
844                Err(_) => Err(DecodeError::new(
845                    "invalid string value: data is not UTF-8 encoded",
846                )),
847            }
848        }
849    }
850
851    length_delimited!(String);
852
853    #[cfg(test)]
854    mod test {
855        use proptest::prelude::*;
856
857        use super::super::test::{check_collection_type, check_type};
858        use super::*;
859
860        proptest! {
861            #[test]
862            fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
863                super::test::check_type(value, tag, WireType::LengthDelimited,
864                                        encode, merge, encoded_len)?;
865            }
866            #[test]
867            fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
868                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
869                                                   encode_repeated, merge_repeated,
870                                                   encoded_len_repeated)?;
871            }
872        }
873    }
874}
875
876pub trait BytesAdapter: sealed::BytesAdapter {}
877
878mod sealed {
879    use super::{Buf, BufMut};
880
881    pub trait BytesAdapter: Default + Sized + 'static {
882        fn len(&self) -> usize;
883
884        /// Replace contents of this buffer with the contents of another buffer.
885        fn replace_with<B>(&mut self, buf: B)
886        where
887            B: Buf;
888
889        /// Appends this buffer to the (contents of) other buffer.
890        fn append_to<B>(&self, buf: &mut B)
891        where
892            B: BufMut;
893
894        fn is_empty(&self) -> bool {
895            self.len() == 0
896        }
897    }
898}
899
900impl BytesAdapter for Bytes {}
901
902impl sealed::BytesAdapter for Bytes {
903    fn len(&self) -> usize {
904        Buf::remaining(self)
905    }
906
907    fn replace_with<B>(&mut self, mut buf: B)
908    where
909        B: Buf,
910    {
911        *self = buf.copy_to_bytes(buf.remaining());
912    }
913
914    fn append_to<B>(&self, buf: &mut B)
915    where
916        B: BufMut,
917    {
918        buf.put(self.clone())
919    }
920}
921
922impl BytesAdapter for Vec<u8> {}
923
924impl sealed::BytesAdapter for Vec<u8> {
925    fn len(&self) -> usize {
926        Vec::len(self)
927    }
928
929    fn replace_with<B>(&mut self, buf: B)
930    where
931        B: Buf,
932    {
933        self.clear();
934        self.reserve(buf.remaining());
935        self.put(buf);
936    }
937
938    fn append_to<B>(&self, buf: &mut B)
939    where
940        B: BufMut,
941    {
942        buf.put(self.as_slice())
943    }
944}
945
946pub mod bytes {
947    use super::*;
948
949    pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B)
950    where
951        A: BytesAdapter,
952        B: BufMut,
953    {
954        encode_key(tag, WireType::LengthDelimited, buf);
955        encode_varint(value.len() as u64, buf);
956        value.append_to(buf);
957    }
958
959    pub fn merge<A, B>(
960        wire_type: WireType,
961        value: &mut A,
962        buf: &mut B,
963        _ctx: DecodeContext,
964    ) -> Result<(), DecodeError>
965    where
966        A: BytesAdapter,
967        B: Buf,
968    {
969        check_wire_type(WireType::LengthDelimited, wire_type)?;
970        let len = decode_varint(buf)?;
971        if len > buf.remaining() as u64 {
972            return Err(DecodeError::new("buffer underflow"));
973        }
974        let len = len as usize;
975
976        // Clear the existing value. This follows from the following rule in the encoding guide[1]:
977        //
978        // > Normally, an encoded message would never have more than one instance of a non-repeated
979        // > field. However, parsers are expected to handle the case in which they do. For numeric
980        // > types and strings, if the same field appears multiple times, the parser accepts the
981        // > last value it sees.
982        //
983        // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional
984        //
985        // This is intended for A and B both being Bytes so it is zero-copy.
986        // Some combinations of A and B types may cause a double-copy,
987        // in which case merge_one_copy() should be used instead.
988        value.replace_with(buf.copy_to_bytes(len));
989        Ok(())
990    }
991
992    pub(super) fn merge_one_copy<A, B>(
993        wire_type: WireType,
994        value: &mut A,
995        buf: &mut B,
996        _ctx: DecodeContext,
997    ) -> Result<(), DecodeError>
998    where
999        A: BytesAdapter,
1000        B: Buf,
1001    {
1002        check_wire_type(WireType::LengthDelimited, wire_type)?;
1003        let len = decode_varint(buf)?;
1004        if len > buf.remaining() as u64 {
1005            return Err(DecodeError::new("buffer underflow"));
1006        }
1007        let len = len as usize;
1008
1009        // If we must copy, make sure to copy only once.
1010        value.replace_with(buf.take(len));
1011        Ok(())
1012    }
1013
1014    length_delimited!(impl BytesAdapter);
1015
1016    #[cfg(test)]
1017    mod test {
1018        use proptest::prelude::*;
1019
1020        use super::super::test::{check_collection_type, check_type};
1021        use super::*;
1022
1023        proptest! {
1024            #[test]
1025            fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1026                super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
1027                                                            encode, merge, encoded_len)?;
1028            }
1029
1030            #[test]
1031            fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1032                let value = Bytes::from(value);
1033                super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
1034                                                        encode, merge, encoded_len)?;
1035            }
1036
1037            #[test]
1038            fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1039                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1040                                                   encode_repeated, merge_repeated,
1041                                                   encoded_len_repeated)?;
1042            }
1043
1044            #[test]
1045            fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1046                let value = value.into_iter().map(Bytes::from).collect();
1047                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1048                                                   encode_repeated, merge_repeated,
1049                                                   encoded_len_repeated)?;
1050            }
1051        }
1052    }
1053}
1054
1055pub mod message {
1056    use super::*;
1057
1058    pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1059    where
1060        M: Message,
1061        B: BufMut,
1062    {
1063        encode_key(tag, WireType::LengthDelimited, buf);
1064        encode_varint(msg.encoded_len() as u64, buf);
1065        msg.encode_raw(buf);
1066    }
1067
1068    pub fn merge<M, B>(
1069        wire_type: WireType,
1070        msg: &mut M,
1071        buf: &mut B,
1072        ctx: DecodeContext,
1073    ) -> Result<(), DecodeError>
1074    where
1075        M: Message,
1076        B: Buf,
1077    {
1078        check_wire_type(WireType::LengthDelimited, wire_type)?;
1079        ctx.limit_reached()?;
1080        merge_loop(
1081            msg,
1082            buf,
1083            ctx.enter_recursion(),
1084            |msg: &mut M, buf: &mut B, ctx| {
1085                let (tag, wire_type) = decode_key(buf)?;
1086                msg.merge_field(tag, wire_type, buf, ctx)
1087            },
1088        )
1089    }
1090
1091    pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1092    where
1093        M: Message,
1094        B: BufMut,
1095    {
1096        for msg in messages {
1097            encode(tag, msg, buf);
1098        }
1099    }
1100
1101    pub fn merge_repeated<M, B>(
1102        wire_type: WireType,
1103        messages: &mut Vec<M>,
1104        buf: &mut B,
1105        ctx: DecodeContext,
1106    ) -> Result<(), DecodeError>
1107    where
1108        M: Message + Default,
1109        B: Buf,
1110    {
1111        check_wire_type(WireType::LengthDelimited, wire_type)?;
1112        let mut msg = M::default();
1113        merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
1114        messages.push(msg);
1115        Ok(())
1116    }
1117
1118    #[inline]
1119    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1120    where
1121        M: Message,
1122    {
1123        let len = msg.encoded_len();
1124        key_len(tag) + encoded_len_varint(len as u64) + len
1125    }
1126
1127    #[inline]
1128    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1129    where
1130        M: Message,
1131    {
1132        key_len(tag) * messages.len()
1133            + messages
1134                .iter()
1135                .map(Message::encoded_len)
1136                .map(|len| len + encoded_len_varint(len as u64))
1137                .sum::<usize>()
1138    }
1139}
1140
1141pub mod group {
1142    use super::*;
1143
1144    pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1145    where
1146        M: Message,
1147        B: BufMut,
1148    {
1149        encode_key(tag, WireType::StartGroup, buf);
1150        msg.encode_raw(buf);
1151        encode_key(tag, WireType::EndGroup, buf);
1152    }
1153
1154    pub fn merge<M, B>(
1155        tag: u32,
1156        wire_type: WireType,
1157        msg: &mut M,
1158        buf: &mut B,
1159        ctx: DecodeContext,
1160    ) -> Result<(), DecodeError>
1161    where
1162        M: Message,
1163        B: Buf,
1164    {
1165        check_wire_type(WireType::StartGroup, wire_type)?;
1166
1167        ctx.limit_reached()?;
1168        loop {
1169            let (field_tag, field_wire_type) = decode_key(buf)?;
1170            if field_wire_type == WireType::EndGroup {
1171                if field_tag != tag {
1172                    return Err(DecodeError::new("unexpected end group tag"));
1173                }
1174                return Ok(());
1175            }
1176
1177            M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
1178        }
1179    }
1180
1181    pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1182    where
1183        M: Message,
1184        B: BufMut,
1185    {
1186        for msg in messages {
1187            encode(tag, msg, buf);
1188        }
1189    }
1190
1191    pub fn merge_repeated<M, B>(
1192        tag: u32,
1193        wire_type: WireType,
1194        messages: &mut Vec<M>,
1195        buf: &mut B,
1196        ctx: DecodeContext,
1197    ) -> Result<(), DecodeError>
1198    where
1199        M: Message + Default,
1200        B: Buf,
1201    {
1202        check_wire_type(WireType::StartGroup, wire_type)?;
1203        let mut msg = M::default();
1204        merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
1205        messages.push(msg);
1206        Ok(())
1207    }
1208
1209    #[inline]
1210    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1211    where
1212        M: Message,
1213    {
1214        2 * key_len(tag) + msg.encoded_len()
1215    }
1216
1217    #[inline]
1218    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1219    where
1220        M: Message,
1221    {
1222        2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
1223    }
1224}
1225
1226/// Rust doesn't have a `Map` trait, so macros are currently the best way to be
1227/// generic over `HashMap` and `BTreeMap`.
1228macro_rules! map {
1229    ($map_ty:ident) => {
1230        use crate::encoding::*;
1231        use core::hash::Hash;
1232
1233        /// Generic protobuf map encode function.
1234        pub fn encode<K, V, B, KE, KL, VE, VL>(
1235            key_encode: KE,
1236            key_encoded_len: KL,
1237            val_encode: VE,
1238            val_encoded_len: VL,
1239            tag: u32,
1240            values: &$map_ty<K, V>,
1241            buf: &mut B,
1242        ) where
1243            K: Default + Eq + Hash + Ord,
1244            V: Default + PartialEq,
1245            B: BufMut,
1246            KE: Fn(u32, &K, &mut B),
1247            KL: Fn(u32, &K) -> usize,
1248            VE: Fn(u32, &V, &mut B),
1249            VL: Fn(u32, &V) -> usize,
1250        {
1251            encode_with_default(
1252                key_encode,
1253                key_encoded_len,
1254                val_encode,
1255                val_encoded_len,
1256                &V::default(),
1257                tag,
1258                values,
1259                buf,
1260            )
1261        }
1262
1263        /// Generic protobuf map merge function.
1264        pub fn merge<K, V, B, KM, VM>(
1265            key_merge: KM,
1266            val_merge: VM,
1267            values: &mut $map_ty<K, V>,
1268            buf: &mut B,
1269            ctx: DecodeContext,
1270        ) -> Result<(), DecodeError>
1271        where
1272            K: Default + Eq + Hash + Ord,
1273            V: Default,
1274            B: Buf,
1275            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1276            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1277        {
1278            merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1279        }
1280
1281        /// Generic protobuf map encode function.
1282        pub fn encoded_len<K, V, KL, VL>(
1283            key_encoded_len: KL,
1284            val_encoded_len: VL,
1285            tag: u32,
1286            values: &$map_ty<K, V>,
1287        ) -> usize
1288        where
1289            K: Default + Eq + Hash + Ord,
1290            V: Default + PartialEq,
1291            KL: Fn(u32, &K) -> usize,
1292            VL: Fn(u32, &V) -> usize,
1293        {
1294            encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1295        }
1296
1297        /// Generic protobuf map encode function with an overridden value default.
1298        ///
1299        /// This is necessary because enumeration values can have a default value other
1300        /// than 0 in proto2.
1301        pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1302            key_encode: KE,
1303            key_encoded_len: KL,
1304            val_encode: VE,
1305            val_encoded_len: VL,
1306            val_default: &V,
1307            tag: u32,
1308            values: &$map_ty<K, V>,
1309            buf: &mut B,
1310        ) where
1311            K: Default + Eq + Hash + Ord,
1312            V: PartialEq,
1313            B: BufMut,
1314            KE: Fn(u32, &K, &mut B),
1315            KL: Fn(u32, &K) -> usize,
1316            VE: Fn(u32, &V, &mut B),
1317            VL: Fn(u32, &V) -> usize,
1318        {
1319            for (key, val) in values.iter() {
1320                let skip_key = key == &K::default();
1321                let skip_val = val == val_default;
1322
1323                let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1324                    + (if skip_val { 0 } else { val_encoded_len(2, val) });
1325
1326                encode_key(tag, WireType::LengthDelimited, buf);
1327                encode_varint(len as u64, buf);
1328                if !skip_key {
1329                    key_encode(1, key, buf);
1330                }
1331                if !skip_val {
1332                    val_encode(2, val, buf);
1333                }
1334            }
1335        }
1336
1337        /// Generic protobuf map merge function with an overridden value default.
1338        ///
1339        /// This is necessary because enumeration values can have a default value other
1340        /// than 0 in proto2.
1341        pub fn merge_with_default<K, V, B, KM, VM>(
1342            key_merge: KM,
1343            val_merge: VM,
1344            val_default: V,
1345            values: &mut $map_ty<K, V>,
1346            buf: &mut B,
1347            ctx: DecodeContext,
1348        ) -> Result<(), DecodeError>
1349        where
1350            K: Default + Eq + Hash + Ord,
1351            B: Buf,
1352            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1353            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1354        {
1355            let mut key = Default::default();
1356            let mut val = val_default;
1357            ctx.limit_reached()?;
1358            merge_loop(
1359                &mut (&mut key, &mut val),
1360                buf,
1361                ctx.enter_recursion(),
1362                |&mut (ref mut key, ref mut val), buf, ctx| {
1363                    let (tag, wire_type) = decode_key(buf)?;
1364                    match tag {
1365                        1 => key_merge(wire_type, key, buf, ctx),
1366                        2 => val_merge(wire_type, val, buf, ctx),
1367                        _ => skip_field(wire_type, tag, buf, ctx),
1368                    }
1369                },
1370            )?;
1371            values.insert(key, val);
1372
1373            Ok(())
1374        }
1375
1376        /// Generic protobuf map encode function with an overridden value default.
1377        ///
1378        /// This is necessary because enumeration values can have a default value other
1379        /// than 0 in proto2.
1380        pub fn encoded_len_with_default<K, V, KL, VL>(
1381            key_encoded_len: KL,
1382            val_encoded_len: VL,
1383            val_default: &V,
1384            tag: u32,
1385            values: &$map_ty<K, V>,
1386        ) -> usize
1387        where
1388            K: Default + Eq + Hash + Ord,
1389            V: PartialEq,
1390            KL: Fn(u32, &K) -> usize,
1391            VL: Fn(u32, &V) -> usize,
1392        {
1393            key_len(tag) * values.len()
1394                + values
1395                    .iter()
1396                    .map(|(key, val)| {
1397                        let len = (if key == &K::default() {
1398                            0
1399                        } else {
1400                            key_encoded_len(1, key)
1401                        }) + (if val == val_default {
1402                            0
1403                        } else {
1404                            val_encoded_len(2, val)
1405                        });
1406                        encoded_len_varint(len as u64) + len
1407                    })
1408                    .sum::<usize>()
1409        }
1410    };
1411}
1412
1413#[cfg(feature = "std")]
1414pub mod hash_map {
1415    use std::collections::HashMap;
1416    map!(HashMap);
1417}
1418
1419pub mod btree_map {
1420    map!(BTreeMap);
1421}
1422
1423#[cfg(test)]
1424mod test {
1425    use alloc::string::ToString;
1426    use core::borrow::Borrow;
1427    use core::fmt::Debug;
1428    use core::u64;
1429
1430    use ::bytes::{Bytes, BytesMut};
1431    use proptest::{prelude::*, test_runner::TestCaseResult};
1432
1433    use crate::encoding::*;
1434
1435    pub fn check_type<T, B>(
1436        value: T,
1437        tag: u32,
1438        wire_type: WireType,
1439        encode: fn(u32, &B, &mut BytesMut),
1440        merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1441        encoded_len: fn(u32, &B) -> usize,
1442    ) -> TestCaseResult
1443    where
1444        T: Debug + Default + PartialEq + Borrow<B>,
1445        B: ?Sized,
1446    {
1447        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1448
1449        let expected_len = encoded_len(tag, value.borrow());
1450
1451        let mut buf = BytesMut::with_capacity(expected_len);
1452        encode(tag, value.borrow(), &mut buf);
1453
1454        let mut buf = buf.freeze();
1455
1456        prop_assert_eq!(
1457            buf.remaining(),
1458            expected_len,
1459            "encoded_len wrong; expected: {}, actual: {}",
1460            expected_len,
1461            buf.remaining()
1462        );
1463
1464        if !buf.has_remaining() {
1465            // Short circuit for empty packed values.
1466            return Ok(());
1467        }
1468
1469        let (decoded_tag, decoded_wire_type) =
1470            decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1471        prop_assert_eq!(
1472            tag,
1473            decoded_tag,
1474            "decoded tag does not match; expected: {}, actual: {}",
1475            tag,
1476            decoded_tag
1477        );
1478
1479        prop_assert_eq!(
1480            wire_type,
1481            decoded_wire_type,
1482            "decoded wire type does not match; expected: {:?}, actual: {:?}",
1483            wire_type,
1484            decoded_wire_type,
1485        );
1486
1487        match wire_type {
1488            WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1489                "64bit wire type illegal remaining: {}, tag: {}",
1490                buf.remaining(),
1491                tag
1492            ))),
1493            WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1494                "32bit wire type illegal remaining: {}, tag: {}",
1495                buf.remaining(),
1496                tag
1497            ))),
1498            _ => Ok(()),
1499        }?;
1500
1501        let mut roundtrip_value = T::default();
1502        merge(
1503            wire_type,
1504            &mut roundtrip_value,
1505            &mut buf,
1506            DecodeContext::default(),
1507        )
1508        .map_err(|error| TestCaseError::fail(error.to_string()))?;
1509
1510        prop_assert!(
1511            !buf.has_remaining(),
1512            "expected buffer to be empty, remaining: {}",
1513            buf.remaining()
1514        );
1515
1516        prop_assert_eq!(value, roundtrip_value);
1517
1518        Ok(())
1519    }
1520
1521    pub fn check_collection_type<T, B, E, M, L>(
1522        value: T,
1523        tag: u32,
1524        wire_type: WireType,
1525        encode: E,
1526        mut merge: M,
1527        encoded_len: L,
1528    ) -> TestCaseResult
1529    where
1530        T: Debug + Default + PartialEq + Borrow<B>,
1531        B: ?Sized,
1532        E: FnOnce(u32, &B, &mut BytesMut),
1533        M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1534        L: FnOnce(u32, &B) -> usize,
1535    {
1536        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1537
1538        let expected_len = encoded_len(tag, value.borrow());
1539
1540        let mut buf = BytesMut::with_capacity(expected_len);
1541        encode(tag, value.borrow(), &mut buf);
1542
1543        let mut buf = buf.freeze();
1544
1545        prop_assert_eq!(
1546            buf.remaining(),
1547            expected_len,
1548            "encoded_len wrong; expected: {}, actual: {}",
1549            expected_len,
1550            buf.remaining()
1551        );
1552
1553        let mut roundtrip_value = Default::default();
1554        while buf.has_remaining() {
1555            let (decoded_tag, decoded_wire_type) =
1556                decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1557
1558            prop_assert_eq!(
1559                tag,
1560                decoded_tag,
1561                "decoded tag does not match; expected: {}, actual: {}",
1562                tag,
1563                decoded_tag
1564            );
1565
1566            prop_assert_eq!(
1567                wire_type,
1568                decoded_wire_type,
1569                "decoded wire type does not match; expected: {:?}, actual: {:?}",
1570                wire_type,
1571                decoded_wire_type
1572            );
1573
1574            merge(
1575                wire_type,
1576                &mut roundtrip_value,
1577                &mut buf,
1578                DecodeContext::default(),
1579            )
1580            .map_err(|error| TestCaseError::fail(error.to_string()))?;
1581        }
1582
1583        prop_assert_eq!(value, roundtrip_value);
1584
1585        Ok(())
1586    }
1587
1588    #[test]
1589    fn string_merge_invalid_utf8() {
1590        let mut s = String::new();
1591        let buf = b"\x02\x80\x80";
1592
1593        let r = string::merge(
1594            WireType::LengthDelimited,
1595            &mut s,
1596            &mut &buf[..],
1597            DecodeContext::default(),
1598        );
1599        r.expect_err("must be an error");
1600        assert!(s.is_empty());
1601    }
1602
1603    #[test]
1604    fn varint() {
1605        fn check(value: u64, mut encoded: &[u8]) {
1606            // TODO(rust-lang/rust-clippy#5494)
1607            #![allow(clippy::clone_double_ref)]
1608
1609            // Small buffer.
1610            let mut buf = Vec::with_capacity(1);
1611            encode_varint(value, &mut buf);
1612            assert_eq!(buf, encoded);
1613
1614            // Large buffer.
1615            let mut buf = Vec::with_capacity(100);
1616            encode_varint(value, &mut buf);
1617            assert_eq!(buf, encoded);
1618
1619            assert_eq!(encoded_len_varint(value), encoded.len());
1620
1621            let roundtrip_value = decode_varint(&mut encoded.clone()).expect("decoding failed");
1622            assert_eq!(value, roundtrip_value);
1623
1624            let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed");
1625            assert_eq!(value, roundtrip_value);
1626        }
1627
1628        check(2u64.pow(0) - 1, &[0x00]);
1629        check(2u64.pow(0), &[0x01]);
1630
1631        check(2u64.pow(7) - 1, &[0x7F]);
1632        check(2u64.pow(7), &[0x80, 0x01]);
1633        check(300, &[0xAC, 0x02]);
1634
1635        check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
1636        check(2u64.pow(14), &[0x80, 0x80, 0x01]);
1637
1638        check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
1639        check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
1640
1641        check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
1642        check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
1643
1644        check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1645        check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1646
1647        check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1648        check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1649
1650        check(
1651            2u64.pow(49) - 1,
1652            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1653        );
1654        check(
1655            2u64.pow(49),
1656            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1657        );
1658
1659        check(
1660            2u64.pow(56) - 1,
1661            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1662        );
1663        check(
1664            2u64.pow(56),
1665            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1666        );
1667
1668        check(
1669            2u64.pow(63) - 1,
1670            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1671        );
1672        check(
1673            2u64.pow(63),
1674            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1675        );
1676
1677        check(
1678            u64::MAX,
1679            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
1680        );
1681    }
1682
1683    #[test]
1684    fn varint_overflow() {
1685        let mut u64_max_plus_one: &[u8] =
1686            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
1687
1688        decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
1689        decode_varint_slow(&mut u64_max_plus_one)
1690            .expect_err("slow decoding u64::MAX + 1 succeeded");
1691    }
1692
1693    /// This big bowl o' macro soup generates an encoding property test for each combination of map
1694    /// type, scalar map key, and value type.
1695    /// TODO: these tests take a long time to compile, can this be improved?
1696    #[cfg(feature = "std")]
1697    macro_rules! map_tests {
1698        (keys: $keys:tt,
1699         vals: $vals:tt) => {
1700            mod hash_map {
1701                map_tests!(@private HashMap, hash_map, $keys, $vals);
1702            }
1703            mod btree_map {
1704                map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1705            }
1706        };
1707
1708        (@private $map_type:ident,
1709                  $mod_name:ident,
1710                  [$(($key_ty:ty, $key_proto:ident)),*],
1711                  $vals:tt) => {
1712            $(
1713                mod $key_proto {
1714                    use std::collections::$map_type;
1715
1716                    use proptest::prelude::*;
1717
1718                    use crate::encoding::*;
1719                    use crate::encoding::test::check_collection_type;
1720
1721                    map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1722                }
1723            )*
1724        };
1725
1726        (@private $map_type:ident,
1727                  $mod_name:ident,
1728                  ($key_ty:ty, $key_proto:ident),
1729                  [$(($val_ty:ty, $val_proto:ident)),*]) => {
1730            $(
1731                proptest! {
1732                    #[test]
1733                    fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1734                        check_collection_type(values, tag, WireType::LengthDelimited,
1735                                              |tag, values, buf| {
1736                                                  $mod_name::encode($key_proto::encode,
1737                                                                    $key_proto::encoded_len,
1738                                                                    $val_proto::encode,
1739                                                                    $val_proto::encoded_len,
1740                                                                    tag,
1741                                                                    values,
1742                                                                    buf)
1743                                              },
1744                                              |wire_type, values, buf, ctx| {
1745                                                  check_wire_type(WireType::LengthDelimited, wire_type)?;
1746                                                  $mod_name::merge($key_proto::merge,
1747                                                                   $val_proto::merge,
1748                                                                   values,
1749                                                                   buf,
1750                                                                   ctx)
1751                                              },
1752                                              |tag, values| {
1753                                                  $mod_name::encoded_len($key_proto::encoded_len,
1754                                                                         $val_proto::encoded_len,
1755                                                                         tag,
1756                                                                         values)
1757                                              })?;
1758                    }
1759                }
1760             )*
1761        };
1762    }
1763
1764    #[cfg(feature = "std")]
1765    map_tests!(keys: [
1766        (i32, int32),
1767        (i64, int64),
1768        (u32, uint32),
1769        (u64, uint64),
1770        (i32, sint32),
1771        (i64, sint64),
1772        (u32, fixed32),
1773        (u64, fixed64),
1774        (i32, sfixed32),
1775        (i64, sfixed64),
1776        (bool, bool),
1777        (String, string)
1778    ],
1779    vals: [
1780        (f32, float),
1781        (f64, double),
1782        (i32, int32),
1783        (i64, int64),
1784        (u32, uint32),
1785        (u64, uint64),
1786        (i32, sint32),
1787        (i64, sint64),
1788        (u32, fixed32),
1789        (u64, fixed64),
1790        (i32, sfixed32),
1791        (i64, sfixed64),
1792        (bool, bool),
1793        (String, string),
1794        (Vec<u8>, bytes)
1795    ]);
1796}