1#![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7use alloc::collections::BTreeMap;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::mem;
12use core::str;
13
14use ::bytes::{Buf, BufMut, Bytes};
15
16use crate::DecodeError;
17use crate::Message;
18
19pub mod varint;
20pub use varint::{decode_varint, encode_varint, encoded_len_varint};
21
22pub mod length_delimiter;
23pub use length_delimiter::{
24 decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
25};
26
27pub mod wire_type;
28pub use wire_type::{check_wire_type, WireType};
29
30#[derive(Clone, Debug)]
35#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
36pub struct DecodeContext {
37 #[cfg(not(feature = "no-recursion-limit"))]
44 recurse_count: u32,
45}
46
47#[cfg(not(feature = "no-recursion-limit"))]
48impl Default for DecodeContext {
49 #[inline]
50 fn default() -> DecodeContext {
51 DecodeContext {
52 recurse_count: crate::RECURSION_LIMIT,
53 }
54 }
55}
56
57impl DecodeContext {
58 #[cfg(not(feature = "no-recursion-limit"))]
64 #[inline]
65 pub(crate) fn enter_recursion(&self) -> DecodeContext {
66 DecodeContext {
67 recurse_count: self.recurse_count - 1,
68 }
69 }
70
71 #[cfg(feature = "no-recursion-limit")]
72 #[inline]
73 pub(crate) fn enter_recursion(&self) -> DecodeContext {
74 DecodeContext {}
75 }
76
77 #[cfg(not(feature = "no-recursion-limit"))]
83 #[inline]
84 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
85 if self.recurse_count == 0 {
86 Err(DecodeError::new("recursion limit reached"))
87 } else {
88 Ok(())
89 }
90 }
91
92 #[cfg(feature = "no-recursion-limit")]
93 #[inline]
94 #[allow(clippy::unnecessary_wraps)] pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
96 Ok(())
97 }
98}
99
100pub const MIN_TAG: u32 = 1;
101pub const MAX_TAG: u32 = (1 << 29) - 1;
102
103#[inline]
106pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut impl BufMut) {
107 debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
108 let key = (tag << 3) | wire_type as u32;
109 encode_varint(u64::from(key), buf);
110}
111
112#[inline(always)]
115pub fn decode_key(buf: &mut impl Buf) -> Result<(u32, WireType), DecodeError> {
116 let key = decode_varint(buf)?;
117 if key > u64::from(u32::MAX) {
118 return Err(DecodeError::new(format!("invalid key value: {}", key)));
119 }
120 let wire_type = WireType::try_from(key & 0x07)?;
121 let tag = key as u32 >> 3;
122
123 if tag < MIN_TAG {
124 return Err(DecodeError::new("invalid tag value: 0"));
125 }
126
127 Ok((tag, wire_type))
128}
129
130#[inline]
133pub const fn key_len(tag: u32) -> usize {
134 encoded_len_varint((tag << 3) as u64)
135}
136
137pub fn merge_loop<T, M, B>(
140 value: &mut T,
141 buf: &mut B,
142 ctx: DecodeContext,
143 mut merge: M,
144) -> Result<(), DecodeError>
145where
146 M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
147 B: Buf,
148{
149 let len = decode_varint(buf)?;
150 let remaining = buf.remaining();
151 if len > remaining as u64 {
152 return Err(DecodeError::new("buffer underflow"));
153 }
154
155 let limit = remaining - len as usize;
156 while buf.remaining() > limit {
157 merge(value, buf, ctx.clone())?;
158 }
159
160 if buf.remaining() != limit {
161 return Err(DecodeError::new("delimited length exceeded"));
162 }
163 Ok(())
164}
165
166pub fn skip_field(
167 wire_type: WireType,
168 tag: u32,
169 buf: &mut impl Buf,
170 ctx: DecodeContext,
171) -> Result<(), DecodeError> {
172 ctx.limit_reached()?;
173 let len = match wire_type {
174 WireType::Varint => decode_varint(buf).map(|_| 0)?,
175 WireType::ThirtyTwoBit => 4,
176 WireType::SixtyFourBit => 8,
177 WireType::LengthDelimited => decode_varint(buf)?,
178 WireType::StartGroup => loop {
179 let (inner_tag, inner_wire_type) = decode_key(buf)?;
180 match inner_wire_type {
181 WireType::EndGroup => {
182 if inner_tag != tag {
183 return Err(DecodeError::new("unexpected end group tag"));
184 }
185 break 0;
186 }
187 _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
188 }
189 },
190 WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
191 };
192
193 if len > buf.remaining() as u64 {
194 return Err(DecodeError::new("buffer underflow"));
195 }
196
197 buf.advance(len as usize);
198 Ok(())
199}
200
201macro_rules! encode_repeated {
203 ($ty:ty) => {
204 pub fn encode_repeated(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
205 for value in values {
206 encode(tag, value, buf);
207 }
208 }
209 };
210}
211
212macro_rules! merge_repeated_numeric {
214 ($ty:ty,
215 $wire_type:expr,
216 $merge:ident,
217 $merge_repeated:ident) => {
218 pub fn $merge_repeated(
219 wire_type: WireType,
220 values: &mut Vec<$ty>,
221 buf: &mut impl Buf,
222 ctx: DecodeContext,
223 ) -> Result<(), DecodeError> {
224 if wire_type == WireType::LengthDelimited {
225 merge_loop(values, buf, ctx, |values, buf, ctx| {
227 let mut value = Default::default();
228 $merge($wire_type, &mut value, buf, ctx)?;
229 values.push(value);
230 Ok(())
231 })
232 } else {
233 check_wire_type($wire_type, wire_type)?;
235 let mut value = Default::default();
236 $merge(wire_type, &mut value, buf, ctx)?;
237 values.push(value);
238 Ok(())
239 }
240 }
241 };
242}
243
244macro_rules! varint {
247 ($ty:ty,
248 $proto_ty:ident) => (
249 varint!($ty,
250 $proto_ty,
251 to_uint64(value) { *value as u64 },
252 from_uint64(value) { value as $ty });
253 );
254
255 ($ty:ty,
256 $proto_ty:ident,
257 to_uint64($to_uint64_value:ident) $to_uint64:expr,
258 from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
259
260 pub mod $proto_ty {
261 use crate::encoding::*;
262
263 pub fn encode(tag: u32, $to_uint64_value: &$ty, buf: &mut impl BufMut) {
264 encode_key(tag, WireType::Varint, buf);
265 encode_varint($to_uint64, buf);
266 }
267
268 pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut impl Buf, _ctx: DecodeContext) -> Result<(), DecodeError> {
269 check_wire_type(WireType::Varint, wire_type)?;
270 let $from_uint64_value = decode_varint(buf)?;
271 *value = $from_uint64;
272 Ok(())
273 }
274
275 encode_repeated!($ty);
276
277 pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
278 if values.is_empty() { return; }
279
280 encode_key(tag, WireType::LengthDelimited, buf);
281 let len: usize = values.iter().map(|$to_uint64_value| {
282 encoded_len_varint($to_uint64)
283 }).sum();
284 encode_varint(len as u64, buf);
285
286 for $to_uint64_value in values {
287 encode_varint($to_uint64, buf);
288 }
289 }
290
291 merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
292
293 #[inline]
294 pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
295 key_len(tag) + encoded_len_varint($to_uint64)
296 }
297
298 #[inline]
299 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
300 key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
301 encoded_len_varint($to_uint64)
302 }).sum::<usize>()
303 }
304
305 #[inline]
306 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
307 if values.is_empty() {
308 0
309 } else {
310 let len = values.iter()
311 .map(|$to_uint64_value| encoded_len_varint($to_uint64))
312 .sum::<usize>();
313 key_len(tag) + encoded_len_varint(len as u64) + len
314 }
315 }
316
317 #[cfg(test)]
318 mod test {
319 use proptest::prelude::*;
320
321 use crate::encoding::$proto_ty::*;
322 use crate::encoding::test::{
323 check_collection_type,
324 check_type,
325 };
326
327 proptest! {
328 #[test]
329 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
330 check_type(value, tag, WireType::Varint,
331 encode, merge, encoded_len)?;
332 }
333 #[test]
334 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
335 check_collection_type(value, tag, WireType::Varint,
336 encode_repeated, merge_repeated,
337 encoded_len_repeated)?;
338 }
339 #[test]
340 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
341 check_type(value, tag, WireType::LengthDelimited,
342 encode_packed, merge_repeated,
343 encoded_len_packed)?;
344 }
345 }
346 }
347 }
348
349 );
350}
351varint!(bool, bool,
352 to_uint64(value) u64::from(*value),
353 from_uint64(value) value != 0);
354varint!(i32, int32);
355varint!(i64, int64);
356varint!(u32, uint32);
357varint!(u64, uint64);
358varint!(i32, sint32,
359to_uint64(value) {
360 ((value << 1) ^ (value >> 31)) as u32 as u64
361},
362from_uint64(value) {
363 let value = value as u32;
364 ((value >> 1) as i32) ^ (-((value & 1) as i32))
365});
366varint!(i64, sint64,
367to_uint64(value) {
368 ((value << 1) ^ (value >> 63)) as u64
369},
370from_uint64(value) {
371 ((value >> 1) as i64) ^ (-((value & 1) as i64))
372});
373
374macro_rules! fixed_width {
377 ($ty:ty,
378 $width:expr,
379 $wire_type:expr,
380 $proto_ty:ident,
381 $put:ident,
382 $get:ident) => {
383 pub mod $proto_ty {
384 use crate::encoding::*;
385
386 pub fn encode(tag: u32, value: &$ty, buf: &mut impl BufMut) {
387 encode_key(tag, $wire_type, buf);
388 buf.$put(*value);
389 }
390
391 pub fn merge(
392 wire_type: WireType,
393 value: &mut $ty,
394 buf: &mut impl Buf,
395 _ctx: DecodeContext,
396 ) -> Result<(), DecodeError> {
397 check_wire_type($wire_type, wire_type)?;
398 if buf.remaining() < $width {
399 return Err(DecodeError::new("buffer underflow"));
400 }
401 *value = buf.$get();
402 Ok(())
403 }
404
405 encode_repeated!($ty);
406
407 pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
408 if values.is_empty() {
409 return;
410 }
411
412 encode_key(tag, WireType::LengthDelimited, buf);
413 let len = values.len() as u64 * $width;
414 encode_varint(len as u64, buf);
415
416 for value in values {
417 buf.$put(*value);
418 }
419 }
420
421 merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
422
423 #[inline]
424 pub fn encoded_len(tag: u32, _: &$ty) -> usize {
425 key_len(tag) + $width
426 }
427
428 #[inline]
429 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
430 (key_len(tag) + $width) * values.len()
431 }
432
433 #[inline]
434 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
435 if values.is_empty() {
436 0
437 } else {
438 let len = $width * values.len();
439 key_len(tag) + encoded_len_varint(len as u64) + len
440 }
441 }
442
443 #[cfg(test)]
444 mod test {
445 use proptest::prelude::*;
446
447 use super::super::test::{check_collection_type, check_type};
448 use super::*;
449
450 proptest! {
451 #[test]
452 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
453 check_type(value, tag, $wire_type,
454 encode, merge, encoded_len)?;
455 }
456 #[test]
457 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
458 check_collection_type(value, tag, $wire_type,
459 encode_repeated, merge_repeated,
460 encoded_len_repeated)?;
461 }
462 #[test]
463 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
464 check_type(value, tag, WireType::LengthDelimited,
465 encode_packed, merge_repeated,
466 encoded_len_packed)?;
467 }
468 }
469 }
470 }
471 };
472}
473fixed_width!(
474 f32,
475 4,
476 WireType::ThirtyTwoBit,
477 float,
478 put_f32_le,
479 get_f32_le
480);
481fixed_width!(
482 f64,
483 8,
484 WireType::SixtyFourBit,
485 double,
486 put_f64_le,
487 get_f64_le
488);
489fixed_width!(
490 u32,
491 4,
492 WireType::ThirtyTwoBit,
493 fixed32,
494 put_u32_le,
495 get_u32_le
496);
497fixed_width!(
498 u64,
499 8,
500 WireType::SixtyFourBit,
501 fixed64,
502 put_u64_le,
503 get_u64_le
504);
505fixed_width!(
506 i32,
507 4,
508 WireType::ThirtyTwoBit,
509 sfixed32,
510 put_i32_le,
511 get_i32_le
512);
513fixed_width!(
514 i64,
515 8,
516 WireType::SixtyFourBit,
517 sfixed64,
518 put_i64_le,
519 get_i64_le
520);
521
522macro_rules! length_delimited {
524 ($ty:ty) => {
525 encode_repeated!($ty);
526
527 pub fn merge_repeated(
528 wire_type: WireType,
529 values: &mut Vec<$ty>,
530 buf: &mut impl Buf,
531 ctx: DecodeContext,
532 ) -> Result<(), DecodeError> {
533 check_wire_type(WireType::LengthDelimited, wire_type)?;
534 let mut value = Default::default();
535 merge(wire_type, &mut value, buf, ctx)?;
536 values.push(value);
537 Ok(())
538 }
539
540 #[inline]
541 pub fn encoded_len(tag: u32, value: &$ty) -> usize {
542 key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
543 }
544
545 #[inline]
546 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
547 key_len(tag) * values.len()
548 + values
549 .iter()
550 .map(|value| encoded_len_varint(value.len() as u64) + value.len())
551 .sum::<usize>()
552 }
553 };
554}
555
556pub mod string {
557 use super::*;
558
559 pub fn encode(tag: u32, value: &String, buf: &mut impl BufMut) {
560 encode_key(tag, WireType::LengthDelimited, buf);
561 encode_varint(value.len() as u64, buf);
562 buf.put_slice(value.as_bytes());
563 }
564
565 pub fn merge(
566 wire_type: WireType,
567 value: &mut String,
568 buf: &mut impl Buf,
569 ctx: DecodeContext,
570 ) -> Result<(), DecodeError> {
571 unsafe {
585 struct DropGuard<'a>(&'a mut Vec<u8>);
586 impl Drop for DropGuard<'_> {
587 #[inline]
588 fn drop(&mut self) {
589 self.0.clear();
590 }
591 }
592
593 let drop_guard = DropGuard(value.as_mut_vec());
594 bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
595 match str::from_utf8(drop_guard.0) {
596 Ok(_) => {
597 mem::forget(drop_guard);
599 Ok(())
600 }
601 Err(_) => Err(DecodeError::new(
602 "invalid string value: data is not UTF-8 encoded",
603 )),
604 }
605 }
606 }
607
608 length_delimited!(String);
609
610 #[cfg(test)]
611 mod test {
612 use proptest::prelude::*;
613
614 use super::super::test::{check_collection_type, check_type};
615 use super::*;
616
617 proptest! {
618 #[test]
619 fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
620 super::test::check_type(value, tag, WireType::LengthDelimited,
621 encode, merge, encoded_len)?;
622 }
623 #[test]
624 fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
625 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
626 encode_repeated, merge_repeated,
627 encoded_len_repeated)?;
628 }
629 }
630 }
631}
632
633pub trait BytesAdapter: sealed::BytesAdapter {}
634
635mod sealed {
636 use super::{Buf, BufMut};
637
638 pub trait BytesAdapter: Default + Sized + 'static {
639 fn len(&self) -> usize;
640
641 fn replace_with(&mut self, buf: impl Buf);
643
644 fn append_to(&self, buf: &mut impl BufMut);
646
647 fn is_empty(&self) -> bool {
648 self.len() == 0
649 }
650 }
651}
652
653impl BytesAdapter for Bytes {}
654
655impl sealed::BytesAdapter for Bytes {
656 fn len(&self) -> usize {
657 Buf::remaining(self)
658 }
659
660 fn replace_with(&mut self, mut buf: impl Buf) {
661 *self = buf.copy_to_bytes(buf.remaining());
662 }
663
664 fn append_to(&self, buf: &mut impl BufMut) {
665 buf.put(self.clone())
666 }
667}
668
669impl BytesAdapter for Vec<u8> {}
670
671impl sealed::BytesAdapter for Vec<u8> {
672 fn len(&self) -> usize {
673 Vec::len(self)
674 }
675
676 fn replace_with(&mut self, buf: impl Buf) {
677 self.clear();
678 self.reserve(buf.remaining());
679 self.put(buf);
680 }
681
682 fn append_to(&self, buf: &mut impl BufMut) {
683 buf.put(self.as_slice())
684 }
685}
686
687pub mod bytes {
688 use super::*;
689
690 pub fn encode(tag: u32, value: &impl BytesAdapter, buf: &mut impl BufMut) {
691 encode_key(tag, WireType::LengthDelimited, buf);
692 encode_varint(value.len() as u64, buf);
693 value.append_to(buf);
694 }
695
696 pub fn merge(
697 wire_type: WireType,
698 value: &mut impl BytesAdapter,
699 buf: &mut impl Buf,
700 _ctx: DecodeContext,
701 ) -> Result<(), DecodeError> {
702 check_wire_type(WireType::LengthDelimited, wire_type)?;
703 let len = decode_varint(buf)?;
704 if len > buf.remaining() as u64 {
705 return Err(DecodeError::new("buffer underflow"));
706 }
707 let len = len as usize;
708
709 value.replace_with(buf.copy_to_bytes(len));
722 Ok(())
723 }
724
725 pub(super) fn merge_one_copy(
726 wire_type: WireType,
727 value: &mut impl BytesAdapter,
728 buf: &mut impl Buf,
729 _ctx: DecodeContext,
730 ) -> Result<(), DecodeError> {
731 check_wire_type(WireType::LengthDelimited, wire_type)?;
732 let len = decode_varint(buf)?;
733 if len > buf.remaining() as u64 {
734 return Err(DecodeError::new("buffer underflow"));
735 }
736 let len = len as usize;
737
738 value.replace_with(buf.take(len));
740 Ok(())
741 }
742
743 length_delimited!(impl BytesAdapter);
744
745 #[cfg(test)]
746 mod test {
747 use proptest::prelude::*;
748
749 use super::super::test::{check_collection_type, check_type};
750 use super::*;
751
752 proptest! {
753 #[test]
754 fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
755 super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
756 encode, merge, encoded_len)?;
757 }
758
759 #[test]
760 fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
761 let value = Bytes::from(value);
762 super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
763 encode, merge, encoded_len)?;
764 }
765
766 #[test]
767 fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
768 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
769 encode_repeated, merge_repeated,
770 encoded_len_repeated)?;
771 }
772
773 #[test]
774 fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
775 let value = value.into_iter().map(Bytes::from).collect();
776 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
777 encode_repeated, merge_repeated,
778 encoded_len_repeated)?;
779 }
780 }
781 }
782}
783
784pub mod message {
785 use super::*;
786
787 pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
788 where
789 M: Message,
790 {
791 encode_key(tag, WireType::LengthDelimited, buf);
792 encode_varint(msg.encoded_len() as u64, buf);
793 msg.encode_raw(buf);
794 }
795
796 pub fn merge<M, B>(
797 wire_type: WireType,
798 msg: &mut M,
799 buf: &mut B,
800 ctx: DecodeContext,
801 ) -> Result<(), DecodeError>
802 where
803 M: Message,
804 B: Buf,
805 {
806 check_wire_type(WireType::LengthDelimited, wire_type)?;
807 ctx.limit_reached()?;
808 merge_loop(
809 msg,
810 buf,
811 ctx.enter_recursion(),
812 |msg: &mut M, buf: &mut B, ctx| {
813 let (tag, wire_type) = decode_key(buf)?;
814 msg.merge_field(tag, wire_type, buf, ctx)
815 },
816 )
817 }
818
819 pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
820 where
821 M: Message,
822 {
823 for msg in messages {
824 encode(tag, msg, buf);
825 }
826 }
827
828 pub fn merge_repeated<M>(
829 wire_type: WireType,
830 messages: &mut Vec<M>,
831 buf: &mut impl Buf,
832 ctx: DecodeContext,
833 ) -> Result<(), DecodeError>
834 where
835 M: Message + Default,
836 {
837 check_wire_type(WireType::LengthDelimited, wire_type)?;
838 let mut msg = M::default();
839 merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
840 messages.push(msg);
841 Ok(())
842 }
843
844 #[inline]
845 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
846 where
847 M: Message,
848 {
849 let len = msg.encoded_len();
850 key_len(tag) + encoded_len_varint(len as u64) + len
851 }
852
853 #[inline]
854 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
855 where
856 M: Message,
857 {
858 key_len(tag) * messages.len()
859 + messages
860 .iter()
861 .map(Message::encoded_len)
862 .map(|len| len + encoded_len_varint(len as u64))
863 .sum::<usize>()
864 }
865}
866
867pub mod group {
868 use super::*;
869
870 pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
871 where
872 M: Message,
873 {
874 encode_key(tag, WireType::StartGroup, buf);
875 msg.encode_raw(buf);
876 encode_key(tag, WireType::EndGroup, buf);
877 }
878
879 pub fn merge<M>(
880 tag: u32,
881 wire_type: WireType,
882 msg: &mut M,
883 buf: &mut impl Buf,
884 ctx: DecodeContext,
885 ) -> Result<(), DecodeError>
886 where
887 M: Message,
888 {
889 check_wire_type(WireType::StartGroup, wire_type)?;
890
891 ctx.limit_reached()?;
892 loop {
893 let (field_tag, field_wire_type) = decode_key(buf)?;
894 if field_wire_type == WireType::EndGroup {
895 if field_tag != tag {
896 return Err(DecodeError::new("unexpected end group tag"));
897 }
898 return Ok(());
899 }
900
901 M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
902 }
903 }
904
905 pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
906 where
907 M: Message,
908 {
909 for msg in messages {
910 encode(tag, msg, buf);
911 }
912 }
913
914 pub fn merge_repeated<M>(
915 tag: u32,
916 wire_type: WireType,
917 messages: &mut Vec<M>,
918 buf: &mut impl Buf,
919 ctx: DecodeContext,
920 ) -> Result<(), DecodeError>
921 where
922 M: Message + Default,
923 {
924 check_wire_type(WireType::StartGroup, wire_type)?;
925 let mut msg = M::default();
926 merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
927 messages.push(msg);
928 Ok(())
929 }
930
931 #[inline]
932 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
933 where
934 M: Message,
935 {
936 2 * key_len(tag) + msg.encoded_len()
937 }
938
939 #[inline]
940 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
941 where
942 M: Message,
943 {
944 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
945 }
946}
947
948macro_rules! map {
951 ($map_ty:ident) => {
952 use crate::encoding::*;
953 use core::hash::Hash;
954
955 pub fn encode<K, V, B, KE, KL, VE, VL>(
957 key_encode: KE,
958 key_encoded_len: KL,
959 val_encode: VE,
960 val_encoded_len: VL,
961 tag: u32,
962 values: &$map_ty<K, V>,
963 buf: &mut B,
964 ) where
965 K: Default + Eq + Hash + Ord,
966 V: Default + PartialEq,
967 B: BufMut,
968 KE: Fn(u32, &K, &mut B),
969 KL: Fn(u32, &K) -> usize,
970 VE: Fn(u32, &V, &mut B),
971 VL: Fn(u32, &V) -> usize,
972 {
973 encode_with_default(
974 key_encode,
975 key_encoded_len,
976 val_encode,
977 val_encoded_len,
978 &V::default(),
979 tag,
980 values,
981 buf,
982 )
983 }
984
985 pub fn merge<K, V, B, KM, VM>(
987 key_merge: KM,
988 val_merge: VM,
989 values: &mut $map_ty<K, V>,
990 buf: &mut B,
991 ctx: DecodeContext,
992 ) -> Result<(), DecodeError>
993 where
994 K: Default + Eq + Hash + Ord,
995 V: Default,
996 B: Buf,
997 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
998 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
999 {
1000 merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1001 }
1002
1003 pub fn encoded_len<K, V, KL, VL>(
1005 key_encoded_len: KL,
1006 val_encoded_len: VL,
1007 tag: u32,
1008 values: &$map_ty<K, V>,
1009 ) -> usize
1010 where
1011 K: Default + Eq + Hash + Ord,
1012 V: Default + PartialEq,
1013 KL: Fn(u32, &K) -> usize,
1014 VL: Fn(u32, &V) -> usize,
1015 {
1016 encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1017 }
1018
1019 pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1024 key_encode: KE,
1025 key_encoded_len: KL,
1026 val_encode: VE,
1027 val_encoded_len: VL,
1028 val_default: &V,
1029 tag: u32,
1030 values: &$map_ty<K, V>,
1031 buf: &mut B,
1032 ) where
1033 K: Default + Eq + Hash + Ord,
1034 V: PartialEq,
1035 B: BufMut,
1036 KE: Fn(u32, &K, &mut B),
1037 KL: Fn(u32, &K) -> usize,
1038 VE: Fn(u32, &V, &mut B),
1039 VL: Fn(u32, &V) -> usize,
1040 {
1041 for (key, val) in values.iter() {
1042 let skip_key = key == &K::default();
1043 let skip_val = val == val_default;
1044
1045 let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1046 + (if skip_val { 0 } else { val_encoded_len(2, val) });
1047
1048 encode_key(tag, WireType::LengthDelimited, buf);
1049 encode_varint(len as u64, buf);
1050 if !skip_key {
1051 key_encode(1, key, buf);
1052 }
1053 if !skip_val {
1054 val_encode(2, val, buf);
1055 }
1056 }
1057 }
1058
1059 pub fn merge_with_default<K, V, B, KM, VM>(
1064 key_merge: KM,
1065 val_merge: VM,
1066 val_default: V,
1067 values: &mut $map_ty<K, V>,
1068 buf: &mut B,
1069 ctx: DecodeContext,
1070 ) -> Result<(), DecodeError>
1071 where
1072 K: Default + Eq + Hash + Ord,
1073 B: Buf,
1074 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1075 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1076 {
1077 let mut key = Default::default();
1078 let mut val = val_default;
1079 ctx.limit_reached()?;
1080 merge_loop(
1081 &mut (&mut key, &mut val),
1082 buf,
1083 ctx.enter_recursion(),
1084 |&mut (ref mut key, ref mut val), buf, ctx| {
1085 let (tag, wire_type) = decode_key(buf)?;
1086 match tag {
1087 1 => key_merge(wire_type, key, buf, ctx),
1088 2 => val_merge(wire_type, val, buf, ctx),
1089 _ => skip_field(wire_type, tag, buf, ctx),
1090 }
1091 },
1092 )?;
1093 values.insert(key, val);
1094
1095 Ok(())
1096 }
1097
1098 pub fn encoded_len_with_default<K, V, KL, VL>(
1103 key_encoded_len: KL,
1104 val_encoded_len: VL,
1105 val_default: &V,
1106 tag: u32,
1107 values: &$map_ty<K, V>,
1108 ) -> usize
1109 where
1110 K: Default + Eq + Hash + Ord,
1111 V: PartialEq,
1112 KL: Fn(u32, &K) -> usize,
1113 VL: Fn(u32, &V) -> usize,
1114 {
1115 key_len(tag) * values.len()
1116 + values
1117 .iter()
1118 .map(|(key, val)| {
1119 let len = (if key == &K::default() {
1120 0
1121 } else {
1122 key_encoded_len(1, key)
1123 }) + (if val == val_default {
1124 0
1125 } else {
1126 val_encoded_len(2, val)
1127 });
1128 encoded_len_varint(len as u64) + len
1129 })
1130 .sum::<usize>()
1131 }
1132 };
1133}
1134
1135#[cfg(feature = "std")]
1136pub mod hash_map {
1137 use std::collections::HashMap;
1138 map!(HashMap);
1139}
1140
1141pub mod btree_map {
1142 map!(BTreeMap);
1143}
1144
1145#[cfg(test)]
1146mod test {
1147 #[cfg(not(feature = "std"))]
1148 use alloc::string::ToString;
1149 use core::borrow::Borrow;
1150 use core::fmt::Debug;
1151
1152 use ::bytes::BytesMut;
1153 use proptest::{prelude::*, test_runner::TestCaseResult};
1154
1155 use super::*;
1156
1157 pub fn check_type<T, B>(
1158 value: T,
1159 tag: u32,
1160 wire_type: WireType,
1161 encode: fn(u32, &B, &mut BytesMut),
1162 merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1163 encoded_len: fn(u32, &B) -> usize,
1164 ) -> TestCaseResult
1165 where
1166 T: Debug + Default + PartialEq + Borrow<B>,
1167 B: ?Sized,
1168 {
1169 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1170
1171 let expected_len = encoded_len(tag, value.borrow());
1172
1173 let mut buf = BytesMut::with_capacity(expected_len);
1174 encode(tag, value.borrow(), &mut buf);
1175
1176 let mut buf = buf.freeze();
1177
1178 prop_assert_eq!(
1179 buf.remaining(),
1180 expected_len,
1181 "encoded_len wrong; expected: {}, actual: {}",
1182 expected_len,
1183 buf.remaining()
1184 );
1185
1186 if !buf.has_remaining() {
1187 return Ok(());
1189 }
1190
1191 let (decoded_tag, decoded_wire_type) =
1192 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1193 prop_assert_eq!(
1194 tag,
1195 decoded_tag,
1196 "decoded tag does not match; expected: {}, actual: {}",
1197 tag,
1198 decoded_tag
1199 );
1200
1201 prop_assert_eq!(
1202 wire_type,
1203 decoded_wire_type,
1204 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1205 wire_type,
1206 decoded_wire_type,
1207 );
1208
1209 match wire_type {
1210 WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1211 "64bit wire type illegal remaining: {}, tag: {}",
1212 buf.remaining(),
1213 tag
1214 ))),
1215 WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1216 "32bit wire type illegal remaining: {}, tag: {}",
1217 buf.remaining(),
1218 tag
1219 ))),
1220 _ => Ok(()),
1221 }?;
1222
1223 let mut roundtrip_value = T::default();
1224 merge(
1225 wire_type,
1226 &mut roundtrip_value,
1227 &mut buf,
1228 DecodeContext::default(),
1229 )
1230 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1231
1232 prop_assert!(
1233 !buf.has_remaining(),
1234 "expected buffer to be empty, remaining: {}",
1235 buf.remaining()
1236 );
1237
1238 prop_assert_eq!(value, roundtrip_value);
1239
1240 Ok(())
1241 }
1242
1243 pub fn check_collection_type<T, B, E, M, L>(
1244 value: T,
1245 tag: u32,
1246 wire_type: WireType,
1247 encode: E,
1248 mut merge: M,
1249 encoded_len: L,
1250 ) -> TestCaseResult
1251 where
1252 T: Debug + Default + PartialEq + Borrow<B>,
1253 B: ?Sized,
1254 E: FnOnce(u32, &B, &mut BytesMut),
1255 M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1256 L: FnOnce(u32, &B) -> usize,
1257 {
1258 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1259
1260 let expected_len = encoded_len(tag, value.borrow());
1261
1262 let mut buf = BytesMut::with_capacity(expected_len);
1263 encode(tag, value.borrow(), &mut buf);
1264
1265 let mut buf = buf.freeze();
1266
1267 prop_assert_eq!(
1268 buf.remaining(),
1269 expected_len,
1270 "encoded_len wrong; expected: {}, actual: {}",
1271 expected_len,
1272 buf.remaining()
1273 );
1274
1275 let mut roundtrip_value = Default::default();
1276 while buf.has_remaining() {
1277 let (decoded_tag, decoded_wire_type) =
1278 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1279
1280 prop_assert_eq!(
1281 tag,
1282 decoded_tag,
1283 "decoded tag does not match; expected: {}, actual: {}",
1284 tag,
1285 decoded_tag
1286 );
1287
1288 prop_assert_eq!(
1289 wire_type,
1290 decoded_wire_type,
1291 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1292 wire_type,
1293 decoded_wire_type
1294 );
1295
1296 merge(
1297 wire_type,
1298 &mut roundtrip_value,
1299 &mut buf,
1300 DecodeContext::default(),
1301 )
1302 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1303 }
1304
1305 prop_assert_eq!(value, roundtrip_value);
1306
1307 Ok(())
1308 }
1309
1310 #[test]
1311 fn string_merge_invalid_utf8() {
1312 let mut s = String::new();
1313 let buf = b"\x02\x80\x80";
1314
1315 let r = string::merge(
1316 WireType::LengthDelimited,
1317 &mut s,
1318 &mut &buf[..],
1319 DecodeContext::default(),
1320 );
1321 r.expect_err("must be an error");
1322 assert!(s.is_empty());
1323 }
1324
1325 #[cfg(feature = "std")]
1329 macro_rules! map_tests {
1330 (keys: $keys:tt,
1331 vals: $vals:tt) => {
1332 mod hash_map {
1333 map_tests!(@private HashMap, hash_map, $keys, $vals);
1334 }
1335 mod btree_map {
1336 map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1337 }
1338 };
1339
1340 (@private $map_type:ident,
1341 $mod_name:ident,
1342 [$(($key_ty:ty, $key_proto:ident)),*],
1343 $vals:tt) => {
1344 $(
1345 mod $key_proto {
1346 use std::collections::$map_type;
1347
1348 use proptest::prelude::*;
1349
1350 use crate::encoding::*;
1351 use crate::encoding::test::check_collection_type;
1352
1353 map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1354 }
1355 )*
1356 };
1357
1358 (@private $map_type:ident,
1359 $mod_name:ident,
1360 ($key_ty:ty, $key_proto:ident),
1361 [$(($val_ty:ty, $val_proto:ident)),*]) => {
1362 $(
1363 proptest! {
1364 #[test]
1365 fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1366 check_collection_type(values, tag, WireType::LengthDelimited,
1367 |tag, values, buf| {
1368 $mod_name::encode($key_proto::encode,
1369 $key_proto::encoded_len,
1370 $val_proto::encode,
1371 $val_proto::encoded_len,
1372 tag,
1373 values,
1374 buf)
1375 },
1376 |wire_type, values, buf, ctx| {
1377 check_wire_type(WireType::LengthDelimited, wire_type)?;
1378 $mod_name::merge($key_proto::merge,
1379 $val_proto::merge,
1380 values,
1381 buf,
1382 ctx)
1383 },
1384 |tag, values| {
1385 $mod_name::encoded_len($key_proto::encoded_len,
1386 $val_proto::encoded_len,
1387 tag,
1388 values)
1389 })?;
1390 }
1391 }
1392 )*
1393 };
1394 }
1395
1396 #[cfg(feature = "std")]
1397 map_tests!(keys: [
1398 (i32, int32),
1399 (i64, int64),
1400 (u32, uint32),
1401 (u64, uint64),
1402 (i32, sint32),
1403 (i64, sint64),
1404 (u32, fixed32),
1405 (u64, fixed64),
1406 (i32, sfixed32),
1407 (i64, sfixed64),
1408 (bool, bool),
1409 (String, string)
1410 ],
1411 vals: [
1412 (f32, float),
1413 (f64, double),
1414 (i32, int32),
1415 (i64, int64),
1416 (u32, uint32),
1417 (u64, uint64),
1418 (i32, sint32),
1419 (i64, sint64),
1420 (u32, fixed32),
1421 (u64, fixed64),
1422 (i32, sfixed32),
1423 (i64, sfixed64),
1424 (bool, bool),
1425 (String, string),
1426 (Vec<u8>, bytes)
1427 ]);
1428}