protobuf/coded_input_stream/
buf_read_iter.rs
1use std::cmp;
2use std::io::BufRead;
3use std::io::BufReader;
4use std::io::Read;
5use std::mem;
6use std::mem::MaybeUninit;
7
8#[cfg(feature = "bytes")]
9use bytes::buf::UninitSlice;
10#[cfg(feature = "bytes")]
11use bytes::BufMut;
12#[cfg(feature = "bytes")]
13use bytes::Bytes;
14#[cfg(feature = "bytes")]
15use bytes::BytesMut;
16
17use crate::coded_input_stream::buf_read_or_reader::BufReadOrReader;
18use crate::coded_input_stream::input_buf::InputBuf;
19use crate::coded_input_stream::input_source::InputSource;
20use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
21use crate::error::ProtobufError;
22use crate::error::WireError;
23
24const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
27
28const NO_LIMIT: u64 = u64::MAX;
29
30#[derive(Debug)]
45pub(crate) struct BufReadIter<'a> {
46 input_source: InputSource<'a>,
47 buf: InputBuf<'a>,
48 pos_of_buf_start: u64,
49 limit: u64,
50}
51
52impl<'a> Drop for BufReadIter<'a> {
53 fn drop(&mut self) {
54 match self.input_source {
55 InputSource::Read(ref mut buf_read) => buf_read.consume(self.buf.pos_within_buf()),
56 _ => {}
57 }
58 }
59}
60
61impl<'a> BufReadIter<'a> {
62 pub(crate) fn from_read(read: &'a mut dyn Read) -> BufReadIter<'a> {
63 BufReadIter {
64 input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity(
65 INPUT_STREAM_BUFFER_SIZE,
66 read,
67 ))),
68 buf: InputBuf::empty(),
69 pos_of_buf_start: 0,
70 limit: NO_LIMIT,
71 }
72 }
73
74 pub(crate) fn from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
75 BufReadIter {
76 input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)),
77 buf: InputBuf::empty(),
78 pos_of_buf_start: 0,
79 limit: NO_LIMIT,
80 }
81 }
82
83 pub(crate) fn from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a> {
84 BufReadIter {
85 input_source: InputSource::Slice(bytes),
86 buf: InputBuf::from_bytes(bytes),
87 pos_of_buf_start: 0,
88 limit: NO_LIMIT,
89 }
90 }
91
92 #[cfg(feature = "bytes")]
93 pub(crate) fn from_bytes(bytes: &'a Bytes) -> BufReadIter<'a> {
94 BufReadIter {
95 input_source: InputSource::Bytes(bytes),
96 buf: InputBuf::from_bytes(&bytes),
97 pos_of_buf_start: 0,
98 limit: NO_LIMIT,
99 }
100 }
101
102 #[inline]
103 fn assertions(&self) {
104 debug_assert!(self.pos() <= self.limit);
105 self.buf.assertions();
106 }
107
108 #[inline(always)]
109 pub(crate) fn pos(&self) -> u64 {
110 self.pos_of_buf_start + self.buf.pos_within_buf() as u64
111 }
112
113 #[inline]
115 fn update_limit_within_buf(&mut self) {
116 assert!(self.limit >= self.pos_of_buf_start);
117 self.buf.update_limit(self.limit - self.pos_of_buf_start);
118 self.assertions();
119 }
120
121 pub(crate) fn push_limit(&mut self, limit: u64) -> crate::Result<u64> {
122 let new_limit = match self.pos().checked_add(limit) {
123 Some(new_limit) => new_limit,
124 None => return Err(ProtobufError::WireError(WireError::LimitOverflow).into()),
125 };
126
127 if new_limit > self.limit {
128 return Err(ProtobufError::WireError(WireError::LimitIncrease).into());
129 }
130
131 let prev_limit = mem::replace(&mut self.limit, new_limit);
132
133 self.update_limit_within_buf();
134
135 Ok(prev_limit)
136 }
137
138 #[inline]
139 pub(crate) fn pop_limit(&mut self, limit: u64) {
140 assert!(limit >= self.limit);
141
142 self.limit = limit;
143
144 self.update_limit_within_buf();
145 }
146
147 #[inline(always)]
148 pub(crate) fn remaining_in_buf(&self) -> &[u8] {
149 self.buf.remaining_in_buf()
150 }
151
152 #[inline]
153 pub(crate) fn consume(&mut self, amt: usize) {
154 self.buf.consume(amt);
155 }
156
157 #[inline(always)]
158 pub(crate) fn remaining_in_buf_len(&self) -> usize {
159 self.remaining_in_buf().len()
160 }
161
162 #[inline(always)]
163 pub(crate) fn bytes_until_limit(&self) -> u64 {
164 if self.limit == NO_LIMIT {
165 NO_LIMIT
166 } else {
167 self.limit - self.pos()
168 }
169 }
170
171 #[inline(always)]
172 pub(crate) fn eof(&mut self) -> crate::Result<bool> {
173 if self.remaining_in_buf_len() != 0 {
174 Ok(false)
175 } else {
176 Ok(self.fill_buf()?.is_empty())
177 }
178 }
179
180 fn read_byte_slow(&mut self) -> crate::Result<u8> {
181 self.fill_buf_slow()?;
182
183 if let Some(b) = self.buf.read_byte() {
184 return Ok(b);
185 }
186
187 Err(WireError::UnexpectedEof.into())
188 }
189
190 #[inline(always)]
191 pub(crate) fn read_byte(&mut self) -> crate::Result<u8> {
192 if let Some(b) = self.buf.read_byte() {
193 return Ok(b);
194 }
195
196 self.read_byte_slow()
197 }
198
199 #[cfg(feature = "bytes")]
200 pub(crate) fn read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes> {
201 if let InputSource::Bytes(bytes) = self.input_source {
202 if len > self.remaining_in_buf_len() {
203 return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
204 }
205 let end = self.buf.pos_within_buf() + len;
206
207 let r = bytes.slice(self.buf.pos_within_buf()..end);
208 self.buf.consume(len);
209 Ok(r)
210 } else {
211 if len >= READ_RAW_BYTES_MAX_ALLOC {
212 let mut v = Vec::new();
215 self.read_exact_to_vec(len, &mut v)?;
216 Ok(Bytes::from(v))
217 } else {
218 let mut r = BytesMut::with_capacity(len);
219 unsafe {
220 let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
221 self.read_exact(buf)?;
222 r.advance_mut(len);
223 }
224 Ok(r.freeze())
225 }
226 }
227 }
228
229 #[cfg(feature = "bytes")]
230 unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] {
231 use std::slice;
232 slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len())
233 }
234
235 pub(crate) fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> {
237 let rem = self.fill_buf()?;
238
239 let len = cmp::min(rem.len(), buf.len());
240 buf[..len].copy_from_slice(&rem[..len]);
241 self.buf.consume(len);
242 Ok(len)
243 }
244
245 fn consume_buf(&mut self) -> crate::Result<()> {
246 match &mut self.input_source {
247 InputSource::Read(read) => {
248 read.consume(self.buf.pos_within_buf());
249 self.pos_of_buf_start += self.buf.pos_within_buf() as u64;
250 self.buf = InputBuf::empty();
251 self.assertions();
252 Ok(())
253 }
254 _ => Err(WireError::UnexpectedEof.into()),
255 }
256 }
257
258 fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize> {
262 let rem = self.fill_buf()?;
263
264 let len = cmp::min(rem.len(), max);
265 vec.extend_from_slice(&rem[..len]);
266 self.buf.consume(len);
267 Ok(len)
268 }
269
270 fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
271 if self.bytes_until_limit() < buf.len() as u64 {
272 return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
273 }
274
275 self.consume_buf()?;
276
277 match &mut self.input_source {
278 InputSource::Read(buf_read) => {
279 buf_read.read_exact_uninit(buf)?;
280 self.pos_of_buf_start += buf.len() as u64;
281 self.assertions();
282 Ok(())
283 }
284 _ => unreachable!(),
285 }
286 }
287
288 #[inline]
289 pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
290 if self.remaining_in_buf_len() >= buf.len() {
291 self.buf.read_bytes(buf);
292 return Ok(());
293 }
294
295 self.read_exact_slow(buf)
296 }
297
298 pub(crate) fn read_exact_to_vec(
301 &mut self,
302 count: usize,
303 target: &mut Vec<u8>,
304 ) -> crate::Result<()> {
305 if count as u64 > self.bytes_until_limit() {
307 return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
308 }
309
310 target.clear();
311
312 if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
313 target.reserve(READ_RAW_BYTES_MAX_ALLOC);
316
317 while target.len() < count {
318 if count - target.len() <= target.len() {
319 target.reserve_exact(count - target.len());
320 } else {
321 target.reserve(1);
322 }
323
324 let max = cmp::min(target.capacity() - target.len(), count - target.len());
325 let read = self.read_to_vec(target, max)?;
326 if read == 0 {
327 return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
328 }
329 }
330 } else {
331 target.reserve_exact(count);
332
333 unsafe {
334 self.read_exact(&mut target.spare_capacity_mut()[..count])?;
335 target.set_len(count);
336 }
337 }
338
339 debug_assert_eq!(count, target.len());
340
341 Ok(())
342 }
343
344 pub(crate) fn skip_bytes(&mut self, count: u32) -> crate::Result<()> {
345 if count as usize <= self.remaining_in_buf_len() {
346 self.buf.consume(count as usize);
347 return Ok(());
348 }
349
350 if count as u64 > self.bytes_until_limit() {
351 return Err(WireError::TruncatedMessage.into());
352 }
353
354 self.consume_buf()?;
355
356 match &mut self.input_source {
357 InputSource::Read(read) => {
358 read.skip_bytes(count as usize)?;
359 self.pos_of_buf_start += count as u64;
360 self.assertions();
361 Ok(())
362 }
363 _ => unreachable!(),
364 }
365 }
366
367 fn fill_buf_slow(&mut self) -> crate::Result<()> {
368 self.assertions();
369 if self.limit == self.pos() {
370 return Ok(());
371 }
372
373 match self.input_source {
374 InputSource::Read(..) => {}
375 _ => return Ok(()),
376 }
377
378 self.consume_buf()?;
379
380 match self.input_source {
381 InputSource::Read(ref mut buf_read) => {
382 self.buf = unsafe { InputBuf::from_bytes_ignore_lifetime(buf_read.fill_buf()?) };
383 self.update_limit_within_buf();
384 Ok(())
385 }
386 _ => {
387 unreachable!();
388 }
389 }
390 }
391
392 #[inline(always)]
393 pub(crate) fn fill_buf(&mut self) -> crate::Result<&[u8]> {
394 let rem = self.buf.remaining_in_buf();
395 if !rem.is_empty() {
396 return Ok(rem);
397 }
398
399 if self.limit == self.pos() {
400 return Ok(&[]);
401 }
402
403 self.fill_buf_slow()?;
404
405 Ok(self.buf.remaining_in_buf())
406 }
407}
408
409#[cfg(all(test, feature = "bytes"))]
410mod test_bytes {
411 use std::io::Write;
412
413 use super::*;
414
415 fn make_long_string(len: usize) -> Vec<u8> {
416 let mut s = Vec::new();
417 while s.len() < len {
418 let len = s.len();
419 write!(&mut s, "{}", len).expect("unexpected");
420 }
421 s.truncate(len);
422 s
423 }
424
425 #[test]
426 #[cfg_attr(miri, ignore)] fn read_exact_bytes_from_slice() {
428 let bytes = make_long_string(100);
429 let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
430 assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
431 assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
432 }
433
434 #[test]
435 #[cfg_attr(miri, ignore)] fn read_exact_bytes_from_bytes() {
437 let bytes = Bytes::from(make_long_string(100));
438 let mut bri = BufReadIter::from_bytes(&bytes);
439 let read = bri.read_exact_bytes(90).unwrap();
440 assert_eq!(&bytes[..90], &read[..]);
441 assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
442 assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
443 }
444}
445
446#[cfg(test)]
447mod test {
448 use std::io;
449
450 use super::*;
451
452 #[test]
453 fn eof_at_limit() {
454 struct Read5ThenPanic {
455 pos: usize,
456 }
457
458 impl Read for Read5ThenPanic {
459 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
460 unreachable!();
461 }
462 }
463
464 impl BufRead for Read5ThenPanic {
465 fn fill_buf(&mut self) -> io::Result<&[u8]> {
466 assert_eq!(0, self.pos);
467 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
468 Ok(ZERO_TO_FIVE)
469 }
470
471 fn consume(&mut self, amt: usize) {
472 if amt == 0 {
473 return;
475 }
476
477 assert_eq!(0, self.pos);
478 assert_eq!(5, amt);
479 self.pos += amt;
480 }
481 }
482
483 let mut read = Read5ThenPanic { pos: 0 };
484 let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
485 assert_eq!(0, buf_read_iter.pos());
486 let _prev_limit = buf_read_iter.push_limit(5);
487 buf_read_iter.read_byte().expect("read_byte");
488 buf_read_iter
489 .read_exact(&mut [
490 MaybeUninit::uninit(),
491 MaybeUninit::uninit(),
492 MaybeUninit::uninit(),
493 MaybeUninit::uninit(),
494 ])
495 .expect("read_exact");
496 assert!(buf_read_iter.eof().expect("eof"));
497 }
498}