prost/
encoding.rs

1//! Utility functions and types for encoding and decoding Protobuf types.
2//!
3//! Meant to be used only from `Message` implementations.
4
5#![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7use alloc::collections::BTreeMap;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::mem;
12use core::str;
13
14use ::bytes::{Buf, BufMut, Bytes};
15
16use crate::DecodeError;
17use crate::Message;
18
19pub mod varint;
20pub use varint::{decode_varint, encode_varint, encoded_len_varint};
21
22pub mod length_delimiter;
23pub use length_delimiter::{
24    decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
25};
26
27pub mod wire_type;
28pub use wire_type::{check_wire_type, WireType};
29
30/// Additional information passed to every decode/merge function.
31///
32/// The context should be passed by value and can be freely cloned. When passing
33/// to a function which is decoding a nested object, then use `enter_recursion`.
34#[derive(Clone, Debug)]
35#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
36pub struct DecodeContext {
37    /// How many times we can recurse in the current decode stack before we hit
38    /// the recursion limit.
39    ///
40    /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
41    /// customized. The recursion limit can be ignored by building the Prost
42    /// crate with the `no-recursion-limit` feature.
43    #[cfg(not(feature = "no-recursion-limit"))]
44    recurse_count: u32,
45}
46
47#[cfg(not(feature = "no-recursion-limit"))]
48impl Default for DecodeContext {
49    #[inline]
50    fn default() -> DecodeContext {
51        DecodeContext {
52            recurse_count: crate::RECURSION_LIMIT,
53        }
54    }
55}
56
57impl DecodeContext {
58    /// Call this function before recursively decoding.
59    ///
60    /// There is no `exit` function since this function creates a new `DecodeContext`
61    /// to be used at the next level of recursion. Continue to use the old context
62    // at the previous level of recursion.
63    #[cfg(not(feature = "no-recursion-limit"))]
64    #[inline]
65    pub(crate) fn enter_recursion(&self) -> DecodeContext {
66        DecodeContext {
67            recurse_count: self.recurse_count - 1,
68        }
69    }
70
71    #[cfg(feature = "no-recursion-limit")]
72    #[inline]
73    pub(crate) fn enter_recursion(&self) -> DecodeContext {
74        DecodeContext {}
75    }
76
77    /// Checks whether the recursion limit has been reached in the stack of
78    /// decodes described by the `DecodeContext` at `self.ctx`.
79    ///
80    /// Returns `Ok<()>` if it is ok to continue recursing.
81    /// Returns `Err<DecodeError>` if the recursion limit has been reached.
82    #[cfg(not(feature = "no-recursion-limit"))]
83    #[inline]
84    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
85        if self.recurse_count == 0 {
86            Err(DecodeError::new("recursion limit reached"))
87        } else {
88            Ok(())
89        }
90    }
91
92    #[cfg(feature = "no-recursion-limit")]
93    #[inline]
94    #[allow(clippy::unnecessary_wraps)] // needed in other features
95    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
96        Ok(())
97    }
98}
99
100pub const MIN_TAG: u32 = 1;
101pub const MAX_TAG: u32 = (1 << 29) - 1;
102
103/// Encodes a Protobuf field key, which consists of a wire type designator and
104/// the field tag.
105#[inline]
106pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut impl BufMut) {
107    debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
108    let key = (tag << 3) | wire_type as u32;
109    encode_varint(u64::from(key), buf);
110}
111
112/// Decodes a Protobuf field key, which consists of a wire type designator and
113/// the field tag.
114#[inline(always)]
115pub fn decode_key(buf: &mut impl Buf) -> Result<(u32, WireType), DecodeError> {
116    let key = decode_varint(buf)?;
117    if key > u64::from(u32::MAX) {
118        return Err(DecodeError::new(format!("invalid key value: {}", key)));
119    }
120    let wire_type = WireType::try_from(key & 0x07)?;
121    let tag = key as u32 >> 3;
122
123    if tag < MIN_TAG {
124        return Err(DecodeError::new("invalid tag value: 0"));
125    }
126
127    Ok((tag, wire_type))
128}
129
130/// Returns the width of an encoded Protobuf field key with the given tag.
131/// The returned width will be between 1 and 5 bytes (inclusive).
132#[inline]
133pub const fn key_len(tag: u32) -> usize {
134    encoded_len_varint((tag << 3) as u64)
135}
136
137/// Helper function which abstracts reading a length delimiter prefix followed
138/// by decoding values until the length of bytes is exhausted.
139pub fn merge_loop<T, M, B>(
140    value: &mut T,
141    buf: &mut B,
142    ctx: DecodeContext,
143    mut merge: M,
144) -> Result<(), DecodeError>
145where
146    M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
147    B: Buf,
148{
149    let len = decode_varint(buf)?;
150    let remaining = buf.remaining();
151    if len > remaining as u64 {
152        return Err(DecodeError::new("buffer underflow"));
153    }
154
155    let limit = remaining - len as usize;
156    while buf.remaining() > limit {
157        merge(value, buf, ctx.clone())?;
158    }
159
160    if buf.remaining() != limit {
161        return Err(DecodeError::new("delimited length exceeded"));
162    }
163    Ok(())
164}
165
166pub fn skip_field(
167    wire_type: WireType,
168    tag: u32,
169    buf: &mut impl Buf,
170    ctx: DecodeContext,
171) -> Result<(), DecodeError> {
172    ctx.limit_reached()?;
173    let len = match wire_type {
174        WireType::Varint => decode_varint(buf).map(|_| 0)?,
175        WireType::ThirtyTwoBit => 4,
176        WireType::SixtyFourBit => 8,
177        WireType::LengthDelimited => decode_varint(buf)?,
178        WireType::StartGroup => loop {
179            let (inner_tag, inner_wire_type) = decode_key(buf)?;
180            match inner_wire_type {
181                WireType::EndGroup => {
182                    if inner_tag != tag {
183                        return Err(DecodeError::new("unexpected end group tag"));
184                    }
185                    break 0;
186                }
187                _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
188            }
189        },
190        WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
191    };
192
193    if len > buf.remaining() as u64 {
194        return Err(DecodeError::new("buffer underflow"));
195    }
196
197    buf.advance(len as usize);
198    Ok(())
199}
200
201/// Helper macro which emits an `encode_repeated` function for the type.
202macro_rules! encode_repeated {
203    ($ty:ty) => {
204        pub fn encode_repeated(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
205            for value in values {
206                encode(tag, value, buf);
207            }
208        }
209    };
210}
211
212/// Helper macro which emits a `merge_repeated` function for the numeric type.
213macro_rules! merge_repeated_numeric {
214    ($ty:ty,
215     $wire_type:expr,
216     $merge:ident,
217     $merge_repeated:ident) => {
218        pub fn $merge_repeated(
219            wire_type: WireType,
220            values: &mut Vec<$ty>,
221            buf: &mut impl Buf,
222            ctx: DecodeContext,
223        ) -> Result<(), DecodeError> {
224            if wire_type == WireType::LengthDelimited {
225                // Packed.
226                merge_loop(values, buf, ctx, |values, buf, ctx| {
227                    let mut value = Default::default();
228                    $merge($wire_type, &mut value, buf, ctx)?;
229                    values.push(value);
230                    Ok(())
231                })
232            } else {
233                // Unpacked.
234                check_wire_type($wire_type, wire_type)?;
235                let mut value = Default::default();
236                $merge(wire_type, &mut value, buf, ctx)?;
237                values.push(value);
238                Ok(())
239            }
240        }
241    };
242}
243
244/// Macro which emits a module containing a set of encoding functions for a
245/// variable width numeric type.
246macro_rules! varint {
247    ($ty:ty,
248     $proto_ty:ident) => (
249        varint!($ty,
250                $proto_ty,
251                to_uint64(value) { *value as u64 },
252                from_uint64(value) { value as $ty });
253    );
254
255    ($ty:ty,
256     $proto_ty:ident,
257     to_uint64($to_uint64_value:ident) $to_uint64:expr,
258     from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
259
260         pub mod $proto_ty {
261            use crate::encoding::*;
262
263            pub fn encode(tag: u32, $to_uint64_value: &$ty, buf: &mut impl BufMut) {
264                encode_key(tag, WireType::Varint, buf);
265                encode_varint($to_uint64, buf);
266            }
267
268            pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut impl Buf, _ctx: DecodeContext) -> Result<(), DecodeError> {
269                check_wire_type(WireType::Varint, wire_type)?;
270                let $from_uint64_value = decode_varint(buf)?;
271                *value = $from_uint64;
272                Ok(())
273            }
274
275            encode_repeated!($ty);
276
277            pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
278                if values.is_empty() { return; }
279
280                encode_key(tag, WireType::LengthDelimited, buf);
281                let len: usize = values.iter().map(|$to_uint64_value| {
282                    encoded_len_varint($to_uint64)
283                }).sum();
284                encode_varint(len as u64, buf);
285
286                for $to_uint64_value in values {
287                    encode_varint($to_uint64, buf);
288                }
289            }
290
291            merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
292
293            #[inline]
294            pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
295                key_len(tag) + encoded_len_varint($to_uint64)
296            }
297
298            #[inline]
299            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
300                key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
301                    encoded_len_varint($to_uint64)
302                }).sum::<usize>()
303            }
304
305            #[inline]
306            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
307                if values.is_empty() {
308                    0
309                } else {
310                    let len = values.iter()
311                                    .map(|$to_uint64_value| encoded_len_varint($to_uint64))
312                                    .sum::<usize>();
313                    key_len(tag) + encoded_len_varint(len as u64) + len
314                }
315            }
316
317            #[cfg(test)]
318            mod test {
319                use proptest::prelude::*;
320
321                use crate::encoding::$proto_ty::*;
322                use crate::encoding::test::{
323                    check_collection_type,
324                    check_type,
325                };
326
327                proptest! {
328                    #[test]
329                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
330                        check_type(value, tag, WireType::Varint,
331                                   encode, merge, encoded_len)?;
332                    }
333                    #[test]
334                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
335                        check_collection_type(value, tag, WireType::Varint,
336                                              encode_repeated, merge_repeated,
337                                              encoded_len_repeated)?;
338                    }
339                    #[test]
340                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
341                        check_type(value, tag, WireType::LengthDelimited,
342                                   encode_packed, merge_repeated,
343                                   encoded_len_packed)?;
344                    }
345                }
346            }
347         }
348
349    );
350}
351varint!(bool, bool,
352        to_uint64(value) u64::from(*value),
353        from_uint64(value) value != 0);
354varint!(i32, int32);
355varint!(i64, int64);
356varint!(u32, uint32);
357varint!(u64, uint64);
358varint!(i32, sint32,
359to_uint64(value) {
360    ((value << 1) ^ (value >> 31)) as u32 as u64
361},
362from_uint64(value) {
363    let value = value as u32;
364    ((value >> 1) as i32) ^ (-((value & 1) as i32))
365});
366varint!(i64, sint64,
367to_uint64(value) {
368    ((value << 1) ^ (value >> 63)) as u64
369},
370from_uint64(value) {
371    ((value >> 1) as i64) ^ (-((value & 1) as i64))
372});
373
374/// Macro which emits a module containing a set of encoding functions for a
375/// fixed width numeric type.
376macro_rules! fixed_width {
377    ($ty:ty,
378     $width:expr,
379     $wire_type:expr,
380     $proto_ty:ident,
381     $put:ident,
382     $get:ident) => {
383        pub mod $proto_ty {
384            use crate::encoding::*;
385
386            pub fn encode(tag: u32, value: &$ty, buf: &mut impl BufMut) {
387                encode_key(tag, $wire_type, buf);
388                buf.$put(*value);
389            }
390
391            pub fn merge(
392                wire_type: WireType,
393                value: &mut $ty,
394                buf: &mut impl Buf,
395                _ctx: DecodeContext,
396            ) -> Result<(), DecodeError> {
397                check_wire_type($wire_type, wire_type)?;
398                if buf.remaining() < $width {
399                    return Err(DecodeError::new("buffer underflow"));
400                }
401                *value = buf.$get();
402                Ok(())
403            }
404
405            encode_repeated!($ty);
406
407            pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
408                if values.is_empty() {
409                    return;
410                }
411
412                encode_key(tag, WireType::LengthDelimited, buf);
413                let len = values.len() as u64 * $width;
414                encode_varint(len as u64, buf);
415
416                for value in values {
417                    buf.$put(*value);
418                }
419            }
420
421            merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
422
423            #[inline]
424            pub fn encoded_len(tag: u32, _: &$ty) -> usize {
425                key_len(tag) + $width
426            }
427
428            #[inline]
429            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
430                (key_len(tag) + $width) * values.len()
431            }
432
433            #[inline]
434            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
435                if values.is_empty() {
436                    0
437                } else {
438                    let len = $width * values.len();
439                    key_len(tag) + encoded_len_varint(len as u64) + len
440                }
441            }
442
443            #[cfg(test)]
444            mod test {
445                use proptest::prelude::*;
446
447                use super::super::test::{check_collection_type, check_type};
448                use super::*;
449
450                proptest! {
451                    #[test]
452                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
453                        check_type(value, tag, $wire_type,
454                                   encode, merge, encoded_len)?;
455                    }
456                    #[test]
457                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
458                        check_collection_type(value, tag, $wire_type,
459                                              encode_repeated, merge_repeated,
460                                              encoded_len_repeated)?;
461                    }
462                    #[test]
463                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
464                        check_type(value, tag, WireType::LengthDelimited,
465                                   encode_packed, merge_repeated,
466                                   encoded_len_packed)?;
467                    }
468                }
469            }
470        }
471    };
472}
473fixed_width!(
474    f32,
475    4,
476    WireType::ThirtyTwoBit,
477    float,
478    put_f32_le,
479    get_f32_le
480);
481fixed_width!(
482    f64,
483    8,
484    WireType::SixtyFourBit,
485    double,
486    put_f64_le,
487    get_f64_le
488);
489fixed_width!(
490    u32,
491    4,
492    WireType::ThirtyTwoBit,
493    fixed32,
494    put_u32_le,
495    get_u32_le
496);
497fixed_width!(
498    u64,
499    8,
500    WireType::SixtyFourBit,
501    fixed64,
502    put_u64_le,
503    get_u64_le
504);
505fixed_width!(
506    i32,
507    4,
508    WireType::ThirtyTwoBit,
509    sfixed32,
510    put_i32_le,
511    get_i32_le
512);
513fixed_width!(
514    i64,
515    8,
516    WireType::SixtyFourBit,
517    sfixed64,
518    put_i64_le,
519    get_i64_le
520);
521
522/// Macro which emits encoding functions for a length-delimited type.
523macro_rules! length_delimited {
524    ($ty:ty) => {
525        encode_repeated!($ty);
526
527        pub fn merge_repeated(
528            wire_type: WireType,
529            values: &mut Vec<$ty>,
530            buf: &mut impl Buf,
531            ctx: DecodeContext,
532        ) -> Result<(), DecodeError> {
533            check_wire_type(WireType::LengthDelimited, wire_type)?;
534            let mut value = Default::default();
535            merge(wire_type, &mut value, buf, ctx)?;
536            values.push(value);
537            Ok(())
538        }
539
540        #[inline]
541        pub fn encoded_len(tag: u32, value: &$ty) -> usize {
542            key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
543        }
544
545        #[inline]
546        pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
547            key_len(tag) * values.len()
548                + values
549                    .iter()
550                    .map(|value| encoded_len_varint(value.len() as u64) + value.len())
551                    .sum::<usize>()
552        }
553    };
554}
555
556pub mod string {
557    use super::*;
558
559    pub fn encode(tag: u32, value: &String, buf: &mut impl BufMut) {
560        encode_key(tag, WireType::LengthDelimited, buf);
561        encode_varint(value.len() as u64, buf);
562        buf.put_slice(value.as_bytes());
563    }
564
565    pub fn merge(
566        wire_type: WireType,
567        value: &mut String,
568        buf: &mut impl Buf,
569        ctx: DecodeContext,
570    ) -> Result<(), DecodeError> {
571        // ## Unsafety
572        //
573        // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
574        // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
575        // string is cleared, so as to avoid leaking a string field with invalid data.
576        //
577        // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
578        // alternative of temporarily swapping an empty `String` into the field, because it results
579        // in up to 10% better performance on the protobuf message decoding benchmarks.
580        //
581        // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
582        // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
583        // in the buf implementation, a drop guard is used.
584        unsafe {
585            struct DropGuard<'a>(&'a mut Vec<u8>);
586            impl Drop for DropGuard<'_> {
587                #[inline]
588                fn drop(&mut self) {
589                    self.0.clear();
590                }
591            }
592
593            let drop_guard = DropGuard(value.as_mut_vec());
594            bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
595            match str::from_utf8(drop_guard.0) {
596                Ok(_) => {
597                    // Success; do not clear the bytes.
598                    mem::forget(drop_guard);
599                    Ok(())
600                }
601                Err(_) => Err(DecodeError::new(
602                    "invalid string value: data is not UTF-8 encoded",
603                )),
604            }
605        }
606    }
607
608    length_delimited!(String);
609
610    #[cfg(test)]
611    mod test {
612        use proptest::prelude::*;
613
614        use super::super::test::{check_collection_type, check_type};
615        use super::*;
616
617        proptest! {
618            #[test]
619            fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
620                super::test::check_type(value, tag, WireType::LengthDelimited,
621                                        encode, merge, encoded_len)?;
622            }
623            #[test]
624            fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
625                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
626                                                   encode_repeated, merge_repeated,
627                                                   encoded_len_repeated)?;
628            }
629        }
630    }
631}
632
633pub trait BytesAdapter: sealed::BytesAdapter {}
634
635mod sealed {
636    use super::{Buf, BufMut};
637
638    pub trait BytesAdapter: Default + Sized + 'static {
639        fn len(&self) -> usize;
640
641        /// Replace contents of this buffer with the contents of another buffer.
642        fn replace_with(&mut self, buf: impl Buf);
643
644        /// Appends this buffer to the (contents of) other buffer.
645        fn append_to(&self, buf: &mut impl BufMut);
646
647        fn is_empty(&self) -> bool {
648            self.len() == 0
649        }
650    }
651}
652
653impl BytesAdapter for Bytes {}
654
655impl sealed::BytesAdapter for Bytes {
656    fn len(&self) -> usize {
657        Buf::remaining(self)
658    }
659
660    fn replace_with(&mut self, mut buf: impl Buf) {
661        *self = buf.copy_to_bytes(buf.remaining());
662    }
663
664    fn append_to(&self, buf: &mut impl BufMut) {
665        buf.put(self.clone())
666    }
667}
668
669impl BytesAdapter for Vec<u8> {}
670
671impl sealed::BytesAdapter for Vec<u8> {
672    fn len(&self) -> usize {
673        Vec::len(self)
674    }
675
676    fn replace_with(&mut self, buf: impl Buf) {
677        self.clear();
678        self.reserve(buf.remaining());
679        self.put(buf);
680    }
681
682    fn append_to(&self, buf: &mut impl BufMut) {
683        buf.put(self.as_slice())
684    }
685}
686
687pub mod bytes {
688    use super::*;
689
690    pub fn encode(tag: u32, value: &impl BytesAdapter, buf: &mut impl BufMut) {
691        encode_key(tag, WireType::LengthDelimited, buf);
692        encode_varint(value.len() as u64, buf);
693        value.append_to(buf);
694    }
695
696    pub fn merge(
697        wire_type: WireType,
698        value: &mut impl BytesAdapter,
699        buf: &mut impl Buf,
700        _ctx: DecodeContext,
701    ) -> Result<(), DecodeError> {
702        check_wire_type(WireType::LengthDelimited, wire_type)?;
703        let len = decode_varint(buf)?;
704        if len > buf.remaining() as u64 {
705            return Err(DecodeError::new("buffer underflow"));
706        }
707        let len = len as usize;
708
709        // Clear the existing value. This follows from the following rule in the encoding guide[1]:
710        //
711        // > Normally, an encoded message would never have more than one instance of a non-repeated
712        // > field. However, parsers are expected to handle the case in which they do. For numeric
713        // > types and strings, if the same field appears multiple times, the parser accepts the
714        // > last value it sees.
715        //
716        // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional
717        //
718        // This is intended for A and B both being Bytes so it is zero-copy.
719        // Some combinations of A and B types may cause a double-copy,
720        // in which case merge_one_copy() should be used instead.
721        value.replace_with(buf.copy_to_bytes(len));
722        Ok(())
723    }
724
725    pub(super) fn merge_one_copy(
726        wire_type: WireType,
727        value: &mut impl BytesAdapter,
728        buf: &mut impl Buf,
729        _ctx: DecodeContext,
730    ) -> Result<(), DecodeError> {
731        check_wire_type(WireType::LengthDelimited, wire_type)?;
732        let len = decode_varint(buf)?;
733        if len > buf.remaining() as u64 {
734            return Err(DecodeError::new("buffer underflow"));
735        }
736        let len = len as usize;
737
738        // If we must copy, make sure to copy only once.
739        value.replace_with(buf.take(len));
740        Ok(())
741    }
742
743    length_delimited!(impl BytesAdapter);
744
745    #[cfg(test)]
746    mod test {
747        use proptest::prelude::*;
748
749        use super::super::test::{check_collection_type, check_type};
750        use super::*;
751
752        proptest! {
753            #[test]
754            fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
755                super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
756                                                            encode, merge, encoded_len)?;
757            }
758
759            #[test]
760            fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
761                let value = Bytes::from(value);
762                super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
763                                                        encode, merge, encoded_len)?;
764            }
765
766            #[test]
767            fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
768                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
769                                                   encode_repeated, merge_repeated,
770                                                   encoded_len_repeated)?;
771            }
772
773            #[test]
774            fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
775                let value = value.into_iter().map(Bytes::from).collect();
776                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
777                                                   encode_repeated, merge_repeated,
778                                                   encoded_len_repeated)?;
779            }
780        }
781    }
782}
783
784pub mod message {
785    use super::*;
786
787    pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
788    where
789        M: Message,
790    {
791        encode_key(tag, WireType::LengthDelimited, buf);
792        encode_varint(msg.encoded_len() as u64, buf);
793        msg.encode_raw(buf);
794    }
795
796    pub fn merge<M, B>(
797        wire_type: WireType,
798        msg: &mut M,
799        buf: &mut B,
800        ctx: DecodeContext,
801    ) -> Result<(), DecodeError>
802    where
803        M: Message,
804        B: Buf,
805    {
806        check_wire_type(WireType::LengthDelimited, wire_type)?;
807        ctx.limit_reached()?;
808        merge_loop(
809            msg,
810            buf,
811            ctx.enter_recursion(),
812            |msg: &mut M, buf: &mut B, ctx| {
813                let (tag, wire_type) = decode_key(buf)?;
814                msg.merge_field(tag, wire_type, buf, ctx)
815            },
816        )
817    }
818
819    pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
820    where
821        M: Message,
822    {
823        for msg in messages {
824            encode(tag, msg, buf);
825        }
826    }
827
828    pub fn merge_repeated<M>(
829        wire_type: WireType,
830        messages: &mut Vec<M>,
831        buf: &mut impl Buf,
832        ctx: DecodeContext,
833    ) -> Result<(), DecodeError>
834    where
835        M: Message + Default,
836    {
837        check_wire_type(WireType::LengthDelimited, wire_type)?;
838        let mut msg = M::default();
839        merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
840        messages.push(msg);
841        Ok(())
842    }
843
844    #[inline]
845    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
846    where
847        M: Message,
848    {
849        let len = msg.encoded_len();
850        key_len(tag) + encoded_len_varint(len as u64) + len
851    }
852
853    #[inline]
854    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
855    where
856        M: Message,
857    {
858        key_len(tag) * messages.len()
859            + messages
860                .iter()
861                .map(Message::encoded_len)
862                .map(|len| len + encoded_len_varint(len as u64))
863                .sum::<usize>()
864    }
865}
866
867pub mod group {
868    use super::*;
869
870    pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
871    where
872        M: Message,
873    {
874        encode_key(tag, WireType::StartGroup, buf);
875        msg.encode_raw(buf);
876        encode_key(tag, WireType::EndGroup, buf);
877    }
878
879    pub fn merge<M>(
880        tag: u32,
881        wire_type: WireType,
882        msg: &mut M,
883        buf: &mut impl Buf,
884        ctx: DecodeContext,
885    ) -> Result<(), DecodeError>
886    where
887        M: Message,
888    {
889        check_wire_type(WireType::StartGroup, wire_type)?;
890
891        ctx.limit_reached()?;
892        loop {
893            let (field_tag, field_wire_type) = decode_key(buf)?;
894            if field_wire_type == WireType::EndGroup {
895                if field_tag != tag {
896                    return Err(DecodeError::new("unexpected end group tag"));
897                }
898                return Ok(());
899            }
900
901            M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
902        }
903    }
904
905    pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
906    where
907        M: Message,
908    {
909        for msg in messages {
910            encode(tag, msg, buf);
911        }
912    }
913
914    pub fn merge_repeated<M>(
915        tag: u32,
916        wire_type: WireType,
917        messages: &mut Vec<M>,
918        buf: &mut impl Buf,
919        ctx: DecodeContext,
920    ) -> Result<(), DecodeError>
921    where
922        M: Message + Default,
923    {
924        check_wire_type(WireType::StartGroup, wire_type)?;
925        let mut msg = M::default();
926        merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
927        messages.push(msg);
928        Ok(())
929    }
930
931    #[inline]
932    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
933    where
934        M: Message,
935    {
936        2 * key_len(tag) + msg.encoded_len()
937    }
938
939    #[inline]
940    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
941    where
942        M: Message,
943    {
944        2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
945    }
946}
947
948/// Rust doesn't have a `Map` trait, so macros are currently the best way to be
949/// generic over `HashMap` and `BTreeMap`.
950macro_rules! map {
951    ($map_ty:ident) => {
952        use crate::encoding::*;
953        use core::hash::Hash;
954
955        /// Generic protobuf map encode function.
956        pub fn encode<K, V, B, KE, KL, VE, VL>(
957            key_encode: KE,
958            key_encoded_len: KL,
959            val_encode: VE,
960            val_encoded_len: VL,
961            tag: u32,
962            values: &$map_ty<K, V>,
963            buf: &mut B,
964        ) where
965            K: Default + Eq + Hash + Ord,
966            V: Default + PartialEq,
967            B: BufMut,
968            KE: Fn(u32, &K, &mut B),
969            KL: Fn(u32, &K) -> usize,
970            VE: Fn(u32, &V, &mut B),
971            VL: Fn(u32, &V) -> usize,
972        {
973            encode_with_default(
974                key_encode,
975                key_encoded_len,
976                val_encode,
977                val_encoded_len,
978                &V::default(),
979                tag,
980                values,
981                buf,
982            )
983        }
984
985        /// Generic protobuf map merge function.
986        pub fn merge<K, V, B, KM, VM>(
987            key_merge: KM,
988            val_merge: VM,
989            values: &mut $map_ty<K, V>,
990            buf: &mut B,
991            ctx: DecodeContext,
992        ) -> Result<(), DecodeError>
993        where
994            K: Default + Eq + Hash + Ord,
995            V: Default,
996            B: Buf,
997            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
998            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
999        {
1000            merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1001        }
1002
1003        /// Generic protobuf map encode function.
1004        pub fn encoded_len<K, V, KL, VL>(
1005            key_encoded_len: KL,
1006            val_encoded_len: VL,
1007            tag: u32,
1008            values: &$map_ty<K, V>,
1009        ) -> usize
1010        where
1011            K: Default + Eq + Hash + Ord,
1012            V: Default + PartialEq,
1013            KL: Fn(u32, &K) -> usize,
1014            VL: Fn(u32, &V) -> usize,
1015        {
1016            encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1017        }
1018
1019        /// Generic protobuf map encode function with an overridden value default.
1020        ///
1021        /// This is necessary because enumeration values can have a default value other
1022        /// than 0 in proto2.
1023        pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1024            key_encode: KE,
1025            key_encoded_len: KL,
1026            val_encode: VE,
1027            val_encoded_len: VL,
1028            val_default: &V,
1029            tag: u32,
1030            values: &$map_ty<K, V>,
1031            buf: &mut B,
1032        ) where
1033            K: Default + Eq + Hash + Ord,
1034            V: PartialEq,
1035            B: BufMut,
1036            KE: Fn(u32, &K, &mut B),
1037            KL: Fn(u32, &K) -> usize,
1038            VE: Fn(u32, &V, &mut B),
1039            VL: Fn(u32, &V) -> usize,
1040        {
1041            for (key, val) in values.iter() {
1042                let skip_key = key == &K::default();
1043                let skip_val = val == val_default;
1044
1045                let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1046                    + (if skip_val { 0 } else { val_encoded_len(2, val) });
1047
1048                encode_key(tag, WireType::LengthDelimited, buf);
1049                encode_varint(len as u64, buf);
1050                if !skip_key {
1051                    key_encode(1, key, buf);
1052                }
1053                if !skip_val {
1054                    val_encode(2, val, buf);
1055                }
1056            }
1057        }
1058
1059        /// Generic protobuf map merge function with an overridden value default.
1060        ///
1061        /// This is necessary because enumeration values can have a default value other
1062        /// than 0 in proto2.
1063        pub fn merge_with_default<K, V, B, KM, VM>(
1064            key_merge: KM,
1065            val_merge: VM,
1066            val_default: V,
1067            values: &mut $map_ty<K, V>,
1068            buf: &mut B,
1069            ctx: DecodeContext,
1070        ) -> Result<(), DecodeError>
1071        where
1072            K: Default + Eq + Hash + Ord,
1073            B: Buf,
1074            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1075            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1076        {
1077            let mut key = Default::default();
1078            let mut val = val_default;
1079            ctx.limit_reached()?;
1080            merge_loop(
1081                &mut (&mut key, &mut val),
1082                buf,
1083                ctx.enter_recursion(),
1084                |&mut (ref mut key, ref mut val), buf, ctx| {
1085                    let (tag, wire_type) = decode_key(buf)?;
1086                    match tag {
1087                        1 => key_merge(wire_type, key, buf, ctx),
1088                        2 => val_merge(wire_type, val, buf, ctx),
1089                        _ => skip_field(wire_type, tag, buf, ctx),
1090                    }
1091                },
1092            )?;
1093            values.insert(key, val);
1094
1095            Ok(())
1096        }
1097
1098        /// Generic protobuf map encode function with an overridden value default.
1099        ///
1100        /// This is necessary because enumeration values can have a default value other
1101        /// than 0 in proto2.
1102        pub fn encoded_len_with_default<K, V, KL, VL>(
1103            key_encoded_len: KL,
1104            val_encoded_len: VL,
1105            val_default: &V,
1106            tag: u32,
1107            values: &$map_ty<K, V>,
1108        ) -> usize
1109        where
1110            K: Default + Eq + Hash + Ord,
1111            V: PartialEq,
1112            KL: Fn(u32, &K) -> usize,
1113            VL: Fn(u32, &V) -> usize,
1114        {
1115            key_len(tag) * values.len()
1116                + values
1117                    .iter()
1118                    .map(|(key, val)| {
1119                        let len = (if key == &K::default() {
1120                            0
1121                        } else {
1122                            key_encoded_len(1, key)
1123                        }) + (if val == val_default {
1124                            0
1125                        } else {
1126                            val_encoded_len(2, val)
1127                        });
1128                        encoded_len_varint(len as u64) + len
1129                    })
1130                    .sum::<usize>()
1131        }
1132    };
1133}
1134
1135#[cfg(feature = "std")]
1136pub mod hash_map {
1137    use std::collections::HashMap;
1138    map!(HashMap);
1139}
1140
1141pub mod btree_map {
1142    map!(BTreeMap);
1143}
1144
1145#[cfg(test)]
1146mod test {
1147    #[cfg(not(feature = "std"))]
1148    use alloc::string::ToString;
1149    use core::borrow::Borrow;
1150    use core::fmt::Debug;
1151
1152    use ::bytes::BytesMut;
1153    use proptest::{prelude::*, test_runner::TestCaseResult};
1154
1155    use super::*;
1156
1157    pub fn check_type<T, B>(
1158        value: T,
1159        tag: u32,
1160        wire_type: WireType,
1161        encode: fn(u32, &B, &mut BytesMut),
1162        merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1163        encoded_len: fn(u32, &B) -> usize,
1164    ) -> TestCaseResult
1165    where
1166        T: Debug + Default + PartialEq + Borrow<B>,
1167        B: ?Sized,
1168    {
1169        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1170
1171        let expected_len = encoded_len(tag, value.borrow());
1172
1173        let mut buf = BytesMut::with_capacity(expected_len);
1174        encode(tag, value.borrow(), &mut buf);
1175
1176        let mut buf = buf.freeze();
1177
1178        prop_assert_eq!(
1179            buf.remaining(),
1180            expected_len,
1181            "encoded_len wrong; expected: {}, actual: {}",
1182            expected_len,
1183            buf.remaining()
1184        );
1185
1186        if !buf.has_remaining() {
1187            // Short circuit for empty packed values.
1188            return Ok(());
1189        }
1190
1191        let (decoded_tag, decoded_wire_type) =
1192            decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1193        prop_assert_eq!(
1194            tag,
1195            decoded_tag,
1196            "decoded tag does not match; expected: {}, actual: {}",
1197            tag,
1198            decoded_tag
1199        );
1200
1201        prop_assert_eq!(
1202            wire_type,
1203            decoded_wire_type,
1204            "decoded wire type does not match; expected: {:?}, actual: {:?}",
1205            wire_type,
1206            decoded_wire_type,
1207        );
1208
1209        match wire_type {
1210            WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1211                "64bit wire type illegal remaining: {}, tag: {}",
1212                buf.remaining(),
1213                tag
1214            ))),
1215            WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1216                "32bit wire type illegal remaining: {}, tag: {}",
1217                buf.remaining(),
1218                tag
1219            ))),
1220            _ => Ok(()),
1221        }?;
1222
1223        let mut roundtrip_value = T::default();
1224        merge(
1225            wire_type,
1226            &mut roundtrip_value,
1227            &mut buf,
1228            DecodeContext::default(),
1229        )
1230        .map_err(|error| TestCaseError::fail(error.to_string()))?;
1231
1232        prop_assert!(
1233            !buf.has_remaining(),
1234            "expected buffer to be empty, remaining: {}",
1235            buf.remaining()
1236        );
1237
1238        prop_assert_eq!(value, roundtrip_value);
1239
1240        Ok(())
1241    }
1242
1243    pub fn check_collection_type<T, B, E, M, L>(
1244        value: T,
1245        tag: u32,
1246        wire_type: WireType,
1247        encode: E,
1248        mut merge: M,
1249        encoded_len: L,
1250    ) -> TestCaseResult
1251    where
1252        T: Debug + Default + PartialEq + Borrow<B>,
1253        B: ?Sized,
1254        E: FnOnce(u32, &B, &mut BytesMut),
1255        M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1256        L: FnOnce(u32, &B) -> usize,
1257    {
1258        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1259
1260        let expected_len = encoded_len(tag, value.borrow());
1261
1262        let mut buf = BytesMut::with_capacity(expected_len);
1263        encode(tag, value.borrow(), &mut buf);
1264
1265        let mut buf = buf.freeze();
1266
1267        prop_assert_eq!(
1268            buf.remaining(),
1269            expected_len,
1270            "encoded_len wrong; expected: {}, actual: {}",
1271            expected_len,
1272            buf.remaining()
1273        );
1274
1275        let mut roundtrip_value = Default::default();
1276        while buf.has_remaining() {
1277            let (decoded_tag, decoded_wire_type) =
1278                decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1279
1280            prop_assert_eq!(
1281                tag,
1282                decoded_tag,
1283                "decoded tag does not match; expected: {}, actual: {}",
1284                tag,
1285                decoded_tag
1286            );
1287
1288            prop_assert_eq!(
1289                wire_type,
1290                decoded_wire_type,
1291                "decoded wire type does not match; expected: {:?}, actual: {:?}",
1292                wire_type,
1293                decoded_wire_type
1294            );
1295
1296            merge(
1297                wire_type,
1298                &mut roundtrip_value,
1299                &mut buf,
1300                DecodeContext::default(),
1301            )
1302            .map_err(|error| TestCaseError::fail(error.to_string()))?;
1303        }
1304
1305        prop_assert_eq!(value, roundtrip_value);
1306
1307        Ok(())
1308    }
1309
1310    #[test]
1311    fn string_merge_invalid_utf8() {
1312        let mut s = String::new();
1313        let buf = b"\x02\x80\x80";
1314
1315        let r = string::merge(
1316            WireType::LengthDelimited,
1317            &mut s,
1318            &mut &buf[..],
1319            DecodeContext::default(),
1320        );
1321        r.expect_err("must be an error");
1322        assert!(s.is_empty());
1323    }
1324
1325    /// This big bowl o' macro soup generates an encoding property test for each combination of map
1326    /// type, scalar map key, and value type.
1327    /// TODO: these tests take a long time to compile, can this be improved?
1328    #[cfg(feature = "std")]
1329    macro_rules! map_tests {
1330        (keys: $keys:tt,
1331         vals: $vals:tt) => {
1332            mod hash_map {
1333                map_tests!(@private HashMap, hash_map, $keys, $vals);
1334            }
1335            mod btree_map {
1336                map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1337            }
1338        };
1339
1340        (@private $map_type:ident,
1341                  $mod_name:ident,
1342                  [$(($key_ty:ty, $key_proto:ident)),*],
1343                  $vals:tt) => {
1344            $(
1345                mod $key_proto {
1346                    use std::collections::$map_type;
1347
1348                    use proptest::prelude::*;
1349
1350                    use crate::encoding::*;
1351                    use crate::encoding::test::check_collection_type;
1352
1353                    map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1354                }
1355            )*
1356        };
1357
1358        (@private $map_type:ident,
1359                  $mod_name:ident,
1360                  ($key_ty:ty, $key_proto:ident),
1361                  [$(($val_ty:ty, $val_proto:ident)),*]) => {
1362            $(
1363                proptest! {
1364                    #[test]
1365                    fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1366                        check_collection_type(values, tag, WireType::LengthDelimited,
1367                                              |tag, values, buf| {
1368                                                  $mod_name::encode($key_proto::encode,
1369                                                                    $key_proto::encoded_len,
1370                                                                    $val_proto::encode,
1371                                                                    $val_proto::encoded_len,
1372                                                                    tag,
1373                                                                    values,
1374                                                                    buf)
1375                                              },
1376                                              |wire_type, values, buf, ctx| {
1377                                                  check_wire_type(WireType::LengthDelimited, wire_type)?;
1378                                                  $mod_name::merge($key_proto::merge,
1379                                                                   $val_proto::merge,
1380                                                                   values,
1381                                                                   buf,
1382                                                                   ctx)
1383                                              },
1384                                              |tag, values| {
1385                                                  $mod_name::encoded_len($key_proto::encoded_len,
1386                                                                         $val_proto::encoded_len,
1387                                                                         tag,
1388                                                                         values)
1389                                              })?;
1390                    }
1391                }
1392             )*
1393        };
1394    }
1395
1396    #[cfg(feature = "std")]
1397    map_tests!(keys: [
1398        (i32, int32),
1399        (i64, int64),
1400        (u32, uint32),
1401        (u64, uint64),
1402        (i32, sint32),
1403        (i64, sint64),
1404        (u32, fixed32),
1405        (u64, fixed64),
1406        (i32, sfixed32),
1407        (i64, sfixed64),
1408        (bool, bool),
1409        (String, string)
1410    ],
1411    vals: [
1412        (f32, float),
1413        (f64, double),
1414        (i32, int32),
1415        (i64, int64),
1416        (u32, uint32),
1417        (u64, uint64),
1418        (i32, sint32),
1419        (i64, sint64),
1420        (u32, fixed32),
1421        (u64, fixed64),
1422        (i32, sfixed32),
1423        (i64, sfixed64),
1424        (bool, bool),
1425        (String, string),
1426        (Vec<u8>, bytes)
1427    ]);
1428}