1use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::Arc;
16use std::{
17 borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::{
22 constants::{
23 CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, SessionStateType,
24 StatusFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags, MAX_PAYLOAD_LEN,
25 },
26 io::{BufMutExt, ParseBuf},
27 misc::{
28 lenenc_str_len,
29 raw::{
30 bytes::{
31 BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U32Bytes,
32 U8Bytes,
33 },
34 int::{ConstU32, ConstU8, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
35 seq::Seq,
36 Const, Either, RawBytes, RawConst, RawInt, Skip,
37 },
38 unexpected_buf_eof,
39 },
40 proto::{MyDeserialize, MySerialize},
41 value::{ClientSide, SerializationSide, Value},
42};
43
44use self::session_state_change::SessionStateChange;
45
46lazy_static::lazy_static! {
47 static ref MARIADB_VERSION_RE: Regex =
48 Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap();
49 static ref VERSION_RE: Regex = Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap();
50}
51
52macro_rules! define_header {
53 ($name:ident, $err:ident($msg:literal), $val:literal) => {
54 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
55 #[error($msg)]
56 pub struct $err;
57 pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
58 };
59 ($name:ident, $cmd:ident, $err:ident) => {
60 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
61 #[error("Invalid header for {}", stringify!($cmd))]
62 pub struct $err;
63 pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
64 };
65}
66
67macro_rules! define_const {
68 ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
69 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
70 #[error($msg)]
71 pub struct $err;
72 pub type $name = $kind<$err, $val>;
73 };
74}
75
76macro_rules! define_const_bytes {
77 ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr, $len:literal) => {
78 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
79 #[error($msg)]
80 pub struct $err;
81
82 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
83 pub struct $vname;
84
85 impl ConstBytesValue<$len> for $vname {
86 const VALUE: [u8; $len] = $val;
87 type Error = $err;
88 }
89
90 pub type $name = ConstBytes<$vname, $len>;
91 };
92}
93
94pub mod binlog_request;
95pub mod caching_sha2_password;
96pub mod session_state_change;
97
98define_const_bytes!(
99 Catalog,
100 ColumnDefinitionCatalog,
101 InvalidCatalog("Invalid catalog value in the column definition"),
102 *b"\x03def",
103 4
104);
105
106define_const!(
107 ConstU8,
108 FixedLengthFieldsLen,
109 InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
110 0x0c
111);
112
113#[derive(Debug, Default, Clone, Eq, PartialEq)]
115struct ColumnMeta<'a> {
116 schema: RawBytes<'a, LenEnc>,
117 table: RawBytes<'a, LenEnc>,
118 org_table: RawBytes<'a, LenEnc>,
119 name: RawBytes<'a, LenEnc>,
120 org_name: RawBytes<'a, LenEnc>,
121}
122
123impl ColumnMeta<'_> {
124 pub fn into_owned(self) -> ColumnMeta<'static> {
125 ColumnMeta {
126 schema: self.schema.into_owned(),
127 table: self.table.into_owned(),
128 org_table: self.org_table.into_owned(),
129 name: self.name.into_owned(),
130 org_name: self.org_name.into_owned(),
131 }
132 }
133
134 pub fn schema_ref(&self) -> &[u8] {
136 self.schema.as_bytes()
137 }
138
139 pub fn schema_str(&self) -> Cow<'_, str> {
141 String::from_utf8_lossy(self.schema_ref())
142 }
143
144 pub fn table_ref(&self) -> &[u8] {
146 self.table.as_bytes()
147 }
148
149 pub fn table_str(&self) -> Cow<'_, str> {
151 String::from_utf8_lossy(self.table_ref())
152 }
153
154 pub fn org_table_ref(&self) -> &[u8] {
158 self.org_table.as_bytes()
159 }
160
161 pub fn org_table_str(&self) -> Cow<'_, str> {
163 String::from_utf8_lossy(self.org_table_ref())
164 }
165
166 pub fn name_ref(&self) -> &[u8] {
168 self.name.as_bytes()
169 }
170
171 pub fn name_str(&self) -> Cow<'_, str> {
173 String::from_utf8_lossy(self.name_ref())
174 }
175
176 pub fn org_name_ref(&self) -> &[u8] {
180 self.org_name.as_bytes()
181 }
182
183 pub fn org_name_str(&self) -> Cow<'_, str> {
185 String::from_utf8_lossy(self.org_name_ref())
186 }
187}
188
189impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
190 const SIZE: Option<usize> = None;
191 type Ctx = ();
192
193 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
194 Ok(Self {
195 schema: buf.parse_unchecked(())?,
196 table: buf.parse_unchecked(())?,
197 org_table: buf.parse_unchecked(())?,
198 name: buf.parse_unchecked(())?,
199 org_name: buf.parse_unchecked(())?,
200 })
201 }
202}
203
204impl MySerialize for ColumnMeta<'_> {
205 fn serialize(&self, buf: &mut Vec<u8>) {
206 self.schema.serialize(&mut *buf);
207 self.table.serialize(&mut *buf);
208 self.org_table.serialize(&mut *buf);
209 self.name.serialize(&mut *buf);
210 self.org_name.serialize(&mut *buf);
211 }
212}
213
214#[derive(Debug, Clone, Eq, PartialEq)]
216pub struct Column {
217 catalog: ColumnDefinitionCatalog,
218 meta: Arc<ColumnMeta<'static>>,
219 fixed_length_fields_len: FixedLengthFieldsLen,
220 column_length: RawInt<LeU32>,
221 character_set: RawInt<LeU16>,
222 column_type: Const<ColumnType, u8>,
223 flags: Const<ColumnFlags, LeU16>,
224 decimals: RawInt<u8>,
225 __filler: Skip<2>,
226 }
228
229impl<'de> MyDeserialize<'de> for Column {
230 const SIZE: Option<usize> = None;
231 type Ctx = ();
232
233 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
234 let catalog = buf.parse(())?;
235 let meta = Arc::new(buf.parse::<ColumnMeta>(())?.into_owned());
236 let mut buf: ParseBuf = buf.parse(13)?;
237
238 Ok(Column {
239 catalog,
240 meta,
241 fixed_length_fields_len: buf.parse_unchecked(())?,
242 character_set: buf.parse_unchecked(())?,
243 column_length: buf.parse_unchecked(())?,
244 column_type: buf.parse_unchecked(())?,
245 flags: buf.parse_unchecked(())?,
246 decimals: buf.parse_unchecked(())?,
247 __filler: buf.parse_unchecked(())?,
248 })
249 }
250}
251
252impl MySerialize for Column {
253 fn serialize(&self, buf: &mut Vec<u8>) {
254 self.catalog.serialize(&mut *buf);
255 self.meta.serialize(&mut *buf);
256 self.fixed_length_fields_len.serialize(&mut *buf);
257 self.column_length.serialize(&mut *buf);
258 self.character_set.serialize(&mut *buf);
259 self.column_type.serialize(&mut *buf);
260 self.flags.serialize(&mut *buf);
261 self.decimals.serialize(&mut *buf);
262 self.__filler.serialize(&mut *buf);
263 }
264}
265
266impl Column {
267 pub fn new(column_type: ColumnType) -> Self {
268 Self {
269 catalog: Default::default(),
270 meta: Default::default(),
271 fixed_length_fields_len: Default::default(),
272 column_length: Default::default(),
273 character_set: Default::default(),
274 flags: Default::default(),
275 column_type: Const::new(column_type),
276 decimals: Default::default(),
277 __filler: Skip,
278 }
279 }
280
281 pub fn with_schema(mut self, schema: &[u8]) -> Self {
282 Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
283 self
284 }
285
286 pub fn with_table(mut self, table: &[u8]) -> Self {
287 Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
288 self
289 }
290
291 pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
292 Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
293 self
294 }
295
296 pub fn with_name(mut self, name: &[u8]) -> Self {
297 Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
298 self
299 }
300
301 pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
302 Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
303 self
304 }
305
306 pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
307 self.flags = Const::new(flags);
308 self
309 }
310
311 pub fn with_column_length(mut self, column_length: u32) -> Self {
312 self.column_length = RawInt::new(column_length);
313 self
314 }
315
316 pub fn with_character_set(mut self, character_set: u16) -> Self {
317 self.character_set = RawInt::new(character_set);
318 self
319 }
320
321 pub fn with_decimals(mut self, decimals: u8) -> Self {
322 self.decimals = RawInt::new(decimals);
323 self
324 }
325
326 pub fn column_length(&self) -> u32 {
330 *self.column_length
331 }
332
333 pub fn column_type(&self) -> ColumnType {
335 *self.column_type
336 }
337
338 pub fn character_set(&self) -> u16 {
340 *self.character_set
341 }
342
343 pub fn flags(&self) -> ColumnFlags {
345 *self.flags
346 }
347
348 pub fn decimals(&self) -> u8 {
356 *self.decimals
357 }
358
359 #[inline(always)]
361 pub fn schema_ref(&self) -> &[u8] {
362 self.meta.schema_ref()
363 }
364
365 #[inline(always)]
367 pub fn schema_str(&self) -> Cow<'_, str> {
368 self.meta.schema_str()
369 }
370
371 #[inline(always)]
373 pub fn table_ref(&self) -> &[u8] {
374 self.meta.table_ref()
375 }
376
377 #[inline(always)]
379 pub fn table_str(&self) -> Cow<'_, str> {
380 self.meta.table_str()
381 }
382
383 #[inline(always)]
387 pub fn org_table_ref(&self) -> &[u8] {
388 self.meta.org_table_ref()
389 }
390
391 #[inline(always)]
393 pub fn org_table_str(&self) -> Cow<'_, str> {
394 self.meta.org_table_str()
395 }
396
397 #[inline(always)]
399 pub fn name_ref(&self) -> &[u8] {
400 self.meta.name_ref()
401 }
402
403 #[inline(always)]
405 pub fn name_str(&self) -> Cow<'_, str> {
406 self.meta.name_str()
407 }
408
409 #[inline(always)]
413 pub fn org_name_ref(&self) -> &[u8] {
414 self.meta.org_name_ref()
415 }
416
417 #[inline(always)]
419 pub fn org_name_str(&self) -> Cow<'_, str> {
420 self.meta.org_name_str()
421 }
422}
423
424#[derive(Debug, Clone, Eq, PartialEq)]
426pub struct SessionStateInfo<'a> {
427 data_type: Const<SessionStateType, u8>,
428 data: RawBytes<'a, LenEnc>,
429}
430
431impl SessionStateInfo<'_> {
432 pub fn into_owned(self) -> SessionStateInfo<'static> {
433 let SessionStateInfo { data_type, data } = self;
434 SessionStateInfo {
435 data_type,
436 data: data.into_owned(),
437 }
438 }
439
440 pub fn data_type(&self) -> SessionStateType {
441 *self.data_type
442 }
443
444 pub fn data_ref(&self) -> &[u8] {
446 self.data.as_bytes()
447 }
448
449 pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
451 ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
452 }
453}
454
455impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
456 const SIZE: Option<usize> = None;
457 type Ctx = ();
458
459 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
460 Ok(SessionStateInfo {
461 data_type: buf.parse(())?,
462 data: buf.parse(())?,
463 })
464 }
465}
466
467impl MySerialize for SessionStateInfo<'_> {
468 fn serialize(&self, buf: &mut Vec<u8>) {
469 self.data_type.serialize(&mut *buf);
470 self.data.serialize(buf);
471 }
472}
473
474#[derive(Debug, Clone, Eq, PartialEq)]
476pub struct OkPacketBody<'a> {
477 affected_rows: RawInt<LenEnc>,
478 last_insert_id: RawInt<LenEnc>,
479 status_flags: Const<StatusFlags, LeU16>,
480 warnings: RawInt<LeU16>,
481 info: RawBytes<'a, LenEnc>,
482 session_state_info: RawBytes<'a, LenEnc>,
483}
484
485pub trait OkPacketKind {
489 const HEADER: u8;
490
491 fn parse_body<'de>(
492 capabilities: CapabilityFlags,
493 buf: &mut ParseBuf<'de>,
494 ) -> io::Result<OkPacketBody<'de>>;
495}
496
497#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
499pub struct ResultSetTerminator;
500
501impl OkPacketKind for ResultSetTerminator {
502 const HEADER: u8 = 0xFE;
503
504 fn parse_body<'de>(
505 capabilities: CapabilityFlags,
506 buf: &mut ParseBuf<'de>,
507 ) -> io::Result<OkPacketBody<'de>> {
508 buf.parse::<RawInt<LenEnc>>(())?;
513 buf.parse::<RawInt<LenEnc>>(())?;
514
515 let mut sbuf: ParseBuf = buf.parse(4)?;
517 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
518 let warnings = sbuf.parse_unchecked(())?;
519
520 let (info, session_state_info) =
521 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
522 let info = buf.parse(())?;
523 let session_state_info =
524 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
525 buf.parse(())?
526 } else {
527 RawBytes::default()
528 };
529 (info, session_state_info)
530 } else if !buf.is_empty() && buf.0[0] > 0 {
531 let info = buf.parse(())?;
535 (info, RawBytes::default())
536 } else {
537 (RawBytes::default(), RawBytes::default())
538 };
539
540 Ok(OkPacketBody {
541 affected_rows: RawInt::new(0),
542 last_insert_id: RawInt::new(0),
543 status_flags,
544 warnings,
545 info,
546 session_state_info,
547 })
548 }
549}
550
551#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
553pub struct OldEofPacket;
554
555impl OkPacketKind for OldEofPacket {
556 const HEADER: u8 = 0xFE;
557
558 fn parse_body<'de>(
559 _: CapabilityFlags,
560 buf: &mut ParseBuf<'de>,
561 ) -> io::Result<OkPacketBody<'de>> {
562 let mut buf: ParseBuf = buf.parse(4)?;
564 let warnings = buf.parse_unchecked(())?;
565 let status_flags = buf.parse_unchecked(())?;
566
567 Ok(OkPacketBody {
568 affected_rows: RawInt::new(0),
569 last_insert_id: RawInt::new(0),
570 status_flags,
571 warnings,
572 info: RawBytes::new(&[][..]),
573 session_state_info: RawBytes::new(&[][..]),
574 })
575 }
576}
577
578#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
580pub struct NetworkStreamTerminator;
581
582impl OkPacketKind for NetworkStreamTerminator {
583 const HEADER: u8 = 0xFE;
584
585 fn parse_body<'de>(
586 flags: CapabilityFlags,
587 buf: &mut ParseBuf<'de>,
588 ) -> io::Result<OkPacketBody<'de>> {
589 OldEofPacket::parse_body(flags, buf)
590 }
591}
592
593#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
595pub struct CommonOkPacket;
596
597impl OkPacketKind for CommonOkPacket {
598 const HEADER: u8 = 0x00;
599
600 fn parse_body<'de>(
601 capabilities: CapabilityFlags,
602 buf: &mut ParseBuf<'de>,
603 ) -> io::Result<OkPacketBody<'de>> {
604 let affected_rows = buf.parse(())?;
605 let last_insert_id = buf.parse(())?;
606
607 let mut sbuf: ParseBuf = buf.parse(4)?;
609 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
610 let warnings = sbuf.parse_unchecked(())?;
611
612 let (info, session_state_info) =
613 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
614 let info = buf.parse(())?;
615 let session_state_info =
616 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
617 buf.parse(())?
618 } else {
619 RawBytes::default()
620 };
621 (info, session_state_info)
622 } else if !buf.is_empty() && buf.0[0] > 0 {
623 let info = buf.parse(())?;
627 (info, RawBytes::default())
628 } else {
629 (RawBytes::default(), RawBytes::default())
630 };
631
632 Ok(OkPacketBody {
633 affected_rows,
634 last_insert_id,
635 status_flags,
636 warnings,
637 info,
638 session_state_info,
639 })
640 }
641}
642
643impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
644 type Error = io::Error;
645
646 fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
647 Ok(OkPacket {
648 affected_rows: *body.affected_rows,
649 last_insert_id: if *body.last_insert_id == 0 {
650 None
651 } else {
652 Some(*body.last_insert_id)
653 },
654 status_flags: *body.status_flags,
655 warnings: *body.warnings,
656 info: if !body.info.is_empty() {
657 Some(body.info)
658 } else {
659 None
660 },
661 session_state_info: if !body.session_state_info.is_empty() {
662 Some(body.session_state_info)
663 } else {
664 None
665 },
666 })
667 }
668}
669
670#[derive(Debug, Clone, Eq, PartialEq)]
672pub struct OkPacket<'a> {
673 affected_rows: u64,
674 last_insert_id: Option<u64>,
675 status_flags: StatusFlags,
676 warnings: u16,
677 info: Option<RawBytes<'a, LenEnc>>,
678 session_state_info: Option<RawBytes<'a, LenEnc>>,
679}
680
681impl OkPacket<'_> {
682 pub fn into_owned(self) -> OkPacket<'static> {
683 OkPacket {
684 affected_rows: self.affected_rows,
685 last_insert_id: self.last_insert_id,
686 status_flags: self.status_flags,
687 warnings: self.warnings,
688 info: self.info.map(|x| x.into_owned()),
689 session_state_info: self.session_state_info.map(|x| x.into_owned()),
690 }
691 }
692
693 pub fn affected_rows(&self) -> u64 {
695 self.affected_rows
696 }
697
698 pub fn last_insert_id(&self) -> Option<u64> {
700 self.last_insert_id
701 }
702
703 pub fn status_flags(&self) -> StatusFlags {
705 self.status_flags
706 }
707
708 pub fn warnings(&self) -> u16 {
710 self.warnings
711 }
712
713 pub fn info_ref(&self) -> Option<&[u8]> {
715 self.info.as_ref().map(|x| x.as_bytes())
716 }
717
718 pub fn info_str(&self) -> Option<Cow<str>> {
720 self.info.as_ref().map(|x| x.as_str())
721 }
722
723 pub fn session_state_info_ref(&self) -> Option<&[u8]> {
725 self.session_state_info.as_ref().map(|x| x.as_bytes())
726 }
727
728 pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
730 self.session_state_info_ref()
731 .map(|data| {
732 let mut data = ParseBuf(data);
733 let mut entries = Vec::new();
734 while !data.is_empty() {
735 entries.push(data.parse(())?);
736 }
737 Ok(entries)
738 })
739 .transpose()
740 .map(|x| x.unwrap_or_default())
741 }
742}
743
744#[derive(Debug, Clone, PartialEq, Eq)]
745pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
746
747impl<'de, T> OkPacketDeserializer<'de, T> {
748 pub fn into_inner(self) -> OkPacket<'de> {
749 self.0
750 }
751}
752
753impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
754 fn from(x: OkPacketDeserializer<'de, T>) -> Self {
755 x.0
756 }
757}
758
759#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
760#[error("Invalid OK packet header")]
761pub struct InvalidOkPacketHeader;
762
763impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
764 const SIZE: Option<usize> = None;
765 type Ctx = CapabilityFlags;
766
767 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
768 if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
769 let body = T::parse_body(capabilities, buf)?;
770 let ok = OkPacket::try_from(body)?;
771 Ok(Self(ok, PhantomData))
772 } else {
773 Err(io::Error::new(
774 io::ErrorKind::InvalidData,
775 InvalidOkPacketHeader,
776 ))
777 }
778 }
779}
780
781#[derive(Debug, Clone, Eq, PartialEq)]
783pub struct ProgressReport<'a> {
784 stage: RawInt<u8>,
785 max_stage: RawInt<u8>,
786 progress: RawInt<LeU24>,
787 stage_info: RawBytes<'a, LenEnc>,
788}
789
790impl<'a> ProgressReport<'a> {
791 pub fn new(
792 stage: u8,
793 max_stage: u8,
794 progress: u32,
795 stage_info: impl Into<Cow<'a, [u8]>>,
796 ) -> ProgressReport<'a> {
797 ProgressReport {
798 stage: RawInt::new(stage),
799 max_stage: RawInt::new(max_stage),
800 progress: RawInt::new(progress),
801 stage_info: RawBytes::new(stage_info),
802 }
803 }
804
805 pub fn stage(&self) -> u8 {
807 *self.stage
808 }
809
810 pub fn max_stage(&self) -> u8 {
811 *self.max_stage
812 }
813
814 pub fn progress(&self) -> u32 {
816 *self.progress
817 }
818
819 pub fn stage_info_ref(&self) -> &[u8] {
821 self.stage_info.as_bytes()
822 }
823
824 pub fn stage_info_str(&self) -> Cow<'_, str> {
826 self.stage_info.as_str()
827 }
828
829 pub fn into_owned(self) -> ProgressReport<'static> {
830 ProgressReport {
831 stage: self.stage,
832 max_stage: self.max_stage,
833 progress: self.progress,
834 stage_info: self.stage_info.into_owned(),
835 }
836 }
837}
838
839impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
840 const SIZE: Option<usize> = None;
841 type Ctx = ();
842
843 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
844 let mut sbuf: ParseBuf = buf.parse(6)?;
845
846 sbuf.skip(1); Ok(ProgressReport {
849 stage: sbuf.parse_unchecked(())?,
850 max_stage: sbuf.parse_unchecked(())?,
851 progress: sbuf.parse_unchecked(())?,
852 stage_info: buf.parse(())?,
853 })
854 }
855}
856
857impl MySerialize for ProgressReport<'_> {
858 fn serialize(&self, buf: &mut Vec<u8>) {
859 buf.put_u8(1);
860 self.stage.serialize(&mut *buf);
861 self.max_stage.serialize(&mut *buf);
862 self.progress.serialize(&mut *buf);
863 self.stage_info.serialize(buf);
864 }
865}
866
867impl fmt::Display for ProgressReport<'_> {
868 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
869 write!(
870 f,
871 "Stage: {} of {} '{}' {:.2}% of stage done",
872 self.stage(),
873 self.max_stage(),
874 self.progress(),
875 self.stage_info_str()
876 )
877 }
878}
879
880define_header!(
881 ErrPacketHeader,
882 InvalidErrPacketHeader("Invalid error packet header"),
883 0xFF
884);
885
886#[derive(Debug, Clone, PartialEq)]
890pub enum ErrPacket<'a> {
891 Error(ServerError<'a>),
892 Progress(ProgressReport<'a>),
893}
894
895impl ErrPacket<'_> {
896 pub fn is_error(&self) -> bool {
898 matches!(self, ErrPacket::Error { .. })
899 }
900
901 pub fn is_progress_report(&self) -> bool {
903 !self.is_error()
904 }
905
906 pub fn progress_report(&self) -> &ProgressReport<'_> {
908 match *self {
909 ErrPacket::Progress(ref progress_report) => progress_report,
910 _ => panic!("This ErrPacket does not contains progress report"),
911 }
912 }
913
914 pub fn server_error(&self) -> &ServerError<'_> {
916 match self {
917 ErrPacket::Error(error) => error,
918 ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
919 }
920 }
921}
922
923impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
924 const SIZE: Option<usize> = None;
925 type Ctx = CapabilityFlags;
926
927 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
928 let mut sbuf: ParseBuf = buf.parse(3)?;
929 sbuf.parse_unchecked::<ErrPacketHeader>(())?;
930 let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
931
932 if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
933 buf.parse(()).map(ErrPacket::Progress)
934 } else {
935 buf.parse((
936 *code,
937 capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
938 ))
939 .map(ErrPacket::Error)
940 }
941 }
942}
943
944impl MySerialize for ErrPacket<'_> {
945 fn serialize(&self, buf: &mut Vec<u8>) {
946 ErrPacketHeader::new().serialize(&mut *buf);
947 match self {
948 ErrPacket::Error(server_error) => {
949 server_error.code.serialize(&mut *buf);
950 server_error.serialize(buf);
951 }
952 ErrPacket::Progress(progress_report) => {
953 RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
954 progress_report.serialize(buf);
955 }
956 }
957 }
958}
959
960impl fmt::Display for ErrPacket<'_> {
961 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
962 match self {
963 ErrPacket::Error(server_error) => write!(f, "{}", server_error),
964 ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
965 }
966 }
967}
968
969define_header!(
970 SqlStateMarker,
971 InvalidSqlStateMarker("Invalid SqlStateMarker value"),
972 b'#'
973);
974
975#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
977pub struct SqlState {
978 __state_marker: SqlStateMarker,
979 state: [u8; 5],
980}
981
982impl SqlState {
983 pub fn new(state: [u8; 5]) -> Self {
985 Self {
986 __state_marker: SqlStateMarker::new(),
987 state,
988 }
989 }
990
991 pub fn as_bytes(&self) -> [u8; 5] {
993 self.state
994 }
995
996 pub fn as_str(&self) -> Cow<'_, str> {
998 String::from_utf8_lossy(&self.state)
999 }
1000}
1001
1002impl<'de> MyDeserialize<'de> for SqlState {
1003 const SIZE: Option<usize> = Some(6);
1004 type Ctx = ();
1005
1006 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1007 Ok(Self {
1008 __state_marker: buf.parse(())?,
1009 state: buf.parse(())?,
1010 })
1011 }
1012}
1013
1014impl MySerialize for SqlState {
1015 fn serialize(&self, buf: &mut Vec<u8>) {
1016 self.__state_marker.serialize(buf);
1017 self.state.serialize(buf);
1018 }
1019}
1020
1021#[derive(Debug, Clone, PartialEq)]
1025pub struct ServerError<'a> {
1026 code: RawInt<LeU16>,
1027 state: Option<SqlState>,
1028 message: RawBytes<'a, EofBytes>,
1029}
1030
1031impl<'a> ServerError<'a> {
1032 pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1033 Self {
1034 code: RawInt::new(code),
1035 state,
1036 message: RawBytes::new(msg),
1037 }
1038 }
1039
1040 pub fn error_code(&self) -> u16 {
1042 *self.code
1043 }
1044
1045 pub fn sql_state_ref(&self) -> Option<&SqlState> {
1047 self.state.as_ref()
1048 }
1049
1050 pub fn message_ref(&self) -> &[u8] {
1052 self.message.as_bytes()
1053 }
1054
1055 pub fn message_str(&self) -> Cow<'_, str> {
1057 self.message.as_str()
1058 }
1059
1060 pub fn into_owned(self) -> ServerError<'static> {
1061 ServerError {
1062 code: self.code,
1063 state: self.state,
1064 message: self.message.into_owned(),
1065 }
1066 }
1067}
1068
1069impl<'de> MyDeserialize<'de> for ServerError<'de> {
1070 const SIZE: Option<usize> = None;
1071 type Ctx = (u16, bool);
1073
1074 fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1075 let server_error = if protocol_41 {
1076 ServerError {
1077 code: RawInt::new(code),
1078 state: Some(buf.parse(())?),
1079 message: buf.parse(())?,
1080 }
1081 } else {
1082 ServerError {
1083 code: RawInt::new(code),
1084 state: None,
1085 message: buf.parse(())?,
1086 }
1087 };
1088 Ok(server_error)
1089 }
1090}
1091
1092impl MySerialize for ServerError<'_> {
1093 fn serialize(&self, buf: &mut Vec<u8>) {
1094 if let Some(state) = &self.state {
1095 state.serialize(buf);
1096 }
1097 self.message.serialize(buf);
1098 }
1099}
1100
1101impl fmt::Display for ServerError<'_> {
1102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103 let sql_state_str = self
1104 .sql_state_ref()
1105 .map(|s| format!(" ({})", s.as_str()))
1106 .unwrap_or_default();
1107
1108 write!(
1109 f,
1110 "ERROR {}{}: {}",
1111 self.error_code(),
1112 sql_state_str,
1113 self.message_str()
1114 )
1115 }
1116}
1117
1118define_header!(
1119 LocalInfileHeader,
1120 InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1121 0xFB
1122);
1123
1124#[derive(Debug, Clone, Eq, PartialEq)]
1126pub struct LocalInfilePacket<'a> {
1127 __header: LocalInfileHeader,
1128 file_name: RawBytes<'a, EofBytes>,
1129}
1130
1131impl<'a> LocalInfilePacket<'a> {
1132 pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1133 Self {
1134 __header: LocalInfileHeader::new(),
1135 file_name: RawBytes::new(file_name),
1136 }
1137 }
1138
1139 pub fn file_name_ref(&self) -> &[u8] {
1141 self.file_name.as_bytes()
1142 }
1143
1144 pub fn file_name_str(&self) -> Cow<'_, str> {
1146 self.file_name.as_str()
1147 }
1148
1149 pub fn into_owned(self) -> LocalInfilePacket<'static> {
1150 LocalInfilePacket {
1151 __header: self.__header,
1152 file_name: self.file_name.into_owned(),
1153 }
1154 }
1155}
1156
1157impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1158 const SIZE: Option<usize> = None;
1159 type Ctx = ();
1160
1161 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1162 Ok(LocalInfilePacket {
1163 __header: buf.parse(())?,
1164 file_name: buf.parse(())?,
1165 })
1166 }
1167}
1168
1169impl MySerialize for LocalInfilePacket<'_> {
1170 fn serialize(&self, buf: &mut Vec<u8>) {
1171 self.__header.serialize(buf);
1172 self.file_name.serialize(buf);
1173 }
1174}
1175
1176const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1177const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1178const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1179const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1180
1181#[derive(Debug, Clone, PartialEq, Eq)]
1182pub enum AuthPluginData<'a> {
1183 Old([u8; 8]),
1185 Native([u8; 20]),
1187 Sha2([u8; 32]),
1189 Clear(Cow<'a, [u8]>),
1191}
1192
1193impl AuthPluginData<'_> {
1194 pub fn into_owned(self) -> AuthPluginData<'static> {
1195 match self {
1196 AuthPluginData::Old(x) => AuthPluginData::Old(x),
1197 AuthPluginData::Native(x) => AuthPluginData::Native(x),
1198 AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1199 AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1200 }
1201 }
1202}
1203
1204impl std::ops::Deref for AuthPluginData<'_> {
1205 type Target = [u8];
1206
1207 fn deref(&self) -> &Self::Target {
1208 match self {
1209 Self::Sha2(x) => &x[..],
1210 Self::Native(x) => &x[..],
1211 Self::Old(x) => &x[..],
1212 Self::Clear(x) => &x[..],
1213 }
1214 }
1215}
1216
1217impl MySerialize for AuthPluginData<'_> {
1218 fn serialize(&self, buf: &mut Vec<u8>) {
1219 match self {
1220 Self::Sha2(x) => buf.put_slice(&x[..]),
1221 Self::Native(x) => buf.put_slice(&x[..]),
1222 Self::Old(x) => {
1223 buf.put_slice(&x[..]);
1224 buf.push(0);
1225 }
1226 Self::Clear(x) => {
1227 buf.put_slice(x);
1228 buf.push(0);
1229 }
1230 }
1231 }
1232}
1233
1234#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1236pub enum AuthPlugin<'a> {
1237 MysqlOldPassword,
1239 MysqlClearPassword,
1241 MysqlNativePassword,
1243 CachingSha2Password,
1245 Other(Cow<'a, [u8]>),
1246}
1247
1248impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1249 const SIZE: Option<usize> = None;
1250 type Ctx = ();
1251
1252 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1253 Ok(Self::from_bytes(buf.eat_all()))
1254 }
1255}
1256
1257impl MySerialize for AuthPlugin<'_> {
1258 fn serialize(&self, buf: &mut Vec<u8>) {
1259 buf.put_slice(self.as_bytes());
1260 buf.put_u8(0);
1261 }
1262}
1263
1264impl<'a> AuthPlugin<'a> {
1265 pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1266 let name = if let [name @ .., 0] = name {
1267 name
1268 } else {
1269 name
1270 };
1271 match name {
1272 CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1273 MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1274 MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1275 MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1276 name => AuthPlugin::Other(Cow::Borrowed(name)),
1277 }
1278 }
1279
1280 pub fn as_bytes(&self) -> &[u8] {
1281 match self {
1282 AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1283 AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1284 AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1285 AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1286 AuthPlugin::Other(name) => name,
1287 }
1288 }
1289
1290 pub fn into_owned(self) -> AuthPlugin<'static> {
1291 match self {
1292 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1293 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1294 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1295 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1296 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1297 }
1298 }
1299
1300 pub fn borrow(&self) -> AuthPlugin<'_> {
1301 match self {
1302 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1303 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1304 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1305 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1306 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1307 }
1308 }
1309
1310 pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1316 use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1317
1318 match pass {
1319 Some(pass) if !pass.is_empty() => match self {
1320 AuthPlugin::CachingSha2Password => {
1321 scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1322 }
1323 AuthPlugin::MysqlNativePassword => {
1324 scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1325 }
1326 AuthPlugin::MysqlOldPassword => {
1327 scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1328 .map(AuthPluginData::Old)
1329 }
1330 AuthPlugin::MysqlClearPassword => {
1331 Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1332 }
1333 AuthPlugin::Other(_) => None,
1334 },
1335 _ => None,
1336 }
1337 }
1338}
1339
1340define_header!(
1341 AuthMoreDataHeader,
1342 InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1343 0x01
1344);
1345
1346#[derive(Debug, Clone, Eq, PartialEq)]
1348pub struct AuthMoreData<'a> {
1349 __header: AuthMoreDataHeader,
1350 data: RawBytes<'a, EofBytes>,
1351}
1352
1353impl<'a> AuthMoreData<'a> {
1354 pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1355 Self {
1356 __header: AuthMoreDataHeader::new(),
1357 data: RawBytes::new(data),
1358 }
1359 }
1360
1361 pub fn data(&self) -> &[u8] {
1362 self.data.as_bytes()
1363 }
1364
1365 pub fn into_owned(self) -> AuthMoreData<'static> {
1366 AuthMoreData {
1367 __header: self.__header,
1368 data: self.data.into_owned(),
1369 }
1370 }
1371}
1372
1373impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1374 const SIZE: Option<usize> = None;
1375 type Ctx = ();
1376
1377 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1378 Ok(Self {
1379 __header: buf.parse(())?,
1380 data: buf.parse(())?,
1381 })
1382 }
1383}
1384
1385impl MySerialize for AuthMoreData<'_> {
1386 fn serialize(&self, buf: &mut Vec<u8>) {
1387 self.__header.serialize(&mut *buf);
1388 self.data.serialize(buf);
1389 }
1390}
1391
1392define_header!(
1393 PublicKeyResponseHeader,
1394 InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1395 0x01
1396);
1397
1398#[derive(Debug, Clone, Eq, PartialEq)]
1402pub struct PublicKeyResponse<'a> {
1403 __header: PublicKeyResponseHeader,
1404 rsa_key: RawBytes<'a, EofBytes>,
1405}
1406
1407impl<'a> PublicKeyResponse<'a> {
1408 pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1409 Self {
1410 __header: PublicKeyResponseHeader::new(),
1411 rsa_key: RawBytes::new(rsa_key),
1412 }
1413 }
1414
1415 pub fn rsa_key(&self) -> Cow<'_, str> {
1417 self.rsa_key.as_str()
1418 }
1419}
1420
1421impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1422 const SIZE: Option<usize> = None;
1423 type Ctx = ();
1424
1425 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1426 Ok(Self {
1427 __header: buf.parse(())?,
1428 rsa_key: buf.parse(())?,
1429 })
1430 }
1431}
1432
1433impl MySerialize for PublicKeyResponse<'_> {
1434 fn serialize(&self, buf: &mut Vec<u8>) {
1435 self.__header.serialize(&mut *buf);
1436 self.rsa_key.serialize(buf);
1437 }
1438}
1439
1440define_header!(
1441 AuthSwitchRequestHeader,
1442 InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1443 0xFE
1444);
1445
1446#[derive(Debug, Clone, Eq, PartialEq)]
1451pub struct OldAuthSwitchRequest {
1452 __header: AuthSwitchRequestHeader,
1453}
1454
1455impl OldAuthSwitchRequest {
1456 pub fn new() -> Self {
1457 Self {
1458 __header: AuthSwitchRequestHeader::new(),
1459 }
1460 }
1461
1462 pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1463 AuthPlugin::MysqlOldPassword
1464 }
1465}
1466
1467impl Default for OldAuthSwitchRequest {
1468 fn default() -> Self {
1469 Self::new()
1470 }
1471}
1472
1473impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1474 const SIZE: Option<usize> = Some(1);
1475 type Ctx = ();
1476
1477 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1478 Ok(Self {
1479 __header: buf.parse(())?,
1480 })
1481 }
1482}
1483
1484impl MySerialize for OldAuthSwitchRequest {
1485 fn serialize(&self, buf: &mut Vec<u8>) {
1486 self.__header.serialize(&mut *buf);
1487 }
1488}
1489
1490#[derive(Debug, Clone, Eq, PartialEq)]
1495pub struct AuthSwitchRequest<'a> {
1496 __header: AuthSwitchRequestHeader,
1497 auth_plugin: RawBytes<'a, NullBytes>,
1498 plugin_data: RawBytes<'a, EofBytes>,
1499}
1500
1501impl<'a> AuthSwitchRequest<'a> {
1502 pub fn new(
1503 auth_plugin: impl Into<Cow<'a, [u8]>>,
1504 plugin_data: impl Into<Cow<'a, [u8]>>,
1505 ) -> Self {
1506 Self {
1507 __header: AuthSwitchRequestHeader::new(),
1508 auth_plugin: RawBytes::new(auth_plugin),
1509 plugin_data: RawBytes::new(plugin_data),
1510 }
1511 }
1512
1513 pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1514 ParseBuf(self.auth_plugin.as_bytes())
1515 .parse(())
1516 .expect("infallible")
1517 }
1518
1519 pub fn plugin_data(&self) -> &[u8] {
1520 match self.plugin_data.as_bytes() {
1521 [head @ .., 0] => head,
1522 all => all,
1523 }
1524 }
1525
1526 pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1527 AuthSwitchRequest {
1528 __header: self.__header,
1529 auth_plugin: self.auth_plugin.into_owned(),
1530 plugin_data: self.plugin_data.into_owned(),
1531 }
1532 }
1533}
1534
1535impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1536 const SIZE: Option<usize> = None;
1537 type Ctx = ();
1538
1539 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1540 Ok(Self {
1541 __header: buf.parse(())?,
1542 auth_plugin: buf.parse(())?,
1543 plugin_data: buf.parse(())?,
1544 })
1545 }
1546}
1547
1548impl MySerialize for AuthSwitchRequest<'_> {
1549 fn serialize(&self, buf: &mut Vec<u8>) {
1550 self.__header.serialize(&mut *buf);
1551 self.auth_plugin.serialize(&mut *buf);
1552 self.plugin_data.serialize(buf);
1553 }
1554}
1555
1556#[derive(Debug, Clone, Eq, PartialEq)]
1558pub struct HandshakePacket<'a> {
1559 protocol_version: RawInt<u8>,
1560 server_version: RawBytes<'a, NullBytes>,
1561 connection_id: RawInt<LeU32>,
1562 scramble_1: [u8; 8],
1563 __filler: Skip<1>,
1564 capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1566 default_collation: RawInt<u8>,
1567 status_flags: Const<StatusFlags, LeU16>,
1568 capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1570 auth_plugin_data_len: RawInt<u8>,
1571 __reserved: Skip<10>,
1572 scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1573 auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1574}
1575
1576impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1577 const SIZE: Option<usize> = None;
1578 type Ctx = ();
1579
1580 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1581 let protocol_version = buf.parse(())?;
1582 let server_version = buf.parse(())?;
1583
1584 let mut sbuf: ParseBuf = buf.parse(31)?;
1586 let connection_id = sbuf.parse_unchecked(())?;
1587 let scramble_1 = sbuf.parse_unchecked(())?;
1588 let __filler = sbuf.parse_unchecked(())?;
1589 let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1590 let default_collation = sbuf.parse_unchecked(())?;
1591 let status_flags = sbuf.parse_unchecked(())?;
1592 let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1593 let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1594 let __reserved = sbuf.parse_unchecked(())?;
1595 let mut scramble_2 = None;
1596 if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1597 let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1598 scramble_2 = buf.parse(len).map(Some)?;
1599 }
1600 let mut auth_plugin_name = None;
1601 if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1602 auth_plugin_name = match buf.eat_all() {
1603 [head @ .., 0] => Some(RawBytes::new(head)),
1604 all => Some(RawBytes::new(all)),
1606 }
1607 }
1608
1609 Ok(Self {
1610 protocol_version,
1611 server_version,
1612 connection_id,
1613 scramble_1,
1614 __filler,
1615 capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1616 default_collation,
1617 status_flags,
1618 capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1619 auth_plugin_data_len,
1620 __reserved,
1621 scramble_2,
1622 auth_plugin_name,
1623 })
1624 }
1625}
1626
1627impl MySerialize for HandshakePacket<'_> {
1628 fn serialize(&self, buf: &mut Vec<u8>) {
1629 self.protocol_version.serialize(&mut *buf);
1630 self.server_version.serialize(&mut *buf);
1631 self.connection_id.serialize(&mut *buf);
1632 self.scramble_1.serialize(&mut *buf);
1633 buf.put_u8(0x00);
1634 self.capabilities_1.serialize(&mut *buf);
1635 self.default_collation.serialize(&mut *buf);
1636 self.status_flags.serialize(&mut *buf);
1637 self.capabilities_2.serialize(&mut *buf);
1638
1639 if self
1640 .capabilities_2
1641 .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1642 {
1643 buf.put_u8(
1644 self.scramble_2
1645 .as_ref()
1646 .map(|x| (x.len() + 8) as u8)
1647 .unwrap_or_default(),
1648 );
1649 } else {
1650 buf.put_u8(0);
1651 }
1652
1653 buf.put_slice(&[0_u8; 10][..]);
1654
1655 if let Some(scramble_2) = &self.scramble_2 {
1658 scramble_2.serialize(&mut *buf);
1659 }
1660
1661 if let Some(client_plugin_auth) = &self.auth_plugin_name {
1664 client_plugin_auth.serialize(buf);
1665 }
1666 }
1667}
1668
1669impl<'a> HandshakePacket<'a> {
1670 #[allow(clippy::too_many_arguments)]
1671 pub fn new(
1672 protocol_version: u8,
1673 server_version: impl Into<Cow<'a, [u8]>>,
1674 connection_id: u32,
1675 scramble_1: [u8; 8],
1676 scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1677 capabilities: CapabilityFlags,
1678 default_collation: u8,
1679 status_flags: StatusFlags,
1680 auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1681 ) -> Self {
1682 let (capabilities_1, capabilities_2) = (
1686 CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1687 CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1688 );
1689
1690 let scramble_2 = scramble_2.map(RawBytes::new);
1691
1692 HandshakePacket {
1693 protocol_version: RawInt::new(protocol_version),
1694 server_version: RawBytes::new(server_version),
1695 connection_id: RawInt::new(connection_id),
1696 scramble_1,
1697 __filler: Skip,
1698 capabilities_1: Const::new(capabilities_1),
1699 default_collation: RawInt::new(default_collation),
1700 status_flags: Const::new(status_flags),
1701 capabilities_2: Const::new(capabilities_2),
1702 auth_plugin_data_len: RawInt::new(
1703 scramble_2
1704 .as_ref()
1705 .map(|x| x.len() as u8)
1706 .unwrap_or_default(),
1707 ),
1708 __reserved: Skip,
1709 scramble_2,
1710 auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1711 }
1712 }
1713
1714 pub fn into_owned(self) -> HandshakePacket<'static> {
1715 HandshakePacket {
1716 protocol_version: self.protocol_version,
1717 server_version: self.server_version.into_owned(),
1718 connection_id: self.connection_id,
1719 scramble_1: self.scramble_1,
1720 __filler: self.__filler,
1721 capabilities_1: self.capabilities_1,
1722 default_collation: self.default_collation,
1723 status_flags: self.status_flags,
1724 capabilities_2: self.capabilities_2,
1725 auth_plugin_data_len: self.auth_plugin_data_len,
1726 __reserved: self.__reserved,
1727 scramble_2: self.scramble_2.map(|x| x.into_owned()),
1728 auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1729 }
1730 }
1731
1732 pub fn protocol_version(&self) -> u8 {
1734 self.protocol_version.0
1735 }
1736
1737 pub fn server_version_ref(&self) -> &[u8] {
1739 self.server_version.as_bytes()
1740 }
1741
1742 pub fn server_version_str(&self) -> Cow<'_, str> {
1745 self.server_version.as_str()
1746 }
1747
1748 pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1752 VERSION_RE
1753 .captures(self.server_version_ref())
1754 .map(|captures| {
1755 (
1757 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1758 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1759 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1760 )
1761 })
1762 }
1763
1764 pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1766 MARIADB_VERSION_RE
1767 .captures(self.server_version_ref())
1768 .map(|captures| {
1769 (
1771 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1772 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1773 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1774 )
1775 })
1776 }
1777
1778 pub fn connection_id(&self) -> u32 {
1780 self.connection_id.0
1781 }
1782
1783 pub fn scramble_1_ref(&self) -> &[u8] {
1785 self.scramble_1.as_ref()
1786 }
1787
1788 pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1792 self.scramble_2.as_ref().map(|x| x.as_bytes())
1793 }
1794
1795 pub fn nonce(&self) -> Vec<u8> {
1797 let mut out = Vec::from(self.scramble_1_ref());
1798 out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1799
1800 out.resize(20, 0);
1803 out
1804 }
1805
1806 pub fn capabilities(&self) -> CapabilityFlags {
1808 self.capabilities_1.0 | self.capabilities_2.0
1809 }
1810
1811 pub fn default_collation(&self) -> u8 {
1813 self.default_collation.0
1814 }
1815
1816 pub fn status_flags(&self) -> StatusFlags {
1818 self.status_flags.0
1819 }
1820
1821 pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1823 self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1824 }
1825
1826 pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1829 self.auth_plugin_name.as_ref().map(|x| x.as_str())
1830 }
1831
1832 pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1834 self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1835 [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1836 all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1837 })
1838 }
1839}
1840
1841define_header!(
1842 ComChangeUserHeader,
1843 InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1844 0x11
1845);
1846
1847#[derive(Debug, Clone, PartialEq, Eq)]
1848pub struct ComChangeUser<'a> {
1849 __header: ComChangeUserHeader,
1850 user: RawBytes<'a, NullBytes>,
1851 auth_plugin_data: RawBytes<'a, U8Bytes>,
1853 database: RawBytes<'a, NullBytes>,
1854 more_data: Option<ComChangeUserMoreData<'a>>,
1855}
1856
1857impl<'a> ComChangeUser<'a> {
1858 pub fn new() -> Self {
1859 Self {
1860 __header: ComChangeUserHeader::new(),
1861 user: Default::default(),
1862 auth_plugin_data: Default::default(),
1863 database: Default::default(),
1864 more_data: None,
1865 }
1866 }
1867
1868 pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1869 self.user = user.map(RawBytes::new).unwrap_or_default();
1870 self
1871 }
1872
1873 pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1874 self.database = database.map(RawBytes::new).unwrap_or_default();
1875 self
1876 }
1877
1878 pub fn with_auth_plugin_data(
1879 mut self,
1880 auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1881 ) -> Self {
1882 self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1883 self
1884 }
1885
1886 pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1887 self.more_data = more_data;
1888 self
1889 }
1890
1891 pub fn into_owned(self) -> ComChangeUser<'static> {
1892 ComChangeUser {
1893 __header: self.__header,
1894 user: self.user.into_owned(),
1895 auth_plugin_data: self.auth_plugin_data.into_owned(),
1896 database: self.database.into_owned(),
1897 more_data: self.more_data.map(|x| x.into_owned()),
1898 }
1899 }
1900}
1901
1902impl Default for ComChangeUser<'_> {
1903 fn default() -> Self {
1904 Self::new()
1905 }
1906}
1907
1908impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1909 const SIZE: Option<usize> = None;
1910
1911 type Ctx = CapabilityFlags;
1912
1913 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1914 Ok(Self {
1915 __header: buf.parse(())?,
1916 user: buf.parse(())?,
1917 auth_plugin_data: buf.parse(())?,
1918 database: buf.parse(())?,
1919 more_data: if !buf.is_empty() {
1920 Some(buf.parse(flags)?)
1921 } else {
1922 None
1923 },
1924 })
1925 }
1926}
1927
1928impl MySerialize for ComChangeUser<'_> {
1929 fn serialize(&self, buf: &mut Vec<u8>) {
1930 self.__header.serialize(&mut *buf);
1931 self.user.serialize(&mut *buf);
1932 self.auth_plugin_data.serialize(&mut *buf);
1933 self.database.serialize(&mut *buf);
1934 if let Some(ref more_data) = self.more_data {
1935 more_data.serialize(&mut *buf);
1936 }
1937 }
1938}
1939
1940#[derive(Debug, Clone, PartialEq, Eq)]
1941pub struct ComChangeUserMoreData<'a> {
1942 character_set: RawInt<LeU16>,
1943 auth_plugin: Option<AuthPlugin<'a>>,
1944 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1945}
1946
1947impl<'a> ComChangeUserMoreData<'a> {
1948 pub fn new(character_set: u16) -> Self {
1949 Self {
1950 character_set: RawInt::new(character_set),
1951 auth_plugin: None,
1952 connect_attributes: None,
1953 }
1954 }
1955
1956 pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
1957 self.auth_plugin = auth_plugin;
1958 self
1959 }
1960
1961 pub fn with_connect_attributes(
1962 mut self,
1963 connect_attributes: Option<HashMap<String, String>>,
1964 ) -> Self {
1965 self.connect_attributes = connect_attributes.map(|attrs| {
1966 attrs
1967 .into_iter()
1968 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
1969 .collect()
1970 });
1971 self
1972 }
1973
1974 pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
1975 ComChangeUserMoreData {
1976 character_set: self.character_set,
1977 auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
1978 connect_attributes: self.connect_attributes.map(|x| {
1979 x.into_iter()
1980 .map(|(k, v)| (k.into_owned(), v.into_owned()))
1981 .collect()
1982 }),
1983 }
1984 }
1985}
1986
1987fn deserialize_connect_attrs<'de>(
1989 buf: &mut ParseBuf<'de>,
1990) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
1991 let data_len = buf.parse::<RawInt<LenEnc>>(())?;
1992 let mut data: ParseBuf = buf.parse(data_len.0 as usize)?;
1993 let mut attrs = HashMap::new();
1994 while !data.is_empty() {
1995 let key = data.parse::<RawBytes<LenEnc>>(())?;
1996 let value = data.parse::<RawBytes<LenEnc>>(())?;
1997 attrs.insert(key, value);
1998 }
1999 Ok(attrs)
2000}
2001
2002fn serialize_connect_attrs<'a>(
2004 connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2005 buf: &mut Vec<u8>,
2006) {
2007 let len = connect_attributes
2008 .iter()
2009 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2010 .sum::<u64>();
2011 buf.put_lenenc_int(len);
2012
2013 for (name, value) in connect_attributes {
2014 name.serialize(&mut *buf);
2015 value.serialize(&mut *buf);
2016 }
2017}
2018
2019impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2020 const SIZE: Option<usize> = None;
2021 type Ctx = CapabilityFlags;
2022
2023 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2024 let character_set = buf.parse(())?;
2026 let mut auth_plugin = None;
2027 let mut connect_attributes = None;
2028
2029 if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2030 match buf.parse::<RawBytes<NullBytes>>(())?.0 {
2032 Cow::Borrowed(bytes) => {
2033 let mut auth_plugin_buf = ParseBuf(bytes);
2034 auth_plugin = Some(auth_plugin_buf.parse(())?);
2035 }
2036 _ => unreachable!(),
2037 }
2038 };
2039
2040 if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2041 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2042 };
2043
2044 Ok(Self {
2045 character_set,
2046 auth_plugin,
2047 connect_attributes,
2048 })
2049 }
2050}
2051
2052impl MySerialize for ComChangeUserMoreData<'_> {
2053 fn serialize(&self, buf: &mut Vec<u8>) {
2054 self.character_set.serialize(&mut *buf);
2055 if let Some(ref auth_plugin) = self.auth_plugin {
2056 auth_plugin.serialize(&mut *buf);
2057 }
2058 if let Some(ref connect_attributes) = self.connect_attributes {
2059 serialize_connect_attrs(connect_attributes, buf);
2060 } else {
2061 serialize_connect_attrs(&Default::default(), buf);
2064 }
2065 }
2066}
2067
2068type ScrambleBuf<'a> =
2070 Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2071
2072#[derive(Debug, Clone, PartialEq, Eq)]
2073pub struct HandshakeResponse<'a> {
2074 capabilities: Const<CapabilityFlags, LeU32>,
2075 max_packet_size: RawInt<LeU32>,
2076 collation: RawInt<u8>,
2077 scramble_buf: ScrambleBuf<'a>,
2078 user: RawBytes<'a, NullBytes>,
2079 db_name: Option<RawBytes<'a, NullBytes>>,
2080 auth_plugin: Option<AuthPlugin<'a>>,
2081 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2082}
2083
2084impl<'a> HandshakeResponse<'a> {
2085 #[allow(clippy::too_many_arguments)]
2086 pub fn new(
2087 scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2088 server_version: (u16, u16, u16),
2089 user: Option<impl Into<Cow<'a, [u8]>>>,
2090 db_name: Option<impl Into<Cow<'a, [u8]>>>,
2091 auth_plugin: Option<AuthPlugin<'a>>,
2092 mut capabilities: CapabilityFlags,
2093 connect_attributes: Option<HashMap<String, String>>,
2094 max_packet_size: u32,
2095 ) -> Self {
2096 let scramble_buf =
2097 if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2098 Either::Left(RawBytes::new(
2099 scramble_buf.map(Into::into).unwrap_or_default(),
2100 ))
2101 } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2102 Either::Right(Either::Left(RawBytes::new(
2103 scramble_buf.map(Into::into).unwrap_or_default(),
2104 )))
2105 } else {
2106 Either::Right(Either::Right(RawBytes::new(
2107 scramble_buf.map(Into::into).unwrap_or_default(),
2108 )))
2109 };
2110
2111 if db_name.is_some() {
2112 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2113 } else {
2114 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2115 }
2116
2117 if auth_plugin.is_some() {
2118 capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2119 } else {
2120 capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2121 }
2122
2123 if connect_attributes.is_some() {
2124 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2125 } else {
2126 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2127 }
2128
2129 Self {
2130 scramble_buf,
2131 collation: if server_version >= (5, 5, 3) {
2132 RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2133 } else {
2134 RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2135 },
2136 user: user.map(RawBytes::new).unwrap_or_default(),
2137 db_name: db_name.map(RawBytes::new),
2138 auth_plugin,
2139 capabilities: Const::new(capabilities),
2140 connect_attributes: connect_attributes.map(|attrs| {
2141 attrs
2142 .into_iter()
2143 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2144 .collect()
2145 }),
2146 max_packet_size: RawInt::new(max_packet_size),
2147 }
2148 }
2149
2150 pub fn capabilities(&self) -> CapabilityFlags {
2151 self.capabilities.0
2152 }
2153
2154 pub fn collation(&self) -> u8 {
2155 self.collation.0
2156 }
2157
2158 pub fn scramble_buf(&self) -> &[u8] {
2159 match &self.scramble_buf {
2160 Either::Left(x) => x.as_bytes(),
2161 Either::Right(x) => match x {
2162 Either::Left(x) => x.as_bytes(),
2163 Either::Right(x) => x.as_bytes(),
2164 },
2165 }
2166 }
2167
2168 pub fn user(&self) -> &[u8] {
2169 self.user.as_bytes()
2170 }
2171
2172 pub fn db_name(&self) -> Option<&[u8]> {
2173 self.db_name.as_ref().map(|x| x.as_bytes())
2174 }
2175
2176 pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2177 self.auth_plugin.as_ref()
2178 }
2179
2180 #[must_use = "entails computation"]
2181 pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2182 self.connect_attributes.as_ref().map(|attrs| {
2183 attrs
2184 .iter()
2185 .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2186 .collect()
2187 })
2188 }
2189}
2190
2191impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2192 const SIZE: Option<usize> = None;
2193 type Ctx = ();
2194
2195 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2196 let mut sbuf: ParseBuf = buf.parse(4 + 4 + 1 + 23)?;
2197 let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2198 let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2199 let collation = sbuf.parse_unchecked(())?;
2200 sbuf.parse_unchecked::<Skip<23>>(())?;
2201
2202 let user = buf.parse(())?;
2203 let scramble_buf =
2204 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2205 Either::Left(buf.parse(())?)
2206 } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2207 Either::Right(Either::Left(buf.parse(())?))
2208 } else {
2209 Either::Right(Either::Right(buf.parse(())?))
2210 };
2211
2212 let mut db_name = None;
2213 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2214 db_name = buf.parse(()).map(Some)?;
2215 }
2216
2217 let mut auth_plugin = None;
2218 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2219 let auth_plugin_name = buf.eat_null_str();
2220 auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2221 }
2222
2223 let mut connect_attributes = None;
2224 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2225 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2226 }
2227
2228 Ok(Self {
2229 capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2230 max_packet_size,
2231 collation,
2232 scramble_buf,
2233 user,
2234 db_name,
2235 auth_plugin,
2236 connect_attributes,
2237 })
2238 }
2239}
2240
2241impl MySerialize for HandshakeResponse<'_> {
2242 fn serialize(&self, buf: &mut Vec<u8>) {
2243 self.capabilities.serialize(&mut *buf);
2244 self.max_packet_size.serialize(&mut *buf);
2245 self.collation.serialize(&mut *buf);
2246 buf.put_slice(&[0; 23]);
2247 self.user.serialize(&mut *buf);
2248 self.scramble_buf.serialize(&mut *buf);
2249
2250 if let Some(db_name) = &self.db_name {
2251 db_name.serialize(&mut *buf);
2252 }
2253
2254 if let Some(auth_plugin) = &self.auth_plugin {
2255 auth_plugin.serialize(&mut *buf);
2256 }
2257
2258 if let Some(attrs) = &self.connect_attributes {
2259 let len = attrs
2260 .iter()
2261 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2262 .sum::<u64>();
2263 buf.put_lenenc_int(len);
2264
2265 for (name, value) in attrs {
2266 name.serialize(&mut *buf);
2267 value.serialize(&mut *buf);
2268 }
2269 }
2270 }
2271}
2272
2273#[derive(Debug, Clone, Eq, PartialEq)]
2274pub struct SslRequest {
2275 capabilities: Const<CapabilityFlags, LeU32>,
2276 max_packet_size: RawInt<LeU32>,
2277 character_set: RawInt<u8>,
2278 __skip: Skip<23>,
2279}
2280
2281impl SslRequest {
2282 pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2283 Self {
2284 capabilities: Const::new(capabilities),
2285 max_packet_size: RawInt::new(max_packet_size),
2286 character_set: RawInt::new(character_set),
2287 __skip: Skip,
2288 }
2289 }
2290
2291 pub fn capabilities(&self) -> CapabilityFlags {
2292 self.capabilities.0
2293 }
2294
2295 pub fn max_packet_size(&self) -> u32 {
2296 self.max_packet_size.0
2297 }
2298
2299 pub fn character_set(&self) -> u8 {
2300 self.character_set.0
2301 }
2302}
2303
2304impl<'de> MyDeserialize<'de> for SslRequest {
2305 const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2306 type Ctx = ();
2307
2308 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2309 let mut buf: ParseBuf = buf.parse(Self::SIZE.unwrap())?;
2310 let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2311 Ok(Self {
2312 capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2313 max_packet_size: buf.parse_unchecked(())?,
2314 character_set: buf.parse_unchecked(())?,
2315 __skip: buf.parse_unchecked(())?,
2316 })
2317 }
2318}
2319
2320impl MySerialize for SslRequest {
2321 fn serialize(&self, buf: &mut Vec<u8>) {
2322 self.capabilities.serialize(&mut *buf);
2323 self.max_packet_size.serialize(&mut *buf);
2324 self.character_set.serialize(&mut *buf);
2325 self.__skip.serialize(&mut *buf);
2326 }
2327}
2328
2329#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2330#[error("Invalid statement packet status")]
2331pub struct InvalidStmtPacketStatus;
2332
2333#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2335pub struct StmtPacket {
2336 status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2337 statement_id: RawInt<LeU32>,
2338 num_columns: RawInt<LeU16>,
2339 num_params: RawInt<LeU16>,
2340 __skip: Skip<1>,
2341 warning_count: RawInt<LeU16>,
2342}
2343
2344impl<'de> MyDeserialize<'de> for StmtPacket {
2345 const SIZE: Option<usize> = Some(12);
2346 type Ctx = ();
2347
2348 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2349 let mut buf: ParseBuf = buf.parse(Self::SIZE.unwrap())?;
2350 Ok(StmtPacket {
2351 status: buf.parse_unchecked(())?,
2352 statement_id: buf.parse_unchecked(())?,
2353 num_columns: buf.parse_unchecked(())?,
2354 num_params: buf.parse_unchecked(())?,
2355 __skip: buf.parse_unchecked(())?,
2356 warning_count: buf.parse_unchecked(())?,
2357 })
2358 }
2359}
2360
2361impl MySerialize for StmtPacket {
2362 fn serialize(&self, buf: &mut Vec<u8>) {
2363 self.status.serialize(&mut *buf);
2364 self.statement_id.serialize(&mut *buf);
2365 self.num_columns.serialize(&mut *buf);
2366 self.num_params.serialize(&mut *buf);
2367 self.__skip.serialize(&mut *buf);
2368 self.warning_count.serialize(&mut *buf);
2369 }
2370}
2371
2372impl StmtPacket {
2373 pub fn statement_id(&self) -> u32 {
2375 *self.statement_id
2376 }
2377
2378 pub fn num_columns(&self) -> u16 {
2380 *self.num_columns
2381 }
2382
2383 pub fn num_params(&self) -> u16 {
2385 *self.num_params
2386 }
2387
2388 pub fn warning_count(&self) -> u16 {
2390 *self.warning_count
2391 }
2392}
2393
2394#[derive(Debug, Clone, Eq, PartialEq)]
2398pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2399
2400impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2401 const SIZE: Option<usize> = None;
2402 type Ctx = usize;
2403
2404 fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2405 let bitmap_len = Self::bitmap_len(num_columns);
2406 let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2407 Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2408 }
2409}
2410
2411impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2412 pub fn new(num_columns: usize) -> Self {
2414 Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2415 }
2416
2417 pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2419 let bitmap_len = Self::bitmap_len(num_columns);
2420 assert!(input.len() >= bitmap_len);
2421
2422 let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2423 *input = &input[bitmap_len..];
2424
2425 bitmap
2426 }
2427}
2428
2429impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2430 pub fn bitmap_len(num_columns: usize) -> usize {
2431 (num_columns + 7 + T::BIT_OFFSET) / 8
2432 }
2433
2434 fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2435 let offset = column_index + T::BIT_OFFSET;
2436 let byte = offset / 8;
2437 let bit = 1 << (offset % 8) as u8;
2438
2439 assert!(byte < self.0.as_ref().len());
2440
2441 (byte, bit)
2442 }
2443
2444 pub fn from_bytes(bytes: U) -> Self {
2446 Self(bytes, PhantomData)
2447 }
2448
2449 pub fn is_null(&self, column_index: usize) -> bool {
2451 let (byte, bit) = self.byte_and_bit(column_index);
2452 self.0.as_ref()[byte] & bit > 0
2453 }
2454}
2455
2456impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2457 pub fn set(&mut self, column_index: usize, is_null: bool) {
2459 let (byte, bit) = self.byte_and_bit(column_index);
2460 if is_null {
2461 self.0.as_mut()[byte] |= bit
2462 } else {
2463 self.0.as_mut()[byte] &= !bit
2464 }
2465 }
2466}
2467
2468impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2469 fn as_ref(&self) -> &[u8] {
2470 self.0.as_ref()
2471 }
2472}
2473
2474#[derive(Debug, Clone, PartialEq)]
2475pub struct ComStmtExecuteRequestBuilder {
2476 pub stmt_id: u32,
2477}
2478
2479impl ComStmtExecuteRequestBuilder {
2480 pub const NULL_BITMAP_OFFSET: usize = 10;
2481
2482 pub fn new(stmt_id: u32) -> Self {
2483 Self { stmt_id }
2484 }
2485}
2486
2487impl ComStmtExecuteRequestBuilder {
2488 pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2489 let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2490
2491 let mut bitmap_bytes = vec![0; bitmap_len];
2492 let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2493 let params = params.iter().collect::<Vec<_>>();
2494
2495 let meta_len = params.len() * 2;
2496
2497 let mut data_len = 0;
2498 for (i, param) in params.iter().enumerate() {
2499 match param.bin_len() as usize {
2500 0 => bitmap.set(i, true),
2501 x => data_len += x,
2502 }
2503 }
2504
2505 let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2506
2507 let as_long_data = total_len > MAX_PAYLOAD_LEN;
2508
2509 (
2510 ComStmtExecuteRequest {
2511 com_stmt_execute: ConstU8::new(),
2512 stmt_id: RawInt::new(self.stmt_id),
2513 flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2514 iteration_count: ConstU32::new(),
2515 params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2516 bitmap: RawBytes::new(bitmap_bytes),
2517 params,
2518 as_long_data,
2519 },
2520 as_long_data,
2521 )
2522 }
2523}
2524
2525define_header!(
2526 ComStmtExecuteHeader,
2527 COM_STMT_EXECUTE,
2528 InvalidComStmtExecuteHeader
2529);
2530
2531define_const!(
2532 ConstU32,
2533 IterationCount,
2534 InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2535 1
2536);
2537
2538#[derive(Debug, Clone, PartialEq)]
2539pub struct ComStmtExecuteRequest<'a> {
2540 com_stmt_execute: ComStmtExecuteHeader,
2541 stmt_id: RawInt<LeU32>,
2542 flags: Const<CursorType, u8>,
2543 iteration_count: IterationCount,
2544 bitmap: RawBytes<'a, BareBytes<8192>>,
2546 params_flags: Const<StmtExecuteParamsFlags, u8>,
2547 params: Vec<&'a Value>,
2548 as_long_data: bool,
2549}
2550
2551impl<'a> ComStmtExecuteRequest<'a> {
2552 pub fn stmt_id(&self) -> u32 {
2553 self.stmt_id.0
2554 }
2555
2556 pub fn flags(&self) -> CursorType {
2557 self.flags.0
2558 }
2559
2560 pub fn bitmap(&self) -> &[u8] {
2561 self.bitmap.as_bytes()
2562 }
2563
2564 pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2565 self.params_flags.0
2566 }
2567
2568 pub fn params(&self) -> &[&'a Value] {
2569 self.params.as_ref()
2570 }
2571
2572 pub fn as_long_data(&self) -> bool {
2573 self.as_long_data
2574 }
2575}
2576
2577impl MySerialize for ComStmtExecuteRequest<'_> {
2578 fn serialize(&self, buf: &mut Vec<u8>) {
2579 self.com_stmt_execute.serialize(&mut *buf);
2580 self.stmt_id.serialize(&mut *buf);
2581 self.flags.serialize(&mut *buf);
2582 self.iteration_count.serialize(&mut *buf);
2583
2584 if !self.params.is_empty() {
2585 self.bitmap.serialize(&mut *buf);
2586 self.params_flags.serialize(&mut *buf);
2587 }
2588
2589 for param in &self.params {
2590 let (column_type, flags) = match param {
2591 Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2592 Value::Bytes(_) => (
2593 ColumnType::MYSQL_TYPE_VAR_STRING,
2594 StmtExecuteParamFlags::empty(),
2595 ),
2596 Value::Int(_) => (
2597 ColumnType::MYSQL_TYPE_LONGLONG,
2598 StmtExecuteParamFlags::empty(),
2599 ),
2600 Value::UInt(_) => (
2601 ColumnType::MYSQL_TYPE_LONGLONG,
2602 StmtExecuteParamFlags::UNSIGNED,
2603 ),
2604 Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2605 Value::Double(_) => (
2606 ColumnType::MYSQL_TYPE_DOUBLE,
2607 StmtExecuteParamFlags::empty(),
2608 ),
2609 Value::Date(..) => (
2610 ColumnType::MYSQL_TYPE_DATETIME,
2611 StmtExecuteParamFlags::empty(),
2612 ),
2613 Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2614 };
2615
2616 buf.put_slice(&[column_type as u8, flags.bits()]);
2617 }
2618
2619 for param in &self.params {
2620 match **param {
2621 Value::Int(_)
2622 | Value::UInt(_)
2623 | Value::Float(_)
2624 | Value::Double(_)
2625 | Value::Date(..)
2626 | Value::Time(..) => {
2627 param.serialize(buf);
2628 }
2629 Value::Bytes(_) if !self.as_long_data => {
2630 param.serialize(buf);
2631 }
2632 Value::Bytes(_) | Value::NULL => {}
2633 }
2634 }
2635 }
2636}
2637
2638define_header!(
2639 ComStmtSendLongDataHeader,
2640 COM_STMT_SEND_LONG_DATA,
2641 InvalidComStmtSendLongDataHeader
2642);
2643
2644#[derive(Debug, Clone, Eq, PartialEq)]
2645pub struct ComStmtSendLongData<'a> {
2646 __header: ComStmtSendLongDataHeader,
2647 stmt_id: RawInt<LeU32>,
2648 param_index: RawInt<LeU16>,
2649 data: RawBytes<'a, EofBytes>,
2650}
2651
2652impl<'a> ComStmtSendLongData<'a> {
2653 pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2654 Self {
2655 __header: ComStmtSendLongDataHeader::new(),
2656 stmt_id: RawInt::new(stmt_id),
2657 param_index: RawInt::new(param_index),
2658 data: RawBytes::new(data),
2659 }
2660 }
2661
2662 pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2663 ComStmtSendLongData {
2664 __header: self.__header,
2665 stmt_id: self.stmt_id,
2666 param_index: self.param_index,
2667 data: self.data.into_owned(),
2668 }
2669 }
2670}
2671
2672impl MySerialize for ComStmtSendLongData<'_> {
2673 fn serialize(&self, buf: &mut Vec<u8>) {
2674 self.__header.serialize(&mut *buf);
2675 self.stmt_id.serialize(&mut *buf);
2676 self.param_index.serialize(&mut *buf);
2677 self.data.serialize(&mut *buf);
2678 }
2679}
2680
2681#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2682pub struct ComStmtClose {
2683 pub stmt_id: u32,
2684}
2685
2686impl ComStmtClose {
2687 pub fn new(stmt_id: u32) -> Self {
2688 Self { stmt_id }
2689 }
2690}
2691
2692impl MySerialize for ComStmtClose {
2693 fn serialize(&self, buf: &mut Vec<u8>) {
2694 buf.put_u8(Command::COM_STMT_CLOSE as u8);
2695 buf.put_u32_le(self.stmt_id);
2696 }
2697}
2698
2699define_header!(
2700 ComRegisterSlaveHeader,
2701 COM_REGISTER_SLAVE,
2702 InvalidComRegisterSlaveHeader
2703);
2704
2705#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2708pub struct ComRegisterSlave<'a> {
2709 header: ComRegisterSlaveHeader,
2710 server_id: RawInt<LeU32>,
2712 hostname: RawBytes<'a, U8Bytes>,
2715 user: RawBytes<'a, U8Bytes>,
2722 password: RawBytes<'a, U8Bytes>,
2729 port: RawInt<LeU16>,
2736 replication_rank: RawInt<LeU32>,
2738 master_id: RawInt<LeU32>,
2741}
2742
2743impl<'a> ComRegisterSlave<'a> {
2744 pub fn new(server_id: u32) -> Self {
2746 Self {
2747 header: Default::default(),
2748 server_id: RawInt::new(server_id),
2749 hostname: Default::default(),
2750 user: Default::default(),
2751 password: Default::default(),
2752 port: Default::default(),
2753 replication_rank: Default::default(),
2754 master_id: Default::default(),
2755 }
2756 }
2757
2758 pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2760 self.hostname = RawBytes::new(hostname);
2761 self
2762 }
2763
2764 pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2766 self.user = RawBytes::new(user);
2767 self
2768 }
2769
2770 pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2772 self.password = RawBytes::new(password);
2773 self
2774 }
2775
2776 pub fn with_port(mut self, port: u16) -> Self {
2778 self.port = RawInt::new(port);
2779 self
2780 }
2781
2782 pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2784 self.replication_rank = RawInt::new(replication_rank);
2785 self
2786 }
2787
2788 pub fn with_master_id(mut self, master_id: u32) -> Self {
2790 self.master_id = RawInt::new(master_id);
2791 self
2792 }
2793
2794 pub fn server_id(&self) -> u32 {
2796 self.server_id.0
2797 }
2798
2799 pub fn hostname_raw(&self) -> &[u8] {
2801 self.hostname.as_bytes()
2802 }
2803
2804 pub fn hostname(&'a self) -> Cow<'a, str> {
2806 self.hostname.as_str()
2807 }
2808
2809 pub fn user_raw(&self) -> &[u8] {
2811 self.user.as_bytes()
2812 }
2813
2814 pub fn user(&'a self) -> Cow<'a, str> {
2816 self.user.as_str()
2817 }
2818
2819 pub fn password_raw(&self) -> &[u8] {
2821 self.password.as_bytes()
2822 }
2823
2824 pub fn password(&'a self) -> Cow<'a, str> {
2826 self.password.as_str()
2827 }
2828
2829 pub fn port(&self) -> u16 {
2831 self.port.0
2832 }
2833
2834 pub fn replication_rank(&self) -> u32 {
2836 self.replication_rank.0
2837 }
2838
2839 pub fn master_id(&self) -> u32 {
2841 self.master_id.0
2842 }
2843}
2844
2845impl MySerialize for ComRegisterSlave<'_> {
2846 fn serialize(&self, buf: &mut Vec<u8>) {
2847 self.header.serialize(&mut *buf);
2848 self.server_id.serialize(&mut *buf);
2849 self.hostname.serialize(&mut *buf);
2850 self.user.serialize(&mut *buf);
2851 self.password.serialize(&mut *buf);
2852 self.port.serialize(&mut *buf);
2853 self.replication_rank.serialize(&mut *buf);
2854 self.master_id.serialize(&mut *buf);
2855 }
2856}
2857
2858impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2859 const SIZE: Option<usize> = None;
2860 type Ctx = ();
2861
2862 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2863 let mut sbuf: ParseBuf = buf.parse(5)?;
2864 let header = sbuf.parse_unchecked(())?;
2865 let server_id = sbuf.parse_unchecked(())?;
2866
2867 let hostname = buf.parse(())?;
2868 let user = buf.parse(())?;
2869 let password = buf.parse(())?;
2870
2871 let mut sbuf: ParseBuf = buf.parse(10)?;
2872 let port = sbuf.parse_unchecked(())?;
2873 let replication_rank = sbuf.parse_unchecked(())?;
2874 let master_id = sbuf.parse_unchecked(())?;
2875
2876 Ok(Self {
2877 header,
2878 server_id,
2879 hostname,
2880 user,
2881 password,
2882 port,
2883 replication_rank,
2884 master_id,
2885 })
2886 }
2887}
2888
2889define_header!(
2890 ComTableDumpHeader,
2891 COM_TABLE_DUMP,
2892 InvalidComTableDumpHeader
2893);
2894
2895#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2897pub struct ComTableDump<'a> {
2898 header: ComTableDumpHeader,
2899 database: RawBytes<'a, U8Bytes>,
2905 table: RawBytes<'a, U8Bytes>,
2911}
2912
2913impl<'a> ComTableDump<'a> {
2914 pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2916 Self {
2917 header: Default::default(),
2918 database: RawBytes::new(database),
2919 table: RawBytes::new(table),
2920 }
2921 }
2922
2923 pub fn database_raw(&self) -> &[u8] {
2925 self.database.as_bytes()
2926 }
2927
2928 pub fn database(&self) -> Cow<str> {
2930 self.database.as_str()
2931 }
2932
2933 pub fn table_raw(&self) -> &[u8] {
2935 self.table.as_bytes()
2936 }
2937
2938 pub fn table(&self) -> Cow<str> {
2940 self.table.as_str()
2941 }
2942}
2943
2944impl MySerialize for ComTableDump<'_> {
2945 fn serialize(&self, buf: &mut Vec<u8>) {
2946 self.header.serialize(&mut *buf);
2947 self.database.serialize(&mut *buf);
2948 self.table.serialize(&mut *buf);
2949 }
2950}
2951
2952impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
2953 const SIZE: Option<usize> = None;
2954 type Ctx = ();
2955
2956 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2957 Ok(Self {
2958 header: buf.parse(())?,
2959 database: buf.parse(())?,
2960 table: buf.parse(())?,
2961 })
2962 }
2963}
2964
2965my_bitflags! {
2966 BinlogDumpFlags,
2967 #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={:b})", _0)]
2968 UnknownBinlogDumpFlags,
2969 u16,
2970
2971 #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
2973 pub struct BinlogDumpFlags: u16 {
2974 const BINLOG_DUMP_NON_BLOCK = 0x01;
2976 const BINLOG_THROUGH_POSITION = 0x02;
2977 const BINLOG_THROUGH_GTID = 0x04;
2978 }
2979}
2980
2981define_header!(
2982 ComBinlogDumpHeader,
2983 COM_BINLOG_DUMP,
2984 InvalidComBinlogDumpHeader
2985);
2986
2987#[derive(Clone, Debug, Eq, PartialEq, Hash)]
2989pub struct ComBinlogDump<'a> {
2990 header: ComBinlogDumpHeader,
2991 pos: RawInt<LeU32>,
2993 flags: Const<BinlogDumpFlags, LeU16>,
2997 server_id: RawInt<LeU32>,
2999 filename: RawBytes<'a, EofBytes>,
3004}
3005
3006impl<'a> ComBinlogDump<'a> {
3007 pub fn new(server_id: u32) -> Self {
3009 Self {
3010 header: Default::default(),
3011 pos: Default::default(),
3012 flags: Default::default(),
3013 server_id: RawInt::new(server_id),
3014 filename: Default::default(),
3015 }
3016 }
3017
3018 pub fn with_pos(mut self, pos: u32) -> Self {
3020 self.pos = RawInt::new(pos);
3021 self
3022 }
3023
3024 pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3026 self.flags = Const::new(flags);
3027 self
3028 }
3029
3030 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3032 self.filename = RawBytes::new(filename);
3033 self
3034 }
3035
3036 pub fn pos(&self) -> u32 {
3038 *self.pos
3039 }
3040
3041 pub fn flags(&self) -> BinlogDumpFlags {
3043 *self.flags
3044 }
3045
3046 pub fn server_id(&self) -> u32 {
3048 *self.server_id
3049 }
3050
3051 pub fn filename_raw(&self) -> &[u8] {
3053 self.filename.as_bytes()
3054 }
3055
3056 pub fn filename(&self) -> Cow<str> {
3058 self.filename.as_str()
3059 }
3060}
3061
3062impl MySerialize for ComBinlogDump<'_> {
3063 fn serialize(&self, buf: &mut Vec<u8>) {
3064 self.header.serialize(&mut *buf);
3065 self.pos.serialize(&mut *buf);
3066 self.flags.serialize(&mut *buf);
3067 self.server_id.serialize(&mut *buf);
3068 self.filename.serialize(&mut *buf);
3069 }
3070}
3071
3072impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3073 const SIZE: Option<usize> = None;
3074 type Ctx = ();
3075
3076 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3077 let mut sbuf: ParseBuf = buf.parse(11)?;
3078 Ok(Self {
3079 header: sbuf.parse_unchecked(())?,
3080 pos: sbuf.parse_unchecked(())?,
3081 flags: sbuf.parse_unchecked(())?,
3082 server_id: sbuf.parse_unchecked(())?,
3083 filename: buf.parse(())?,
3084 })
3085 }
3086}
3087
3088#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3090pub struct GnoInterval {
3091 start: RawInt<LeU64>,
3092 end: RawInt<LeU64>,
3093}
3094
3095impl GnoInterval {
3096 pub fn new(start: u64, end: u64) -> Self {
3098 Self {
3099 start: RawInt::new(start),
3100 end: RawInt::new(end),
3101 }
3102 }
3103 pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3105 if start >= end {
3106 return Err(io::Error::new(
3107 io::ErrorKind::InvalidData,
3108 format!("start({}) >= end({}) in GnoInterval", start, end),
3109 ));
3110 }
3111 if start == 0 || end == 0 {
3112 return Err(io::Error::new(
3113 io::ErrorKind::InvalidData,
3114 "Gno can't be zero",
3115 ));
3116 }
3117 Ok(Self::new(start, end))
3118 }
3119}
3120
3121impl MySerialize for GnoInterval {
3122 fn serialize(&self, buf: &mut Vec<u8>) {
3123 self.start.serialize(&mut *buf);
3124 self.end.serialize(&mut *buf);
3125 }
3126}
3127
3128impl<'de> MyDeserialize<'de> for GnoInterval {
3129 const SIZE: Option<usize> = Some(16);
3130 type Ctx = ();
3131
3132 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3133 Ok(Self {
3134 start: buf.parse_unchecked(())?,
3135 end: buf.parse_unchecked(())?,
3136 })
3137 }
3138}
3139
3140pub const UUID_LEN: usize = 16;
3142
3143#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3146pub struct Sid<'a> {
3147 uuid: [u8; UUID_LEN],
3148 intervals: Seq<'a, GnoInterval, LeU64>,
3149}
3150
3151impl Sid<'_> {
3152 pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3154 Self {
3155 uuid,
3156 intervals: Default::default(),
3157 }
3158 }
3159
3160 pub fn uuid(&self) -> [u8; UUID_LEN] {
3162 self.uuid
3163 }
3164
3165 pub fn intervals(&self) -> &[GnoInterval] {
3167 &self.intervals[..]
3168 }
3169
3170 pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3172 let mut intervals = self.intervals.0.into_owned();
3173 intervals.push(interval);
3174 self.intervals = Seq::new(intervals);
3175 self
3176 }
3177
3178 pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3180 self.intervals = Seq::new(intervals);
3181 self
3182 }
3183
3184 fn len(&self) -> u64 {
3185 use saturating::Saturating as S;
3186 let mut len = S(UUID_LEN as u64); len += S(8); len += S((self.intervals.len() * 16) as u64);
3189 len.0
3190 }
3191}
3192
3193impl MySerialize for Sid<'_> {
3194 fn serialize(&self, buf: &mut Vec<u8>) {
3195 self.uuid.serialize(&mut *buf);
3196 self.intervals.serialize(buf);
3197 }
3198}
3199
3200impl<'de> MyDeserialize<'de> for Sid<'de> {
3201 const SIZE: Option<usize> = None;
3202 type Ctx = ();
3203
3204 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3205 Ok(Self {
3206 uuid: buf.parse(())?,
3207 intervals: buf.parse(())?,
3208 })
3209 }
3210}
3211
3212impl Sid<'_> {
3213 fn wrap_err(msg: String) -> io::Error {
3214 io::Error::new(io::ErrorKind::InvalidInput, msg)
3215 }
3216
3217 fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3218 let n: u64 = to_parse.parse().map_err(|e| {
3219 Sid::wrap_err(format!(
3220 "invalid GnoInterval format: {}, error: {}",
3221 full, e
3222 ))
3223 })?;
3224 Ok(n)
3225 }
3226}
3227
3228impl FromStr for Sid<'_> {
3229 type Err = io::Error;
3230
3231 fn from_str(s: &str) -> Result<Self, Self::Err> {
3232 let (uuid, intervals) = s
3233 .split_once(':')
3234 .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3235 let uuid = Uuid::parse_str(uuid)
3236 .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3237 let intervals = intervals
3238 .split(':')
3239 .map(|interval| {
3240 let nums = interval.split('-').collect::<Vec<_>>();
3241 if nums.len() != 1 && nums.len() != 2 {
3242 return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3243 }
3244 if nums.len() == 1 {
3245 let start = Sid::parse_interval_num(nums[0], s)?;
3246 let interval = GnoInterval::check_and_new(start, start + 1)?;
3247 Ok(interval)
3248 } else {
3249 let start = Sid::parse_interval_num(nums[0], s)?;
3250 let end = Sid::parse_interval_num(nums[1], s)?;
3251 let interval = GnoInterval::check_and_new(start, end + 1)?;
3252 Ok(interval)
3253 }
3254 })
3255 .collect::<Result<Vec<_>, _>>()?;
3256 Ok(Self {
3257 uuid: *uuid.as_bytes(),
3258 intervals: Seq::new(intervals),
3259 })
3260 }
3261}
3262
3263define_header!(
3264 ComBinlogDumpGtidHeader,
3265 COM_BINLOG_DUMP_GTID,
3266 InvalidComBinlogDumpGtidHeader
3267);
3268
3269#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3271pub struct ComBinlogDumpGtid<'a> {
3272 header: ComBinlogDumpGtidHeader,
3273 flags: Const<BinlogDumpFlags, LeU16>,
3275 server_id: RawInt<LeU32>,
3277 filename: RawBytes<'a, U32Bytes>,
3286 pos: RawInt<LeU64>,
3288 sid_block: Seq<'a, Sid<'a>, LeU64>,
3290}
3291
3292impl<'a> ComBinlogDumpGtid<'a> {
3293 pub fn new(server_id: u32) -> Self {
3295 Self {
3296 header: Default::default(),
3297 pos: Default::default(),
3298 flags: Default::default(),
3299 server_id: RawInt::new(server_id),
3300 filename: Default::default(),
3301 sid_block: Default::default(),
3302 }
3303 }
3304
3305 pub fn server_id(&self) -> u32 {
3307 self.server_id.0
3308 }
3309
3310 pub fn flags(&self) -> BinlogDumpFlags {
3312 self.flags.0
3313 }
3314
3315 pub fn filename_raw(&self) -> &[u8] {
3317 self.filename.as_bytes()
3318 }
3319
3320 pub fn filename(&self) -> Cow<str> {
3322 self.filename.as_str()
3323 }
3324
3325 pub fn pos(&self) -> u64 {
3327 self.pos.0
3328 }
3329
3330 pub fn sids(&self) -> &[Sid<'a>] {
3332 &self.sid_block
3333 }
3334
3335 pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3337 Self {
3338 header: self.header,
3339 flags: self.flags,
3340 server_id: self.server_id,
3341 filename: RawBytes::new(filename),
3342 pos: self.pos,
3343 sid_block: self.sid_block,
3344 }
3345 }
3346
3347 pub fn with_server_id(mut self, server_id: u32) -> Self {
3349 self.server_id.0 = server_id;
3350 self
3351 }
3352
3353 pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3355 if self.sid_block.is_empty() {
3356 flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3357 } else {
3358 flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3359 }
3360 self.flags.0 = flags;
3361 self
3362 }
3363
3364 pub fn with_pos(mut self, pos: u64) -> Self {
3366 self.pos.0 = pos;
3367 self
3368 }
3369
3370 pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3372 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3373 self.sid_block.push(sid);
3374 self
3375 }
3376
3377 pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3379 self.sid_block = Seq::new(sids);
3380 if self.sid_block.is_empty() {
3381 self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3382 } else {
3383 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3384 }
3385 self
3386 }
3387
3388 fn sid_block_len(&self) -> u32 {
3389 use saturating::Saturating as S;
3390 let mut len = S(8); for sid in self.sid_block.iter() {
3392 len += S(sid.len() as u32);
3393 }
3394 len.0
3395 }
3396}
3397
3398impl MySerialize for ComBinlogDumpGtid<'_> {
3399 fn serialize(&self, buf: &mut Vec<u8>) {
3400 self.header.serialize(&mut *buf);
3401 self.flags.serialize(&mut *buf);
3402 self.server_id.serialize(&mut *buf);
3403 self.filename.serialize(&mut *buf);
3404 self.pos.serialize(&mut *buf);
3405 buf.put_u32_le(self.sid_block_len());
3406 self.sid_block.serialize(&mut *buf);
3407 }
3408}
3409
3410impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3411 const SIZE: Option<usize> = None;
3412 type Ctx = ();
3413
3414 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3415 let mut sbuf: ParseBuf = buf.parse(7)?;
3416 let header = sbuf.parse_unchecked(())?;
3417 let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3418 let server_id = sbuf.parse_unchecked(())?;
3419
3420 let filename = buf.parse(())?;
3421 let pos = buf.parse(())?;
3422
3423 let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3425 let mut buf: ParseBuf = buf.parse(sid_data_len.0 as usize)?;
3426 let sid_block = buf.parse(())?;
3427
3428 Ok(Self {
3429 header,
3430 flags,
3431 server_id,
3432 filename,
3433 pos,
3434 sid_block,
3435 })
3436 }
3437}
3438
3439define_header!(
3440 SemiSyncAckPacketPacketHeader,
3441 InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3442 0xEF
3443);
3444
3445pub struct SemiSyncAckPacket<'a> {
3448 header: SemiSyncAckPacketPacketHeader,
3449 position: RawInt<LeU64>,
3450 filename: RawBytes<'a, EofBytes>,
3451}
3452
3453impl<'a> SemiSyncAckPacket<'a> {
3454 pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3455 Self {
3456 header: Default::default(),
3457 position: RawInt::new(position),
3458 filename: RawBytes::new(filename),
3459 }
3460 }
3461
3462 pub fn with_position(mut self, position: u64) -> Self {
3464 self.position.0 = position;
3465 self
3466 }
3467
3468 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3470 self.filename = RawBytes::new(filename);
3471 self
3472 }
3473
3474 pub fn position(&self) -> u64 {
3476 self.position.0
3477 }
3478
3479 pub fn filename_raw(&self) -> &[u8] {
3481 self.filename.as_bytes()
3482 }
3483
3484 pub fn filename(&self) -> Cow<'_, str> {
3486 self.filename.as_str()
3487 }
3488}
3489
3490impl MySerialize for SemiSyncAckPacket<'_> {
3491 fn serialize(&self, buf: &mut Vec<u8>) {
3492 self.header.serialize(&mut *buf);
3493 self.position.serialize(&mut *buf);
3494 self.filename.serialize(&mut *buf);
3495 }
3496}
3497
3498impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3499 const SIZE: Option<usize> = None;
3500 type Ctx = ();
3501
3502 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3503 let mut sbuf: ParseBuf = buf.parse(9)?;
3504 Ok(Self {
3505 header: sbuf.parse_unchecked(())?,
3506 position: sbuf.parse_unchecked(())?,
3507 filename: buf.parse(())?,
3508 })
3509 }
3510}
3511
3512#[cfg(test)]
3513mod test {
3514 use super::*;
3515 use crate::{
3516 constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3517 proto::{MyDeserialize, MySerialize},
3518 };
3519
3520 proptest::proptest! {
3521 #[test]
3522 fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3523 let cmd = ComTableDump::new(database, table);
3524
3525 let mut output = Vec::new();
3526 cmd.serialize(&mut output);
3527
3528 assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3529 }
3530
3531 #[test]
3532 fn com_binlog_dump_roundtrip(
3533 server_id: u32,
3534 filename: Vec<u8>,
3535 pos: u32,
3536 flags: u16,
3537 ) {
3538 let cmd = ComBinlogDump::new(server_id)
3539 .with_filename(filename)
3540 .with_pos(pos)
3541 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3542
3543 let mut output = Vec::new();
3544 cmd.serialize(&mut output);
3545
3546 assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3547 }
3548
3549 #[test]
3550 fn com_register_slave_roundtrip(
3551 server_id: u32,
3552 hostname in r"\w{0,256}",
3553 user in r"\w{0,256}",
3554 password in r"\w{0,256}",
3555 port: u16,
3556 replication_rank: u32,
3557 master_id: u32,
3558 ) {
3559 let cmd = ComRegisterSlave::new(server_id)
3560 .with_hostname(hostname.as_bytes())
3561 .with_user(user.as_bytes())
3562 .with_password(password.as_bytes())
3563 .with_port(port)
3564 .with_replication_rank(replication_rank)
3565 .with_master_id(master_id);
3566
3567 let mut output = Vec::new();
3568 cmd.serialize(&mut output);
3569 let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3570
3571 if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3572 assert_ne!(cmd, parsed);
3573 } else {
3574 assert_eq!(cmd, parsed);
3575 }
3576 }
3577
3578 #[test]
3579 fn com_binlog_dump_gtid_roundtrip(
3580 flags: u16,
3581 server_id: u32,
3582 filename: Vec<u8>,
3583 pos: u64,
3584 n_sid_blocks in 0_u64..1024,
3585 ) {
3586 let mut cmd = ComBinlogDumpGtid::new(server_id)
3587 .with_filename(filename)
3588 .with_pos(pos)
3589 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3590
3591 let mut sids = Vec::new();
3592 for i in 0..n_sid_blocks {
3593 let mut block = Sid::new([i as u8; 16]);
3594 for j in 0..i {
3595 block = block.with_interval(GnoInterval::new(i, j));
3596 }
3597 sids.push(block);
3598 }
3599
3600 cmd = cmd.with_sids(sids);
3601
3602 let mut output = Vec::new();
3603 cmd.serialize(&mut output);
3604
3605 assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3606 }
3607 }
3608
3609 #[test]
3610 fn should_parse_local_infile_packet() {
3611 const LIP: &[u8] = b"\xfbfile_name";
3612
3613 let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3614 assert_eq!(lip.file_name_str(), "file_name");
3615 }
3616
3617 #[test]
3618 fn should_parse_stmt_packet() {
3619 const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3620 const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3621
3622 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3623 assert_eq!(sp.statement_id(), 0x01);
3624 assert_eq!(sp.num_columns(), 0x01);
3625 assert_eq!(sp.num_params(), 0x02);
3626 assert_eq!(sp.warning_count(), 0x00);
3627
3628 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3629 assert_eq!(sp.statement_id(), 0x01);
3630 assert_eq!(sp.num_columns(), 0x00);
3631 assert_eq!(sp.num_params(), 0x00);
3632 assert_eq!(sp.warning_count(), 0x00);
3633 }
3634
3635 #[test]
3636 fn should_parse_handshake_packet() {
3637 const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3638 \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3639 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3640 \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3641
3642 const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3643 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3644 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3645 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3646 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3647 \x6f\x72\x64\x00";
3648
3649 const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3650 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3651 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3652 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3653 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3654 \x6f\x72\x64\x00";
3655
3656 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3657 assert_eq!(hsp.protocol_version(), 0x0a);
3658 assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3659 assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3660 assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3661 assert_eq!(hsp.connection_id(), 0x0b);
3662 assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3663 assert_eq!(
3664 hsp.capabilities(),
3665 CapabilityFlags::from_bits_truncate(0xf7ff)
3666 );
3667 assert_eq!(hsp.default_collation(), 0x08);
3668 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3669 assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3670 assert_eq!(hsp.auth_plugin_name_ref(), None);
3671
3672 let mut output = Vec::new();
3673 hsp.serialize(&mut output);
3674 assert_eq!(&output, HSP);
3675
3676 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3677 assert_eq!(hsp.protocol_version(), 0x0a);
3678 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3679 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3680 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3681 assert_eq!(hsp.connection_id(), 0x0a56);
3682 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3683 assert_eq!(
3684 hsp.capabilities(),
3685 CapabilityFlags::from_bits_truncate(0xc00fffff)
3686 );
3687 assert_eq!(hsp.default_collation(), 0x08);
3688 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3689 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3690 assert_eq!(
3691 hsp.auth_plugin_name_ref(),
3692 Some(&b"mysql_native_password"[..])
3693 );
3694
3695 let mut output = Vec::new();
3696 hsp.serialize(&mut output);
3697 assert_eq!(&output, HSP_2);
3698
3699 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3700 assert_eq!(hsp.protocol_version(), 0x0a);
3701 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3702 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3703 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3704 assert_eq!(hsp.connection_id(), 0x0a56);
3705 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3706 assert_eq!(
3707 hsp.capabilities(),
3708 CapabilityFlags::from_bits_truncate(0xc00fffff)
3709 );
3710 assert_eq!(hsp.default_collation(), 0x08);
3711 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3712 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3713 assert_eq!(
3714 hsp.auth_plugin_name_ref(),
3715 Some(&b"mysql_native_password"[..])
3716 );
3717
3718 let mut output = Vec::new();
3719 hsp.serialize(&mut output);
3720 assert_eq!(&output, HSP_3);
3721 }
3722
3723 #[test]
3724 fn should_parse_err_packet() {
3725 const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3726 \x6c\x65\x73\x20\x75\x73\x65\x64";
3727 const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3728 \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3729 const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3730
3731 let err_packet = ErrPacket::deserialize(
3732 CapabilityFlags::CLIENT_PROTOCOL_41,
3733 &mut ParseBuf(ERR_PACKET),
3734 )
3735 .unwrap();
3736 let err_packet = err_packet.server_error();
3737 assert_eq!(err_packet.error_code(), 1096);
3738 assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3739 assert_eq!(err_packet.message_str(), "No tables used");
3740
3741 let err_packet =
3742 ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3743 .unwrap();
3744 let server_error = err_packet.server_error();
3745 assert_eq!(server_error.error_code(), 1040);
3746 assert_eq!(server_error.sql_state_ref(), None);
3747 assert_eq!(server_error.message_str(), "Too many connections");
3748
3749 let err_packet = ErrPacket::deserialize(
3750 CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3751 &mut ParseBuf(PROGRESS_PACKET),
3752 )
3753 .unwrap();
3754 assert!(err_packet.is_progress_report());
3755 let progress_report = err_packet.progress_report();
3756 assert_eq!(progress_report.stage(), 1);
3757 assert_eq!(progress_report.max_stage(), 10);
3758 assert_eq!(progress_report.progress(), 23500);
3759 assert_eq!(progress_report.stage_info_str(), "stage name");
3760 }
3761
3762 #[test]
3763 fn should_parse_column_packet() {
3764 const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3765 \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3766 let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3767 assert_eq!(column.schema_str(), "schema");
3768 assert_eq!(column.table_str(), "table");
3769 assert_eq!(column.org_table_str(), "org_table");
3770 assert_eq!(column.name_str(), "name");
3771 assert_eq!(column.org_name_str(), "org_name");
3772 assert_eq!(
3773 column.character_set(),
3774 CollationId::UTF8MB3_GENERAL_CI as u16
3775 );
3776 assert_eq!(column.column_length(), 15);
3777 assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3778 assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3779 assert_eq!(column.decimals(), 8);
3780 }
3781
3782 #[test]
3783 fn should_parse_auth_switch_request() {
3784 const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3785 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3786 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3787 let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3788 assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3789 assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3790 }
3791
3792 #[test]
3793 fn should_parse_auth_more_data() {
3794 const PAYLOAD: &[u8] = b"\x01\x04";
3795 let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3796 assert_eq!(packet.data(), b"\x04",);
3797 }
3798
3799 #[test]
3800 fn should_parse_ok_packet() {
3801 const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3802 const RESULT_SET_TERMINATOR: &[u8] = &[
3803 0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3804 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3805 0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3806 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3807 0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3808 0x73, 0x65, 0x63, 0x2e,
3809 ];
3810 const SESS_STATE_SYS_VAR_OK: &[u8] =
3811 b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3812 \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3813 const SESS_STATE_SCHEMA_OK: &[u8] =
3814 b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3815 const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3816 const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3817
3818 OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3820 CapabilityFlags::empty(),
3821 &mut ParseBuf(PLAIN_OK),
3822 )
3823 .unwrap_err();
3824
3825 let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3826 CapabilityFlags::empty(),
3827 &mut ParseBuf(PLAIN_OK),
3828 )
3829 .unwrap()
3830 .into();
3831 assert_eq!(ok_packet.affected_rows(), 1);
3832 assert_eq!(ok_packet.last_insert_id(), None);
3833 assert_eq!(
3834 ok_packet.status_flags(),
3835 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3836 );
3837 assert_eq!(ok_packet.warnings(), 0);
3838 assert_eq!(ok_packet.info_ref(), None);
3839 assert_eq!(ok_packet.session_state_info_ref(), None);
3840
3841 let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3842 CapabilityFlags::CLIENT_SESSION_TRACK,
3843 &mut ParseBuf(PLAIN_OK),
3844 )
3845 .unwrap()
3846 .into();
3847 assert_eq!(ok_packet.affected_rows(), 1);
3848 assert_eq!(ok_packet.last_insert_id(), None);
3849 assert_eq!(
3850 ok_packet.status_flags(),
3851 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3852 );
3853 assert_eq!(ok_packet.warnings(), 0);
3854 assert_eq!(ok_packet.info_ref(), None);
3855 assert_eq!(ok_packet.session_state_info_ref(), None);
3856
3857 let ok_packet: OkPacket = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3858 CapabilityFlags::CLIENT_SESSION_TRACK,
3859 &mut ParseBuf(RESULT_SET_TERMINATOR),
3860 )
3861 .unwrap()
3862 .into();
3863 assert_eq!(ok_packet.affected_rows(), 0);
3864 assert_eq!(ok_packet.last_insert_id(), None);
3865 assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3866 assert_eq!(ok_packet.warnings(), 0);
3867 assert_eq!(
3868 ok_packet.info_str(),
3869 Some(Cow::Borrowed(
3870 "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3871 ))
3872 );
3873 assert_eq!(ok_packet.session_state_info_ref(), None);
3874
3875 let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3876 CapabilityFlags::CLIENT_SESSION_TRACK,
3877 &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3878 )
3879 .unwrap()
3880 .into();
3881 assert_eq!(ok_packet.affected_rows(), 0);
3882 assert_eq!(ok_packet.last_insert_id(), None);
3883 assert_eq!(
3884 ok_packet.status_flags(),
3885 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3886 );
3887 assert_eq!(ok_packet.warnings(), 0);
3888 assert_eq!(ok_packet.info_ref(), None);
3889 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3890
3891 match sess_state_info.decode().unwrap() {
3892 SessionStateChange::SystemVariables(mut vals) => {
3893 let val = vals.pop().unwrap();
3894 assert_eq!(val.name_bytes(), b"autocommit");
3895 assert_eq!(val.value_bytes(), b"OFF");
3896 assert!(vals.is_empty());
3897 }
3898 _ => panic!(),
3899 }
3900
3901 let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3902 CapabilityFlags::CLIENT_SESSION_TRACK,
3903 &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3904 )
3905 .unwrap()
3906 .into();
3907 assert_eq!(ok_packet.affected_rows(), 0);
3908 assert_eq!(ok_packet.last_insert_id(), None);
3909 assert_eq!(
3910 ok_packet.status_flags(),
3911 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3912 );
3913 assert_eq!(ok_packet.warnings(), 0);
3914 assert_eq!(ok_packet.info_ref(), None);
3915 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3916 match sess_state_info.decode().unwrap() {
3917 SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3918 _ => panic!(),
3919 }
3920
3921 let ok_packet: OkPacket = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3922 CapabilityFlags::CLIENT_SESSION_TRACK,
3923 &mut ParseBuf(SESS_STATE_TRACK_OK),
3924 )
3925 .unwrap()
3926 .into();
3927 assert_eq!(ok_packet.affected_rows(), 0);
3928 assert_eq!(ok_packet.last_insert_id(), None);
3929 assert_eq!(
3930 ok_packet.status_flags(),
3931 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3932 );
3933 assert_eq!(ok_packet.warnings(), 0);
3934 assert_eq!(ok_packet.info_ref(), None);
3935 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3936 assert_eq!(
3937 sess_state_info.decode().unwrap(),
3938 SessionStateChange::IsTracked(true),
3939 );
3940
3941 let ok_packet: OkPacket = OkPacketDeserializer::<OldEofPacket>::deserialize(
3942 CapabilityFlags::CLIENT_SESSION_TRACK,
3943 &mut ParseBuf(EOF),
3944 )
3945 .unwrap()
3946 .into();
3947 assert_eq!(ok_packet.affected_rows(), 0);
3948 assert_eq!(ok_packet.last_insert_id(), None);
3949 assert_eq!(
3950 ok_packet.status_flags(),
3951 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3952 );
3953 assert_eq!(ok_packet.warnings(), 0);
3954 assert_eq!(ok_packet.info_ref(), None);
3955 assert_eq!(ok_packet.session_state_info_ref(), None);
3956 }
3957
3958 #[test]
3959 fn should_build_handshake_response() {
3960 let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
3961 let response = HandshakeResponse::new(
3962 Some(&[][..]),
3963 (5u16, 5, 5),
3964 Some(&b"root"[..]),
3965 None::<&'static [u8]>,
3966 Some(AuthPlugin::MysqlNativePassword),
3967 flags_without_db_name,
3968 None,
3969 1_u32.to_be(),
3970 );
3971 let mut actual = Vec::new();
3972 response.serialize(&mut actual);
3973
3974 let expected: Vec<u8> = [
3975 0x05, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
3979 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
3983 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
3985 .to_vec();
3986
3987 assert_eq!(expected, actual);
3988
3989 let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
3990 let response = HandshakeResponse::new(
3991 Some(&[][..]),
3992 (5u16, 5, 5),
3993 Some(&b"root"[..]),
3994 Some(&b"mydb"[..]),
3995 Some(AuthPlugin::MysqlNativePassword),
3996 flags_with_db_name,
3997 None,
3998 1_u32.to_be(),
3999 );
4000 let mut actual = Vec::new();
4001 response.serialize(&mut actual);
4002
4003 let expected: Vec<u8> = [
4004 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4008 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x64, 0x62, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4013 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4015 .to_vec();
4016
4017 assert_eq!(expected, actual);
4018
4019 let response = HandshakeResponse::new(
4020 Some(&[][..]),
4021 (5u16, 5, 5),
4022 Some(&b"root"[..]),
4023 Some(&b"mydb"[..]),
4024 Some(AuthPlugin::MysqlNativePassword),
4025 flags_without_db_name,
4026 None,
4027 1_u32.to_be(),
4028 );
4029 let mut actual = Vec::new();
4030 response.serialize(&mut actual);
4031 assert_eq!(expected, actual);
4032
4033 let response = HandshakeResponse::new(
4034 Some(&[][..]),
4035 (5u16, 5, 5),
4036 Some(&b"root"[..]),
4037 Some(&[][..]),
4038 Some(AuthPlugin::MysqlNativePassword),
4039 flags_with_db_name,
4040 None,
4041 1_u32.to_be(),
4042 );
4043 let mut actual = Vec::new();
4044 response.serialize(&mut actual);
4045
4046 let expected: Vec<u8> = [
4047 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4051 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4056 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4058 .to_vec();
4059 assert_eq!(expected, actual);
4060 }
4061
4062 #[test]
4063 fn parse_str_to_sid() {
4064 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4065 let sid = input.parse::<Sid>().unwrap();
4066 let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4067 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4068 assert_eq!(sid.intervals.len(), 1);
4069 assert_eq!(sid.intervals[0].start.0, 23);
4070 assert_eq!(sid.intervals[0].end.0, 24);
4071
4072 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4073 let sid = input.parse::<Sid>().unwrap();
4074 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4075 assert_eq!(sid.intervals.len(), 2);
4076 assert_eq!(sid.intervals[0].start.0, 1);
4077 assert_eq!(sid.intervals[0].end.0, 6);
4078 assert_eq!(sid.intervals[1].start.0, 10);
4079 assert_eq!(sid.intervals[1].end.0, 16);
4080
4081 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4082 let e = input.parse::<Sid>().unwrap_err();
4083 assert_eq!(
4084 e.to_string(),
4085 "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4086 );
4087
4088 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4089 let e = input.parse::<Sid>().unwrap_err();
4090 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4091
4092 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4093 let e = input.parse::<Sid>().unwrap_err();
4094 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4095
4096 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4097 let e = input.parse::<Sid>().unwrap_err();
4098 assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4099
4100 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4101 let e = input.parse::<Sid>().unwrap_err();
4102 assert_eq!(
4103 e.to_string(),
4104 "start(4) >= end(4) in GnoInterval".to_string()
4105 );
4106 }
4107
4108 #[test]
4109 fn should_parse_rsa_public_key_response_packet() {
4110 const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4111
4112 let rsa_public_key_response =
4113 PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4114
4115 assert!(rsa_public_key_response.is_ok());
4116 assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4117 }
4118
4119 #[test]
4120 fn should_build_rsa_public_key_response_packet() {
4121 let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4122
4123 let mut actual = Vec::new();
4124 rsa_public_key_response.serialize(&mut actual);
4125
4126 let expected = b"\x01test".to_vec();
4127
4128 assert_eq!(expected, actual);
4129 }
4130}