protobuf/coded_input_stream/
mod.rs

1mod buf_read_iter;
2mod buf_read_or_reader;
3mod input_buf;
4mod input_source;
5
6use std::io;
7use std::io::BufRead;
8use std::io::Read;
9use std::mem;
10use std::mem::MaybeUninit;
11
12#[cfg(feature = "bytes")]
13use ::bytes::Bytes;
14
15#[cfg(feature = "bytes")]
16use crate::chars::Chars;
17use crate::coded_input_stream::buf_read_iter::BufReadIter;
18use crate::enums::Enum;
19use crate::error::ProtobufError;
20use crate::error::WireError;
21use crate::misc::maybe_ununit_array_assume_init;
22use crate::reflect::types::ProtobufTypeBool;
23use crate::reflect::types::ProtobufTypeDouble;
24use crate::reflect::types::ProtobufTypeFixed;
25use crate::reflect::types::ProtobufTypeFixed32;
26use crate::reflect::types::ProtobufTypeFixed64;
27use crate::reflect::types::ProtobufTypeFloat;
28use crate::reflect::types::ProtobufTypeInt32;
29use crate::reflect::types::ProtobufTypeInt64;
30use crate::reflect::types::ProtobufTypeSfixed32;
31use crate::reflect::types::ProtobufTypeSfixed64;
32use crate::reflect::types::ProtobufTypeSint32;
33use crate::reflect::types::ProtobufTypeSint64;
34use crate::reflect::types::ProtobufTypeTrait;
35use crate::reflect::types::ProtobufTypeUint32;
36use crate::reflect::types::ProtobufTypeUint64;
37use crate::reflect::MessageDescriptor;
38use crate::unknown::UnknownValue;
39use crate::varint::decode::decode_varint32;
40use crate::varint::decode::decode_varint64;
41use crate::varint::MAX_VARINT_ENCODED_LEN;
42use crate::wire_format;
43use crate::wire_format::WireType;
44use crate::zigzag::decode_zig_zag_32;
45use crate::zigzag::decode_zig_zag_64;
46use crate::EnumOrUnknown;
47use crate::Message;
48use crate::MessageDyn;
49
50// Default recursion level limit. 100 is the default value of C++'s implementation.
51const DEFAULT_RECURSION_LIMIT: u32 = 100;
52
53// Max allocated vec when reading length-delimited from unknown input stream
54pub(crate) const READ_RAW_BYTES_MAX_ALLOC: usize = 10_000_000;
55
56/// Buffered read with handy utilities.
57#[derive(Debug)]
58pub struct CodedInputStream<'a> {
59    source: BufReadIter<'a>,
60    recursion_level: u32,
61    recursion_limit: u32,
62}
63
64impl<'a> CodedInputStream<'a> {
65    /// Wrap a `Read`.
66    ///
67    /// Note resulting `CodedInputStream` is buffered.
68    ///
69    /// If `Read` is buffered, the resulting stream will be double buffered,
70    /// consider using [`from_buf_read`](Self::from_buf_read) instead.
71    pub fn new(read: &'a mut dyn Read) -> CodedInputStream<'a> {
72        CodedInputStream::from_buf_read_iter(BufReadIter::from_read(read))
73    }
74
75    /// Create from `BufRead`.
76    ///
77    /// `CodedInputStream` will utilize `BufRead` buffer.
78    pub fn from_buf_read(buf_read: &'a mut dyn BufRead) -> CodedInputStream<'a> {
79        CodedInputStream::from_buf_read_iter(BufReadIter::from_buf_read(buf_read))
80    }
81
82    /// Read from byte slice
83    pub fn from_bytes(bytes: &'a [u8]) -> CodedInputStream<'a> {
84        CodedInputStream::from_buf_read_iter(BufReadIter::from_byte_slice(bytes))
85    }
86
87    /// Read from `Bytes`.
88    ///
89    /// `CodedInputStream` operations like
90    /// [`read_tokio_bytes`](crate::CodedInputStream::read_tokio_bytes)
91    /// will return a shared copy of this bytes object.
92    #[cfg(feature = "bytes")]
93    pub fn from_tokio_bytes(bytes: &'a Bytes) -> CodedInputStream<'a> {
94        CodedInputStream::from_buf_read_iter(BufReadIter::from_bytes(bytes))
95    }
96
97    fn from_buf_read_iter(source: BufReadIter<'a>) -> CodedInputStream<'a> {
98        CodedInputStream {
99            source,
100            recursion_level: 0,
101            recursion_limit: DEFAULT_RECURSION_LIMIT,
102        }
103    }
104
105    /// Set the recursion limit.
106    pub fn set_recursion_limit(&mut self, limit: u32) {
107        self.recursion_limit = limit;
108    }
109
110    #[inline]
111    pub(crate) fn incr_recursion(&mut self) -> crate::Result<()> {
112        if self.recursion_level >= self.recursion_limit {
113            return Err(ProtobufError::WireError(WireError::OverRecursionLimit).into());
114        }
115        self.recursion_level += 1;
116        Ok(())
117    }
118
119    #[inline]
120    pub(crate) fn decr_recursion(&mut self) {
121        self.recursion_level -= 1;
122    }
123
124    /// How many bytes processed
125    pub fn pos(&self) -> u64 {
126        self.source.pos()
127    }
128
129    /// How many bytes until current limit
130    pub fn bytes_until_limit(&self) -> u64 {
131        self.source.bytes_until_limit()
132    }
133
134    /// Read bytes into given `buf`.
135    #[inline]
136    pub fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
137        self.source.read_exact(buf)
138    }
139
140    /// Read exact number of bytes as `Bytes` object.
141    ///
142    /// This operation returns a shared view if `CodedInputStream` is
143    /// constructed with `Bytes` parameter.
144    #[cfg(feature = "bytes")]
145    fn read_raw_tokio_bytes(&mut self, count: usize) -> crate::Result<Bytes> {
146        self.source.read_exact_bytes(count)
147    }
148
149    /// Read one byte
150    #[inline(always)]
151    pub fn read_raw_byte(&mut self) -> crate::Result<u8> {
152        self.source.read_byte()
153    }
154
155    /// Push new limit, return previous limit.
156    pub fn push_limit(&mut self, limit: u64) -> crate::Result<u64> {
157        self.source.push_limit(limit)
158    }
159
160    /// Restore previous limit.
161    pub fn pop_limit(&mut self, old_limit: u64) {
162        self.source.pop_limit(old_limit);
163    }
164
165    /// Are we at EOF?
166    #[inline(always)]
167    pub fn eof(&mut self) -> crate::Result<bool> {
168        self.source.eof()
169    }
170
171    /// Check we are at EOF.
172    ///
173    /// Return error if we are not at EOF.
174    pub fn check_eof(&mut self) -> crate::Result<()> {
175        let eof = self.eof()?;
176        if !eof {
177            return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
178        }
179        Ok(())
180    }
181
182    fn read_raw_varint64_slow(&mut self) -> crate::Result<u64> {
183        let mut r: u64 = 0;
184        let mut i = 0;
185        loop {
186            if i == MAX_VARINT_ENCODED_LEN {
187                return Err(ProtobufError::WireError(WireError::IncorrectVarint).into());
188            }
189            let b = self.read_raw_byte()?;
190            if i == 9 && (b & 0x7f) > 1 {
191                return Err(ProtobufError::WireError(WireError::IncorrectVarint).into());
192            }
193            r = r | (((b & 0x7f) as u64) << (i * 7));
194            i += 1;
195            if b < 0x80 {
196                return Ok(r);
197            }
198        }
199    }
200
201    fn read_raw_varint32_slow(&mut self) -> crate::Result<u32> {
202        let v = self.read_raw_varint64_slow()?;
203        if v > u32::MAX as u64 {
204            return Err(ProtobufError::WireError(WireError::U32Overflow(v)).into());
205        }
206        Ok(v as u32)
207    }
208
209    /// Read varint
210    #[inline]
211    pub fn read_raw_varint64(&mut self) -> crate::Result<u64> {
212        let rem = self.source.remaining_in_buf();
213
214        match decode_varint64(rem)? {
215            Some((r, c)) => {
216                self.source.consume(c);
217                Ok(r)
218            }
219            None => self.read_raw_varint64_slow(),
220        }
221    }
222
223    /// Read varint
224    #[inline]
225    pub fn read_raw_varint32(&mut self) -> crate::Result<u32> {
226        let rem = self.source.remaining_in_buf();
227
228        match decode_varint32(rem)? {
229            Some((r, c)) => {
230                self.source.consume(c);
231                Ok(r)
232            }
233            None => self.read_raw_varint32_slow(),
234        }
235    }
236
237    #[inline]
238    fn read_raw_varint32_or_eof(&mut self) -> crate::Result<Option<u32>> {
239        let rem = self.source.remaining_in_buf();
240        let v = decode_varint32(rem)?;
241        match v {
242            Some((r, c)) => {
243                self.source.consume(c);
244                Ok(Some(r))
245            }
246            None => {
247                if self.eof()? {
248                    Ok(None)
249                } else {
250                    let v = self.read_raw_varint32_slow()?;
251                    Ok(Some(v))
252                }
253            }
254        }
255    }
256
257    /// Read little-endian 32-bit integer
258    pub fn read_raw_little_endian32(&mut self) -> crate::Result<u32> {
259        let mut bytes = [MaybeUninit::uninit(); 4];
260        self.read_exact(&mut bytes)?;
261        // SAFETY: `read_exact` guarantees that the buffer is filled.
262        let bytes = unsafe { maybe_ununit_array_assume_init(bytes) };
263        Ok(u32::from_le_bytes(bytes))
264    }
265
266    /// Read little-endian 64-bit integer
267    pub fn read_raw_little_endian64(&mut self) -> crate::Result<u64> {
268        let mut bytes = [MaybeUninit::uninit(); 8];
269        self.read_exact(&mut bytes)?;
270        // SAFETY: `read_exact` guarantees that the buffer is filled.
271        let bytes = unsafe { maybe_ununit_array_assume_init(bytes) };
272        Ok(u64::from_le_bytes(bytes))
273    }
274
275    /// Read tag number as `u32` or None if EOF is reached.
276    #[inline]
277    pub fn read_raw_tag_or_eof(&mut self) -> crate::Result<Option<u32>> {
278        self.read_raw_varint32_or_eof()
279    }
280
281    /// Read tag
282    #[inline]
283    pub(crate) fn read_tag(&mut self) -> crate::Result<wire_format::Tag> {
284        let v = self.read_raw_varint32()?;
285        wire_format::Tag::new(v)
286    }
287
288    /// Read tag, return it is pair (field number, wire type)
289    #[inline]
290    pub(crate) fn read_tag_unpack(&mut self) -> crate::Result<(u32, WireType)> {
291        self.read_tag().map(|t| t.unpack())
292    }
293
294    /// Read `double`
295    pub fn read_double(&mut self) -> crate::Result<f64> {
296        let bits = self.read_raw_little_endian64()?;
297        Ok(f64::from_bits(bits))
298    }
299
300    /// Read `float`
301    pub fn read_float(&mut self) -> crate::Result<f32> {
302        let bits = self.read_raw_little_endian32()?;
303        Ok(f32::from_bits(bits))
304    }
305
306    /// Read `int64`
307    pub fn read_int64(&mut self) -> crate::Result<i64> {
308        self.read_raw_varint64().map(|v| v as i64)
309    }
310
311    /// Read `int32`
312    pub fn read_int32(&mut self) -> crate::Result<i32> {
313        let v = self.read_int64()?;
314        i32::try_from(v).map_err(|_| WireError::I32Overflow(v).into())
315    }
316
317    /// Read `uint64`
318    pub fn read_uint64(&mut self) -> crate::Result<u64> {
319        self.read_raw_varint64()
320    }
321
322    /// Read `uint32`
323    pub fn read_uint32(&mut self) -> crate::Result<u32> {
324        self.read_raw_varint32()
325    }
326
327    /// Read `sint64`
328    pub fn read_sint64(&mut self) -> crate::Result<i64> {
329        self.read_uint64().map(decode_zig_zag_64)
330    }
331
332    /// Read `sint32`
333    pub fn read_sint32(&mut self) -> crate::Result<i32> {
334        self.read_uint32().map(decode_zig_zag_32)
335    }
336
337    /// Read `fixed64`
338    pub fn read_fixed64(&mut self) -> crate::Result<u64> {
339        self.read_raw_little_endian64()
340    }
341
342    /// Read `fixed32`
343    pub fn read_fixed32(&mut self) -> crate::Result<u32> {
344        self.read_raw_little_endian32()
345    }
346
347    /// Read `sfixed64`
348    pub fn read_sfixed64(&mut self) -> crate::Result<i64> {
349        self.read_raw_little_endian64().map(|v| v as i64)
350    }
351
352    /// Read `sfixed32`
353    pub fn read_sfixed32(&mut self) -> crate::Result<i32> {
354        self.read_raw_little_endian32().map(|v| v as i32)
355    }
356
357    /// Read `bool`
358    pub fn read_bool(&mut self) -> crate::Result<bool> {
359        self.read_raw_varint64().map(|v| v != 0)
360    }
361
362    pub(crate) fn read_enum_value(&mut self) -> crate::Result<i32> {
363        self.read_int32()
364    }
365
366    /// Read `enum` as `ProtobufEnum`
367    pub fn read_enum<E: Enum>(&mut self) -> crate::Result<E> {
368        let i = self.read_enum_value()?;
369        match Enum::from_i32(i) {
370            Some(e) => Ok(e),
371            None => Err(ProtobufError::WireError(WireError::InvalidEnumValue(E::NAME, i)).into()),
372        }
373    }
374
375    /// Read `enum` as `ProtobufEnumOrUnknown`
376    pub fn read_enum_or_unknown<E: Enum>(&mut self) -> crate::Result<EnumOrUnknown<E>> {
377        Ok(EnumOrUnknown::from_i32(self.read_int32()?))
378    }
379
380    fn read_repeated_packed_fixed_into<T: ProtobufTypeFixed>(
381        &mut self,
382        target: &mut Vec<T::ProtobufValue>,
383    ) -> crate::Result<()> {
384        let len_bytes = self.read_raw_varint64()?;
385
386        let reserve = if len_bytes <= READ_RAW_BYTES_MAX_ALLOC as u64 {
387            (len_bytes as usize) / (T::ENCODED_SIZE as usize)
388        } else {
389            // prevent OOM on malformed input
390            // probably should truncate
391            READ_RAW_BYTES_MAX_ALLOC / (T::ENCODED_SIZE as usize)
392        };
393
394        target.reserve(reserve);
395
396        let old_limit = self.push_limit(len_bytes)?;
397        while !self.eof()? {
398            target.push(T::read(self)?);
399        }
400        self.pop_limit(old_limit);
401        Ok(())
402    }
403
404    fn read_repeated_packed_into<T: ProtobufTypeTrait>(
405        &mut self,
406        target: &mut Vec<T::ProtobufValue>,
407    ) -> crate::Result<()> {
408        let len_bytes = self.read_raw_varint64()?;
409
410        // value is at least 1 bytes, so this is lower bound of element count
411        let reserve = if len_bytes <= READ_RAW_BYTES_MAX_ALLOC as u64 {
412            len_bytes as usize
413        } else {
414            // prevent OOM on malformed input
415            READ_RAW_BYTES_MAX_ALLOC
416        };
417
418        target.reserve(reserve);
419
420        let old_limit = self.push_limit(len_bytes)?;
421        while !self.eof()? {
422            target.push(T::read(self)?);
423        }
424        self.pop_limit(old_limit);
425        Ok(())
426    }
427
428    /// Read repeated packed `double`
429    pub fn read_repeated_packed_double_into(&mut self, target: &mut Vec<f64>) -> crate::Result<()> {
430        self.read_repeated_packed_fixed_into::<ProtobufTypeDouble>(target)
431    }
432
433    /// Read repeated packed `float`
434    pub fn read_repeated_packed_float_into(&mut self, target: &mut Vec<f32>) -> crate::Result<()> {
435        self.read_repeated_packed_fixed_into::<ProtobufTypeFloat>(target)
436    }
437
438    /// Read repeated packed `int64`
439    pub fn read_repeated_packed_int64_into(&mut self, target: &mut Vec<i64>) -> crate::Result<()> {
440        self.read_repeated_packed_into::<ProtobufTypeInt64>(target)
441    }
442
443    /// Read repeated packed `int32`
444    pub fn read_repeated_packed_int32_into(&mut self, target: &mut Vec<i32>) -> crate::Result<()> {
445        self.read_repeated_packed_into::<ProtobufTypeInt32>(target)
446    }
447
448    /// Read repeated packed `uint64`
449    pub fn read_repeated_packed_uint64_into(&mut self, target: &mut Vec<u64>) -> crate::Result<()> {
450        self.read_repeated_packed_into::<ProtobufTypeUint64>(target)
451    }
452
453    /// Read repeated packed `uint32`
454    pub fn read_repeated_packed_uint32_into(&mut self, target: &mut Vec<u32>) -> crate::Result<()> {
455        self.read_repeated_packed_into::<ProtobufTypeUint32>(target)
456    }
457
458    /// Read repeated packed `sint64`
459    pub fn read_repeated_packed_sint64_into(&mut self, target: &mut Vec<i64>) -> crate::Result<()> {
460        self.read_repeated_packed_into::<ProtobufTypeSint64>(target)
461    }
462
463    /// Read repeated packed `sint32`
464    pub fn read_repeated_packed_sint32_into(&mut self, target: &mut Vec<i32>) -> crate::Result<()> {
465        self.read_repeated_packed_into::<ProtobufTypeSint32>(target)
466    }
467
468    /// Read repeated packed `fixed64`
469    pub fn read_repeated_packed_fixed64_into(
470        &mut self,
471        target: &mut Vec<u64>,
472    ) -> crate::Result<()> {
473        self.read_repeated_packed_fixed_into::<ProtobufTypeFixed64>(target)
474    }
475
476    /// Read repeated packed `fixed32`
477    pub fn read_repeated_packed_fixed32_into(
478        &mut self,
479        target: &mut Vec<u32>,
480    ) -> crate::Result<()> {
481        self.read_repeated_packed_fixed_into::<ProtobufTypeFixed32>(target)
482    }
483
484    /// Read repeated packed `sfixed64`
485    pub fn read_repeated_packed_sfixed64_into(
486        &mut self,
487        target: &mut Vec<i64>,
488    ) -> crate::Result<()> {
489        self.read_repeated_packed_fixed_into::<ProtobufTypeSfixed64>(target)
490    }
491
492    /// Read repeated packed `sfixed32`
493    pub fn read_repeated_packed_sfixed32_into(
494        &mut self,
495        target: &mut Vec<i32>,
496    ) -> crate::Result<()> {
497        self.read_repeated_packed_fixed_into::<ProtobufTypeSfixed32>(target)
498    }
499
500    /// Read repeated packed `bool`
501    pub fn read_repeated_packed_bool_into(&mut self, target: &mut Vec<bool>) -> crate::Result<()> {
502        self.read_repeated_packed_into::<ProtobufTypeBool>(target)
503    }
504
505    /// Read repeated packed enum values into the vector.
506    pub(crate) fn read_repeated_packed_enum_values_into(
507        &mut self,
508        target: &mut Vec<i32>,
509    ) -> crate::Result<()> {
510        self.read_repeated_packed_into::<ProtobufTypeInt32>(target)
511    }
512
513    fn skip_group(&mut self) -> crate::Result<()> {
514        self.incr_recursion()?;
515        let ret = self.skip_group_no_depth_check();
516        self.decr_recursion();
517        ret
518    }
519
520    fn skip_group_no_depth_check(&mut self) -> crate::Result<()> {
521        while !self.eof()? {
522            let wire_type = self.read_tag_unpack()?.1;
523            if wire_type == WireType::EndGroup {
524                break;
525            }
526            self.skip_field(wire_type)?;
527        }
528        Ok(())
529    }
530
531    /// Read `UnknownValue`
532    pub fn read_unknown(&mut self, wire_type: WireType) -> crate::Result<UnknownValue> {
533        match wire_type {
534            WireType::Varint => self.read_raw_varint64().map(|v| UnknownValue::Varint(v)),
535            WireType::Fixed64 => self.read_fixed64().map(|v| UnknownValue::Fixed64(v)),
536            WireType::Fixed32 => self.read_fixed32().map(|v| UnknownValue::Fixed32(v)),
537            WireType::LengthDelimited => {
538                let len = self.read_raw_varint32()?;
539                self.read_raw_bytes(len)
540                    .map(|v| UnknownValue::LengthDelimited(v))
541            }
542            WireType::StartGroup => {
543                self.skip_group()?;
544                // We do not support groups, so just return something.
545                Ok(UnknownValue::LengthDelimited(Vec::new()))
546            }
547            WireType::EndGroup => {
548                Err(ProtobufError::WireError(WireError::UnexpectedWireType(wire_type)).into())
549            }
550        }
551    }
552
553    /// Skip field.
554    pub fn skip_field(&mut self, wire_type: WireType) -> crate::Result<()> {
555        match wire_type {
556            WireType::Varint => self.read_raw_varint64().map(|_| ()),
557            WireType::Fixed64 => self.read_fixed64().map(|_| ()),
558            WireType::Fixed32 => self.read_fixed32().map(|_| ()),
559            WireType::LengthDelimited => {
560                let len = self.read_raw_varint32()?;
561                self.skip_raw_bytes(len)
562            }
563            WireType::StartGroup => self.skip_group(),
564            WireType::EndGroup => {
565                Err(ProtobufError::WireError(WireError::UnexpectedWireType(wire_type)).into())
566            }
567        }
568    }
569
570    /// Read raw bytes into the supplied vector.  The vector will be resized as needed and
571    /// overwritten.
572    pub fn read_raw_bytes_into(&mut self, count: u32, target: &mut Vec<u8>) -> crate::Result<()> {
573        self.source.read_exact_to_vec(count as usize, target)
574    }
575
576    /// Read exact number of bytes
577    pub fn read_raw_bytes(&mut self, count: u32) -> crate::Result<Vec<u8>> {
578        let mut r = Vec::new();
579        self.read_raw_bytes_into(count, &mut r)?;
580        Ok(r)
581    }
582
583    /// Skip exact number of bytes
584    pub fn skip_raw_bytes(&mut self, count: u32) -> crate::Result<()> {
585        self.source.skip_bytes(count)
586    }
587
588    /// Read `bytes` field, length delimited
589    pub fn read_bytes(&mut self) -> crate::Result<Vec<u8>> {
590        let mut r = Vec::new();
591        self.read_bytes_into(&mut r)?;
592        Ok(r)
593    }
594
595    /// Read `bytes` field, length delimited
596    #[cfg(feature = "bytes")]
597    pub fn read_tokio_bytes(&mut self) -> crate::Result<Bytes> {
598        let len = self.read_raw_varint32()?;
599        self.read_raw_tokio_bytes(len as usize)
600    }
601
602    /// Read `string` field, length delimited
603    #[cfg(feature = "bytes")]
604    pub fn read_tokio_chars(&mut self) -> crate::Result<Chars> {
605        let bytes = self.read_tokio_bytes()?;
606        Ok(Chars::from_bytes(bytes).map_err(ProtobufError::Utf8)?)
607    }
608
609    /// Read `bytes` field, length delimited
610    pub fn read_bytes_into(&mut self, target: &mut Vec<u8>) -> crate::Result<()> {
611        let len = self.read_raw_varint32()?;
612        self.read_raw_bytes_into(len, target)?;
613        Ok(())
614    }
615
616    /// Read `string` field, length delimited
617    pub fn read_string(&mut self) -> crate::Result<String> {
618        let mut r = String::new();
619        self.read_string_into(&mut r)?;
620        Ok(r)
621    }
622
623    /// Read `string` field, length delimited
624    pub fn read_string_into(&mut self, target: &mut String) -> crate::Result<()> {
625        target.clear();
626        // take target's buffer
627        let mut vec = mem::replace(target, String::new()).into_bytes();
628        self.read_bytes_into(&mut vec)?;
629
630        let s = match String::from_utf8(vec) {
631            Ok(t) => t,
632            Err(_) => return Err(ProtobufError::WireError(WireError::Utf8Error).into()),
633        };
634        *target = s;
635        Ok(())
636    }
637
638    /// Read message, do not check if message is initialized
639    pub fn merge_message<M: Message>(&mut self, message: &mut M) -> crate::Result<()> {
640        self.incr_recursion()?;
641        let ret = self.merge_message_no_depth_check(message);
642        self.decr_recursion();
643        ret
644    }
645
646    fn merge_message_no_depth_check<M: Message>(&mut self, message: &mut M) -> crate::Result<()> {
647        let len = self.read_raw_varint64()?;
648        let old_limit = self.push_limit(len)?;
649        message.merge_from(self)?;
650        self.pop_limit(old_limit);
651        Ok(())
652    }
653
654    /// Like `merge_message`, but for dynamic messages.
655    pub fn merge_message_dyn(&mut self, message: &mut dyn MessageDyn) -> crate::Result<()> {
656        let len = self.read_raw_varint64()?;
657        let old_limit = self.push_limit(len)?;
658        message.merge_from_dyn(self)?;
659        self.pop_limit(old_limit);
660        Ok(())
661    }
662
663    /// Read message
664    pub fn read_message<M: Message>(&mut self) -> crate::Result<M> {
665        let mut r: M = Message::new();
666        self.merge_message(&mut r)?;
667        r.check_initialized()?;
668        Ok(r)
669    }
670
671    /// Read message.
672    pub fn read_message_dyn(
673        &mut self,
674        descriptor: &MessageDescriptor,
675    ) -> crate::Result<Box<dyn MessageDyn>> {
676        let mut r = descriptor.new_instance();
677        self.merge_message_dyn(&mut *r)?;
678        r.check_initialized_dyn()?;
679        Ok(r)
680    }
681}
682
683impl<'a> Read for CodedInputStream<'a> {
684    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
685        self.source.read(buf).map_err(Into::into)
686    }
687}
688
689impl<'a> BufRead for CodedInputStream<'a> {
690    fn fill_buf(&mut self) -> io::Result<&[u8]> {
691        self.source.fill_buf().map_err(Into::into)
692    }
693
694    fn consume(&mut self, amt: usize) {
695        self.source.consume(amt)
696    }
697}
698
699#[cfg(test)]
700mod test {
701
702    use std::fmt::Debug;
703    use std::io;
704    use std::io::BufRead;
705    use std::io::Read;
706
707    use super::CodedInputStream;
708    use super::READ_RAW_BYTES_MAX_ALLOC;
709    use crate::error::ProtobufError;
710    use crate::hex::decode_hex;
711    use crate::wire_format::Tag;
712    use crate::wire_format::WireType;
713    use crate::CodedOutputStream;
714
715    fn test_read_partial<F>(hex: &str, mut callback: F)
716    where
717        F: FnMut(&mut CodedInputStream),
718    {
719        let d = decode_hex(hex);
720        // Test with buffered reader.
721        {
722            let mut reader = io::Cursor::new(&d);
723            let mut is = CodedInputStream::from_buf_read(&mut reader as &mut dyn BufRead);
724            assert_eq!(0, is.pos());
725            callback(&mut is);
726        }
727        // Test from bytes.
728        {
729            let mut is = CodedInputStream::from_bytes(&d);
730            assert_eq!(0, is.pos());
731            callback(&mut is);
732        }
733    }
734
735    fn test_read<F>(hex: &str, mut callback: F)
736    where
737        F: FnMut(&mut CodedInputStream),
738    {
739        let len = decode_hex(hex).len();
740        test_read_partial(hex, |reader| {
741            callback(reader);
742            assert!(reader.eof().expect("eof"));
743            assert_eq!(len as u64, reader.pos());
744        });
745    }
746
747    fn test_read_v<F, V>(hex: &str, v: V, mut callback: F)
748    where
749        F: FnMut(&mut CodedInputStream) -> crate::Result<V>,
750        V: PartialEq + Debug,
751    {
752        test_read(hex, |reader| {
753            assert_eq!(v, callback(reader).unwrap());
754        });
755    }
756
757    #[test]
758    fn test_input_stream_read_raw_byte() {
759        test_read("17", |is| {
760            assert_eq!(23, is.read_raw_byte().unwrap());
761        });
762    }
763
764    #[test]
765    fn test_input_stream_read_raw_varint() {
766        test_read_v("07", 7, |reader| reader.read_raw_varint32());
767        test_read_v("07", 7, |reader| reader.read_raw_varint64());
768
769        test_read_v("96 01", 150, |reader| reader.read_raw_varint32());
770        test_read_v("96 01", 150, |reader| reader.read_raw_varint64());
771
772        test_read_v(
773            "ff ff ff ff ff ff ff ff ff 01",
774            0xffffffffffffffff,
775            |reader| reader.read_raw_varint64(),
776        );
777
778        test_read_v("ff ff ff ff 0f", 0xffffffff, |reader| {
779            reader.read_raw_varint32()
780        });
781        test_read_v("ff ff ff ff 0f", 0xffffffff, |reader| {
782            reader.read_raw_varint64()
783        });
784    }
785
786    #[test]
787    fn test_input_stream_read_raw_varint_out_of_range() {
788        test_read_partial("ff ff ff ff ff ff ff ff ff 02", |is| {
789            assert!(is.read_raw_varint64().is_err());
790        });
791        test_read_partial("ff ff ff ff ff ff ff ff ff 02", |is| {
792            assert!(is.read_raw_varint32().is_err());
793        });
794    }
795
796    #[test]
797    fn test_input_stream_read_raw_varint_too_long() {
798        // varint cannot have length > 10
799        test_read_partial("ff ff ff ff ff ff ff ff ff ff 01", |reader| {
800            let error = reader.read_raw_varint64().unwrap_err().0;
801            match *error {
802                ProtobufError::WireError(..) => (),
803                _ => panic!(),
804            }
805        });
806        test_read_partial("ff ff ff ff ff ff ff ff ff ff 01", |reader| {
807            let error = reader.read_raw_varint32().unwrap_err().0;
808            match *error {
809                ProtobufError::WireError(..) => (),
810                _ => panic!(),
811            }
812        });
813    }
814
815    #[test]
816    fn test_input_stream_read_raw_varint_unexpected_eof() {
817        test_read_partial("96 97", |reader| {
818            let error = reader.read_raw_varint32().unwrap_err().0;
819            match *error {
820                ProtobufError::WireError(..) => (),
821                _ => panic!(),
822            }
823        });
824    }
825
826    #[test]
827    fn test_input_stream_read_raw_varint_pos() {
828        test_read_partial("95 01 98", |reader| {
829            assert_eq!(149, reader.read_raw_varint32().unwrap());
830            assert_eq!(2, reader.pos());
831        });
832    }
833
834    #[test]
835    fn test_input_stream_read_int32() {
836        test_read_v("02", 2, |reader| reader.read_int32());
837    }
838
839    #[test]
840    fn test_input_stream_read_float() {
841        test_read_v("95 73 13 61", 17e19, |is| is.read_float());
842    }
843
844    #[test]
845    fn test_input_stream_read_double() {
846        test_read_v("40 d5 ab 68 b3 07 3d 46", 23e29, |is| is.read_double());
847    }
848
849    #[test]
850    fn test_input_stream_skip_raw_bytes() {
851        test_read("", |reader| {
852            reader.skip_raw_bytes(0).unwrap();
853        });
854        test_read("aa bb", |reader| {
855            reader.skip_raw_bytes(2).unwrap();
856        });
857        test_read("aa bb cc dd ee ff", |reader| {
858            reader.skip_raw_bytes(6).unwrap();
859        });
860    }
861
862    #[test]
863    fn test_input_stream_read_raw_bytes() {
864        test_read("", |reader| {
865            assert_eq!(
866                Vec::from(&b""[..]),
867                reader.read_raw_bytes(0).expect("read_raw_bytes")
868            );
869        })
870    }
871
872    #[test]
873    fn test_input_stream_limits() {
874        test_read("aa bb cc", |is| {
875            let old_limit = is.push_limit(1).unwrap();
876            assert_eq!(1, is.bytes_until_limit());
877            let r1 = is.read_raw_bytes(1).unwrap();
878            assert_eq!(&[0xaa as u8], &r1[..]);
879            is.pop_limit(old_limit);
880            let r2 = is.read_raw_bytes(2).unwrap();
881            assert_eq!(&[0xbb as u8, 0xcc], &r2[..]);
882        });
883    }
884
885    #[test]
886    fn test_input_stream_io_read() {
887        test_read("aa bb cc", |is| {
888            let mut buf = [0; 3];
889            assert_eq!(Read::read(is, &mut buf).expect("io::Read"), 3);
890            assert_eq!(buf, [0xaa, 0xbb, 0xcc]);
891        });
892    }
893
894    #[test]
895    fn test_input_stream_io_bufread() {
896        test_read("aa bb cc", |is| {
897            assert_eq!(
898                BufRead::fill_buf(is).expect("io::BufRead::fill_buf"),
899                &[0xaa, 0xbb, 0xcc]
900            );
901            BufRead::consume(is, 3);
902        });
903    }
904
905    #[test]
906    #[cfg_attr(miri, ignore)] // Miri is too slow for this test.
907    fn test_input_stream_read_raw_bytes_into_huge() {
908        let mut v = Vec::new();
909        for i in 0..READ_RAW_BYTES_MAX_ALLOC + 1000 {
910            v.push((i % 10) as u8);
911        }
912
913        let mut slice: &[u8] = v.as_slice();
914
915        let mut is = CodedInputStream::new(&mut slice);
916
917        let mut buf = Vec::new();
918
919        is.read_raw_bytes_into(READ_RAW_BYTES_MAX_ALLOC as u32 + 10, &mut buf)
920            .expect("read");
921
922        assert_eq!(READ_RAW_BYTES_MAX_ALLOC + 10, buf.len());
923
924        buf.clear();
925
926        is.read_raw_bytes_into(1000 - 10, &mut buf).expect("read");
927
928        assert_eq!(1000 - 10, buf.len());
929
930        assert!(is.eof().expect("eof"));
931    }
932
933    // Copy of this test: https://tinyurl.com/34hfavtz
934    #[test]
935    fn test_skip_group() {
936        // Create an output stream with a group in:
937        // Field 1: string "field 1"
938        // Field 2: group containing:
939        //   Field 1: fixed int32 value 100
940        //   Field 2: string "ignore me"
941        //   Field 3: nested group containing
942        //      Field 1: fixed int64 value 1000
943        // Field 3: string "field 3"
944
945        let mut vec = Vec::new();
946        let mut os = CodedOutputStream::new(&mut vec);
947        os.write_tag(1, WireType::LengthDelimited).unwrap();
948        os.write_string_no_tag("field 1").unwrap();
949
950        // The outer group...
951        os.write_tag(2, WireType::StartGroup).unwrap();
952        os.write_tag(1, WireType::Fixed32).unwrap();
953        os.write_fixed32_no_tag(100).unwrap();
954        os.write_tag(3, WireType::LengthDelimited).unwrap();
955        os.write_string_no_tag("ignore me").unwrap();
956        // The nested group...
957        os.write_tag(3, WireType::StartGroup).unwrap();
958        os.write_tag(1, WireType::Fixed64).unwrap();
959        os.write_fixed64_no_tag(1000).unwrap();
960        // Note: Not sure the field number is relevant for end group...
961        os.write_tag(3, WireType::EndGroup).unwrap();
962
963        // End the outer group
964        os.write_tag(2, WireType::EndGroup).unwrap();
965
966        os.write_tag(3, WireType::LengthDelimited).unwrap();
967        os.write_string_no_tag("field 3").unwrap();
968        os.flush().unwrap();
969        drop(os);
970
971        let mut input = CodedInputStream::from_bytes(&vec);
972        // Now act like a generated client
973        assert_eq!(
974            Tag::make(1, WireType::LengthDelimited),
975            input.read_tag().unwrap()
976        );
977        assert_eq!("field 1", &input.read_string().unwrap());
978        assert_eq!(
979            Tag::make(2, WireType::StartGroup),
980            input.read_tag().unwrap()
981        );
982        input.skip_field(WireType::StartGroup).unwrap();
983        assert_eq!(
984            Tag::make(3, WireType::LengthDelimited),
985            input.read_tag().unwrap()
986        );
987        assert_eq!("field 3", input.read_string().unwrap());
988    }
989
990    #[test]
991    fn test_shallow_nested_unknown_groups() {
992        // Test skip_group() succeeds on a start group tag 50 times
993        // followed by end group tag 50 times. We should be able to
994        // successfully skip the outermost group.
995        let mut vec = Vec::new();
996        let mut os = CodedOutputStream::new(&mut vec);
997        for _ in 0..50 {
998            os.write_tag(1, WireType::StartGroup).unwrap();
999        }
1000        for _ in 0..50 {
1001            os.write_tag(1, WireType::EndGroup).unwrap();
1002        }
1003        drop(os);
1004
1005        let mut input = CodedInputStream::from_bytes(&vec);
1006        assert!(input.skip_group().is_ok());
1007    }
1008
1009    #[test]
1010    fn test_deeply_nested_unknown_groups() {
1011        // Create an output stream that has groups nested recursively 1000
1012        // deep, and try to skip the group.
1013        // This should fail the default depth limit of 100 which ensures we
1014        // don't blow the stack on adversial input.
1015        let mut vec = Vec::new();
1016        let mut os = CodedOutputStream::new(&mut vec);
1017        for _ in 0..1000 {
1018            os.write_tag(1, WireType::StartGroup).unwrap();
1019        }
1020        for _ in 0..1000 {
1021            os.write_tag(1, WireType::EndGroup).unwrap();
1022        }
1023        drop(os);
1024
1025        let mut input = CodedInputStream::from_bytes(&vec);
1026        assert!(input
1027            .skip_group()
1028            .unwrap_err()
1029            .to_string()
1030            .contains("Over recursion limit"));
1031    }
1032}