mysql_common/packets/
mod.rs

1// Copyright (c) 2017 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::{
17    borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::scramble::create_response_for_ed25519;
22use crate::{
23    constants::{
24        CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
25        SessionStateType, StatusFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags,
26    },
27    io::{BufMutExt, ParseBuf},
28    misc::{
29        lenenc_str_len,
30        raw::{
31            Const, Either, RawBytes, RawConst, RawInt, Skip,
32            bytes::{
33                BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U8Bytes,
34                U32Bytes,
35            },
36            int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
37            seq::Seq,
38        },
39        unexpected_buf_eof,
40    },
41    proto::{MyDeserialize, MySerialize},
42    value::{ClientSide, SerializationSide, Value},
43};
44
45use self::session_state_change::SessionStateChange;
46
47static MARIADB_VERSION_RE: LazyLock<Regex> =
48    LazyLock::new(|| Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap());
49static VERSION_RE: LazyLock<Regex> =
50    LazyLock::new(|| Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap());
51
52macro_rules! define_header {
53    ($name:ident, $err:ident($msg:literal), $val:literal) => {
54        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
55        #[error($msg)]
56        pub struct $err;
57        pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
58    };
59    ($name:ident, $cmd:ident, $err:ident) => {
60        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
61        #[error("Invalid header for {}", stringify!($cmd))]
62        pub struct $err;
63        pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
64    };
65}
66
67macro_rules! define_const {
68    ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
69        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
70        #[error($msg)]
71        pub struct $err;
72        pub type $name = $kind<$err, $val>;
73    };
74}
75
76macro_rules! define_const_bytes {
77    ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr_2021, $len:literal) => {
78        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
79        #[error($msg)]
80        pub struct $err;
81
82        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
83        pub struct $vname;
84
85        impl ConstBytesValue<$len> for $vname {
86            const VALUE: [u8; $len] = $val;
87            type Error = $err;
88        }
89
90        pub type $name = ConstBytes<$vname, $len>;
91    };
92}
93
94pub mod binlog_request;
95pub mod caching_sha2_password;
96pub mod session_state_change;
97
98define_const_bytes!(
99    Catalog,
100    ColumnDefinitionCatalog,
101    InvalidCatalog("Invalid catalog value in the column definition"),
102    *b"\x03def",
103    4
104);
105
106define_const!(
107    ConstU8,
108    FixedLengthFieldsLen,
109    InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
110    0x0c
111);
112
113/// Dynamically-sized column metadata — a part of the [`Column`] packet.
114#[derive(Debug, Default, Clone, Eq, PartialEq)]
115struct ColumnMeta<'a> {
116    schema: RawBytes<'a, LenEnc>,
117    table: RawBytes<'a, LenEnc>,
118    org_table: RawBytes<'a, LenEnc>,
119    name: RawBytes<'a, LenEnc>,
120    org_name: RawBytes<'a, LenEnc>,
121}
122
123impl ColumnMeta<'_> {
124    pub fn into_owned(self) -> ColumnMeta<'static> {
125        ColumnMeta {
126            schema: self.schema.into_owned(),
127            table: self.table.into_owned(),
128            org_table: self.org_table.into_owned(),
129            name: self.name.into_owned(),
130            org_name: self.org_name.into_owned(),
131        }
132    }
133
134    /// Returns the value of the [`ColumnMeta::schema`] field as a byte slice.
135    pub fn schema_ref(&self) -> &[u8] {
136        self.schema.as_bytes()
137    }
138
139    /// Returns the value of the [`ColumnMeta::schema`] field as a string (lossy converted).
140    pub fn schema_str(&self) -> Cow<'_, str> {
141        String::from_utf8_lossy(self.schema_ref())
142    }
143
144    /// Returns the value of the [`ColumnMeta::table`] field as a byte slice.
145    pub fn table_ref(&self) -> &[u8] {
146        self.table.as_bytes()
147    }
148
149    /// Returns the value of the [`ColumnMeta::table`] field as a string (lossy converted).
150    pub fn table_str(&self) -> Cow<'_, str> {
151        String::from_utf8_lossy(self.table_ref())
152    }
153
154    /// Returns the value of the [`ColumnMeta::org_table`] field as a byte slice.
155    ///
156    /// "org_table" is for original table name.
157    pub fn org_table_ref(&self) -> &[u8] {
158        self.org_table.as_bytes()
159    }
160
161    /// Returns the value of the [`ColumnMeta::org_table`] field as a string (lossy converted).
162    pub fn org_table_str(&self) -> Cow<'_, str> {
163        String::from_utf8_lossy(self.org_table_ref())
164    }
165
166    /// Returns the value of the [`ColumnMeta::name`] field as a byte slice.
167    pub fn name_ref(&self) -> &[u8] {
168        self.name.as_bytes()
169    }
170
171    /// Returns the value of the [`ColumnMeta::name`] field as a string (lossy converted).
172    pub fn name_str(&self) -> Cow<'_, str> {
173        String::from_utf8_lossy(self.name_ref())
174    }
175
176    /// Returns the value of the [`ColumnMeta::org_name`] field as a byte slice.
177    ///
178    /// "org_name" is for original column name.
179    pub fn org_name_ref(&self) -> &[u8] {
180        self.org_name.as_bytes()
181    }
182
183    /// Returns value of the [`ColumnMeta::org_name`] field as a string (lossy converted).
184    pub fn org_name_str(&self) -> Cow<'_, str> {
185        String::from_utf8_lossy(self.org_name_ref())
186    }
187}
188
189impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
190    const SIZE: Option<usize> = None;
191    type Ctx = ();
192
193    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
194        Ok(Self {
195            schema: buf.parse_unchecked(())?,
196            table: buf.parse_unchecked(())?,
197            org_table: buf.parse_unchecked(())?,
198            name: buf.parse_unchecked(())?,
199            org_name: buf.parse_unchecked(())?,
200        })
201    }
202}
203
204impl MySerialize for ColumnMeta<'_> {
205    fn serialize(&self, buf: &mut Vec<u8>) {
206        self.schema.serialize(&mut *buf);
207        self.table.serialize(&mut *buf);
208        self.org_table.serialize(&mut *buf);
209        self.name.serialize(&mut *buf);
210        self.org_name.serialize(&mut *buf);
211    }
212}
213
214/// Represents MySql Column (column packet).
215#[derive(Debug, Clone, Eq, PartialEq)]
216pub struct Column {
217    catalog: ColumnDefinitionCatalog,
218    meta: Arc<ColumnMeta<'static>>,
219    fixed_length_fields_len: FixedLengthFieldsLen,
220    column_length: RawInt<LeU32>,
221    character_set: RawInt<LeU16>,
222    column_type: Const<ColumnType, u8>,
223    flags: Const<ColumnFlags, LeU16>,
224    decimals: RawInt<u8>,
225    __filler: Skip<2>,
226    // COM_FIELD_LIST is deprecated, so we won't support it
227}
228
229impl<'de> MyDeserialize<'de> for Column {
230    const SIZE: Option<usize> = None;
231    type Ctx = ();
232
233    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
234        let catalog = buf.parse(())?;
235        let meta = Arc::new(buf.parse::<ColumnMeta<'_>>(())?.into_owned());
236        let mut buf: ParseBuf<'_> = buf.parse(13)?;
237
238        Ok(Column {
239            catalog,
240            meta,
241            fixed_length_fields_len: buf.parse_unchecked(())?,
242            character_set: buf.parse_unchecked(())?,
243            column_length: buf.parse_unchecked(())?,
244            column_type: buf.parse_unchecked(())?,
245            flags: buf.parse_unchecked(())?,
246            decimals: buf.parse_unchecked(())?,
247            __filler: buf.parse_unchecked(())?,
248        })
249    }
250}
251
252impl MySerialize for Column {
253    fn serialize(&self, buf: &mut Vec<u8>) {
254        self.catalog.serialize(&mut *buf);
255        self.meta.serialize(&mut *buf);
256        self.fixed_length_fields_len.serialize(&mut *buf);
257        self.column_length.serialize(&mut *buf);
258        self.character_set.serialize(&mut *buf);
259        self.column_type.serialize(&mut *buf);
260        self.flags.serialize(&mut *buf);
261        self.decimals.serialize(&mut *buf);
262        self.__filler.serialize(&mut *buf);
263    }
264}
265
266impl Column {
267    pub fn new(column_type: ColumnType) -> Self {
268        Self {
269            catalog: Default::default(),
270            meta: Default::default(),
271            fixed_length_fields_len: Default::default(),
272            column_length: Default::default(),
273            character_set: Default::default(),
274            flags: Default::default(),
275            column_type: Const::new(column_type),
276            decimals: Default::default(),
277            __filler: Skip,
278        }
279    }
280
281    pub fn with_schema(mut self, schema: &[u8]) -> Self {
282        Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
283        self
284    }
285
286    pub fn with_table(mut self, table: &[u8]) -> Self {
287        Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
288        self
289    }
290
291    pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
292        Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
293        self
294    }
295
296    pub fn with_name(mut self, name: &[u8]) -> Self {
297        Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
298        self
299    }
300
301    pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
302        Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
303        self
304    }
305
306    pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
307        self.flags = Const::new(flags);
308        self
309    }
310
311    pub fn with_column_length(mut self, column_length: u32) -> Self {
312        self.column_length = RawInt::new(column_length);
313        self
314    }
315
316    pub fn with_character_set(mut self, character_set: u16) -> Self {
317        self.character_set = RawInt::new(character_set);
318        self
319    }
320
321    pub fn with_decimals(mut self, decimals: u8) -> Self {
322        self.decimals = RawInt::new(decimals);
323        self
324    }
325
326    /// Returns value of the column_length field of a column packet.
327    ///
328    /// Can be used for text-output formatting.
329    pub fn column_length(&self) -> u32 {
330        *self.column_length
331    }
332
333    /// Returns value of the column_type field of a column packet.
334    pub fn column_type(&self) -> ColumnType {
335        *self.column_type
336    }
337
338    /// Returns value of the character_set field of a column packet.
339    pub fn character_set(&self) -> u16 {
340        *self.character_set
341    }
342
343    /// Returns value of the flags field of a column packet.
344    pub fn flags(&self) -> ColumnFlags {
345        *self.flags
346    }
347
348    /// Returns value of the decimals field of a column packet.
349    ///
350    /// Max shown decimal digits. Can be used for text-output formatting
351    ///
352    /// *   `0x00` for integers and static strings
353    /// *   `0x1f` for dynamic strings, double, float
354    /// *   `0x00..=0x51` for decimals
355    pub fn decimals(&self) -> u8 {
356        *self.decimals
357    }
358
359    /// Returns value of the schema field of a column packet as a byte slice.
360    #[inline(always)]
361    pub fn schema_ref(&self) -> &[u8] {
362        self.meta.schema_ref()
363    }
364
365    /// Returns value of the schema field of a column packet as a string (lossy converted).
366    #[inline(always)]
367    pub fn schema_str(&self) -> Cow<'_, str> {
368        self.meta.schema_str()
369    }
370
371    /// Returns value of the table field of a column packet as a byte slice.
372    #[inline(always)]
373    pub fn table_ref(&self) -> &[u8] {
374        self.meta.table_ref()
375    }
376
377    /// Returns value of the table field of a column packet as a string (lossy converted).
378    #[inline(always)]
379    pub fn table_str(&self) -> Cow<'_, str> {
380        self.meta.table_str()
381    }
382
383    /// Returns value of the org_table field of a column packet as a byte slice.
384    ///
385    /// "org_table" is for original table name.
386    #[inline(always)]
387    pub fn org_table_ref(&self) -> &[u8] {
388        self.meta.org_table_ref()
389    }
390
391    /// Returns value of the org_table field of a column packet as a string (lossy converted).
392    #[inline(always)]
393    pub fn org_table_str(&self) -> Cow<'_, str> {
394        self.meta.org_table_str()
395    }
396
397    /// Returns value of the name field of a column packet as a byte slice.
398    #[inline(always)]
399    pub fn name_ref(&self) -> &[u8] {
400        self.meta.name_ref()
401    }
402
403    /// Returns value of the name field of a column packet as a string (lossy converted).
404    #[inline(always)]
405    pub fn name_str(&self) -> Cow<'_, str> {
406        self.meta.name_str()
407    }
408
409    /// Returns value of the org_name field of a column packet as a byte slice.
410    ///
411    /// "org_name" is for original column name.
412    #[inline(always)]
413    pub fn org_name_ref(&self) -> &[u8] {
414        self.meta.org_name_ref()
415    }
416
417    /// Returns value of the org_name field of a column packet as a string (lossy converted).
418    #[inline(always)]
419    pub fn org_name_str(&self) -> Cow<'_, str> {
420        self.meta.org_name_str()
421    }
422}
423
424/// Represents change in session state (part of MySql's Ok packet).
425#[derive(Debug, Clone, Eq, PartialEq)]
426pub struct SessionStateInfo<'a> {
427    data_type: Const<SessionStateType, u8>,
428    data: RawBytes<'a, LenEnc>,
429}
430
431impl SessionStateInfo<'_> {
432    pub fn into_owned(self) -> SessionStateInfo<'static> {
433        let SessionStateInfo { data_type, data } = self;
434        SessionStateInfo {
435            data_type,
436            data: data.into_owned(),
437        }
438    }
439
440    pub fn data_type(&self) -> SessionStateType {
441        *self.data_type
442    }
443
444    /// Returns raw session state info data.
445    pub fn data_ref(&self) -> &[u8] {
446        self.data.as_bytes()
447    }
448
449    /// Tries to decode session state info data.
450    pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
451        ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
452    }
453}
454
455impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
456    const SIZE: Option<usize> = None;
457    type Ctx = ();
458
459    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
460        Ok(SessionStateInfo {
461            data_type: buf.parse(())?,
462            data: buf.parse(())?,
463        })
464    }
465}
466
467impl MySerialize for SessionStateInfo<'_> {
468    fn serialize(&self, buf: &mut Vec<u8>) {
469        self.data_type.serialize(&mut *buf);
470        self.data.serialize(buf);
471    }
472}
473
474/// Represents MySql's Ok packet.
475#[derive(Debug, Clone, Eq, PartialEq)]
476pub struct OkPacketBody<'a> {
477    affected_rows: RawInt<LenEnc>,
478    last_insert_id: RawInt<LenEnc>,
479    status_flags: Const<StatusFlags, LeU16>,
480    warnings: RawInt<LeU16>,
481    info: RawBytes<'a, LenEnc>,
482    session_state_info: RawBytes<'a, LenEnc>,
483}
484
485/// OK packet kind (see _OK packet identifier_ section of [WL#7766][1]).
486///
487/// [1]: https://dev.mysql.com/worklog/task/?id=7766
488pub trait OkPacketKind {
489    const HEADER: u8;
490
491    fn parse_body<'de>(
492        capabilities: CapabilityFlags,
493        buf: &mut ParseBuf<'de>,
494    ) -> io::Result<OkPacketBody<'de>>;
495}
496
497/// Ok packet that terminates a result set (text or binary).
498#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
499pub struct ResultSetTerminator;
500
501impl OkPacketKind for ResultSetTerminator {
502    const HEADER: u8 = 0xFE;
503
504    fn parse_body<'de>(
505        capabilities: CapabilityFlags,
506        buf: &mut ParseBuf<'de>,
507    ) -> io::Result<OkPacketBody<'de>> {
508        // We need to skip affected_rows and insert_id here
509        // because valid content of EOF packet includes
510        // packet marker, server status and warning count only.
511        // (see `read_ok_ex` in sql-common/client.cc)
512        buf.parse::<RawInt<LenEnc>>(())?;
513        buf.parse::<RawInt<LenEnc>>(())?;
514
515        // assume CLIENT_PROTOCOL_41 flag
516        let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
517        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
518        let warnings = sbuf.parse_unchecked(())?;
519
520        let (info, session_state_info) =
521            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
522                let info = buf.parse(())?;
523                let session_state_info =
524                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
525                        buf.parse(())?
526                    } else {
527                        RawBytes::default()
528                    };
529                (info, session_state_info)
530            } else if !buf.is_empty() && buf.0[0] > 0 {
531                // The `info` field is a `string<EOF>` according to the MySQL Internals
532                // Manual, but actually it's a `string<lenenc>`.
533                // SEE: sql/protocol_classics.cc `net_send_ok`
534                let info = buf.parse(())?;
535                (info, RawBytes::default())
536            } else {
537                (RawBytes::default(), RawBytes::default())
538            };
539
540        Ok(OkPacketBody {
541            affected_rows: RawInt::new(0),
542            last_insert_id: RawInt::new(0),
543            status_flags,
544            warnings,
545            info,
546            session_state_info,
547        })
548    }
549}
550
551/// Old deprecated EOF packet.
552#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
553pub struct OldEofPacket;
554
555impl OkPacketKind for OldEofPacket {
556    const HEADER: u8 = 0xFE;
557
558    fn parse_body<'de>(
559        _: CapabilityFlags,
560        buf: &mut ParseBuf<'de>,
561    ) -> io::Result<OkPacketBody<'de>> {
562        // We assume that CLIENT_PROTOCOL_41 was set
563        let mut buf: ParseBuf<'_> = buf.parse(4)?;
564        let warnings = buf.parse_unchecked(())?;
565        let status_flags = buf.parse_unchecked(())?;
566
567        Ok(OkPacketBody {
568            affected_rows: RawInt::new(0),
569            last_insert_id: RawInt::new(0),
570            status_flags,
571            warnings,
572            info: RawBytes::new(&[][..]),
573            session_state_info: RawBytes::new(&[][..]),
574        })
575    }
576}
577
578/// This packet terminates a binlog network stream.
579#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
580pub struct NetworkStreamTerminator;
581
582impl OkPacketKind for NetworkStreamTerminator {
583    const HEADER: u8 = 0xFE;
584
585    fn parse_body<'de>(
586        flags: CapabilityFlags,
587        buf: &mut ParseBuf<'de>,
588    ) -> io::Result<OkPacketBody<'de>> {
589        OldEofPacket::parse_body(flags, buf)
590    }
591}
592
593/// Ok packet that is not a result set terminator.
594#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
595pub struct CommonOkPacket;
596
597impl OkPacketKind for CommonOkPacket {
598    const HEADER: u8 = 0x00;
599
600    fn parse_body<'de>(
601        capabilities: CapabilityFlags,
602        buf: &mut ParseBuf<'de>,
603    ) -> io::Result<OkPacketBody<'de>> {
604        let affected_rows = buf.parse(())?;
605        let last_insert_id = buf.parse(())?;
606
607        // We assume that CLIENT_PROTOCOL_41 was set
608        let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
609        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
610        let warnings = sbuf.parse_unchecked(())?;
611
612        let (info, session_state_info) =
613            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
614                let info = buf.parse(())?;
615                let session_state_info =
616                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
617                        buf.parse(())?
618                    } else {
619                        RawBytes::default()
620                    };
621                (info, session_state_info)
622            } else if !buf.is_empty() && buf.0[0] > 0 {
623                // The `info` field is a `string<EOF>` according to the MySQL Internals
624                // Manual, but actually it's a `string<lenenc>`.
625                // SEE: sql/protocol_classics.cc `net_send_ok`
626                let info = buf.parse(())?;
627                (info, RawBytes::default())
628            } else {
629                (RawBytes::default(), RawBytes::default())
630            };
631
632        Ok(OkPacketBody {
633            affected_rows,
634            last_insert_id,
635            status_flags,
636            warnings,
637            info,
638            session_state_info,
639        })
640    }
641}
642
643impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
644    type Error = io::Error;
645
646    fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
647        Ok(OkPacket {
648            affected_rows: *body.affected_rows,
649            last_insert_id: if *body.last_insert_id == 0 {
650                None
651            } else {
652                Some(*body.last_insert_id)
653            },
654            status_flags: *body.status_flags,
655            warnings: *body.warnings,
656            info: if !body.info.is_empty() {
657                Some(body.info)
658            } else {
659                None
660            },
661            session_state_info: if !body.session_state_info.is_empty() {
662                Some(body.session_state_info)
663            } else {
664                None
665            },
666        })
667    }
668}
669
670/// Represents MySql's Ok packet.
671#[derive(Debug, Clone, Eq, PartialEq)]
672pub struct OkPacket<'a> {
673    affected_rows: u64,
674    last_insert_id: Option<u64>,
675    status_flags: StatusFlags,
676    warnings: u16,
677    info: Option<RawBytes<'a, LenEnc>>,
678    session_state_info: Option<RawBytes<'a, LenEnc>>,
679}
680
681impl OkPacket<'_> {
682    pub fn into_owned(self) -> OkPacket<'static> {
683        OkPacket {
684            affected_rows: self.affected_rows,
685            last_insert_id: self.last_insert_id,
686            status_flags: self.status_flags,
687            warnings: self.warnings,
688            info: self.info.map(|x| x.into_owned()),
689            session_state_info: self.session_state_info.map(|x| x.into_owned()),
690        }
691    }
692
693    /// Value of the affected_rows field of an Ok packet.
694    pub fn affected_rows(&self) -> u64 {
695        self.affected_rows
696    }
697
698    /// Value of the last_insert_id field of an Ok packet.
699    pub fn last_insert_id(&self) -> Option<u64> {
700        self.last_insert_id
701    }
702
703    /// Value of the status_flags field of an Ok packet.
704    pub fn status_flags(&self) -> StatusFlags {
705        self.status_flags
706    }
707
708    /// Value of the warnings field of an Ok packet.
709    pub fn warnings(&self) -> u16 {
710        self.warnings
711    }
712
713    /// Value of the info field of an Ok packet as a byte slice.
714    pub fn info_ref(&self) -> Option<&[u8]> {
715        self.info.as_ref().map(|x| x.as_bytes())
716    }
717
718    /// Value of the info field of an Ok packet as a string (lossy converted).
719    pub fn info_str(&self) -> Option<Cow<'_, str>> {
720        self.info.as_ref().map(|x| x.as_str())
721    }
722
723    /// Returns raw reference to a session state info.
724    pub fn session_state_info_ref(&self) -> Option<&[u8]> {
725        self.session_state_info.as_ref().map(|x| x.as_bytes())
726    }
727
728    /// Tries to parse session state info, if any.
729    pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
730        self.session_state_info_ref()
731            .map(|data| {
732                let mut data = ParseBuf(data);
733                let mut entries = Vec::new();
734                while !data.is_empty() {
735                    entries.push(data.parse(())?);
736                }
737                Ok(entries)
738            })
739            .transpose()
740            .map(|x| x.unwrap_or_default())
741    }
742}
743
744#[derive(Debug, Clone, PartialEq, Eq)]
745pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
746
747impl<'de, T> OkPacketDeserializer<'de, T> {
748    pub fn into_inner(self) -> OkPacket<'de> {
749        self.0
750    }
751}
752
753impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
754    fn from(x: OkPacketDeserializer<'de, T>) -> Self {
755        x.0
756    }
757}
758
759#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
760#[error("Invalid OK packet header")]
761pub struct InvalidOkPacketHeader;
762
763impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
764    const SIZE: Option<usize> = None;
765    type Ctx = CapabilityFlags;
766
767    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
768        if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
769            let body = T::parse_body(capabilities, buf)?;
770            let ok = OkPacket::try_from(body)?;
771            Ok(Self(ok, PhantomData))
772        } else {
773            Err(io::Error::new(
774                io::ErrorKind::InvalidData,
775                InvalidOkPacketHeader,
776            ))
777        }
778    }
779}
780
781/// Progress report information (may be in an error packet of MariaDB server).
782#[derive(Debug, Clone, Eq, PartialEq)]
783pub struct ProgressReport<'a> {
784    stage: RawInt<u8>,
785    max_stage: RawInt<u8>,
786    progress: RawInt<LeU24>,
787    stage_info: RawBytes<'a, LenEnc>,
788}
789
790impl<'a> ProgressReport<'a> {
791    pub fn new(
792        stage: u8,
793        max_stage: u8,
794        progress: u32,
795        stage_info: impl Into<Cow<'a, [u8]>>,
796    ) -> ProgressReport<'a> {
797        ProgressReport {
798            stage: RawInt::new(stage),
799            max_stage: RawInt::new(max_stage),
800            progress: RawInt::new(progress),
801            stage_info: RawBytes::new(stage_info),
802        }
803    }
804
805    /// 1 to max_stage
806    pub fn stage(&self) -> u8 {
807        *self.stage
808    }
809
810    pub fn max_stage(&self) -> u8 {
811        *self.max_stage
812    }
813
814    /// Progress as '% * 1000'
815    pub fn progress(&self) -> u32 {
816        *self.progress
817    }
818
819    /// Status or state name as a byte slice.
820    pub fn stage_info_ref(&self) -> &[u8] {
821        self.stage_info.as_bytes()
822    }
823
824    /// Status or state name as a string (lossy converted).
825    pub fn stage_info_str(&self) -> Cow<'_, str> {
826        self.stage_info.as_str()
827    }
828
829    pub fn into_owned(self) -> ProgressReport<'static> {
830        ProgressReport {
831            stage: self.stage,
832            max_stage: self.max_stage,
833            progress: self.progress,
834            stage_info: self.stage_info.into_owned(),
835        }
836    }
837}
838
839impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
840    const SIZE: Option<usize> = None;
841    type Ctx = ();
842
843    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
844        let mut sbuf: ParseBuf<'_> = buf.parse(6)?;
845
846        sbuf.skip(1); // Ignore number of strings.
847
848        Ok(ProgressReport {
849            stage: sbuf.parse_unchecked(())?,
850            max_stage: sbuf.parse_unchecked(())?,
851            progress: sbuf.parse_unchecked(())?,
852            stage_info: buf.parse(())?,
853        })
854    }
855}
856
857impl MySerialize for ProgressReport<'_> {
858    fn serialize(&self, buf: &mut Vec<u8>) {
859        buf.put_u8(1);
860        self.stage.serialize(&mut *buf);
861        self.max_stage.serialize(&mut *buf);
862        self.progress.serialize(&mut *buf);
863        self.stage_info.serialize(buf);
864    }
865}
866
867impl fmt::Display for ProgressReport<'_> {
868    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
869        write!(
870            f,
871            "Stage: {} of {} '{}'  {:.2}% of stage done",
872            self.stage(),
873            self.max_stage(),
874            self.progress(),
875            self.stage_info_str()
876        )
877    }
878}
879
880define_header!(
881    ErrPacketHeader,
882    InvalidErrPacketHeader("Invalid error packet header"),
883    0xFF
884);
885
886/// MySql error packet.
887///
888/// May hold an error or a progress report.
889#[derive(Debug, Clone, PartialEq)]
890pub enum ErrPacket<'a> {
891    Error(ServerError<'a>),
892    Progress(ProgressReport<'a>),
893}
894
895impl ErrPacket<'_> {
896    /// Returns false if this error packet contains progress report.
897    pub fn is_error(&self) -> bool {
898        matches!(self, ErrPacket::Error { .. })
899    }
900
901    /// Returns true if this error packet contains progress report.
902    pub fn is_progress_report(&self) -> bool {
903        !self.is_error()
904    }
905
906    /// Will panic if ErrPacket does not contains progress report
907    pub fn progress_report(&self) -> &ProgressReport<'_> {
908        match *self {
909            ErrPacket::Progress(ref progress_report) => progress_report,
910            _ => panic!("This ErrPacket does not contains progress report"),
911        }
912    }
913
914    /// Will panic if ErrPacket does not contains a `ServerError`.
915    pub fn server_error(&self) -> &ServerError<'_> {
916        match self {
917            ErrPacket::Error(error) => error,
918            ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
919        }
920    }
921}
922
923impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
924    const SIZE: Option<usize> = None;
925    type Ctx = CapabilityFlags;
926
927    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
928        let mut sbuf: ParseBuf<'_> = buf.parse(3)?;
929        sbuf.parse_unchecked::<ErrPacketHeader>(())?;
930        let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
931
932        if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
933            buf.parse(()).map(ErrPacket::Progress)
934        } else {
935            buf.parse((
936                *code,
937                capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
938            ))
939            .map(ErrPacket::Error)
940        }
941    }
942}
943
944impl MySerialize for ErrPacket<'_> {
945    fn serialize(&self, buf: &mut Vec<u8>) {
946        ErrPacketHeader::new().serialize(&mut *buf);
947        match self {
948            ErrPacket::Error(server_error) => {
949                server_error.code.serialize(&mut *buf);
950                server_error.serialize(buf);
951            }
952            ErrPacket::Progress(progress_report) => {
953                RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
954                progress_report.serialize(buf);
955            }
956        }
957    }
958}
959
960impl fmt::Display for ErrPacket<'_> {
961    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
962        match self {
963            ErrPacket::Error(server_error) => write!(f, "{}", server_error),
964            ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
965        }
966    }
967}
968
969define_header!(
970    SqlStateMarker,
971    InvalidSqlStateMarker("Invalid SqlStateMarker value"),
972    b'#'
973);
974
975/// MySql error state.
976#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
977pub struct SqlState {
978    __state_marker: SqlStateMarker,
979    state: [u8; 5],
980}
981
982impl SqlState {
983    /// Creates new sql state.
984    pub fn new(state: [u8; 5]) -> Self {
985        Self {
986            __state_marker: SqlStateMarker::new(),
987            state,
988        }
989    }
990
991    /// Returns an sql state as bytes.
992    pub fn as_bytes(&self) -> [u8; 5] {
993        self.state
994    }
995
996    /// Returns an sql state as a string (lossy converted).
997    pub fn as_str(&self) -> Cow<'_, str> {
998        String::from_utf8_lossy(&self.state)
999    }
1000}
1001
1002impl<'de> MyDeserialize<'de> for SqlState {
1003    const SIZE: Option<usize> = Some(6);
1004    type Ctx = ();
1005
1006    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1007        Ok(Self {
1008            __state_marker: buf.parse(())?,
1009            state: buf.parse(())?,
1010        })
1011    }
1012}
1013
1014impl MySerialize for SqlState {
1015    fn serialize(&self, buf: &mut Vec<u8>) {
1016        self.__state_marker.serialize(buf);
1017        self.state.serialize(buf);
1018    }
1019}
1020
1021/// MySql error packet.
1022///
1023/// May hold an error or a progress report.
1024#[derive(Debug, Clone, PartialEq)]
1025pub struct ServerError<'a> {
1026    code: RawInt<LeU16>,
1027    state: Option<SqlState>,
1028    message: RawBytes<'a, EofBytes>,
1029}
1030
1031impl<'a> ServerError<'a> {
1032    pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1033        Self {
1034            code: RawInt::new(code),
1035            state,
1036            message: RawBytes::new(msg),
1037        }
1038    }
1039
1040    /// Returns an error code.
1041    pub fn error_code(&self) -> u16 {
1042        *self.code
1043    }
1044
1045    /// Returns an sql state.
1046    pub fn sql_state_ref(&self) -> Option<&SqlState> {
1047        self.state.as_ref()
1048    }
1049
1050    /// Returns an error message.
1051    pub fn message_ref(&self) -> &[u8] {
1052        self.message.as_bytes()
1053    }
1054
1055    /// Returns an error message as a string (lossy converted).
1056    pub fn message_str(&self) -> Cow<'_, str> {
1057        self.message.as_str()
1058    }
1059
1060    pub fn into_owned(self) -> ServerError<'static> {
1061        ServerError {
1062            code: self.code,
1063            state: self.state,
1064            message: self.message.into_owned(),
1065        }
1066    }
1067}
1068
1069impl<'de> MyDeserialize<'de> for ServerError<'de> {
1070    const SIZE: Option<usize> = None;
1071    /// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
1072    type Ctx = (u16, bool);
1073
1074    fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1075        let server_error = if protocol_41 {
1076            ServerError {
1077                code: RawInt::new(code),
1078                state: Some(buf.parse(())?),
1079                message: buf.parse(())?,
1080            }
1081        } else {
1082            ServerError {
1083                code: RawInt::new(code),
1084                state: None,
1085                message: buf.parse(())?,
1086            }
1087        };
1088        Ok(server_error)
1089    }
1090}
1091
1092impl MySerialize for ServerError<'_> {
1093    fn serialize(&self, buf: &mut Vec<u8>) {
1094        if let Some(state) = &self.state {
1095            state.serialize(buf);
1096        }
1097        self.message.serialize(buf);
1098    }
1099}
1100
1101impl fmt::Display for ServerError<'_> {
1102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103        let sql_state_str = self
1104            .sql_state_ref()
1105            .map(|s| format!(" ({})", s.as_str()))
1106            .unwrap_or_default();
1107
1108        write!(
1109            f,
1110            "ERROR {}{}: {}",
1111            self.error_code(),
1112            sql_state_str,
1113            self.message_str()
1114        )
1115    }
1116}
1117
1118define_header!(
1119    LocalInfileHeader,
1120    InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1121    0xFB
1122);
1123
1124/// Represents MySql's local infile packet.
1125#[derive(Debug, Clone, Eq, PartialEq)]
1126pub struct LocalInfilePacket<'a> {
1127    __header: LocalInfileHeader,
1128    file_name: RawBytes<'a, EofBytes>,
1129}
1130
1131impl<'a> LocalInfilePacket<'a> {
1132    pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1133        Self {
1134            __header: LocalInfileHeader::new(),
1135            file_name: RawBytes::new(file_name),
1136        }
1137    }
1138
1139    /// Value of the file_name field of a local infile packet as a byte slice.
1140    pub fn file_name_ref(&self) -> &[u8] {
1141        self.file_name.as_bytes()
1142    }
1143
1144    /// Value of the file_name field of a local infile packet as a string (lossy converted).
1145    pub fn file_name_str(&self) -> Cow<'_, str> {
1146        self.file_name.as_str()
1147    }
1148
1149    pub fn into_owned(self) -> LocalInfilePacket<'static> {
1150        LocalInfilePacket {
1151            __header: self.__header,
1152            file_name: self.file_name.into_owned(),
1153        }
1154    }
1155}
1156
1157impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1158    const SIZE: Option<usize> = None;
1159    type Ctx = ();
1160
1161    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1162        Ok(LocalInfilePacket {
1163            __header: buf.parse(())?,
1164            file_name: buf.parse(())?,
1165        })
1166    }
1167}
1168
1169impl MySerialize for LocalInfilePacket<'_> {
1170    fn serialize(&self, buf: &mut Vec<u8>) {
1171        self.__header.serialize(buf);
1172        self.file_name.serialize(buf);
1173    }
1174}
1175
1176const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1177const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1178const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1179const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1180const ED25519_PLUGIN_NAME: &[u8] = b"client_ed25519";
1181
1182#[derive(Debug, Clone, PartialEq, Eq)]
1183pub enum AuthPluginData<'a> {
1184    /// Auth data for the `mysql_old_password` plugin.
1185    Old([u8; 8]),
1186    /// Auth data for the `mysql_native_password` plugin.
1187    Native([u8; 20]),
1188    /// Auth data for `sha2_password` and `caching_sha2_password` plugins.
1189    Sha2([u8; 32]),
1190    /// Clear password for `mysql_clear_password` plugin.
1191    Clear(Cow<'a, [u8]>),
1192    /// Auth data for MariaDB's `client_ed25519` plugin.
1193    ///
1194    /// This plugin is known to the library but the actual support is enabled
1195    /// by the `client_ed25519` feature.
1196    Ed25519([u8; 64]),
1197}
1198
1199impl AuthPluginData<'_> {
1200    pub fn into_owned(self) -> AuthPluginData<'static> {
1201        match self {
1202            AuthPluginData::Old(x) => AuthPluginData::Old(x),
1203            AuthPluginData::Native(x) => AuthPluginData::Native(x),
1204            AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1205            AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1206            AuthPluginData::Ed25519(x) => AuthPluginData::Ed25519(x),
1207        }
1208    }
1209}
1210
1211impl std::ops::Deref for AuthPluginData<'_> {
1212    type Target = [u8];
1213
1214    fn deref(&self) -> &Self::Target {
1215        match self {
1216            Self::Sha2(x) => &x[..],
1217            Self::Native(x) => &x[..],
1218            Self::Old(x) => &x[..],
1219            Self::Clear(x) => &x[..],
1220            Self::Ed25519(x) => &x[..],
1221        }
1222    }
1223}
1224
1225impl MySerialize for AuthPluginData<'_> {
1226    fn serialize(&self, buf: &mut Vec<u8>) {
1227        match self {
1228            Self::Sha2(x) => buf.put_slice(&x[..]),
1229            Self::Native(x) => buf.put_slice(&x[..]),
1230            Self::Old(x) => {
1231                buf.put_slice(&x[..]);
1232                buf.push(0);
1233            }
1234            Self::Clear(x) => {
1235                buf.put_slice(x);
1236                buf.push(0);
1237            }
1238            Self::Ed25519(x) => buf.put_slice(&x[..]),
1239        }
1240    }
1241}
1242
1243/// Authentication plugin
1244#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1245pub enum AuthPlugin<'a> {
1246    /// Old Password Authentication
1247    MysqlOldPassword,
1248    /// Client-Side Cleartext Pluggable Authentication
1249    MysqlClearPassword,
1250    /// Legacy authentication plugin
1251    MysqlNativePassword,
1252    /// Default since MySql v8.0.4
1253    CachingSha2Password,
1254    /// MariaDB's Ed25519 based authentication
1255    ///
1256    /// This plugin is known to the library but the actual support is enabled
1257    /// by the `client_ed25519` feature.
1258    Ed25519,
1259    Other(Cow<'a, [u8]>),
1260}
1261
1262impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1263    const SIZE: Option<usize> = None;
1264    type Ctx = ();
1265
1266    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1267        Ok(Self::from_bytes(buf.eat_all()))
1268    }
1269}
1270
1271impl MySerialize for AuthPlugin<'_> {
1272    fn serialize(&self, buf: &mut Vec<u8>) {
1273        buf.put_slice(self.as_bytes());
1274        buf.put_u8(0);
1275    }
1276}
1277
1278impl<'a> AuthPlugin<'a> {
1279    pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1280        let name = if let [name @ .., 0] = name {
1281            name
1282        } else {
1283            name
1284        };
1285        match name {
1286            CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1287            MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1288            MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1289            MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1290            ED25519_PLUGIN_NAME => AuthPlugin::Ed25519,
1291            name => AuthPlugin::Other(Cow::Borrowed(name)),
1292        }
1293    }
1294
1295    pub fn as_bytes(&self) -> &[u8] {
1296        match self {
1297            AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1298            AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1299            AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1300            AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1301            AuthPlugin::Ed25519 => ED25519_PLUGIN_NAME,
1302            AuthPlugin::Other(name) => name,
1303        }
1304    }
1305
1306    pub fn into_owned(self) -> AuthPlugin<'static> {
1307        match self {
1308            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1309            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1310            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1311            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1312            AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1313            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1314        }
1315    }
1316
1317    pub fn borrow(&self) -> AuthPlugin<'_> {
1318        match self {
1319            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1320            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1321            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1322            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1323            AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1324            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1325        }
1326    }
1327
1328    /// Generates auth plugin data for this plugin.
1329    ///
1330    /// It'll generate `None` if password is `None` or empty.
1331    ///
1332    /// Note, that you should trim terminating null character from the `nonce`.
1333    ///
1334    /// # Panic
1335    ///
1336    /// * [`AuthPlugin::Ed25519`] will panic if `client_ed25519` feature is disabled.
1337    pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1338        use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1339
1340        match pass {
1341            Some(pass) if !pass.is_empty() => match self {
1342                AuthPlugin::CachingSha2Password => {
1343                    scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1344                }
1345                AuthPlugin::MysqlNativePassword => {
1346                    scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1347                }
1348                AuthPlugin::MysqlOldPassword => {
1349                    scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1350                        .map(AuthPluginData::Old)
1351                }
1352                AuthPlugin::MysqlClearPassword => {
1353                    Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1354                }
1355                AuthPlugin::Ed25519 => Some(AuthPluginData::Ed25519(create_response_for_ed25519(
1356                    pass.as_bytes(),
1357                    nonce,
1358                ))),
1359                AuthPlugin::Other(_) => None,
1360            },
1361            _ => None,
1362        }
1363    }
1364}
1365
1366define_header!(
1367    AuthMoreDataHeader,
1368    InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1369    0x01
1370);
1371
1372/// Extra auth-data beyond the initial challenge.
1373#[derive(Debug, Clone, Eq, PartialEq)]
1374pub struct AuthMoreData<'a> {
1375    __header: AuthMoreDataHeader,
1376    data: RawBytes<'a, EofBytes>,
1377}
1378
1379impl<'a> AuthMoreData<'a> {
1380    pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1381        Self {
1382            __header: AuthMoreDataHeader::new(),
1383            data: RawBytes::new(data),
1384        }
1385    }
1386
1387    pub fn data(&self) -> &[u8] {
1388        self.data.as_bytes()
1389    }
1390
1391    pub fn into_owned(self) -> AuthMoreData<'static> {
1392        AuthMoreData {
1393            __header: self.__header,
1394            data: self.data.into_owned(),
1395        }
1396    }
1397}
1398
1399impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1400    const SIZE: Option<usize> = None;
1401    type Ctx = ();
1402
1403    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1404        Ok(Self {
1405            __header: buf.parse(())?,
1406            data: buf.parse(())?,
1407        })
1408    }
1409}
1410
1411impl MySerialize for AuthMoreData<'_> {
1412    fn serialize(&self, buf: &mut Vec<u8>) {
1413        self.__header.serialize(&mut *buf);
1414        self.data.serialize(buf);
1415    }
1416}
1417
1418define_header!(
1419    PublicKeyResponseHeader,
1420    InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1421    0x01
1422);
1423
1424/// A server response to a [`PublicKeyRequest`] containing a public RSA key for authentication protection.
1425///
1426/// [`PublicKeyRequest`]: crate::packets::caching_sha2_password::PublicKeyRequest
1427#[derive(Debug, Clone, Eq, PartialEq)]
1428pub struct PublicKeyResponse<'a> {
1429    __header: PublicKeyResponseHeader,
1430    rsa_key: RawBytes<'a, EofBytes>,
1431}
1432
1433impl<'a> PublicKeyResponse<'a> {
1434    pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1435        Self {
1436            __header: PublicKeyResponseHeader::new(),
1437            rsa_key: RawBytes::new(rsa_key),
1438        }
1439    }
1440
1441    /// The server's RSA public key in PEM format.
1442    pub fn rsa_key(&self) -> Cow<'_, str> {
1443        self.rsa_key.as_str()
1444    }
1445}
1446
1447impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1448    const SIZE: Option<usize> = None;
1449    type Ctx = ();
1450
1451    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1452        Ok(Self {
1453            __header: buf.parse(())?,
1454            rsa_key: buf.parse(())?,
1455        })
1456    }
1457}
1458
1459impl MySerialize for PublicKeyResponse<'_> {
1460    fn serialize(&self, buf: &mut Vec<u8>) {
1461        self.__header.serialize(&mut *buf);
1462        self.rsa_key.serialize(buf);
1463    }
1464}
1465
1466define_header!(
1467    AuthSwitchRequestHeader,
1468    InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1469    0xFE
1470);
1471
1472/// Old Authentication Method Switch Request Packet.
1473///
1474/// Used for It is sent by server to request client to switch to Old Password Authentication
1475/// if `CLIENT_PLUGIN_AUTH` capability is not supported (by either the client or the server).
1476#[derive(Debug, Clone, Eq, PartialEq)]
1477pub struct OldAuthSwitchRequest {
1478    __header: AuthSwitchRequestHeader,
1479}
1480
1481impl OldAuthSwitchRequest {
1482    pub fn new() -> Self {
1483        Self {
1484            __header: AuthSwitchRequestHeader::new(),
1485        }
1486    }
1487
1488    pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1489        AuthPlugin::MysqlOldPassword
1490    }
1491}
1492
1493impl Default for OldAuthSwitchRequest {
1494    fn default() -> Self {
1495        Self::new()
1496    }
1497}
1498
1499impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1500    const SIZE: Option<usize> = Some(1);
1501    type Ctx = ();
1502
1503    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1504        Ok(Self {
1505            __header: buf.parse(())?,
1506        })
1507    }
1508}
1509
1510impl MySerialize for OldAuthSwitchRequest {
1511    fn serialize(&self, buf: &mut Vec<u8>) {
1512        self.__header.serialize(&mut *buf);
1513    }
1514}
1515
1516/// Authentication Method Switch Request Packet.
1517///
1518/// If both server and client support `CLIENT_PLUGIN_AUTH` capability, server can send this packet
1519/// to ask client to use another authentication method.
1520#[derive(Debug, Clone, Eq, PartialEq)]
1521pub struct AuthSwitchRequest<'a> {
1522    __header: AuthSwitchRequestHeader,
1523    auth_plugin: RawBytes<'a, NullBytes>,
1524    plugin_data: RawBytes<'a, EofBytes>,
1525}
1526
1527impl<'a> AuthSwitchRequest<'a> {
1528    pub fn new(
1529        auth_plugin: impl Into<Cow<'a, [u8]>>,
1530        plugin_data: impl Into<Cow<'a, [u8]>>,
1531    ) -> Self {
1532        Self {
1533            __header: AuthSwitchRequestHeader::new(),
1534            auth_plugin: RawBytes::new(auth_plugin),
1535            plugin_data: RawBytes::new(plugin_data),
1536        }
1537    }
1538
1539    pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1540        ParseBuf(self.auth_plugin.as_bytes())
1541            .parse(())
1542            .expect("infallible")
1543    }
1544
1545    pub fn plugin_data(&self) -> &[u8] {
1546        match self.plugin_data.as_bytes() {
1547            [head @ .., 0] => head,
1548            all => all,
1549        }
1550    }
1551
1552    pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1553        AuthSwitchRequest {
1554            __header: self.__header,
1555            auth_plugin: self.auth_plugin.into_owned(),
1556            plugin_data: self.plugin_data.into_owned(),
1557        }
1558    }
1559}
1560
1561impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1562    const SIZE: Option<usize> = None;
1563    type Ctx = ();
1564
1565    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1566        Ok(Self {
1567            __header: buf.parse(())?,
1568            auth_plugin: buf.parse(())?,
1569            plugin_data: buf.parse(())?,
1570        })
1571    }
1572}
1573
1574impl MySerialize for AuthSwitchRequest<'_> {
1575    fn serialize(&self, buf: &mut Vec<u8>) {
1576        self.__header.serialize(&mut *buf);
1577        self.auth_plugin.serialize(&mut *buf);
1578        self.plugin_data.serialize(buf);
1579    }
1580}
1581
1582/// Represents MySql's initial handshake packet.
1583#[derive(Debug, Clone, Eq, PartialEq)]
1584pub struct HandshakePacket<'a> {
1585    protocol_version: RawInt<u8>,
1586    server_version: RawBytes<'a, NullBytes>,
1587    connection_id: RawInt<LeU32>,
1588    scramble_1: [u8; 8],
1589    __filler: Skip<1>,
1590    // lower 16 bytes
1591    capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1592    default_collation: RawInt<u8>,
1593    status_flags: Const<StatusFlags, LeU16>,
1594    // upper 16 bytes
1595    capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1596    auth_plugin_data_len: RawInt<u8>,
1597    __reserved: Skip<10>,
1598    scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1599    auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1600}
1601
1602impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1603    const SIZE: Option<usize> = None;
1604    type Ctx = ();
1605
1606    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1607        let protocol_version = buf.parse(())?;
1608        let server_version = buf.parse(())?;
1609
1610        // includes trailing 10 bytes filler
1611        let mut sbuf: ParseBuf<'_> = buf.parse(31)?;
1612        let connection_id = sbuf.parse_unchecked(())?;
1613        let scramble_1 = sbuf.parse_unchecked(())?;
1614        let __filler = sbuf.parse_unchecked(())?;
1615        let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1616        let default_collation = sbuf.parse_unchecked(())?;
1617        let status_flags = sbuf.parse_unchecked(())?;
1618        let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1619        let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1620        let __reserved = sbuf.parse_unchecked(())?;
1621        let mut scramble_2 = None;
1622        if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1623            let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1624            scramble_2 = buf.parse(len).map(Some)?;
1625        }
1626        let mut auth_plugin_name = None;
1627        if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1628            auth_plugin_name = match buf.eat_all() {
1629                [head @ .., 0] => Some(RawBytes::new(head)),
1630                // missing trailing `0` is a known bug in mysql
1631                all => Some(RawBytes::new(all)),
1632            }
1633        }
1634
1635        Ok(Self {
1636            protocol_version,
1637            server_version,
1638            connection_id,
1639            scramble_1,
1640            __filler,
1641            capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1642            default_collation,
1643            status_flags,
1644            capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1645            auth_plugin_data_len,
1646            __reserved,
1647            scramble_2,
1648            auth_plugin_name,
1649        })
1650    }
1651}
1652
1653impl MySerialize for HandshakePacket<'_> {
1654    fn serialize(&self, buf: &mut Vec<u8>) {
1655        self.protocol_version.serialize(&mut *buf);
1656        self.server_version.serialize(&mut *buf);
1657        self.connection_id.serialize(&mut *buf);
1658        self.scramble_1.serialize(&mut *buf);
1659        buf.put_u8(0x00);
1660        self.capabilities_1.serialize(&mut *buf);
1661        self.default_collation.serialize(&mut *buf);
1662        self.status_flags.serialize(&mut *buf);
1663        self.capabilities_2.serialize(&mut *buf);
1664
1665        if self
1666            .capabilities_2
1667            .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1668        {
1669            buf.put_u8(
1670                self.scramble_2
1671                    .as_ref()
1672                    .map(|x| (x.len() + 8) as u8)
1673                    .unwrap_or_default(),
1674            );
1675        } else {
1676            buf.put_u8(0);
1677        }
1678
1679        buf.put_slice(&[0_u8; 10][..]);
1680
1681        // Assume that the packet is well formed:
1682        // * the CLIENT_SECURE_CONNECTION is set.
1683        if let Some(scramble_2) = &self.scramble_2 {
1684            scramble_2.serialize(&mut *buf);
1685        }
1686
1687        // Assume that the packet is well formed:
1688        // * the CLIENT_PLUGIN_AUTH is set.
1689        if let Some(client_plugin_auth) = &self.auth_plugin_name {
1690            client_plugin_auth.serialize(buf);
1691        }
1692    }
1693}
1694
1695impl<'a> HandshakePacket<'a> {
1696    #[allow(clippy::too_many_arguments)]
1697    pub fn new(
1698        protocol_version: u8,
1699        server_version: impl Into<Cow<'a, [u8]>>,
1700        connection_id: u32,
1701        scramble_1: [u8; 8],
1702        scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1703        capabilities: CapabilityFlags,
1704        default_collation: u8,
1705        status_flags: StatusFlags,
1706        auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1707    ) -> Self {
1708        // Safety:
1709        // * capabilities are given as a valid CapabilityFlags instance
1710        // * the BitAnd operation can't set new bits
1711        let (capabilities_1, capabilities_2) = (
1712            CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1713            CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1714        );
1715
1716        let scramble_2 = scramble_2.map(RawBytes::new);
1717
1718        HandshakePacket {
1719            protocol_version: RawInt::new(protocol_version),
1720            server_version: RawBytes::new(server_version),
1721            connection_id: RawInt::new(connection_id),
1722            scramble_1,
1723            __filler: Skip,
1724            capabilities_1: Const::new(capabilities_1),
1725            default_collation: RawInt::new(default_collation),
1726            status_flags: Const::new(status_flags),
1727            capabilities_2: Const::new(capabilities_2),
1728            auth_plugin_data_len: RawInt::new(
1729                scramble_2
1730                    .as_ref()
1731                    .map(|x| x.len() as u8)
1732                    .unwrap_or_default(),
1733            ),
1734            __reserved: Skip,
1735            scramble_2,
1736            auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1737        }
1738    }
1739
1740    pub fn into_owned(self) -> HandshakePacket<'static> {
1741        HandshakePacket {
1742            protocol_version: self.protocol_version,
1743            server_version: self.server_version.into_owned(),
1744            connection_id: self.connection_id,
1745            scramble_1: self.scramble_1,
1746            __filler: self.__filler,
1747            capabilities_1: self.capabilities_1,
1748            default_collation: self.default_collation,
1749            status_flags: self.status_flags,
1750            capabilities_2: self.capabilities_2,
1751            auth_plugin_data_len: self.auth_plugin_data_len,
1752            __reserved: self.__reserved,
1753            scramble_2: self.scramble_2.map(|x| x.into_owned()),
1754            auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1755        }
1756    }
1757
1758    /// Value of the protocol_version field of an initial handshake packet.
1759    pub fn protocol_version(&self) -> u8 {
1760        self.protocol_version.0
1761    }
1762
1763    /// Value of the server_version field of an initial handshake packet as a byte slice.
1764    pub fn server_version_ref(&self) -> &[u8] {
1765        self.server_version.as_bytes()
1766    }
1767
1768    /// Value of the server_version field of an initial handshake packet as a string
1769    /// (lossy converted).
1770    pub fn server_version_str(&self) -> Cow<'_, str> {
1771        self.server_version.as_str()
1772    }
1773
1774    /// Parsed server version.
1775    ///
1776    /// Will parse first \d+.\d+.\d+ of a server version string (if any).
1777    pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1778        VERSION_RE
1779            .captures(self.server_version_ref())
1780            .map(|captures| {
1781                // Should not panic because validated with regex
1782                (
1783                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1784                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1785                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1786                )
1787            })
1788    }
1789
1790    /// Parsed mariadb server version.
1791    pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1792        MARIADB_VERSION_RE
1793            .captures(self.server_version_ref())
1794            .map(|captures| {
1795                // Should not panic because validated with regex
1796                (
1797                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1798                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1799                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1800                )
1801            })
1802    }
1803
1804    /// Value of the connection_id field of an initial handshake packet.
1805    pub fn connection_id(&self) -> u32 {
1806        self.connection_id.0
1807    }
1808
1809    /// Value of the scramble_1 field of an initial handshake packet as a byte slice.
1810    pub fn scramble_1_ref(&self) -> &[u8] {
1811        self.scramble_1.as_ref()
1812    }
1813
1814    /// Value of the scramble_2 field of an initial handshake packet as a byte slice.
1815    ///
1816    /// Note that this may include a terminating null character.
1817    pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1818        self.scramble_2.as_ref().map(|x| x.as_bytes())
1819    }
1820
1821    /// Returns concatenated auth plugin nonce.
1822    pub fn nonce(&self) -> Vec<u8> {
1823        let mut out = Vec::from(self.scramble_1_ref());
1824        out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1825
1826        // Trim zero terminator. Fill with zeroes if nonce
1827        // is somehow smaller than 20 bytes.
1828        out.resize(20, 0);
1829        out
1830    }
1831
1832    /// Value of a server capabilities.
1833    pub fn capabilities(&self) -> CapabilityFlags {
1834        self.capabilities_1.0 | self.capabilities_2.0
1835    }
1836
1837    /// Value of the default_collation field of an initial handshake packet.
1838    pub fn default_collation(&self) -> u8 {
1839        self.default_collation.0
1840    }
1841
1842    /// Value of a status flags.
1843    pub fn status_flags(&self) -> StatusFlags {
1844        self.status_flags.0
1845    }
1846
1847    /// Value of the auth_plugin_name field of an initial handshake packet as a byte slice.
1848    pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1849        self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1850    }
1851
1852    /// Value of the auth_plugin_name field of an initial handshake packet as a string
1853    /// (lossy converted).
1854    pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1855        self.auth_plugin_name.as_ref().map(|x| x.as_str())
1856    }
1857
1858    /// Auth plugin of a handshake packet
1859    pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1860        self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1861            [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1862            all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1863        })
1864    }
1865}
1866
1867define_header!(
1868    ComChangeUserHeader,
1869    InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1870    0x11
1871);
1872
1873#[derive(Debug, Clone, PartialEq, Eq)]
1874pub struct ComChangeUser<'a> {
1875    __header: ComChangeUserHeader,
1876    user: RawBytes<'a, NullBytes>,
1877    // Only CLIENT_SECURE_CONNECTION capable servers are supported
1878    auth_plugin_data: RawBytes<'a, U8Bytes>,
1879    database: RawBytes<'a, NullBytes>,
1880    more_data: Option<ComChangeUserMoreData<'a>>,
1881}
1882
1883impl<'a> ComChangeUser<'a> {
1884    pub fn new() -> Self {
1885        Self {
1886            __header: ComChangeUserHeader::new(),
1887            user: Default::default(),
1888            auth_plugin_data: Default::default(),
1889            database: Default::default(),
1890            more_data: None,
1891        }
1892    }
1893
1894    pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1895        self.user = user.map(RawBytes::new).unwrap_or_default();
1896        self
1897    }
1898
1899    pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1900        self.database = database.map(RawBytes::new).unwrap_or_default();
1901        self
1902    }
1903
1904    pub fn with_auth_plugin_data(
1905        mut self,
1906        auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1907    ) -> Self {
1908        self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1909        self
1910    }
1911
1912    pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1913        self.more_data = more_data;
1914        self
1915    }
1916
1917    pub fn into_owned(self) -> ComChangeUser<'static> {
1918        ComChangeUser {
1919            __header: self.__header,
1920            user: self.user.into_owned(),
1921            auth_plugin_data: self.auth_plugin_data.into_owned(),
1922            database: self.database.into_owned(),
1923            more_data: self.more_data.map(|x| x.into_owned()),
1924        }
1925    }
1926}
1927
1928impl Default for ComChangeUser<'_> {
1929    fn default() -> Self {
1930        Self::new()
1931    }
1932}
1933
1934impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1935    const SIZE: Option<usize> = None;
1936
1937    type Ctx = CapabilityFlags;
1938
1939    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1940        Ok(Self {
1941            __header: buf.parse(())?,
1942            user: buf.parse(())?,
1943            auth_plugin_data: buf.parse(())?,
1944            database: buf.parse(())?,
1945            more_data: if !buf.is_empty() {
1946                Some(buf.parse(flags)?)
1947            } else {
1948                None
1949            },
1950        })
1951    }
1952}
1953
1954impl MySerialize for ComChangeUser<'_> {
1955    fn serialize(&self, buf: &mut Vec<u8>) {
1956        self.__header.serialize(&mut *buf);
1957        self.user.serialize(&mut *buf);
1958        self.auth_plugin_data.serialize(&mut *buf);
1959        self.database.serialize(&mut *buf);
1960        if let Some(ref more_data) = self.more_data {
1961            more_data.serialize(&mut *buf);
1962        }
1963    }
1964}
1965
1966#[derive(Debug, Clone, PartialEq, Eq)]
1967pub struct ComChangeUserMoreData<'a> {
1968    character_set: RawInt<LeU16>,
1969    auth_plugin: Option<AuthPlugin<'a>>,
1970    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1971}
1972
1973impl<'a> ComChangeUserMoreData<'a> {
1974    pub fn new(character_set: u16) -> Self {
1975        Self {
1976            character_set: RawInt::new(character_set),
1977            auth_plugin: None,
1978            connect_attributes: None,
1979        }
1980    }
1981
1982    pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
1983        self.auth_plugin = auth_plugin;
1984        self
1985    }
1986
1987    pub fn with_connect_attributes(
1988        mut self,
1989        connect_attributes: Option<HashMap<String, String>>,
1990    ) -> Self {
1991        self.connect_attributes = connect_attributes.map(|attrs| {
1992            attrs
1993                .into_iter()
1994                .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
1995                .collect()
1996        });
1997        self
1998    }
1999
2000    pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
2001        ComChangeUserMoreData {
2002            character_set: self.character_set,
2003            auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
2004            connect_attributes: self.connect_attributes.map(|x| {
2005                x.into_iter()
2006                    .map(|(k, v)| (k.into_owned(), v.into_owned()))
2007                    .collect()
2008            }),
2009        }
2010    }
2011}
2012
2013// Helper that deserializes connect attributes.
2014fn deserialize_connect_attrs<'de>(
2015    buf: &mut ParseBuf<'de>,
2016) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
2017    let data_len = buf.parse::<RawInt<LenEnc>>(())?;
2018    let mut data: ParseBuf<'_> = buf.parse(data_len.0 as usize)?;
2019    let mut attrs = HashMap::new();
2020    while !data.is_empty() {
2021        let key = data.parse::<RawBytes<'_, LenEnc>>(())?;
2022        let value = data.parse::<RawBytes<'_, LenEnc>>(())?;
2023        attrs.insert(key, value);
2024    }
2025    Ok(attrs)
2026}
2027
2028// Helper that serializes connect attributes.
2029fn serialize_connect_attrs<'a>(
2030    connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2031    buf: &mut Vec<u8>,
2032) {
2033    let len = connect_attributes
2034        .iter()
2035        .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2036        .sum::<u64>();
2037    buf.put_lenenc_int(len);
2038
2039    for (name, value) in connect_attributes {
2040        name.serialize(&mut *buf);
2041        value.serialize(&mut *buf);
2042    }
2043}
2044
2045impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2046    const SIZE: Option<usize> = None;
2047    type Ctx = CapabilityFlags;
2048
2049    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2050        // always assume CLIENT_PROTOCOL_41
2051        let character_set = buf.parse(())?;
2052        let mut auth_plugin = None;
2053        let mut connect_attributes = None;
2054
2055        if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2056            // plugin name is null-terminated here
2057            match buf.parse::<RawBytes<'_, NullBytes>>(())?.0 {
2058                Cow::Borrowed(bytes) => {
2059                    let mut auth_plugin_buf = ParseBuf(bytes);
2060                    auth_plugin = Some(auth_plugin_buf.parse(())?);
2061                }
2062                _ => unreachable!(),
2063            }
2064        };
2065
2066        if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2067            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2068        };
2069
2070        Ok(Self {
2071            character_set,
2072            auth_plugin,
2073            connect_attributes,
2074        })
2075    }
2076}
2077
2078impl MySerialize for ComChangeUserMoreData<'_> {
2079    fn serialize(&self, buf: &mut Vec<u8>) {
2080        self.character_set.serialize(&mut *buf);
2081        if let Some(ref auth_plugin) = self.auth_plugin {
2082            auth_plugin.serialize(&mut *buf);
2083        }
2084        if let Some(ref connect_attributes) = self.connect_attributes {
2085            serialize_connect_attrs(connect_attributes, buf);
2086        } else {
2087            // We'll always act like CLIENT_CONNECT_ATTRS is set,
2088            // this is to avoid looking into the actual connection flags.
2089            serialize_connect_attrs(&Default::default(), buf);
2090        }
2091    }
2092}
2093
2094/// Actual serialization of this field depends on capability flags values.
2095type ScrambleBuf<'a> =
2096    Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2097
2098#[derive(Debug, Clone, PartialEq, Eq)]
2099pub struct HandshakeResponse<'a> {
2100    capabilities: Const<CapabilityFlags, LeU32>,
2101    max_packet_size: RawInt<LeU32>,
2102    collation: RawInt<u8>,
2103    scramble_buf: ScrambleBuf<'a>,
2104    user: RawBytes<'a, NullBytes>,
2105    db_name: Option<RawBytes<'a, NullBytes>>,
2106    auth_plugin: Option<AuthPlugin<'a>>,
2107    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2108}
2109
2110impl<'a> HandshakeResponse<'a> {
2111    #[allow(clippy::too_many_arguments)]
2112    pub fn new(
2113        scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2114        server_version: (u16, u16, u16),
2115        user: Option<impl Into<Cow<'a, [u8]>>>,
2116        db_name: Option<impl Into<Cow<'a, [u8]>>>,
2117        auth_plugin: Option<AuthPlugin<'a>>,
2118        mut capabilities: CapabilityFlags,
2119        connect_attributes: Option<HashMap<String, String>>,
2120        max_packet_size: u32,
2121    ) -> Self {
2122        let scramble_buf =
2123            if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2124                Either::Left(RawBytes::new(
2125                    scramble_buf.map(Into::into).unwrap_or_default(),
2126                ))
2127            } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2128                Either::Right(Either::Left(RawBytes::new(
2129                    scramble_buf.map(Into::into).unwrap_or_default(),
2130                )))
2131            } else {
2132                Either::Right(Either::Right(RawBytes::new(
2133                    scramble_buf.map(Into::into).unwrap_or_default(),
2134                )))
2135            };
2136
2137        if db_name.is_some() {
2138            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2139        } else {
2140            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2141        }
2142
2143        if auth_plugin.is_some() {
2144            capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2145        } else {
2146            capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2147        }
2148
2149        if connect_attributes.is_some() {
2150            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2151        } else {
2152            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2153        }
2154
2155        Self {
2156            scramble_buf,
2157            collation: if server_version >= (5, 5, 3) {
2158                RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2159            } else {
2160                RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2161            },
2162            user: user.map(RawBytes::new).unwrap_or_default(),
2163            db_name: db_name.map(RawBytes::new),
2164            auth_plugin,
2165            capabilities: Const::new(capabilities),
2166            connect_attributes: connect_attributes.map(|attrs| {
2167                attrs
2168                    .into_iter()
2169                    .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2170                    .collect()
2171            }),
2172            max_packet_size: RawInt::new(max_packet_size),
2173        }
2174    }
2175
2176    pub fn capabilities(&self) -> CapabilityFlags {
2177        self.capabilities.0
2178    }
2179
2180    pub fn collation(&self) -> u8 {
2181        self.collation.0
2182    }
2183
2184    pub fn scramble_buf(&self) -> &[u8] {
2185        match &self.scramble_buf {
2186            Either::Left(x) => x.as_bytes(),
2187            Either::Right(x) => match x {
2188                Either::Left(x) => x.as_bytes(),
2189                Either::Right(x) => x.as_bytes(),
2190            },
2191        }
2192    }
2193
2194    pub fn user(&self) -> &[u8] {
2195        self.user.as_bytes()
2196    }
2197
2198    pub fn db_name(&self) -> Option<&[u8]> {
2199        self.db_name.as_ref().map(|x| x.as_bytes())
2200    }
2201
2202    pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2203        self.auth_plugin.as_ref()
2204    }
2205
2206    #[must_use = "entails computation"]
2207    pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2208        self.connect_attributes.as_ref().map(|attrs| {
2209            attrs
2210                .iter()
2211                .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2212                .collect()
2213        })
2214    }
2215}
2216
2217impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2218    const SIZE: Option<usize> = None;
2219    type Ctx = ();
2220
2221    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2222        let mut sbuf: ParseBuf<'_> = buf.parse(4 + 4 + 1 + 23)?;
2223        let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2224        let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2225        let collation = sbuf.parse_unchecked(())?;
2226        sbuf.parse_unchecked::<Skip<23>>(())?;
2227
2228        let user = buf.parse(())?;
2229        let scramble_buf =
2230            if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2231                Either::Left(buf.parse(())?)
2232            } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2233                Either::Right(Either::Left(buf.parse(())?))
2234            } else {
2235                Either::Right(Either::Right(buf.parse(())?))
2236            };
2237
2238        let mut db_name = None;
2239        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2240            db_name = buf.parse(()).map(Some)?;
2241        }
2242
2243        let mut auth_plugin = None;
2244        if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2245            let auth_plugin_name = buf.eat_null_str();
2246            auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2247        }
2248
2249        let mut connect_attributes = None;
2250        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2251            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2252        }
2253
2254        Ok(Self {
2255            capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2256            max_packet_size,
2257            collation,
2258            scramble_buf,
2259            user,
2260            db_name,
2261            auth_plugin,
2262            connect_attributes,
2263        })
2264    }
2265}
2266
2267impl MySerialize for HandshakeResponse<'_> {
2268    fn serialize(&self, buf: &mut Vec<u8>) {
2269        self.capabilities.serialize(&mut *buf);
2270        self.max_packet_size.serialize(&mut *buf);
2271        self.collation.serialize(&mut *buf);
2272        buf.put_slice(&[0; 23]);
2273        self.user.serialize(&mut *buf);
2274        self.scramble_buf.serialize(&mut *buf);
2275
2276        if let Some(db_name) = &self.db_name {
2277            db_name.serialize(&mut *buf);
2278        }
2279
2280        if let Some(auth_plugin) = &self.auth_plugin {
2281            auth_plugin.serialize(&mut *buf);
2282        }
2283
2284        if let Some(attrs) = &self.connect_attributes {
2285            let len = attrs
2286                .iter()
2287                .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2288                .sum::<u64>();
2289            buf.put_lenenc_int(len);
2290
2291            for (name, value) in attrs {
2292                name.serialize(&mut *buf);
2293                value.serialize(&mut *buf);
2294            }
2295        }
2296    }
2297}
2298
2299#[derive(Debug, Clone, Eq, PartialEq)]
2300pub struct SslRequest {
2301    capabilities: Const<CapabilityFlags, LeU32>,
2302    max_packet_size: RawInt<LeU32>,
2303    character_set: RawInt<u8>,
2304    __skip: Skip<23>,
2305}
2306
2307impl SslRequest {
2308    pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2309        Self {
2310            capabilities: Const::new(capabilities),
2311            max_packet_size: RawInt::new(max_packet_size),
2312            character_set: RawInt::new(character_set),
2313            __skip: Skip,
2314        }
2315    }
2316
2317    pub fn capabilities(&self) -> CapabilityFlags {
2318        self.capabilities.0
2319    }
2320
2321    pub fn max_packet_size(&self) -> u32 {
2322        self.max_packet_size.0
2323    }
2324
2325    pub fn character_set(&self) -> u8 {
2326        self.character_set.0
2327    }
2328}
2329
2330impl<'de> MyDeserialize<'de> for SslRequest {
2331    const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2332    type Ctx = ();
2333
2334    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2335        let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2336        let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2337        Ok(Self {
2338            capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2339            max_packet_size: buf.parse_unchecked(())?,
2340            character_set: buf.parse_unchecked(())?,
2341            __skip: buf.parse_unchecked(())?,
2342        })
2343    }
2344}
2345
2346impl MySerialize for SslRequest {
2347    fn serialize(&self, buf: &mut Vec<u8>) {
2348        self.capabilities.serialize(&mut *buf);
2349        self.max_packet_size.serialize(&mut *buf);
2350        self.character_set.serialize(&mut *buf);
2351        self.__skip.serialize(&mut *buf);
2352    }
2353}
2354
2355#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2356#[error("Invalid statement packet status")]
2357pub struct InvalidStmtPacketStatus;
2358
2359/// Represents MySql's statement packet.
2360#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2361pub struct StmtPacket {
2362    status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2363    statement_id: RawInt<LeU32>,
2364    num_columns: RawInt<LeU16>,
2365    num_params: RawInt<LeU16>,
2366    __skip: Skip<1>,
2367    warning_count: RawInt<LeU16>,
2368}
2369
2370impl<'de> MyDeserialize<'de> for StmtPacket {
2371    const SIZE: Option<usize> = Some(12);
2372    type Ctx = ();
2373
2374    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2375        let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2376        Ok(StmtPacket {
2377            status: buf.parse_unchecked(())?,
2378            statement_id: buf.parse_unchecked(())?,
2379            num_columns: buf.parse_unchecked(())?,
2380            num_params: buf.parse_unchecked(())?,
2381            __skip: buf.parse_unchecked(())?,
2382            warning_count: buf.parse_unchecked(())?,
2383        })
2384    }
2385}
2386
2387impl MySerialize for StmtPacket {
2388    fn serialize(&self, buf: &mut Vec<u8>) {
2389        self.status.serialize(&mut *buf);
2390        self.statement_id.serialize(&mut *buf);
2391        self.num_columns.serialize(&mut *buf);
2392        self.num_params.serialize(&mut *buf);
2393        self.__skip.serialize(&mut *buf);
2394        self.warning_count.serialize(&mut *buf);
2395    }
2396}
2397
2398impl StmtPacket {
2399    /// Value of the statement_id field of a statement packet.
2400    pub fn statement_id(&self) -> u32 {
2401        *self.statement_id
2402    }
2403
2404    /// Value of the num_columns field of a statement packet.
2405    pub fn num_columns(&self) -> u16 {
2406        *self.num_columns
2407    }
2408
2409    /// Value of the num_params field of a statement packet.
2410    pub fn num_params(&self) -> u16 {
2411        *self.num_params
2412    }
2413
2414    /// Value of the warning_count field of a statement packet.
2415    pub fn warning_count(&self) -> u16 {
2416        *self.warning_count
2417    }
2418}
2419
2420/// Null-bitmap.
2421///
2422/// <http://dev.mysql.com/doc/internals/en/null-bitmap.html>
2423#[derive(Debug, Clone, Eq, PartialEq)]
2424pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2425
2426impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2427    const SIZE: Option<usize> = None;
2428    type Ctx = usize;
2429
2430    fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2431        let bitmap_len = Self::bitmap_len(num_columns);
2432        let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2433        Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2434    }
2435}
2436
2437impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2438    /// Creates new null-bitmap for a given number of columns.
2439    pub fn new(num_columns: usize) -> Self {
2440        Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2441    }
2442
2443    /// Will read null-bitmap for a given number of columns from `input`.
2444    pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2445        let bitmap_len = Self::bitmap_len(num_columns);
2446        assert!(input.len() >= bitmap_len);
2447
2448        let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2449        *input = &input[bitmap_len..];
2450
2451        bitmap
2452    }
2453}
2454
2455impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2456    pub fn bitmap_len(num_columns: usize) -> usize {
2457        (num_columns + 7 + T::BIT_OFFSET) / 8
2458    }
2459
2460    fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2461        let offset = column_index + T::BIT_OFFSET;
2462        let byte = offset / 8;
2463        let bit = 1 << (offset % 8) as u8;
2464
2465        assert!(byte < self.0.as_ref().len());
2466
2467        (byte, bit)
2468    }
2469
2470    /// Creates new null-bitmap from given bytes.
2471    pub fn from_bytes(bytes: U) -> Self {
2472        Self(bytes, PhantomData)
2473    }
2474
2475    /// Returns `true` if given column is `NULL` in this `NullBitmap`.
2476    pub fn is_null(&self, column_index: usize) -> bool {
2477        let (byte, bit) = self.byte_and_bit(column_index);
2478        self.0.as_ref()[byte] & bit > 0
2479    }
2480}
2481
2482impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2483    /// Sets flag value for given column.
2484    pub fn set(&mut self, column_index: usize, is_null: bool) {
2485        let (byte, bit) = self.byte_and_bit(column_index);
2486        if is_null {
2487            self.0.as_mut()[byte] |= bit
2488        } else {
2489            self.0.as_mut()[byte] &= !bit
2490        }
2491    }
2492}
2493
2494impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2495    fn as_ref(&self) -> &[u8] {
2496        self.0.as_ref()
2497    }
2498}
2499
2500#[derive(Debug, Clone, PartialEq)]
2501pub struct ComStmtExecuteRequestBuilder {
2502    pub stmt_id: u32,
2503}
2504
2505impl ComStmtExecuteRequestBuilder {
2506    pub const NULL_BITMAP_OFFSET: usize = 10;
2507
2508    pub fn new(stmt_id: u32) -> Self {
2509        Self { stmt_id }
2510    }
2511}
2512
2513impl ComStmtExecuteRequestBuilder {
2514    pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2515        let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2516
2517        let mut bitmap_bytes = vec![0; bitmap_len];
2518        let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2519        let params = params.iter().collect::<Vec<_>>();
2520
2521        let meta_len = params.len() * 2;
2522
2523        let mut data_len = 0;
2524        for (i, param) in params.iter().enumerate() {
2525            match param.bin_len() as usize {
2526                0 => bitmap.set(i, true),
2527                x => data_len += x,
2528            }
2529        }
2530
2531        let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2532
2533        let as_long_data = total_len > MAX_PAYLOAD_LEN;
2534
2535        (
2536            ComStmtExecuteRequest {
2537                com_stmt_execute: ConstU8::new(),
2538                stmt_id: RawInt::new(self.stmt_id),
2539                flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2540                iteration_count: ConstU32::new(),
2541                params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2542                bitmap: RawBytes::new(bitmap_bytes),
2543                params,
2544                as_long_data,
2545            },
2546            as_long_data,
2547        )
2548    }
2549}
2550
2551define_header!(
2552    ComStmtExecuteHeader,
2553    COM_STMT_EXECUTE,
2554    InvalidComStmtExecuteHeader
2555);
2556
2557define_const!(
2558    ConstU32,
2559    IterationCount,
2560    InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2561    1
2562);
2563
2564#[derive(Debug, Clone, PartialEq)]
2565pub struct ComStmtExecuteRequest<'a> {
2566    com_stmt_execute: ComStmtExecuteHeader,
2567    stmt_id: RawInt<LeU32>,
2568    flags: Const<CursorType, u8>,
2569    iteration_count: IterationCount,
2570    // max params / bits per byte = 8192
2571    bitmap: RawBytes<'a, BareBytes<8192>>,
2572    params_flags: Const<StmtExecuteParamsFlags, u8>,
2573    params: Vec<&'a Value>,
2574    as_long_data: bool,
2575}
2576
2577impl<'a> ComStmtExecuteRequest<'a> {
2578    pub fn stmt_id(&self) -> u32 {
2579        self.stmt_id.0
2580    }
2581
2582    pub fn flags(&self) -> CursorType {
2583        self.flags.0
2584    }
2585
2586    pub fn bitmap(&self) -> &[u8] {
2587        self.bitmap.as_bytes()
2588    }
2589
2590    pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2591        self.params_flags.0
2592    }
2593
2594    pub fn params(&self) -> &[&'a Value] {
2595        self.params.as_ref()
2596    }
2597
2598    pub fn as_long_data(&self) -> bool {
2599        self.as_long_data
2600    }
2601}
2602
2603impl MySerialize for ComStmtExecuteRequest<'_> {
2604    fn serialize(&self, buf: &mut Vec<u8>) {
2605        self.com_stmt_execute.serialize(&mut *buf);
2606        self.stmt_id.serialize(&mut *buf);
2607        self.flags.serialize(&mut *buf);
2608        self.iteration_count.serialize(&mut *buf);
2609
2610        if !self.params.is_empty() {
2611            self.bitmap.serialize(&mut *buf);
2612            self.params_flags.serialize(&mut *buf);
2613        }
2614
2615        for param in &self.params {
2616            let (column_type, flags) = match param {
2617                Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2618                Value::Bytes(_) => (
2619                    ColumnType::MYSQL_TYPE_VAR_STRING,
2620                    StmtExecuteParamFlags::empty(),
2621                ),
2622                Value::Int(_) => (
2623                    ColumnType::MYSQL_TYPE_LONGLONG,
2624                    StmtExecuteParamFlags::empty(),
2625                ),
2626                Value::UInt(_) => (
2627                    ColumnType::MYSQL_TYPE_LONGLONG,
2628                    StmtExecuteParamFlags::UNSIGNED,
2629                ),
2630                Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2631                Value::Double(_) => (
2632                    ColumnType::MYSQL_TYPE_DOUBLE,
2633                    StmtExecuteParamFlags::empty(),
2634                ),
2635                Value::Date(..) => (
2636                    ColumnType::MYSQL_TYPE_DATETIME,
2637                    StmtExecuteParamFlags::empty(),
2638                ),
2639                Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2640            };
2641
2642            buf.put_slice(&[column_type as u8, flags.bits()]);
2643        }
2644
2645        for param in &self.params {
2646            match **param {
2647                Value::Int(_)
2648                | Value::UInt(_)
2649                | Value::Float(_)
2650                | Value::Double(_)
2651                | Value::Date(..)
2652                | Value::Time(..) => {
2653                    param.serialize(buf);
2654                }
2655                Value::Bytes(_) if !self.as_long_data => {
2656                    param.serialize(buf);
2657                }
2658                Value::Bytes(_) | Value::NULL => {}
2659            }
2660        }
2661    }
2662}
2663
2664define_header!(
2665    ComStmtSendLongDataHeader,
2666    COM_STMT_SEND_LONG_DATA,
2667    InvalidComStmtSendLongDataHeader
2668);
2669
2670#[derive(Debug, Clone, Eq, PartialEq)]
2671pub struct ComStmtSendLongData<'a> {
2672    __header: ComStmtSendLongDataHeader,
2673    stmt_id: RawInt<LeU32>,
2674    param_index: RawInt<LeU16>,
2675    data: RawBytes<'a, EofBytes>,
2676}
2677
2678impl<'a> ComStmtSendLongData<'a> {
2679    pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2680        Self {
2681            __header: ComStmtSendLongDataHeader::new(),
2682            stmt_id: RawInt::new(stmt_id),
2683            param_index: RawInt::new(param_index),
2684            data: RawBytes::new(data),
2685        }
2686    }
2687
2688    pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2689        ComStmtSendLongData {
2690            __header: self.__header,
2691            stmt_id: self.stmt_id,
2692            param_index: self.param_index,
2693            data: self.data.into_owned(),
2694        }
2695    }
2696}
2697
2698impl MySerialize for ComStmtSendLongData<'_> {
2699    fn serialize(&self, buf: &mut Vec<u8>) {
2700        self.__header.serialize(&mut *buf);
2701        self.stmt_id.serialize(&mut *buf);
2702        self.param_index.serialize(&mut *buf);
2703        self.data.serialize(&mut *buf);
2704    }
2705}
2706
2707#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2708pub struct ComStmtClose {
2709    pub stmt_id: u32,
2710}
2711
2712impl ComStmtClose {
2713    pub fn new(stmt_id: u32) -> Self {
2714        Self { stmt_id }
2715    }
2716}
2717
2718impl MySerialize for ComStmtClose {
2719    fn serialize(&self, buf: &mut Vec<u8>) {
2720        buf.put_u8(Command::COM_STMT_CLOSE as u8);
2721        buf.put_u32_le(self.stmt_id);
2722    }
2723}
2724
2725define_header!(
2726    ComRegisterSlaveHeader,
2727    COM_REGISTER_SLAVE,
2728    InvalidComRegisterSlaveHeader
2729);
2730
2731/// Registers a slave at the master. Should be sent before requesting a binlog events
2732/// with `COM_BINLOG_DUMP`.
2733#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2734pub struct ComRegisterSlave<'a> {
2735    header: ComRegisterSlaveHeader,
2736    /// The slaves server-id.
2737    server_id: RawInt<LeU32>,
2738    /// The host name or IP address of the slave to be reported to the master during slave
2739    /// registration. Usually empty.
2740    hostname: RawBytes<'a, U8Bytes>,
2741    /// The account user name of the slave to be reported to the master during slave registration.
2742    /// Usually empty.
2743    ///
2744    /// # Note
2745    ///
2746    /// Serialization will truncate this value if length is greater than 255 bytes.
2747    user: RawBytes<'a, U8Bytes>,
2748    /// The account password of the slave to be reported to the master during slave registration.
2749    /// Usually empty.
2750    ///
2751    /// # Note
2752    ///
2753    /// Serialization will truncate this value if length is greater than 255 bytes.
2754    password: RawBytes<'a, U8Bytes>,
2755    /// The TCP/IP port number for connecting to the slave, to be reported to the master during
2756    /// slave registration. Usually empty.
2757    ///
2758    /// # Note
2759    ///
2760    /// Serialization will truncate this value if length is greater than 255 bytes.
2761    port: RawInt<LeU16>,
2762    /// Ignored.
2763    replication_rank: RawInt<LeU32>,
2764    /// Usually 0. Appears as "master id" in `SHOW SLAVE HOSTS` on the master. Unknown what else
2765    /// it impacts.
2766    master_id: RawInt<LeU32>,
2767}
2768
2769impl<'a> ComRegisterSlave<'a> {
2770    /// Creates new `ComRegisterSlave` with the given server identifier. Other fields will be empty.
2771    pub fn new(server_id: u32) -> Self {
2772        Self {
2773            header: Default::default(),
2774            server_id: RawInt::new(server_id),
2775            hostname: Default::default(),
2776            user: Default::default(),
2777            password: Default::default(),
2778            port: Default::default(),
2779            replication_rank: Default::default(),
2780            master_id: Default::default(),
2781        }
2782    }
2783
2784    /// Sets the `hostname` field of the packet (maximum length is 255 bytes).
2785    pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2786        self.hostname = RawBytes::new(hostname);
2787        self
2788    }
2789
2790    /// Sets the `user` field of the packet (maximum length is 255 bytes).
2791    pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2792        self.user = RawBytes::new(user);
2793        self
2794    }
2795
2796    /// Sets the `password` field of the packet (maximum length is 255 bytes).
2797    pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2798        self.password = RawBytes::new(password);
2799        self
2800    }
2801
2802    /// Sets the `port` field of the packet.
2803    pub fn with_port(mut self, port: u16) -> Self {
2804        self.port = RawInt::new(port);
2805        self
2806    }
2807
2808    /// Sets the `replication_rank` field of the packet.
2809    pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2810        self.replication_rank = RawInt::new(replication_rank);
2811        self
2812    }
2813
2814    /// Sets the `master_id` field of the packet.
2815    pub fn with_master_id(mut self, master_id: u32) -> Self {
2816        self.master_id = RawInt::new(master_id);
2817        self
2818    }
2819
2820    /// Returns the `server_id` field of the packet.
2821    pub fn server_id(&self) -> u32 {
2822        self.server_id.0
2823    }
2824
2825    /// Returns the raw `hostname` field value.
2826    pub fn hostname_raw(&self) -> &[u8] {
2827        self.hostname.as_bytes()
2828    }
2829
2830    /// Returns the `hostname` field as a UTF-8 string (lossy converted).
2831    pub fn hostname(&'a self) -> Cow<'a, str> {
2832        self.hostname.as_str()
2833    }
2834
2835    /// Returns the raw `user` field value.
2836    pub fn user_raw(&self) -> &[u8] {
2837        self.user.as_bytes()
2838    }
2839
2840    /// Returns the `user` field as a UTF-8 string (lossy converted).
2841    pub fn user(&'a self) -> Cow<'a, str> {
2842        self.user.as_str()
2843    }
2844
2845    /// Returns the raw `password` field value.
2846    pub fn password_raw(&self) -> &[u8] {
2847        self.password.as_bytes()
2848    }
2849
2850    /// Returns the `password` field as a UTF-8 string (lossy converted).
2851    pub fn password(&'a self) -> Cow<'a, str> {
2852        self.password.as_str()
2853    }
2854
2855    /// Returns the `port` field of the packet.
2856    pub fn port(&self) -> u16 {
2857        self.port.0
2858    }
2859
2860    /// Returns the `replication_rank` field of the packet.
2861    pub fn replication_rank(&self) -> u32 {
2862        self.replication_rank.0
2863    }
2864
2865    /// Returns the `master_id` field of the packet.
2866    pub fn master_id(&self) -> u32 {
2867        self.master_id.0
2868    }
2869}
2870
2871impl MySerialize for ComRegisterSlave<'_> {
2872    fn serialize(&self, buf: &mut Vec<u8>) {
2873        self.header.serialize(&mut *buf);
2874        self.server_id.serialize(&mut *buf);
2875        self.hostname.serialize(&mut *buf);
2876        self.user.serialize(&mut *buf);
2877        self.password.serialize(&mut *buf);
2878        self.port.serialize(&mut *buf);
2879        self.replication_rank.serialize(&mut *buf);
2880        self.master_id.serialize(&mut *buf);
2881    }
2882}
2883
2884impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2885    const SIZE: Option<usize> = None;
2886    type Ctx = ();
2887
2888    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2889        let mut sbuf: ParseBuf<'_> = buf.parse(5)?;
2890        let header = sbuf.parse_unchecked(())?;
2891        let server_id = sbuf.parse_unchecked(())?;
2892
2893        let hostname = buf.parse(())?;
2894        let user = buf.parse(())?;
2895        let password = buf.parse(())?;
2896
2897        let mut sbuf: ParseBuf<'_> = buf.parse(10)?;
2898        let port = sbuf.parse_unchecked(())?;
2899        let replication_rank = sbuf.parse_unchecked(())?;
2900        let master_id = sbuf.parse_unchecked(())?;
2901
2902        Ok(Self {
2903            header,
2904            server_id,
2905            hostname,
2906            user,
2907            password,
2908            port,
2909            replication_rank,
2910            master_id,
2911        })
2912    }
2913}
2914
2915define_header!(
2916    ComTableDumpHeader,
2917    COM_TABLE_DUMP,
2918    InvalidComTableDumpHeader
2919);
2920
2921/// COM_TABLE_DUMP command.
2922#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2923pub struct ComTableDump<'a> {
2924    header: ComTableDumpHeader,
2925    /// Database name.
2926    ///
2927    /// # Note
2928    ///
2929    /// Serialization will truncate this value if length is greater than 255 bytes.
2930    database: RawBytes<'a, U8Bytes>,
2931    /// Table name.
2932    ///
2933    /// # Note
2934    ///
2935    /// Serialization will truncate this value if length is greater than 255 bytes.
2936    table: RawBytes<'a, U8Bytes>,
2937}
2938
2939impl<'a> ComTableDump<'a> {
2940    /// Creates new instance.
2941    pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2942        Self {
2943            header: Default::default(),
2944            database: RawBytes::new(database),
2945            table: RawBytes::new(table),
2946        }
2947    }
2948
2949    /// Returns the raw `database` field value.
2950    pub fn database_raw(&self) -> &[u8] {
2951        self.database.as_bytes()
2952    }
2953
2954    /// Returns the `database` field value as a UTF-8 string (lossy converted).
2955    pub fn database(&self) -> Cow<'_, str> {
2956        self.database.as_str()
2957    }
2958
2959    /// Returns the raw `table` field value.
2960    pub fn table_raw(&self) -> &[u8] {
2961        self.table.as_bytes()
2962    }
2963
2964    /// Returns the `table` field value as a UTF-8 string (lossy converted).
2965    pub fn table(&self) -> Cow<'_, str> {
2966        self.table.as_str()
2967    }
2968}
2969
2970impl MySerialize for ComTableDump<'_> {
2971    fn serialize(&self, buf: &mut Vec<u8>) {
2972        self.header.serialize(&mut *buf);
2973        self.database.serialize(&mut *buf);
2974        self.table.serialize(&mut *buf);
2975    }
2976}
2977
2978impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
2979    const SIZE: Option<usize> = None;
2980    type Ctx = ();
2981
2982    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2983        Ok(Self {
2984            header: buf.parse(())?,
2985            database: buf.parse(())?,
2986            table: buf.parse(())?,
2987        })
2988    }
2989}
2990
2991my_bitflags! {
2992    BinlogDumpFlags,
2993    #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={0:b})")]
2994    UnknownBinlogDumpFlags,
2995    u16,
2996
2997    /// Empty flags of a `LoadEvent`.
2998    #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
2999    pub struct BinlogDumpFlags: u16 {
3000        /// If there is no more event to send a EOF_Packet instead of blocking the connection
3001        const BINLOG_DUMP_NON_BLOCK = 0x01;
3002        const BINLOG_THROUGH_POSITION = 0x02;
3003        const BINLOG_THROUGH_GTID = 0x04;
3004    }
3005}
3006
3007define_header!(
3008    ComBinlogDumpHeader,
3009    COM_BINLOG_DUMP,
3010    InvalidComBinlogDumpHeader
3011);
3012
3013/// Command to request a binlog-stream from the master starting a given position.
3014#[derive(Clone, Debug, Eq, PartialEq, Hash)]
3015pub struct ComBinlogDump<'a> {
3016    header: ComBinlogDumpHeader,
3017    /// Position in the binlog-file to start the stream with (`0` by default).
3018    pos: RawInt<LeU32>,
3019    /// Command flags (empty by default).
3020    ///
3021    /// Only `BINLOG_DUMP_NON_BLOCK` is supported for this command.
3022    flags: Const<BinlogDumpFlags, LeU16>,
3023    /// Server id of this slave.
3024    server_id: RawInt<LeU32>,
3025    /// Filename of the binlog on the master.
3026    ///
3027    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3028    /// binlog.
3029    filename: RawBytes<'a, EofBytes>,
3030}
3031
3032impl<'a> ComBinlogDump<'a> {
3033    /// Creates new instance with default values for `pos` and `flags`.
3034    pub fn new(server_id: u32) -> Self {
3035        Self {
3036            header: Default::default(),
3037            pos: Default::default(),
3038            flags: Default::default(),
3039            server_id: RawInt::new(server_id),
3040            filename: Default::default(),
3041        }
3042    }
3043
3044    /// Defines position for this instance.
3045    pub fn with_pos(mut self, pos: u32) -> Self {
3046        self.pos = RawInt::new(pos);
3047        self
3048    }
3049
3050    /// Defines flags for this instance.
3051    pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3052        self.flags = Const::new(flags);
3053        self
3054    }
3055
3056    /// Defines filename for this instance.
3057    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3058        self.filename = RawBytes::new(filename);
3059        self
3060    }
3061
3062    /// Returns parsed `pos` field with unknown bits truncated.
3063    pub fn pos(&self) -> u32 {
3064        *self.pos
3065    }
3066
3067    /// Returns parsed `flags` field with unknown bits truncated.
3068    pub fn flags(&self) -> BinlogDumpFlags {
3069        *self.flags
3070    }
3071
3072    /// Returns parsed `server_id` field with unknown bits truncated.
3073    pub fn server_id(&self) -> u32 {
3074        *self.server_id
3075    }
3076
3077    /// Returns the raw `filename` field value.
3078    pub fn filename_raw(&self) -> &[u8] {
3079        self.filename.as_bytes()
3080    }
3081
3082    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3083    pub fn filename(&self) -> Cow<'_, str> {
3084        self.filename.as_str()
3085    }
3086}
3087
3088impl MySerialize for ComBinlogDump<'_> {
3089    fn serialize(&self, buf: &mut Vec<u8>) {
3090        self.header.serialize(&mut *buf);
3091        self.pos.serialize(&mut *buf);
3092        self.flags.serialize(&mut *buf);
3093        self.server_id.serialize(&mut *buf);
3094        self.filename.serialize(&mut *buf);
3095    }
3096}
3097
3098impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3099    const SIZE: Option<usize> = None;
3100    type Ctx = ();
3101
3102    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3103        let mut sbuf: ParseBuf<'_> = buf.parse(11)?;
3104        Ok(Self {
3105            header: sbuf.parse_unchecked(())?,
3106            pos: sbuf.parse_unchecked(())?,
3107            flags: sbuf.parse_unchecked(())?,
3108            server_id: sbuf.parse_unchecked(())?,
3109            filename: buf.parse(())?,
3110        })
3111    }
3112}
3113
3114/// GnoInterval. Stored within [`Sid`]
3115#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3116pub struct GnoInterval {
3117    start: RawInt<LeU64>,
3118    end: RawInt<LeU64>,
3119}
3120
3121impl GnoInterval {
3122    /// Creates a new interval.
3123    pub fn new(start: u64, end: u64) -> Self {
3124        Self {
3125            start: RawInt::new(start),
3126            end: RawInt::new(end),
3127        }
3128    }
3129    /// Checks if the [start, end) interval is valid and creates it.
3130    pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3131        if start >= end {
3132            return Err(io::Error::new(
3133                io::ErrorKind::InvalidData,
3134                format!("start({}) >= end({}) in GnoInterval", start, end),
3135            ));
3136        }
3137        if start == 0 || end == 0 {
3138            return Err(io::Error::new(
3139                io::ErrorKind::InvalidData,
3140                "Gno can't be zero",
3141            ));
3142        }
3143        Ok(Self::new(start, end))
3144    }
3145}
3146
3147impl MySerialize for GnoInterval {
3148    fn serialize(&self, buf: &mut Vec<u8>) {
3149        self.start.serialize(&mut *buf);
3150        self.end.serialize(&mut *buf);
3151    }
3152}
3153
3154impl<'de> MyDeserialize<'de> for GnoInterval {
3155    const SIZE: Option<usize> = Some(16);
3156    type Ctx = ();
3157
3158    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3159        Ok(Self {
3160            start: buf.parse_unchecked(())?,
3161            end: buf.parse_unchecked(())?,
3162        })
3163    }
3164}
3165
3166/// Length of a Uuid in `COM_BINLOG_DUMP_GTID` command packet.
3167pub const UUID_LEN: usize = 16;
3168
3169/// SID is a part of the `COM_BINLOG_DUMP_GTID` command. It's a GtidSet whose
3170/// has only one Uuid.
3171#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3172pub struct Sid<'a> {
3173    uuid: [u8; UUID_LEN],
3174    intervals: Seq<'a, GnoInterval, LeU64>,
3175}
3176
3177impl Sid<'_> {
3178    /// Creates a new instance.
3179    pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3180        Self {
3181            uuid,
3182            intervals: Default::default(),
3183        }
3184    }
3185
3186    /// Returns the `uuid` field value.
3187    pub fn uuid(&self) -> [u8; UUID_LEN] {
3188        self.uuid
3189    }
3190
3191    /// Returns the `intervals` field value.
3192    pub fn intervals(&self) -> &[GnoInterval] {
3193        &self.intervals[..]
3194    }
3195
3196    /// Appends an GnoInterval to this block.
3197    pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3198        let mut intervals = self.intervals.0.into_owned();
3199        intervals.push(interval);
3200        self.intervals = Seq::new(intervals);
3201        self
3202    }
3203
3204    /// Sets the `intevals` value for this block.
3205    pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3206        self.intervals = Seq::new(intervals);
3207        self
3208    }
3209
3210    fn len(&self) -> u64 {
3211        use saturating::Saturating as S;
3212        let mut len = S(UUID_LEN as u64); // SID
3213        len += S(8); // n_intervals
3214        len += S((self.intervals.len() * 16) as u64);
3215        len.0
3216    }
3217}
3218
3219impl MySerialize for Sid<'_> {
3220    fn serialize(&self, buf: &mut Vec<u8>) {
3221        self.uuid.serialize(&mut *buf);
3222        self.intervals.serialize(buf);
3223    }
3224}
3225
3226impl<'de> MyDeserialize<'de> for Sid<'de> {
3227    const SIZE: Option<usize> = None;
3228    type Ctx = ();
3229
3230    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3231        Ok(Self {
3232            uuid: buf.parse(())?,
3233            intervals: buf.parse(())?,
3234        })
3235    }
3236}
3237
3238impl Sid<'_> {
3239    fn wrap_err(msg: String) -> io::Error {
3240        io::Error::new(io::ErrorKind::InvalidInput, msg)
3241    }
3242
3243    fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3244        let n: u64 = to_parse.parse().map_err(|e| {
3245            Sid::wrap_err(format!(
3246                "invalid GnoInterval format: {}, error: {}",
3247                full, e
3248            ))
3249        })?;
3250        Ok(n)
3251    }
3252}
3253
3254impl FromStr for Sid<'_> {
3255    type Err = io::Error;
3256
3257    fn from_str(s: &str) -> Result<Self, Self::Err> {
3258        let (uuid, intervals) = s
3259            .split_once(':')
3260            .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3261        let uuid = Uuid::parse_str(uuid)
3262            .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3263        let intervals = intervals
3264            .split(':')
3265            .map(|interval| {
3266                let nums = interval.split('-').collect::<Vec<_>>();
3267                if nums.len() != 1 && nums.len() != 2 {
3268                    return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3269                }
3270                if nums.len() == 1 {
3271                    let start = Sid::parse_interval_num(nums[0], s)?;
3272                    let interval = GnoInterval::check_and_new(start, start + 1)?;
3273                    Ok(interval)
3274                } else {
3275                    let start = Sid::parse_interval_num(nums[0], s)?;
3276                    let end = Sid::parse_interval_num(nums[1], s)?;
3277                    let interval = GnoInterval::check_and_new(start, end + 1)?;
3278                    Ok(interval)
3279                }
3280            })
3281            .collect::<Result<Vec<_>, _>>()?;
3282        Ok(Self {
3283            uuid: *uuid.as_bytes(),
3284            intervals: Seq::new(intervals),
3285        })
3286    }
3287}
3288
3289define_header!(
3290    ComBinlogDumpGtidHeader,
3291    COM_BINLOG_DUMP_GTID,
3292    InvalidComBinlogDumpGtidHeader
3293);
3294
3295/// Command to request a binlog-stream from the master starting a given position.
3296#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3297pub struct ComBinlogDumpGtid<'a> {
3298    header: ComBinlogDumpGtidHeader,
3299    /// Command flags (empty by default).
3300    flags: Const<BinlogDumpFlags, LeU16>,
3301    /// Server id of this slave.
3302    server_id: RawInt<LeU32>,
3303    /// Filename of the binlog on the master.
3304    ///
3305    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3306    /// binlog.
3307    ///
3308    /// # Note
3309    ///
3310    /// Serialization will truncate this value if length is greater than 2^32 - 1 bytes.
3311    filename: RawBytes<'a, U32Bytes>,
3312    /// Position in the binlog-file to start the stream with (`0` by default).
3313    pos: RawInt<LeU64>,
3314    /// SID block.
3315    sid_block: Seq<'a, Sid<'a>, LeU64>,
3316}
3317
3318impl<'a> ComBinlogDumpGtid<'a> {
3319    /// Creates new instance with default values for `pos`, `data` and `flags` fields.
3320    pub fn new(server_id: u32) -> Self {
3321        Self {
3322            header: Default::default(),
3323            pos: Default::default(),
3324            flags: Default::default(),
3325            server_id: RawInt::new(server_id),
3326            filename: Default::default(),
3327            sid_block: Default::default(),
3328        }
3329    }
3330
3331    /// Returns the `server_id` field value.
3332    pub fn server_id(&self) -> u32 {
3333        self.server_id.0
3334    }
3335
3336    /// Returns the `flags` field value.
3337    pub fn flags(&self) -> BinlogDumpFlags {
3338        self.flags.0
3339    }
3340
3341    /// Returns the `filename` field value.
3342    pub fn filename_raw(&self) -> &[u8] {
3343        self.filename.as_bytes()
3344    }
3345
3346    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3347    pub fn filename(&self) -> Cow<'_, str> {
3348        self.filename.as_str()
3349    }
3350
3351    /// Returns the `pos` field value.
3352    pub fn pos(&self) -> u64 {
3353        self.pos.0
3354    }
3355
3356    /// Returns the sequence of sids in this packet.
3357    pub fn sids(&self) -> &[Sid<'a>] {
3358        &self.sid_block
3359    }
3360
3361    /// Defines filename for this instance.
3362    pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3363        Self {
3364            header: self.header,
3365            flags: self.flags,
3366            server_id: self.server_id,
3367            filename: RawBytes::new(filename),
3368            pos: self.pos,
3369            sid_block: self.sid_block,
3370        }
3371    }
3372
3373    /// Sets the `server_id` field value.
3374    pub fn with_server_id(mut self, server_id: u32) -> Self {
3375        self.server_id.0 = server_id;
3376        self
3377    }
3378
3379    /// Sets the `flags` field value.
3380    pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3381        if self.sid_block.is_empty() {
3382            flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3383        } else {
3384            flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3385        }
3386        self.flags.0 = flags;
3387        self
3388    }
3389
3390    /// Sets the `pos` field value.
3391    pub fn with_pos(mut self, pos: u64) -> Self {
3392        self.pos.0 = pos;
3393        self
3394    }
3395
3396    /// Sets the `sid_block` field value.
3397    pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3398        self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3399        self.sid_block.push(sid);
3400        self
3401    }
3402
3403    /// Sets the `sid_block` field value.
3404    pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3405        self.sid_block = Seq::new(sids);
3406        if self.sid_block.is_empty() {
3407            self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3408        } else {
3409            self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3410        }
3411        self
3412    }
3413
3414    fn sid_block_len(&self) -> u32 {
3415        use saturating::Saturating as S;
3416        let mut len = S(8); // n_sids
3417        for sid in self.sid_block.iter() {
3418            len += S(sid.len() as u32);
3419        }
3420        len.0
3421    }
3422}
3423
3424impl MySerialize for ComBinlogDumpGtid<'_> {
3425    fn serialize(&self, buf: &mut Vec<u8>) {
3426        self.header.serialize(&mut *buf);
3427        self.flags.serialize(&mut *buf);
3428        self.server_id.serialize(&mut *buf);
3429        self.filename.serialize(&mut *buf);
3430        self.pos.serialize(&mut *buf);
3431        buf.put_u32_le(self.sid_block_len());
3432        self.sid_block.serialize(&mut *buf);
3433    }
3434}
3435
3436impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3437    const SIZE: Option<usize> = None;
3438    type Ctx = ();
3439
3440    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3441        let mut sbuf: ParseBuf<'_> = buf.parse(7)?;
3442        let header = sbuf.parse_unchecked(())?;
3443        let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3444        let server_id = sbuf.parse_unchecked(())?;
3445
3446        let filename = buf.parse(())?;
3447        let pos = buf.parse(())?;
3448
3449        // `flags` should contain `BINLOG_THROUGH_GTID` flag if sid_block isn't empty
3450        let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3451        let mut buf: ParseBuf<'_> = buf.parse(sid_data_len.0 as usize)?;
3452        let sid_block = buf.parse(())?;
3453
3454        Ok(Self {
3455            header,
3456            flags,
3457            server_id,
3458            filename,
3459            pos,
3460            sid_block,
3461        })
3462    }
3463}
3464
3465define_header!(
3466    SemiSyncAckPacketPacketHeader,
3467    InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3468    0xEF
3469);
3470
3471/// Each Semi Sync Binlog Event with the `SEMI_SYNC_ACK_REQ` flag set the slave has to acknowledge
3472/// with Semi-Sync ACK packet.
3473pub struct SemiSyncAckPacket<'a> {
3474    header: SemiSyncAckPacketPacketHeader,
3475    position: RawInt<LeU64>,
3476    filename: RawBytes<'a, EofBytes>,
3477}
3478
3479impl<'a> SemiSyncAckPacket<'a> {
3480    pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3481        Self {
3482            header: Default::default(),
3483            position: RawInt::new(position),
3484            filename: RawBytes::new(filename),
3485        }
3486    }
3487
3488    /// Sets the `position` field value.
3489    pub fn with_position(mut self, position: u64) -> Self {
3490        self.position.0 = position;
3491        self
3492    }
3493
3494    /// Sets the `filename` field value.
3495    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3496        self.filename = RawBytes::new(filename);
3497        self
3498    }
3499
3500    /// Returns the `position` field value.
3501    pub fn position(&self) -> u64 {
3502        self.position.0
3503    }
3504
3505    /// Returns the raw `filename` field value.
3506    pub fn filename_raw(&self) -> &[u8] {
3507        self.filename.as_bytes()
3508    }
3509
3510    /// Returns the `filename` field value as a string (lossy converted).
3511    pub fn filename(&self) -> Cow<'_, str> {
3512        self.filename.as_str()
3513    }
3514}
3515
3516impl MySerialize for SemiSyncAckPacket<'_> {
3517    fn serialize(&self, buf: &mut Vec<u8>) {
3518        self.header.serialize(&mut *buf);
3519        self.position.serialize(&mut *buf);
3520        self.filename.serialize(&mut *buf);
3521    }
3522}
3523
3524impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3525    const SIZE: Option<usize> = None;
3526    type Ctx = ();
3527
3528    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3529        let mut sbuf: ParseBuf<'_> = buf.parse(9)?;
3530        Ok(Self {
3531            header: sbuf.parse_unchecked(())?,
3532            position: sbuf.parse_unchecked(())?,
3533            filename: buf.parse(())?,
3534        })
3535    }
3536}
3537
3538#[cfg(test)]
3539mod test {
3540    use super::*;
3541    use crate::{
3542        constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3543        proto::{MyDeserialize, MySerialize},
3544    };
3545
3546    proptest::proptest! {
3547        #[test]
3548        fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3549            let cmd = ComTableDump::new(database, table);
3550
3551            let mut output = Vec::new();
3552            cmd.serialize(&mut output);
3553
3554            assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3555        }
3556
3557        #[test]
3558        fn com_binlog_dump_roundtrip(
3559            server_id: u32,
3560            filename: Vec<u8>,
3561            pos: u32,
3562            flags: u16,
3563        ) {
3564            let cmd = ComBinlogDump::new(server_id)
3565                .with_filename(filename)
3566                .with_pos(pos)
3567                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3568
3569            let mut output = Vec::new();
3570            cmd.serialize(&mut output);
3571
3572            assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3573        }
3574
3575        #[test]
3576        fn com_register_slave_roundtrip(
3577            server_id: u32,
3578            hostname in r"\w{0,256}",
3579            user in r"\w{0,256}",
3580            password in r"\w{0,256}",
3581            port: u16,
3582            replication_rank: u32,
3583            master_id: u32,
3584        ) {
3585            let cmd = ComRegisterSlave::new(server_id)
3586                .with_hostname(hostname.as_bytes())
3587                .with_user(user.as_bytes())
3588                .with_password(password.as_bytes())
3589                .with_port(port)
3590                .with_replication_rank(replication_rank)
3591                .with_master_id(master_id);
3592
3593            let mut output = Vec::new();
3594            cmd.serialize(&mut output);
3595            let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3596
3597            if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3598                assert_ne!(cmd, parsed);
3599            } else {
3600                assert_eq!(cmd, parsed);
3601            }
3602        }
3603
3604        #[test]
3605        fn com_binlog_dump_gtid_roundtrip(
3606            flags: u16,
3607            server_id: u32,
3608            filename: Vec<u8>,
3609            pos: u64,
3610            n_sid_blocks in 0_u64..1024,
3611        ) {
3612            let mut cmd = ComBinlogDumpGtid::new(server_id)
3613                .with_filename(filename)
3614                .with_pos(pos)
3615                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3616
3617            let mut sids = Vec::new();
3618            for i in 0..n_sid_blocks {
3619                let mut block = Sid::new([i as u8; 16]);
3620                for j in 0..i {
3621                    block = block.with_interval(GnoInterval::new(i, j));
3622                }
3623                sids.push(block);
3624            }
3625
3626            cmd = cmd.with_sids(sids);
3627
3628            let mut output = Vec::new();
3629            cmd.serialize(&mut output);
3630
3631            assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3632        }
3633    }
3634
3635    #[test]
3636    fn should_parse_local_infile_packet() {
3637        const LIP: &[u8] = b"\xfbfile_name";
3638
3639        let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3640        assert_eq!(lip.file_name_str(), "file_name");
3641    }
3642
3643    #[test]
3644    fn should_parse_stmt_packet() {
3645        const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3646        const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3647
3648        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3649        assert_eq!(sp.statement_id(), 0x01);
3650        assert_eq!(sp.num_columns(), 0x01);
3651        assert_eq!(sp.num_params(), 0x02);
3652        assert_eq!(sp.warning_count(), 0x00);
3653
3654        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3655        assert_eq!(sp.statement_id(), 0x01);
3656        assert_eq!(sp.num_columns(), 0x00);
3657        assert_eq!(sp.num_params(), 0x00);
3658        assert_eq!(sp.warning_count(), 0x00);
3659    }
3660
3661    #[test]
3662    fn should_parse_handshake_packet() {
3663        const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3664                             \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3665                             \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3666                             \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3667
3668        const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3669                               \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3670                               \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3671                               \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3672                               \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3673                               \x6f\x72\x64\x00";
3674
3675        const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3676                                \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3677                                \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3678                                \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3679                                \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3680                                \x6f\x72\x64\x00";
3681
3682        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3683        assert_eq!(hsp.protocol_version(), 0x0a);
3684        assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3685        assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3686        assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3687        assert_eq!(hsp.connection_id(), 0x0b);
3688        assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3689        assert_eq!(
3690            hsp.capabilities(),
3691            CapabilityFlags::from_bits_truncate(0xf7ff)
3692        );
3693        assert_eq!(hsp.default_collation(), 0x08);
3694        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3695        assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3696        assert_eq!(hsp.auth_plugin_name_ref(), None);
3697
3698        let mut output = Vec::new();
3699        hsp.serialize(&mut output);
3700        assert_eq!(&output, HSP);
3701
3702        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3703        assert_eq!(hsp.protocol_version(), 0x0a);
3704        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3705        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3706        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3707        assert_eq!(hsp.connection_id(), 0x0a56);
3708        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3709        assert_eq!(
3710            hsp.capabilities(),
3711            CapabilityFlags::from_bits_truncate(0xc00fffff)
3712        );
3713        assert_eq!(hsp.default_collation(), 0x08);
3714        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3715        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3716        assert_eq!(
3717            hsp.auth_plugin_name_ref(),
3718            Some(&b"mysql_native_password"[..])
3719        );
3720
3721        let mut output = Vec::new();
3722        hsp.serialize(&mut output);
3723        assert_eq!(&output, HSP_2);
3724
3725        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3726        assert_eq!(hsp.protocol_version(), 0x0a);
3727        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3728        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3729        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3730        assert_eq!(hsp.connection_id(), 0x0a56);
3731        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3732        assert_eq!(
3733            hsp.capabilities(),
3734            CapabilityFlags::from_bits_truncate(0xc00fffff)
3735        );
3736        assert_eq!(hsp.default_collation(), 0x08);
3737        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3738        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3739        assert_eq!(
3740            hsp.auth_plugin_name_ref(),
3741            Some(&b"mysql_native_password"[..])
3742        );
3743
3744        let mut output = Vec::new();
3745        hsp.serialize(&mut output);
3746        assert_eq!(&output, HSP_3);
3747    }
3748
3749    #[test]
3750    fn should_parse_err_packet() {
3751        const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3752        \x6c\x65\x73\x20\x75\x73\x65\x64";
3753        const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3754        \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3755        const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3756
3757        let err_packet = ErrPacket::deserialize(
3758            CapabilityFlags::CLIENT_PROTOCOL_41,
3759            &mut ParseBuf(ERR_PACKET),
3760        )
3761        .unwrap();
3762        let err_packet = err_packet.server_error();
3763        assert_eq!(err_packet.error_code(), 1096);
3764        assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3765        assert_eq!(err_packet.message_str(), "No tables used");
3766
3767        let err_packet =
3768            ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3769                .unwrap();
3770        let server_error = err_packet.server_error();
3771        assert_eq!(server_error.error_code(), 1040);
3772        assert_eq!(server_error.sql_state_ref(), None);
3773        assert_eq!(server_error.message_str(), "Too many connections");
3774
3775        let err_packet = ErrPacket::deserialize(
3776            CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3777            &mut ParseBuf(PROGRESS_PACKET),
3778        )
3779        .unwrap();
3780        assert!(err_packet.is_progress_report());
3781        let progress_report = err_packet.progress_report();
3782        assert_eq!(progress_report.stage(), 1);
3783        assert_eq!(progress_report.max_stage(), 10);
3784        assert_eq!(progress_report.progress(), 23500);
3785        assert_eq!(progress_report.stage_info_str(), "stage name");
3786    }
3787
3788    #[test]
3789    fn should_parse_column_packet() {
3790        const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3791              \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3792        let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3793        assert_eq!(column.schema_str(), "schema");
3794        assert_eq!(column.table_str(), "table");
3795        assert_eq!(column.org_table_str(), "org_table");
3796        assert_eq!(column.name_str(), "name");
3797        assert_eq!(column.org_name_str(), "org_name");
3798        assert_eq!(
3799            column.character_set(),
3800            CollationId::UTF8MB3_GENERAL_CI as u16
3801        );
3802        assert_eq!(column.column_length(), 15);
3803        assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3804        assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3805        assert_eq!(column.decimals(), 8);
3806    }
3807
3808    #[test]
3809    fn should_parse_auth_switch_request() {
3810        const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3811                                 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3812                                 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3813        let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3814        assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3815        assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3816    }
3817
3818    #[test]
3819    fn should_parse_auth_more_data() {
3820        const PAYLOAD: &[u8] = b"\x01\x04";
3821        let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3822        assert_eq!(packet.data(), b"\x04",);
3823    }
3824
3825    #[test]
3826    fn should_parse_ok_packet() {
3827        const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3828        const RESULT_SET_TERMINATOR: &[u8] = &[
3829            0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3830            0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3831            0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3832            0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3833            0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3834            0x73, 0x65, 0x63, 0x2e,
3835        ];
3836        const SESS_STATE_SYS_VAR_OK: &[u8] =
3837            b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3838              \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3839        const SESS_STATE_SCHEMA_OK: &[u8] =
3840            b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3841        const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3842        const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3843
3844        // packet starting with 0x00 is not an ok packet if it terminates a result set
3845        OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3846            CapabilityFlags::empty(),
3847            &mut ParseBuf(PLAIN_OK),
3848        )
3849        .unwrap_err();
3850
3851        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3852            CapabilityFlags::empty(),
3853            &mut ParseBuf(PLAIN_OK),
3854        )
3855        .unwrap()
3856        .into();
3857        assert_eq!(ok_packet.affected_rows(), 1);
3858        assert_eq!(ok_packet.last_insert_id(), None);
3859        assert_eq!(
3860            ok_packet.status_flags(),
3861            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3862        );
3863        assert_eq!(ok_packet.warnings(), 0);
3864        assert_eq!(ok_packet.info_ref(), None);
3865        assert_eq!(ok_packet.session_state_info_ref(), None);
3866
3867        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3868            CapabilityFlags::CLIENT_SESSION_TRACK,
3869            &mut ParseBuf(PLAIN_OK),
3870        )
3871        .unwrap()
3872        .into();
3873        assert_eq!(ok_packet.affected_rows(), 1);
3874        assert_eq!(ok_packet.last_insert_id(), None);
3875        assert_eq!(
3876            ok_packet.status_flags(),
3877            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3878        );
3879        assert_eq!(ok_packet.warnings(), 0);
3880        assert_eq!(ok_packet.info_ref(), None);
3881        assert_eq!(ok_packet.session_state_info_ref(), None);
3882
3883        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3884            CapabilityFlags::CLIENT_SESSION_TRACK,
3885            &mut ParseBuf(RESULT_SET_TERMINATOR),
3886        )
3887        .unwrap()
3888        .into();
3889        assert_eq!(ok_packet.affected_rows(), 0);
3890        assert_eq!(ok_packet.last_insert_id(), None);
3891        assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3892        assert_eq!(ok_packet.warnings(), 0);
3893        assert_eq!(
3894            ok_packet.info_str(),
3895            Some(Cow::Borrowed(
3896                "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3897            ))
3898        );
3899        assert_eq!(ok_packet.session_state_info_ref(), None);
3900
3901        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3902            CapabilityFlags::CLIENT_SESSION_TRACK,
3903            &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3904        )
3905        .unwrap()
3906        .into();
3907        assert_eq!(ok_packet.affected_rows(), 0);
3908        assert_eq!(ok_packet.last_insert_id(), None);
3909        assert_eq!(
3910            ok_packet.status_flags(),
3911            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3912        );
3913        assert_eq!(ok_packet.warnings(), 0);
3914        assert_eq!(ok_packet.info_ref(), None);
3915        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3916
3917        match sess_state_info.decode().unwrap() {
3918            SessionStateChange::SystemVariables(mut vals) => {
3919                let val = vals.pop().unwrap();
3920                assert_eq!(val.name_bytes(), b"autocommit");
3921                assert_eq!(val.value_bytes(), b"OFF");
3922                assert!(vals.is_empty());
3923            }
3924            _ => panic!(),
3925        }
3926
3927        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3928            CapabilityFlags::CLIENT_SESSION_TRACK,
3929            &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3930        )
3931        .unwrap()
3932        .into();
3933        assert_eq!(ok_packet.affected_rows(), 0);
3934        assert_eq!(ok_packet.last_insert_id(), None);
3935        assert_eq!(
3936            ok_packet.status_flags(),
3937            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3938        );
3939        assert_eq!(ok_packet.warnings(), 0);
3940        assert_eq!(ok_packet.info_ref(), None);
3941        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3942        match sess_state_info.decode().unwrap() {
3943            SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3944            _ => panic!(),
3945        }
3946
3947        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3948            CapabilityFlags::CLIENT_SESSION_TRACK,
3949            &mut ParseBuf(SESS_STATE_TRACK_OK),
3950        )
3951        .unwrap()
3952        .into();
3953        assert_eq!(ok_packet.affected_rows(), 0);
3954        assert_eq!(ok_packet.last_insert_id(), None);
3955        assert_eq!(
3956            ok_packet.status_flags(),
3957            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3958        );
3959        assert_eq!(ok_packet.warnings(), 0);
3960        assert_eq!(ok_packet.info_ref(), None);
3961        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3962        assert_eq!(
3963            sess_state_info.decode().unwrap(),
3964            SessionStateChange::IsTracked(true),
3965        );
3966
3967        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<OldEofPacket>::deserialize(
3968            CapabilityFlags::CLIENT_SESSION_TRACK,
3969            &mut ParseBuf(EOF),
3970        )
3971        .unwrap()
3972        .into();
3973        assert_eq!(ok_packet.affected_rows(), 0);
3974        assert_eq!(ok_packet.last_insert_id(), None);
3975        assert_eq!(
3976            ok_packet.status_flags(),
3977            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3978        );
3979        assert_eq!(ok_packet.warnings(), 0);
3980        assert_eq!(ok_packet.info_ref(), None);
3981        assert_eq!(ok_packet.session_state_info_ref(), None);
3982    }
3983
3984    #[test]
3985    fn should_build_handshake_response() {
3986        let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
3987        let response = HandshakeResponse::new(
3988            Some(&[][..]),
3989            (5u16, 5, 5),
3990            Some(&b"root"[..]),
3991            None::<&'static [u8]>,
3992            Some(AuthPlugin::MysqlNativePassword),
3993            flags_without_db_name,
3994            None,
3995            1_u32.to_be(),
3996        );
3997        let mut actual = Vec::new();
3998        response.serialize(&mut actual);
3999
4000        let expected: Vec<u8> = [
4001            0x05, 0xa2, 0xae, 0x81, // client capabilities
4002            0x00, 0x00, 0x00, 0x01, // max packet
4003            0x2d, // charset
4004            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4005            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4006            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4007            0x00, // blank scramble
4008            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4009            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4010        ]
4011        .to_vec();
4012
4013        assert_eq!(expected, actual);
4014
4015        let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
4016        let response = HandshakeResponse::new(
4017            Some(&[][..]),
4018            (5u16, 5, 5),
4019            Some(&b"root"[..]),
4020            Some(&b"mydb"[..]),
4021            Some(AuthPlugin::MysqlNativePassword),
4022            flags_with_db_name,
4023            None,
4024            1_u32.to_be(),
4025        );
4026        let mut actual = Vec::new();
4027        response.serialize(&mut actual);
4028
4029        let expected: Vec<u8> = [
4030            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4031            0x00, 0x00, 0x00, 0x01, // max packet
4032            0x2d, // charset
4033            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4034            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4035            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4036            0x00, // blank scramble
4037            0x6d, 0x79, 0x64, 0x62, 0x00, // dbname
4038            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4039            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4040        ]
4041        .to_vec();
4042
4043        assert_eq!(expected, actual);
4044
4045        let response = HandshakeResponse::new(
4046            Some(&[][..]),
4047            (5u16, 5, 5),
4048            Some(&b"root"[..]),
4049            Some(&b"mydb"[..]),
4050            Some(AuthPlugin::MysqlNativePassword),
4051            flags_without_db_name,
4052            None,
4053            1_u32.to_be(),
4054        );
4055        let mut actual = Vec::new();
4056        response.serialize(&mut actual);
4057        assert_eq!(expected, actual);
4058
4059        let response = HandshakeResponse::new(
4060            Some(&[][..]),
4061            (5u16, 5, 5),
4062            Some(&b"root"[..]),
4063            Some(&[][..]),
4064            Some(AuthPlugin::MysqlNativePassword),
4065            flags_with_db_name,
4066            None,
4067            1_u32.to_be(),
4068        );
4069        let mut actual = Vec::new();
4070        response.serialize(&mut actual);
4071
4072        let expected: Vec<u8> = [
4073            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4074            0x00, 0x00, 0x00, 0x01, // max packet
4075            0x2d, // charset
4076            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4077            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4078            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4079            0x00, // blank db_name
4080            0x00, // blank scramble
4081            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4082            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4083        ]
4084        .to_vec();
4085        assert_eq!(expected, actual);
4086    }
4087
4088    #[test]
4089    fn parse_str_to_sid() {
4090        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4091        let sid = input.parse::<Sid<'_>>().unwrap();
4092        let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4093        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4094        assert_eq!(sid.intervals.len(), 1);
4095        assert_eq!(sid.intervals[0].start.0, 23);
4096        assert_eq!(sid.intervals[0].end.0, 24);
4097
4098        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4099        let sid = input.parse::<Sid<'_>>().unwrap();
4100        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4101        assert_eq!(sid.intervals.len(), 2);
4102        assert_eq!(sid.intervals[0].start.0, 1);
4103        assert_eq!(sid.intervals[0].end.0, 6);
4104        assert_eq!(sid.intervals[1].start.0, 10);
4105        assert_eq!(sid.intervals[1].end.0, 16);
4106
4107        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4108        let e = input.parse::<Sid<'_>>().unwrap_err();
4109        assert_eq!(
4110            e.to_string(),
4111            "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4112        );
4113
4114        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4115        let e = input.parse::<Sid<'_>>().unwrap_err();
4116        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4117
4118        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4119        let e = input.parse::<Sid<'_>>().unwrap_err();
4120        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4121
4122        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4123        let e = input.parse::<Sid<'_>>().unwrap_err();
4124        assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4125
4126        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4127        let e = input.parse::<Sid<'_>>().unwrap_err();
4128        assert_eq!(
4129            e.to_string(),
4130            "start(4) >= end(4) in GnoInterval".to_string()
4131        );
4132    }
4133
4134    #[test]
4135    fn should_parse_rsa_public_key_response_packet() {
4136        const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4137
4138        let rsa_public_key_response =
4139            PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4140
4141        assert!(rsa_public_key_response.is_ok());
4142        assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4143    }
4144
4145    #[test]
4146    fn should_build_rsa_public_key_response_packet() {
4147        let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4148
4149        let mut actual = Vec::new();
4150        rsa_public_key_response.serialize(&mut actual);
4151
4152        let expected = b"\x01test".to_vec();
4153
4154        assert_eq!(expected, actual);
4155    }
4156}