mysql_common/packets/
mod.rs

1// Copyright (c) 2017 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::{
17    borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::scramble::create_response_for_ed25519;
22use crate::{
23    constants::{
24        CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
25        MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags,
26        StmtExecuteParamsFlags,
27    },
28    io::{BufMutExt, ParseBuf},
29    misc::{
30        lenenc_str_len,
31        raw::{
32            Const, Either, RawBytes, RawConst, RawInt, Skip,
33            bytes::{
34                BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U8Bytes,
35                U32Bytes,
36            },
37            int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
38            seq::Seq,
39        },
40        unexpected_buf_eof,
41    },
42    proto::{MyDeserialize, MySerialize},
43    value::{ClientSide, SerializationSide, Value},
44};
45
46use self::session_state_change::SessionStateChange;
47
48static MARIADB_VERSION_RE: LazyLock<Regex> =
49    LazyLock::new(|| Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap());
50static VERSION_RE: LazyLock<Regex> =
51    LazyLock::new(|| Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap());
52
53macro_rules! define_header {
54    ($name:ident, $err:ident($msg:literal), $val:literal) => {
55        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
56        #[error($msg)]
57        pub struct $err;
58        pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
59    };
60    ($name:ident, $cmd:ident, $err:ident) => {
61        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
62        #[error("Invalid header for {}", stringify!($cmd))]
63        pub struct $err;
64        pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
65    };
66}
67
68macro_rules! define_const {
69    ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
70        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
71        #[error($msg)]
72        pub struct $err;
73        pub type $name = $kind<$err, $val>;
74    };
75}
76
77macro_rules! define_const_bytes {
78    ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr_2021, $len:literal) => {
79        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
80        #[error($msg)]
81        pub struct $err;
82
83        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
84        pub struct $vname;
85
86        impl ConstBytesValue<$len> for $vname {
87            const VALUE: [u8; $len] = $val;
88            type Error = $err;
89        }
90
91        pub type $name = ConstBytes<$vname, $len>;
92    };
93}
94
95pub mod binlog_request;
96pub mod caching_sha2_password;
97pub mod session_state_change;
98
99define_const_bytes!(
100    Catalog,
101    ColumnDefinitionCatalog,
102    InvalidCatalog("Invalid catalog value in the column definition"),
103    *b"\x03def",
104    4
105);
106
107define_const!(
108    ConstU8,
109    FixedLengthFieldsLen,
110    InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
111    0x0c
112);
113
114/// Dynamically-sized column metadata — a part of the [`Column`] packet.
115#[derive(Debug, Default, Clone, Eq, PartialEq)]
116struct ColumnMeta<'a> {
117    schema: RawBytes<'a, LenEnc>,
118    table: RawBytes<'a, LenEnc>,
119    org_table: RawBytes<'a, LenEnc>,
120    name: RawBytes<'a, LenEnc>,
121    org_name: RawBytes<'a, LenEnc>,
122}
123
124impl ColumnMeta<'_> {
125    pub fn into_owned(self) -> ColumnMeta<'static> {
126        ColumnMeta {
127            schema: self.schema.into_owned(),
128            table: self.table.into_owned(),
129            org_table: self.org_table.into_owned(),
130            name: self.name.into_owned(),
131            org_name: self.org_name.into_owned(),
132        }
133    }
134
135    /// Returns the value of the [`ColumnMeta::schema`] field as a byte slice.
136    pub fn schema_ref(&self) -> &[u8] {
137        self.schema.as_bytes()
138    }
139
140    /// Returns the value of the [`ColumnMeta::schema`] field as a string (lossy converted).
141    pub fn schema_str(&self) -> Cow<'_, str> {
142        String::from_utf8_lossy(self.schema_ref())
143    }
144
145    /// Returns the value of the [`ColumnMeta::table`] field as a byte slice.
146    pub fn table_ref(&self) -> &[u8] {
147        self.table.as_bytes()
148    }
149
150    /// Returns the value of the [`ColumnMeta::table`] field as a string (lossy converted).
151    pub fn table_str(&self) -> Cow<'_, str> {
152        String::from_utf8_lossy(self.table_ref())
153    }
154
155    /// Returns the value of the [`ColumnMeta::org_table`] field as a byte slice.
156    ///
157    /// "org_table" is for original table name.
158    pub fn org_table_ref(&self) -> &[u8] {
159        self.org_table.as_bytes()
160    }
161
162    /// Returns the value of the [`ColumnMeta::org_table`] field as a string (lossy converted).
163    pub fn org_table_str(&self) -> Cow<'_, str> {
164        String::from_utf8_lossy(self.org_table_ref())
165    }
166
167    /// Returns the value of the [`ColumnMeta::name`] field as a byte slice.
168    pub fn name_ref(&self) -> &[u8] {
169        self.name.as_bytes()
170    }
171
172    /// Returns the value of the [`ColumnMeta::name`] field as a string (lossy converted).
173    pub fn name_str(&self) -> Cow<'_, str> {
174        String::from_utf8_lossy(self.name_ref())
175    }
176
177    /// Returns the value of the [`ColumnMeta::org_name`] field as a byte slice.
178    ///
179    /// "org_name" is for original column name.
180    pub fn org_name_ref(&self) -> &[u8] {
181        self.org_name.as_bytes()
182    }
183
184    /// Returns value of the [`ColumnMeta::org_name`] field as a string (lossy converted).
185    pub fn org_name_str(&self) -> Cow<'_, str> {
186        String::from_utf8_lossy(self.org_name_ref())
187    }
188}
189
190impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
191    const SIZE: Option<usize> = None;
192    type Ctx = ();
193
194    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
195        Ok(Self {
196            schema: buf.parse_unchecked(())?,
197            table: buf.parse_unchecked(())?,
198            org_table: buf.parse_unchecked(())?,
199            name: buf.parse_unchecked(())?,
200            org_name: buf.parse_unchecked(())?,
201        })
202    }
203}
204
205impl MySerialize for ColumnMeta<'_> {
206    fn serialize(&self, buf: &mut Vec<u8>) {
207        self.schema.serialize(&mut *buf);
208        self.table.serialize(&mut *buf);
209        self.org_table.serialize(&mut *buf);
210        self.name.serialize(&mut *buf);
211        self.org_name.serialize(&mut *buf);
212    }
213}
214
215/// Represents MySql Column (column packet).
216#[derive(Debug, Clone, Eq, PartialEq)]
217pub struct Column {
218    catalog: ColumnDefinitionCatalog,
219    meta: Arc<ColumnMeta<'static>>,
220    fixed_length_fields_len: FixedLengthFieldsLen,
221    column_length: RawInt<LeU32>,
222    character_set: RawInt<LeU16>,
223    column_type: Const<ColumnType, u8>,
224    flags: Const<ColumnFlags, LeU16>,
225    decimals: RawInt<u8>,
226    __filler: Skip<2>,
227    // COM_FIELD_LIST is deprecated, so we won't support it
228}
229
230impl<'de> MyDeserialize<'de> for Column {
231    const SIZE: Option<usize> = None;
232    type Ctx = ();
233
234    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
235        let catalog = buf.parse(())?;
236        let meta = Arc::new(buf.parse::<ColumnMeta<'_>>(())?.into_owned());
237        let mut buf: ParseBuf<'_> = buf.parse(13)?;
238
239        Ok(Column {
240            catalog,
241            meta,
242            fixed_length_fields_len: buf.parse_unchecked(())?,
243            character_set: buf.parse_unchecked(())?,
244            column_length: buf.parse_unchecked(())?,
245            column_type: buf.parse_unchecked(())?,
246            flags: buf.parse_unchecked(())?,
247            decimals: buf.parse_unchecked(())?,
248            __filler: buf.parse_unchecked(())?,
249        })
250    }
251}
252
253impl MySerialize for Column {
254    fn serialize(&self, buf: &mut Vec<u8>) {
255        self.catalog.serialize(&mut *buf);
256        self.meta.serialize(&mut *buf);
257        self.fixed_length_fields_len.serialize(&mut *buf);
258        self.column_length.serialize(&mut *buf);
259        self.character_set.serialize(&mut *buf);
260        self.column_type.serialize(&mut *buf);
261        self.flags.serialize(&mut *buf);
262        self.decimals.serialize(&mut *buf);
263        self.__filler.serialize(&mut *buf);
264    }
265}
266
267impl Column {
268    pub fn new(column_type: ColumnType) -> Self {
269        Self {
270            catalog: Default::default(),
271            meta: Default::default(),
272            fixed_length_fields_len: Default::default(),
273            column_length: Default::default(),
274            character_set: Default::default(),
275            flags: Default::default(),
276            column_type: Const::new(column_type),
277            decimals: Default::default(),
278            __filler: Skip,
279        }
280    }
281
282    pub fn with_schema(mut self, schema: &[u8]) -> Self {
283        Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
284        self
285    }
286
287    pub fn with_table(mut self, table: &[u8]) -> Self {
288        Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
289        self
290    }
291
292    pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
293        Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
294        self
295    }
296
297    pub fn with_name(mut self, name: &[u8]) -> Self {
298        Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
299        self
300    }
301
302    pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
303        Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
304        self
305    }
306
307    pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
308        self.flags = Const::new(flags);
309        self
310    }
311
312    pub fn with_column_length(mut self, column_length: u32) -> Self {
313        self.column_length = RawInt::new(column_length);
314        self
315    }
316
317    pub fn with_character_set(mut self, character_set: u16) -> Self {
318        self.character_set = RawInt::new(character_set);
319        self
320    }
321
322    pub fn with_decimals(mut self, decimals: u8) -> Self {
323        self.decimals = RawInt::new(decimals);
324        self
325    }
326
327    /// Returns value of the column_length field of a column packet.
328    ///
329    /// Can be used for text-output formatting.
330    pub fn column_length(&self) -> u32 {
331        *self.column_length
332    }
333
334    /// Returns value of the column_type field of a column packet.
335    pub fn column_type(&self) -> ColumnType {
336        *self.column_type
337    }
338
339    /// Returns value of the character_set field of a column packet.
340    pub fn character_set(&self) -> u16 {
341        *self.character_set
342    }
343
344    /// Returns value of the flags field of a column packet.
345    pub fn flags(&self) -> ColumnFlags {
346        *self.flags
347    }
348
349    /// Returns value of the decimals field of a column packet.
350    ///
351    /// Max shown decimal digits. Can be used for text-output formatting
352    ///
353    /// *   `0x00` for integers and static strings
354    /// *   `0x1f` for dynamic strings, double, float
355    /// *   `0x00..=0x51` for decimals
356    pub fn decimals(&self) -> u8 {
357        *self.decimals
358    }
359
360    /// Returns value of the schema field of a column packet as a byte slice.
361    #[inline(always)]
362    pub fn schema_ref(&self) -> &[u8] {
363        self.meta.schema_ref()
364    }
365
366    /// Returns value of the schema field of a column packet as a string (lossy converted).
367    #[inline(always)]
368    pub fn schema_str(&self) -> Cow<'_, str> {
369        self.meta.schema_str()
370    }
371
372    /// Returns value of the table field of a column packet as a byte slice.
373    #[inline(always)]
374    pub fn table_ref(&self) -> &[u8] {
375        self.meta.table_ref()
376    }
377
378    /// Returns value of the table field of a column packet as a string (lossy converted).
379    #[inline(always)]
380    pub fn table_str(&self) -> Cow<'_, str> {
381        self.meta.table_str()
382    }
383
384    /// Returns value of the org_table field of a column packet as a byte slice.
385    ///
386    /// "org_table" is for original table name.
387    #[inline(always)]
388    pub fn org_table_ref(&self) -> &[u8] {
389        self.meta.org_table_ref()
390    }
391
392    /// Returns value of the org_table field of a column packet as a string (lossy converted).
393    #[inline(always)]
394    pub fn org_table_str(&self) -> Cow<'_, str> {
395        self.meta.org_table_str()
396    }
397
398    /// Returns value of the name field of a column packet as a byte slice.
399    #[inline(always)]
400    pub fn name_ref(&self) -> &[u8] {
401        self.meta.name_ref()
402    }
403
404    /// Returns value of the name field of a column packet as a string (lossy converted).
405    #[inline(always)]
406    pub fn name_str(&self) -> Cow<'_, str> {
407        self.meta.name_str()
408    }
409
410    /// Returns value of the org_name field of a column packet as a byte slice.
411    ///
412    /// "org_name" is for original column name.
413    #[inline(always)]
414    pub fn org_name_ref(&self) -> &[u8] {
415        self.meta.org_name_ref()
416    }
417
418    /// Returns value of the org_name field of a column packet as a string (lossy converted).
419    #[inline(always)]
420    pub fn org_name_str(&self) -> Cow<'_, str> {
421        self.meta.org_name_str()
422    }
423}
424
425/// Represents change in session state (part of MySql's Ok packet).
426#[derive(Debug, Clone, Eq, PartialEq)]
427pub struct SessionStateInfo<'a> {
428    data_type: Const<SessionStateType, u8>,
429    data: RawBytes<'a, LenEnc>,
430}
431
432impl SessionStateInfo<'_> {
433    pub fn into_owned(self) -> SessionStateInfo<'static> {
434        let SessionStateInfo { data_type, data } = self;
435        SessionStateInfo {
436            data_type,
437            data: data.into_owned(),
438        }
439    }
440
441    pub fn data_type(&self) -> SessionStateType {
442        *self.data_type
443    }
444
445    /// Returns raw session state info data.
446    pub fn data_ref(&self) -> &[u8] {
447        self.data.as_bytes()
448    }
449
450    /// Tries to decode session state info data.
451    pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
452        ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
453    }
454}
455
456impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
457    const SIZE: Option<usize> = None;
458    type Ctx = ();
459
460    fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
461        Ok(SessionStateInfo {
462            data_type: buf.parse(())?,
463            data: buf.parse(())?,
464        })
465    }
466}
467
468impl MySerialize for SessionStateInfo<'_> {
469    fn serialize(&self, buf: &mut Vec<u8>) {
470        self.data_type.serialize(&mut *buf);
471        self.data.serialize(buf);
472    }
473}
474
475/// Represents MySql's Ok packet.
476#[derive(Debug, Clone, Eq, PartialEq)]
477pub struct OkPacketBody<'a> {
478    affected_rows: RawInt<LenEnc>,
479    last_insert_id: RawInt<LenEnc>,
480    status_flags: Const<StatusFlags, LeU16>,
481    warnings: RawInt<LeU16>,
482    info: RawBytes<'a, LenEnc>,
483    session_state_info: RawBytes<'a, LenEnc>,
484}
485
486/// OK packet kind (see _OK packet identifier_ section of [WL#7766][1]).
487///
488/// [1]: https://dev.mysql.com/worklog/task/?id=7766
489pub trait OkPacketKind {
490    const HEADER: u8;
491
492    fn parse_body<'de>(
493        capabilities: CapabilityFlags,
494        buf: &mut ParseBuf<'de>,
495    ) -> io::Result<OkPacketBody<'de>>;
496}
497
498/// Ok packet that terminates a result set (text or binary).
499#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
500pub struct ResultSetTerminator;
501
502impl OkPacketKind for ResultSetTerminator {
503    const HEADER: u8 = 0xFE;
504
505    fn parse_body<'de>(
506        capabilities: CapabilityFlags,
507        buf: &mut ParseBuf<'de>,
508    ) -> io::Result<OkPacketBody<'de>> {
509        // We need to skip affected_rows and insert_id here
510        // because valid content of EOF packet includes
511        // packet marker, server status and warning count only.
512        // (see `read_ok_ex` in sql-common/client.cc)
513        buf.parse::<RawInt<LenEnc>>(())?;
514        buf.parse::<RawInt<LenEnc>>(())?;
515
516        // assume CLIENT_PROTOCOL_41 flag
517        let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
518        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
519        let warnings = sbuf.parse_unchecked(())?;
520
521        let (info, session_state_info) =
522            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
523                let info = buf.parse(())?;
524                let session_state_info =
525                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
526                        buf.parse(())?
527                    } else {
528                        RawBytes::default()
529                    };
530                (info, session_state_info)
531            } else if !buf.is_empty() && buf.0[0] > 0 {
532                // The `info` field is a `string<EOF>` according to the MySQL Internals
533                // Manual, but actually it's a `string<lenenc>`.
534                // SEE: sql/protocol_classics.cc `net_send_ok`
535                let info = buf.parse(())?;
536                (info, RawBytes::default())
537            } else {
538                (RawBytes::default(), RawBytes::default())
539            };
540
541        Ok(OkPacketBody {
542            affected_rows: RawInt::new(0),
543            last_insert_id: RawInt::new(0),
544            status_flags,
545            warnings,
546            info,
547            session_state_info,
548        })
549    }
550}
551
552/// Old deprecated EOF packet.
553#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
554pub struct OldEofPacket;
555
556impl OkPacketKind for OldEofPacket {
557    const HEADER: u8 = 0xFE;
558
559    fn parse_body<'de>(
560        _: CapabilityFlags,
561        buf: &mut ParseBuf<'de>,
562    ) -> io::Result<OkPacketBody<'de>> {
563        // We assume that CLIENT_PROTOCOL_41 was set
564        let mut buf: ParseBuf<'_> = buf.parse(4)?;
565        let warnings = buf.parse_unchecked(())?;
566        let status_flags = buf.parse_unchecked(())?;
567
568        Ok(OkPacketBody {
569            affected_rows: RawInt::new(0),
570            last_insert_id: RawInt::new(0),
571            status_flags,
572            warnings,
573            info: RawBytes::new(&[][..]),
574            session_state_info: RawBytes::new(&[][..]),
575        })
576    }
577}
578
579/// This packet terminates a binlog network stream.
580#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
581pub struct NetworkStreamTerminator;
582
583impl OkPacketKind for NetworkStreamTerminator {
584    const HEADER: u8 = 0xFE;
585
586    fn parse_body<'de>(
587        flags: CapabilityFlags,
588        buf: &mut ParseBuf<'de>,
589    ) -> io::Result<OkPacketBody<'de>> {
590        OldEofPacket::parse_body(flags, buf)
591    }
592}
593
594/// Ok packet that is not a result set terminator.
595#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
596pub struct CommonOkPacket;
597
598impl OkPacketKind for CommonOkPacket {
599    const HEADER: u8 = 0x00;
600
601    fn parse_body<'de>(
602        capabilities: CapabilityFlags,
603        buf: &mut ParseBuf<'de>,
604    ) -> io::Result<OkPacketBody<'de>> {
605        let affected_rows = buf.parse(())?;
606        let last_insert_id = buf.parse(())?;
607
608        // We assume that CLIENT_PROTOCOL_41 was set
609        let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
610        let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
611        let warnings = sbuf.parse_unchecked(())?;
612
613        let (info, session_state_info) =
614            if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
615                let info = buf.parse(())?;
616                let session_state_info =
617                    if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
618                        buf.parse(())?
619                    } else {
620                        RawBytes::default()
621                    };
622                (info, session_state_info)
623            } else if !buf.is_empty() && buf.0[0] > 0 {
624                // The `info` field is a `string<EOF>` according to the MySQL Internals
625                // Manual, but actually it's a `string<lenenc>`.
626                // SEE: sql/protocol_classics.cc `net_send_ok`
627                let info = buf.parse(())?;
628                (info, RawBytes::default())
629            } else {
630                (RawBytes::default(), RawBytes::default())
631            };
632
633        Ok(OkPacketBody {
634            affected_rows,
635            last_insert_id,
636            status_flags,
637            warnings,
638            info,
639            session_state_info,
640        })
641    }
642}
643
644impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
645    type Error = io::Error;
646
647    fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
648        Ok(OkPacket {
649            affected_rows: *body.affected_rows,
650            last_insert_id: if *body.last_insert_id == 0 {
651                None
652            } else {
653                Some(*body.last_insert_id)
654            },
655            status_flags: *body.status_flags,
656            warnings: *body.warnings,
657            info: if !body.info.is_empty() {
658                Some(body.info)
659            } else {
660                None
661            },
662            session_state_info: if !body.session_state_info.is_empty() {
663                Some(body.session_state_info)
664            } else {
665                None
666            },
667        })
668    }
669}
670
671/// Represents MySql's Ok packet.
672#[derive(Debug, Clone, Eq, PartialEq)]
673pub struct OkPacket<'a> {
674    affected_rows: u64,
675    last_insert_id: Option<u64>,
676    status_flags: StatusFlags,
677    warnings: u16,
678    info: Option<RawBytes<'a, LenEnc>>,
679    session_state_info: Option<RawBytes<'a, LenEnc>>,
680}
681
682impl OkPacket<'_> {
683    pub fn into_owned(self) -> OkPacket<'static> {
684        OkPacket {
685            affected_rows: self.affected_rows,
686            last_insert_id: self.last_insert_id,
687            status_flags: self.status_flags,
688            warnings: self.warnings,
689            info: self.info.map(|x| x.into_owned()),
690            session_state_info: self.session_state_info.map(|x| x.into_owned()),
691        }
692    }
693
694    /// Value of the affected_rows field of an Ok packet.
695    pub fn affected_rows(&self) -> u64 {
696        self.affected_rows
697    }
698
699    /// Value of the last_insert_id field of an Ok packet.
700    pub fn last_insert_id(&self) -> Option<u64> {
701        self.last_insert_id
702    }
703
704    /// Value of the status_flags field of an Ok packet.
705    pub fn status_flags(&self) -> StatusFlags {
706        self.status_flags
707    }
708
709    /// Value of the warnings field of an Ok packet.
710    pub fn warnings(&self) -> u16 {
711        self.warnings
712    }
713
714    /// Value of the info field of an Ok packet as a byte slice.
715    pub fn info_ref(&self) -> Option<&[u8]> {
716        self.info.as_ref().map(|x| x.as_bytes())
717    }
718
719    /// Value of the info field of an Ok packet as a string (lossy converted).
720    pub fn info_str(&self) -> Option<Cow<'_, str>> {
721        self.info.as_ref().map(|x| x.as_str())
722    }
723
724    /// Returns raw reference to a session state info.
725    pub fn session_state_info_ref(&self) -> Option<&[u8]> {
726        self.session_state_info.as_ref().map(|x| x.as_bytes())
727    }
728
729    /// Tries to parse session state info, if any.
730    pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
731        self.session_state_info_ref()
732            .map(|data| {
733                let mut data = ParseBuf(data);
734                let mut entries = Vec::new();
735                while !data.is_empty() {
736                    entries.push(data.parse(())?);
737                }
738                Ok(entries)
739            })
740            .transpose()
741            .map(|x| x.unwrap_or_default())
742    }
743}
744
745#[derive(Debug, Clone, PartialEq, Eq)]
746pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
747
748impl<'de, T> OkPacketDeserializer<'de, T> {
749    pub fn into_inner(self) -> OkPacket<'de> {
750        self.0
751    }
752}
753
754impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
755    fn from(x: OkPacketDeserializer<'de, T>) -> Self {
756        x.0
757    }
758}
759
760#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
761#[error("Invalid OK packet header")]
762pub struct InvalidOkPacketHeader;
763
764impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
765    const SIZE: Option<usize> = None;
766    type Ctx = CapabilityFlags;
767
768    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
769        if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
770            let body = T::parse_body(capabilities, buf)?;
771            let ok = OkPacket::try_from(body)?;
772            Ok(Self(ok, PhantomData))
773        } else {
774            Err(io::Error::new(
775                io::ErrorKind::InvalidData,
776                InvalidOkPacketHeader,
777            ))
778        }
779    }
780}
781
782/// Progress report information (may be in an error packet of MariaDB server).
783#[derive(Debug, Clone, Eq, PartialEq)]
784pub struct ProgressReport<'a> {
785    stage: RawInt<u8>,
786    max_stage: RawInt<u8>,
787    progress: RawInt<LeU24>,
788    stage_info: RawBytes<'a, LenEnc>,
789}
790
791impl<'a> ProgressReport<'a> {
792    pub fn new(
793        stage: u8,
794        max_stage: u8,
795        progress: u32,
796        stage_info: impl Into<Cow<'a, [u8]>>,
797    ) -> ProgressReport<'a> {
798        ProgressReport {
799            stage: RawInt::new(stage),
800            max_stage: RawInt::new(max_stage),
801            progress: RawInt::new(progress),
802            stage_info: RawBytes::new(stage_info),
803        }
804    }
805
806    /// 1 to max_stage
807    pub fn stage(&self) -> u8 {
808        *self.stage
809    }
810
811    pub fn max_stage(&self) -> u8 {
812        *self.max_stage
813    }
814
815    /// Progress as '% * 1000'
816    pub fn progress(&self) -> u32 {
817        *self.progress
818    }
819
820    /// Status or state name as a byte slice.
821    pub fn stage_info_ref(&self) -> &[u8] {
822        self.stage_info.as_bytes()
823    }
824
825    /// Status or state name as a string (lossy converted).
826    pub fn stage_info_str(&self) -> Cow<'_, str> {
827        self.stage_info.as_str()
828    }
829
830    pub fn into_owned(self) -> ProgressReport<'static> {
831        ProgressReport {
832            stage: self.stage,
833            max_stage: self.max_stage,
834            progress: self.progress,
835            stage_info: self.stage_info.into_owned(),
836        }
837    }
838}
839
840impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
841    const SIZE: Option<usize> = None;
842    type Ctx = ();
843
844    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
845        let mut sbuf: ParseBuf<'_> = buf.parse(6)?;
846
847        sbuf.skip(1); // Ignore number of strings.
848
849        Ok(ProgressReport {
850            stage: sbuf.parse_unchecked(())?,
851            max_stage: sbuf.parse_unchecked(())?,
852            progress: sbuf.parse_unchecked(())?,
853            stage_info: buf.parse(())?,
854        })
855    }
856}
857
858impl MySerialize for ProgressReport<'_> {
859    fn serialize(&self, buf: &mut Vec<u8>) {
860        buf.put_u8(1);
861        self.stage.serialize(&mut *buf);
862        self.max_stage.serialize(&mut *buf);
863        self.progress.serialize(&mut *buf);
864        self.stage_info.serialize(buf);
865    }
866}
867
868impl fmt::Display for ProgressReport<'_> {
869    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
870        write!(
871            f,
872            "Stage: {} of {} '{}'  {:.2}% of stage done",
873            self.stage(),
874            self.max_stage(),
875            self.progress(),
876            self.stage_info_str()
877        )
878    }
879}
880
881define_header!(
882    ErrPacketHeader,
883    InvalidErrPacketHeader("Invalid error packet header"),
884    0xFF
885);
886
887/// MySql error packet.
888///
889/// May hold an error or a progress report.
890#[derive(Debug, Clone, PartialEq)]
891pub enum ErrPacket<'a> {
892    Error(ServerError<'a>),
893    Progress(ProgressReport<'a>),
894}
895
896impl ErrPacket<'_> {
897    /// Returns false if this error packet contains progress report.
898    pub fn is_error(&self) -> bool {
899        matches!(self, ErrPacket::Error { .. })
900    }
901
902    /// Returns true if this error packet contains progress report.
903    pub fn is_progress_report(&self) -> bool {
904        !self.is_error()
905    }
906
907    /// Will panic if ErrPacket does not contains progress report
908    pub fn progress_report(&self) -> &ProgressReport<'_> {
909        match *self {
910            ErrPacket::Progress(ref progress_report) => progress_report,
911            _ => panic!("This ErrPacket does not contains progress report"),
912        }
913    }
914
915    /// Will panic if ErrPacket does not contains a `ServerError`.
916    pub fn server_error(&self) -> &ServerError<'_> {
917        match self {
918            ErrPacket::Error(error) => error,
919            ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
920        }
921    }
922}
923
924impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
925    const SIZE: Option<usize> = None;
926    type Ctx = CapabilityFlags;
927
928    fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
929        let mut sbuf: ParseBuf<'_> = buf.parse(3)?;
930        sbuf.parse_unchecked::<ErrPacketHeader>(())?;
931        let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
932
933        if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
934            buf.parse(()).map(ErrPacket::Progress)
935        } else {
936            buf.parse((
937                *code,
938                capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
939            ))
940            .map(ErrPacket::Error)
941        }
942    }
943}
944
945impl MySerialize for ErrPacket<'_> {
946    fn serialize(&self, buf: &mut Vec<u8>) {
947        ErrPacketHeader::new().serialize(&mut *buf);
948        match self {
949            ErrPacket::Error(server_error) => {
950                server_error.code.serialize(&mut *buf);
951                server_error.serialize(buf);
952            }
953            ErrPacket::Progress(progress_report) => {
954                RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
955                progress_report.serialize(buf);
956            }
957        }
958    }
959}
960
961impl fmt::Display for ErrPacket<'_> {
962    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
963        match self {
964            ErrPacket::Error(server_error) => write!(f, "{}", server_error),
965            ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
966        }
967    }
968}
969
970define_header!(
971    SqlStateMarker,
972    InvalidSqlStateMarker("Invalid SqlStateMarker value"),
973    b'#'
974);
975
976/// MySql error state.
977#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
978pub struct SqlState {
979    __state_marker: SqlStateMarker,
980    state: [u8; 5],
981}
982
983impl SqlState {
984    /// Creates new sql state.
985    pub fn new(state: [u8; 5]) -> Self {
986        Self {
987            __state_marker: SqlStateMarker::new(),
988            state,
989        }
990    }
991
992    /// Returns an sql state as bytes.
993    pub fn as_bytes(&self) -> [u8; 5] {
994        self.state
995    }
996
997    /// Returns an sql state as a string (lossy converted).
998    pub fn as_str(&self) -> Cow<'_, str> {
999        String::from_utf8_lossy(&self.state)
1000    }
1001}
1002
1003impl<'de> MyDeserialize<'de> for SqlState {
1004    const SIZE: Option<usize> = Some(6);
1005    type Ctx = ();
1006
1007    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1008        Ok(Self {
1009            __state_marker: buf.parse(())?,
1010            state: buf.parse(())?,
1011        })
1012    }
1013}
1014
1015impl MySerialize for SqlState {
1016    fn serialize(&self, buf: &mut Vec<u8>) {
1017        self.__state_marker.serialize(buf);
1018        self.state.serialize(buf);
1019    }
1020}
1021
1022/// MySql error packet.
1023///
1024/// May hold an error or a progress report.
1025#[derive(Debug, Clone, PartialEq)]
1026pub struct ServerError<'a> {
1027    code: RawInt<LeU16>,
1028    state: Option<SqlState>,
1029    message: RawBytes<'a, EofBytes>,
1030}
1031
1032impl<'a> ServerError<'a> {
1033    pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1034        Self {
1035            code: RawInt::new(code),
1036            state,
1037            message: RawBytes::new(msg),
1038        }
1039    }
1040
1041    /// Returns an error code.
1042    pub fn error_code(&self) -> u16 {
1043        *self.code
1044    }
1045
1046    /// Returns an sql state.
1047    pub fn sql_state_ref(&self) -> Option<&SqlState> {
1048        self.state.as_ref()
1049    }
1050
1051    /// Returns an error message.
1052    pub fn message_ref(&self) -> &[u8] {
1053        self.message.as_bytes()
1054    }
1055
1056    /// Returns an error message as a string (lossy converted).
1057    pub fn message_str(&self) -> Cow<'_, str> {
1058        self.message.as_str()
1059    }
1060
1061    pub fn into_owned(self) -> ServerError<'static> {
1062        ServerError {
1063            code: self.code,
1064            state: self.state,
1065            message: self.message.into_owned(),
1066        }
1067    }
1068}
1069
1070impl<'de> MyDeserialize<'de> for ServerError<'de> {
1071    const SIZE: Option<usize> = None;
1072    /// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
1073    type Ctx = (u16, bool);
1074
1075    fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1076        let server_error = if protocol_41 {
1077            ServerError {
1078                code: RawInt::new(code),
1079                state: Some(buf.parse(())?),
1080                message: buf.parse(())?,
1081            }
1082        } else {
1083            ServerError {
1084                code: RawInt::new(code),
1085                state: None,
1086                message: buf.parse(())?,
1087            }
1088        };
1089        Ok(server_error)
1090    }
1091}
1092
1093impl MySerialize for ServerError<'_> {
1094    fn serialize(&self, buf: &mut Vec<u8>) {
1095        if let Some(state) = &self.state {
1096            state.serialize(buf);
1097        }
1098        self.message.serialize(buf);
1099    }
1100}
1101
1102impl fmt::Display for ServerError<'_> {
1103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1104        let sql_state_str = self
1105            .sql_state_ref()
1106            .map(|s| format!(" ({})", s.as_str()))
1107            .unwrap_or_default();
1108
1109        write!(
1110            f,
1111            "ERROR {}{}: {}",
1112            self.error_code(),
1113            sql_state_str,
1114            self.message_str()
1115        )
1116    }
1117}
1118
1119define_header!(
1120    LocalInfileHeader,
1121    InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1122    0xFB
1123);
1124
1125/// Represents MySql's local infile packet.
1126#[derive(Debug, Clone, Eq, PartialEq)]
1127pub struct LocalInfilePacket<'a> {
1128    __header: LocalInfileHeader,
1129    file_name: RawBytes<'a, EofBytes>,
1130}
1131
1132impl<'a> LocalInfilePacket<'a> {
1133    pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1134        Self {
1135            __header: LocalInfileHeader::new(),
1136            file_name: RawBytes::new(file_name),
1137        }
1138    }
1139
1140    /// Value of the file_name field of a local infile packet as a byte slice.
1141    pub fn file_name_ref(&self) -> &[u8] {
1142        self.file_name.as_bytes()
1143    }
1144
1145    /// Value of the file_name field of a local infile packet as a string (lossy converted).
1146    pub fn file_name_str(&self) -> Cow<'_, str> {
1147        self.file_name.as_str()
1148    }
1149
1150    pub fn into_owned(self) -> LocalInfilePacket<'static> {
1151        LocalInfilePacket {
1152            __header: self.__header,
1153            file_name: self.file_name.into_owned(),
1154        }
1155    }
1156}
1157
1158impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1159    const SIZE: Option<usize> = None;
1160    type Ctx = ();
1161
1162    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1163        Ok(LocalInfilePacket {
1164            __header: buf.parse(())?,
1165            file_name: buf.parse(())?,
1166        })
1167    }
1168}
1169
1170impl MySerialize for LocalInfilePacket<'_> {
1171    fn serialize(&self, buf: &mut Vec<u8>) {
1172        self.__header.serialize(buf);
1173        self.file_name.serialize(buf);
1174    }
1175}
1176
1177const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1178const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1179const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1180const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1181const ED25519_PLUGIN_NAME: &[u8] = b"client_ed25519";
1182
1183#[derive(Debug, Clone, PartialEq, Eq)]
1184pub enum AuthPluginData<'a> {
1185    /// Auth data for the `mysql_old_password` plugin.
1186    Old([u8; 8]),
1187    /// Auth data for the `mysql_native_password` plugin.
1188    Native([u8; 20]),
1189    /// Auth data for `sha2_password` and `caching_sha2_password` plugins.
1190    Sha2([u8; 32]),
1191    /// Clear password for `mysql_clear_password` plugin.
1192    Clear(Cow<'a, [u8]>),
1193    /// Auth data for MariaDB's `client_ed25519` plugin.
1194    ///
1195    /// This plugin is known to the library but the actual support is enabled
1196    /// by the `client_ed25519` feature.
1197    Ed25519([u8; 64]),
1198}
1199
1200impl AuthPluginData<'_> {
1201    pub fn into_owned(self) -> AuthPluginData<'static> {
1202        match self {
1203            AuthPluginData::Old(x) => AuthPluginData::Old(x),
1204            AuthPluginData::Native(x) => AuthPluginData::Native(x),
1205            AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1206            AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1207            AuthPluginData::Ed25519(x) => AuthPluginData::Ed25519(x),
1208        }
1209    }
1210}
1211
1212impl std::ops::Deref for AuthPluginData<'_> {
1213    type Target = [u8];
1214
1215    fn deref(&self) -> &Self::Target {
1216        match self {
1217            Self::Sha2(x) => &x[..],
1218            Self::Native(x) => &x[..],
1219            Self::Old(x) => &x[..],
1220            Self::Clear(x) => &x[..],
1221            Self::Ed25519(x) => &x[..],
1222        }
1223    }
1224}
1225
1226impl MySerialize for AuthPluginData<'_> {
1227    fn serialize(&self, buf: &mut Vec<u8>) {
1228        match self {
1229            Self::Sha2(x) => buf.put_slice(&x[..]),
1230            Self::Native(x) => buf.put_slice(&x[..]),
1231            Self::Old(x) => {
1232                buf.put_slice(&x[..]);
1233                buf.push(0);
1234            }
1235            Self::Clear(x) => {
1236                buf.put_slice(x);
1237                buf.push(0);
1238            }
1239            Self::Ed25519(x) => buf.put_slice(&x[..]),
1240        }
1241    }
1242}
1243
1244/// Authentication plugin
1245#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1246pub enum AuthPlugin<'a> {
1247    /// Old Password Authentication
1248    MysqlOldPassword,
1249    /// Client-Side Cleartext Pluggable Authentication
1250    MysqlClearPassword,
1251    /// Legacy authentication plugin
1252    MysqlNativePassword,
1253    /// Default since MySql v8.0.4
1254    CachingSha2Password,
1255    /// MariaDB's Ed25519 based authentication
1256    ///
1257    /// This plugin is known to the library but the actual support is enabled
1258    /// by the `client_ed25519` feature.
1259    Ed25519,
1260    Other(Cow<'a, [u8]>),
1261}
1262
1263impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1264    const SIZE: Option<usize> = None;
1265    type Ctx = ();
1266
1267    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1268        Ok(Self::from_bytes(buf.eat_all()))
1269    }
1270}
1271
1272impl MySerialize for AuthPlugin<'_> {
1273    fn serialize(&self, buf: &mut Vec<u8>) {
1274        buf.put_slice(self.as_bytes());
1275        buf.put_u8(0);
1276    }
1277}
1278
1279impl<'a> AuthPlugin<'a> {
1280    pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1281        let name = if let [name @ .., 0] = name {
1282            name
1283        } else {
1284            name
1285        };
1286        match name {
1287            CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1288            MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1289            MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1290            MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1291            ED25519_PLUGIN_NAME => AuthPlugin::Ed25519,
1292            name => AuthPlugin::Other(Cow::Borrowed(name)),
1293        }
1294    }
1295
1296    pub fn as_bytes(&self) -> &[u8] {
1297        match self {
1298            AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1299            AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1300            AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1301            AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1302            AuthPlugin::Ed25519 => ED25519_PLUGIN_NAME,
1303            AuthPlugin::Other(name) => name,
1304        }
1305    }
1306
1307    pub fn into_owned(self) -> AuthPlugin<'static> {
1308        match self {
1309            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1310            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1311            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1312            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1313            AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1314            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1315        }
1316    }
1317
1318    pub fn borrow(&self) -> AuthPlugin<'_> {
1319        match self {
1320            AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1321            AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1322            AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1323            AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1324            AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1325            AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1326        }
1327    }
1328
1329    /// Generates auth plugin data for this plugin.
1330    ///
1331    /// It'll generate `None` if password is `None` or empty.
1332    ///
1333    /// Note, that you should trim terminating null character from the `nonce`.
1334    ///
1335    /// # Panic
1336    ///
1337    /// * [`AuthPlugin::Ed25519`] will panic if `client_ed25519` feature is disabled.
1338    pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1339        use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1340
1341        match pass {
1342            Some(pass) if !pass.is_empty() => match self {
1343                AuthPlugin::CachingSha2Password => {
1344                    scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1345                }
1346                AuthPlugin::MysqlNativePassword => {
1347                    scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1348                }
1349                AuthPlugin::MysqlOldPassword => {
1350                    scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1351                        .map(AuthPluginData::Old)
1352                }
1353                AuthPlugin::MysqlClearPassword => {
1354                    Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1355                }
1356                AuthPlugin::Ed25519 => Some(AuthPluginData::Ed25519(create_response_for_ed25519(
1357                    pass.as_bytes(),
1358                    nonce,
1359                ))),
1360                AuthPlugin::Other(_) => None,
1361            },
1362            _ => None,
1363        }
1364    }
1365}
1366
1367define_header!(
1368    AuthMoreDataHeader,
1369    InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1370    0x01
1371);
1372
1373/// Extra auth-data beyond the initial challenge.
1374#[derive(Debug, Clone, Eq, PartialEq)]
1375pub struct AuthMoreData<'a> {
1376    __header: AuthMoreDataHeader,
1377    data: RawBytes<'a, EofBytes>,
1378}
1379
1380impl<'a> AuthMoreData<'a> {
1381    pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1382        Self {
1383            __header: AuthMoreDataHeader::new(),
1384            data: RawBytes::new(data),
1385        }
1386    }
1387
1388    pub fn data(&self) -> &[u8] {
1389        self.data.as_bytes()
1390    }
1391
1392    pub fn into_owned(self) -> AuthMoreData<'static> {
1393        AuthMoreData {
1394            __header: self.__header,
1395            data: self.data.into_owned(),
1396        }
1397    }
1398}
1399
1400impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1401    const SIZE: Option<usize> = None;
1402    type Ctx = ();
1403
1404    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1405        Ok(Self {
1406            __header: buf.parse(())?,
1407            data: buf.parse(())?,
1408        })
1409    }
1410}
1411
1412impl MySerialize for AuthMoreData<'_> {
1413    fn serialize(&self, buf: &mut Vec<u8>) {
1414        self.__header.serialize(&mut *buf);
1415        self.data.serialize(buf);
1416    }
1417}
1418
1419define_header!(
1420    PublicKeyResponseHeader,
1421    InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1422    0x01
1423);
1424
1425/// A server response to a [`PublicKeyRequest`] containing a public RSA key for authentication protection.
1426///
1427/// [`PublicKeyRequest`]: crate::packets::caching_sha2_password::PublicKeyRequest
1428#[derive(Debug, Clone, Eq, PartialEq)]
1429pub struct PublicKeyResponse<'a> {
1430    __header: PublicKeyResponseHeader,
1431    rsa_key: RawBytes<'a, EofBytes>,
1432}
1433
1434impl<'a> PublicKeyResponse<'a> {
1435    pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1436        Self {
1437            __header: PublicKeyResponseHeader::new(),
1438            rsa_key: RawBytes::new(rsa_key),
1439        }
1440    }
1441
1442    /// The server's RSA public key in PEM format.
1443    pub fn rsa_key(&self) -> Cow<'_, str> {
1444        self.rsa_key.as_str()
1445    }
1446}
1447
1448impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1449    const SIZE: Option<usize> = None;
1450    type Ctx = ();
1451
1452    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1453        Ok(Self {
1454            __header: buf.parse(())?,
1455            rsa_key: buf.parse(())?,
1456        })
1457    }
1458}
1459
1460impl MySerialize for PublicKeyResponse<'_> {
1461    fn serialize(&self, buf: &mut Vec<u8>) {
1462        self.__header.serialize(&mut *buf);
1463        self.rsa_key.serialize(buf);
1464    }
1465}
1466
1467define_header!(
1468    AuthSwitchRequestHeader,
1469    InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1470    0xFE
1471);
1472
1473/// Old Authentication Method Switch Request Packet.
1474///
1475/// Used for It is sent by server to request client to switch to Old Password Authentication
1476/// if `CLIENT_PLUGIN_AUTH` capability is not supported (by either the client or the server).
1477#[derive(Debug, Clone, Eq, PartialEq)]
1478pub struct OldAuthSwitchRequest {
1479    __header: AuthSwitchRequestHeader,
1480}
1481
1482impl OldAuthSwitchRequest {
1483    pub fn new() -> Self {
1484        Self {
1485            __header: AuthSwitchRequestHeader::new(),
1486        }
1487    }
1488
1489    pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1490        AuthPlugin::MysqlOldPassword
1491    }
1492}
1493
1494impl Default for OldAuthSwitchRequest {
1495    fn default() -> Self {
1496        Self::new()
1497    }
1498}
1499
1500impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1501    const SIZE: Option<usize> = Some(1);
1502    type Ctx = ();
1503
1504    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1505        Ok(Self {
1506            __header: buf.parse(())?,
1507        })
1508    }
1509}
1510
1511impl MySerialize for OldAuthSwitchRequest {
1512    fn serialize(&self, buf: &mut Vec<u8>) {
1513        self.__header.serialize(&mut *buf);
1514    }
1515}
1516
1517/// Authentication Method Switch Request Packet.
1518///
1519/// If both server and client support `CLIENT_PLUGIN_AUTH` capability, server can send this packet
1520/// to ask client to use another authentication method.
1521#[derive(Debug, Clone, Eq, PartialEq)]
1522pub struct AuthSwitchRequest<'a> {
1523    __header: AuthSwitchRequestHeader,
1524    auth_plugin: RawBytes<'a, NullBytes>,
1525    plugin_data: RawBytes<'a, EofBytes>,
1526}
1527
1528impl<'a> AuthSwitchRequest<'a> {
1529    pub fn new(
1530        auth_plugin: impl Into<Cow<'a, [u8]>>,
1531        plugin_data: impl Into<Cow<'a, [u8]>>,
1532    ) -> Self {
1533        Self {
1534            __header: AuthSwitchRequestHeader::new(),
1535            auth_plugin: RawBytes::new(auth_plugin),
1536            plugin_data: RawBytes::new(plugin_data),
1537        }
1538    }
1539
1540    pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1541        ParseBuf(self.auth_plugin.as_bytes())
1542            .parse(())
1543            .expect("infallible")
1544    }
1545
1546    pub fn plugin_data(&self) -> &[u8] {
1547        match self.plugin_data.as_bytes() {
1548            [head @ .., 0] => head,
1549            all => all,
1550        }
1551    }
1552
1553    pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1554        AuthSwitchRequest {
1555            __header: self.__header,
1556            auth_plugin: self.auth_plugin.into_owned(),
1557            plugin_data: self.plugin_data.into_owned(),
1558        }
1559    }
1560}
1561
1562impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1563    const SIZE: Option<usize> = None;
1564    type Ctx = ();
1565
1566    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1567        Ok(Self {
1568            __header: buf.parse(())?,
1569            auth_plugin: buf.parse(())?,
1570            plugin_data: buf.parse(())?,
1571        })
1572    }
1573}
1574
1575impl MySerialize for AuthSwitchRequest<'_> {
1576    fn serialize(&self, buf: &mut Vec<u8>) {
1577        self.__header.serialize(&mut *buf);
1578        self.auth_plugin.serialize(&mut *buf);
1579        self.plugin_data.serialize(buf);
1580    }
1581}
1582
1583/// Represents MySql's initial handshake packet.
1584#[derive(Debug, Clone, Eq, PartialEq)]
1585pub struct HandshakePacket<'a> {
1586    protocol_version: RawInt<u8>,
1587    server_version: RawBytes<'a, NullBytes>,
1588    connection_id: RawInt<LeU32>,
1589    scramble_1: [u8; 8],
1590    __filler: Skip<1>,
1591    // lower 16 bytes
1592    capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1593    default_collation: RawInt<u8>,
1594    status_flags: Const<StatusFlags, LeU16>,
1595    // upper 16 bytes
1596    capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1597    auth_plugin_data_len: RawInt<u8>,
1598    __reserved: Skip<6>,
1599    // MariaDB uses last 4 reserved bytes to pass its extended capabilities.
1600    mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
1601    scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1602    auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1603}
1604
1605impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1606    const SIZE: Option<usize> = None;
1607    type Ctx = ();
1608
1609    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1610        let protocol_version = buf.parse(())?;
1611        let server_version = buf.parse(())?;
1612
1613        // includes trailing 10 bytes filler
1614        let mut sbuf: ParseBuf<'_> = buf.parse(31)?;
1615        let connection_id = sbuf.parse_unchecked(())?;
1616        let scramble_1 = sbuf.parse_unchecked(())?;
1617        let __filler = sbuf.parse_unchecked(())?;
1618        let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1619        let default_collation = sbuf.parse_unchecked(())?;
1620        let status_flags = sbuf.parse_unchecked(())?;
1621        let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1622        let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1623        let __reserved = sbuf.parse_unchecked(())?;
1624        // If the server is MariaDB, it will pass its extended capabilities
1625        // in the last 4 reserved bytes.
1626        let mariadb_capabiities: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
1627        let mut scramble_2 = None;
1628        if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1629            let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1630            scramble_2 = buf.parse(len).map(Some)?;
1631        }
1632        let mut auth_plugin_name = None;
1633        if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1634            auth_plugin_name = match buf.eat_all() {
1635                [head @ .., 0] => Some(RawBytes::new(head)),
1636                // missing trailing `0` is a known bug in mysql
1637                all => Some(RawBytes::new(all)),
1638            }
1639        }
1640
1641        Ok(Self {
1642            protocol_version,
1643            server_version,
1644            connection_id,
1645            scramble_1,
1646            __filler,
1647            capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1648            default_collation,
1649            status_flags,
1650            capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1651            auth_plugin_data_len,
1652            __reserved,
1653            mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
1654                mariadb_capabiities.0,
1655            )),
1656            scramble_2,
1657            auth_plugin_name,
1658        })
1659    }
1660}
1661
1662impl MySerialize for HandshakePacket<'_> {
1663    fn serialize(&self, buf: &mut Vec<u8>) {
1664        self.protocol_version.serialize(&mut *buf);
1665        self.server_version.serialize(&mut *buf);
1666        self.connection_id.serialize(&mut *buf);
1667        self.scramble_1.serialize(&mut *buf);
1668        buf.put_u8(0x00);
1669        self.capabilities_1.serialize(&mut *buf);
1670        self.default_collation.serialize(&mut *buf);
1671        self.status_flags.serialize(&mut *buf);
1672        self.capabilities_2.serialize(&mut *buf);
1673
1674        if self
1675            .capabilities_2
1676            .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1677        {
1678            buf.put_u8(
1679                self.scramble_2
1680                    .as_ref()
1681                    .map(|x| (x.len() + 8) as u8)
1682                    .unwrap_or_default(),
1683            );
1684        } else {
1685            buf.put_u8(0);
1686        }
1687
1688        self.__reserved.serialize(&mut *buf);
1689        self.mariadb_ext_capabilities.serialize(&mut *buf);
1690
1691        // Assume that the packet is well formed:
1692        // * the CLIENT_SECURE_CONNECTION is set.
1693        if let Some(scramble_2) = &self.scramble_2 {
1694            scramble_2.serialize(&mut *buf);
1695        }
1696
1697        // Assume that the packet is well formed:
1698        // * the CLIENT_PLUGIN_AUTH is set.
1699        if let Some(client_plugin_auth) = &self.auth_plugin_name {
1700            client_plugin_auth.serialize(buf);
1701        }
1702    }
1703}
1704
1705impl<'a> HandshakePacket<'a> {
1706    #[allow(clippy::too_many_arguments)]
1707    pub fn new(
1708        protocol_version: u8,
1709        server_version: impl Into<Cow<'a, [u8]>>,
1710        connection_id: u32,
1711        scramble_1: [u8; 8],
1712        scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1713        capabilities: CapabilityFlags,
1714        default_collation: u8,
1715        status_flags: StatusFlags,
1716        auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1717    ) -> Self {
1718        // Safety:
1719        // * capabilities are given as a valid CapabilityFlags instance
1720        // * the BitAnd operation can't set new bits
1721        let (capabilities_1, capabilities_2) = (
1722            CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1723            CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1724        );
1725
1726        let scramble_2 = scramble_2.map(RawBytes::new);
1727
1728        HandshakePacket {
1729            protocol_version: RawInt::new(protocol_version),
1730            server_version: RawBytes::new(server_version),
1731            connection_id: RawInt::new(connection_id),
1732            scramble_1,
1733            __filler: Skip,
1734            capabilities_1: Const::new(capabilities_1),
1735            default_collation: RawInt::new(default_collation),
1736            status_flags: Const::new(status_flags),
1737            capabilities_2: Const::new(capabilities_2),
1738            auth_plugin_data_len: RawInt::new(
1739                scramble_2
1740                    .as_ref()
1741                    .map(|x| x.len() as u8)
1742                    .unwrap_or_default(),
1743            ),
1744            __reserved: Skip,
1745            mariadb_ext_capabilities: Const::new(MariadbCapabilities::empty()),
1746            scramble_2,
1747            auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1748        }
1749    }
1750
1751    pub fn with_mariadb_ext_capabilities(mut self, flags: MariadbCapabilities) -> Self {
1752        self.mariadb_ext_capabilities = Const::new(flags);
1753        self
1754    }
1755
1756    pub fn into_owned(self) -> HandshakePacket<'static> {
1757        HandshakePacket {
1758            protocol_version: self.protocol_version,
1759            server_version: self.server_version.into_owned(),
1760            connection_id: self.connection_id,
1761            scramble_1: self.scramble_1,
1762            __filler: self.__filler,
1763            capabilities_1: self.capabilities_1,
1764            default_collation: self.default_collation,
1765            status_flags: self.status_flags,
1766            capabilities_2: self.capabilities_2,
1767            auth_plugin_data_len: self.auth_plugin_data_len,
1768            __reserved: self.__reserved,
1769            mariadb_ext_capabilities: self.mariadb_ext_capabilities,
1770            scramble_2: self.scramble_2.map(|x| x.into_owned()),
1771            auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1772        }
1773    }
1774
1775    /// Value of the protocol_version field of an initial handshake packet.
1776    pub fn protocol_version(&self) -> u8 {
1777        self.protocol_version.0
1778    }
1779
1780    /// Value of the server_version field of an initial handshake packet as a byte slice.
1781    pub fn server_version_ref(&self) -> &[u8] {
1782        self.server_version.as_bytes()
1783    }
1784
1785    /// Value of the server_version field of an initial handshake packet as a string
1786    /// (lossy converted).
1787    pub fn server_version_str(&self) -> Cow<'_, str> {
1788        self.server_version.as_str()
1789    }
1790
1791    /// Parsed server version.
1792    ///
1793    /// Will parse first \d+.\d+.\d+ of a server version string (if any).
1794    pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1795        VERSION_RE
1796            .captures(self.server_version_ref())
1797            .map(|captures| {
1798                // Should not panic because validated with regex
1799                (
1800                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1801                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1802                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1803                )
1804            })
1805    }
1806
1807    /// Parsed mariadb server version.
1808    pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1809        MARIADB_VERSION_RE
1810            .captures(self.server_version_ref())
1811            .map(|captures| {
1812                // Should not panic because validated with regex
1813                (
1814                    btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1815                    btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1816                    btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1817                )
1818            })
1819    }
1820
1821    /// Value of the connection_id field of an initial handshake packet.
1822    pub fn connection_id(&self) -> u32 {
1823        self.connection_id.0
1824    }
1825
1826    /// Value of the scramble_1 field of an initial handshake packet as a byte slice.
1827    pub fn scramble_1_ref(&self) -> &[u8] {
1828        self.scramble_1.as_ref()
1829    }
1830
1831    /// Value of the scramble_2 field of an initial handshake packet as a byte slice.
1832    ///
1833    /// Note that this may include a terminating null character.
1834    pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1835        self.scramble_2.as_ref().map(|x| x.as_bytes())
1836    }
1837
1838    /// Returns concatenated auth plugin nonce.
1839    pub fn nonce(&self) -> Vec<u8> {
1840        let mut out = Vec::from(self.scramble_1_ref());
1841        out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1842
1843        // Trim zero terminator. Fill with zeroes if nonce
1844        // is somehow smaller than 20 bytes.
1845        out.resize(20, 0);
1846        out
1847    }
1848
1849    /// Value of a server capabilities.
1850    pub fn capabilities(&self) -> CapabilityFlags {
1851        self.capabilities_1.0 | self.capabilities_2.0
1852    }
1853
1854    /// Value of MariaDB specific server capabilities
1855    pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
1856        self.mariadb_ext_capabilities.0
1857    }
1858    /// Value of the default_collation field of an initial handshake packet.
1859    pub fn default_collation(&self) -> u8 {
1860        self.default_collation.0
1861    }
1862
1863    /// Value of a status flags.
1864    pub fn status_flags(&self) -> StatusFlags {
1865        self.status_flags.0
1866    }
1867
1868    /// Value of the auth_plugin_name field of an initial handshake packet as a byte slice.
1869    pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1870        self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1871    }
1872
1873    /// Value of the auth_plugin_name field of an initial handshake packet as a string
1874    /// (lossy converted).
1875    pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1876        self.auth_plugin_name.as_ref().map(|x| x.as_str())
1877    }
1878
1879    /// Auth plugin of a handshake packet
1880    pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1881        self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1882            [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1883            all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1884        })
1885    }
1886}
1887
1888define_header!(
1889    ComChangeUserHeader,
1890    InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1891    0x11
1892);
1893
1894#[derive(Debug, Clone, PartialEq, Eq)]
1895pub struct ComChangeUser<'a> {
1896    __header: ComChangeUserHeader,
1897    user: RawBytes<'a, NullBytes>,
1898    // Only CLIENT_SECURE_CONNECTION capable servers are supported
1899    auth_plugin_data: RawBytes<'a, U8Bytes>,
1900    database: RawBytes<'a, NullBytes>,
1901    more_data: Option<ComChangeUserMoreData<'a>>,
1902}
1903
1904impl<'a> ComChangeUser<'a> {
1905    pub fn new() -> Self {
1906        Self {
1907            __header: ComChangeUserHeader::new(),
1908            user: Default::default(),
1909            auth_plugin_data: Default::default(),
1910            database: Default::default(),
1911            more_data: None,
1912        }
1913    }
1914
1915    pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1916        self.user = user.map(RawBytes::new).unwrap_or_default();
1917        self
1918    }
1919
1920    pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1921        self.database = database.map(RawBytes::new).unwrap_or_default();
1922        self
1923    }
1924
1925    pub fn with_auth_plugin_data(
1926        mut self,
1927        auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1928    ) -> Self {
1929        self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1930        self
1931    }
1932
1933    pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1934        self.more_data = more_data;
1935        self
1936    }
1937
1938    pub fn into_owned(self) -> ComChangeUser<'static> {
1939        ComChangeUser {
1940            __header: self.__header,
1941            user: self.user.into_owned(),
1942            auth_plugin_data: self.auth_plugin_data.into_owned(),
1943            database: self.database.into_owned(),
1944            more_data: self.more_data.map(|x| x.into_owned()),
1945        }
1946    }
1947}
1948
1949impl Default for ComChangeUser<'_> {
1950    fn default() -> Self {
1951        Self::new()
1952    }
1953}
1954
1955impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1956    const SIZE: Option<usize> = None;
1957
1958    type Ctx = CapabilityFlags;
1959
1960    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1961        Ok(Self {
1962            __header: buf.parse(())?,
1963            user: buf.parse(())?,
1964            auth_plugin_data: buf.parse(())?,
1965            database: buf.parse(())?,
1966            more_data: if !buf.is_empty() {
1967                Some(buf.parse(flags)?)
1968            } else {
1969                None
1970            },
1971        })
1972    }
1973}
1974
1975impl MySerialize for ComChangeUser<'_> {
1976    fn serialize(&self, buf: &mut Vec<u8>) {
1977        self.__header.serialize(&mut *buf);
1978        self.user.serialize(&mut *buf);
1979        self.auth_plugin_data.serialize(&mut *buf);
1980        self.database.serialize(&mut *buf);
1981        if let Some(ref more_data) = self.more_data {
1982            more_data.serialize(&mut *buf);
1983        }
1984    }
1985}
1986
1987#[derive(Debug, Clone, PartialEq, Eq)]
1988pub struct ComChangeUserMoreData<'a> {
1989    character_set: RawInt<LeU16>,
1990    auth_plugin: Option<AuthPlugin<'a>>,
1991    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1992}
1993
1994impl<'a> ComChangeUserMoreData<'a> {
1995    pub fn new(character_set: u16) -> Self {
1996        Self {
1997            character_set: RawInt::new(character_set),
1998            auth_plugin: None,
1999            connect_attributes: None,
2000        }
2001    }
2002
2003    pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
2004        self.auth_plugin = auth_plugin;
2005        self
2006    }
2007
2008    pub fn with_connect_attributes(
2009        mut self,
2010        connect_attributes: Option<HashMap<String, String>>,
2011    ) -> Self {
2012        self.connect_attributes = connect_attributes.map(|attrs| {
2013            attrs
2014                .into_iter()
2015                .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2016                .collect()
2017        });
2018        self
2019    }
2020
2021    pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
2022        ComChangeUserMoreData {
2023            character_set: self.character_set,
2024            auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
2025            connect_attributes: self.connect_attributes.map(|x| {
2026                x.into_iter()
2027                    .map(|(k, v)| (k.into_owned(), v.into_owned()))
2028                    .collect()
2029            }),
2030        }
2031    }
2032}
2033
2034// Helper that deserializes connect attributes.
2035fn deserialize_connect_attrs<'de>(
2036    buf: &mut ParseBuf<'de>,
2037) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
2038    let data_len = buf.parse::<RawInt<LenEnc>>(())?;
2039    let mut data: ParseBuf<'_> = buf.parse(data_len.0 as usize)?;
2040    let mut attrs = HashMap::new();
2041    while !data.is_empty() {
2042        let key = data.parse::<RawBytes<'_, LenEnc>>(())?;
2043        let value = data.parse::<RawBytes<'_, LenEnc>>(())?;
2044        attrs.insert(key, value);
2045    }
2046    Ok(attrs)
2047}
2048
2049// Helper that serializes connect attributes.
2050fn serialize_connect_attrs<'a>(
2051    connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2052    buf: &mut Vec<u8>,
2053) {
2054    let len = connect_attributes
2055        .iter()
2056        .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2057        .sum::<u64>();
2058    buf.put_lenenc_int(len);
2059
2060    for (name, value) in connect_attributes {
2061        name.serialize(&mut *buf);
2062        value.serialize(&mut *buf);
2063    }
2064}
2065
2066impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2067    const SIZE: Option<usize> = None;
2068    type Ctx = CapabilityFlags;
2069
2070    fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2071        // always assume CLIENT_PROTOCOL_41
2072        let character_set = buf.parse(())?;
2073        let mut auth_plugin = None;
2074        let mut connect_attributes = None;
2075
2076        if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2077            // plugin name is null-terminated here
2078            match buf.parse::<RawBytes<'_, NullBytes>>(())?.0 {
2079                Cow::Borrowed(bytes) => {
2080                    let mut auth_plugin_buf = ParseBuf(bytes);
2081                    auth_plugin = Some(auth_plugin_buf.parse(())?);
2082                }
2083                _ => unreachable!(),
2084            }
2085        };
2086
2087        if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2088            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2089        };
2090
2091        Ok(Self {
2092            character_set,
2093            auth_plugin,
2094            connect_attributes,
2095        })
2096    }
2097}
2098
2099impl MySerialize for ComChangeUserMoreData<'_> {
2100    fn serialize(&self, buf: &mut Vec<u8>) {
2101        self.character_set.serialize(&mut *buf);
2102        if let Some(ref auth_plugin) = self.auth_plugin {
2103            auth_plugin.serialize(&mut *buf);
2104        }
2105        if let Some(ref connect_attributes) = self.connect_attributes {
2106            serialize_connect_attrs(connect_attributes, buf);
2107        } else {
2108            // We'll always act like CLIENT_CONNECT_ATTRS is set,
2109            // this is to avoid looking into the actual connection flags.
2110            serialize_connect_attrs(&Default::default(), buf);
2111        }
2112    }
2113}
2114
2115/// Actual serialization of this field depends on capability flags values.
2116type ScrambleBuf<'a> =
2117    Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2118
2119#[derive(Debug, Clone, PartialEq, Eq)]
2120pub struct HandshakeResponse<'a> {
2121    capabilities: Const<CapabilityFlags, LeU32>,
2122    max_packet_size: RawInt<LeU32>,
2123    collation: RawInt<u8>,
2124    scramble_buf: ScrambleBuf<'a>,
2125    user: RawBytes<'a, NullBytes>,
2126    db_name: Option<RawBytes<'a, NullBytes>>,
2127    auth_plugin: Option<AuthPlugin<'a>>,
2128    connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2129    mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
2130}
2131
2132impl<'a> HandshakeResponse<'a> {
2133    #[allow(clippy::too_many_arguments)]
2134    pub fn new(
2135        scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2136        server_version: (u16, u16, u16),
2137        user: Option<impl Into<Cow<'a, [u8]>>>,
2138        db_name: Option<impl Into<Cow<'a, [u8]>>>,
2139        auth_plugin: Option<AuthPlugin<'a>>,
2140        mut capabilities: CapabilityFlags,
2141        connect_attributes: Option<HashMap<String, String>>,
2142        max_packet_size: u32,
2143    ) -> Self {
2144        let scramble_buf =
2145            if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2146                Either::Left(RawBytes::new(
2147                    scramble_buf.map(Into::into).unwrap_or_default(),
2148                ))
2149            } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2150                Either::Right(Either::Left(RawBytes::new(
2151                    scramble_buf.map(Into::into).unwrap_or_default(),
2152                )))
2153            } else {
2154                Either::Right(Either::Right(RawBytes::new(
2155                    scramble_buf.map(Into::into).unwrap_or_default(),
2156                )))
2157            };
2158
2159        if db_name.is_some() {
2160            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2161        } else {
2162            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2163        }
2164
2165        if auth_plugin.is_some() {
2166            capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2167        } else {
2168            capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2169        }
2170
2171        if connect_attributes.is_some() {
2172            capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2173        } else {
2174            capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2175        }
2176
2177        Self {
2178            scramble_buf,
2179            collation: if server_version >= (5, 5, 3) {
2180                RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2181            } else {
2182                RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2183            },
2184            user: user.map(RawBytes::new).unwrap_or_default(),
2185            db_name: db_name.map(RawBytes::new),
2186            auth_plugin,
2187            capabilities: Const::new(capabilities),
2188            connect_attributes: connect_attributes.map(|attrs| {
2189                attrs
2190                    .into_iter()
2191                    .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2192                    .collect()
2193            }),
2194            max_packet_size: RawInt::new(max_packet_size),
2195            mariadb_ext_capabilities: Const::new(MariadbCapabilities::empty()),
2196        }
2197    }
2198
2199    pub fn with_mariadb_ext_capabilities(
2200        mut self,
2201        mariadb_ext_capabilities: MariadbCapabilities,
2202    ) -> Self {
2203        self.mariadb_ext_capabilities = Const::new(mariadb_ext_capabilities);
2204        self
2205    }
2206
2207    pub fn capabilities(&self) -> CapabilityFlags {
2208        self.capabilities.0
2209    }
2210
2211    pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
2212        self.mariadb_ext_capabilities.0
2213    }
2214
2215    pub fn collation(&self) -> u8 {
2216        self.collation.0
2217    }
2218
2219    pub fn scramble_buf(&self) -> &[u8] {
2220        match &self.scramble_buf {
2221            Either::Left(x) => x.as_bytes(),
2222            Either::Right(x) => match x {
2223                Either::Left(x) => x.as_bytes(),
2224                Either::Right(x) => x.as_bytes(),
2225            },
2226        }
2227    }
2228
2229    pub fn user(&self) -> &[u8] {
2230        self.user.as_bytes()
2231    }
2232
2233    pub fn db_name(&self) -> Option<&[u8]> {
2234        self.db_name.as_ref().map(|x| x.as_bytes())
2235    }
2236
2237    pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2238        self.auth_plugin.as_ref()
2239    }
2240
2241    #[must_use = "entails computation"]
2242    pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2243        self.connect_attributes.as_ref().map(|attrs| {
2244            attrs
2245                .iter()
2246                .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2247                .collect()
2248        })
2249    }
2250}
2251
2252impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2253    const SIZE: Option<usize> = None;
2254    type Ctx = ();
2255
2256    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2257        let mut sbuf: ParseBuf<'_> = buf.parse(4 + 4 + 1 + 23)?;
2258        let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2259        let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2260        let collation = sbuf.parse_unchecked(())?;
2261        sbuf.parse_unchecked::<Skip<19>>(())?;
2262        let mariadb_flags: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
2263
2264        let user = buf.parse(())?;
2265        let scramble_buf =
2266            if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2267                Either::Left(buf.parse(())?)
2268            } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2269                Either::Right(Either::Left(buf.parse(())?))
2270            } else {
2271                Either::Right(Either::Right(buf.parse(())?))
2272            };
2273
2274        let mut db_name = None;
2275        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2276            db_name = buf.parse(()).map(Some)?;
2277        }
2278
2279        let mut auth_plugin = None;
2280        if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2281            let auth_plugin_name = buf.eat_null_str();
2282            auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2283        }
2284
2285        let mut connect_attributes = None;
2286        if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2287            connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2288        }
2289
2290        Ok(Self {
2291            capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2292            max_packet_size,
2293            collation,
2294            scramble_buf,
2295            user,
2296            db_name,
2297            auth_plugin,
2298            connect_attributes,
2299            mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
2300                mariadb_flags.0,
2301            )),
2302        })
2303    }
2304}
2305
2306impl MySerialize for HandshakeResponse<'_> {
2307    fn serialize(&self, buf: &mut Vec<u8>) {
2308        self.capabilities.serialize(&mut *buf);
2309        self.max_packet_size.serialize(&mut *buf);
2310        self.collation.serialize(&mut *buf);
2311        buf.put_slice(&[0; 19]);
2312        self.mariadb_ext_capabilities.serialize(&mut *buf);
2313        self.user.serialize(&mut *buf);
2314        self.scramble_buf.serialize(&mut *buf);
2315
2316        if let Some(db_name) = &self.db_name {
2317            db_name.serialize(&mut *buf);
2318        }
2319
2320        if let Some(auth_plugin) = &self.auth_plugin {
2321            auth_plugin.serialize(&mut *buf);
2322        }
2323
2324        if let Some(attrs) = &self.connect_attributes {
2325            let len = attrs
2326                .iter()
2327                .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2328                .sum::<u64>();
2329            buf.put_lenenc_int(len);
2330
2331            for (name, value) in attrs {
2332                name.serialize(&mut *buf);
2333                value.serialize(&mut *buf);
2334            }
2335        }
2336    }
2337}
2338
2339#[derive(Debug, Clone, Eq, PartialEq)]
2340pub struct SslRequest {
2341    capabilities: Const<CapabilityFlags, LeU32>,
2342    max_packet_size: RawInt<LeU32>,
2343    character_set: RawInt<u8>,
2344    __skip: Skip<23>,
2345}
2346
2347impl SslRequest {
2348    pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2349        Self {
2350            capabilities: Const::new(capabilities),
2351            max_packet_size: RawInt::new(max_packet_size),
2352            character_set: RawInt::new(character_set),
2353            __skip: Skip,
2354        }
2355    }
2356
2357    pub fn capabilities(&self) -> CapabilityFlags {
2358        self.capabilities.0
2359    }
2360
2361    pub fn max_packet_size(&self) -> u32 {
2362        self.max_packet_size.0
2363    }
2364
2365    pub fn character_set(&self) -> u8 {
2366        self.character_set.0
2367    }
2368}
2369
2370impl<'de> MyDeserialize<'de> for SslRequest {
2371    const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2372    type Ctx = ();
2373
2374    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2375        let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2376        let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2377        Ok(Self {
2378            capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2379            max_packet_size: buf.parse_unchecked(())?,
2380            character_set: buf.parse_unchecked(())?,
2381            __skip: buf.parse_unchecked(())?,
2382        })
2383    }
2384}
2385
2386impl MySerialize for SslRequest {
2387    fn serialize(&self, buf: &mut Vec<u8>) {
2388        self.capabilities.serialize(&mut *buf);
2389        self.max_packet_size.serialize(&mut *buf);
2390        self.character_set.serialize(&mut *buf);
2391        self.__skip.serialize(&mut *buf);
2392    }
2393}
2394
2395#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2396#[error("Invalid statement packet status")]
2397pub struct InvalidStmtPacketStatus;
2398
2399/// Represents MySql's statement packet.
2400#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2401pub struct StmtPacket {
2402    status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2403    statement_id: RawInt<LeU32>,
2404    num_columns: RawInt<LeU16>,
2405    num_params: RawInt<LeU16>,
2406    __skip: Skip<1>,
2407    warning_count: RawInt<LeU16>,
2408}
2409
2410impl<'de> MyDeserialize<'de> for StmtPacket {
2411    const SIZE: Option<usize> = Some(12);
2412    type Ctx = ();
2413
2414    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2415        let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2416        Ok(StmtPacket {
2417            status: buf.parse_unchecked(())?,
2418            statement_id: buf.parse_unchecked(())?,
2419            num_columns: buf.parse_unchecked(())?,
2420            num_params: buf.parse_unchecked(())?,
2421            __skip: buf.parse_unchecked(())?,
2422            warning_count: buf.parse_unchecked(())?,
2423        })
2424    }
2425}
2426
2427impl MySerialize for StmtPacket {
2428    fn serialize(&self, buf: &mut Vec<u8>) {
2429        self.status.serialize(&mut *buf);
2430        self.statement_id.serialize(&mut *buf);
2431        self.num_columns.serialize(&mut *buf);
2432        self.num_params.serialize(&mut *buf);
2433        self.__skip.serialize(&mut *buf);
2434        self.warning_count.serialize(&mut *buf);
2435    }
2436}
2437
2438impl StmtPacket {
2439    /// Value of the statement_id field of a statement packet.
2440    pub fn statement_id(&self) -> u32 {
2441        *self.statement_id
2442    }
2443
2444    /// Value of the num_columns field of a statement packet.
2445    pub fn num_columns(&self) -> u16 {
2446        *self.num_columns
2447    }
2448
2449    /// Value of the num_params field of a statement packet.
2450    pub fn num_params(&self) -> u16 {
2451        *self.num_params
2452    }
2453
2454    /// Value of the warning_count field of a statement packet.
2455    pub fn warning_count(&self) -> u16 {
2456        *self.warning_count
2457    }
2458}
2459
2460/// Null-bitmap.
2461///
2462/// <http://dev.mysql.com/doc/internals/en/null-bitmap.html>
2463#[derive(Debug, Clone, Eq, PartialEq)]
2464pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2465
2466impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2467    const SIZE: Option<usize> = None;
2468    type Ctx = usize;
2469
2470    fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2471        let bitmap_len = Self::bitmap_len(num_columns);
2472        let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2473        Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2474    }
2475}
2476
2477impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2478    /// Creates new null-bitmap for a given number of columns.
2479    pub fn new(num_columns: usize) -> Self {
2480        Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2481    }
2482
2483    /// Will read null-bitmap for a given number of columns from `input`.
2484    pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2485        let bitmap_len = Self::bitmap_len(num_columns);
2486        assert!(input.len() >= bitmap_len);
2487
2488        let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2489        *input = &input[bitmap_len..];
2490
2491        bitmap
2492    }
2493}
2494
2495impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2496    pub fn bitmap_len(num_columns: usize) -> usize {
2497        (num_columns + 7 + T::BIT_OFFSET) / 8
2498    }
2499
2500    fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2501        let offset = column_index + T::BIT_OFFSET;
2502        let byte = offset / 8;
2503        let bit = 1 << (offset % 8) as u8;
2504
2505        assert!(byte < self.0.as_ref().len());
2506
2507        (byte, bit)
2508    }
2509
2510    /// Creates new null-bitmap from given bytes.
2511    pub fn from_bytes(bytes: U) -> Self {
2512        Self(bytes, PhantomData)
2513    }
2514
2515    /// Returns `true` if given column is `NULL` in this `NullBitmap`.
2516    pub fn is_null(&self, column_index: usize) -> bool {
2517        let (byte, bit) = self.byte_and_bit(column_index);
2518        self.0.as_ref()[byte] & bit > 0
2519    }
2520}
2521
2522impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2523    /// Sets flag value for given column.
2524    pub fn set(&mut self, column_index: usize, is_null: bool) {
2525        let (byte, bit) = self.byte_and_bit(column_index);
2526        if is_null {
2527            self.0.as_mut()[byte] |= bit
2528        } else {
2529            self.0.as_mut()[byte] &= !bit
2530        }
2531    }
2532}
2533
2534impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2535    fn as_ref(&self) -> &[u8] {
2536        self.0.as_ref()
2537    }
2538}
2539
2540#[derive(Debug, Clone, PartialEq)]
2541pub struct ComStmtExecuteRequestBuilder {
2542    pub stmt_id: u32,
2543}
2544
2545impl ComStmtExecuteRequestBuilder {
2546    pub const NULL_BITMAP_OFFSET: usize = 10;
2547
2548    pub fn new(stmt_id: u32) -> Self {
2549        Self { stmt_id }
2550    }
2551}
2552
2553impl ComStmtExecuteRequestBuilder {
2554    pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2555        let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2556
2557        let mut bitmap_bytes = vec![0; bitmap_len];
2558        let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2559        let params = params.iter().collect::<Vec<_>>();
2560
2561        let meta_len = params.len() * 2;
2562
2563        let mut data_len = 0;
2564        for (i, param) in params.iter().enumerate() {
2565            match param.bin_len() as usize {
2566                0 => bitmap.set(i, true),
2567                x => data_len += x,
2568            }
2569        }
2570
2571        let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2572
2573        let as_long_data = total_len > MAX_PAYLOAD_LEN;
2574
2575        (
2576            ComStmtExecuteRequest {
2577                com_stmt_execute: ConstU8::new(),
2578                stmt_id: RawInt::new(self.stmt_id),
2579                flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2580                iteration_count: ConstU32::new(),
2581                params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2582                bitmap: RawBytes::new(bitmap_bytes),
2583                params,
2584                as_long_data,
2585            },
2586            as_long_data,
2587        )
2588    }
2589}
2590
2591define_header!(
2592    ComStmtExecuteHeader,
2593    COM_STMT_EXECUTE,
2594    InvalidComStmtExecuteHeader
2595);
2596
2597define_const!(
2598    ConstU32,
2599    IterationCount,
2600    InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2601    1
2602);
2603
2604#[derive(Debug, Clone, PartialEq)]
2605pub struct ComStmtExecuteRequest<'a> {
2606    com_stmt_execute: ComStmtExecuteHeader,
2607    stmt_id: RawInt<LeU32>,
2608    flags: Const<CursorType, u8>,
2609    iteration_count: IterationCount,
2610    // max params / bits per byte = 8192
2611    bitmap: RawBytes<'a, BareBytes<8192>>,
2612    params_flags: Const<StmtExecuteParamsFlags, u8>,
2613    params: Vec<&'a Value>,
2614    as_long_data: bool,
2615}
2616
2617impl<'a> ComStmtExecuteRequest<'a> {
2618    pub fn stmt_id(&self) -> u32 {
2619        self.stmt_id.0
2620    }
2621
2622    pub fn flags(&self) -> CursorType {
2623        self.flags.0
2624    }
2625
2626    pub fn bitmap(&self) -> &[u8] {
2627        self.bitmap.as_bytes()
2628    }
2629
2630    pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2631        self.params_flags.0
2632    }
2633
2634    pub fn params(&self) -> &[&'a Value] {
2635        self.params.as_ref()
2636    }
2637
2638    pub fn as_long_data(&self) -> bool {
2639        self.as_long_data
2640    }
2641}
2642
2643impl MySerialize for ComStmtExecuteRequest<'_> {
2644    fn serialize(&self, buf: &mut Vec<u8>) {
2645        self.com_stmt_execute.serialize(&mut *buf);
2646        self.stmt_id.serialize(&mut *buf);
2647        self.flags.serialize(&mut *buf);
2648        self.iteration_count.serialize(&mut *buf);
2649
2650        if !self.params.is_empty() {
2651            self.bitmap.serialize(&mut *buf);
2652            self.params_flags.serialize(&mut *buf);
2653        }
2654
2655        for param in &self.params {
2656            let (column_type, flags) = match param {
2657                Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2658                Value::Bytes(_) => (
2659                    ColumnType::MYSQL_TYPE_VAR_STRING,
2660                    StmtExecuteParamFlags::empty(),
2661                ),
2662                Value::Int(_) => (
2663                    ColumnType::MYSQL_TYPE_LONGLONG,
2664                    StmtExecuteParamFlags::empty(),
2665                ),
2666                Value::UInt(_) => (
2667                    ColumnType::MYSQL_TYPE_LONGLONG,
2668                    StmtExecuteParamFlags::UNSIGNED,
2669                ),
2670                Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2671                Value::Double(_) => (
2672                    ColumnType::MYSQL_TYPE_DOUBLE,
2673                    StmtExecuteParamFlags::empty(),
2674                ),
2675                Value::Date(..) => (
2676                    ColumnType::MYSQL_TYPE_DATETIME,
2677                    StmtExecuteParamFlags::empty(),
2678                ),
2679                Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2680            };
2681
2682            buf.put_slice(&[column_type as u8, flags.bits()]);
2683        }
2684
2685        for param in &self.params {
2686            match **param {
2687                Value::Int(_)
2688                | Value::UInt(_)
2689                | Value::Float(_)
2690                | Value::Double(_)
2691                | Value::Date(..)
2692                | Value::Time(..) => {
2693                    param.serialize(buf);
2694                }
2695                Value::Bytes(_) if !self.as_long_data => {
2696                    param.serialize(buf);
2697                }
2698                Value::Bytes(_) | Value::NULL => {}
2699            }
2700        }
2701    }
2702}
2703
2704define_header!(
2705    ComStmtSendLongDataHeader,
2706    COM_STMT_SEND_LONG_DATA,
2707    InvalidComStmtSendLongDataHeader
2708);
2709
2710#[derive(Debug, Clone, Eq, PartialEq)]
2711pub struct ComStmtSendLongData<'a> {
2712    __header: ComStmtSendLongDataHeader,
2713    stmt_id: RawInt<LeU32>,
2714    param_index: RawInt<LeU16>,
2715    data: RawBytes<'a, EofBytes>,
2716}
2717
2718impl<'a> ComStmtSendLongData<'a> {
2719    pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2720        Self {
2721            __header: ComStmtSendLongDataHeader::new(),
2722            stmt_id: RawInt::new(stmt_id),
2723            param_index: RawInt::new(param_index),
2724            data: RawBytes::new(data),
2725        }
2726    }
2727
2728    pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2729        ComStmtSendLongData {
2730            __header: self.__header,
2731            stmt_id: self.stmt_id,
2732            param_index: self.param_index,
2733            data: self.data.into_owned(),
2734        }
2735    }
2736}
2737
2738impl MySerialize for ComStmtSendLongData<'_> {
2739    fn serialize(&self, buf: &mut Vec<u8>) {
2740        self.__header.serialize(&mut *buf);
2741        self.stmt_id.serialize(&mut *buf);
2742        self.param_index.serialize(&mut *buf);
2743        self.data.serialize(&mut *buf);
2744    }
2745}
2746
2747#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2748pub struct ComStmtClose {
2749    pub stmt_id: u32,
2750}
2751
2752impl ComStmtClose {
2753    pub fn new(stmt_id: u32) -> Self {
2754        Self { stmt_id }
2755    }
2756}
2757
2758impl MySerialize for ComStmtClose {
2759    fn serialize(&self, buf: &mut Vec<u8>) {
2760        buf.put_u8(Command::COM_STMT_CLOSE as u8);
2761        buf.put_u32_le(self.stmt_id);
2762    }
2763}
2764
2765define_header!(
2766    ComRegisterSlaveHeader,
2767    COM_REGISTER_SLAVE,
2768    InvalidComRegisterSlaveHeader
2769);
2770
2771/// Registers a slave at the master. Should be sent before requesting a binlog events
2772/// with `COM_BINLOG_DUMP`.
2773#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2774pub struct ComRegisterSlave<'a> {
2775    header: ComRegisterSlaveHeader,
2776    /// The slaves server-id.
2777    server_id: RawInt<LeU32>,
2778    /// The host name or IP address of the slave to be reported to the master during slave
2779    /// registration. Usually empty.
2780    hostname: RawBytes<'a, U8Bytes>,
2781    /// The account user name of the slave to be reported to the master during slave registration.
2782    /// Usually empty.
2783    ///
2784    /// # Note
2785    ///
2786    /// Serialization will truncate this value if length is greater than 255 bytes.
2787    user: RawBytes<'a, U8Bytes>,
2788    /// The account password of the slave to be reported to the master during slave registration.
2789    /// Usually empty.
2790    ///
2791    /// # Note
2792    ///
2793    /// Serialization will truncate this value if length is greater than 255 bytes.
2794    password: RawBytes<'a, U8Bytes>,
2795    /// The TCP/IP port number for connecting to the slave, to be reported to the master during
2796    /// slave registration. Usually empty.
2797    ///
2798    /// # Note
2799    ///
2800    /// Serialization will truncate this value if length is greater than 255 bytes.
2801    port: RawInt<LeU16>,
2802    /// Ignored.
2803    replication_rank: RawInt<LeU32>,
2804    /// Usually 0. Appears as "master id" in `SHOW SLAVE HOSTS` on the master. Unknown what else
2805    /// it impacts.
2806    master_id: RawInt<LeU32>,
2807}
2808
2809impl<'a> ComRegisterSlave<'a> {
2810    /// Creates new `ComRegisterSlave` with the given server identifier. Other fields will be empty.
2811    pub fn new(server_id: u32) -> Self {
2812        Self {
2813            header: Default::default(),
2814            server_id: RawInt::new(server_id),
2815            hostname: Default::default(),
2816            user: Default::default(),
2817            password: Default::default(),
2818            port: Default::default(),
2819            replication_rank: Default::default(),
2820            master_id: Default::default(),
2821        }
2822    }
2823
2824    /// Sets the `hostname` field of the packet (maximum length is 255 bytes).
2825    pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2826        self.hostname = RawBytes::new(hostname);
2827        self
2828    }
2829
2830    /// Sets the `user` field of the packet (maximum length is 255 bytes).
2831    pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2832        self.user = RawBytes::new(user);
2833        self
2834    }
2835
2836    /// Sets the `password` field of the packet (maximum length is 255 bytes).
2837    pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2838        self.password = RawBytes::new(password);
2839        self
2840    }
2841
2842    /// Sets the `port` field of the packet.
2843    pub fn with_port(mut self, port: u16) -> Self {
2844        self.port = RawInt::new(port);
2845        self
2846    }
2847
2848    /// Sets the `replication_rank` field of the packet.
2849    pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2850        self.replication_rank = RawInt::new(replication_rank);
2851        self
2852    }
2853
2854    /// Sets the `master_id` field of the packet.
2855    pub fn with_master_id(mut self, master_id: u32) -> Self {
2856        self.master_id = RawInt::new(master_id);
2857        self
2858    }
2859
2860    /// Returns the `server_id` field of the packet.
2861    pub fn server_id(&self) -> u32 {
2862        self.server_id.0
2863    }
2864
2865    /// Returns the raw `hostname` field value.
2866    pub fn hostname_raw(&self) -> &[u8] {
2867        self.hostname.as_bytes()
2868    }
2869
2870    /// Returns the `hostname` field as a UTF-8 string (lossy converted).
2871    pub fn hostname(&'a self) -> Cow<'a, str> {
2872        self.hostname.as_str()
2873    }
2874
2875    /// Returns the raw `user` field value.
2876    pub fn user_raw(&self) -> &[u8] {
2877        self.user.as_bytes()
2878    }
2879
2880    /// Returns the `user` field as a UTF-8 string (lossy converted).
2881    pub fn user(&'a self) -> Cow<'a, str> {
2882        self.user.as_str()
2883    }
2884
2885    /// Returns the raw `password` field value.
2886    pub fn password_raw(&self) -> &[u8] {
2887        self.password.as_bytes()
2888    }
2889
2890    /// Returns the `password` field as a UTF-8 string (lossy converted).
2891    pub fn password(&'a self) -> Cow<'a, str> {
2892        self.password.as_str()
2893    }
2894
2895    /// Returns the `port` field of the packet.
2896    pub fn port(&self) -> u16 {
2897        self.port.0
2898    }
2899
2900    /// Returns the `replication_rank` field of the packet.
2901    pub fn replication_rank(&self) -> u32 {
2902        self.replication_rank.0
2903    }
2904
2905    /// Returns the `master_id` field of the packet.
2906    pub fn master_id(&self) -> u32 {
2907        self.master_id.0
2908    }
2909}
2910
2911impl MySerialize for ComRegisterSlave<'_> {
2912    fn serialize(&self, buf: &mut Vec<u8>) {
2913        self.header.serialize(&mut *buf);
2914        self.server_id.serialize(&mut *buf);
2915        self.hostname.serialize(&mut *buf);
2916        self.user.serialize(&mut *buf);
2917        self.password.serialize(&mut *buf);
2918        self.port.serialize(&mut *buf);
2919        self.replication_rank.serialize(&mut *buf);
2920        self.master_id.serialize(&mut *buf);
2921    }
2922}
2923
2924impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2925    const SIZE: Option<usize> = None;
2926    type Ctx = ();
2927
2928    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2929        let mut sbuf: ParseBuf<'_> = buf.parse(5)?;
2930        let header = sbuf.parse_unchecked(())?;
2931        let server_id = sbuf.parse_unchecked(())?;
2932
2933        let hostname = buf.parse(())?;
2934        let user = buf.parse(())?;
2935        let password = buf.parse(())?;
2936
2937        let mut sbuf: ParseBuf<'_> = buf.parse(10)?;
2938        let port = sbuf.parse_unchecked(())?;
2939        let replication_rank = sbuf.parse_unchecked(())?;
2940        let master_id = sbuf.parse_unchecked(())?;
2941
2942        Ok(Self {
2943            header,
2944            server_id,
2945            hostname,
2946            user,
2947            password,
2948            port,
2949            replication_rank,
2950            master_id,
2951        })
2952    }
2953}
2954
2955define_header!(
2956    ComTableDumpHeader,
2957    COM_TABLE_DUMP,
2958    InvalidComTableDumpHeader
2959);
2960
2961/// COM_TABLE_DUMP command.
2962#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2963pub struct ComTableDump<'a> {
2964    header: ComTableDumpHeader,
2965    /// Database name.
2966    ///
2967    /// # Note
2968    ///
2969    /// Serialization will truncate this value if length is greater than 255 bytes.
2970    database: RawBytes<'a, U8Bytes>,
2971    /// Table name.
2972    ///
2973    /// # Note
2974    ///
2975    /// Serialization will truncate this value if length is greater than 255 bytes.
2976    table: RawBytes<'a, U8Bytes>,
2977}
2978
2979impl<'a> ComTableDump<'a> {
2980    /// Creates new instance.
2981    pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2982        Self {
2983            header: Default::default(),
2984            database: RawBytes::new(database),
2985            table: RawBytes::new(table),
2986        }
2987    }
2988
2989    /// Returns the raw `database` field value.
2990    pub fn database_raw(&self) -> &[u8] {
2991        self.database.as_bytes()
2992    }
2993
2994    /// Returns the `database` field value as a UTF-8 string (lossy converted).
2995    pub fn database(&self) -> Cow<'_, str> {
2996        self.database.as_str()
2997    }
2998
2999    /// Returns the raw `table` field value.
3000    pub fn table_raw(&self) -> &[u8] {
3001        self.table.as_bytes()
3002    }
3003
3004    /// Returns the `table` field value as a UTF-8 string (lossy converted).
3005    pub fn table(&self) -> Cow<'_, str> {
3006        self.table.as_str()
3007    }
3008}
3009
3010impl MySerialize for ComTableDump<'_> {
3011    fn serialize(&self, buf: &mut Vec<u8>) {
3012        self.header.serialize(&mut *buf);
3013        self.database.serialize(&mut *buf);
3014        self.table.serialize(&mut *buf);
3015    }
3016}
3017
3018impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
3019    const SIZE: Option<usize> = None;
3020    type Ctx = ();
3021
3022    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3023        Ok(Self {
3024            header: buf.parse(())?,
3025            database: buf.parse(())?,
3026            table: buf.parse(())?,
3027        })
3028    }
3029}
3030
3031my_bitflags! {
3032    BinlogDumpFlags,
3033    #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={0:b})")]
3034    UnknownBinlogDumpFlags,
3035    u16,
3036
3037    /// Empty flags of a `LoadEvent`.
3038    #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
3039    pub struct BinlogDumpFlags: u16 {
3040        /// If there is no more event to send a EOF_Packet instead of blocking the connection
3041        const BINLOG_DUMP_NON_BLOCK = 0x01;
3042        const BINLOG_THROUGH_POSITION = 0x02;
3043        const BINLOG_THROUGH_GTID = 0x04;
3044    }
3045}
3046
3047define_header!(
3048    ComBinlogDumpHeader,
3049    COM_BINLOG_DUMP,
3050    InvalidComBinlogDumpHeader
3051);
3052
3053/// Command to request a binlog-stream from the master starting a given position.
3054#[derive(Clone, Debug, Eq, PartialEq, Hash)]
3055pub struct ComBinlogDump<'a> {
3056    header: ComBinlogDumpHeader,
3057    /// Position in the binlog-file to start the stream with (`0` by default).
3058    pos: RawInt<LeU32>,
3059    /// Command flags (empty by default).
3060    ///
3061    /// Only `BINLOG_DUMP_NON_BLOCK` is supported for this command.
3062    flags: Const<BinlogDumpFlags, LeU16>,
3063    /// Server id of this slave.
3064    server_id: RawInt<LeU32>,
3065    /// Filename of the binlog on the master.
3066    ///
3067    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3068    /// binlog.
3069    filename: RawBytes<'a, EofBytes>,
3070}
3071
3072impl<'a> ComBinlogDump<'a> {
3073    /// Creates new instance with default values for `pos` and `flags`.
3074    pub fn new(server_id: u32) -> Self {
3075        Self {
3076            header: Default::default(),
3077            pos: Default::default(),
3078            flags: Default::default(),
3079            server_id: RawInt::new(server_id),
3080            filename: Default::default(),
3081        }
3082    }
3083
3084    /// Defines position for this instance.
3085    pub fn with_pos(mut self, pos: u32) -> Self {
3086        self.pos = RawInt::new(pos);
3087        self
3088    }
3089
3090    /// Defines flags for this instance.
3091    pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3092        self.flags = Const::new(flags);
3093        self
3094    }
3095
3096    /// Defines filename for this instance.
3097    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3098        self.filename = RawBytes::new(filename);
3099        self
3100    }
3101
3102    /// Returns parsed `pos` field with unknown bits truncated.
3103    pub fn pos(&self) -> u32 {
3104        *self.pos
3105    }
3106
3107    /// Returns parsed `flags` field with unknown bits truncated.
3108    pub fn flags(&self) -> BinlogDumpFlags {
3109        *self.flags
3110    }
3111
3112    /// Returns parsed `server_id` field with unknown bits truncated.
3113    pub fn server_id(&self) -> u32 {
3114        *self.server_id
3115    }
3116
3117    /// Returns the raw `filename` field value.
3118    pub fn filename_raw(&self) -> &[u8] {
3119        self.filename.as_bytes()
3120    }
3121
3122    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3123    pub fn filename(&self) -> Cow<'_, str> {
3124        self.filename.as_str()
3125    }
3126}
3127
3128impl MySerialize for ComBinlogDump<'_> {
3129    fn serialize(&self, buf: &mut Vec<u8>) {
3130        self.header.serialize(&mut *buf);
3131        self.pos.serialize(&mut *buf);
3132        self.flags.serialize(&mut *buf);
3133        self.server_id.serialize(&mut *buf);
3134        self.filename.serialize(&mut *buf);
3135    }
3136}
3137
3138impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3139    const SIZE: Option<usize> = None;
3140    type Ctx = ();
3141
3142    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3143        let mut sbuf: ParseBuf<'_> = buf.parse(11)?;
3144        Ok(Self {
3145            header: sbuf.parse_unchecked(())?,
3146            pos: sbuf.parse_unchecked(())?,
3147            flags: sbuf.parse_unchecked(())?,
3148            server_id: sbuf.parse_unchecked(())?,
3149            filename: buf.parse(())?,
3150        })
3151    }
3152}
3153
3154/// GnoInterval. Stored within [`Sid`]
3155#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3156pub struct GnoInterval {
3157    start: RawInt<LeU64>,
3158    end: RawInt<LeU64>,
3159}
3160
3161impl GnoInterval {
3162    /// Creates a new interval.
3163    pub fn new(start: u64, end: u64) -> Self {
3164        Self {
3165            start: RawInt::new(start),
3166            end: RawInt::new(end),
3167        }
3168    }
3169    /// Checks if the [start, end) interval is valid and creates it.
3170    pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3171        if start >= end {
3172            return Err(io::Error::new(
3173                io::ErrorKind::InvalidData,
3174                format!("start({}) >= end({}) in GnoInterval", start, end),
3175            ));
3176        }
3177        if start == 0 || end == 0 {
3178            return Err(io::Error::new(
3179                io::ErrorKind::InvalidData,
3180                "Gno can't be zero",
3181            ));
3182        }
3183        Ok(Self::new(start, end))
3184    }
3185}
3186
3187impl MySerialize for GnoInterval {
3188    fn serialize(&self, buf: &mut Vec<u8>) {
3189        self.start.serialize(&mut *buf);
3190        self.end.serialize(&mut *buf);
3191    }
3192}
3193
3194impl<'de> MyDeserialize<'de> for GnoInterval {
3195    const SIZE: Option<usize> = Some(16);
3196    type Ctx = ();
3197
3198    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3199        Ok(Self {
3200            start: buf.parse_unchecked(())?,
3201            end: buf.parse_unchecked(())?,
3202        })
3203    }
3204}
3205
3206/// Length of a Uuid in `COM_BINLOG_DUMP_GTID` command packet.
3207pub const UUID_LEN: usize = 16;
3208
3209/// SID is a part of the `COM_BINLOG_DUMP_GTID` command. It's a GtidSet whose
3210/// has only one Uuid.
3211#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3212pub struct Sid<'a> {
3213    uuid: [u8; UUID_LEN],
3214    intervals: Seq<'a, GnoInterval, LeU64>,
3215}
3216
3217impl Sid<'_> {
3218    /// Creates a new instance.
3219    pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3220        Self {
3221            uuid,
3222            intervals: Default::default(),
3223        }
3224    }
3225
3226    /// Returns the `uuid` field value.
3227    pub fn uuid(&self) -> [u8; UUID_LEN] {
3228        self.uuid
3229    }
3230
3231    /// Returns the `intervals` field value.
3232    pub fn intervals(&self) -> &[GnoInterval] {
3233        &self.intervals[..]
3234    }
3235
3236    /// Appends an GnoInterval to this block.
3237    pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3238        let mut intervals = self.intervals.0.into_owned();
3239        intervals.push(interval);
3240        self.intervals = Seq::new(intervals);
3241        self
3242    }
3243
3244    /// Sets the `intevals` value for this block.
3245    pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3246        self.intervals = Seq::new(intervals);
3247        self
3248    }
3249
3250    fn len(&self) -> u64 {
3251        use saturating::Saturating as S;
3252        let mut len = S(UUID_LEN as u64); // SID
3253        len += S(8); // n_intervals
3254        len += S((self.intervals.len() * 16) as u64);
3255        len.0
3256    }
3257}
3258
3259impl MySerialize for Sid<'_> {
3260    fn serialize(&self, buf: &mut Vec<u8>) {
3261        self.uuid.serialize(&mut *buf);
3262        self.intervals.serialize(buf);
3263    }
3264}
3265
3266impl<'de> MyDeserialize<'de> for Sid<'de> {
3267    const SIZE: Option<usize> = None;
3268    type Ctx = ();
3269
3270    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3271        Ok(Self {
3272            uuid: buf.parse(())?,
3273            intervals: buf.parse(())?,
3274        })
3275    }
3276}
3277
3278impl Sid<'_> {
3279    fn wrap_err(msg: String) -> io::Error {
3280        io::Error::new(io::ErrorKind::InvalidInput, msg)
3281    }
3282
3283    fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3284        let n: u64 = to_parse.parse().map_err(|e| {
3285            Sid::wrap_err(format!(
3286                "invalid GnoInterval format: {}, error: {}",
3287                full, e
3288            ))
3289        })?;
3290        Ok(n)
3291    }
3292}
3293
3294impl FromStr for Sid<'_> {
3295    type Err = io::Error;
3296
3297    fn from_str(s: &str) -> Result<Self, Self::Err> {
3298        let (uuid, intervals) = s
3299            .split_once(':')
3300            .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3301        let uuid = Uuid::parse_str(uuid)
3302            .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3303        let intervals = intervals
3304            .split(':')
3305            .map(|interval| {
3306                let nums = interval.split('-').collect::<Vec<_>>();
3307                if nums.len() != 1 && nums.len() != 2 {
3308                    return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3309                }
3310                if nums.len() == 1 {
3311                    let start = Sid::parse_interval_num(nums[0], s)?;
3312                    let interval = GnoInterval::check_and_new(start, start + 1)?;
3313                    Ok(interval)
3314                } else {
3315                    let start = Sid::parse_interval_num(nums[0], s)?;
3316                    let end = Sid::parse_interval_num(nums[1], s)?;
3317                    let interval = GnoInterval::check_and_new(start, end + 1)?;
3318                    Ok(interval)
3319                }
3320            })
3321            .collect::<Result<Vec<_>, _>>()?;
3322        Ok(Self {
3323            uuid: *uuid.as_bytes(),
3324            intervals: Seq::new(intervals),
3325        })
3326    }
3327}
3328
3329define_header!(
3330    ComBinlogDumpGtidHeader,
3331    COM_BINLOG_DUMP_GTID,
3332    InvalidComBinlogDumpGtidHeader
3333);
3334
3335/// Command to request a binlog-stream from the master starting a given position.
3336#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3337pub struct ComBinlogDumpGtid<'a> {
3338    header: ComBinlogDumpGtidHeader,
3339    /// Command flags (empty by default).
3340    flags: Const<BinlogDumpFlags, LeU16>,
3341    /// Server id of this slave.
3342    server_id: RawInt<LeU32>,
3343    /// Filename of the binlog on the master.
3344    ///
3345    /// If the binlog-filename is empty, the server will send the binlog-stream of the first known
3346    /// binlog.
3347    ///
3348    /// # Note
3349    ///
3350    /// Serialization will truncate this value if length is greater than 2^32 - 1 bytes.
3351    filename: RawBytes<'a, U32Bytes>,
3352    /// Position in the binlog-file to start the stream with (`0` by default).
3353    pos: RawInt<LeU64>,
3354    /// SID block.
3355    sid_block: Seq<'a, Sid<'a>, LeU64>,
3356}
3357
3358impl<'a> ComBinlogDumpGtid<'a> {
3359    /// Creates new instance with default values for `pos`, `data` and `flags` fields.
3360    pub fn new(server_id: u32) -> Self {
3361        Self {
3362            header: Default::default(),
3363            pos: Default::default(),
3364            flags: Default::default(),
3365            server_id: RawInt::new(server_id),
3366            filename: Default::default(),
3367            sid_block: Default::default(),
3368        }
3369    }
3370
3371    /// Returns the `server_id` field value.
3372    pub fn server_id(&self) -> u32 {
3373        self.server_id.0
3374    }
3375
3376    /// Returns the `flags` field value.
3377    pub fn flags(&self) -> BinlogDumpFlags {
3378        self.flags.0
3379    }
3380
3381    /// Returns the `filename` field value.
3382    pub fn filename_raw(&self) -> &[u8] {
3383        self.filename.as_bytes()
3384    }
3385
3386    /// Returns the `filename` field value as a UTF-8 string (lossy converted).
3387    pub fn filename(&self) -> Cow<'_, str> {
3388        self.filename.as_str()
3389    }
3390
3391    /// Returns the `pos` field value.
3392    pub fn pos(&self) -> u64 {
3393        self.pos.0
3394    }
3395
3396    /// Returns the sequence of sids in this packet.
3397    pub fn sids(&self) -> &[Sid<'a>] {
3398        &self.sid_block
3399    }
3400
3401    /// Defines filename for this instance.
3402    pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3403        Self {
3404            header: self.header,
3405            flags: self.flags,
3406            server_id: self.server_id,
3407            filename: RawBytes::new(filename),
3408            pos: self.pos,
3409            sid_block: self.sid_block,
3410        }
3411    }
3412
3413    /// Sets the `server_id` field value.
3414    pub fn with_server_id(mut self, server_id: u32) -> Self {
3415        self.server_id.0 = server_id;
3416        self
3417    }
3418
3419    /// Sets the `flags` field value.
3420    pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3421        if self.sid_block.is_empty() {
3422            flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3423        } else {
3424            flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3425        }
3426        self.flags.0 = flags;
3427        self
3428    }
3429
3430    /// Sets the `pos` field value.
3431    pub fn with_pos(mut self, pos: u64) -> Self {
3432        self.pos.0 = pos;
3433        self
3434    }
3435
3436    /// Sets the `sid_block` field value.
3437    pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3438        self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3439        self.sid_block.push(sid);
3440        self
3441    }
3442
3443    /// Sets the `sid_block` field value.
3444    pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3445        self.sid_block = Seq::new(sids);
3446        if self.sid_block.is_empty() {
3447            self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3448        } else {
3449            self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3450        }
3451        self
3452    }
3453
3454    fn sid_block_len(&self) -> u32 {
3455        use saturating::Saturating as S;
3456        let mut len = S(8); // n_sids
3457        for sid in self.sid_block.iter() {
3458            len += S(sid.len() as u32);
3459        }
3460        len.0
3461    }
3462}
3463
3464impl MySerialize for ComBinlogDumpGtid<'_> {
3465    fn serialize(&self, buf: &mut Vec<u8>) {
3466        self.header.serialize(&mut *buf);
3467        self.flags.serialize(&mut *buf);
3468        self.server_id.serialize(&mut *buf);
3469        self.filename.serialize(&mut *buf);
3470        self.pos.serialize(&mut *buf);
3471        buf.put_u32_le(self.sid_block_len());
3472        self.sid_block.serialize(&mut *buf);
3473    }
3474}
3475
3476impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3477    const SIZE: Option<usize> = None;
3478    type Ctx = ();
3479
3480    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3481        let mut sbuf: ParseBuf<'_> = buf.parse(7)?;
3482        let header = sbuf.parse_unchecked(())?;
3483        let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3484        let server_id = sbuf.parse_unchecked(())?;
3485
3486        let filename = buf.parse(())?;
3487        let pos = buf.parse(())?;
3488
3489        // `flags` should contain `BINLOG_THROUGH_GTID` flag if sid_block isn't empty
3490        let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3491        let mut buf: ParseBuf<'_> = buf.parse(sid_data_len.0 as usize)?;
3492        let sid_block = buf.parse(())?;
3493
3494        Ok(Self {
3495            header,
3496            flags,
3497            server_id,
3498            filename,
3499            pos,
3500            sid_block,
3501        })
3502    }
3503}
3504
3505define_header!(
3506    SemiSyncAckPacketPacketHeader,
3507    InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3508    0xEF
3509);
3510
3511/// Each Semi Sync Binlog Event with the `SEMI_SYNC_ACK_REQ` flag set the slave has to acknowledge
3512/// with Semi-Sync ACK packet.
3513pub struct SemiSyncAckPacket<'a> {
3514    header: SemiSyncAckPacketPacketHeader,
3515    position: RawInt<LeU64>,
3516    filename: RawBytes<'a, EofBytes>,
3517}
3518
3519impl<'a> SemiSyncAckPacket<'a> {
3520    pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3521        Self {
3522            header: Default::default(),
3523            position: RawInt::new(position),
3524            filename: RawBytes::new(filename),
3525        }
3526    }
3527
3528    /// Sets the `position` field value.
3529    pub fn with_position(mut self, position: u64) -> Self {
3530        self.position.0 = position;
3531        self
3532    }
3533
3534    /// Sets the `filename` field value.
3535    pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3536        self.filename = RawBytes::new(filename);
3537        self
3538    }
3539
3540    /// Returns the `position` field value.
3541    pub fn position(&self) -> u64 {
3542        self.position.0
3543    }
3544
3545    /// Returns the raw `filename` field value.
3546    pub fn filename_raw(&self) -> &[u8] {
3547        self.filename.as_bytes()
3548    }
3549
3550    /// Returns the `filename` field value as a string (lossy converted).
3551    pub fn filename(&self) -> Cow<'_, str> {
3552        self.filename.as_str()
3553    }
3554}
3555
3556impl MySerialize for SemiSyncAckPacket<'_> {
3557    fn serialize(&self, buf: &mut Vec<u8>) {
3558        self.header.serialize(&mut *buf);
3559        self.position.serialize(&mut *buf);
3560        self.filename.serialize(&mut *buf);
3561    }
3562}
3563
3564impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3565    const SIZE: Option<usize> = None;
3566    type Ctx = ();
3567
3568    fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3569        let mut sbuf: ParseBuf<'_> = buf.parse(9)?;
3570        Ok(Self {
3571            header: sbuf.parse_unchecked(())?,
3572            position: sbuf.parse_unchecked(())?,
3573            filename: buf.parse(())?,
3574        })
3575    }
3576}
3577
3578#[cfg(test)]
3579mod test {
3580    use super::*;
3581    use crate::{
3582        constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3583        proto::{MyDeserialize, MySerialize},
3584    };
3585
3586    proptest::proptest! {
3587        #[test]
3588        fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3589            let cmd = ComTableDump::new(database, table);
3590
3591            let mut output = Vec::new();
3592            cmd.serialize(&mut output);
3593
3594            assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3595        }
3596
3597        #[test]
3598        fn com_binlog_dump_roundtrip(
3599            server_id: u32,
3600            filename: Vec<u8>,
3601            pos: u32,
3602            flags: u16,
3603        ) {
3604            let cmd = ComBinlogDump::new(server_id)
3605                .with_filename(filename)
3606                .with_pos(pos)
3607                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3608
3609            let mut output = Vec::new();
3610            cmd.serialize(&mut output);
3611
3612            assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3613        }
3614
3615        #[test]
3616        fn com_register_slave_roundtrip(
3617            server_id: u32,
3618            hostname in r"\w{0,256}",
3619            user in r"\w{0,256}",
3620            password in r"\w{0,256}",
3621            port: u16,
3622            replication_rank: u32,
3623            master_id: u32,
3624        ) {
3625            let cmd = ComRegisterSlave::new(server_id)
3626                .with_hostname(hostname.as_bytes())
3627                .with_user(user.as_bytes())
3628                .with_password(password.as_bytes())
3629                .with_port(port)
3630                .with_replication_rank(replication_rank)
3631                .with_master_id(master_id);
3632
3633            let mut output = Vec::new();
3634            cmd.serialize(&mut output);
3635            let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3636
3637            if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3638                assert_ne!(cmd, parsed);
3639            } else {
3640                assert_eq!(cmd, parsed);
3641            }
3642        }
3643
3644        #[test]
3645        fn com_binlog_dump_gtid_roundtrip(
3646            flags: u16,
3647            server_id: u32,
3648            filename: Vec<u8>,
3649            pos: u64,
3650            n_sid_blocks in 0_u64..1024,
3651        ) {
3652            let mut cmd = ComBinlogDumpGtid::new(server_id)
3653                .with_filename(filename)
3654                .with_pos(pos)
3655                .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3656
3657            let mut sids = Vec::new();
3658            for i in 0..n_sid_blocks {
3659                let mut block = Sid::new([i as u8; 16]);
3660                for j in 0..i {
3661                    block = block.with_interval(GnoInterval::new(i, j));
3662                }
3663                sids.push(block);
3664            }
3665
3666            cmd = cmd.with_sids(sids);
3667
3668            let mut output = Vec::new();
3669            cmd.serialize(&mut output);
3670
3671            assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3672        }
3673    }
3674
3675    #[test]
3676    fn should_parse_local_infile_packet() {
3677        const LIP: &[u8] = b"\xfbfile_name";
3678
3679        let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3680        assert_eq!(lip.file_name_str(), "file_name");
3681    }
3682
3683    #[test]
3684    fn should_parse_stmt_packet() {
3685        const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3686        const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3687
3688        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3689        assert_eq!(sp.statement_id(), 0x01);
3690        assert_eq!(sp.num_columns(), 0x01);
3691        assert_eq!(sp.num_params(), 0x02);
3692        assert_eq!(sp.warning_count(), 0x00);
3693
3694        let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3695        assert_eq!(sp.statement_id(), 0x01);
3696        assert_eq!(sp.num_columns(), 0x00);
3697        assert_eq!(sp.num_params(), 0x00);
3698        assert_eq!(sp.warning_count(), 0x00);
3699    }
3700
3701    #[test]
3702    fn should_parse_handshake_packet() {
3703        const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3704                             \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3705                             \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3706                             \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3707
3708        const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3709                               \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3710                               \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3711                               \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3712                               \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3713                               \x6f\x72\x64\x00";
3714
3715        const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3716                                \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3717                                \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3718                                \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3719                                \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3720                                \x6f\x72\x64\x00";
3721
3722        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3723        assert_eq!(hsp.protocol_version(), 0x0a);
3724        assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3725        assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3726        assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3727        assert_eq!(hsp.connection_id(), 0x0b);
3728        assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3729        assert_eq!(
3730            hsp.capabilities(),
3731            CapabilityFlags::from_bits_truncate(0xf7ff)
3732        );
3733        assert_eq!(hsp.default_collation(), 0x08);
3734        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3735        assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3736        assert_eq!(hsp.auth_plugin_name_ref(), None);
3737
3738        let mut output = Vec::new();
3739        hsp.serialize(&mut output);
3740        assert_eq!(&output, HSP);
3741
3742        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3743        assert_eq!(hsp.protocol_version(), 0x0a);
3744        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3745        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3746        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3747        assert_eq!(hsp.connection_id(), 0x0a56);
3748        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3749        assert_eq!(
3750            hsp.capabilities(),
3751            CapabilityFlags::from_bits_truncate(0xc00fffff)
3752        );
3753        assert_eq!(hsp.default_collation(), 0x08);
3754        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3755        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3756        assert_eq!(
3757            hsp.auth_plugin_name_ref(),
3758            Some(&b"mysql_native_password"[..])
3759        );
3760
3761        let mut output = Vec::new();
3762        hsp.serialize(&mut output);
3763        assert_eq!(&output, HSP_2);
3764
3765        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3766        assert_eq!(hsp.protocol_version(), 0x0a);
3767        assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3768        assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3769        assert_eq!(hsp.maria_db_server_version_parsed(), None);
3770        assert_eq!(hsp.connection_id(), 0x0a56);
3771        assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3772        assert_eq!(
3773            hsp.capabilities(),
3774            CapabilityFlags::from_bits_truncate(0xc00fffff)
3775        );
3776        assert_eq!(hsp.default_collation(), 0x08);
3777        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3778        assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3779        assert_eq!(
3780            hsp.auth_plugin_name_ref(),
3781            Some(&b"mysql_native_password"[..])
3782        );
3783
3784        let mut output = Vec::new();
3785        hsp.serialize(&mut output);
3786        assert_eq!(&output, HSP_3);
3787    }
3788
3789    #[test]
3790    fn should_parse_err_packet() {
3791        const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3792        \x6c\x65\x73\x20\x75\x73\x65\x64";
3793        const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3794        \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3795        const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3796
3797        let err_packet = ErrPacket::deserialize(
3798            CapabilityFlags::CLIENT_PROTOCOL_41,
3799            &mut ParseBuf(ERR_PACKET),
3800        )
3801        .unwrap();
3802        let err_packet = err_packet.server_error();
3803        assert_eq!(err_packet.error_code(), 1096);
3804        assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3805        assert_eq!(err_packet.message_str(), "No tables used");
3806
3807        let err_packet =
3808            ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3809                .unwrap();
3810        let server_error = err_packet.server_error();
3811        assert_eq!(server_error.error_code(), 1040);
3812        assert_eq!(server_error.sql_state_ref(), None);
3813        assert_eq!(server_error.message_str(), "Too many connections");
3814
3815        let err_packet = ErrPacket::deserialize(
3816            CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3817            &mut ParseBuf(PROGRESS_PACKET),
3818        )
3819        .unwrap();
3820        assert!(err_packet.is_progress_report());
3821        let progress_report = err_packet.progress_report();
3822        assert_eq!(progress_report.stage(), 1);
3823        assert_eq!(progress_report.max_stage(), 10);
3824        assert_eq!(progress_report.progress(), 23500);
3825        assert_eq!(progress_report.stage_info_str(), "stage name");
3826    }
3827
3828    #[test]
3829    fn should_parse_column_packet() {
3830        const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3831              \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3832        let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3833        assert_eq!(column.schema_str(), "schema");
3834        assert_eq!(column.table_str(), "table");
3835        assert_eq!(column.org_table_str(), "org_table");
3836        assert_eq!(column.name_str(), "name");
3837        assert_eq!(column.org_name_str(), "org_name");
3838        assert_eq!(
3839            column.character_set(),
3840            CollationId::UTF8MB3_GENERAL_CI as u16
3841        );
3842        assert_eq!(column.column_length(), 15);
3843        assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3844        assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3845        assert_eq!(column.decimals(), 8);
3846    }
3847
3848    #[test]
3849    fn should_parse_auth_switch_request() {
3850        const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3851                                 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3852                                 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3853        let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3854        assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3855        assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3856    }
3857
3858    #[test]
3859    fn should_parse_auth_more_data() {
3860        const PAYLOAD: &[u8] = b"\x01\x04";
3861        let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3862        assert_eq!(packet.data(), b"\x04",);
3863    }
3864
3865    #[test]
3866    fn should_parse_ok_packet() {
3867        const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3868        const RESULT_SET_TERMINATOR: &[u8] = &[
3869            0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3870            0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3871            0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3872            0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3873            0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3874            0x73, 0x65, 0x63, 0x2e,
3875        ];
3876        const SESS_STATE_SYS_VAR_OK: &[u8] =
3877            b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3878              \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3879        const SESS_STATE_SCHEMA_OK: &[u8] =
3880            b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3881        const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3882        const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3883
3884        // packet starting with 0x00 is not an ok packet if it terminates a result set
3885        OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3886            CapabilityFlags::empty(),
3887            &mut ParseBuf(PLAIN_OK),
3888        )
3889        .unwrap_err();
3890
3891        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3892            CapabilityFlags::empty(),
3893            &mut ParseBuf(PLAIN_OK),
3894        )
3895        .unwrap()
3896        .into();
3897        assert_eq!(ok_packet.affected_rows(), 1);
3898        assert_eq!(ok_packet.last_insert_id(), None);
3899        assert_eq!(
3900            ok_packet.status_flags(),
3901            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3902        );
3903        assert_eq!(ok_packet.warnings(), 0);
3904        assert_eq!(ok_packet.info_ref(), None);
3905        assert_eq!(ok_packet.session_state_info_ref(), None);
3906
3907        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3908            CapabilityFlags::CLIENT_SESSION_TRACK,
3909            &mut ParseBuf(PLAIN_OK),
3910        )
3911        .unwrap()
3912        .into();
3913        assert_eq!(ok_packet.affected_rows(), 1);
3914        assert_eq!(ok_packet.last_insert_id(), None);
3915        assert_eq!(
3916            ok_packet.status_flags(),
3917            StatusFlags::SERVER_STATUS_AUTOCOMMIT
3918        );
3919        assert_eq!(ok_packet.warnings(), 0);
3920        assert_eq!(ok_packet.info_ref(), None);
3921        assert_eq!(ok_packet.session_state_info_ref(), None);
3922
3923        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3924            CapabilityFlags::CLIENT_SESSION_TRACK,
3925            &mut ParseBuf(RESULT_SET_TERMINATOR),
3926        )
3927        .unwrap()
3928        .into();
3929        assert_eq!(ok_packet.affected_rows(), 0);
3930        assert_eq!(ok_packet.last_insert_id(), None);
3931        assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3932        assert_eq!(ok_packet.warnings(), 0);
3933        assert_eq!(
3934            ok_packet.info_str(),
3935            Some(Cow::Borrowed(
3936                "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3937            ))
3938        );
3939        assert_eq!(ok_packet.session_state_info_ref(), None);
3940
3941        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3942            CapabilityFlags::CLIENT_SESSION_TRACK,
3943            &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3944        )
3945        .unwrap()
3946        .into();
3947        assert_eq!(ok_packet.affected_rows(), 0);
3948        assert_eq!(ok_packet.last_insert_id(), None);
3949        assert_eq!(
3950            ok_packet.status_flags(),
3951            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3952        );
3953        assert_eq!(ok_packet.warnings(), 0);
3954        assert_eq!(ok_packet.info_ref(), None);
3955        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3956
3957        match sess_state_info.decode().unwrap() {
3958            SessionStateChange::SystemVariables(mut vals) => {
3959                let val = vals.pop().unwrap();
3960                assert_eq!(val.name_bytes(), b"autocommit");
3961                assert_eq!(val.value_bytes(), b"OFF");
3962                assert!(vals.is_empty());
3963            }
3964            _ => panic!(),
3965        }
3966
3967        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3968            CapabilityFlags::CLIENT_SESSION_TRACK,
3969            &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3970        )
3971        .unwrap()
3972        .into();
3973        assert_eq!(ok_packet.affected_rows(), 0);
3974        assert_eq!(ok_packet.last_insert_id(), None);
3975        assert_eq!(
3976            ok_packet.status_flags(),
3977            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3978        );
3979        assert_eq!(ok_packet.warnings(), 0);
3980        assert_eq!(ok_packet.info_ref(), None);
3981        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3982        match sess_state_info.decode().unwrap() {
3983            SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3984            _ => panic!(),
3985        }
3986
3987        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3988            CapabilityFlags::CLIENT_SESSION_TRACK,
3989            &mut ParseBuf(SESS_STATE_TRACK_OK),
3990        )
3991        .unwrap()
3992        .into();
3993        assert_eq!(ok_packet.affected_rows(), 0);
3994        assert_eq!(ok_packet.last_insert_id(), None);
3995        assert_eq!(
3996            ok_packet.status_flags(),
3997            StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3998        );
3999        assert_eq!(ok_packet.warnings(), 0);
4000        assert_eq!(ok_packet.info_ref(), None);
4001        let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
4002        assert_eq!(
4003            sess_state_info.decode().unwrap(),
4004            SessionStateChange::IsTracked(true),
4005        );
4006
4007        let ok_packet: OkPacket<'_> = OkPacketDeserializer::<OldEofPacket>::deserialize(
4008            CapabilityFlags::CLIENT_SESSION_TRACK,
4009            &mut ParseBuf(EOF),
4010        )
4011        .unwrap()
4012        .into();
4013        assert_eq!(ok_packet.affected_rows(), 0);
4014        assert_eq!(ok_packet.last_insert_id(), None);
4015        assert_eq!(
4016            ok_packet.status_flags(),
4017            StatusFlags::SERVER_STATUS_AUTOCOMMIT
4018        );
4019        assert_eq!(ok_packet.warnings(), 0);
4020        assert_eq!(ok_packet.info_ref(), None);
4021        assert_eq!(ok_packet.session_state_info_ref(), None);
4022    }
4023
4024    #[test]
4025    fn should_build_handshake_response() {
4026        let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
4027        let response = HandshakeResponse::new(
4028            Some(&[][..]),
4029            (5u16, 5, 5),
4030            Some(&b"root"[..]),
4031            None::<&'static [u8]>,
4032            Some(AuthPlugin::MysqlNativePassword),
4033            flags_without_db_name,
4034            None,
4035            1_u32.to_be(),
4036        );
4037        let mut actual = Vec::new();
4038        response.serialize(&mut actual);
4039
4040        let expected: Vec<u8> = [
4041            0x05, 0xa2, 0xae, 0x81, // client capabilities
4042            0x00, 0x00, 0x00, 0x01, // max packet
4043            0x2d, // charset
4044            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4045            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4046            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4047            0x00, // blank scramble
4048            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4049            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4050        ]
4051        .to_vec();
4052
4053        assert_eq!(expected, actual);
4054
4055        let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
4056        let response = HandshakeResponse::new(
4057            Some(&[][..]),
4058            (5u16, 5, 5),
4059            Some(&b"root"[..]),
4060            Some(&b"mydb"[..]),
4061            Some(AuthPlugin::MysqlNativePassword),
4062            flags_with_db_name,
4063            None,
4064            1_u32.to_be(),
4065        );
4066        let mut actual = Vec::new();
4067        response.serialize(&mut actual);
4068
4069        let expected: Vec<u8> = [
4070            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4071            0x00, 0x00, 0x00, 0x01, // max packet
4072            0x2d, // charset
4073            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4074            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4075            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4076            0x00, // blank scramble
4077            0x6d, 0x79, 0x64, 0x62, 0x00, // dbname
4078            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4079            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4080        ]
4081        .to_vec();
4082
4083        assert_eq!(expected, actual);
4084
4085        let response = HandshakeResponse::new(
4086            Some(&[][..]),
4087            (5u16, 5, 5),
4088            Some(&b"root"[..]),
4089            Some(&b"mydb"[..]),
4090            Some(AuthPlugin::MysqlNativePassword),
4091            flags_without_db_name,
4092            None,
4093            1_u32.to_be(),
4094        );
4095        let mut actual = Vec::new();
4096        response.serialize(&mut actual);
4097        assert_eq!(expected, actual);
4098
4099        let response = HandshakeResponse::new(
4100            Some(&[][..]),
4101            (5u16, 5, 5),
4102            Some(&b"root"[..]),
4103            Some(&[][..]),
4104            Some(AuthPlugin::MysqlNativePassword),
4105            flags_with_db_name,
4106            None,
4107            1_u32.to_be(),
4108        );
4109        let mut actual = Vec::new();
4110        response.serialize(&mut actual);
4111
4112        let expected: Vec<u8> = [
4113            0x0d, 0xa2, 0xae, 0x81, // client capabilities
4114            0x00, 0x00, 0x00, 0x01, // max packet
4115            0x2d, // charset
4116            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4117            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4118            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4119            0x00, // blank db_name
4120            0x00, // blank scramble
4121            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4122            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4123        ]
4124        .to_vec();
4125        assert_eq!(expected, actual);
4126    }
4127
4128    #[test]
4129    fn should_parse_handshake_packet_with_mariadb_ext_capabilities() {
4130        const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\
4131                             \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
4132                             \x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\
4133                             \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
4134
4135        let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
4136        assert_eq!(hsp.protocol_version(), 0x0a);
4137        assert_eq!(hsp.server_version_str(), "5.5.5-11.4.7-MariaDB-log");
4138        assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
4139        assert_eq!(hsp.maria_db_server_version_parsed(), Some((11, 4, 7)));
4140        assert_eq!(hsp.connection_id(), 0x0b);
4141        assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
4142        assert_eq!(
4143            hsp.capabilities(),
4144            CapabilityFlags::from_bits_truncate(0xf7ff)
4145        );
4146        assert_eq!(hsp.default_collation(), 0x08);
4147        assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
4148        assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
4149        assert_eq!(hsp.auth_plugin_name_ref(), None);
4150        assert_eq!(
4151            hsp.mariadb_ext_capabilities(),
4152            MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
4153        );
4154        let mut output = Vec::new();
4155        hsp.serialize(&mut output);
4156        assert_eq!(&output, HSP);
4157    }
4158
4159    #[test]
4160    fn should_build_handshake_response_with_mariadb_capabilities() {
4161        let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
4162        let response = HandshakeResponse::new(
4163            Some(&[][..]),
4164            (5u16, 5, 5),
4165            Some(&b"root"[..]),
4166            None::<&'static [u8]>,
4167            Some(AuthPlugin::MysqlNativePassword),
4168            flags_without_db_name,
4169            None,
4170            1_u32.to_be(),
4171        )
4172        .with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA);
4173        let mut actual = Vec::new();
4174        response.serialize(&mut actual);
4175
4176        let expected: Vec<u8> = [
4177            0x05, 0xa2, 0xae, 0x81, // client capabilities
4178            0x00, 0x00, 0x00, 0x01, // max packet
4179            0x2d, // charset
4180            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4181            0x00, 0x00, 0x00, 0x00, 0x00, // reserved
4182            0x10, 0x00, 0x00, 0x00, // mariadb capabilities
4183            0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
4184            0x00, // blank scramble
4185            0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4186            0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, // mysql_native_password
4187        ]
4188        .to_vec();
4189
4190        assert_eq!(expected, actual);
4191    }
4192
4193    #[test]
4194    fn parse_str_to_sid() {
4195        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4196        let sid = input.parse::<Sid<'_>>().unwrap();
4197        let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4198        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4199        assert_eq!(sid.intervals.len(), 1);
4200        assert_eq!(sid.intervals[0].start.0, 23);
4201        assert_eq!(sid.intervals[0].end.0, 24);
4202
4203        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4204        let sid = input.parse::<Sid<'_>>().unwrap();
4205        assert_eq!(sid.uuid, *expected_sid.as_bytes());
4206        assert_eq!(sid.intervals.len(), 2);
4207        assert_eq!(sid.intervals[0].start.0, 1);
4208        assert_eq!(sid.intervals[0].end.0, 6);
4209        assert_eq!(sid.intervals[1].start.0, 10);
4210        assert_eq!(sid.intervals[1].end.0, 16);
4211
4212        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4213        let e = input.parse::<Sid<'_>>().unwrap_err();
4214        assert_eq!(
4215            e.to_string(),
4216            "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4217        );
4218
4219        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4220        let e = input.parse::<Sid<'_>>().unwrap_err();
4221        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4222
4223        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4224        let e = input.parse::<Sid<'_>>().unwrap_err();
4225        assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4226
4227        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4228        let e = input.parse::<Sid<'_>>().unwrap_err();
4229        assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4230
4231        let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4232        let e = input.parse::<Sid<'_>>().unwrap_err();
4233        assert_eq!(
4234            e.to_string(),
4235            "start(4) >= end(4) in GnoInterval".to_string()
4236        );
4237    }
4238
4239    #[test]
4240    fn should_parse_rsa_public_key_response_packet() {
4241        const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4242
4243        let rsa_public_key_response =
4244            PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4245
4246        assert!(rsa_public_key_response.is_ok());
4247        assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4248    }
4249
4250    #[test]
4251    fn should_build_rsa_public_key_response_packet() {
4252        let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4253
4254        let mut actual = Vec::new();
4255        rsa_public_key_response.serialize(&mut actual);
4256
4257        let expected = b"\x01test".to_vec();
4258
4259        assert_eq!(expected, actual);
4260    }
4261}