prost/encoding/
varint.rs

1use core::cmp::min;
2use core::num::NonZeroU64;
3
4use ::bytes::{Buf, BufMut};
5
6use crate::DecodeError;
7
8/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
9/// The buffer must have enough remaining space (maximum 10 bytes).
10#[inline]
11pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) {
12    // Varints are never more than 10 bytes
13    for _ in 0..10 {
14        if value < 0x80 {
15            buf.put_u8(value as u8);
16            break;
17        } else {
18            buf.put_u8(((value & 0x7F) | 0x80) as u8);
19            value >>= 7;
20        }
21    }
22}
23
24/// Returns the encoded length of the value in LEB128 variable length format.
25/// The returned value will be between 1 and 10, inclusive.
26#[inline]
27pub const fn encoded_len_varint(value: u64) -> usize {
28    // Based on [VarintSize64][1].
29    // [1]: https://github.com/protocolbuffers/protobuf/blob/v28.3/src/google/protobuf/io/coded_stream.h#L1744-L1756
30    // Safety: value | 1 is non-zero.
31    let log2value = unsafe { NonZeroU64::new_unchecked(value | 1) }.ilog2();
32    ((log2value * 9 + (64 + 9)) / 64) as usize
33}
34
35/// Decodes a LEB128-encoded variable length integer from the buffer.
36#[inline]
37pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> {
38    let bytes = buf.chunk();
39    let len = bytes.len();
40    if len == 0 {
41        return Err(DecodeError::new("invalid varint"));
42    }
43
44    let byte = bytes[0];
45    if byte < 0x80 {
46        buf.advance(1);
47        Ok(u64::from(byte))
48    } else if len > 10 || bytes[len - 1] < 0x80 {
49        let (value, advance) = decode_varint_slice(bytes)?;
50        buf.advance(advance);
51        Ok(value)
52    } else {
53        decode_varint_slow(buf)
54    }
55}
56
57/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
58/// number of bytes read.
59///
60/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
61/// [`ConsumeVarint`][2].
62///
63/// ## Safety
64///
65/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
66/// element in bytes is < `0x80`.
67///
68/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
69/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
70#[inline]
71fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
72    // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
73
74    // Use assertions to ensure memory safety, but it should always be optimized after inline.
75    assert!(!bytes.is_empty());
76    assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
77
78    let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
79    let mut part0: u32 = u32::from(b);
80    if b < 0x80 {
81        return Ok((u64::from(part0), 1));
82    };
83    part0 -= 0x80;
84    b = unsafe { *bytes.get_unchecked(1) };
85    part0 += u32::from(b) << 7;
86    if b < 0x80 {
87        return Ok((u64::from(part0), 2));
88    };
89    part0 -= 0x80 << 7;
90    b = unsafe { *bytes.get_unchecked(2) };
91    part0 += u32::from(b) << 14;
92    if b < 0x80 {
93        return Ok((u64::from(part0), 3));
94    };
95    part0 -= 0x80 << 14;
96    b = unsafe { *bytes.get_unchecked(3) };
97    part0 += u32::from(b) << 21;
98    if b < 0x80 {
99        return Ok((u64::from(part0), 4));
100    };
101    part0 -= 0x80 << 21;
102    let value = u64::from(part0);
103
104    b = unsafe { *bytes.get_unchecked(4) };
105    let mut part1: u32 = u32::from(b);
106    if b < 0x80 {
107        return Ok((value + (u64::from(part1) << 28), 5));
108    };
109    part1 -= 0x80;
110    b = unsafe { *bytes.get_unchecked(5) };
111    part1 += u32::from(b) << 7;
112    if b < 0x80 {
113        return Ok((value + (u64::from(part1) << 28), 6));
114    };
115    part1 -= 0x80 << 7;
116    b = unsafe { *bytes.get_unchecked(6) };
117    part1 += u32::from(b) << 14;
118    if b < 0x80 {
119        return Ok((value + (u64::from(part1) << 28), 7));
120    };
121    part1 -= 0x80 << 14;
122    b = unsafe { *bytes.get_unchecked(7) };
123    part1 += u32::from(b) << 21;
124    if b < 0x80 {
125        return Ok((value + (u64::from(part1) << 28), 8));
126    };
127    part1 -= 0x80 << 21;
128    let value = value + ((u64::from(part1)) << 28);
129
130    b = unsafe { *bytes.get_unchecked(8) };
131    let mut part2: u32 = u32::from(b);
132    if b < 0x80 {
133        return Ok((value + (u64::from(part2) << 56), 9));
134    };
135    part2 -= 0x80;
136    b = unsafe { *bytes.get_unchecked(9) };
137    part2 += u32::from(b) << 7;
138    // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
139    // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
140    if b < 0x02 {
141        return Ok((value + (u64::from(part2) << 56), 10));
142    };
143
144    // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
145    // Assume the data is corrupt.
146    Err(DecodeError::new("invalid varint"))
147}
148
149/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
150/// necessary.
151///
152/// Contains a varint overflow check from [`ConsumeVarint`][1].
153///
154/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
155#[inline(never)]
156#[cold]
157fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> {
158    let mut value = 0;
159    for count in 0..min(10, buf.remaining()) {
160        let byte = buf.get_u8();
161        value |= u64::from(byte & 0x7F) << (count * 7);
162        if byte <= 0x7F {
163            // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
164            // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
165            if count == 9 && byte >= 0x02 {
166                return Err(DecodeError::new("invalid varint"));
167            } else {
168                return Ok(value);
169            }
170        }
171    }
172
173    Err(DecodeError::new("invalid varint"))
174}
175
176#[cfg(test)]
177mod test {
178    use super::*;
179
180    #[test]
181    fn varint() {
182        fn check(value: u64, encoded: &[u8]) {
183            // Small buffer.
184            let mut buf = Vec::with_capacity(1);
185            encode_varint(value, &mut buf);
186            assert_eq!(buf, encoded);
187
188            // Large buffer.
189            let mut buf = Vec::with_capacity(100);
190            encode_varint(value, &mut buf);
191            assert_eq!(buf, encoded);
192
193            assert_eq!(encoded_len_varint(value), encoded.len());
194
195            // See: https://github.com/tokio-rs/prost/pull/1008 for copying reasoning.
196            let mut encoded_copy = encoded;
197            let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed");
198            assert_eq!(value, roundtrip_value);
199
200            let mut encoded_copy = encoded;
201            let roundtrip_value =
202                decode_varint_slow(&mut encoded_copy).expect("slow decoding failed");
203            assert_eq!(value, roundtrip_value);
204        }
205
206        check(2u64.pow(0) - 1, &[0x00]);
207        check(2u64.pow(0), &[0x01]);
208
209        check(2u64.pow(7) - 1, &[0x7F]);
210        check(2u64.pow(7), &[0x80, 0x01]);
211        check(300, &[0xAC, 0x02]);
212
213        check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
214        check(2u64.pow(14), &[0x80, 0x80, 0x01]);
215
216        check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
217        check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
218
219        check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
220        check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
221
222        check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
223        check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
224
225        check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
226        check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
227
228        check(
229            2u64.pow(49) - 1,
230            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
231        );
232        check(
233            2u64.pow(49),
234            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
235        );
236
237        check(
238            2u64.pow(56) - 1,
239            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
240        );
241        check(
242            2u64.pow(56),
243            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
244        );
245
246        check(
247            2u64.pow(63) - 1,
248            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
249        );
250        check(
251            2u64.pow(63),
252            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
253        );
254
255        check(
256            u64::MAX,
257            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
258        );
259    }
260
261    const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
262
263    #[test]
264    fn varint_overflow() {
265        let mut copy = U64_MAX_PLUS_ONE;
266        decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded");
267    }
268
269    #[test]
270    fn variant_slow_overflow() {
271        let mut copy = U64_MAX_PLUS_ONE;
272        decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded");
273    }
274}