1use core::cmp::min;
2use core::num::NonZeroU64;
3
4use ::bytes::{Buf, BufMut};
5
6use crate::DecodeError;
7
8#[inline]
11pub fn encode_varint(mut value: u64, buf: &mut impl BufMut) {
12 for _ in 0..10 {
14 if value < 0x80 {
15 buf.put_u8(value as u8);
16 break;
17 } else {
18 buf.put_u8(((value & 0x7F) | 0x80) as u8);
19 value >>= 7;
20 }
21 }
22}
23
24#[inline]
27pub const fn encoded_len_varint(value: u64) -> usize {
28 let log2value = unsafe { NonZeroU64::new_unchecked(value | 1) }.ilog2();
32 ((log2value * 9 + (64 + 9)) / 64) as usize
33}
34
35#[inline]
37pub fn decode_varint(buf: &mut impl Buf) -> Result<u64, DecodeError> {
38 let bytes = buf.chunk();
39 let len = bytes.len();
40 if len == 0 {
41 return Err(DecodeError::new("invalid varint"));
42 }
43
44 let byte = bytes[0];
45 if byte < 0x80 {
46 buf.advance(1);
47 Ok(u64::from(byte))
48 } else if len > 10 || bytes[len - 1] < 0x80 {
49 let (value, advance) = decode_varint_slice(bytes)?;
50 buf.advance(advance);
51 Ok(value)
52 } else {
53 decode_varint_slow(buf)
54 }
55}
56
57#[inline]
71fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
72 assert!(!bytes.is_empty());
76 assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
77
78 let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
79 let mut part0: u32 = u32::from(b);
80 if b < 0x80 {
81 return Ok((u64::from(part0), 1));
82 };
83 part0 -= 0x80;
84 b = unsafe { *bytes.get_unchecked(1) };
85 part0 += u32::from(b) << 7;
86 if b < 0x80 {
87 return Ok((u64::from(part0), 2));
88 };
89 part0 -= 0x80 << 7;
90 b = unsafe { *bytes.get_unchecked(2) };
91 part0 += u32::from(b) << 14;
92 if b < 0x80 {
93 return Ok((u64::from(part0), 3));
94 };
95 part0 -= 0x80 << 14;
96 b = unsafe { *bytes.get_unchecked(3) };
97 part0 += u32::from(b) << 21;
98 if b < 0x80 {
99 return Ok((u64::from(part0), 4));
100 };
101 part0 -= 0x80 << 21;
102 let value = u64::from(part0);
103
104 b = unsafe { *bytes.get_unchecked(4) };
105 let mut part1: u32 = u32::from(b);
106 if b < 0x80 {
107 return Ok((value + (u64::from(part1) << 28), 5));
108 };
109 part1 -= 0x80;
110 b = unsafe { *bytes.get_unchecked(5) };
111 part1 += u32::from(b) << 7;
112 if b < 0x80 {
113 return Ok((value + (u64::from(part1) << 28), 6));
114 };
115 part1 -= 0x80 << 7;
116 b = unsafe { *bytes.get_unchecked(6) };
117 part1 += u32::from(b) << 14;
118 if b < 0x80 {
119 return Ok((value + (u64::from(part1) << 28), 7));
120 };
121 part1 -= 0x80 << 14;
122 b = unsafe { *bytes.get_unchecked(7) };
123 part1 += u32::from(b) << 21;
124 if b < 0x80 {
125 return Ok((value + (u64::from(part1) << 28), 8));
126 };
127 part1 -= 0x80 << 21;
128 let value = value + ((u64::from(part1)) << 28);
129
130 b = unsafe { *bytes.get_unchecked(8) };
131 let mut part2: u32 = u32::from(b);
132 if b < 0x80 {
133 return Ok((value + (u64::from(part2) << 56), 9));
134 };
135 part2 -= 0x80;
136 b = unsafe { *bytes.get_unchecked(9) };
137 part2 += u32::from(b) << 7;
138 if b < 0x02 {
141 return Ok((value + (u64::from(part2) << 56), 10));
142 };
143
144 Err(DecodeError::new("invalid varint"))
147}
148
149#[inline(never)]
156#[cold]
157fn decode_varint_slow(buf: &mut impl Buf) -> Result<u64, DecodeError> {
158 let mut value = 0;
159 for count in 0..min(10, buf.remaining()) {
160 let byte = buf.get_u8();
161 value |= u64::from(byte & 0x7F) << (count * 7);
162 if byte <= 0x7F {
163 if count == 9 && byte >= 0x02 {
166 return Err(DecodeError::new("invalid varint"));
167 } else {
168 return Ok(value);
169 }
170 }
171 }
172
173 Err(DecodeError::new("invalid varint"))
174}
175
176#[cfg(test)]
177mod test {
178 use super::*;
179
180 #[test]
181 fn varint() {
182 fn check(value: u64, encoded: &[u8]) {
183 let mut buf = Vec::with_capacity(1);
185 encode_varint(value, &mut buf);
186 assert_eq!(buf, encoded);
187
188 let mut buf = Vec::with_capacity(100);
190 encode_varint(value, &mut buf);
191 assert_eq!(buf, encoded);
192
193 assert_eq!(encoded_len_varint(value), encoded.len());
194
195 let mut encoded_copy = encoded;
197 let roundtrip_value = decode_varint(&mut encoded_copy).expect("decoding failed");
198 assert_eq!(value, roundtrip_value);
199
200 let mut encoded_copy = encoded;
201 let roundtrip_value =
202 decode_varint_slow(&mut encoded_copy).expect("slow decoding failed");
203 assert_eq!(value, roundtrip_value);
204 }
205
206 check(2u64.pow(0) - 1, &[0x00]);
207 check(2u64.pow(0), &[0x01]);
208
209 check(2u64.pow(7) - 1, &[0x7F]);
210 check(2u64.pow(7), &[0x80, 0x01]);
211 check(300, &[0xAC, 0x02]);
212
213 check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
214 check(2u64.pow(14), &[0x80, 0x80, 0x01]);
215
216 check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
217 check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
218
219 check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
220 check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
221
222 check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
223 check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
224
225 check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
226 check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
227
228 check(
229 2u64.pow(49) - 1,
230 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
231 );
232 check(
233 2u64.pow(49),
234 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
235 );
236
237 check(
238 2u64.pow(56) - 1,
239 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
240 );
241 check(
242 2u64.pow(56),
243 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
244 );
245
246 check(
247 2u64.pow(63) - 1,
248 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
249 );
250 check(
251 2u64.pow(63),
252 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
253 );
254
255 check(
256 u64::MAX,
257 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
258 );
259 }
260
261 const U64_MAX_PLUS_ONE: &[u8] = &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
262
263 #[test]
264 fn varint_overflow() {
265 let mut copy = U64_MAX_PLUS_ONE;
266 decode_varint(&mut copy).expect_err("decoding u64::MAX + 1 succeeded");
267 }
268
269 #[test]
270 fn variant_slow_overflow() {
271 let mut copy = U64_MAX_PLUS_ONE;
272 decode_varint_slow(&mut copy).expect_err("slow decoding u64::MAX + 1 succeeded");
273 }
274}