1use byteorder::{LittleEndian as LE, ReadBytesExt, WriteBytesExt};
10use bytes::BufMut;
11use std::{cmp::min, io};
12
13use crate::proto::MyDeserialize;
14
15pub trait BufMutExt: BufMut {
16 fn put_lenenc_int(&mut self, n: u64) {
18 if n < 251 {
19 self.put_u8(n as u8);
20 } else if n < 65_536 {
21 self.put_u8(0xFC);
22 self.put_uint_le(n, 2);
23 } else if n < 16_777_216 {
24 self.put_u8(0xFD);
25 self.put_uint_le(n, 3);
26 } else {
27 self.put_u8(0xFE);
28 self.put_uint_le(n, 8);
29 }
30 }
31
32 fn put_lenenc_str(&mut self, s: &[u8]) {
34 self.put_lenenc_int(s.len() as u64);
35 self.put_slice(s);
36 }
37
38 fn put_u24_le(&mut self, x: u32) {
40 self.put_uint_le(x as u64, 3);
41 }
42
43 fn put_i24_le(&mut self, x: i32) {
45 self.put_int_le(x as i64, 3);
46 }
47
48 fn put_u48_le(&mut self, x: u64) {
50 self.put_uint_le(x, 6);
51 }
52
53 fn put_u56_le(&mut self, x: u64) {
55 self.put_uint_le(x, 7);
56 }
57
58 fn put_i56_le(&mut self, x: i64) {
60 self.put_int_le(x, 7);
61 }
62
63 fn put_u8_str(&mut self, s: &[u8]) {
65 let len = std::cmp::min(s.len(), u8::MAX as usize);
66 self.put_u8(len as u8);
67 self.put_slice(&s[..len]);
68 }
69
70 fn put_u32_str(&mut self, s: &[u8]) {
72 let len = std::cmp::min(s.len(), u32::MAX as usize);
73 self.put_u32_le(len as u32);
74 self.put_slice(&s[..len]);
75 }
76}
77
78impl<T: BufMut> BufMutExt for T {}
79
80#[derive(Debug, Copy, Clone, PartialEq, Eq)]
81pub struct ParseBuf<'a>(pub &'a [u8]);
82
83impl io::Read for ParseBuf<'_> {
84 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
85 let count = min(self.0.len(), buf.len());
86 (buf[..count]).copy_from_slice(&self.0[..count]);
87 self.0 = &self.0[count..];
88 Ok(count)
89 }
90}
91
92macro_rules! eat_num {
93 ($name:ident, $checked:ident, $t:ident::$fn:ident) => {
94 #[doc = "Consumes a number from the head of the buffer."]
95 pub fn $name(&mut self) -> $t {
96 const SIZE: usize = std::mem::size_of::<$t>();
97 let bytes = self.eat(SIZE);
98 unsafe { $t::$fn(*(bytes as *const _ as *const [_; SIZE])) }
99 }
100
101 #[doc = "Consumes a number from the head of the buffer. Returns `None` if buffer is too small."]
102 pub fn $checked(&mut self) -> Option<$t> {
103 if self.len() >= std::mem::size_of::<$t>() {
104 Some(self.$name())
105 } else {
106 None
107 }
108 }
109 };
110 ($name:ident, $checked:ident, $size:literal, $offset:literal, $t:ident::$fn:ident) => {
111 #[doc = "Consumes a number from the head of the buffer."]
112 pub fn $name(&mut self) -> $t {
113 const SIZE: usize = $size;
114 let mut x: $t = 0;
115 let bytes = self.eat(SIZE);
116 for (i, b) in bytes.iter().enumerate() {
117 x |= (*b as $t) << ((8 * i) + (8 * $offset));
118 }
119 $t::$fn(x)
120 }
121
122 #[doc = "Consumes a number from the head of the buffer. Returns `None` if buffer is too small."]
123 pub fn $checked(&mut self) -> Option<$t> {
124 if self.len() >= $size {
125 Some(self.$name())
126 } else {
127 None
128 }
129 }
130 };
131}
132
133impl<'a> ParseBuf<'a> {
134 #[inline(always)]
138 pub fn parse_unchecked<T>(&mut self, ctx: T::Ctx) -> io::Result<T>
139 where
140 T: MyDeserialize<'a>,
141 {
142 T::deserialize(ctx, self)
143 }
144
145 #[inline(always)]
147 pub fn parse<T>(&mut self, ctx: T::Ctx) -> io::Result<T>
148 where
149 T: MyDeserialize<'a>,
150 {
151 match T::SIZE {
152 Some(size) => {
153 let mut buf: ParseBuf = self.parse_unchecked(size)?;
154 buf.parse_unchecked(ctx)
155 }
156 None => self.parse_unchecked(ctx),
157 }
158 }
159
160 pub fn is_empty(&self) -> bool {
162 self.len() == 0
163 }
164
165 pub fn len(&self) -> usize {
167 self.0.len()
168 }
169
170 pub fn skip(&mut self, cnt: usize) {
174 self.0 = &self.0[cnt..];
175 }
176
177 pub fn checked_skip(&mut self, cnt: usize) -> bool {
179 if self.len() >= cnt {
180 self.skip(cnt);
181 true
182 } else {
183 false
184 }
185 }
186
187 pub fn eat(&mut self, n: usize) -> &'a [u8] {
195 let (left, right) = self.0.split_at(n);
196 self.0 = right;
197 left
198 }
199
200 pub fn eat_buf(&mut self, n: usize) -> Self {
201 Self(self.eat(n))
202 }
203
204 pub fn checked_eat(&mut self, n: usize) -> Option<&'a [u8]> {
206 if self.len() >= n {
207 Some(self.eat(n))
208 } else {
209 None
210 }
211 }
212
213 pub fn checked_eat_buf(&mut self, n: usize) -> Option<Self> {
214 Some(Self(self.checked_eat(n)?))
215 }
216
217 pub fn eat_all(&mut self) -> &'a [u8] {
218 self.eat(self.len())
219 }
220
221 eat_num!(eat_u8, checked_eat_u8, u8::from_le_bytes);
222 eat_num!(eat_i8, checked_eat_i8, i8::from_le_bytes);
223 eat_num!(eat_u16_le, checked_eat_u16_le, u16::from_le_bytes);
224 eat_num!(eat_i16_le, checked_eat_i16_le, i16::from_le_bytes);
225 eat_num!(eat_u16_be, checked_eat_u16_be, u16::from_be_bytes);
226 eat_num!(eat_i16_be, checked_eat_i16_be, i16::from_be_bytes);
227 eat_num!(eat_u24_le, checked_eat_u24_le, 3, 0, u32::from_le);
228 eat_num!(eat_i24_le, checked_eat_i24_le, 3, 0, i32::from_le);
229 eat_num!(eat_u24_be, checked_eat_u24_be, 3, 1, u32::from_be);
230 eat_num!(eat_i24_be, checked_eat_i24_be, 3, 1, i32::from_be);
231 eat_num!(eat_u32_le, checked_eat_u32_le, u32::from_le_bytes);
232 eat_num!(eat_i32_le, checked_eat_i32_le, i32::from_le_bytes);
233 eat_num!(eat_u32_be, checked_eat_u32_be, u32::from_be_bytes);
234 eat_num!(eat_i32_be, checked_eat_i32_be, i32::from_be_bytes);
235 eat_num!(eat_u40_le, checked_eat_u40_le, 5, 0, u64::from_le);
236 eat_num!(eat_i40_le, checked_eat_i40_le, 5, 0, i64::from_le);
237 eat_num!(eat_u40_be, checked_eat_u40_be, 5, 3, u64::from_be);
238 eat_num!(eat_i40_be, checked_eat_i40_be, 5, 3, i64::from_be);
239 eat_num!(eat_u48_le, checked_eat_u48_le, 6, 0, u64::from_le);
240 eat_num!(eat_i48_le, checked_eat_i48_le, 6, 0, i64::from_le);
241 eat_num!(eat_u48_be, checked_eat_u48_be, 6, 2, u64::from_be);
242 eat_num!(eat_i48_be, checked_eat_i48_be, 6, 2, i64::from_be);
243 eat_num!(eat_u56_le, checked_eat_u56_le, 7, 0, u64::from_le);
244 eat_num!(eat_i56_le, checked_eat_i56_le, 7, 0, i64::from_le);
245 eat_num!(eat_u56_be, checked_eat_u56_be, 7, 1, u64::from_be);
246 eat_num!(eat_i56_be, checked_eat_i56_be, 7, 1, i64::from_be);
247 eat_num!(eat_u64_le, checked_eat_u64_le, u64::from_le_bytes);
248 eat_num!(eat_i64_le, checked_eat_i64_le, i64::from_le_bytes);
249 eat_num!(eat_u64_be, checked_eat_u64_be, u64::from_be_bytes);
250 eat_num!(eat_i64_be, checked_eat_i64_be, i64::from_be_bytes);
251 eat_num!(eat_u128_le, checked_eat_u128_le, u128::from_le_bytes);
252 eat_num!(eat_i128_le, checked_eat_i128_le, i128::from_le_bytes);
253 eat_num!(eat_u128_be, checked_eat_u128_be, u128::from_be_bytes);
254 eat_num!(eat_i128_be, checked_eat_i128_be, i128::from_be_bytes);
255
256 eat_num!(eat_f32_le, checked_eat_f32_le, f32::from_le_bytes);
257 eat_num!(eat_f32_be, checked_eat_f32_be, f32::from_be_bytes);
258
259 eat_num!(eat_f64_le, checked_eat_f64_le, f64::from_le_bytes);
260 eat_num!(eat_f64_be, checked_eat_f64_be, f64::from_be_bytes);
261
262 pub fn eat_lenenc_int(&mut self) -> u64 {
266 match self.eat_u8() {
267 x @ 0..=0xfa => x as u64,
268 0xfc => self.eat_u16_le() as u64,
269 0xfd => self.eat_u24_le() as u64,
270 0xfe => self.eat_u64_le(),
271 0xfb | 0xff => 0,
272 }
273 }
274
275 pub fn checked_eat_lenenc_int(&mut self) -> Option<u64> {
277 match self.checked_eat_u8()? {
278 x @ 0..=0xfa => Some(x as u64),
279 0xfc => self.checked_eat_u16_le().map(|x| x as u64),
280 0xfd => self.checked_eat_u24_le().map(|x| x as u64),
281 0xfe => self.checked_eat_u64_le(),
282 0xfb | 0xff => Some(0),
283 }
284 }
285
286 pub fn eat_lenenc_str(&mut self) -> &'a [u8] {
290 let len = self.eat_lenenc_int();
291 self.eat(len as usize)
292 }
293
294 pub fn checked_eat_lenenc_str(&mut self) -> Option<&'a [u8]> {
296 let len = self.checked_eat_lenenc_int()?;
297 self.checked_eat(len as usize)
298 }
299
300 pub fn eat_u8_str(&mut self) -> &'a [u8] {
302 let len = self.eat_u8();
303 self.eat(len as usize)
304 }
305
306 pub fn checked_eat_u8_str(&mut self) -> Option<&'a [u8]> {
308 let len = self.checked_eat_u8()?;
309 self.checked_eat(len as usize)
310 }
311
312 pub fn eat_u32_str(&mut self) -> &'a [u8] {
314 let len = self.eat_u32_le();
315 self.eat(len as usize)
316 }
317
318 pub fn checked_eat_u32_str(&mut self) -> Option<&'a [u8]> {
320 let len = self.checked_eat_u32_le()?;
321 self.checked_eat(len as usize)
322 }
323
324 pub fn eat_null_str(&mut self) -> &'a [u8] {
328 let pos = self
329 .0
330 .iter()
331 .position(|x| *x == 0)
332 .map(|x| x + 1)
333 .unwrap_or_else(|| self.len());
334 match self.eat(pos) {
335 [head @ .., 0_u8] => head,
336 x => x,
337 }
338 }
339}
340
341#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
342#[error("Invalid length-encoded integer value (starts with 0xfb|0xff)")]
343pub struct InvalidLenghEncodedInteger;
344
345pub trait ReadMysqlExt: ReadBytesExt {
346 fn read_lenenc_int(&mut self) -> io::Result<u64> {
348 match self.read_u8()? {
349 x if x <= 0xfa => Ok(x.into()),
350 0xfc => self.read_uint::<LE>(2),
351 0xfd => self.read_uint::<LE>(3),
352 0xfe => self.read_uint::<LE>(8),
353 0xfb | 0xff => Err(io::Error::new(
354 io::ErrorKind::Other,
355 InvalidLenghEncodedInteger,
356 )),
357 _ => unreachable!(),
358 }
359 }
360
361 fn read_lenenc_str(&mut self) -> io::Result<Vec<u8>> {
363 let len = self.read_lenenc_int()?;
364 let mut output = vec![0_u8; len as usize];
365 self.read_exact(&mut output)?;
366 Ok(output)
367 }
368}
369
370pub trait WriteMysqlExt: WriteBytesExt {
371 fn write_lenenc_int(&mut self, x: u64) -> io::Result<u64> {
373 if x < 251 {
374 self.write_u8(x as u8)?;
375 Ok(1)
376 } else if x < 65_536 {
377 self.write_u8(0xFC)?;
378 self.write_uint::<LE>(x, 2)?;
379 Ok(3)
380 } else if x < 16_777_216 {
381 self.write_u8(0xFD)?;
382 self.write_uint::<LE>(x, 3)?;
383 Ok(4)
384 } else {
385 self.write_u8(0xFE)?;
386 self.write_uint::<LE>(x, 8)?;
387 Ok(9)
388 }
389 }
390
391 fn write_lenenc_str(&mut self, bytes: &[u8]) -> io::Result<u64> {
393 let written = self.write_lenenc_int(bytes.len() as u64)?;
394 self.write_all(bytes)?;
395 Ok(written + bytes.len() as u64)
396 }
397}
398
399impl<T> ReadMysqlExt for T where T: ReadBytesExt {}
400impl<T> WriteMysqlExt for T where T: WriteBytesExt {}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn be_le() {
408 let buf = ParseBuf(&[0, 1, 2]);
409 assert_eq!(buf.clone().eat_u24_le(), 0x00020100);
410 assert_eq!(buf.clone().eat_u24_be(), 0x00000102);
411 let buf = ParseBuf(&[0, 1, 2, 3, 4]);
412 assert_eq!(buf.clone().eat_u40_le(), 0x0000000403020100);
413 assert_eq!(buf.clone().eat_u40_be(), 0x0000000001020304);
414 let buf = ParseBuf(&[0, 1, 2, 3, 4, 5]);
415 assert_eq!(buf.clone().eat_u48_le(), 0x0000050403020100);
416 assert_eq!(buf.clone().eat_u48_be(), 0x0000000102030405);
417 let buf = ParseBuf(&[0, 1, 2, 3, 4, 5, 6]);
418 assert_eq!(buf.clone().eat_u56_le(), 0x0006050403020100);
419 assert_eq!(buf.clone().eat_u56_be(), 0x0000010203040506);
420 }
421}