protobuf/coded_input_stream/
buf_read_iter.rs

1use std::cmp;
2use std::io::BufRead;
3use std::io::BufReader;
4use std::io::Read;
5use std::mem;
6use std::mem::MaybeUninit;
7
8#[cfg(feature = "bytes")]
9use bytes::buf::UninitSlice;
10#[cfg(feature = "bytes")]
11use bytes::BufMut;
12#[cfg(feature = "bytes")]
13use bytes::Bytes;
14#[cfg(feature = "bytes")]
15use bytes::BytesMut;
16
17use crate::coded_input_stream::buf_read_or_reader::BufReadOrReader;
18use crate::coded_input_stream::input_buf::InputBuf;
19use crate::coded_input_stream::input_source::InputSource;
20use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
21use crate::error::ProtobufError;
22use crate::error::WireError;
23
24// If an input stream is constructed with a `Read`, we create a
25// `BufReader` with an internal buffer of this size.
26const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
27
28const NO_LIMIT: u64 = u64::MAX;
29
30/// Dangerous implementation of `BufRead`.
31///
32/// Unsafe wrapper around BufRead which assumes that `BufRead` buf is
33/// not moved when `BufRead` is moved.
34///
35/// This assumption is generally incorrect, however, in practice
36/// `BufReadIter` is created either from `BufRead` reference (which
37/// cannot  be moved, because it is locked by `CodedInputStream`) or from
38/// `BufReader` which does not move its buffer (we know that from
39/// inspecting rust standard library).
40///
41/// It is important for `CodedInputStream` performance that small reads
42/// (e. g. 4 bytes reads) do not involve virtual calls or switches.
43/// This is achievable with `BufReadIter`.
44#[derive(Debug)]
45pub(crate) struct BufReadIter<'a> {
46    input_source: InputSource<'a>,
47    buf: InputBuf<'a>,
48    pos_of_buf_start: u64,
49    limit: u64,
50}
51
52impl<'a> Drop for BufReadIter<'a> {
53    fn drop(&mut self) {
54        match self.input_source {
55            InputSource::Read(ref mut buf_read) => buf_read.consume(self.buf.pos_within_buf()),
56            _ => {}
57        }
58    }
59}
60
61impl<'a> BufReadIter<'a> {
62    pub(crate) fn from_read(read: &'a mut dyn Read) -> BufReadIter<'a> {
63        BufReadIter {
64            input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity(
65                INPUT_STREAM_BUFFER_SIZE,
66                read,
67            ))),
68            buf: InputBuf::empty(),
69            pos_of_buf_start: 0,
70            limit: NO_LIMIT,
71        }
72    }
73
74    pub(crate) fn from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
75        BufReadIter {
76            input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)),
77            buf: InputBuf::empty(),
78            pos_of_buf_start: 0,
79            limit: NO_LIMIT,
80        }
81    }
82
83    pub(crate) fn from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a> {
84        BufReadIter {
85            input_source: InputSource::Slice(bytes),
86            buf: InputBuf::from_bytes(bytes),
87            pos_of_buf_start: 0,
88            limit: NO_LIMIT,
89        }
90    }
91
92    #[cfg(feature = "bytes")]
93    pub(crate) fn from_bytes(bytes: &'a Bytes) -> BufReadIter<'a> {
94        BufReadIter {
95            input_source: InputSource::Bytes(bytes),
96            buf: InputBuf::from_bytes(&bytes),
97            pos_of_buf_start: 0,
98            limit: NO_LIMIT,
99        }
100    }
101
102    #[inline]
103    fn assertions(&self) {
104        debug_assert!(self.pos() <= self.limit);
105        self.buf.assertions();
106    }
107
108    #[inline(always)]
109    pub(crate) fn pos(&self) -> u64 {
110        self.pos_of_buf_start + self.buf.pos_within_buf() as u64
111    }
112
113    /// Recompute `limit_within_buf` after update of `limit`
114    #[inline]
115    fn update_limit_within_buf(&mut self) {
116        assert!(self.limit >= self.pos_of_buf_start);
117        self.buf.update_limit(self.limit - self.pos_of_buf_start);
118        self.assertions();
119    }
120
121    pub(crate) fn push_limit(&mut self, limit: u64) -> crate::Result<u64> {
122        let new_limit = match self.pos().checked_add(limit) {
123            Some(new_limit) => new_limit,
124            None => return Err(ProtobufError::WireError(WireError::LimitOverflow).into()),
125        };
126
127        if new_limit > self.limit {
128            return Err(ProtobufError::WireError(WireError::LimitIncrease).into());
129        }
130
131        let prev_limit = mem::replace(&mut self.limit, new_limit);
132
133        self.update_limit_within_buf();
134
135        Ok(prev_limit)
136    }
137
138    #[inline]
139    pub(crate) fn pop_limit(&mut self, limit: u64) {
140        assert!(limit >= self.limit);
141
142        self.limit = limit;
143
144        self.update_limit_within_buf();
145    }
146
147    #[inline(always)]
148    pub(crate) fn remaining_in_buf(&self) -> &[u8] {
149        self.buf.remaining_in_buf()
150    }
151
152    #[inline]
153    pub(crate) fn consume(&mut self, amt: usize) {
154        self.buf.consume(amt);
155    }
156
157    #[inline(always)]
158    pub(crate) fn remaining_in_buf_len(&self) -> usize {
159        self.remaining_in_buf().len()
160    }
161
162    #[inline(always)]
163    pub(crate) fn bytes_until_limit(&self) -> u64 {
164        if self.limit == NO_LIMIT {
165            NO_LIMIT
166        } else {
167            self.limit - self.pos()
168        }
169    }
170
171    #[inline(always)]
172    pub(crate) fn eof(&mut self) -> crate::Result<bool> {
173        if self.remaining_in_buf_len() != 0 {
174            Ok(false)
175        } else {
176            Ok(self.fill_buf()?.is_empty())
177        }
178    }
179
180    fn read_byte_slow(&mut self) -> crate::Result<u8> {
181        self.fill_buf_slow()?;
182
183        if let Some(b) = self.buf.read_byte() {
184            return Ok(b);
185        }
186
187        Err(WireError::UnexpectedEof.into())
188    }
189
190    #[inline(always)]
191    pub(crate) fn read_byte(&mut self) -> crate::Result<u8> {
192        if let Some(b) = self.buf.read_byte() {
193            return Ok(b);
194        }
195
196        self.read_byte_slow()
197    }
198
199    #[cfg(feature = "bytes")]
200    pub(crate) fn read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes> {
201        if let InputSource::Bytes(bytes) = self.input_source {
202            if len > self.remaining_in_buf_len() {
203                return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
204            }
205            let end = self.buf.pos_within_buf() + len;
206
207            let r = bytes.slice(self.buf.pos_within_buf()..end);
208            self.buf.consume(len);
209            Ok(r)
210        } else {
211            if len >= READ_RAW_BYTES_MAX_ALLOC {
212                // We cannot trust `len` because protobuf message could be malformed.
213                // Reading should not result in OOM when allocating a buffer.
214                let mut v = Vec::new();
215                self.read_exact_to_vec(len, &mut v)?;
216                Ok(Bytes::from(v))
217            } else {
218                let mut r = BytesMut::with_capacity(len);
219                unsafe {
220                    let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
221                    self.read_exact(buf)?;
222                    r.advance_mut(len);
223                }
224                Ok(r.freeze())
225            }
226        }
227    }
228
229    #[cfg(feature = "bytes")]
230    unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] {
231        use std::slice;
232        slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len())
233    }
234
235    /// Returns 0 when EOF or limit reached.
236    pub(crate) fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> {
237        let rem = self.fill_buf()?;
238
239        let len = cmp::min(rem.len(), buf.len());
240        buf[..len].copy_from_slice(&rem[..len]);
241        self.buf.consume(len);
242        Ok(len)
243    }
244
245    fn consume_buf(&mut self) -> crate::Result<()> {
246        match &mut self.input_source {
247            InputSource::Read(read) => {
248                read.consume(self.buf.pos_within_buf());
249                self.pos_of_buf_start += self.buf.pos_within_buf() as u64;
250                self.buf = InputBuf::empty();
251                self.assertions();
252                Ok(())
253            }
254            _ => Err(WireError::UnexpectedEof.into()),
255        }
256    }
257
258    /// Read at most `max` bytes.
259    ///
260    /// Returns 0 when EOF or limit reached.
261    fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize> {
262        let rem = self.fill_buf()?;
263
264        let len = cmp::min(rem.len(), max);
265        vec.extend_from_slice(&rem[..len]);
266        self.buf.consume(len);
267        Ok(len)
268    }
269
270    fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
271        if self.bytes_until_limit() < buf.len() as u64 {
272            return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
273        }
274
275        self.consume_buf()?;
276
277        match &mut self.input_source {
278            InputSource::Read(buf_read) => {
279                buf_read.read_exact_uninit(buf)?;
280                self.pos_of_buf_start += buf.len() as u64;
281                self.assertions();
282                Ok(())
283            }
284            _ => unreachable!(),
285        }
286    }
287
288    #[inline]
289    pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
290        if self.remaining_in_buf_len() >= buf.len() {
291            self.buf.read_bytes(buf);
292            return Ok(());
293        }
294
295        self.read_exact_slow(buf)
296    }
297
298    /// Read raw bytes into the supplied vector.  The vector will be resized as needed and
299    /// overwritten.
300    pub(crate) fn read_exact_to_vec(
301        &mut self,
302        count: usize,
303        target: &mut Vec<u8>,
304    ) -> crate::Result<()> {
305        // TODO: also do some limits when reading from unlimited source
306        if count as u64 > self.bytes_until_limit() {
307            return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
308        }
309
310        target.clear();
311
312        if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
313            // avoid calling `reserve` on buf with very large buffer: could be a malformed message
314
315            target.reserve(READ_RAW_BYTES_MAX_ALLOC);
316
317            while target.len() < count {
318                if count - target.len() <= target.len() {
319                    target.reserve_exact(count - target.len());
320                } else {
321                    target.reserve(1);
322                }
323
324                let max = cmp::min(target.capacity() - target.len(), count - target.len());
325                let read = self.read_to_vec(target, max)?;
326                if read == 0 {
327                    return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
328                }
329            }
330        } else {
331            target.reserve_exact(count);
332
333            unsafe {
334                self.read_exact(&mut target.spare_capacity_mut()[..count])?;
335                target.set_len(count);
336            }
337        }
338
339        debug_assert_eq!(count, target.len());
340
341        Ok(())
342    }
343
344    pub(crate) fn skip_bytes(&mut self, count: u32) -> crate::Result<()> {
345        if count as usize <= self.remaining_in_buf_len() {
346            self.buf.consume(count as usize);
347            return Ok(());
348        }
349
350        if count as u64 > self.bytes_until_limit() {
351            return Err(WireError::TruncatedMessage.into());
352        }
353
354        self.consume_buf()?;
355
356        match &mut self.input_source {
357            InputSource::Read(read) => {
358                read.skip_bytes(count as usize)?;
359                self.pos_of_buf_start += count as u64;
360                self.assertions();
361                Ok(())
362            }
363            _ => unreachable!(),
364        }
365    }
366
367    fn fill_buf_slow(&mut self) -> crate::Result<()> {
368        self.assertions();
369        if self.limit == self.pos() {
370            return Ok(());
371        }
372
373        match self.input_source {
374            InputSource::Read(..) => {}
375            _ => return Ok(()),
376        }
377
378        self.consume_buf()?;
379
380        match self.input_source {
381            InputSource::Read(ref mut buf_read) => {
382                self.buf = unsafe { InputBuf::from_bytes_ignore_lifetime(buf_read.fill_buf()?) };
383                self.update_limit_within_buf();
384                Ok(())
385            }
386            _ => {
387                unreachable!();
388            }
389        }
390    }
391
392    #[inline(always)]
393    pub(crate) fn fill_buf(&mut self) -> crate::Result<&[u8]> {
394        let rem = self.buf.remaining_in_buf();
395        if !rem.is_empty() {
396            return Ok(rem);
397        }
398
399        if self.limit == self.pos() {
400            return Ok(&[]);
401        }
402
403        self.fill_buf_slow()?;
404
405        Ok(self.buf.remaining_in_buf())
406    }
407}
408
409#[cfg(all(test, feature = "bytes"))]
410mod test_bytes {
411    use std::io::Write;
412
413    use super::*;
414
415    fn make_long_string(len: usize) -> Vec<u8> {
416        let mut s = Vec::new();
417        while s.len() < len {
418            let len = s.len();
419            write!(&mut s, "{}", len).expect("unexpected");
420        }
421        s.truncate(len);
422        s
423    }
424
425    #[test]
426    #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522
427    fn read_exact_bytes_from_slice() {
428        let bytes = make_long_string(100);
429        let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
430        assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
431        assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
432    }
433
434    #[test]
435    #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522
436    fn read_exact_bytes_from_bytes() {
437        let bytes = Bytes::from(make_long_string(100));
438        let mut bri = BufReadIter::from_bytes(&bytes);
439        let read = bri.read_exact_bytes(90).unwrap();
440        assert_eq!(&bytes[..90], &read[..]);
441        assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
442        assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
443    }
444}
445
446#[cfg(test)]
447mod test {
448    use std::io;
449
450    use super::*;
451
452    #[test]
453    fn eof_at_limit() {
454        struct Read5ThenPanic {
455            pos: usize,
456        }
457
458        impl Read for Read5ThenPanic {
459            fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
460                unreachable!();
461            }
462        }
463
464        impl BufRead for Read5ThenPanic {
465            fn fill_buf(&mut self) -> io::Result<&[u8]> {
466                assert_eq!(0, self.pos);
467                static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
468                Ok(ZERO_TO_FIVE)
469            }
470
471            fn consume(&mut self, amt: usize) {
472                if amt == 0 {
473                    // drop of BufReadIter
474                    return;
475                }
476
477                assert_eq!(0, self.pos);
478                assert_eq!(5, amt);
479                self.pos += amt;
480            }
481        }
482
483        let mut read = Read5ThenPanic { pos: 0 };
484        let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
485        assert_eq!(0, buf_read_iter.pos());
486        let _prev_limit = buf_read_iter.push_limit(5);
487        buf_read_iter.read_byte().expect("read_byte");
488        buf_read_iter
489            .read_exact(&mut [
490                MaybeUninit::uninit(),
491                MaybeUninit::uninit(),
492                MaybeUninit::uninit(),
493                MaybeUninit::uninit(),
494            ])
495            .expect("read_exact");
496        assert!(buf_read_iter.eof().expect("eof"));
497    }
498}