mysql_common/packets/
mod.rs

1// Copyright (c) 2017 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::Arc;
16use std::{
17    borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::{
22    constants::{
23        CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, SessionStateType,
24        StatusFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags, MAX_PAYLOAD_LEN,
25    },
26    io::{BufMutExt, ParseBuf},
27    misc::{
28        lenenc_str_len,
29        raw::{
30            bytes::{
31                BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U32Bytes,
32                U8Bytes,
33            },
34            int::{ConstU32, ConstU8, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
35            seq::Seq,
36            Const, Either, RawBytes, RawConst, RawInt, Skip,
37        },
38        unexpected_buf_eof,
39    },
40    proto::{MyDeserialize, MySerialize},
41    value::{ClientSide, SerializationSide, Value},
42};
43
44use self::session_state_change::SessionStateChange;
45
46lazy_static::lazy_static! {
47    static ref MARIADB_VERSION_RE: Regex =
48        Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap();
49    static ref VERSION_RE: Regex = Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap();
50}
51
52macro_rules! define_header {
53    ($name:ident, $err:ident($msg:literal), $val:literal) => {
54        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
55        #[error($msg)]
56        pub struct $err;
57        pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
58    };
59    ($name:ident, $cmd:ident, $err:ident) => {
60        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
61        #[error("Invalid header for {}", stringify!($cmd))]
62        pub struct $err;
63        pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
64    };
65}
66
67macro_rules! define_const {
68    ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
69        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
70        #[error($msg)]
71        pub struct $err;
72        pub type $name = $kind<$err, $val>;
73    };
74}
75
76macro_rules! define_const_bytes {
77    ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr, $len:literal) => {
78        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
79        #[error($msg)]
80        pub struct $err;
81
82        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
83        pub struct $vname;
84
85        impl ConstBytesValue<$len> for $vname {
86            const VALUE: [u8; $len] = $val;
87            type Error = $err;
88        }
89
90        pub type $name = ConstBytes<$vname, $len>;
91    };
92}
93
94pub mod binlog_request;
95pub mod caching_sha2_password;
96pub mod session_state_change;
97
98define_const_bytes!(
99    Catalog,
100    ColumnDefinitionCatalog,
101    InvalidCatalog("Invalid catalog value in the column definition"),
102    *b"\x03def",
103    4
104);
105
106define_const!(
107    ConstU8,
108    FixedLengthFieldsLen,
109    InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
110    0x0c
111);
112
113/// Dynamically-sized column metadata — a part of the [`Column`] packet.
114#[derive(Debug, Default, Clone, Eq, PartialEq)]
115struct ColumnMeta<'a> {
116    schema: RawBytes<'a, LenEnc>,
117    table: RawBytes<'a, LenEnc>,
118    org_table: RawBytes<'a, LenEnc>,
119    name: RawBytes<'a, LenEnc>,
120    org_name: RawBytes<'a, LenEnc>,
121}
122
123impl ColumnMeta<'_> {
124    pub fn into_owned(self) -> ColumnMeta<'static> {
125        ColumnMeta {
126            schema: self.schema.into_owned(),
127            table: self.table.into_owned(),
128            org_table: self.org_table.into_owned(),
129            name: self.name.into_owned(),
130            org_name: self.org_name.into_owned(),
131        }
132    }
133
134    /// Returns the value of the [`ColumnMeta::schema`] field as a byte slice.
135    pub fn schema_ref(&self) -> &[u8] {
136        self.schema.as_bytes()
137    }
138
139    /// Returns the value of the [`ColumnMeta::schema`] field as a string (lossy converted).
140    pub fn schema_str(&self) -> Cow<'_, str> {
141        String::from_utf8_lossy(self.schema_ref())
142    }
143
144    /// Returns the value of the [`ColumnMeta::table`] field as a byte slice.
145    pub fn table_ref(&self) -> &[u8] {
146        self.table.as_bytes()
147    }
148
149    /// Returns the value of the [`ColumnMeta::table`] field as a string (lossy converted).
150    pub fn table_str(&self) -> Cow<'_, str> {
151        String::from_utf8_lossy(self.table_ref())
152    }
153
154    /// Returns the value of the [`ColumnMeta::org_table`] field as a byte slice.
155    ///
156    /// "org_table" is for original table name.
157    pub fn org_table_ref(&self) -> &[u8] {
158        self.org_table.as_bytes()
159    }
160
161    /// Returns the value of the [`ColumnMeta::org_table`] field as a string (lossy converted).
162    pub fn org_table_str(&self) -> Cow<'_, str> {
163        String::from_utf8_lossy(self.org_table_ref())
164    }
165
166    /// Returns the value of the [`ColumnMeta::name`] field as a byte slice.
167    pub fn name_ref(&self) -> &[u8] {
168        self.name.as_bytes()
169    }
170
171    /// Returns the value of the [`ColumnMeta::name`] field as a string (lossy converted).
172    pub fn name_str(&self) -> Cow<'_, str> {
173        String::from_utf8_lossy(self.name_ref())
174    }
175
176    /// Returns the value of the [`ColumnMeta::org_name`] field as a byte slice.
177    ///
178    /// "org_name" is for original column name.
179    pub fn org_name_ref(&self) -> &[u8] {
180        self.org_name.as_bytes()
181    }
182
183    /// Returns value of the [`ColumnMeta::org_name`] field as a string (lossy converted).
184    pub fn org_name_str(&self) -> Cow<'_, str> {
185        String::from_utf8_lossy(self.org_name_ref())
186    }
187}
188
189impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
190    const SIZE: Option<usize> = None;
191    type Ctx = ();
192
193    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
194        Ok(Self {
195            schema: buf.parse_unchecked(())?,
196            table: buf.parse_unchecked(())?,
197            org_table: buf.parse_unchecked(())?,
198            name: buf.parse_unchecked(())?,
199            org_name: buf.parse_unchecked(())?,
200        })
201    }
202}
203
204impl MySerialize for ColumnMeta<'_> {
205    fn serialize(&self, buf: &mut Vec<u8>) {
206        self.schema.serialize(&mut *buf);
207        self.table.serialize(&mut *buf);
208        self.org_table.serialize(&mut *buf);
209        self.name.serialize(&mut *buf);
210        self.org_name.serialize(&mut *buf);
211    }
212}
213
214/// Represents MySql Column (column packet).
215#[derive(Debug, Clone, Eq, PartialEq)]
216pub struct Column {
217    catalog: ColumnDefinitionCatalog,
218    meta: Arc<ColumnMeta<'static>>,
219    fixed_length_fields_len: FixedLengthFieldsLen,
220    column_length: RawInt<LeU32>,
221    character_set: RawInt<LeU16>,
222    column_type: Const<ColumnType, u8>,
223    flags: Const<ColumnFlags, LeU16>,
224    decimals: RawInt<u8>,
225    __filler: Skip<2>,
226    // COM_FIELD_LIST is deprecated, so we won't support it
227}
228
229impl<'de> MyDeserialize<'de> for Column {
230    const SIZE: Option<usize> = None;
231    type Ctx = ();
232
233    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
234        let catalog = buf.parse(())?;
235        let meta = Arc::new(buf.parse::<ColumnMeta>(())?.into_owned());
236        let mut buf: ParseBuf = buf.parse(13)?;
237
238        Ok(Column {
239            catalog,
240            meta,
241            fixed_length_fields_len: buf.parse_unchecked(())?,
242            character_set: buf.parse_unchecked(())?,
243            column_length: buf.parse_unchecked(())?,
244            column_type: buf.parse_unchecked(())?,
245            flags: buf.parse_unchecked(())?,
246            decimals: buf.parse_unchecked(())?,
247            __filler: buf.parse_unchecked(())?,
248        })
249    }
250}
251
252impl MySerialize for Column {
253    fn serialize(&self, buf: &mut Vec<u8>) {
254        self.catalog.serialize(&mut *buf);
255        self.meta.serialize(&mut *buf);
256        self.fixed_length_fields_len.serialize(&mut *buf);
257        self.column_length.serialize(&mut *buf);
258        self.character_set.serialize(&mut *buf);
259        self.column_type.serialize(&mut *buf);
260        self.flags.serialize(&mut *buf);
261        self.decimals.serialize(&mut *buf);
262        self.__filler.serialize(&mut *buf);
263    }
264}
265
266impl Column {
267    pub fn new(column_type: ColumnType) -> Self {
268        Self {
269            catalog: Default::default(),
270            meta: Default::default(),
271            fixed_length_fields_len: Default::default(),
272            column_length: Default::default(),
273            character_set: Default::default(),
274            flags: Default::default(),
275            column_type: Const::new(column_type),
276            decimals: Default::default(),
277            __filler: Skip,
278        }
279    }
280
281    pub fn with_schema(mut self, schema: &[u8]) -> Self {
282        Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
283        self
284    }
285
286    pub fn with_table(mut self, table: &[u8]) -> Self {
287        Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
288        self
289    }
290
291    pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
292        Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
293        self
294    }
295
296    pub fn with_name(mut self, name: &[u8]) -> Self {
297        Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
298        self
299    }
300
301    pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
302        Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
303        self
304    }
305
306    pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
307        self.flags = Const::new(flags);
308        self
309    }
310
311    pub fn with_column_length(mut self, column_length: u32) -> Self {
312        self.column_length = RawInt::new(column_length);
313        self
314    }
315
316    pub fn with_character_set(mut self, character_set: u16) -> Self {
317        self.character_set = RawInt::new(character_set);
318        self
319    }
320
321    pub fn with_decimals(mut self, decimals: u8) -> Self {
322        self.decimals = RawInt::new(decimals);
323        self
324    }
325
326    /// Returns value of the column_length field of a column packet.
327    ///
328    /// Can be used for text-output formatting.
329    pub fn column_length(&self) -> u32 {
330        *self.column_length
331    }
332
333    /// Returns value of the column_type field of a column packet.
334    pub fn column_type(&self) -> ColumnType {
335        *self.column_type
336    }
337
338    /// Returns value of the character_set field of a column packet.
339    pub fn character_set(&self) -> u16 {
340        *self.character_set
341    }
342
343    /// Returns value of the flags field of a column packet.
344    pub fn flags(&self) -> ColumnFlags {
345        *self.flags
346    }
347
348    /// Returns value of the decimals field of a column packet.
349    ///
350    /// Max shown decimal digits. Can be used for text-output formatting
351    ///
352    /// *   `0x00` for integers and static strings
353    /// *   `0x1f` for dynamic strings, double, float
354    /// *   `0x00..=0x51` for decimals
355    pub fn decimals(&self) -> u8 {
356        *self.decimals
357    }
358
359    /// Returns value of the schema field of a column packet as a byte slice.
360    #[inline(always)]
361    pub fn schema_ref(&self) -> &[u8] {
362        self.meta.schema_ref()
363    }
364
365    /// Returns value of the schema field of a column packet as a string (lossy converted).
366    #[inline(always)]
367    pub fn schema_str(&self) -> Cow<'_, str> {
368        self.meta.schema_str()
369    }
370
371    /// Returns value of the table field of a column packet as a byte slice.
372    #[inline(always)]
373    pub fn table_ref(&self) -> &[u8] {
374        self.meta.table_ref()
375    }
376
377    /// Returns value of the table field of a column packet as a string (lossy converted).
378    #[inline(always)]
379    pub fn table_str(&self) -> Cow<'_, str> {
380        self.meta.table_str()
381    }
382
383    /// Returns value of the org_table field of a column packet as a byte slice.
384    ///
385    /// "org_table" is for original table name.
386    #[inline(always)]
387    pub fn org_table_ref(&self) -> &[u8] {
388        self.meta.org_table_ref()
389    }
390
391    /// Returns value of the org_table field of a column packet as a string (lossy converted).
392    #[inline(always)]
393    pub fn org_table_str(&self) -> Cow<'_, str> {
394        self.meta.org_table_str()
395    }
396
397    /// Returns value of the name field of a column packet as a byte slice.
398    #[inline(always)]
399    pub fn name_ref(&self) -> &[u8] {
400        self.meta.name_ref()
401    }
402
403    /// Returns value of the name field of a column packet as a string (lossy converted).
404    #[inline(always)]
405    pub fn name_str(&self) -> Cow<'_, str> {
406        self.meta.name_str()
407    }
408
409    /// Returns value of the org_name field of a column packet as a byte slice.
410    ///
411    /// "org_name" is for original column name.
412    #[inline(always)]
413    pub fn org_name_ref(&self) -> &[u8] {
414        self.meta.org_name_ref()
415    }
416
417    /// Returns value of the org_name field of a column packet as a string (lossy converted).
418    #[inline(always)]
419    pub fn org_name_str(&self) -> Cow<'_, str> {
420        self.meta.org_name_str()
421    }
422}
423
424/// Represents change in session state (part of MySql's Ok packet).
425#[derive(Debug, Clone, Eq, PartialEq)]
426pub struct SessionStateInfo<'a> {
427    data_type: Const<SessionStateType, u8>,
428    data: RawBytes<'a, LenEnc>,
429}
430
431impl SessionStateInfo<'_> {
432    pub fn into_owned(self) -> SessionStateInfo<'static> {
433        let SessionStateInfo { data_type, data } = self;
434        SessionStateInfo {
435            data_type,
436            data: data.into_owned(),
437        }
438    }
439
440    pub fn data_type(&self) -> SessionStateType {
441        *self.data_type
442    }
443
444    /// Returns raw session state info data.
445    pub fn data_ref(&self) -> &[u8] {
446        self.data.as_bytes()
447    }
448
449    /// Tries to decode session state info data.
450    pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
451        ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
452    }
453}
454
455impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
456    const SIZE: Option<usize> = None;
457    type Ctx = ();
458
459    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
460        Ok(SessionStateInfo {
461            data_type: buf.parse(())?,
462            data: buf.parse(())?,
463        })
464    }
465}
466
467impl MySerialize for SessionStateInfo<'_> {
468    fn serialize(&self, buf: &mut Vec<u8>) {
469        self.data_type.serialize(&mut *buf);
470        self.data.serialize(buf);
471    }
472}
473
474/// Represents MySql's Ok packet.
475#[derive(Debug, Clone, Eq, PartialEq)]
476pub struct OkPacketBody<'a> {
477    affected_rows: RawInt<LenEnc>,
478    last_insert_id: RawInt<LenEnc>,
479    status_flags: Const<StatusFlags, LeU16>,
480    warnings: RawInt<LeU16>,
481    info: RawBytes<'a, LenEnc>,
482    session_state_info: RawBytes<'a, LenEnc>,
483}
484
485/// OK packet kind (see _OK packet identifier_ section of [WL#7766][1]).
486///
487/// [1]: https://dev.mysql.com/worklog/task/?id=7766
488pub trait OkPacketKind {
489    const HEADER: u8;
490
491    fn parse_body<'de>(
492        capabilities: CapabilityFlags,
493        buf: &mut ParseBuf<'de>,
494    ) -> io::Result<OkPacketBody<'de>>;
495}
496
497/// Ok packet that terminates a result set (text or binary).
498#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
499pub struct ResultSetTerminator;
500
501impl OkPacketKind for ResultSetTerminator {
502    const HEADER: u8 = 0xFE;
503
504    fn parse_body<'de>(
505        capabilities: CapabilityFlags,
506        buf: &mut ParseBuf<'de>,
507    ) -> io::Result<OkPacketBody<'de>> {
508        // We need to skip affected_rows and insert_id here
509        // because valid content of EOF packet includes
510        // packet marker, server status and warning count only.
511        // (see `read_ok_ex` in sql-common/client.cc)
512        buf.parse::<RawInt<LenEnc>>(())?;
513        buf.parse::<RawInt<LenEnc>>(())?;
514
515        // assume CLIENT_PROTOCOL_41 flag
516        let mut sbuf: ParseBuf = buf.parse(4)?;
517        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
518        let warnings = sbuf.parse_unchecked(())?;
519
520        let (info, session_state_info) =
521            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
522                let info = buf.parse(())?;
523                let session_state_info =
524                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
525                        buf.parse(())?
526                    } else {
527                        RawBytes::default()
528                    };
529                (info, session_state_info)
530            } else if !buf.is_empty() && buf.0[0] > 0 {
531                // The `info` field is a `string<EOF>` according to the MySQL Internals
532                // Manual, but actually it's a `string<lenenc>`.
533                // SEE: sql/protocol_classics.cc `net_send_ok`
534                let info = buf.parse(())?;
535                (info, RawBytes::default())
536            } else {
537                (RawBytes::default(), RawBytes::default())
538            };
539
540        Ok(OkPacketBody {
541            affected_rows: RawInt::new(0),
542            last_insert_id: RawInt::new(0),
543            status_flags,
544            warnings,
545            info,
546            session_state_info,
547        })
548    }
549}
550
551/// Old deprecated EOF packet.
552#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
553pub struct OldEofPacket;
554
555impl OkPacketKind for OldEofPacket {
556    const HEADER: u8 = 0xFE;
557
558    fn parse_body<'de>(
559        _: CapabilityFlags,
560        buf: &mut ParseBuf<'de>,
561    ) -> io::Result<OkPacketBody<'de>> {
562        // We assume that CLIENT_PROTOCOL_41 was set
563        let mut buf: ParseBuf = buf.parse(4)?;
564        let warnings = buf.parse_unchecked(())?;
565        let status_flags = buf.parse_unchecked(())?;
566
567        Ok(OkPacketBody {
568            affected_rows: RawInt::new(0),
569            last_insert_id: RawInt::new(0),
570            status_flags,
571            warnings,
572            info: RawBytes::new(&[][..]),
573            session_state_info: RawBytes::new(&[][..]),
574        })
575    }
576}
577
578/// This packet terminates a binlog network stream.
579#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
580pub struct NetworkStreamTerminator;
581
582impl OkPacketKind for NetworkStreamTerminator {
583    const HEADER: u8 = 0xFE;
584
585    fn parse_body<'de>(
586        flags: CapabilityFlags,
587        buf: &mut ParseBuf<'de>,
588    ) -> io::Result<OkPacketBody<'de>> {
589        OldEofPacket::parse_body(flags, buf)
590    }
591}
592
593/// Ok packet that is not a result set terminator.
594#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
595pub struct CommonOkPacket;
596
597impl OkPacketKind for CommonOkPacket {
598    const HEADER: u8 = 0x00;
599
600    fn parse_body<'de>(
601        capabilities: CapabilityFlags,
602        buf: &mut ParseBuf<'de>,
603    ) -> io::Result<OkPacketBody<'de>> {
604        let affected_rows = buf.parse(())?;
605        let last_insert_id = buf.parse(())?;
606
607        // We assume that CLIENT_PROTOCOL_41 was set
608        let mut sbuf: ParseBuf = buf.parse(4)?;
609        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
610        let warnings = sbuf.parse_unchecked(())?;
611
612        let (info, session_state_info) =
613            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
614                let info = buf.parse(())?;
615                let session_state_info =
616                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
617                        buf.parse(())?
618                    } else {
619                        RawBytes::default()
620                    };
621                (info, session_state_info)
622            } else if !buf.is_empty() && buf.0[0] > 0 {
623                // The `info` field is a `string<EOF>` according to the MySQL Internals
624                // Manual, but actually it's a `string<lenenc>`.
625                // SEE: sql/protocol_classics.cc `net_send_ok`
626                let info = buf.parse(())?;
627                (info, RawBytes::default())
628            } else {
629                (RawBytes::default(), RawBytes::default())
630            };
631
632        Ok(OkPacketBody {
633            affected_rows,
634            last_insert_id,
635            status_flags,
636            warnings,
637            info,
638            session_state_info,
639        })
640    }
641}
642
643impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
644    type Error = io::Error;
645
646    fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
647        Ok(OkPacket {
648            affected_rows: *body.affected_rows,
649            last_insert_id: if *body.last_insert_id == 0 {
650                None
651            } else {
652                Some(*body.last_insert_id)
653            },
654            status_flags: *body.status_flags,
655            warnings: *body.warnings,
656            info: if !body.info.is_empty() {
657                Some(body.info)
658            } else {
659                None
660            },
661            session_state_info: if !body.session_state_info.is_empty() {
662                Some(body.session_state_info)
663            } else {
664                None
665            },
666        })
667    }
668}
669
670/// Represents MySql's Ok packet.
671#[derive(Debug, Clone, Eq, PartialEq)]
672pub struct OkPacket<'a> {
673    affected_rows: u64,
674    last_insert_id: Option<u64>,
675    status_flags: StatusFlags,
676    warnings: u16,
677    info: Option<RawBytes<'a, LenEnc>>,
678    session_state_info: Option<RawBytes<'a, LenEnc>>,
679}
680
681impl OkPacket<'_> {
682    pub fn into_owned(self) -> OkPacket<'static> {
683        OkPacket {
684            affected_rows: self.affected_rows,
685            last_insert_id: self.last_insert_id,
686            status_flags: self.status_flags,
687            warnings: self.warnings,
688            info: self.info.map(|x| x.into_owned()),
689            session_state_info: self.session_state_info.map(|x| x.into_owned()),
690        }
691    }
692
693    /// Value of the affected_rows field of an Ok packet.
694    pub fn affected_rows(&self) -> u64 {
695        self.affected_rows
696    }
697
698    /// Value of the last_insert_id field of an Ok packet.
699    pub fn last_insert_id(&self) -> Option<u64> {
700        self.last_insert_id
701    }
702
703    /// Value of the status_flags field of an Ok packet.
704    pub fn status_flags(&self) -> StatusFlags {
705        self.status_flags
706    }
707
708    /// Value of the warnings field of an Ok packet.
709    pub fn warnings(&self) -> u16 {
710        self.warnings
711    }
712
713    /// Value of the info field of an Ok packet as a byte slice.
714    pub fn info_ref(&self) -> Option<&[u8]> {
715        self.info.as_ref().map(|x| x.as_bytes())
716    }
717
718    /// Value of the info field of an Ok packet as a string (lossy converted).
719    pub fn info_str(&self) -> Option<Cow<str>> {
720        self.info.as_ref().map(|x| x.as_str())
721    }
722
723    /// Returns raw reference to a session state info.
724    pub fn session_state_info_ref(&self) -> Option<&[u8]> {
725        self.session_state_info.as_ref().map(|x| x.as_bytes())
726    }
727
728    /// Tries to parse session state info, if any.
729    pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
730        self.session_state_info_ref()
731            .map(|data| {
732                let mut data = ParseBuf(data);
733                let mut entries = Vec::new();
734                while !data.is_empty() {
735                    entries.push(data.parse(())?);
736                }
737                Ok(entries)
738            })
739            .transpose()
740            .map(|x| x.unwrap_or_default())
741    }
742}
743
744#[derive(Debug, Clone, PartialEq, Eq)]
745pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
746
747impl<'de, T> OkPacketDeserializer<'de, T> {
748    pub fn into_inner(self) -> OkPacket<'de> {
749        self.0
750    }
751}
752
753impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
754    fn from(x: OkPacketDeserializer<'de, T>) -> Self {
755        x.0
756    }
757}
758
759#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
760#[error("Invalid OK packet header")]
761pub struct InvalidOkPacketHeader;
762
763impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
764    const SIZE: Option<usize> = None;
765    type Ctx = CapabilityFlags;
766
767    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
768        if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
769            let body = T::parse_body(capabilities, buf)?;
770            let ok = OkPacket::try_from(body)?;
771            Ok(Self(ok, PhantomData))
772        } else {
773            Err(io::Error::new(
774                io::ErrorKind::InvalidData,
775                InvalidOkPacketHeader,
776            ))
777        }
778    }
779}
780
781/// Progress report information (may be in an error packet of MariaDB server).
782#[derive(Debug, Clone, Eq, PartialEq)]
783pub struct ProgressReport<'a> {
784    stage: RawInt<u8>,
785    max_stage: RawInt<u8>,
786    progress: RawInt<LeU24>,
787    stage_info: RawBytes<'a, LenEnc>,
788}
789
790impl<'a> ProgressReport<'a> {
791    pub fn new(
792        stage: u8,
793        max_stage: u8,
794        progress: u32,
795        stage_info: impl Into<Cow<'a, [u8]>>,
796    ) -> ProgressReport<'a> {
797        ProgressReport {
798            stage: RawInt::new(stage),
799            max_stage: RawInt::new(max_stage),
800            progress: RawInt::new(progress),
801            stage_info: RawBytes::new(stage_info),
802        }
803    }
804
805    /// 1 to max_stage
806    pub fn stage(&self) -> u8 {
807        *self.stage
808    }
809
810    pub fn max_stage(&self) -> u8 {
811        *self.max_stage
812    }
813
814    /// Progress as '% * 1000'
815    pub fn progress(&self) -> u32 {
816        *self.progress
817    }
818
819    /// Status or state name as a byte slice.
820    pub fn stage_info_ref(&self) -> &[u8] {
821        self.stage_info.as_bytes()
822    }
823
824    /// Status or state name as a string (lossy converted).
825    pub fn stage_info_str(&self) -> Cow<'_, str> {
826        self.stage_info.as_str()
827    }
828
829    pub fn into_owned(self) -> ProgressReport<'static> {
830        ProgressReport {
831            stage: self.stage,
832            max_stage: self.max_stage,
833            progress: self.progress,
834            stage_info: self.stage_info.into_owned(),
835        }
836    }
837}
838
839impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
840    const SIZE: Option<usize> = None;
841    type Ctx = ();
842
843    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
844        let mut sbuf: ParseBuf = buf.parse(6)?;
845
846        sbuf.skip(1); // Ignore number of strings.
847
848        Ok(ProgressReport {
849            stage: sbuf.parse_unchecked(())?,
850            max_stage: sbuf.parse_unchecked(())?,
851            progress: sbuf.parse_unchecked(())?,
852            stage_info: buf.parse(())?,
853        })
854    }
855}
856
857impl MySerialize for ProgressReport<'_> {
858    fn serialize(&self, buf: &mut Vec<u8>) {
859        buf.put_u8(1);
860        self.stage.serialize(&mut *buf);
861        self.max_stage.serialize(&mut *buf);
862        self.progress.serialize(&mut *buf);
863        self.stage_info.serialize(buf);
864    }
865}
866
867impl fmt::Display for ProgressReport<'_> {
868    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
869        write!(
870            f,
871            "Stage: {} of {} '{}'  {:.2}% of stage done",
872            self.stage(),
873            self.max_stage(),
874            self.progress(),
875            self.stage_info_str()
876        )
877    }
878}
879
880define_header!(
881    ErrPacketHeader,
882    InvalidErrPacketHeader("Invalid error packet header"),
883    0xFF
884);
885
886/// MySql error packet.
887///
888/// May hold an error or a progress report.
889#[derive(Debug, Clone, PartialEq)]
890pub enum ErrPacket<'a> {
891    Error(ServerError<'a>),
892    Progress(ProgressReport<'a>),
893}
894
895impl ErrPacket<'_> {
896    /// Returns false if this error packet contains progress report.
897    pub fn is_error(&self) -> bool {
898        matches!(self, ErrPacket::Error { .. })
899    }
900
901    /// Returns true if this error packet contains progress report.
902    pub fn is_progress_report(&self) -> bool {
903        !self.is_error()
904    }
905
906    /// Will panic if ErrPacket does not contains progress report
907    pub fn progress_report(&self) -> &ProgressReport<'_> {
908        match *self {
909            ErrPacket::Progress(ref progress_report) => progress_report,
910            _ => panic!("This ErrPacket does not contains progress report"),
911        }
912    }
913
914    /// Will panic if ErrPacket does not contains a `ServerError`.
915    pub fn server_error(&self) -> &ServerError<'_> {
916        match self {
917            ErrPacket::Error(error) => error,
918            ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
919        }
920    }
921}
922
923impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
924    const SIZE: Option<usize> = None;
925    type Ctx = CapabilityFlags;
926
927    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
928        let mut sbuf: ParseBuf = buf.parse(3)?;
929        sbuf.parse_unchecked::<ErrPacketHeader>(())?;
930        let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
931
932        if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
933            buf.parse(()).map(ErrPacket::Progress)
934        } else {
935            buf.parse((
936                *code,
937                capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
938            ))
939            .map(ErrPacket::Error)
940        }
941    }
942}
943
944impl MySerialize for ErrPacket<'_> {
945    fn serialize(&self, buf: &mut Vec<u8>) {
946        ErrPacketHeader::new().serialize(&mut *buf);
947        match self {
948            ErrPacket::Error(server_error) => {
949                server_error.code.serialize(&mut *buf);
950                server_error.serialize(buf);
951            }
952            ErrPacket::Progress(progress_report) => {
953                RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
954                progress_report.serialize(buf);
955            }
956        }
957    }
958}
959
960impl fmt::Display for ErrPacket<'_> {
961    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
962        match self {
963            ErrPacket::Error(server_error) => write!(f, "{}", server_error),
964            ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
965        }
966    }
967}
968
969define_header!(
970    SqlStateMarker,
971    InvalidSqlStateMarker("Invalid SqlStateMarker value"),
972    b'#'
973);
974
975/// MySql error state.
976#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
977pub struct SqlState {
978    __state_marker: SqlStateMarker,
979    state: [u8; 5],
980}
981
982impl SqlState {
983    /// Creates new sql state.
984    pub fn new(state: [u8; 5]) -> Self {
985        Self {
986            __state_marker: SqlStateMarker::new(),
987            state,
988        }
989    }
990
991    /// Returns an sql state as bytes.
992    pub fn as_bytes(&self) -> [u8; 5] {
993        self.state
994    }
995
996    /// Returns an sql state as a string (lossy converted).
997    pub fn as_str(&self) -> Cow<'_, str> {
998        String::from_utf8_lossy(&self.state)
999    }
1000}
1001
1002impl<'de> MyDeserialize<'de> for SqlState {
1003    const SIZE: Option<usize> = Some(6);
1004    type Ctx = ();
1005
1006    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1007        Ok(Self {
1008            __state_marker: buf.parse(())?,
1009            state: buf.parse(())?,
1010        })
1011    }
1012}
1013
1014impl MySerialize for SqlState {
1015    fn serialize(&self, buf: &mut Vec<u8>) {
1016        self.__state_marker.serialize(buf);
1017        self.state.serialize(buf);
1018    }
1019}
1020
1021/// MySql error packet.
1022///
1023/// May hold an error or a progress report.
1024#[derive(Debug, Clone, PartialEq)]
1025pub struct ServerError<'a> {
1026    code: RawInt<LeU16>,
1027    state: Option<SqlState>,
1028    message: RawBytes<'a, EofBytes>,
1029}
1030
1031impl<'a> ServerError<'a> {
1032    pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1033        Self {
1034            code: RawInt::new(code),
1035            state,
1036            message: RawBytes::new(msg),
1037        }
1038    }
1039
1040    /// Returns an error code.
1041    pub fn error_code(&self) -> u16 {
1042        *self.code
1043    }
1044
1045    /// Returns an sql state.
1046    pub fn sql_state_ref(&self) -> Option<&SqlState> {
1047        self.state.as_ref()
1048    }
1049
1050    /// Returns an error message.
1051    pub fn message_ref(&self) -> &[u8] {
1052        self.message.as_bytes()
1053    }
1054
1055    /// Returns an error message as a string (lossy converted).
1056    pub fn message_str(&self) -> Cow<'_, str> {
1057        self.message.as_str()
1058    }
1059
1060    pub fn into_owned(self) -> ServerError<'static> {
1061        ServerError {
1062            code: self.code,
1063            state: self.state,
1064            message: self.message.into_owned(),
1065        }
1066    }
1067}
1068
1069impl<'de> MyDeserialize<'de> for ServerError<'de> {
1070    const SIZE: Option<usize> = None;
1071    /// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
1072    type Ctx = (u16, bool);
1073
1074    fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1075        let server_error = if protocol_41 {
1076            ServerError {
1077                code: RawInt::new(code),
1078                state: Some(buf.parse(())?),
1079                message: buf.parse(())?,
1080            }
1081        } else {
1082            ServerError {
1083                code: RawInt::new(code),
1084                state: None,
1085                message: buf.parse(())?,
1086            }
1087        };
1088        Ok(server_error)
1089    }
1090}
1091
1092impl MySerialize for ServerError<'_> {
1093    fn serialize(&self, buf: &mut Vec<u8>) {
1094        if let Some(state) = &self.state {
1095            state.serialize(buf);
1096        }
1097        self.message.serialize(buf);
1098    }
1099}
1100
1101impl fmt::Display for ServerError<'_> {
1102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103        let sql_state_str = self
1104            .sql_state_ref()
1105            .map(|s| format!(" ({})", s.as_str()))
1106            .unwrap_or_default();
1107
1108        write!(
1109            f,
1110            "ERROR {}{}: {}",
1111            self.error_code(),
1112            sql_state_str,
1113            self.message_str()
1114        )
1115    }
1116}
1117
1118define_header!(
1119    LocalInfileHeader,
1120    InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1121    0xFB
1122);
1123
1124/// Represents MySql's local infile packet.
1125#[derive(Debug, Clone, Eq, PartialEq)]
1126pub struct LocalInfilePacket<'a> {
1127    __header: LocalInfileHeader,
1128    file_name: RawBytes<'a, EofBytes>,
1129}
1130
1131impl<'a> LocalInfilePacket<'a> {
1132    pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1133        Self {
1134            __header: LocalInfileHeader::new(),
1135            file_name: RawBytes::new(file_name),
1136        }
1137    }
1138
1139    /// Value of the file_name field of a local infile packet as a byte slice.
1140    pub fn file_name_ref(&self) -> &[u8] {
1141        self.file_name.as_bytes()
1142    }
1143
1144    /// Value of the file_name field of a local infile packet as a string (lossy converted).
1145    pub fn file_name_str(&self) -> Cow<'_, str> {
1146        self.file_name.as_str()
1147    }
1148
1149    pub fn into_owned(self) -> LocalInfilePacket<'static> {
1150        LocalInfilePacket {
1151            __header: self.__header,
1152            file_name: self.file_name.into_owned(),
1153        }
1154    }
1155}
1156
1157impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1158    const SIZE: Option<usize> = None;
1159    type Ctx = ();
1160
1161    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1162        Ok(LocalInfilePacket {
1163            __header: buf.parse(())?,
1164            file_name: buf.parse(())?,
1165        })
1166    }
1167}
1168
1169impl MySerialize for LocalInfilePacket<'_> {
1170    fn serialize(&self, buf: &mut Vec<u8>) {
1171        self.__header.serialize(buf);
1172        self.file_name.serialize(buf);
1173    }
1174}
1175
1176const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1177const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1178const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1179const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1180
1181#[derive(Debug, Clone, PartialEq, Eq)]
1182pub enum AuthPluginData<'a> {
1183    /// Auth data for the `mysql_old_password` plugin.
1184    Old([u8; 8]),
1185    /// Auth data for the `mysql_native_password` plugin.
1186    Native([u8; 20]),
1187    /// Auth data for `sha2_password` and `caching_sha2_password` plugins.
1188    Sha2([u8; 32]),
1189    /// Clear password for `mysql_clear_password` plugin.
1190    Clear(Cow<'a, [u8]>),
1191}
1192
1193impl AuthPluginData<'_> {
1194    pub fn into_owned(self) -> AuthPluginData<'static> {
1195        match self {
1196            AuthPluginData::Old(x) => AuthPluginData::Old(x),
1197            AuthPluginData::Native(x) => AuthPluginData::Native(x),
1198            AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1199            AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1200        }
1201    }
1202}
1203
1204impl std::ops::Deref for AuthPluginData<'_> {
1205    type Target = [u8];
1206
1207    fn deref(&self) -> &Self::Target {
1208        match self {
1209            Self::Sha2(x) => &x[..],
1210            Self::Native(x) => &x[..],
1211            Self::Old(x) => &x[..],
1212            Self::Clear(x) => &x[..],
1213        }
1214    }
1215}
1216
1217impl MySerialize for AuthPluginData<'_> {
1218    fn serialize(&self, buf: &mut Vec<u8>) {
1219        match self {
1220            Self::Sha2(x) => buf.put_slice(&x[..]),
1221            Self::Native(x) => buf.put_slice(&x[..]),
1222            Self::Old(x) => {
1223                buf.put_slice(&x[..]);
1224                buf.push(0);
1225            }
1226            Self::Clear(x) => {
1227                buf.put_slice(x);
1228                buf.push(0);
1229            }
1230        }
1231    }
1232}
1233
1234/// Authentication plugin
1235#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1236pub enum AuthPlugin<'a> {
1237    /// Old Password Authentication
1238    MysqlOldPassword,
1239    /// Client-Side Cleartext Pluggable Authentication
1240    MysqlClearPassword,
1241    /// Legacy authentication plugin
1242    MysqlNativePassword,
1243    /// Default since MySql v8.0.4
1244    CachingSha2Password,
1245    Other(Cow<'a, [u8]>),
1246}
1247
1248impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1249    const SIZE: Option<usize> = None;
1250    type Ctx = ();
1251
1252    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1253        Ok(Self::from_bytes(buf.eat_all()))
1254    }
1255}
1256
1257impl MySerialize for AuthPlugin<'_> {
1258    fn serialize(&self, buf: &mut Vec<u8>) {
1259        buf.put_slice(self.as_bytes());
1260        buf.put_u8(0);
1261    }
1262}
1263
1264impl<'a> AuthPlugin<'a> {
1265    pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1266        let name = if let [name @ .., 0] = name {
1267            name
1268        } else {
1269            name
1270        };
1271        match name {
1272            CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1273            MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1274            MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1275            MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1276            name => AuthPlugin::Other(Cow::Borrowed(name)),
1277        }
1278    }
1279
1280    pub fn as_bytes(&self) -> &[u8] {
1281        match self {
1282            AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1283            AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1284            AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1285            AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1286            AuthPlugin::Other(name) => name,
1287        }
1288    }
1289
1290    pub fn into_owned(self) -> AuthPlugin<'static> {
1291        match self {
1292            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1293            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1294            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1295            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1296            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1297        }
1298    }
1299
1300    pub fn borrow(&self) -> AuthPlugin<'_> {
1301        match self {
1302            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1303            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1304            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1305            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1306            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1307        }
1308    }
1309
1310    /// Generates auth plugin data for this plugin.
1311    ///
1312    /// It'll generate `None` if password is `None` or empty.
1313    ///
1314    /// Note, that you should trim terminating null character from the `nonce`.
1315    pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1316        use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1317
1318        match pass {
1319            Some(pass) if !pass.is_empty() => match self {
1320                AuthPlugin::CachingSha2Password => {
1321                    scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1322                }
1323                AuthPlugin::MysqlNativePassword => {
1324                    scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1325                }
1326                AuthPlugin::MysqlOldPassword => {
1327                    scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1328                        .map(AuthPluginData::Old)
1329                }
1330                AuthPlugin::MysqlClearPassword => {
1331                    Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1332                }
1333                AuthPlugin::Other(_) => None,
1334            },
1335            _ => None,
1336        }
1337    }
1338}
1339
1340define_header!(
1341    AuthMoreDataHeader,
1342    InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1343    0x01
1344);
1345
1346/// Extra auth-data beyond the initial challenge.
1347#[derive(Debug, Clone, Eq, PartialEq)]
1348pub struct AuthMoreData<'a> {
1349    __header: AuthMoreDataHeader,
1350    data: RawBytes<'a, EofBytes>,
1351}
1352
1353impl<'a> AuthMoreData<'a> {
1354    pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1355        Self {
1356            __header: AuthMoreDataHeader::new(),
1357            data: RawBytes::new(data),
1358        }
1359    }
1360
1361    pub fn data(&self) -> &[u8] {
1362        self.data.as_bytes()
1363    }
1364
1365    pub fn into_owned(self) -> AuthMoreData<'static> {
1366        AuthMoreData {
1367            __header: self.__header,
1368            data: self.data.into_owned(),
1369        }
1370    }
1371}
1372
1373impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1374    const SIZE: Option<usize> = None;
1375    type Ctx = ();
1376
1377    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1378        Ok(Self {
1379            __header: buf.parse(())?,
1380            data: buf.parse(())?,
1381        })
1382    }
1383}
1384
1385impl MySerialize for AuthMoreData<'_> {
1386    fn serialize(&self, buf: &mut Vec<u8>) {
1387        self.__header.serialize(&mut *buf);
1388        self.data.serialize(buf);
1389    }
1390}
1391
1392define_header!(
1393    PublicKeyResponseHeader,
1394    InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1395    0x01
1396);
1397
1398/// A server response to a [`PublicKeyRequest`] containing a public RSA key for authentication protection.
1399///
1400/// [`PublicKeyRequest`]: crate::packets::caching_sha2_password::PublicKeyRequest
1401#[derive(Debug, Clone, Eq, PartialEq)]
1402pub struct PublicKeyResponse<'a> {
1403    __header: PublicKeyResponseHeader,
1404    rsa_key: RawBytes<'a, EofBytes>,
1405}
1406
1407impl<'a> PublicKeyResponse<'a> {
1408    pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1409        Self {
1410            __header: PublicKeyResponseHeader::new(),
1411            rsa_key: RawBytes::new(rsa_key),
1412        }
1413    }
1414
1415    /// The server's RSA public key in PEM format.
1416    pub fn rsa_key(&self) -> Cow<'_, str> {
1417        self.rsa_key.as_str()
1418    }
1419}
1420
1421impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1422    const SIZE: Option<usize> = None;
1423    type Ctx = ();
1424
1425    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1426        Ok(Self {
1427            __header: buf.parse(())?,
1428            rsa_key: buf.parse(())?,
1429        })
1430    }
1431}
1432
1433impl MySerialize for PublicKeyResponse<'_> {
1434    fn serialize(&self, buf: &mut Vec<u8>) {
1435        self.__header.serialize(&mut *buf);
1436        self.rsa_key.serialize(buf);
1437    }
1438}
1439
1440define_header!(
1441    AuthSwitchRequestHeader,
1442    InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1443    0xFE
1444);
1445
1446/// Old Authentication Method Switch Request Packet.
1447///
1448/// Used for It is sent by server to request client to switch to Old Password Authentication
1449/// if `CLIENT_PLUGIN_AUTH` capability is not supported (by either the client or the server).
1450#[derive(Debug, Clone, Eq, PartialEq)]
1451pub struct OldAuthSwitchRequest {
1452    __header: AuthSwitchRequestHeader,
1453}
1454
1455impl OldAuthSwitchRequest {
1456    pub fn new() -> Self {
1457        Self {
1458            __header: AuthSwitchRequestHeader::new(),
1459        }
1460    }
1461
1462    pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1463        AuthPlugin::MysqlOldPassword
1464    }
1465}
1466
1467impl Default for OldAuthSwitchRequest {
1468    fn default() -> Self {
1469        Self::new()
1470    }
1471}
1472
1473impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1474    const SIZE: Option<usize> = Some(1);
1475    type Ctx = ();
1476
1477    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1478        Ok(Self {
1479            __header: buf.parse(())?,
1480        })
1481    }
1482}
1483
1484impl MySerialize for OldAuthSwitchRequest {
1485    fn serialize(&self, buf: &mut Vec<u8>) {
1486        self.__header.serialize(&mut *buf);
1487    }
1488}
1489
1490/// Authentication Method Switch Request Packet.
1491///
1492/// If both server and client support `CLIENT_PLUGIN_AUTH` capability, server can send this packet
1493/// to ask client to use another authentication method.
1494#[derive(Debug, Clone, Eq, PartialEq)]
1495pub struct AuthSwitchRequest<'a> {
1496    __header: AuthSwitchRequestHeader,
1497    auth_plugin: RawBytes<'a, NullBytes>,
1498    plugin_data: RawBytes<'a, EofBytes>,
1499}
1500
1501impl<'a> AuthSwitchRequest<'a> {
1502    pub fn new(
1503        auth_plugin: impl Into<Cow<'a, [u8]>>,
1504        plugin_data: impl Into<Cow<'a, [u8]>>,
1505    ) -> Self {
1506        Self {
1507            __header: AuthSwitchRequestHeader::new(),
1508            auth_plugin: RawBytes::new(auth_plugin),
1509            plugin_data: RawBytes::new(plugin_data),
1510        }
1511    }
1512
1513    pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1514        ParseBuf(self.auth_plugin.as_bytes())
1515            .parse(())
1516            .expect("infallible")
1517    }
1518
1519    pub fn plugin_data(&self) -> &[u8] {
1520        match self.plugin_data.as_bytes() {
1521            [head @ .., 0] => head,
1522            all => all,
1523        }
1524    }
1525
1526    pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1527        AuthSwitchRequest {
1528            __header: self.__header,
1529            auth_plugin: self.auth_plugin.into_owned(),
1530            plugin_data: self.plugin_data.into_owned(),
1531        }
1532    }
1533}
1534
1535impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1536    const SIZE: Option<usize> = None;
1537    type Ctx = ();
1538
1539    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1540        Ok(Self {
1541            __header: buf.parse(())?,
1542            auth_plugin: buf.parse(())?,
1543            plugin_data: buf.parse(())?,
1544        })
1545    }
1546}
1547
1548impl MySerialize for AuthSwitchRequest<'_> {
1549    fn serialize(&self, buf: &mut Vec<u8>) {
1550        self.__header.serialize(&mut *buf);
1551        self.auth_plugin.serialize(&mut *buf);
1552        self.plugin_data.serialize(buf);
1553    }
1554}
1555
1556/// Represents MySql's initial handshake packet.
1557#[derive(Debug, Clone, Eq, PartialEq)]
1558pub struct HandshakePacket<'a> {
1559    protocol_version: RawInt<u8>,
1560    server_version: RawBytes<'a, NullBytes>,
1561    connection_id: RawInt<LeU32>,
1562    scramble_1: [u8; 8],
1563    __filler: Skip<1>,
1564    // lower 16 bytes
1565    capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1566    default_collation: RawInt<u8>,
1567    status_flags: Const<StatusFlags, LeU16>,
1568    // upper 16 bytes
1569    capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1570    auth_plugin_data_len: RawInt<u8>,
1571    __reserved: Skip<10>,
1572    scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1573    auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1574}
1575
1576impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1577    const SIZE: Option<usize> = None;
1578    type Ctx = ();
1579
1580    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1581        let protocol_version = buf.parse(())?;
1582        let server_version = buf.parse(())?;
1583
1584        // includes trailing 10 bytes filler
1585        let mut sbuf: ParseBuf = buf.parse(31)?;
1586        let connection_id = sbuf.parse_unchecked(())?;
1587        let scramble_1 = sbuf.parse_unchecked(())?;
1588        let __filler = sbuf.parse_unchecked(())?;
1589        let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1590        let default_collation = sbuf.parse_unchecked(())?;
1591        let status_flags = sbuf.parse_unchecked(())?;
1592        let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1593        let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1594        let __reserved = sbuf.parse_unchecked(())?;
1595        let mut scramble_2 = None;
1596        if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1597            let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1598            scramble_2 = buf.parse(len).map(Some)?;
1599        }
1600        let mut auth_plugin_name = None;
1601        if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1602            auth_plugin_name = match buf.eat_all() {
1603                [head @ .., 0] => Some(RawBytes::new(head)),
1604                // missing trailing `0` is a known bug in mysql
1605                all => Some(RawBytes::new(all)),
1606            }
1607        }
1608
1609        Ok(Self {
1610            protocol_version,
1611            server_version,
1612            connection_id,
1613            scramble_1,
1614            __filler,
1615            capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1616            default_collation,
1617            status_flags,
1618            capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1619            auth_plugin_data_len,
1620            __reserved,
1621            scramble_2,
1622            auth_plugin_name,
1623        })
1624    }
1625}
1626
1627impl MySerialize for HandshakePacket<'_> {
1628    fn serialize(&self, buf: &mut Vec<u8>) {
1629        self.protocol_version.serialize(&mut *buf);
1630        self.server_version.serialize(&mut *buf);
1631        self.connection_id.serialize(&mut *buf);
1632        self.scramble_1.serialize(&mut *buf);
1633        buf.put_u8(0x00);
1634        self.capabilities_1.serialize(&mut *buf);
1635        self.default_collation.serialize(&mut *buf);
1636        self.status_flags.serialize(&mut *buf);
1637        self.capabilities_2.serialize(&mut *buf);
1638
1639        if self
1640            .capabilities_2
1641            .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1642        {
1643            buf.put_u8(
1644                self.scramble_2
1645                    .as_ref()
1646                    .map(|x| (x.len() + 8) as u8)
1647                    .unwrap_or_default(),
1648            );
1649        } else {
1650            buf.put_u8(0);
1651        }
1652
1653        buf.put_slice(&[0_u8; 10][..]);
1654
1655        // Assume that the packet is well formed:
1656        // * the CLIENT_SECURE_CONNECTION is set.
1657        if let Some(scramble_2) = &self.scramble_2 {
1658            scramble_2.serialize(&mut *buf);
1659        }
1660
1661        // Assume that the packet is well formed:
1662        // * the CLIENT_PLUGIN_AUTH is set.
1663        if let Some(client_plugin_auth) = &self.auth_plugin_name {
1664            client_plugin_auth.serialize(buf);
1665        }
1666    }
1667}
1668
1669impl<'a> HandshakePacket<'a> {
1670    #[allow(clippy::too_many_arguments)]
1671    pub fn new(
1672        protocol_version: u8,
1673        server_version: impl Into<Cow<'a, [u8]>>,
1674        connection_id: u32,
1675        scramble_1: [u8; 8],
1676        scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1677        capabilities: CapabilityFlags,
1678        default_collation: u8,
1679        status_flags: StatusFlags,
1680        auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1681    ) -> Self {
1682        // Safety:
1683        // * capabilities are given as a valid CapabilityFlags instance
1684        // * the BitAnd operation can't set new bits
1685        let (capabilities_1, capabilities_2) = (
1686            CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1687            CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1688        );
1689
1690        let scramble_2 = scramble_2.map(RawBytes::new);
1691
1692        HandshakePacket {
1693            protocol_version: RawInt::new(protocol_version),
1694            server_version: RawBytes::new(server_version),
1695            connection_id: RawInt::new(connection_id),
1696            scramble_1,
1697            __filler: Skip,
1698            capabilities_1: Const::new(capabilities_1),
1699            default_collation: RawInt::new(default_collation),
1700            status_flags: Const::new(status_flags),
1701            capabilities_2: Const::new(capabilities_2),
1702            auth_plugin_data_len: RawInt::new(
1703                scramble_2
1704                    .as_ref()
1705                    .map(|x| x.len() as u8)
1706                    .unwrap_or_default(),
1707            ),
1708            __reserved: Skip,
1709            scramble_2,
1710            auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1711        }
1712    }
1713
1714    pub fn into_owned(self) -> HandshakePacket<'static> {
1715        HandshakePacket {
1716            protocol_version: self.protocol_version,
1717            server_version: self.server_version.into_owned(),
1718            connection_id: self.connection_id,
1719            scramble_1: self.scramble_1,
1720            __filler: self.__filler,
1721            capabilities_1: self.capabilities_1,
1722            default_collation: self.default_collation,
1723            status_flags: self.status_flags,
1724            capabilities_2: self.capabilities_2,
1725            auth_plugin_data_len: self.auth_plugin_data_len,
1726            __reserved: self.__reserved,
1727            scramble_2: self.scramble_2.map(|x| x.into_owned()),
1728            auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1729        }
1730    }
1731
1732    /// Value of the protocol_version field of an initial handshake packet.
1733    pub fn protocol_version(&self) -> u8 {
1734        self.protocol_version.0
1735    }
1736
1737    /// Value of the server_version field of an initial handshake packet as a byte slice.
1738    pub fn server_version_ref(&self) -> &[u8] {
1739        self.server_version.as_bytes()
1740    }
1741
1742    /// Value of the server_version field of an initial handshake packet as a string
1743    /// (lossy converted).
1744    pub fn server_version_str(&self) -> Cow<'_, str> {
1745        self.server_version.as_str()
1746    }
1747
1748    /// Parsed server version.
1749    ///
1750    /// Will parse first \d+.\d+.\d+ of a server version string (if any).
1751    pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1752        VERSION_RE
1753            .captures(self.server_version_ref())
1754            .map(|captures| {
1755                // Should not panic because validated with regex
1756                (
1757                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1758                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1759                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1760                )
1761            })
1762    }
1763
1764    /// Parsed mariadb server version.
1765    pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1766        MARIADB_VERSION_RE
1767            .captures(self.server_version_ref())
1768            .map(|captures| {
1769                // Should not panic because validated with regex
1770                (
1771                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1772                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1773                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1774                )
1775            })
1776    }
1777
1778    /// Value of the connection_id field of an initial handshake packet.
1779    pub fn connection_id(&self) -> u32 {
1780        self.connection_id.0
1781    }
1782
1783    /// Value of the scramble_1 field of an initial handshake packet as a byte slice.
1784    pub fn scramble_1_ref(&self) -> &[u8] {
1785        self.scramble_1.as_ref()
1786    }
1787
1788    /// Value of the scramble_2 field of an initial handshake packet as a byte slice.
1789    ///
1790    /// Note that this may include a terminating null character.
1791    pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1792        self.scramble_2.as_ref().map(|x| x.as_bytes())
1793    }
1794
1795    /// Returns concatenated auth plugin nonce.
1796    pub fn nonce(&self) -> Vec<u8> {
1797        let mut out = Vec::from(self.scramble_1_ref());
1798        out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1799
1800        // Trim zero terminator. Fill with zeroes if nonce
1801        // is somehow smaller than 20 bytes.
1802        out.resize(20, 0);
1803        out
1804    }
1805
1806    /// Value of a server capabilities.
1807    pub fn capabilities(&self) -> CapabilityFlags {
1808        self.capabilities_1.0 | self.capabilities_2.0
1809    }
1810
1811    /// Value of the default_collation field of an initial handshake packet.
1812    pub fn default_collation(&self) -> u8 {
1813        self.default_collation.0
1814    }
1815
1816    /// Value of a status flags.
1817    pub fn status_flags(&self) -> StatusFlags {
1818        self.status_flags.0
1819    }
1820
1821    /// Value of the auth_plugin_name field of an initial handshake packet as a byte slice.
1822    pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1823        self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1824    }
1825
1826    /// Value of the auth_plugin_name field of an initial handshake packet as a string
1827    /// (lossy converted).
1828    pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1829        self.auth_plugin_name.as_ref().map(|x| x.as_str())
1830    }
1831
1832    /// Auth plugin of a handshake packet
1833    pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1834        self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1835            [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1836            all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1837        })
1838    }
1839}
1840
1841define_header!(
1842    ComChangeUserHeader,
1843    InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1844    0x11
1845);
1846
1847#[derive(Debug, Clone, PartialEq, Eq)]
1848pub struct ComChangeUser<'a> {
1849    __header: ComChangeUserHeader,
1850    user: RawBytes<'a, NullBytes>,
1851    // Only CLIENT_SECURE_CONNECTION capable servers are supported
1852    auth_plugin_data: RawBytes<'a, U8Bytes>,
1853    database: RawBytes<'a, NullBytes>,
1854    more_data: Option<ComChangeUserMoreData<'a>>,
1855}
1856
1857impl<'a> ComChangeUser<'a> {
1858    pub fn new() -> Self {
1859        Self {
1860            __header: ComChangeUserHeader::new(),
1861            user: Default::default(),
1862            auth_plugin_data: Default::default(),
1863            database: Default::default(),
1864            more_data: None,
1865        }
1866    }
1867
1868    pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1869        self.user = user.map(RawBytes::new).unwrap_or_default();
1870        self
1871    }
1872
1873    pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1874        self.database = database.map(RawBytes::new).unwrap_or_default();
1875        self
1876    }
1877
1878    pub fn with_auth_plugin_data(
1879        mut self,
1880        auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1881    ) -> Self {
1882        self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1883        self
1884    }
1885
1886    pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1887        self.more_data = more_data;
1888        self
1889    }
1890
1891    pub fn into_owned(self) -> ComChangeUser<'static> {
1892        ComChangeUser {
1893            __header: self.__header,
1894            user: self.user.into_owned(),
1895            auth_plugin_data: self.auth_plugin_data.into_owned(),
1896            database: self.database.into_owned(),
1897            more_data: self.more_data.map(|x| x.into_owned()),
1898        }
1899    }
1900}
1901
1902impl Default for ComChangeUser<'_> {
1903    fn default() -> Self {
1904        Self::new()
1905    }
1906}
1907
1908impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1909    const SIZE: Option<usize> = None;
1910
1911    type Ctx = CapabilityFlags;
1912
1913    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1914        Ok(Self {
1915            __header: buf.parse(())?,
1916            user: buf.parse(())?,
1917            auth_plugin_data: buf.parse(())?,
1918            database: buf.parse(())?,
1919            more_data: if !buf.is_empty() {
1920                Some(buf.parse(flags)?)
1921            } else {
1922                None
1923            },
1924        })
1925    }
1926}
1927
1928impl MySerialize for ComChangeUser<'_> {
1929    fn serialize(&self, buf: &mut Vec<u8>) {
1930        self.__header.serialize(&mut *buf);
1931        self.user.serialize(&mut *buf);
1932        self.auth_plugin_data.serialize(&mut *buf);
1933        self.database.serialize(&mut *buf);
1934        if let Some(ref more_data) = self.more_data {
1935            more_data.serialize(&mut *buf);
1936        }
1937    }
1938}
1939
1940#[derive(Debug, Clone, PartialEq, Eq)]
1941pub struct ComChangeUserMoreData<'a> {
1942    character_set: RawInt<LeU16>,
1943    auth_plugin: Option<AuthPlugin<'a>>,
1944    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1945}
1946
1947impl<'a> ComChangeUserMoreData<'a> {
1948    pub fn new(character_set: u16) -> Self {
1949        Self {
1950            character_set: RawInt::new(character_set),
1951            auth_plugin: None,
1952            connect_attributes: None,
1953        }
1954    }
1955
1956    pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
1957        self.auth_plugin = auth_plugin;
1958        self
1959    }
1960
1961    pub fn with_connect_attributes(
1962        mut self,
1963        connect_attributes: Option<HashMap<String, String>>,
1964    ) -> Self {
1965        self.connect_attributes = connect_attributes.map(|attrs| {
1966            attrs
1967                .into_iter()
1968                .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
1969                .collect()
1970        });
1971        self
1972    }
1973
1974    pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
1975        ComChangeUserMoreData {
1976            character_set: self.character_set,
1977            auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
1978            connect_attributes: self.connect_attributes.map(|x| {
1979                x.into_iter()
1980                    .map(|(k, v)| (k.into_owned(), v.into_owned()))
1981                    .collect()
1982            }),
1983        }
1984    }
1985}
1986
1987// Helper that deserializes connect attributes.
1988fn deserialize_connect_attrs<'de>(
1989    buf: &mut ParseBuf<'de>,
1990) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
1991    let data_len = buf.parse::<RawInt<LenEnc>>(())?;
1992    let mut data: ParseBuf = buf.parse(data_len.0 as usize)?;
1993    let mut attrs = HashMap::new();
1994    while !data.is_empty() {
1995        let key = data.parse::<RawBytes<LenEnc>>(())?;
1996        let value = data.parse::<RawBytes<LenEnc>>(())?;
1997        attrs.insert(key, value);
1998    }
1999    Ok(attrs)
2000}
2001
2002// Helper that serializes connect attributes.
2003fn serialize_connect_attrs<'a>(
2004    connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2005    buf: &mut Vec<u8>,
2006) {
2007    let len = connect_attributes
2008        .iter()
2009        .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2010        .sum::<u64>();
2011    buf.put_lenenc_int(len);
2012
2013    for (name, value) in connect_attributes {
2014        name.serialize(&mut *buf);
2015        value.serialize(&mut *buf);
2016    }
2017}
2018
2019impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2020    const SIZE: Option<usize> = None;
2021    type Ctx = CapabilityFlags;
2022
2023    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2024        // always assume CLIENT_PROTOCOL_41
2025        let character_set = buf.parse(())?;
2026        let mut auth_plugin = None;
2027        let mut connect_attributes = None;
2028
2029        if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2030            // plugin name is null-terminated here
2031            match buf.parse::<RawBytes<NullBytes>>(())?.0 {
2032                Cow::Borrowed(bytes) => {
2033                    let mut auth_plugin_buf = ParseBuf(bytes);
2034                    auth_plugin = Some(auth_plugin_buf.parse(())?);
2035                }
2036                _ => unreachable!(),
2037            }
2038        };
2039
2040        if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2041            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2042        };
2043
2044        Ok(Self {
2045            character_set,
2046            auth_plugin,
2047            connect_attributes,
2048        })
2049    }
2050}
2051
2052impl MySerialize for ComChangeUserMoreData<'_> {
2053    fn serialize(&self, buf: &mut Vec<u8>) {
2054        self.character_set.serialize(&mut *buf);
2055        if let Some(ref auth_plugin) = self.auth_plugin {
2056            auth_plugin.serialize(&mut *buf);
2057        }
2058        if let Some(ref connect_attributes) = self.connect_attributes {
2059            serialize_connect_attrs(connect_attributes, buf);
2060        } else {
2061            // We'll always act like CLIENT_CONNECT_ATTRS is set,
2062            // this is to avoid looking into the actual connection flags.
2063            serialize_connect_attrs(&Default::default(), buf);
2064        }
2065    }
2066}
2067
2068/// Actual serialization of this field depends on capability flags values.
2069type ScrambleBuf<'a> =
2070    Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2071
2072#[derive(Debug, Clone, PartialEq, Eq)]
2073pub struct HandshakeResponse<'a> {
2074    capabilities: Const<CapabilityFlags, LeU32>,
2075    max_packet_size: RawInt<LeU32>,
2076    collation: RawInt<u8>,
2077    scramble_buf: ScrambleBuf<'a>,
2078    user: RawBytes<'a, NullBytes>,
2079    db_name: Option<RawBytes<'a, NullBytes>>,
2080    auth_plugin: Option<AuthPlugin<'a>>,
2081    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2082}
2083
2084impl<'a> HandshakeResponse<'a> {
2085    #[allow(clippy::too_many_arguments)]
2086    pub fn new(
2087        scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2088        server_version: (u16, u16, u16),
2089        user: Option<impl Into<Cow<'a, [u8]>>>,
2090        db_name: Option<impl Into<Cow<'a, [u8]>>>,
2091        auth_plugin: Option<AuthPlugin<'a>>,
2092        mut capabilities: CapabilityFlags,
2093        connect_attributes: Option<HashMap<String, String>>,
2094        max_packet_size: u32,
2095    ) -> Self {
2096        let scramble_buf =
2097            if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2098                Either::Left(RawBytes::new(
2099                    scramble_buf.map(Into::into).unwrap_or_default(),
2100                ))
2101            } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2102                Either::Right(Either::Left(RawBytes::new(
2103                    scramble_buf.map(Into::into).unwrap_or_default(),
2104                )))
2105            } else {
2106                Either::Right(Either::Right(RawBytes::new(
2107                    scramble_buf.map(Into::into).unwrap_or_default(),
2108                )))
2109            };
2110
2111        if db_name.is_some() {
2112            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2113        } else {
2114            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2115        }
2116
2117        if auth_plugin.is_some() {
2118            capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2119        } else {
2120            capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2121        }
2122
2123        if connect_attributes.is_some() {
2124            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2125        } else {
2126            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2127        }
2128
2129        Self {
2130            scramble_buf,
2131            collation: if server_version >= (5, 5, 3) {
2132                RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2133            } else {
2134                RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2135            },
2136            user: user.map(RawBytes::new).unwrap_or_default(),
2137            db_name: db_name.map(RawBytes::new),
2138            auth_plugin,
2139            capabilities: Const::new(capabilities),
2140            connect_attributes: connect_attributes.map(|attrs| {
2141                attrs
2142                    .into_iter()
2143                    .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2144                    .collect()
2145            }),
2146            max_packet_size: RawInt::new(max_packet_size),
2147        }
2148    }
2149
2150    pub fn capabilities(&self) -> CapabilityFlags {
2151        self.capabilities.0
2152    }
2153
2154    pub fn collation(&self) -> u8 {
2155        self.collation.0
2156    }
2157
2158    pub fn scramble_buf(&self) -> &[u8] {
2159        match &self.scramble_buf {
2160            Either::Left(x) => x.as_bytes(),
2161            Either::Right(x) => match x {
2162                Either::Left(x) => x.as_bytes(),
2163                Either::Right(x) => x.as_bytes(),
2164            },
2165        }
2166    }
2167
2168    pub fn user(&self) -> &[u8] {
2169        self.user.as_bytes()
2170    }
2171
2172    pub fn db_name(&self) -> Option<&[u8]> {
2173        self.db_name.as_ref().map(|x| x.as_bytes())
2174    }
2175
2176    pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2177        self.auth_plugin.as_ref()
2178    }
2179
2180    #[must_use = "entails computation"]
2181    pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2182        self.connect_attributes.as_ref().map(|attrs| {
2183            attrs
2184                .iter()
2185                .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2186                .collect()
2187        })
2188    }
2189}
2190
2191impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2192    const SIZE: Option<usize> = None;
2193    type Ctx = ();
2194
2195    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2196        let mut sbuf: ParseBuf = buf.parse(4 + 4 + 1 + 23)?;
2197        let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2198        let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2199        let collation = sbuf.parse_unchecked(())?;
2200        sbuf.parse_unchecked::<Skip<23>>(())?;
2201
2202        let user = buf.parse(())?;
2203        let scramble_buf =
2204            if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2205                Either::Left(buf.parse(())?)
2206            } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2207                Either::Right(Either::Left(buf.parse(())?))
2208            } else {
2209                Either::Right(Either::Right(buf.parse(())?))
2210            };
2211
2212        let mut db_name = None;
2213        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2214            db_name = buf.parse(()).map(Some)?;
2215        }
2216
2217        let mut auth_plugin = None;
2218        if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2219            let auth_plugin_name = buf.eat_null_str();
2220            auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2221        }
2222
2223        let mut connect_attributes = None;
2224        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2225            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2226        }
2227
2228        Ok(Self {
2229            capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2230            max_packet_size,
2231            collation,
2232            scramble_buf,
2233            user,
2234            db_name,
2235            auth_plugin,
2236            connect_attributes,
2237        })
2238    }
2239}
2240
2241impl MySerialize for HandshakeResponse<'_> {
2242    fn serialize(&self, buf: &mut Vec<u8>) {
2243        self.capabilities.serialize(&mut *buf);
2244        self.max_packet_size.serialize(&mut *buf);
2245        self.collation.serialize(&mut *buf);
2246        buf.put_slice(&[0; 23]);
2247        self.user.serialize(&mut *buf);
2248        self.scramble_buf.serialize(&mut *buf);
2249
2250        if let Some(db_name) = &self.db_name {
2251            db_name.serialize(&mut *buf);
2252        }
2253
2254        if let Some(auth_plugin) = &self.auth_plugin {
2255            auth_plugin.serialize(&mut *buf);
2256        }
2257
2258        if let Some(attrs) = &self.connect_attributes {
2259            let len = attrs
2260                .iter()
2261                .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2262                .sum::<u64>();
2263            buf.put_lenenc_int(len);
2264
2265            for (name, value) in attrs {
2266                name.serialize(&mut *buf);
2267                value.serialize(&mut *buf);
2268            }
2269        }
2270    }
2271}
2272
2273#[derive(Debug, Clone, Eq, PartialEq)]
2274pub struct SslRequest {
2275    capabilities: Const<CapabilityFlags, LeU32>,
2276    max_packet_size: RawInt<LeU32>,
2277    character_set: RawInt<u8>,
2278    __skip: Skip<23>,
2279}
2280
2281impl SslRequest {
2282    pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2283        Self {
2284            capabilities: Const::new(capabilities),
2285            max_packet_size: RawInt::new(max_packet_size),
2286            character_set: RawInt::new(character_set),
2287            __skip: Skip,
2288        }
2289    }
2290
2291    pub fn capabilities(&self) -> CapabilityFlags {
2292        self.capabilities.0
2293    }
2294
2295    pub fn max_packet_size(&self) -> u32 {
2296        self.max_packet_size.0
2297    }
2298
2299    pub fn character_set(&self) -> u8 {
2300        self.character_set.0
2301    }
2302}
2303
2304impl<'de> MyDeserialize<'de> for SslRequest {
2305    const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2306    type Ctx = ();
2307
2308    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2309        let mut buf: ParseBuf = buf.parse(Self::SIZE.unwrap())?;
2310        let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2311        Ok(Self {
2312            capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2313            max_packet_size: buf.parse_unchecked(())?,
2314            character_set: buf.parse_unchecked(())?,
2315            __skip: buf.parse_unchecked(())?,
2316        })
2317    }
2318}
2319
2320impl MySerialize for SslRequest {
2321    fn serialize(&self, buf: &mut Vec<u8>) {
2322        self.capabilities.serialize(&mut *buf);
2323        self.max_packet_size.serialize(&mut *buf);
2324        self.character_set.serialize(&mut *buf);
2325        self.__skip.serialize(&mut *buf);
2326    }
2327}
2328
2329#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2330#[error("Invalid statement packet status")]
2331pub struct InvalidStmtPacketStatus;
2332
2333/// Represents MySql's statement packet.
2334#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2335pub struct StmtPacket {
2336    status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2337    statement_id: RawInt<LeU32>,
2338    num_columns: RawInt<LeU16>,
2339    num_params: RawInt<LeU16>,
2340    __skip: Skip<1>,
2341    warning_count: RawInt<LeU16>,
2342}
2343
2344impl<'de> MyDeserialize<'de> for StmtPacket {
2345    const SIZE: Option<usize> = Some(12);
2346    type Ctx = ();
2347
2348    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2349        let mut buf: ParseBuf = buf.parse(Self::SIZE.unwrap())?;
2350        Ok(StmtPacket {
2351            status: buf.parse_unchecked(())?,
2352            statement_id: buf.parse_unchecked(())?,
2353            num_columns: buf.parse_unchecked(())?,
2354            num_params: buf.parse_unchecked(())?,
2355            __skip: buf.parse_unchecked(())?,
2356            warning_count: buf.parse_unchecked(())?,
2357        })
2358    }
2359}
2360
2361impl MySerialize for StmtPacket {
2362    fn serialize(&self, buf: &mut Vec<u8>) {
2363        self.status.serialize(&mut *buf);
2364        self.statement_id.serialize(&mut *buf);
2365        self.num_columns.serialize(&mut *buf);
2366        self.num_params.serialize(&mut *buf);
2367        self.__skip.serialize(&mut *buf);
2368        self.warning_count.serialize(&mut *buf);
2369    }
2370}
2371
2372impl StmtPacket {
2373    /// Value of the statement_id field of a statement packet.
2374    pub fn statement_id(&self) -> u32 {
2375        *self.statement_id
2376    }
2377
2378    /// Value of the num_columns field of a statement packet.
2379    pub fn num_columns(&self) -> u16 {
2380        *self.num_columns
2381    }
2382
2383    /// Value of the num_params field of a statement packet.
2384    pub fn num_params(&self) -> u16 {
2385        *self.num_params
2386    }
2387
2388    /// Value of the warning_count field of a statement packet.
2389    pub fn warning_count(&self) -> u16 {
2390        *self.warning_count
2391    }
2392}
2393
2394/// Null-bitmap.
2395///
2396/// <http://dev.mysql.com/doc/internals/en/null-bitmap.html>
2397#[derive(Debug, Clone, Eq, PartialEq)]
2398pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2399
2400impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2401    const SIZE: Option<usize> = None;
2402    type Ctx = usize;
2403
2404    fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2405        let bitmap_len = Self::bitmap_len(num_columns);
2406        let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2407        Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2408    }
2409}
2410
2411impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2412    /// Creates new null-bitmap for a given number of columns.
2413    pub fn new(num_columns: usize) -> Self {
2414        Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2415    }
2416
2417    /// Will read null-bitmap for a given number of columns from `input`.
2418    pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2419        let bitmap_len = Self::bitmap_len(num_columns);
2420        assert!(input.len() >= bitmap_len);
2421
2422        let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2423        *input = &input[bitmap_len..];
2424
2425        bitmap
2426    }
2427}
2428
2429impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2430    pub fn bitmap_len(num_columns: usize) -> usize {
2431        (num_columns + 7 + T::BIT_OFFSET) / 8
2432    }
2433
2434    fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2435        let offset = column_index + T::BIT_OFFSET;
2436        let byte = offset / 8;
2437        let bit = 1 << (offset % 8) as u8;
2438
2439        assert!(byte < self.0.as_ref().len());
2440
2441        (byte, bit)
2442    }
2443
2444    /// Creates new null-bitmap from given bytes.
2445    pub fn from_bytes(bytes: U) -> Self {
2446        Self(bytes, PhantomData)
2447    }
2448
2449    /// Returns `true` if given column is `NULL` in this `NullBitmap`.
2450    pub fn is_null(&self, column_index: usize) -> bool {
2451        let (byte, bit) = self.byte_and_bit(column_index);
2452        self.0.as_ref()[byte] & bit > 0
2453    }
2454}
2455
2456impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2457    /// Sets flag value for given column.
2458    pub fn set(&mut self, column_index: usize, is_null: bool) {
2459        let (byte, bit) = self.byte_and_bit(column_index);
2460        if is_null {
2461            self.0.as_mut()[byte] |= bit
2462        } else {
2463            self.0.as_mut()[byte] &= !bit
2464        }
2465    }
2466}
2467
2468impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2469    fn as_ref(&self) -> &[u8] {
2470        self.0.as_ref()
2471    }
2472}
2473
2474#[derive(Debug, Clone, PartialEq)]
2475pub struct ComStmtExecuteRequestBuilder {
2476    pub stmt_id: u32,
2477}
2478
2479impl ComStmtExecuteRequestBuilder {
2480    pub const NULL_BITMAP_OFFSET: usize = 10;
2481
2482    pub fn new(stmt_id: u32) -> Self {
2483        Self { stmt_id }
2484    }
2485}
2486
2487impl ComStmtExecuteRequestBuilder {
2488    pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2489        let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2490
2491        let mut bitmap_bytes = vec![0; bitmap_len];
2492        let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2493        let params = params.iter().collect::<Vec<_>>();
2494
2495        let meta_len = params.len() * 2;
2496
2497        let mut data_len = 0;
2498        for (i, param) in params.iter().enumerate() {
2499            match param.bin_len() as usize {
2500                0 => bitmap.set(i, true),
2501                x => data_len += x,
2502            }
2503        }
2504
2505        let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2506
2507        let as_long_data = total_len > MAX_PAYLOAD_LEN;
2508
2509        (
2510            ComStmtExecuteRequest {
2511                com_stmt_execute: ConstU8::new(),
2512                stmt_id: RawInt::new(self.stmt_id),
2513                flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2514                iteration_count: ConstU32::new(),
2515                params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2516                bitmap: RawBytes::new(bitmap_bytes),
2517                params,
2518                as_long_data,
2519            },
2520            as_long_data,
2521        )
2522    }
2523}
2524
2525define_header!(
2526    ComStmtExecuteHeader,
2527    COM_STMT_EXECUTE,
2528    InvalidComStmtExecuteHeader
2529);
2530
2531define_const!(
2532    ConstU32,
2533    IterationCount,
2534    InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2535    1
2536);
2537
2538#[derive(Debug, Clone, PartialEq)]
2539pub struct ComStmtExecuteRequest<'a> {
2540    com_stmt_execute: ComStmtExecuteHeader,
2541    stmt_id: RawInt<LeU32>,
2542    flags: Const<CursorType, u8>,
2543    iteration_count: IterationCount,
2544    // max params / bits per byte = 8192
2545    bitmap: RawBytes<'a, BareBytes<8192>>,
2546    params_flags: Const<StmtExecuteParamsFlags, u8>,
2547    params: Vec<&'a Value>,
2548    as_long_data: bool,
2549}
2550
2551impl<'a> ComStmtExecuteRequest<'a> {
2552    pub fn stmt_id(&self) -> u32 {
2553        self.stmt_id.0
2554    }
2555
2556    pub fn flags(&self) -> CursorType {
2557        self.flags.0
2558    }
2559
2560    pub fn bitmap(&self) -> &[u8] {
2561        self.bitmap.as_bytes()
2562    }
2563
2564    pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2565        self.params_flags.0
2566    }
2567
2568    pub fn params(&self) -> &[&'a Value] {
2569        self.params.as_ref()
2570    }
2571
2572    pub fn as_long_data(&self) -> bool {
2573        self.as_long_data
2574    }
2575}
2576
2577impl MySerialize for ComStmtExecuteRequest<'_> {
2578    fn serialize(&self, buf: &mut Vec<u8>) {
2579        self.com_stmt_execute.serialize(&mut *buf);
2580        self.stmt_id.serialize(&mut *buf);
2581        self.flags.serialize(&mut *buf);
2582        self.iteration_count.serialize(&mut *buf);
2583
2584        if !self.params.is_empty() {
2585            self.bitmap.serialize(&mut *buf);
2586            self.params_flags.serialize(&mut *buf);
2587        }
2588
2589        for param in &self.params {
2590            let (column_type, flags) = match param {
2591                Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2592                Value::Bytes(_) => (
2593                    ColumnType::MYSQL_TYPE_VAR_STRING,
2594                    StmtExecuteParamFlags::empty(),
2595                ),
2596                Value::Int(_) => (
2597                    ColumnType::MYSQL_TYPE_LONGLONG,
2598                    StmtExecuteParamFlags::empty(),
2599                ),
2600                Value::UInt(_) => (
2601                    ColumnType::MYSQL_TYPE_LONGLONG,
2602                    StmtExecuteParamFlags::UNSIGNED,
2603                ),
2604                Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2605                Value::Double(_) => (
2606                    ColumnType::MYSQL_TYPE_DOUBLE,
2607                    StmtExecuteParamFlags::empty(),
2608                ),
2609                Value::Date(..) => (
2610                    ColumnType::MYSQL_TYPE_DATETIME,
2611                    StmtExecuteParamFlags::empty(),
2612                ),
2613                Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2614            };
2615
2616            buf.put_slice(&[column_type as u8, flags.bits()]);
2617        }
2618
2619        for param in &self.params {
2620            match **param {
2621                Value::Int(_)
2622                | Value::UInt(_)
2623                | Value::Float(_)
2624                | Value::Double(_)
2625                | Value::Date(..)
2626                | Value::Time(..) => {
2627                    param.serialize(buf);
2628                }
2629                Value::Bytes(_) if !self.as_long_data => {
2630                    param.serialize(buf);
2631                }
2632                Value::Bytes(_) | Value::NULL => {}
2633            }
2634        }
2635    }
2636}
2637
2638define_header!(
2639    ComStmtSendLongDataHeader,
2640    COM_STMT_SEND_LONG_DATA,
2641    InvalidComStmtSendLongDataHeader
2642);
2643
2644#[derive(Debug, Clone, Eq, PartialEq)]
2645pub struct ComStmtSendLongData<'a> {
2646    __header: ComStmtSendLongDataHeader,
2647    stmt_id: RawInt<LeU32>,
2648    param_index: RawInt<LeU16>,
2649    data: RawBytes<'a, EofBytes>,
2650}
2651
2652impl<'a> ComStmtSendLongData<'a> {
2653    pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2654        Self {
2655            __header: ComStmtSendLongDataHeader::new(),
2656            stmt_id: RawInt::new(stmt_id),
2657            param_index: RawInt::new(param_index),
2658            data: RawBytes::new(data),
2659        }
2660    }
2661
2662    pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2663        ComStmtSendLongData {
2664            __header: self.__header,
2665            stmt_id: self.stmt_id,
2666            param_index: self.param_index,
2667            data: self.data.into_owned(),
2668        }
2669    }
2670}
2671
2672impl MySerialize for ComStmtSendLongData<'_> {
2673    fn serialize(&self, buf: &mut Vec<u8>) {
2674        self.__header.serialize(&mut *buf);
2675        self.stmt_id.serialize(&mut *buf);
2676        self.param_index.serialize(&mut *buf);
2677        self.data.serialize(&mut *buf);
2678    }
2679}
2680
2681#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2682pub struct ComStmtClose {
2683    pub stmt_id: u32,
2684}
2685
2686impl ComStmtClose {
2687    pub fn new(stmt_id: u32) -> Self {
2688        Self { stmt_id }
2689    }
2690}
2691
2692impl MySerialize for ComStmtClose {
2693    fn serialize(&self, buf: &mut Vec<u8>) {
2694        buf.put_u8(Command::COM_STMT_CLOSE as u8);
2695        buf.put_u32_le(self.stmt_id);
2696    }
2697}
2698
2699define_header!(
2700    ComRegisterSlaveHeader,
2701    COM_REGISTER_SLAVE,
2702    InvalidComRegisterSlaveHeader
2703);
2704
2705/// Registers a slave at the master. Should be sent before requesting a binlog events
2706/// with `COM_BINLOG_DUMP`.
2707#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2708pub struct ComRegisterSlave<'a> {
2709    header: ComRegisterSlaveHeader,
2710    /// The slaves server-id.
2711    server_id: RawInt<LeU32>,
2712    /// The host name or IP address of the slave to be reported to the master during slave
2713    /// registration. Usually empty.
2714    hostname: RawBytes<'a, U8Bytes>,
2715    /// The account user name of the slave to be reported to the master during slave registration.
2716    /// Usually empty.
2717    ///
2718    /// # Note
2719    ///
2720    /// Serialization will truncate this value if length is greater than 255 bytes.
2721    user: RawBytes<'a, U8Bytes>,
2722    /// The account password of the slave to be reported to the master during slave registration.
2723    /// Usually empty.
2724    ///
2725    /// # Note
2726    ///
2727    /// Serialization will truncate this value if length is greater than 255 bytes.
2728    password: RawBytes<'a, U8Bytes>,
2729    /// The TCP/IP port number for connecting to the slave, to be reported to the master during
2730    /// slave registration. Usually empty.
2731    ///
2732    /// # Note
2733    ///
2734    /// Serialization will truncate this value if length is greater than 255 bytes.
2735    port: RawInt<LeU16>,
2736    /// Ignored.
2737    replication_rank: RawInt<LeU32>,
2738    /// Usually 0. Appears as "master id" in `SHOW SLAVE HOSTS` on the master. Unknown what else
2739    /// it impacts.
2740    master_id: RawInt<LeU32>,
2741}
2742
2743impl<'a> ComRegisterSlave<'a> {
2744    /// Creates new `ComRegisterSlave` with the given server identifier. Other fields will be empty.
2745    pub fn new(server_id: u32) -> Self {
2746        Self {
2747            header: Default::default(),
2748            server_id: RawInt::new(server_id),
2749            hostname: Default::default(),
2750            user: Default::default(),
2751            password: Default::default(),
2752            port: Default::default(),
2753            replication_rank: Default::default(),
2754            master_id: Default::default(),
2755        }
2756    }
2757
2758    /// Sets the `hostname` field of the packet (maximum length is 255 bytes).
2759    pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2760        self.hostname = RawBytes::new(hostname);
2761        self
2762    }
2763
2764    /// Sets the `user` field of the packet (maximum length is 255 bytes).
2765    pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2766        self.user = RawBytes::new(user);
2767        self
2768    }
2769
2770    /// Sets the `password` field of the packet (maximum length is 255 bytes).
2771    pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2772        self.password = RawBytes::new(password);
2773        self
2774    }
2775
2776    /// Sets the `port` field of the packet.
2777    pub fn with_port(mut self, port: u16) -> Self {
2778        self.port = RawInt::new(port);
2779        self
2780    }
2781
2782    /// Sets the `replication_rank` field of the packet.
2783    pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2784        self.replication_rank = RawInt::new(replication_rank);
2785        self
2786    }
2787
2788    /// Sets the `master_id` field of the packet.
2789    pub fn with_master_id(mut self, master_id: u32) -> Self {
2790        self.master_id = RawInt::new(master_id);
2791        self
2792    }
2793
2794    /// Returns the `server_id` field of the packet.
2795    pub fn server_id(&self) -> u32 {
2796        self.server_id.0
2797    }
2798
2799    /// Returns the raw `hostname` field value.
2800    pub fn hostname_raw(&self) -> &[u8] {
2801        self.hostname.as_bytes()
2802    }
2803
2804    /// Returns the `hostname` field as a UTF-8 string (lossy converted).
2805    pub fn hostname(&'a self) -> Cow<'a, str> {
2806        self.hostname.as_str()
2807    }
2808
2809    /// Returns the raw `user` field value.
2810    pub fn user_raw(&self) -> &[u8] {
2811        self.user.as_bytes()
2812    }
2813
2814    /// Returns the `user` field as a UTF-8 string (lossy converted).
2815    pub fn user(&'a self) -> Cow<'a, str> {
2816        self.user.as_str()
2817    }
2818
2819    /// Returns the raw `password` field value.
2820    pub fn password_raw(&self) -> &[u8] {
2821        self.password.as_bytes()
2822    }
2823
2824    /// Returns the `password` field as a UTF-8 string (lossy converted).
2825    pub fn password(&'a self) -> Cow<'a, str> {
2826        self.password.as_str()
2827    }
2828
2829    /// Returns the `port` field of the packet.
2830    pub fn port(&self) -> u16 {
2831        self.port.0
2832    }
2833
2834    /// Returns the `replication_rank` field of the packet.
2835    pub fn replication_rank(&self) -> u32 {
2836        self.replication_rank.0
2837    }
2838
2839    /// Returns the `master_id` field of the packet.
2840    pub fn master_id(&self) -> u32 {
2841        self.master_id.0
2842    }
2843}
2844
2845impl MySerialize for ComRegisterSlave<'_> {
2846    fn serialize(&self, buf: &mut Vec<u8>) {
2847        self.header.serialize(&mut *buf);
2848        self.server_id.serialize(&mut *buf);
2849        self.hostname.serialize(&mut *buf);
2850        self.user.serialize(&mut *buf);
2851        self.password.serialize(&mut *buf);
2852        self.port.serialize(&mut *buf);
2853        self.replication_rank.serialize(&mut *buf);
2854        self.master_id.serialize(&mut *buf);
2855    }
2856}
2857
2858impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2859    const SIZE: Option<usize> = None;
2860    type Ctx = ();
2861
2862    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2863        let mut sbuf: ParseBuf = buf.parse(5)?;
2864        let header = sbuf.parse_unchecked(())?;
2865        let server_id = sbuf.parse_unchecked(())?;
2866
2867        let hostname = buf.parse(())?;
2868        let user = buf.parse(())?;
2869        let password = buf.parse(())?;
2870
2871        let mut sbuf: ParseBuf = buf.parse(10)?;
2872        let port = sbuf.parse_unchecked(())?;
2873        let replication_rank = sbuf.parse_unchecked(())?;
2874        let master_id = sbuf.parse_unchecked(())?;
2875
2876        Ok(Self {
2877            header,
2878            server_id,
2879            hostname,
2880            user,
2881            password,
2882            port,
2883            replication_rank,
2884            master_id,
2885        })
2886    }
2887}
2888
2889define_header!(
2890    ComTableDumpHeader,
2891    COM_TABLE_DUMP,
2892    InvalidComTableDumpHeader
2893);
2894
2895/// COM_TABLE_DUMP command.
2896#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2897pub struct ComTableDump<'a> {
2898    header: ComTableDumpHeader,
2899    /// Database name.
2900    ///
2901    /// # Note
2902    ///
2903    /// Serialization will truncate this value if length is greater than 255 bytes.
2904    database: RawBytes<'a, U8Bytes>,
2905    /// Table name.
2906    ///
2907    /// # Note
2908    ///
2909    /// Serialization will truncate this value if length is greater than 255 bytes.
2910    table: RawBytes<'a, U8Bytes>,
2911}
2912
2913impl<'a> ComTableDump<'a> {
2914    /// Creates new instance.
2915    pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2916        Self {
2917            header: Default::default(),
2918            database: RawBytes::new(database),
2919            table: RawBytes::new(table),
2920        }
2921    }
2922
2923    /// Returns the raw `database` field value.
2924    pub fn database_raw(&self) -> &[u8] {
2925        self.database.as_bytes()
2926    }
2927
2928    /// Returns the `database` field value as a UTF-8 string (lossy converted).
2929    pub fn database(&self) -> Cow<str> {
2930        self.database.as_str()
2931    }
2932
2933    /// Returns the raw `table` field value.
2934    pub fn table_raw(&self) -> &[u8] {
2935        self.table.as_bytes()
2936    }
2937
2938    /// Returns the `table` field value as a UTF-8 string (lossy converted).
2939    pub fn table(&self) -> Cow<str> {
2940        self.table.as_str()
2941    }
2942}
2943
2944impl MySerialize for ComTableDump<'_> {
2945    fn serialize(&self, buf: &mut Vec<u8>) {
2946        self.header.serialize(&mut *buf);
2947        self.database.serialize(&mut *buf);
2948        self.table.serialize(&mut *buf);
2949    }
2950}
2951
2952impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
2953    const SIZE: Option<usize> = None;
2954    type Ctx = ();
2955
2956    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2957        Ok(Self {
2958            header: buf.parse(())?,
2959            database: buf.parse(())?,
2960            table: buf.parse(())?,
2961        })
2962    }
2963}
2964
2965my_bitflags! {
2966    BinlogDumpFlags,
2967    #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={:b})", _0)]
2968    UnknownBinlogDumpFlags,
2969    u16,
2970
2971    /// Empty flags of a `LoadEvent`.
2972    #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
2973    pub struct BinlogDumpFlags: u16 {
2974        /// If there is no more event to send a EOF_Packet instead of blocking the connection
2975        const BINLOG_DUMP_NON_BLOCK = 0x01;
2976        const BINLOG_THROUGH_POSITION = 0x02;
2977        const BINLOG_THROUGH_GTID = 0x04;
2978    }
2979}
2980
2981define_header!(
2982    ComBinlogDumpHeader,
2983    COM_BINLOG_DUMP,
2984    InvalidComBinlogDumpHeader
2985);
2986
2987/// Command to request a binlog-stream from the master starting a given position.
2988#[derive(Clone, Debug, Eq, PartialEq, Hash)]
2989pub struct ComBinlogDump<'a> {
2990    header: ComBinlogDumpHeader,
2991    /// Position in the binlog-file to start the stream with (`0` by default).
2992    pos: RawInt<LeU32>,
2993    /// Command flags (empty by default).
2994    ///
2995    /// Only `BINLOG_DUMP_NON_BLOCK` is supported for this command.
2996    flags: Const<BinlogDumpFlags, LeU16>,
2997    /// Server id of this slave.
2998    server_id: RawInt<LeU32>,
2999    /// Filename of the binlog on the master.
3000    ///
3001    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3002    /// binlog.
3003    filename: RawBytes<'a, EofBytes>,
3004}
3005
3006impl<'a> ComBinlogDump<'a> {
3007    /// Creates new instance with default values for `pos` and `flags`.
3008    pub fn new(server_id: u32) -> Self {
3009        Self {
3010            header: Default::default(),
3011            pos: Default::default(),
3012            flags: Default::default(),
3013            server_id: RawInt::new(server_id),
3014            filename: Default::default(),
3015        }
3016    }
3017
3018    /// Defines position for this instance.
3019    pub fn with_pos(mut self, pos: u32) -> Self {
3020        self.pos = RawInt::new(pos);
3021        self
3022    }
3023
3024    /// Defines flags for this instance.
3025    pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3026        self.flags = Const::new(flags);
3027        self
3028    }
3029
3030    /// Defines filename for this instance.
3031    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3032        self.filename = RawBytes::new(filename);
3033        self
3034    }
3035
3036    /// Returns parsed `pos` field with unknown bits truncated.
3037    pub fn pos(&self) -> u32 {
3038        *self.pos
3039    }
3040
3041    /// Returns parsed `flags` field with unknown bits truncated.
3042    pub fn flags(&self) -> BinlogDumpFlags {
3043        *self.flags
3044    }
3045
3046    /// Returns parsed `server_id` field with unknown bits truncated.
3047    pub fn server_id(&self) -> u32 {
3048        *self.server_id
3049    }
3050
3051    /// Returns the raw `filename` field value.
3052    pub fn filename_raw(&self) -> &[u8] {
3053        self.filename.as_bytes()
3054    }
3055
3056    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3057    pub fn filename(&self) -> Cow<str> {
3058        self.filename.as_str()
3059    }
3060}
3061
3062impl MySerialize for ComBinlogDump<'_> {
3063    fn serialize(&self, buf: &mut Vec<u8>) {
3064        self.header.serialize(&mut *buf);
3065        self.pos.serialize(&mut *buf);
3066        self.flags.serialize(&mut *buf);
3067        self.server_id.serialize(&mut *buf);
3068        self.filename.serialize(&mut *buf);
3069    }
3070}
3071
3072impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3073    const SIZE: Option<usize> = None;
3074    type Ctx = ();
3075
3076    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3077        let mut sbuf: ParseBuf = buf.parse(11)?;
3078        Ok(Self {
3079            header: sbuf.parse_unchecked(())?,
3080            pos: sbuf.parse_unchecked(())?,
3081            flags: sbuf.parse_unchecked(())?,
3082            server_id: sbuf.parse_unchecked(())?,
3083            filename: buf.parse(())?,
3084        })
3085    }
3086}
3087
3088/// GnoInterval. Stored within [`Sid`]
3089#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3090pub struct GnoInterval {
3091    start: RawInt<LeU64>,
3092    end: RawInt<LeU64>,
3093}
3094
3095impl GnoInterval {
3096    /// Creates a new interval.
3097    pub fn new(start: u64, end: u64) -> Self {
3098        Self {
3099            start: RawInt::new(start),
3100            end: RawInt::new(end),
3101        }
3102    }
3103    /// Checks if the [start, end) interval is valid and creates it.
3104    pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3105        if start >= end {
3106            return Err(io::Error::new(
3107                io::ErrorKind::InvalidData,
3108                format!("start({}) >= end({}) in GnoInterval", start, end),
3109            ));
3110        }
3111        if start == 0 || end == 0 {
3112            return Err(io::Error::new(
3113                io::ErrorKind::InvalidData,
3114                "Gno can't be zero",
3115            ));
3116        }
3117        Ok(Self::new(start, end))
3118    }
3119}
3120
3121impl MySerialize for GnoInterval {
3122    fn serialize(&self, buf: &mut Vec<u8>) {
3123        self.start.serialize(&mut *buf);
3124        self.end.serialize(&mut *buf);
3125    }
3126}
3127
3128impl<'de> MyDeserialize<'de> for GnoInterval {
3129    const SIZE: Option<usize> = Some(16);
3130    type Ctx = ();
3131
3132    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3133        Ok(Self {
3134            start: buf.parse_unchecked(())?,
3135            end: buf.parse_unchecked(())?,
3136        })
3137    }
3138}
3139
3140/// Length of a Uuid in `COM_BINLOG_DUMP_GTID` command packet.
3141pub const UUID_LEN: usize = 16;
3142
3143/// SID is a part of the `COM_BINLOG_DUMP_GTID` command. It's a GtidSet whose
3144/// has only one Uuid.
3145#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3146pub struct Sid<'a> {
3147    uuid: [u8; UUID_LEN],
3148    intervals: Seq<'a, GnoInterval, LeU64>,
3149}
3150
3151impl Sid<'_> {
3152    /// Creates a new instance.
3153    pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3154        Self {
3155            uuid,
3156            intervals: Default::default(),
3157        }
3158    }
3159
3160    /// Returns the `uuid` field value.
3161    pub fn uuid(&self) -> [u8; UUID_LEN] {
3162        self.uuid
3163    }
3164
3165    /// Returns the `intervals` field value.
3166    pub fn intervals(&self) -> &[GnoInterval] {
3167        &self.intervals[..]
3168    }
3169
3170    /// Appends an GnoInterval to this block.
3171    pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3172        let mut intervals = self.intervals.0.into_owned();
3173        intervals.push(interval);
3174        self.intervals = Seq::new(intervals);
3175        self
3176    }
3177
3178    /// Sets the `intevals` value for this block.
3179    pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3180        self.intervals = Seq::new(intervals);
3181        self
3182    }
3183
3184    fn len(&self) -> u64 {
3185        use saturating::Saturating as S;
3186        let mut len = S(UUID_LEN as u64); // SID
3187        len += S(8); // n_intervals
3188        len += S((self.intervals.len() * 16) as u64);
3189        len.0
3190    }
3191}
3192
3193impl MySerialize for Sid<'_> {
3194    fn serialize(&self, buf: &mut Vec<u8>) {
3195        self.uuid.serialize(&mut *buf);
3196        self.intervals.serialize(buf);
3197    }
3198}
3199
3200impl<'de> MyDeserialize<'de> for Sid<'de> {
3201    const SIZE: Option<usize> = None;
3202    type Ctx = ();
3203
3204    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3205        Ok(Self {
3206            uuid: buf.parse(())?,
3207            intervals: buf.parse(())?,
3208        })
3209    }
3210}
3211
3212impl Sid<'_> {
3213    fn wrap_err(msg: String) -> io::Error {
3214        io::Error::new(io::ErrorKind::InvalidInput, msg)
3215    }
3216
3217    fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3218        let n: u64 = to_parse.parse().map_err(|e| {
3219            Sid::wrap_err(format!(
3220                "invalid GnoInterval format: {}, error: {}",
3221                full, e
3222            ))
3223        })?;
3224        Ok(n)
3225    }
3226}
3227
3228impl FromStr for Sid<'_> {
3229    type Err = io::Error;
3230
3231    fn from_str(s: &str) -> Result<Self, Self::Err> {
3232        let (uuid, intervals) = s
3233            .split_once(':')
3234            .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3235        let uuid = Uuid::parse_str(uuid)
3236            .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3237        let intervals = intervals
3238            .split(':')
3239            .map(|interval| {
3240                let nums = interval.split('-').collect::<Vec<_>>();
3241                if nums.len() != 1 && nums.len() != 2 {
3242                    return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3243                }
3244                if nums.len() == 1 {
3245                    let start = Sid::parse_interval_num(nums[0], s)?;
3246                    let interval = GnoInterval::check_and_new(start, start + 1)?;
3247                    Ok(interval)
3248                } else {
3249                    let start = Sid::parse_interval_num(nums[0], s)?;
3250                    let end = Sid::parse_interval_num(nums[1], s)?;
3251                    let interval = GnoInterval::check_and_new(start, end + 1)?;
3252                    Ok(interval)
3253                }
3254            })
3255            .collect::<Result<Vec<_>, _>>()?;
3256        Ok(Self {
3257            uuid: *uuid.as_bytes(),
3258            intervals: Seq::new(intervals),
3259        })
3260    }
3261}
3262
3263define_header!(
3264    ComBinlogDumpGtidHeader,
3265    COM_BINLOG_DUMP_GTID,
3266    InvalidComBinlogDumpGtidHeader
3267);
3268
3269/// Command to request a binlog-stream from the master starting a given position.
3270#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3271pub struct ComBinlogDumpGtid<'a> {
3272    header: ComBinlogDumpGtidHeader,
3273    /// Command flags (empty by default).
3274    flags: Const<BinlogDumpFlags, LeU16>,
3275    /// Server id of this slave.
3276    server_id: RawInt<LeU32>,
3277    /// Filename of the binlog on the master.
3278    ///
3279    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3280    /// binlog.
3281    ///
3282    /// # Note
3283    ///
3284    /// Serialization will truncate this value if length is greater than 2^32 - 1 bytes.
3285    filename: RawBytes<'a, U32Bytes>,
3286    /// Position in the binlog-file to start the stream with (`0` by default).
3287    pos: RawInt<LeU64>,
3288    /// SID block.
3289    sid_block: Seq<'a, Sid<'a>, LeU64>,
3290}
3291
3292impl<'a> ComBinlogDumpGtid<'a> {
3293    /// Creates new instance with default values for `pos`, `data` and `flags` fields.
3294    pub fn new(server_id: u32) -> Self {
3295        Self {
3296            header: Default::default(),
3297            pos: Default::default(),
3298            flags: Default::default(),
3299            server_id: RawInt::new(server_id),
3300            filename: Default::default(),
3301            sid_block: Default::default(),
3302        }
3303    }
3304
3305    /// Returns the `server_id` field value.
3306    pub fn server_id(&self) -> u32 {
3307        self.server_id.0
3308    }
3309
3310    /// Returns the `flags` field value.
3311    pub fn flags(&self) -> BinlogDumpFlags {
3312        self.flags.0
3313    }
3314
3315    /// Returns the `filename` field value.
3316    pub fn filename_raw(&self) -> &[u8] {
3317        self.filename.as_bytes()
3318    }
3319
3320    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3321    pub fn filename(&self) -> Cow<str> {
3322        self.filename.as_str()
3323    }
3324
3325    /// Returns the `pos` field value.
3326    pub fn pos(&self) -> u64 {
3327        self.pos.0
3328    }
3329
3330    /// Returns the sequence of sids in this packet.
3331    pub fn sids(&self) -> &[Sid<'a>] {
3332        &self.sid_block
3333    }
3334
3335    /// Defines filename for this instance.
3336    pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3337        Self {
3338            header: self.header,
3339            flags: self.flags,
3340            server_id: self.server_id,
3341            filename: RawBytes::new(filename),
3342            pos: self.pos,
3343            sid_block: self.sid_block,
3344        }
3345    }
3346
3347    /// Sets the `server_id` field value.
3348    pub fn with_server_id(mut self, server_id: u32) -> Self {
3349        self.server_id.0 = server_id;
3350        self
3351    }
3352
3353    /// Sets the `flags` field value.
3354    pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3355        if self.sid_block.is_empty() {
3356            flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3357        } else {
3358            flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3359        }
3360        self.flags.0 = flags;
3361        self
3362    }
3363
3364    /// Sets the `pos` field value.
3365    pub fn with_pos(mut self, pos: u64) -> Self {
3366        self.pos.0 = pos;
3367        self
3368    }
3369
3370    /// Sets the `sid_block` field value.
3371    pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3372        self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3373        self.sid_block.push(sid);
3374        self
3375    }
3376
3377    /// Sets the `sid_block` field value.
3378    pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3379        self.sid_block = Seq::new(sids);
3380        if self.sid_block.is_empty() {
3381            self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3382        } else {
3383            self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3384        }
3385        self
3386    }
3387
3388    fn sid_block_len(&self) -> u32 {
3389        use saturating::Saturating as S;
3390        let mut len = S(8); // n_sids
3391        for sid in self.sid_block.iter() {
3392            len += S(sid.len() as u32);
3393        }
3394        len.0
3395    }
3396}
3397
3398impl MySerialize for ComBinlogDumpGtid<'_> {
3399    fn serialize(&self, buf: &mut Vec<u8>) {
3400        self.header.serialize(&mut *buf);
3401        self.flags.serialize(&mut *buf);
3402        self.server_id.serialize(&mut *buf);
3403        self.filename.serialize(&mut *buf);
3404        self.pos.serialize(&mut *buf);
3405        buf.put_u32_le(self.sid_block_len());
3406        self.sid_block.serialize(&mut *buf);
3407    }
3408}
3409
3410impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3411    const SIZE: Option<usize> = None;
3412    type Ctx = ();
3413
3414    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3415        let mut sbuf: ParseBuf = buf.parse(7)?;
3416        let header = sbuf.parse_unchecked(())?;
3417        let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3418        let server_id = sbuf.parse_unchecked(())?;
3419
3420        let filename = buf.parse(())?;
3421        let pos = buf.parse(())?;
3422
3423        // `flags` should contain `BINLOG_THROUGH_GTID` flag if sid_block isn't empty
3424        let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3425        let mut buf: ParseBuf = buf.parse(sid_data_len.0 as usize)?;
3426        let sid_block = buf.parse(())?;
3427
3428        Ok(Self {
3429            header,
3430            flags,
3431            server_id,
3432            filename,
3433            pos,
3434            sid_block,
3435        })
3436    }
3437}
3438
3439define_header!(
3440    SemiSyncAckPacketPacketHeader,
3441    InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3442    0xEF
3443);
3444
3445/// Each Semi Sync Binlog Event with the `SEMI_SYNC_ACK_REQ` flag set the slave has to acknowledge
3446/// with Semi-Sync ACK packet.
3447pub struct SemiSyncAckPacket<'a> {
3448    header: SemiSyncAckPacketPacketHeader,
3449    position: RawInt<LeU64>,
3450    filename: RawBytes<'a, EofBytes>,
3451}
3452
3453impl<'a> SemiSyncAckPacket<'a> {
3454    pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3455        Self {
3456            header: Default::default(),
3457            position: RawInt::new(position),
3458            filename: RawBytes::new(filename),
3459        }
3460    }
3461
3462    /// Sets the `position` field value.
3463    pub fn with_position(mut self, position: u64) -> Self {
3464        self.position.0 = position;
3465        self
3466    }
3467
3468    /// Sets the `filename` field value.
3469    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3470        self.filename = RawBytes::new(filename);
3471        self
3472    }
3473
3474    /// Returns the `position` field value.
3475    pub fn position(&self) -> u64 {
3476        self.position.0
3477    }
3478
3479    /// Returns the raw `filename` field value.
3480    pub fn filename_raw(&self) -> &[u8] {
3481        self.filename.as_bytes()
3482    }
3483
3484    /// Returns the `filename` field value as a string (lossy converted).
3485    pub fn filename(&self) -> Cow<'_, str> {
3486        self.filename.as_str()
3487    }
3488}
3489
3490impl MySerialize for SemiSyncAckPacket<'_> {
3491    fn serialize(&self, buf: &mut Vec<u8>) {
3492        self.header.serialize(&mut *buf);
3493        self.position.serialize(&mut *buf);
3494        self.filename.serialize(&mut *buf);
3495    }
3496}
3497
3498impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3499    const SIZE: Option<usize> = None;
3500    type Ctx = ();
3501
3502    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3503        let mut sbuf: ParseBuf = buf.parse(9)?;
3504        Ok(Self {
3505            header: sbuf.parse_unchecked(())?,
3506            position: sbuf.parse_unchecked(())?,
3507            filename: buf.parse(())?,
3508        })
3509    }
3510}
3511
3512#[cfg(test)]
3513mod test {
3514    use super::*;
3515    use crate::{
3516        constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3517        proto::{MyDeserialize, MySerialize},
3518    };
3519
3520    proptest::proptest! {
3521        #[test]
3522        fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3523            let cmd = ComTableDump::new(database, table);
3524
3525            let mut output = Vec::new();
3526            cmd.serialize(&mut output);
3527
3528            assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3529        }
3530
3531        #[test]
3532        fn com_binlog_dump_roundtrip(
3533            server_id: u32,
3534            filename: Vec<u8>,
3535            pos: u32,
3536            flags: u16,
3537        ) {
3538            let cmd = ComBinlogDump::new(server_id)
3539                .with_filename(filename)
3540                .with_pos(pos)
3541                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3542
3543            let mut output = Vec::new();
3544            cmd.serialize(&mut output);
3545
3546            assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3547        }
3548
3549        #[test]
3550        fn com_register_slave_roundtrip(
3551            server_id: u32,
3552            hostname in r"\w{0,256}",
3553            user in r"\w{0,256}",
3554            password in r"\w{0,256}",
3555            port: u16,
3556            replication_rank: u32,
3557            master_id: u32,
3558        ) {
3559            let cmd = ComRegisterSlave::new(server_id)
3560                .with_hostname(hostname.as_bytes())
3561                .with_user(user.as_bytes())
3562                .with_password(password.as_bytes())
3563                .with_port(port)
3564                .with_replication_rank(replication_rank)
3565                .with_master_id(master_id);
3566
3567            let mut output = Vec::new();
3568            cmd.serialize(&mut output);
3569            let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3570
3571            if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3572                assert_ne!(cmd, parsed);
3573            } else {
3574                assert_eq!(cmd, parsed);
3575            }
3576        }
3577
3578        #[test]
3579        fn com_binlog_dump_gtid_roundtrip(
3580            flags: u16,
3581            server_id: u32,
3582            filename: Vec<u8>,
3583            pos: u64,
3584            n_sid_blocks in 0_u64..1024,
3585        ) {
3586            let mut cmd = ComBinlogDumpGtid::new(server_id)
3587                .with_filename(filename)
3588                .with_pos(pos)
3589                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3590
3591            let mut sids = Vec::new();
3592            for i in 0..n_sid_blocks {
3593                let mut block = Sid::new([i as u8; 16]);
3594                for j in 0..i {
3595                    block = block.with_interval(GnoInterval::new(i, j));
3596                }
3597                sids.push(block);
3598            }
3599
3600            cmd = cmd.with_sids(sids);
3601
3602            let mut output = Vec::new();
3603            cmd.serialize(&mut output);
3604
3605            assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3606        }
3607    }
3608
3609    #[test]
3610    fn should_parse_local_infile_packet() {
3611        const LIP: &[u8] = b"\xfbfile_name";
3612
3613        let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3614        assert_eq!(lip.file_name_str(), "file_name");
3615    }
3616
3617    #[test]
3618    fn should_parse_stmt_packet() {
3619        const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3620        const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3621
3622        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3623        assert_eq!(sp.statement_id(), 0x01);
3624        assert_eq!(sp.num_columns(), 0x01);
3625        assert_eq!(sp.num_params(), 0x02);
3626        assert_eq!(sp.warning_count(), 0x00);
3627
3628        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3629        assert_eq!(sp.statement_id(), 0x01);
3630        assert_eq!(sp.num_columns(), 0x00);
3631        assert_eq!(sp.num_params(), 0x00);
3632        assert_eq!(sp.warning_count(), 0x00);
3633    }
3634
3635    #[test]
3636    fn should_parse_handshake_packet() {
3637        const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3638                             \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3639                             \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3640                             \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3641
3642        const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3643                               \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3644                               \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3645                               \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3646                               \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3647                               \x6f\x72\x64\x00";
3648
3649        const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3650                                \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3651                                \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3652                                \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3653                                \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3654                                \x6f\x72\x64\x00";
3655
3656        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3657        assert_eq!(hsp.protocol_version(), 0x0a);
3658        assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3659        assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3660        assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3661        assert_eq!(hsp.connection_id(), 0x0b);
3662        assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3663        assert_eq!(
3664            hsp.capabilities(),
3665            CapabilityFlags::from_bits_truncate(0xf7ff)
3666        );
3667        assert_eq!(hsp.default_collation(), 0x08);
3668        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3669        assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3670        assert_eq!(hsp.auth_plugin_name_ref(), None);
3671
3672        let mut output = Vec::new();
3673        hsp.serialize(&mut output);
3674        assert_eq!(&output, HSP);
3675
3676        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3677        assert_eq!(hsp.protocol_version(), 0x0a);
3678        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3679        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3680        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3681        assert_eq!(hsp.connection_id(), 0x0a56);
3682        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3683        assert_eq!(
3684            hsp.capabilities(),
3685            CapabilityFlags::from_bits_truncate(0xc00fffff)
3686        );
3687        assert_eq!(hsp.default_collation(), 0x08);
3688        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3689        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3690        assert_eq!(
3691            hsp.auth_plugin_name_ref(),
3692            Some(&b"mysql_native_password"[..])
3693        );
3694
3695        let mut output = Vec::new();
3696        hsp.serialize(&mut output);
3697        assert_eq!(&output, HSP_2);
3698
3699        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3700        assert_eq!(hsp.protocol_version(), 0x0a);
3701        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3702        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3703        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3704        assert_eq!(hsp.connection_id(), 0x0a56);
3705        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3706        assert_eq!(
3707            hsp.capabilities(),
3708            CapabilityFlags::from_bits_truncate(0xc00fffff)
3709        );
3710        assert_eq!(hsp.default_collation(), 0x08);
3711        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3712        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3713        assert_eq!(
3714            hsp.auth_plugin_name_ref(),
3715            Some(&b"mysql_native_password"[..])
3716        );
3717
3718        let mut output = Vec::new();
3719        hsp.serialize(&mut output);
3720        assert_eq!(&output, HSP_3);
3721    }
3722
3723    #[test]
3724    fn should_parse_err_packet() {
3725        const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3726        \x6c\x65\x73\x20\x75\x73\x65\x64";
3727        const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3728        \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3729        const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3730
3731        let err_packet = ErrPacket::deserialize(
3732            CapabilityFlags::CLIENT_PROTOCOL_41,
3733            &mut ParseBuf(ERR_PACKET),
3734        )
3735        .unwrap();
3736        let err_packet = err_packet.server_error();
3737        assert_eq!(err_packet.error_code(), 1096);
3738        assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3739        assert_eq!(err_packet.message_str(), "No tables used");
3740
3741        let err_packet =
3742            ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3743                .unwrap();
3744        let server_error = err_packet.server_error();
3745        assert_eq!(server_error.error_code(), 1040);
3746        assert_eq!(server_error.sql_state_ref(), None);
3747        assert_eq!(server_error.message_str(), "Too many connections");
3748
3749        let err_packet = ErrPacket::deserialize(
3750            CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3751            &mut ParseBuf(PROGRESS_PACKET),
3752        )
3753        .unwrap();
3754        assert!(err_packet.is_progress_report());
3755        let progress_report = err_packet.progress_report();
3756        assert_eq!(progress_report.stage(), 1);
3757        assert_eq!(progress_report.max_stage(), 10);
3758        assert_eq!(progress_report.progress(), 23500);
3759        assert_eq!(progress_report.stage_info_str(), "stage name");
3760    }
3761
3762    #[test]
3763    fn should_parse_column_packet() {
3764        const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3765              \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3766        let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3767        assert_eq!(column.schema_str(), "schema");
3768        assert_eq!(column.table_str(), "table");
3769        assert_eq!(column.org_table_str(), "org_table");
3770        assert_eq!(column.name_str(), "name");
3771        assert_eq!(column.org_name_str(), "org_name");
3772        assert_eq!(
3773            column.character_set(),
3774            CollationId::UTF8MB3_GENERAL_CI as u16
3775        );
3776        assert_eq!(column.column_length(), 15);
3777        assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3778        assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3779        assert_eq!(column.decimals(), 8);
3780    }
3781
3782    #[test]
3783    fn should_parse_auth_switch_request() {
3784        const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3785                                 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3786                                 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3787        let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3788        assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3789        assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3790    }
3791
3792    #[test]
3793    fn should_parse_auth_more_data() {
3794        const PAYLOAD: &[u8] = b"\x01\x04";
3795        let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3796        assert_eq!(packet.data(), b"\x04",);
3797    }
3798
3799    #[test]
3800    fn should_parse_ok_packet() {
3801        const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3802        const RESULT_SET_TERMINATOR: &[u8] = &[
3803            0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3804            0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3805            0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3806            0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3807            0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3808            0x73, 0x65, 0x63, 0x2e,
3809        ];
3810        const SESS_STATE_SYS_VAR_OK: &[u8] =
3811            b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3812              \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3813        const SESS_STATE_SCHEMA_OK: &[u8] =
3814            b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3815        const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3816        const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3817
3818        // packet starting with 0x00 is not an ok packet if it terminates a result set
3819        OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3820            CapabilityFlags::empty(),
3821            &mut ParseBuf(PLAIN_OK),
3822        )
3823        .unwrap_err();
3824
3825        let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3826            CapabilityFlags::empty(),
3827            &mut ParseBuf(PLAIN_OK),
3828        )
3829        .unwrap()
3830        .into();
3831        assert_eq!(ok_packet.affected_rows(), 1);
3832        assert_eq!(ok_packet.last_insert_id(), None);
3833        assert_eq!(
3834            ok_packet.status_flags(),
3835            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3836        );
3837        assert_eq!(ok_packet.warnings(), 0);
3838        assert_eq!(ok_packet.info_ref(), None);
3839        assert_eq!(ok_packet.session_state_info_ref(), None);
3840
3841        let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3842            CapabilityFlags::CLIENT_SESSION_TRACK,
3843            &mut ParseBuf(PLAIN_OK),
3844        )
3845        .unwrap()
3846        .into();
3847        assert_eq!(ok_packet.affected_rows(), 1);
3848        assert_eq!(ok_packet.last_insert_id(), None);
3849        assert_eq!(
3850            ok_packet.status_flags(),
3851            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3852        );
3853        assert_eq!(ok_packet.warnings(), 0);
3854        assert_eq!(ok_packet.info_ref(), None);
3855        assert_eq!(ok_packet.session_state_info_ref(), None);
3856
3857        let ok_packet: OkPacket = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3858            CapabilityFlags::CLIENT_SESSION_TRACK,
3859            &mut ParseBuf(RESULT_SET_TERMINATOR),
3860        )
3861        .unwrap()
3862        .into();
3863        assert_eq!(ok_packet.affected_rows(), 0);
3864        assert_eq!(ok_packet.last_insert_id(), None);
3865        assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3866        assert_eq!(ok_packet.warnings(), 0);
3867        assert_eq!(
3868            ok_packet.info_str(),
3869            Some(Cow::Borrowed(
3870                "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3871            ))
3872        );
3873        assert_eq!(ok_packet.session_state_info_ref(), None);
3874
3875        let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3876            CapabilityFlags::CLIENT_SESSION_TRACK,
3877            &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3878        )
3879        .unwrap()
3880        .into();
3881        assert_eq!(ok_packet.affected_rows(), 0);
3882        assert_eq!(ok_packet.last_insert_id(), None);
3883        assert_eq!(
3884            ok_packet.status_flags(),
3885            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3886        );
3887        assert_eq!(ok_packet.warnings(), 0);
3888        assert_eq!(ok_packet.info_ref(), None);
3889        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3890
3891        match sess_state_info.decode().unwrap() {
3892            SessionStateChange::SystemVariables(mut vals) => {
3893                let val = vals.pop().unwrap();
3894                assert_eq!(val.name_bytes(), b"autocommit");
3895                assert_eq!(val.value_bytes(), b"OFF");
3896                assert!(vals.is_empty());
3897            }
3898            _ => panic!(),
3899        }
3900
3901        let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3902            CapabilityFlags::CLIENT_SESSION_TRACK,
3903            &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3904        )
3905        .unwrap()
3906        .into();
3907        assert_eq!(ok_packet.affected_rows(), 0);
3908        assert_eq!(ok_packet.last_insert_id(), None);
3909        assert_eq!(
3910            ok_packet.status_flags(),
3911            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3912        );
3913        assert_eq!(ok_packet.warnings(), 0);
3914        assert_eq!(ok_packet.info_ref(), None);
3915        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3916        match sess_state_info.decode().unwrap() {
3917            SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3918            _ => panic!(),
3919        }
3920
3921        let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3922            CapabilityFlags::CLIENT_SESSION_TRACK,
3923            &mut ParseBuf(SESS_STATE_TRACK_OK),
3924        )
3925        .unwrap()
3926        .into();
3927        assert_eq!(ok_packet.affected_rows(), 0);
3928        assert_eq!(ok_packet.last_insert_id(), None);
3929        assert_eq!(
3930            ok_packet.status_flags(),
3931            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3932        );
3933        assert_eq!(ok_packet.warnings(), 0);
3934        assert_eq!(ok_packet.info_ref(), None);
3935        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3936        assert_eq!(
3937            sess_state_info.decode().unwrap(),
3938            SessionStateChange::IsTracked(true),
3939        );
3940
3941        let ok_packet: OkPacket = OkPacketDeserializer::<OldEofPacket>::deserialize(
3942            CapabilityFlags::CLIENT_SESSION_TRACK,
3943            &mut ParseBuf(EOF),
3944        )
3945        .unwrap()
3946        .into();
3947        assert_eq!(ok_packet.affected_rows(), 0);
3948        assert_eq!(ok_packet.last_insert_id(), None);
3949        assert_eq!(
3950            ok_packet.status_flags(),
3951            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3952        );
3953        assert_eq!(ok_packet.warnings(), 0);
3954        assert_eq!(ok_packet.info_ref(), None);
3955        assert_eq!(ok_packet.session_state_info_ref(), None);
3956    }
3957
3958    #[test]
3959    fn should_build_handshake_response() {
3960        let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
3961        let response = HandshakeResponse::new(
3962            Some(&[][..]),
3963            (5u16, 5, 5),
3964            Some(&b"root"[..]),
3965            None::<&'static [u8]>,
3966            Some(AuthPlugin::MysqlNativePassword),
3967            flags_without_db_name,
3968            None,
3969            1_u32.to_be(),
3970        );
3971        let mut actual = Vec::new();
3972        response.serialize(&mut actual);
3973
3974        let expected: Vec<u8> = [
3975            0x05, 0xa2, 0xae, 0x81, // client capabilities
3976            0x00, 0x00, 0x00, 0x01, // max packet
3977            0x2d, // charset
3978            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
3979            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
3980            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
3981            0x00, // blank scramble
3982            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
3983            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
3984        ]
3985        .to_vec();
3986
3987        assert_eq!(expected, actual);
3988
3989        let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
3990        let response = HandshakeResponse::new(
3991            Some(&[][..]),
3992            (5u16, 5, 5),
3993            Some(&b"root"[..]),
3994            Some(&b"mydb"[..]),
3995            Some(AuthPlugin::MysqlNativePassword),
3996            flags_with_db_name,
3997            None,
3998            1_u32.to_be(),
3999        );
4000        let mut actual = Vec::new();
4001        response.serialize(&mut actual);
4002
4003        let expected: Vec<u8> = [
4004            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4005            0x00, 0x00, 0x00, 0x01, // max packet
4006            0x2d, // charset
4007            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4008            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4009            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4010            0x00, // blank scramble
4011            0x6d, 0x79, 0x64, 0x62, 0x00, // dbname
4012            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4013            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4014        ]
4015        .to_vec();
4016
4017        assert_eq!(expected, actual);
4018
4019        let response = HandshakeResponse::new(
4020            Some(&[][..]),
4021            (5u16, 5, 5),
4022            Some(&b"root"[..]),
4023            Some(&b"mydb"[..]),
4024            Some(AuthPlugin::MysqlNativePassword),
4025            flags_without_db_name,
4026            None,
4027            1_u32.to_be(),
4028        );
4029        let mut actual = Vec::new();
4030        response.serialize(&mut actual);
4031        assert_eq!(expected, actual);
4032
4033        let response = HandshakeResponse::new(
4034            Some(&[][..]),
4035            (5u16, 5, 5),
4036            Some(&b"root"[..]),
4037            Some(&[][..]),
4038            Some(AuthPlugin::MysqlNativePassword),
4039            flags_with_db_name,
4040            None,
4041            1_u32.to_be(),
4042        );
4043        let mut actual = Vec::new();
4044        response.serialize(&mut actual);
4045
4046        let expected: Vec<u8> = [
4047            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4048            0x00, 0x00, 0x00, 0x01, // max packet
4049            0x2d, // charset
4050            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4051            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4052            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4053            0x00, // blank db_name
4054            0x00, // blank scramble
4055            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4056            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4057        ]
4058        .to_vec();
4059        assert_eq!(expected, actual);
4060    }
4061
4062    #[test]
4063    fn parse_str_to_sid() {
4064        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4065        let sid = input.parse::<Sid>().unwrap();
4066        let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4067        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4068        assert_eq!(sid.intervals.len(), 1);
4069        assert_eq!(sid.intervals[0].start.0, 23);
4070        assert_eq!(sid.intervals[0].end.0, 24);
4071
4072        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4073        let sid = input.parse::<Sid>().unwrap();
4074        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4075        assert_eq!(sid.intervals.len(), 2);
4076        assert_eq!(sid.intervals[0].start.0, 1);
4077        assert_eq!(sid.intervals[0].end.0, 6);
4078        assert_eq!(sid.intervals[1].start.0, 10);
4079        assert_eq!(sid.intervals[1].end.0, 16);
4080
4081        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4082        let e = input.parse::<Sid>().unwrap_err();
4083        assert_eq!(
4084            e.to_string(),
4085            "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4086        );
4087
4088        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4089        let e = input.parse::<Sid>().unwrap_err();
4090        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4091
4092        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4093        let e = input.parse::<Sid>().unwrap_err();
4094        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4095
4096        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4097        let e = input.parse::<Sid>().unwrap_err();
4098        assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4099
4100        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4101        let e = input.parse::<Sid>().unwrap_err();
4102        assert_eq!(
4103            e.to_string(),
4104            "start(4) >= end(4) in GnoInterval".to_string()
4105        );
4106    }
4107
4108    #[test]
4109    fn should_parse_rsa_public_key_response_packet() {
4110        const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4111
4112        let rsa_public_key_response =
4113            PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4114
4115        assert!(rsa_public_key_response.is_ok());
4116        assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4117    }
4118
4119    #[test]
4120    fn should_build_rsa_public_key_response_packet() {
4121        let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4122
4123        let mut actual = Vec::new();
4124        rsa_public_key_response.serialize(&mut actual);
4125
4126        let expected = b"\x01test".to_vec();
4127
4128        assert_eq!(expected, actual);
4129    }
4130}