mysql_common/
io.rs

1// Copyright (c) 2017 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use byteorder::{LittleEndian as LE, ReadBytesExt, WriteBytesExt};
10use bytes::BufMut;
11use std::{cmp::min, io};
12
13use crate::proto::MyDeserialize;
14
15pub trait BufMutExt: BufMut {
16    /// Writes an unsigned integer to self as MySql length-encoded integer.
17    fn put_lenenc_int(&mut self, n: u64) {
18        if n < 251 {
19            self.put_u8(n as u8);
20        } else if n < 65_536 {
21            self.put_u8(0xFC);
22            self.put_uint_le(n, 2);
23        } else if n < 16_777_216 {
24            self.put_u8(0xFD);
25            self.put_uint_le(n, 3);
26        } else {
27            self.put_u8(0xFE);
28            self.put_uint_le(n, 8);
29        }
30    }
31
32    /// Writes a slice to self as MySql length-encoded string.
33    fn put_lenenc_str(&mut self, s: &[u8]) {
34        self.put_lenenc_int(s.len() as u64);
35        self.put_slice(s);
36    }
37
38    /// Writes a 3-bytes unsigned integer.
39    fn put_u24_le(&mut self, x: u32) {
40        self.put_uint_le(x as u64, 3);
41    }
42
43    /// Writes a 3-bytes signed integer.
44    fn put_i24_le(&mut self, x: i32) {
45        self.put_int_le(x as i64, 3);
46    }
47
48    /// Writes a 6-bytes unsigned integer.
49    fn put_u48_le(&mut self, x: u64) {
50        self.put_uint_le(x, 6);
51    }
52
53    /// Writes a 7-bytes unsigned integer.
54    fn put_u56_le(&mut self, x: u64) {
55        self.put_uint_le(x, 7);
56    }
57
58    /// Writes a 7-bytes signed integer.
59    fn put_i56_le(&mut self, x: i64) {
60        self.put_int_le(x, 7);
61    }
62
63    /// Writes a string with u8 length prefix. Truncates, if the length is greater that `u8::MAX`.
64    fn put_u8_str(&mut self, s: &[u8]) {
65        let len = std::cmp::min(s.len(), u8::MAX as usize);
66        self.put_u8(len as u8);
67        self.put_slice(&s[..len]);
68    }
69
70    /// Writes a string with u32 length prefix. Truncates, if the length is greater that `u32::MAX`.
71    fn put_u32_str(&mut self, s: &[u8]) {
72        let len = std::cmp::min(s.len(), u32::MAX as usize);
73        self.put_u32_le(len as u32);
74        self.put_slice(&s[..len]);
75    }
76}
77
78impl<T: BufMut> BufMutExt for T {}
79
80#[derive(Debug, Copy, Clone, PartialEq, Eq)]
81pub struct ParseBuf<'a>(pub &'a [u8]);
82
83impl io::Read for ParseBuf<'_> {
84    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
85        let count = min(self.0.len(), buf.len());
86        (buf[..count]).copy_from_slice(&self.0[..count]);
87        self.0 = &self.0[count..];
88        Ok(count)
89    }
90}
91
92macro_rules! eat_num {
93    ($name:ident, $checked:ident, $t:ident::$fn:ident) => {
94        #[doc = "Consumes a number from the head of the buffer."]
95        pub fn $name(&mut self) -> $t {
96            const SIZE: usize = std::mem::size_of::<$t>();
97            let bytes = self.eat(SIZE);
98            unsafe { $t::$fn(*(bytes as *const _ as *const [_; SIZE])) }
99        }
100
101        #[doc = "Consumes a number from the head of the buffer. Returns `None` if buffer is too small."]
102        pub fn $checked(&mut self) -> Option<$t> {
103            if self.len() >= std::mem::size_of::<$t>() {
104                Some(self.$name())
105            } else {
106                None
107            }
108        }
109    };
110    ($name:ident, $checked:ident, $size:literal, $offset:literal, $t:ident::$fn:ident) => {
111        #[doc = "Consumes a number from the head of the buffer."]
112        pub fn $name(&mut self) -> $t {
113            const SIZE: usize = $size;
114            let mut x: $t = 0;
115            let bytes = self.eat(SIZE);
116            for (i, b) in bytes.iter().enumerate() {
117                x |= (*b as $t) << ((8 * i) + (8 * $offset));
118            }
119            $t::$fn(x)
120        }
121
122        #[doc = "Consumes a number from the head of the buffer. Returns `None` if buffer is too small."]
123        pub fn $checked(&mut self) -> Option<$t> {
124            if self.len() >= $size {
125                Some(self.$name())
126            } else {
127                None
128            }
129        }
130    };
131}
132
133impl<'a> ParseBuf<'a> {
134    /// Returns `T: MyDeserialize` deserialized from `self`.
135    ///
136    /// Note, that this may panic if `T::SIZE.is_some()` and less than `self.0.len()`.
137    #[inline(always)]
138    pub fn parse_unchecked<T>(&mut self, ctx: T::Ctx) -> io::Result<T>
139    where
140        T: MyDeserialize<'a>,
141    {
142        T::deserialize(ctx, self)
143    }
144
145    /// Checked `parse`.
146    #[inline(always)]
147    pub fn parse<T>(&mut self, ctx: T::Ctx) -> io::Result<T>
148    where
149        T: MyDeserialize<'a>,
150    {
151        match T::SIZE {
152            Some(size) => {
153                let mut buf: ParseBuf = self.parse_unchecked(size)?;
154                buf.parse_unchecked(ctx)
155            }
156            None => self.parse_unchecked(ctx),
157        }
158    }
159
160    /// Returns true if buffer is empty.
161    pub fn is_empty(&self) -> bool {
162        self.len() == 0
163    }
164
165    /// Returns the number of bytes in the buffer.
166    pub fn len(&self) -> usize {
167        self.0.len()
168    }
169
170    /// Skips the given number of bytes.
171    ///
172    /// Afterwards self contains elements `[cnt, len)`.
173    pub fn skip(&mut self, cnt: usize) {
174        self.0 = &self.0[cnt..];
175    }
176
177    /// Same as `skip` but returns `false` if buffer is too small.
178    pub fn checked_skip(&mut self, cnt: usize) -> bool {
179        if self.len() >= cnt {
180            self.skip(cnt);
181            true
182        } else {
183            false
184        }
185    }
186
187    /// Splits the buffer into two at the given index. Returns elements `[0, n)`.
188    ///
189    /// Afterwards self contains elements `[n, len)`.
190    ///
191    /// # Panic
192    ///
193    /// Will panic if `n > self.len()`.
194    pub fn eat(&mut self, n: usize) -> &'a [u8] {
195        let (left, right) = self.0.split_at(n);
196        self.0 = right;
197        left
198    }
199
200    pub fn eat_buf(&mut self, n: usize) -> Self {
201        Self(self.eat(n))
202    }
203
204    /// Same as `eat`. Returns `None` if buffer is too small.
205    pub fn checked_eat(&mut self, n: usize) -> Option<&'a [u8]> {
206        if self.len() >= n {
207            Some(self.eat(n))
208        } else {
209            None
210        }
211    }
212
213    pub fn checked_eat_buf(&mut self, n: usize) -> Option<Self> {
214        Some(Self(self.checked_eat(n)?))
215    }
216
217    pub fn eat_all(&mut self) -> &'a [u8] {
218        self.eat(self.len())
219    }
220
221    eat_num!(eat_u8, checked_eat_u8, u8::from_le_bytes);
222    eat_num!(eat_i8, checked_eat_i8, i8::from_le_bytes);
223    eat_num!(eat_u16_le, checked_eat_u16_le, u16::from_le_bytes);
224    eat_num!(eat_i16_le, checked_eat_i16_le, i16::from_le_bytes);
225    eat_num!(eat_u16_be, checked_eat_u16_be, u16::from_be_bytes);
226    eat_num!(eat_i16_be, checked_eat_i16_be, i16::from_be_bytes);
227    eat_num!(eat_u24_le, checked_eat_u24_le, 3, 0, u32::from_le);
228    eat_num!(eat_i24_le, checked_eat_i24_le, 3, 0, i32::from_le);
229    eat_num!(eat_u24_be, checked_eat_u24_be, 3, 1, u32::from_be);
230    eat_num!(eat_i24_be, checked_eat_i24_be, 3, 1, i32::from_be);
231    eat_num!(eat_u32_le, checked_eat_u32_le, u32::from_le_bytes);
232    eat_num!(eat_i32_le, checked_eat_i32_le, i32::from_le_bytes);
233    eat_num!(eat_u32_be, checked_eat_u32_be, u32::from_be_bytes);
234    eat_num!(eat_i32_be, checked_eat_i32_be, i32::from_be_bytes);
235    eat_num!(eat_u40_le, checked_eat_u40_le, 5, 0, u64::from_le);
236    eat_num!(eat_i40_le, checked_eat_i40_le, 5, 0, i64::from_le);
237    eat_num!(eat_u40_be, checked_eat_u40_be, 5, 3, u64::from_be);
238    eat_num!(eat_i40_be, checked_eat_i40_be, 5, 3, i64::from_be);
239    eat_num!(eat_u48_le, checked_eat_u48_le, 6, 0, u64::from_le);
240    eat_num!(eat_i48_le, checked_eat_i48_le, 6, 0, i64::from_le);
241    eat_num!(eat_u48_be, checked_eat_u48_be, 6, 2, u64::from_be);
242    eat_num!(eat_i48_be, checked_eat_i48_be, 6, 2, i64::from_be);
243    eat_num!(eat_u56_le, checked_eat_u56_le, 7, 0, u64::from_le);
244    eat_num!(eat_i56_le, checked_eat_i56_le, 7, 0, i64::from_le);
245    eat_num!(eat_u56_be, checked_eat_u56_be, 7, 1, u64::from_be);
246    eat_num!(eat_i56_be, checked_eat_i56_be, 7, 1, i64::from_be);
247    eat_num!(eat_u64_le, checked_eat_u64_le, u64::from_le_bytes);
248    eat_num!(eat_i64_le, checked_eat_i64_le, i64::from_le_bytes);
249    eat_num!(eat_u64_be, checked_eat_u64_be, u64::from_be_bytes);
250    eat_num!(eat_i64_be, checked_eat_i64_be, i64::from_be_bytes);
251    eat_num!(eat_u128_le, checked_eat_u128_le, u128::from_le_bytes);
252    eat_num!(eat_i128_le, checked_eat_i128_le, i128::from_le_bytes);
253    eat_num!(eat_u128_be, checked_eat_u128_be, u128::from_be_bytes);
254    eat_num!(eat_i128_be, checked_eat_i128_be, i128::from_be_bytes);
255
256    eat_num!(eat_f32_le, checked_eat_f32_le, f32::from_le_bytes);
257    eat_num!(eat_f32_be, checked_eat_f32_be, f32::from_be_bytes);
258
259    eat_num!(eat_f64_le, checked_eat_f64_le, f64::from_le_bytes);
260    eat_num!(eat_f64_be, checked_eat_f64_be, f64::from_be_bytes);
261
262    /// Consumes MySql length-encoded integer from the head of the buffer.
263    ///
264    /// Returns `0` if integer is maliformed (starts with 0xff or 0xfb). First byte will be eaten.
265    pub fn eat_lenenc_int(&mut self) -> u64 {
266        match self.eat_u8() {
267            x @ 0..=0xfa => x as u64,
268            0xfc => self.eat_u16_le() as u64,
269            0xfd => self.eat_u24_le() as u64,
270            0xfe => self.eat_u64_le(),
271            0xfb | 0xff => 0,
272        }
273    }
274
275    /// Same as `eat_lenenc_int`. Returns `None` if buffer is too small.
276    pub fn checked_eat_lenenc_int(&mut self) -> Option<u64> {
277        match self.checked_eat_u8()? {
278            x @ 0..=0xfa => Some(x as u64),
279            0xfc => self.checked_eat_u16_le().map(|x| x as u64),
280            0xfd => self.checked_eat_u24_le().map(|x| x as u64),
281            0xfe => self.checked_eat_u64_le(),
282            0xfb | 0xff => Some(0),
283        }
284    }
285
286    /// Consumes MySql length-encoded string from the head of the buffer.
287    ///
288    /// Returns an empty slice if length is maliformed (starts with 0xff). First byte will be eaten.
289    pub fn eat_lenenc_str(&mut self) -> &'a [u8] {
290        let len = self.eat_lenenc_int();
291        self.eat(len as usize)
292    }
293
294    /// Same as `eat_lenenc_str`. Returns `None` if buffer is too small.
295    pub fn checked_eat_lenenc_str(&mut self) -> Option<&'a [u8]> {
296        let len = self.checked_eat_lenenc_int()?;
297        self.checked_eat(len as usize)
298    }
299
300    /// Consumes MySql string with u8 length prefix from the head of the buffer.
301    pub fn eat_u8_str(&mut self) -> &'a [u8] {
302        let len = self.eat_u8();
303        self.eat(len as usize)
304    }
305
306    /// Same as `eat_u8_str`. Returns `None` if buffer is too small.
307    pub fn checked_eat_u8_str(&mut self) -> Option<&'a [u8]> {
308        let len = self.checked_eat_u8()?;
309        self.checked_eat(len as usize)
310    }
311
312    /// Consumes MySql string with u32 length prefix from the head of the buffer.
313    pub fn eat_u32_str(&mut self) -> &'a [u8] {
314        let len = self.eat_u32_le();
315        self.eat(len as usize)
316    }
317
318    /// Same as `eat_u32_str`. Returns `None` if buffer is too small.
319    pub fn checked_eat_u32_str(&mut self) -> Option<&'a [u8]> {
320        let len = self.checked_eat_u32_le()?;
321        self.checked_eat(len as usize)
322    }
323
324    /// Consumes null-terminated string from the head of the buffer.
325    ///
326    /// Consumes whole buffer if there is no `0`-byte.
327    pub fn eat_null_str(&mut self) -> &'a [u8] {
328        let pos = self
329            .0
330            .iter()
331            .position(|x| *x == 0)
332            .map(|x| x + 1)
333            .unwrap_or_else(|| self.len());
334        match self.eat(pos) {
335            [head @ .., 0_u8] => head,
336            x => x,
337        }
338    }
339}
340
341#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
342#[error("Invalid length-encoded integer value (starts with 0xfb|0xff)")]
343pub struct InvalidLenghEncodedInteger;
344
345pub trait ReadMysqlExt: ReadBytesExt {
346    /// Reads MySql's length-encoded integer.
347    fn read_lenenc_int(&mut self) -> io::Result<u64> {
348        match self.read_u8()? {
349            x if x <= 0xfa => Ok(x.into()),
350            0xfc => self.read_uint::<LE>(2),
351            0xfd => self.read_uint::<LE>(3),
352            0xfe => self.read_uint::<LE>(8),
353            0xfb | 0xff => Err(io::Error::new(
354                io::ErrorKind::Other,
355                InvalidLenghEncodedInteger,
356            )),
357            _ => unreachable!(),
358        }
359    }
360
361    /// Reads MySql's length-encoded string.
362    fn read_lenenc_str(&mut self) -> io::Result<Vec<u8>> {
363        let len = self.read_lenenc_int()?;
364        let mut output = vec![0_u8; len as usize];
365        self.read_exact(&mut output)?;
366        Ok(output)
367    }
368}
369
370pub trait WriteMysqlExt: WriteBytesExt {
371    /// Writes MySql's length-encoded integer.
372    fn write_lenenc_int(&mut self, x: u64) -> io::Result<u64> {
373        if x < 251 {
374            self.write_u8(x as u8)?;
375            Ok(1)
376        } else if x < 65_536 {
377            self.write_u8(0xFC)?;
378            self.write_uint::<LE>(x, 2)?;
379            Ok(3)
380        } else if x < 16_777_216 {
381            self.write_u8(0xFD)?;
382            self.write_uint::<LE>(x, 3)?;
383            Ok(4)
384        } else {
385            self.write_u8(0xFE)?;
386            self.write_uint::<LE>(x, 8)?;
387            Ok(9)
388        }
389    }
390
391    /// Writes MySql's length-encoded string.
392    fn write_lenenc_str(&mut self, bytes: &[u8]) -> io::Result<u64> {
393        let written = self.write_lenenc_int(bytes.len() as u64)?;
394        self.write_all(bytes)?;
395        Ok(written + bytes.len() as u64)
396    }
397}
398
399impl<T> ReadMysqlExt for T where T: ReadBytesExt {}
400impl<T> WriteMysqlExt for T where T: WriteBytesExt {}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn be_le() {
408        let buf = ParseBuf(&[0, 1, 2]);
409        assert_eq!(buf.clone().eat_u24_le(), 0x00020100);
410        assert_eq!(buf.clone().eat_u24_be(), 0x00000102);
411        let buf = ParseBuf(&[0, 1, 2, 3, 4]);
412        assert_eq!(buf.clone().eat_u40_le(), 0x0000000403020100);
413        assert_eq!(buf.clone().eat_u40_be(), 0x0000000001020304);
414        let buf = ParseBuf(&[0, 1, 2, 3, 4, 5]);
415        assert_eq!(buf.clone().eat_u48_le(), 0x0000050403020100);
416        assert_eq!(buf.clone().eat_u48_be(), 0x0000000102030405);
417        let buf = ParseBuf(&[0, 1, 2, 3, 4, 5, 6]);
418        assert_eq!(buf.clone().eat_u56_le(), 0x0006050403020100);
419        assert_eq!(buf.clone().eat_u56_be(), 0x0000010203040506);
420    }
421}