1use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::{
17 borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::scramble::create_response_for_ed25519;
22use crate::{
23 constants::{
24 CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
25 SessionStateType, StatusFlags, StmtExecuteParamFlags, StmtExecuteParamsFlags,
26 },
27 io::{BufMutExt, ParseBuf},
28 misc::{
29 lenenc_str_len,
30 raw::{
31 Const, Either, RawBytes, RawConst, RawInt, Skip,
32 bytes::{
33 BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U8Bytes,
34 U32Bytes,
35 },
36 int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
37 seq::Seq,
38 },
39 unexpected_buf_eof,
40 },
41 proto::{MyDeserialize, MySerialize},
42 value::{ClientSide, SerializationSide, Value},
43};
44
45use self::session_state_change::SessionStateChange;
46
47static MARIADB_VERSION_RE: LazyLock<Regex> =
48 LazyLock::new(|| Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap());
49static VERSION_RE: LazyLock<Regex> =
50 LazyLock::new(|| Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap());
51
52macro_rules! define_header {
53 ($name:ident, $err:ident($msg:literal), $val:literal) => {
54 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
55 #[error($msg)]
56 pub struct $err;
57 pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
58 };
59 ($name:ident, $cmd:ident, $err:ident) => {
60 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
61 #[error("Invalid header for {}", stringify!($cmd))]
62 pub struct $err;
63 pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
64 };
65}
66
67macro_rules! define_const {
68 ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
69 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
70 #[error($msg)]
71 pub struct $err;
72 pub type $name = $kind<$err, $val>;
73 };
74}
75
76macro_rules! define_const_bytes {
77 ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr_2021, $len:literal) => {
78 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
79 #[error($msg)]
80 pub struct $err;
81
82 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
83 pub struct $vname;
84
85 impl ConstBytesValue<$len> for $vname {
86 const VALUE: [u8; $len] = $val;
87 type Error = $err;
88 }
89
90 pub type $name = ConstBytes<$vname, $len>;
91 };
92}
93
94pub mod binlog_request;
95pub mod caching_sha2_password;
96pub mod session_state_change;
97
98define_const_bytes!(
99 Catalog,
100 ColumnDefinitionCatalog,
101 InvalidCatalog("Invalid catalog value in the column definition"),
102 *b"\x03def",
103 4
104);
105
106define_const!(
107 ConstU8,
108 FixedLengthFieldsLen,
109 InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
110 0x0c
111);
112
113#[derive(Debug, Default, Clone, Eq, PartialEq)]
115struct ColumnMeta<'a> {
116 schema: RawBytes<'a, LenEnc>,
117 table: RawBytes<'a, LenEnc>,
118 org_table: RawBytes<'a, LenEnc>,
119 name: RawBytes<'a, LenEnc>,
120 org_name: RawBytes<'a, LenEnc>,
121}
122
123impl ColumnMeta<'_> {
124 pub fn into_owned(self) -> ColumnMeta<'static> {
125 ColumnMeta {
126 schema: self.schema.into_owned(),
127 table: self.table.into_owned(),
128 org_table: self.org_table.into_owned(),
129 name: self.name.into_owned(),
130 org_name: self.org_name.into_owned(),
131 }
132 }
133
134 pub fn schema_ref(&self) -> &[u8] {
136 self.schema.as_bytes()
137 }
138
139 pub fn schema_str(&self) -> Cow<'_, str> {
141 String::from_utf8_lossy(self.schema_ref())
142 }
143
144 pub fn table_ref(&self) -> &[u8] {
146 self.table.as_bytes()
147 }
148
149 pub fn table_str(&self) -> Cow<'_, str> {
151 String::from_utf8_lossy(self.table_ref())
152 }
153
154 pub fn org_table_ref(&self) -> &[u8] {
158 self.org_table.as_bytes()
159 }
160
161 pub fn org_table_str(&self) -> Cow<'_, str> {
163 String::from_utf8_lossy(self.org_table_ref())
164 }
165
166 pub fn name_ref(&self) -> &[u8] {
168 self.name.as_bytes()
169 }
170
171 pub fn name_str(&self) -> Cow<'_, str> {
173 String::from_utf8_lossy(self.name_ref())
174 }
175
176 pub fn org_name_ref(&self) -> &[u8] {
180 self.org_name.as_bytes()
181 }
182
183 pub fn org_name_str(&self) -> Cow<'_, str> {
185 String::from_utf8_lossy(self.org_name_ref())
186 }
187}
188
189impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
190 const SIZE: Option<usize> = None;
191 type Ctx = ();
192
193 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
194 Ok(Self {
195 schema: buf.parse_unchecked(())?,
196 table: buf.parse_unchecked(())?,
197 org_table: buf.parse_unchecked(())?,
198 name: buf.parse_unchecked(())?,
199 org_name: buf.parse_unchecked(())?,
200 })
201 }
202}
203
204impl MySerialize for ColumnMeta<'_> {
205 fn serialize(&self, buf: &mut Vec<u8>) {
206 self.schema.serialize(&mut *buf);
207 self.table.serialize(&mut *buf);
208 self.org_table.serialize(&mut *buf);
209 self.name.serialize(&mut *buf);
210 self.org_name.serialize(&mut *buf);
211 }
212}
213
214#[derive(Debug, Clone, Eq, PartialEq)]
216pub struct Column {
217 catalog: ColumnDefinitionCatalog,
218 meta: Arc<ColumnMeta<'static>>,
219 fixed_length_fields_len: FixedLengthFieldsLen,
220 column_length: RawInt<LeU32>,
221 character_set: RawInt<LeU16>,
222 column_type: Const<ColumnType, u8>,
223 flags: Const<ColumnFlags, LeU16>,
224 decimals: RawInt<u8>,
225 __filler: Skip<2>,
226 }
228
229impl<'de> MyDeserialize<'de> for Column {
230 const SIZE: Option<usize> = None;
231 type Ctx = ();
232
233 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
234 let catalog = buf.parse(())?;
235 let meta = Arc::new(buf.parse::<ColumnMeta<'_>>(())?.into_owned());
236 let mut buf: ParseBuf<'_> = buf.parse(13)?;
237
238 Ok(Column {
239 catalog,
240 meta,
241 fixed_length_fields_len: buf.parse_unchecked(())?,
242 character_set: buf.parse_unchecked(())?,
243 column_length: buf.parse_unchecked(())?,
244 column_type: buf.parse_unchecked(())?,
245 flags: buf.parse_unchecked(())?,
246 decimals: buf.parse_unchecked(())?,
247 __filler: buf.parse_unchecked(())?,
248 })
249 }
250}
251
252impl MySerialize for Column {
253 fn serialize(&self, buf: &mut Vec<u8>) {
254 self.catalog.serialize(&mut *buf);
255 self.meta.serialize(&mut *buf);
256 self.fixed_length_fields_len.serialize(&mut *buf);
257 self.column_length.serialize(&mut *buf);
258 self.character_set.serialize(&mut *buf);
259 self.column_type.serialize(&mut *buf);
260 self.flags.serialize(&mut *buf);
261 self.decimals.serialize(&mut *buf);
262 self.__filler.serialize(&mut *buf);
263 }
264}
265
266impl Column {
267 pub fn new(column_type: ColumnType) -> Self {
268 Self {
269 catalog: Default::default(),
270 meta: Default::default(),
271 fixed_length_fields_len: Default::default(),
272 column_length: Default::default(),
273 character_set: Default::default(),
274 flags: Default::default(),
275 column_type: Const::new(column_type),
276 decimals: Default::default(),
277 __filler: Skip,
278 }
279 }
280
281 pub fn with_schema(mut self, schema: &[u8]) -> Self {
282 Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
283 self
284 }
285
286 pub fn with_table(mut self, table: &[u8]) -> Self {
287 Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
288 self
289 }
290
291 pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
292 Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
293 self
294 }
295
296 pub fn with_name(mut self, name: &[u8]) -> Self {
297 Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
298 self
299 }
300
301 pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
302 Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
303 self
304 }
305
306 pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
307 self.flags = Const::new(flags);
308 self
309 }
310
311 pub fn with_column_length(mut self, column_length: u32) -> Self {
312 self.column_length = RawInt::new(column_length);
313 self
314 }
315
316 pub fn with_character_set(mut self, character_set: u16) -> Self {
317 self.character_set = RawInt::new(character_set);
318 self
319 }
320
321 pub fn with_decimals(mut self, decimals: u8) -> Self {
322 self.decimals = RawInt::new(decimals);
323 self
324 }
325
326 pub fn column_length(&self) -> u32 {
330 *self.column_length
331 }
332
333 pub fn column_type(&self) -> ColumnType {
335 *self.column_type
336 }
337
338 pub fn character_set(&self) -> u16 {
340 *self.character_set
341 }
342
343 pub fn flags(&self) -> ColumnFlags {
345 *self.flags
346 }
347
348 pub fn decimals(&self) -> u8 {
356 *self.decimals
357 }
358
359 #[inline(always)]
361 pub fn schema_ref(&self) -> &[u8] {
362 self.meta.schema_ref()
363 }
364
365 #[inline(always)]
367 pub fn schema_str(&self) -> Cow<'_, str> {
368 self.meta.schema_str()
369 }
370
371 #[inline(always)]
373 pub fn table_ref(&self) -> &[u8] {
374 self.meta.table_ref()
375 }
376
377 #[inline(always)]
379 pub fn table_str(&self) -> Cow<'_, str> {
380 self.meta.table_str()
381 }
382
383 #[inline(always)]
387 pub fn org_table_ref(&self) -> &[u8] {
388 self.meta.org_table_ref()
389 }
390
391 #[inline(always)]
393 pub fn org_table_str(&self) -> Cow<'_, str> {
394 self.meta.org_table_str()
395 }
396
397 #[inline(always)]
399 pub fn name_ref(&self) -> &[u8] {
400 self.meta.name_ref()
401 }
402
403 #[inline(always)]
405 pub fn name_str(&self) -> Cow<'_, str> {
406 self.meta.name_str()
407 }
408
409 #[inline(always)]
413 pub fn org_name_ref(&self) -> &[u8] {
414 self.meta.org_name_ref()
415 }
416
417 #[inline(always)]
419 pub fn org_name_str(&self) -> Cow<'_, str> {
420 self.meta.org_name_str()
421 }
422}
423
424#[derive(Debug, Clone, Eq, PartialEq)]
426pub struct SessionStateInfo<'a> {
427 data_type: Const<SessionStateType, u8>,
428 data: RawBytes<'a, LenEnc>,
429}
430
431impl SessionStateInfo<'_> {
432 pub fn into_owned(self) -> SessionStateInfo<'static> {
433 let SessionStateInfo { data_type, data } = self;
434 SessionStateInfo {
435 data_type,
436 data: data.into_owned(),
437 }
438 }
439
440 pub fn data_type(&self) -> SessionStateType {
441 *self.data_type
442 }
443
444 pub fn data_ref(&self) -> &[u8] {
446 self.data.as_bytes()
447 }
448
449 pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
451 ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
452 }
453}
454
455impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
456 const SIZE: Option<usize> = None;
457 type Ctx = ();
458
459 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
460 Ok(SessionStateInfo {
461 data_type: buf.parse(())?,
462 data: buf.parse(())?,
463 })
464 }
465}
466
467impl MySerialize for SessionStateInfo<'_> {
468 fn serialize(&self, buf: &mut Vec<u8>) {
469 self.data_type.serialize(&mut *buf);
470 self.data.serialize(buf);
471 }
472}
473
474#[derive(Debug, Clone, Eq, PartialEq)]
476pub struct OkPacketBody<'a> {
477 affected_rows: RawInt<LenEnc>,
478 last_insert_id: RawInt<LenEnc>,
479 status_flags: Const<StatusFlags, LeU16>,
480 warnings: RawInt<LeU16>,
481 info: RawBytes<'a, LenEnc>,
482 session_state_info: RawBytes<'a, LenEnc>,
483}
484
485pub trait OkPacketKind {
489 const HEADER: u8;
490
491 fn parse_body<'de>(
492 capabilities: CapabilityFlags,
493 buf: &mut ParseBuf<'de>,
494 ) -> io::Result<OkPacketBody<'de>>;
495}
496
497#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
499pub struct ResultSetTerminator;
500
501impl OkPacketKind for ResultSetTerminator {
502 const HEADER: u8 = 0xFE;
503
504 fn parse_body<'de>(
505 capabilities: CapabilityFlags,
506 buf: &mut ParseBuf<'de>,
507 ) -> io::Result<OkPacketBody<'de>> {
508 buf.parse::<RawInt<LenEnc>>(())?;
513 buf.parse::<RawInt<LenEnc>>(())?;
514
515 let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
517 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
518 let warnings = sbuf.parse_unchecked(())?;
519
520 let (info, session_state_info) =
521 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
522 let info = buf.parse(())?;
523 let session_state_info =
524 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
525 buf.parse(())?
526 } else {
527 RawBytes::default()
528 };
529 (info, session_state_info)
530 } else if !buf.is_empty() && buf.0[0] > 0 {
531 let info = buf.parse(())?;
535 (info, RawBytes::default())
536 } else {
537 (RawBytes::default(), RawBytes::default())
538 };
539
540 Ok(OkPacketBody {
541 affected_rows: RawInt::new(0),
542 last_insert_id: RawInt::new(0),
543 status_flags,
544 warnings,
545 info,
546 session_state_info,
547 })
548 }
549}
550
551#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
553pub struct OldEofPacket;
554
555impl OkPacketKind for OldEofPacket {
556 const HEADER: u8 = 0xFE;
557
558 fn parse_body<'de>(
559 _: CapabilityFlags,
560 buf: &mut ParseBuf<'de>,
561 ) -> io::Result<OkPacketBody<'de>> {
562 let mut buf: ParseBuf<'_> = buf.parse(4)?;
564 let warnings = buf.parse_unchecked(())?;
565 let status_flags = buf.parse_unchecked(())?;
566
567 Ok(OkPacketBody {
568 affected_rows: RawInt::new(0),
569 last_insert_id: RawInt::new(0),
570 status_flags,
571 warnings,
572 info: RawBytes::new(&[][..]),
573 session_state_info: RawBytes::new(&[][..]),
574 })
575 }
576}
577
578#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
580pub struct NetworkStreamTerminator;
581
582impl OkPacketKind for NetworkStreamTerminator {
583 const HEADER: u8 = 0xFE;
584
585 fn parse_body<'de>(
586 flags: CapabilityFlags,
587 buf: &mut ParseBuf<'de>,
588 ) -> io::Result<OkPacketBody<'de>> {
589 OldEofPacket::parse_body(flags, buf)
590 }
591}
592
593#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
595pub struct CommonOkPacket;
596
597impl OkPacketKind for CommonOkPacket {
598 const HEADER: u8 = 0x00;
599
600 fn parse_body<'de>(
601 capabilities: CapabilityFlags,
602 buf: &mut ParseBuf<'de>,
603 ) -> io::Result<OkPacketBody<'de>> {
604 let affected_rows = buf.parse(())?;
605 let last_insert_id = buf.parse(())?;
606
607 let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
609 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
610 let warnings = sbuf.parse_unchecked(())?;
611
612 let (info, session_state_info) =
613 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
614 let info = buf.parse(())?;
615 let session_state_info =
616 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
617 buf.parse(())?
618 } else {
619 RawBytes::default()
620 };
621 (info, session_state_info)
622 } else if !buf.is_empty() && buf.0[0] > 0 {
623 let info = buf.parse(())?;
627 (info, RawBytes::default())
628 } else {
629 (RawBytes::default(), RawBytes::default())
630 };
631
632 Ok(OkPacketBody {
633 affected_rows,
634 last_insert_id,
635 status_flags,
636 warnings,
637 info,
638 session_state_info,
639 })
640 }
641}
642
643impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
644 type Error = io::Error;
645
646 fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
647 Ok(OkPacket {
648 affected_rows: *body.affected_rows,
649 last_insert_id: if *body.last_insert_id == 0 {
650 None
651 } else {
652 Some(*body.last_insert_id)
653 },
654 status_flags: *body.status_flags,
655 warnings: *body.warnings,
656 info: if !body.info.is_empty() {
657 Some(body.info)
658 } else {
659 None
660 },
661 session_state_info: if !body.session_state_info.is_empty() {
662 Some(body.session_state_info)
663 } else {
664 None
665 },
666 })
667 }
668}
669
670#[derive(Debug, Clone, Eq, PartialEq)]
672pub struct OkPacket<'a> {
673 affected_rows: u64,
674 last_insert_id: Option<u64>,
675 status_flags: StatusFlags,
676 warnings: u16,
677 info: Option<RawBytes<'a, LenEnc>>,
678 session_state_info: Option<RawBytes<'a, LenEnc>>,
679}
680
681impl OkPacket<'_> {
682 pub fn into_owned(self) -> OkPacket<'static> {
683 OkPacket {
684 affected_rows: self.affected_rows,
685 last_insert_id: self.last_insert_id,
686 status_flags: self.status_flags,
687 warnings: self.warnings,
688 info: self.info.map(|x| x.into_owned()),
689 session_state_info: self.session_state_info.map(|x| x.into_owned()),
690 }
691 }
692
693 pub fn affected_rows(&self) -> u64 {
695 self.affected_rows
696 }
697
698 pub fn last_insert_id(&self) -> Option<u64> {
700 self.last_insert_id
701 }
702
703 pub fn status_flags(&self) -> StatusFlags {
705 self.status_flags
706 }
707
708 pub fn warnings(&self) -> u16 {
710 self.warnings
711 }
712
713 pub fn info_ref(&self) -> Option<&[u8]> {
715 self.info.as_ref().map(|x| x.as_bytes())
716 }
717
718 pub fn info_str(&self) -> Option<Cow<'_, str>> {
720 self.info.as_ref().map(|x| x.as_str())
721 }
722
723 pub fn session_state_info_ref(&self) -> Option<&[u8]> {
725 self.session_state_info.as_ref().map(|x| x.as_bytes())
726 }
727
728 pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
730 self.session_state_info_ref()
731 .map(|data| {
732 let mut data = ParseBuf(data);
733 let mut entries = Vec::new();
734 while !data.is_empty() {
735 entries.push(data.parse(())?);
736 }
737 Ok(entries)
738 })
739 .transpose()
740 .map(|x| x.unwrap_or_default())
741 }
742}
743
744#[derive(Debug, Clone, PartialEq, Eq)]
745pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
746
747impl<'de, T> OkPacketDeserializer<'de, T> {
748 pub fn into_inner(self) -> OkPacket<'de> {
749 self.0
750 }
751}
752
753impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
754 fn from(x: OkPacketDeserializer<'de, T>) -> Self {
755 x.0
756 }
757}
758
759#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
760#[error("Invalid OK packet header")]
761pub struct InvalidOkPacketHeader;
762
763impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
764 const SIZE: Option<usize> = None;
765 type Ctx = CapabilityFlags;
766
767 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
768 if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
769 let body = T::parse_body(capabilities, buf)?;
770 let ok = OkPacket::try_from(body)?;
771 Ok(Self(ok, PhantomData))
772 } else {
773 Err(io::Error::new(
774 io::ErrorKind::InvalidData,
775 InvalidOkPacketHeader,
776 ))
777 }
778 }
779}
780
781#[derive(Debug, Clone, Eq, PartialEq)]
783pub struct ProgressReport<'a> {
784 stage: RawInt<u8>,
785 max_stage: RawInt<u8>,
786 progress: RawInt<LeU24>,
787 stage_info: RawBytes<'a, LenEnc>,
788}
789
790impl<'a> ProgressReport<'a> {
791 pub fn new(
792 stage: u8,
793 max_stage: u8,
794 progress: u32,
795 stage_info: impl Into<Cow<'a, [u8]>>,
796 ) -> ProgressReport<'a> {
797 ProgressReport {
798 stage: RawInt::new(stage),
799 max_stage: RawInt::new(max_stage),
800 progress: RawInt::new(progress),
801 stage_info: RawBytes::new(stage_info),
802 }
803 }
804
805 pub fn stage(&self) -> u8 {
807 *self.stage
808 }
809
810 pub fn max_stage(&self) -> u8 {
811 *self.max_stage
812 }
813
814 pub fn progress(&self) -> u32 {
816 *self.progress
817 }
818
819 pub fn stage_info_ref(&self) -> &[u8] {
821 self.stage_info.as_bytes()
822 }
823
824 pub fn stage_info_str(&self) -> Cow<'_, str> {
826 self.stage_info.as_str()
827 }
828
829 pub fn into_owned(self) -> ProgressReport<'static> {
830 ProgressReport {
831 stage: self.stage,
832 max_stage: self.max_stage,
833 progress: self.progress,
834 stage_info: self.stage_info.into_owned(),
835 }
836 }
837}
838
839impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
840 const SIZE: Option<usize> = None;
841 type Ctx = ();
842
843 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
844 let mut sbuf: ParseBuf<'_> = buf.parse(6)?;
845
846 sbuf.skip(1); Ok(ProgressReport {
849 stage: sbuf.parse_unchecked(())?,
850 max_stage: sbuf.parse_unchecked(())?,
851 progress: sbuf.parse_unchecked(())?,
852 stage_info: buf.parse(())?,
853 })
854 }
855}
856
857impl MySerialize for ProgressReport<'_> {
858 fn serialize(&self, buf: &mut Vec<u8>) {
859 buf.put_u8(1);
860 self.stage.serialize(&mut *buf);
861 self.max_stage.serialize(&mut *buf);
862 self.progress.serialize(&mut *buf);
863 self.stage_info.serialize(buf);
864 }
865}
866
867impl fmt::Display for ProgressReport<'_> {
868 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
869 write!(
870 f,
871 "Stage: {} of {} '{}' {:.2}% of stage done",
872 self.stage(),
873 self.max_stage(),
874 self.progress(),
875 self.stage_info_str()
876 )
877 }
878}
879
880define_header!(
881 ErrPacketHeader,
882 InvalidErrPacketHeader("Invalid error packet header"),
883 0xFF
884);
885
886#[derive(Debug, Clone, PartialEq)]
890pub enum ErrPacket<'a> {
891 Error(ServerError<'a>),
892 Progress(ProgressReport<'a>),
893}
894
895impl ErrPacket<'_> {
896 pub fn is_error(&self) -> bool {
898 matches!(self, ErrPacket::Error { .. })
899 }
900
901 pub fn is_progress_report(&self) -> bool {
903 !self.is_error()
904 }
905
906 pub fn progress_report(&self) -> &ProgressReport<'_> {
908 match *self {
909 ErrPacket::Progress(ref progress_report) => progress_report,
910 _ => panic!("This ErrPacket does not contains progress report"),
911 }
912 }
913
914 pub fn server_error(&self) -> &ServerError<'_> {
916 match self {
917 ErrPacket::Error(error) => error,
918 ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
919 }
920 }
921}
922
923impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
924 const SIZE: Option<usize> = None;
925 type Ctx = CapabilityFlags;
926
927 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
928 let mut sbuf: ParseBuf<'_> = buf.parse(3)?;
929 sbuf.parse_unchecked::<ErrPacketHeader>(())?;
930 let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
931
932 if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
933 buf.parse(()).map(ErrPacket::Progress)
934 } else {
935 buf.parse((
936 *code,
937 capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
938 ))
939 .map(ErrPacket::Error)
940 }
941 }
942}
943
944impl MySerialize for ErrPacket<'_> {
945 fn serialize(&self, buf: &mut Vec<u8>) {
946 ErrPacketHeader::new().serialize(&mut *buf);
947 match self {
948 ErrPacket::Error(server_error) => {
949 server_error.code.serialize(&mut *buf);
950 server_error.serialize(buf);
951 }
952 ErrPacket::Progress(progress_report) => {
953 RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
954 progress_report.serialize(buf);
955 }
956 }
957 }
958}
959
960impl fmt::Display for ErrPacket<'_> {
961 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
962 match self {
963 ErrPacket::Error(server_error) => write!(f, "{}", server_error),
964 ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
965 }
966 }
967}
968
969define_header!(
970 SqlStateMarker,
971 InvalidSqlStateMarker("Invalid SqlStateMarker value"),
972 b'#'
973);
974
975#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
977pub struct SqlState {
978 __state_marker: SqlStateMarker,
979 state: [u8; 5],
980}
981
982impl SqlState {
983 pub fn new(state: [u8; 5]) -> Self {
985 Self {
986 __state_marker: SqlStateMarker::new(),
987 state,
988 }
989 }
990
991 pub fn as_bytes(&self) -> [u8; 5] {
993 self.state
994 }
995
996 pub fn as_str(&self) -> Cow<'_, str> {
998 String::from_utf8_lossy(&self.state)
999 }
1000}
1001
1002impl<'de> MyDeserialize<'de> for SqlState {
1003 const SIZE: Option<usize> = Some(6);
1004 type Ctx = ();
1005
1006 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1007 Ok(Self {
1008 __state_marker: buf.parse(())?,
1009 state: buf.parse(())?,
1010 })
1011 }
1012}
1013
1014impl MySerialize for SqlState {
1015 fn serialize(&self, buf: &mut Vec<u8>) {
1016 self.__state_marker.serialize(buf);
1017 self.state.serialize(buf);
1018 }
1019}
1020
1021#[derive(Debug, Clone, PartialEq)]
1025pub struct ServerError<'a> {
1026 code: RawInt<LeU16>,
1027 state: Option<SqlState>,
1028 message: RawBytes<'a, EofBytes>,
1029}
1030
1031impl<'a> ServerError<'a> {
1032 pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1033 Self {
1034 code: RawInt::new(code),
1035 state,
1036 message: RawBytes::new(msg),
1037 }
1038 }
1039
1040 pub fn error_code(&self) -> u16 {
1042 *self.code
1043 }
1044
1045 pub fn sql_state_ref(&self) -> Option<&SqlState> {
1047 self.state.as_ref()
1048 }
1049
1050 pub fn message_ref(&self) -> &[u8] {
1052 self.message.as_bytes()
1053 }
1054
1055 pub fn message_str(&self) -> Cow<'_, str> {
1057 self.message.as_str()
1058 }
1059
1060 pub fn into_owned(self) -> ServerError<'static> {
1061 ServerError {
1062 code: self.code,
1063 state: self.state,
1064 message: self.message.into_owned(),
1065 }
1066 }
1067}
1068
1069impl<'de> MyDeserialize<'de> for ServerError<'de> {
1070 const SIZE: Option<usize> = None;
1071 type Ctx = (u16, bool);
1073
1074 fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1075 let server_error = if protocol_41 {
1076 ServerError {
1077 code: RawInt::new(code),
1078 state: Some(buf.parse(())?),
1079 message: buf.parse(())?,
1080 }
1081 } else {
1082 ServerError {
1083 code: RawInt::new(code),
1084 state: None,
1085 message: buf.parse(())?,
1086 }
1087 };
1088 Ok(server_error)
1089 }
1090}
1091
1092impl MySerialize for ServerError<'_> {
1093 fn serialize(&self, buf: &mut Vec<u8>) {
1094 if let Some(state) = &self.state {
1095 state.serialize(buf);
1096 }
1097 self.message.serialize(buf);
1098 }
1099}
1100
1101impl fmt::Display for ServerError<'_> {
1102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103 let sql_state_str = self
1104 .sql_state_ref()
1105 .map(|s| format!(" ({})", s.as_str()))
1106 .unwrap_or_default();
1107
1108 write!(
1109 f,
1110 "ERROR {}{}: {}",
1111 self.error_code(),
1112 sql_state_str,
1113 self.message_str()
1114 )
1115 }
1116}
1117
1118define_header!(
1119 LocalInfileHeader,
1120 InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1121 0xFB
1122);
1123
1124#[derive(Debug, Clone, Eq, PartialEq)]
1126pub struct LocalInfilePacket<'a> {
1127 __header: LocalInfileHeader,
1128 file_name: RawBytes<'a, EofBytes>,
1129}
1130
1131impl<'a> LocalInfilePacket<'a> {
1132 pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1133 Self {
1134 __header: LocalInfileHeader::new(),
1135 file_name: RawBytes::new(file_name),
1136 }
1137 }
1138
1139 pub fn file_name_ref(&self) -> &[u8] {
1141 self.file_name.as_bytes()
1142 }
1143
1144 pub fn file_name_str(&self) -> Cow<'_, str> {
1146 self.file_name.as_str()
1147 }
1148
1149 pub fn into_owned(self) -> LocalInfilePacket<'static> {
1150 LocalInfilePacket {
1151 __header: self.__header,
1152 file_name: self.file_name.into_owned(),
1153 }
1154 }
1155}
1156
1157impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1158 const SIZE: Option<usize> = None;
1159 type Ctx = ();
1160
1161 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1162 Ok(LocalInfilePacket {
1163 __header: buf.parse(())?,
1164 file_name: buf.parse(())?,
1165 })
1166 }
1167}
1168
1169impl MySerialize for LocalInfilePacket<'_> {
1170 fn serialize(&self, buf: &mut Vec<u8>) {
1171 self.__header.serialize(buf);
1172 self.file_name.serialize(buf);
1173 }
1174}
1175
1176const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1177const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1178const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1179const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1180const ED25519_PLUGIN_NAME: &[u8] = b"client_ed25519";
1181
1182#[derive(Debug, Clone, PartialEq, Eq)]
1183pub enum AuthPluginData<'a> {
1184 Old([u8; 8]),
1186 Native([u8; 20]),
1188 Sha2([u8; 32]),
1190 Clear(Cow<'a, [u8]>),
1192 Ed25519([u8; 64]),
1197}
1198
1199impl AuthPluginData<'_> {
1200 pub fn into_owned(self) -> AuthPluginData<'static> {
1201 match self {
1202 AuthPluginData::Old(x) => AuthPluginData::Old(x),
1203 AuthPluginData::Native(x) => AuthPluginData::Native(x),
1204 AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1205 AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1206 AuthPluginData::Ed25519(x) => AuthPluginData::Ed25519(x),
1207 }
1208 }
1209}
1210
1211impl std::ops::Deref for AuthPluginData<'_> {
1212 type Target = [u8];
1213
1214 fn deref(&self) -> &Self::Target {
1215 match self {
1216 Self::Sha2(x) => &x[..],
1217 Self::Native(x) => &x[..],
1218 Self::Old(x) => &x[..],
1219 Self::Clear(x) => &x[..],
1220 Self::Ed25519(x) => &x[..],
1221 }
1222 }
1223}
1224
1225impl MySerialize for AuthPluginData<'_> {
1226 fn serialize(&self, buf: &mut Vec<u8>) {
1227 match self {
1228 Self::Sha2(x) => buf.put_slice(&x[..]),
1229 Self::Native(x) => buf.put_slice(&x[..]),
1230 Self::Old(x) => {
1231 buf.put_slice(&x[..]);
1232 buf.push(0);
1233 }
1234 Self::Clear(x) => {
1235 buf.put_slice(x);
1236 buf.push(0);
1237 }
1238 Self::Ed25519(x) => buf.put_slice(&x[..]),
1239 }
1240 }
1241}
1242
1243#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1245pub enum AuthPlugin<'a> {
1246 MysqlOldPassword,
1248 MysqlClearPassword,
1250 MysqlNativePassword,
1252 CachingSha2Password,
1254 Ed25519,
1259 Other(Cow<'a, [u8]>),
1260}
1261
1262impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1263 const SIZE: Option<usize> = None;
1264 type Ctx = ();
1265
1266 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1267 Ok(Self::from_bytes(buf.eat_all()))
1268 }
1269}
1270
1271impl MySerialize for AuthPlugin<'_> {
1272 fn serialize(&self, buf: &mut Vec<u8>) {
1273 buf.put_slice(self.as_bytes());
1274 buf.put_u8(0);
1275 }
1276}
1277
1278impl<'a> AuthPlugin<'a> {
1279 pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1280 let name = if let [name @ .., 0] = name {
1281 name
1282 } else {
1283 name
1284 };
1285 match name {
1286 CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1287 MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1288 MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1289 MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1290 ED25519_PLUGIN_NAME => AuthPlugin::Ed25519,
1291 name => AuthPlugin::Other(Cow::Borrowed(name)),
1292 }
1293 }
1294
1295 pub fn as_bytes(&self) -> &[u8] {
1296 match self {
1297 AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1298 AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1299 AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1300 AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1301 AuthPlugin::Ed25519 => ED25519_PLUGIN_NAME,
1302 AuthPlugin::Other(name) => name,
1303 }
1304 }
1305
1306 pub fn into_owned(self) -> AuthPlugin<'static> {
1307 match self {
1308 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1309 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1310 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1311 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1312 AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1313 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1314 }
1315 }
1316
1317 pub fn borrow(&self) -> AuthPlugin<'_> {
1318 match self {
1319 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1320 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1321 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1322 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1323 AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1324 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1325 }
1326 }
1327
1328 pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1338 use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1339
1340 match pass {
1341 Some(pass) if !pass.is_empty() => match self {
1342 AuthPlugin::CachingSha2Password => {
1343 scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1344 }
1345 AuthPlugin::MysqlNativePassword => {
1346 scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1347 }
1348 AuthPlugin::MysqlOldPassword => {
1349 scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1350 .map(AuthPluginData::Old)
1351 }
1352 AuthPlugin::MysqlClearPassword => {
1353 Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1354 }
1355 AuthPlugin::Ed25519 => Some(AuthPluginData::Ed25519(create_response_for_ed25519(
1356 pass.as_bytes(),
1357 nonce,
1358 ))),
1359 AuthPlugin::Other(_) => None,
1360 },
1361 _ => None,
1362 }
1363 }
1364}
1365
1366define_header!(
1367 AuthMoreDataHeader,
1368 InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1369 0x01
1370);
1371
1372#[derive(Debug, Clone, Eq, PartialEq)]
1374pub struct AuthMoreData<'a> {
1375 __header: AuthMoreDataHeader,
1376 data: RawBytes<'a, EofBytes>,
1377}
1378
1379impl<'a> AuthMoreData<'a> {
1380 pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1381 Self {
1382 __header: AuthMoreDataHeader::new(),
1383 data: RawBytes::new(data),
1384 }
1385 }
1386
1387 pub fn data(&self) -> &[u8] {
1388 self.data.as_bytes()
1389 }
1390
1391 pub fn into_owned(self) -> AuthMoreData<'static> {
1392 AuthMoreData {
1393 __header: self.__header,
1394 data: self.data.into_owned(),
1395 }
1396 }
1397}
1398
1399impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1400 const SIZE: Option<usize> = None;
1401 type Ctx = ();
1402
1403 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1404 Ok(Self {
1405 __header: buf.parse(())?,
1406 data: buf.parse(())?,
1407 })
1408 }
1409}
1410
1411impl MySerialize for AuthMoreData<'_> {
1412 fn serialize(&self, buf: &mut Vec<u8>) {
1413 self.__header.serialize(&mut *buf);
1414 self.data.serialize(buf);
1415 }
1416}
1417
1418define_header!(
1419 PublicKeyResponseHeader,
1420 InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1421 0x01
1422);
1423
1424#[derive(Debug, Clone, Eq, PartialEq)]
1428pub struct PublicKeyResponse<'a> {
1429 __header: PublicKeyResponseHeader,
1430 rsa_key: RawBytes<'a, EofBytes>,
1431}
1432
1433impl<'a> PublicKeyResponse<'a> {
1434 pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1435 Self {
1436 __header: PublicKeyResponseHeader::new(),
1437 rsa_key: RawBytes::new(rsa_key),
1438 }
1439 }
1440
1441 pub fn rsa_key(&self) -> Cow<'_, str> {
1443 self.rsa_key.as_str()
1444 }
1445}
1446
1447impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1448 const SIZE: Option<usize> = None;
1449 type Ctx = ();
1450
1451 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1452 Ok(Self {
1453 __header: buf.parse(())?,
1454 rsa_key: buf.parse(())?,
1455 })
1456 }
1457}
1458
1459impl MySerialize for PublicKeyResponse<'_> {
1460 fn serialize(&self, buf: &mut Vec<u8>) {
1461 self.__header.serialize(&mut *buf);
1462 self.rsa_key.serialize(buf);
1463 }
1464}
1465
1466define_header!(
1467 AuthSwitchRequestHeader,
1468 InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1469 0xFE
1470);
1471
1472#[derive(Debug, Clone, Eq, PartialEq)]
1477pub struct OldAuthSwitchRequest {
1478 __header: AuthSwitchRequestHeader,
1479}
1480
1481impl OldAuthSwitchRequest {
1482 pub fn new() -> Self {
1483 Self {
1484 __header: AuthSwitchRequestHeader::new(),
1485 }
1486 }
1487
1488 pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1489 AuthPlugin::MysqlOldPassword
1490 }
1491}
1492
1493impl Default for OldAuthSwitchRequest {
1494 fn default() -> Self {
1495 Self::new()
1496 }
1497}
1498
1499impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1500 const SIZE: Option<usize> = Some(1);
1501 type Ctx = ();
1502
1503 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1504 Ok(Self {
1505 __header: buf.parse(())?,
1506 })
1507 }
1508}
1509
1510impl MySerialize for OldAuthSwitchRequest {
1511 fn serialize(&self, buf: &mut Vec<u8>) {
1512 self.__header.serialize(&mut *buf);
1513 }
1514}
1515
1516#[derive(Debug, Clone, Eq, PartialEq)]
1521pub struct AuthSwitchRequest<'a> {
1522 __header: AuthSwitchRequestHeader,
1523 auth_plugin: RawBytes<'a, NullBytes>,
1524 plugin_data: RawBytes<'a, EofBytes>,
1525}
1526
1527impl<'a> AuthSwitchRequest<'a> {
1528 pub fn new(
1529 auth_plugin: impl Into<Cow<'a, [u8]>>,
1530 plugin_data: impl Into<Cow<'a, [u8]>>,
1531 ) -> Self {
1532 Self {
1533 __header: AuthSwitchRequestHeader::new(),
1534 auth_plugin: RawBytes::new(auth_plugin),
1535 plugin_data: RawBytes::new(plugin_data),
1536 }
1537 }
1538
1539 pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1540 ParseBuf(self.auth_plugin.as_bytes())
1541 .parse(())
1542 .expect("infallible")
1543 }
1544
1545 pub fn plugin_data(&self) -> &[u8] {
1546 match self.plugin_data.as_bytes() {
1547 [head @ .., 0] => head,
1548 all => all,
1549 }
1550 }
1551
1552 pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1553 AuthSwitchRequest {
1554 __header: self.__header,
1555 auth_plugin: self.auth_plugin.into_owned(),
1556 plugin_data: self.plugin_data.into_owned(),
1557 }
1558 }
1559}
1560
1561impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1562 const SIZE: Option<usize> = None;
1563 type Ctx = ();
1564
1565 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1566 Ok(Self {
1567 __header: buf.parse(())?,
1568 auth_plugin: buf.parse(())?,
1569 plugin_data: buf.parse(())?,
1570 })
1571 }
1572}
1573
1574impl MySerialize for AuthSwitchRequest<'_> {
1575 fn serialize(&self, buf: &mut Vec<u8>) {
1576 self.__header.serialize(&mut *buf);
1577 self.auth_plugin.serialize(&mut *buf);
1578 self.plugin_data.serialize(buf);
1579 }
1580}
1581
1582#[derive(Debug, Clone, Eq, PartialEq)]
1584pub struct HandshakePacket<'a> {
1585 protocol_version: RawInt<u8>,
1586 server_version: RawBytes<'a, NullBytes>,
1587 connection_id: RawInt<LeU32>,
1588 scramble_1: [u8; 8],
1589 __filler: Skip<1>,
1590 capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1592 default_collation: RawInt<u8>,
1593 status_flags: Const<StatusFlags, LeU16>,
1594 capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1596 auth_plugin_data_len: RawInt<u8>,
1597 __reserved: Skip<10>,
1598 scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1599 auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1600}
1601
1602impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1603 const SIZE: Option<usize> = None;
1604 type Ctx = ();
1605
1606 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1607 let protocol_version = buf.parse(())?;
1608 let server_version = buf.parse(())?;
1609
1610 let mut sbuf: ParseBuf<'_> = buf.parse(31)?;
1612 let connection_id = sbuf.parse_unchecked(())?;
1613 let scramble_1 = sbuf.parse_unchecked(())?;
1614 let __filler = sbuf.parse_unchecked(())?;
1615 let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1616 let default_collation = sbuf.parse_unchecked(())?;
1617 let status_flags = sbuf.parse_unchecked(())?;
1618 let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1619 let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1620 let __reserved = sbuf.parse_unchecked(())?;
1621 let mut scramble_2 = None;
1622 if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1623 let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1624 scramble_2 = buf.parse(len).map(Some)?;
1625 }
1626 let mut auth_plugin_name = None;
1627 if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1628 auth_plugin_name = match buf.eat_all() {
1629 [head @ .., 0] => Some(RawBytes::new(head)),
1630 all => Some(RawBytes::new(all)),
1632 }
1633 }
1634
1635 Ok(Self {
1636 protocol_version,
1637 server_version,
1638 connection_id,
1639 scramble_1,
1640 __filler,
1641 capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1642 default_collation,
1643 status_flags,
1644 capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1645 auth_plugin_data_len,
1646 __reserved,
1647 scramble_2,
1648 auth_plugin_name,
1649 })
1650 }
1651}
1652
1653impl MySerialize for HandshakePacket<'_> {
1654 fn serialize(&self, buf: &mut Vec<u8>) {
1655 self.protocol_version.serialize(&mut *buf);
1656 self.server_version.serialize(&mut *buf);
1657 self.connection_id.serialize(&mut *buf);
1658 self.scramble_1.serialize(&mut *buf);
1659 buf.put_u8(0x00);
1660 self.capabilities_1.serialize(&mut *buf);
1661 self.default_collation.serialize(&mut *buf);
1662 self.status_flags.serialize(&mut *buf);
1663 self.capabilities_2.serialize(&mut *buf);
1664
1665 if self
1666 .capabilities_2
1667 .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1668 {
1669 buf.put_u8(
1670 self.scramble_2
1671 .as_ref()
1672 .map(|x| (x.len() + 8) as u8)
1673 .unwrap_or_default(),
1674 );
1675 } else {
1676 buf.put_u8(0);
1677 }
1678
1679 buf.put_slice(&[0_u8; 10][..]);
1680
1681 if let Some(scramble_2) = &self.scramble_2 {
1684 scramble_2.serialize(&mut *buf);
1685 }
1686
1687 if let Some(client_plugin_auth) = &self.auth_plugin_name {
1690 client_plugin_auth.serialize(buf);
1691 }
1692 }
1693}
1694
1695impl<'a> HandshakePacket<'a> {
1696 #[allow(clippy::too_many_arguments)]
1697 pub fn new(
1698 protocol_version: u8,
1699 server_version: impl Into<Cow<'a, [u8]>>,
1700 connection_id: u32,
1701 scramble_1: [u8; 8],
1702 scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1703 capabilities: CapabilityFlags,
1704 default_collation: u8,
1705 status_flags: StatusFlags,
1706 auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1707 ) -> Self {
1708 let (capabilities_1, capabilities_2) = (
1712 CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1713 CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1714 );
1715
1716 let scramble_2 = scramble_2.map(RawBytes::new);
1717
1718 HandshakePacket {
1719 protocol_version: RawInt::new(protocol_version),
1720 server_version: RawBytes::new(server_version),
1721 connection_id: RawInt::new(connection_id),
1722 scramble_1,
1723 __filler: Skip,
1724 capabilities_1: Const::new(capabilities_1),
1725 default_collation: RawInt::new(default_collation),
1726 status_flags: Const::new(status_flags),
1727 capabilities_2: Const::new(capabilities_2),
1728 auth_plugin_data_len: RawInt::new(
1729 scramble_2
1730 .as_ref()
1731 .map(|x| x.len() as u8)
1732 .unwrap_or_default(),
1733 ),
1734 __reserved: Skip,
1735 scramble_2,
1736 auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1737 }
1738 }
1739
1740 pub fn into_owned(self) -> HandshakePacket<'static> {
1741 HandshakePacket {
1742 protocol_version: self.protocol_version,
1743 server_version: self.server_version.into_owned(),
1744 connection_id: self.connection_id,
1745 scramble_1: self.scramble_1,
1746 __filler: self.__filler,
1747 capabilities_1: self.capabilities_1,
1748 default_collation: self.default_collation,
1749 status_flags: self.status_flags,
1750 capabilities_2: self.capabilities_2,
1751 auth_plugin_data_len: self.auth_plugin_data_len,
1752 __reserved: self.__reserved,
1753 scramble_2: self.scramble_2.map(|x| x.into_owned()),
1754 auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1755 }
1756 }
1757
1758 pub fn protocol_version(&self) -> u8 {
1760 self.protocol_version.0
1761 }
1762
1763 pub fn server_version_ref(&self) -> &[u8] {
1765 self.server_version.as_bytes()
1766 }
1767
1768 pub fn server_version_str(&self) -> Cow<'_, str> {
1771 self.server_version.as_str()
1772 }
1773
1774 pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1778 VERSION_RE
1779 .captures(self.server_version_ref())
1780 .map(|captures| {
1781 (
1783 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1784 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1785 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1786 )
1787 })
1788 }
1789
1790 pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1792 MARIADB_VERSION_RE
1793 .captures(self.server_version_ref())
1794 .map(|captures| {
1795 (
1797 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1798 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1799 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1800 )
1801 })
1802 }
1803
1804 pub fn connection_id(&self) -> u32 {
1806 self.connection_id.0
1807 }
1808
1809 pub fn scramble_1_ref(&self) -> &[u8] {
1811 self.scramble_1.as_ref()
1812 }
1813
1814 pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1818 self.scramble_2.as_ref().map(|x| x.as_bytes())
1819 }
1820
1821 pub fn nonce(&self) -> Vec<u8> {
1823 let mut out = Vec::from(self.scramble_1_ref());
1824 out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1825
1826 out.resize(20, 0);
1829 out
1830 }
1831
1832 pub fn capabilities(&self) -> CapabilityFlags {
1834 self.capabilities_1.0 | self.capabilities_2.0
1835 }
1836
1837 pub fn default_collation(&self) -> u8 {
1839 self.default_collation.0
1840 }
1841
1842 pub fn status_flags(&self) -> StatusFlags {
1844 self.status_flags.0
1845 }
1846
1847 pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1849 self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1850 }
1851
1852 pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1855 self.auth_plugin_name.as_ref().map(|x| x.as_str())
1856 }
1857
1858 pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1860 self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1861 [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1862 all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1863 })
1864 }
1865}
1866
1867define_header!(
1868 ComChangeUserHeader,
1869 InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1870 0x11
1871);
1872
1873#[derive(Debug, Clone, PartialEq, Eq)]
1874pub struct ComChangeUser<'a> {
1875 __header: ComChangeUserHeader,
1876 user: RawBytes<'a, NullBytes>,
1877 auth_plugin_data: RawBytes<'a, U8Bytes>,
1879 database: RawBytes<'a, NullBytes>,
1880 more_data: Option<ComChangeUserMoreData<'a>>,
1881}
1882
1883impl<'a> ComChangeUser<'a> {
1884 pub fn new() -> Self {
1885 Self {
1886 __header: ComChangeUserHeader::new(),
1887 user: Default::default(),
1888 auth_plugin_data: Default::default(),
1889 database: Default::default(),
1890 more_data: None,
1891 }
1892 }
1893
1894 pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1895 self.user = user.map(RawBytes::new).unwrap_or_default();
1896 self
1897 }
1898
1899 pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1900 self.database = database.map(RawBytes::new).unwrap_or_default();
1901 self
1902 }
1903
1904 pub fn with_auth_plugin_data(
1905 mut self,
1906 auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1907 ) -> Self {
1908 self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1909 self
1910 }
1911
1912 pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1913 self.more_data = more_data;
1914 self
1915 }
1916
1917 pub fn into_owned(self) -> ComChangeUser<'static> {
1918 ComChangeUser {
1919 __header: self.__header,
1920 user: self.user.into_owned(),
1921 auth_plugin_data: self.auth_plugin_data.into_owned(),
1922 database: self.database.into_owned(),
1923 more_data: self.more_data.map(|x| x.into_owned()),
1924 }
1925 }
1926}
1927
1928impl Default for ComChangeUser<'_> {
1929 fn default() -> Self {
1930 Self::new()
1931 }
1932}
1933
1934impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1935 const SIZE: Option<usize> = None;
1936
1937 type Ctx = CapabilityFlags;
1938
1939 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1940 Ok(Self {
1941 __header: buf.parse(())?,
1942 user: buf.parse(())?,
1943 auth_plugin_data: buf.parse(())?,
1944 database: buf.parse(())?,
1945 more_data: if !buf.is_empty() {
1946 Some(buf.parse(flags)?)
1947 } else {
1948 None
1949 },
1950 })
1951 }
1952}
1953
1954impl MySerialize for ComChangeUser<'_> {
1955 fn serialize(&self, buf: &mut Vec<u8>) {
1956 self.__header.serialize(&mut *buf);
1957 self.user.serialize(&mut *buf);
1958 self.auth_plugin_data.serialize(&mut *buf);
1959 self.database.serialize(&mut *buf);
1960 if let Some(ref more_data) = self.more_data {
1961 more_data.serialize(&mut *buf);
1962 }
1963 }
1964}
1965
1966#[derive(Debug, Clone, PartialEq, Eq)]
1967pub struct ComChangeUserMoreData<'a> {
1968 character_set: RawInt<LeU16>,
1969 auth_plugin: Option<AuthPlugin<'a>>,
1970 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1971}
1972
1973impl<'a> ComChangeUserMoreData<'a> {
1974 pub fn new(character_set: u16) -> Self {
1975 Self {
1976 character_set: RawInt::new(character_set),
1977 auth_plugin: None,
1978 connect_attributes: None,
1979 }
1980 }
1981
1982 pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
1983 self.auth_plugin = auth_plugin;
1984 self
1985 }
1986
1987 pub fn with_connect_attributes(
1988 mut self,
1989 connect_attributes: Option<HashMap<String, String>>,
1990 ) -> Self {
1991 self.connect_attributes = connect_attributes.map(|attrs| {
1992 attrs
1993 .into_iter()
1994 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
1995 .collect()
1996 });
1997 self
1998 }
1999
2000 pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
2001 ComChangeUserMoreData {
2002 character_set: self.character_set,
2003 auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
2004 connect_attributes: self.connect_attributes.map(|x| {
2005 x.into_iter()
2006 .map(|(k, v)| (k.into_owned(), v.into_owned()))
2007 .collect()
2008 }),
2009 }
2010 }
2011}
2012
2013fn deserialize_connect_attrs<'de>(
2015 buf: &mut ParseBuf<'de>,
2016) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
2017 let data_len = buf.parse::<RawInt<LenEnc>>(())?;
2018 let mut data: ParseBuf<'_> = buf.parse(data_len.0 as usize)?;
2019 let mut attrs = HashMap::new();
2020 while !data.is_empty() {
2021 let key = data.parse::<RawBytes<'_, LenEnc>>(())?;
2022 let value = data.parse::<RawBytes<'_, LenEnc>>(())?;
2023 attrs.insert(key, value);
2024 }
2025 Ok(attrs)
2026}
2027
2028fn serialize_connect_attrs<'a>(
2030 connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2031 buf: &mut Vec<u8>,
2032) {
2033 let len = connect_attributes
2034 .iter()
2035 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2036 .sum::<u64>();
2037 buf.put_lenenc_int(len);
2038
2039 for (name, value) in connect_attributes {
2040 name.serialize(&mut *buf);
2041 value.serialize(&mut *buf);
2042 }
2043}
2044
2045impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2046 const SIZE: Option<usize> = None;
2047 type Ctx = CapabilityFlags;
2048
2049 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2050 let character_set = buf.parse(())?;
2052 let mut auth_plugin = None;
2053 let mut connect_attributes = None;
2054
2055 if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2056 match buf.parse::<RawBytes<'_, NullBytes>>(())?.0 {
2058 Cow::Borrowed(bytes) => {
2059 let mut auth_plugin_buf = ParseBuf(bytes);
2060 auth_plugin = Some(auth_plugin_buf.parse(())?);
2061 }
2062 _ => unreachable!(),
2063 }
2064 };
2065
2066 if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2067 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2068 };
2069
2070 Ok(Self {
2071 character_set,
2072 auth_plugin,
2073 connect_attributes,
2074 })
2075 }
2076}
2077
2078impl MySerialize for ComChangeUserMoreData<'_> {
2079 fn serialize(&self, buf: &mut Vec<u8>) {
2080 self.character_set.serialize(&mut *buf);
2081 if let Some(ref auth_plugin) = self.auth_plugin {
2082 auth_plugin.serialize(&mut *buf);
2083 }
2084 if let Some(ref connect_attributes) = self.connect_attributes {
2085 serialize_connect_attrs(connect_attributes, buf);
2086 } else {
2087 serialize_connect_attrs(&Default::default(), buf);
2090 }
2091 }
2092}
2093
2094type ScrambleBuf<'a> =
2096 Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2097
2098#[derive(Debug, Clone, PartialEq, Eq)]
2099pub struct HandshakeResponse<'a> {
2100 capabilities: Const<CapabilityFlags, LeU32>,
2101 max_packet_size: RawInt<LeU32>,
2102 collation: RawInt<u8>,
2103 scramble_buf: ScrambleBuf<'a>,
2104 user: RawBytes<'a, NullBytes>,
2105 db_name: Option<RawBytes<'a, NullBytes>>,
2106 auth_plugin: Option<AuthPlugin<'a>>,
2107 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2108}
2109
2110impl<'a> HandshakeResponse<'a> {
2111 #[allow(clippy::too_many_arguments)]
2112 pub fn new(
2113 scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2114 server_version: (u16, u16, u16),
2115 user: Option<impl Into<Cow<'a, [u8]>>>,
2116 db_name: Option<impl Into<Cow<'a, [u8]>>>,
2117 auth_plugin: Option<AuthPlugin<'a>>,
2118 mut capabilities: CapabilityFlags,
2119 connect_attributes: Option<HashMap<String, String>>,
2120 max_packet_size: u32,
2121 ) -> Self {
2122 let scramble_buf =
2123 if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2124 Either::Left(RawBytes::new(
2125 scramble_buf.map(Into::into).unwrap_or_default(),
2126 ))
2127 } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2128 Either::Right(Either::Left(RawBytes::new(
2129 scramble_buf.map(Into::into).unwrap_or_default(),
2130 )))
2131 } else {
2132 Either::Right(Either::Right(RawBytes::new(
2133 scramble_buf.map(Into::into).unwrap_or_default(),
2134 )))
2135 };
2136
2137 if db_name.is_some() {
2138 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2139 } else {
2140 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2141 }
2142
2143 if auth_plugin.is_some() {
2144 capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2145 } else {
2146 capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2147 }
2148
2149 if connect_attributes.is_some() {
2150 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2151 } else {
2152 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2153 }
2154
2155 Self {
2156 scramble_buf,
2157 collation: if server_version >= (5, 5, 3) {
2158 RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2159 } else {
2160 RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2161 },
2162 user: user.map(RawBytes::new).unwrap_or_default(),
2163 db_name: db_name.map(RawBytes::new),
2164 auth_plugin,
2165 capabilities: Const::new(capabilities),
2166 connect_attributes: connect_attributes.map(|attrs| {
2167 attrs
2168 .into_iter()
2169 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2170 .collect()
2171 }),
2172 max_packet_size: RawInt::new(max_packet_size),
2173 }
2174 }
2175
2176 pub fn capabilities(&self) -> CapabilityFlags {
2177 self.capabilities.0
2178 }
2179
2180 pub fn collation(&self) -> u8 {
2181 self.collation.0
2182 }
2183
2184 pub fn scramble_buf(&self) -> &[u8] {
2185 match &self.scramble_buf {
2186 Either::Left(x) => x.as_bytes(),
2187 Either::Right(x) => match x {
2188 Either::Left(x) => x.as_bytes(),
2189 Either::Right(x) => x.as_bytes(),
2190 },
2191 }
2192 }
2193
2194 pub fn user(&self) -> &[u8] {
2195 self.user.as_bytes()
2196 }
2197
2198 pub fn db_name(&self) -> Option<&[u8]> {
2199 self.db_name.as_ref().map(|x| x.as_bytes())
2200 }
2201
2202 pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2203 self.auth_plugin.as_ref()
2204 }
2205
2206 #[must_use = "entails computation"]
2207 pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2208 self.connect_attributes.as_ref().map(|attrs| {
2209 attrs
2210 .iter()
2211 .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2212 .collect()
2213 })
2214 }
2215}
2216
2217impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2218 const SIZE: Option<usize> = None;
2219 type Ctx = ();
2220
2221 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2222 let mut sbuf: ParseBuf<'_> = buf.parse(4 + 4 + 1 + 23)?;
2223 let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2224 let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2225 let collation = sbuf.parse_unchecked(())?;
2226 sbuf.parse_unchecked::<Skip<23>>(())?;
2227
2228 let user = buf.parse(())?;
2229 let scramble_buf =
2230 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2231 Either::Left(buf.parse(())?)
2232 } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2233 Either::Right(Either::Left(buf.parse(())?))
2234 } else {
2235 Either::Right(Either::Right(buf.parse(())?))
2236 };
2237
2238 let mut db_name = None;
2239 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2240 db_name = buf.parse(()).map(Some)?;
2241 }
2242
2243 let mut auth_plugin = None;
2244 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2245 let auth_plugin_name = buf.eat_null_str();
2246 auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2247 }
2248
2249 let mut connect_attributes = None;
2250 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2251 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2252 }
2253
2254 Ok(Self {
2255 capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2256 max_packet_size,
2257 collation,
2258 scramble_buf,
2259 user,
2260 db_name,
2261 auth_plugin,
2262 connect_attributes,
2263 })
2264 }
2265}
2266
2267impl MySerialize for HandshakeResponse<'_> {
2268 fn serialize(&self, buf: &mut Vec<u8>) {
2269 self.capabilities.serialize(&mut *buf);
2270 self.max_packet_size.serialize(&mut *buf);
2271 self.collation.serialize(&mut *buf);
2272 buf.put_slice(&[0; 23]);
2273 self.user.serialize(&mut *buf);
2274 self.scramble_buf.serialize(&mut *buf);
2275
2276 if let Some(db_name) = &self.db_name {
2277 db_name.serialize(&mut *buf);
2278 }
2279
2280 if let Some(auth_plugin) = &self.auth_plugin {
2281 auth_plugin.serialize(&mut *buf);
2282 }
2283
2284 if let Some(attrs) = &self.connect_attributes {
2285 let len = attrs
2286 .iter()
2287 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2288 .sum::<u64>();
2289 buf.put_lenenc_int(len);
2290
2291 for (name, value) in attrs {
2292 name.serialize(&mut *buf);
2293 value.serialize(&mut *buf);
2294 }
2295 }
2296 }
2297}
2298
2299#[derive(Debug, Clone, Eq, PartialEq)]
2300pub struct SslRequest {
2301 capabilities: Const<CapabilityFlags, LeU32>,
2302 max_packet_size: RawInt<LeU32>,
2303 character_set: RawInt<u8>,
2304 __skip: Skip<23>,
2305}
2306
2307impl SslRequest {
2308 pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2309 Self {
2310 capabilities: Const::new(capabilities),
2311 max_packet_size: RawInt::new(max_packet_size),
2312 character_set: RawInt::new(character_set),
2313 __skip: Skip,
2314 }
2315 }
2316
2317 pub fn capabilities(&self) -> CapabilityFlags {
2318 self.capabilities.0
2319 }
2320
2321 pub fn max_packet_size(&self) -> u32 {
2322 self.max_packet_size.0
2323 }
2324
2325 pub fn character_set(&self) -> u8 {
2326 self.character_set.0
2327 }
2328}
2329
2330impl<'de> MyDeserialize<'de> for SslRequest {
2331 const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2332 type Ctx = ();
2333
2334 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2335 let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2336 let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2337 Ok(Self {
2338 capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2339 max_packet_size: buf.parse_unchecked(())?,
2340 character_set: buf.parse_unchecked(())?,
2341 __skip: buf.parse_unchecked(())?,
2342 })
2343 }
2344}
2345
2346impl MySerialize for SslRequest {
2347 fn serialize(&self, buf: &mut Vec<u8>) {
2348 self.capabilities.serialize(&mut *buf);
2349 self.max_packet_size.serialize(&mut *buf);
2350 self.character_set.serialize(&mut *buf);
2351 self.__skip.serialize(&mut *buf);
2352 }
2353}
2354
2355#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2356#[error("Invalid statement packet status")]
2357pub struct InvalidStmtPacketStatus;
2358
2359#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2361pub struct StmtPacket {
2362 status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2363 statement_id: RawInt<LeU32>,
2364 num_columns: RawInt<LeU16>,
2365 num_params: RawInt<LeU16>,
2366 __skip: Skip<1>,
2367 warning_count: RawInt<LeU16>,
2368}
2369
2370impl<'de> MyDeserialize<'de> for StmtPacket {
2371 const SIZE: Option<usize> = Some(12);
2372 type Ctx = ();
2373
2374 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2375 let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2376 Ok(StmtPacket {
2377 status: buf.parse_unchecked(())?,
2378 statement_id: buf.parse_unchecked(())?,
2379 num_columns: buf.parse_unchecked(())?,
2380 num_params: buf.parse_unchecked(())?,
2381 __skip: buf.parse_unchecked(())?,
2382 warning_count: buf.parse_unchecked(())?,
2383 })
2384 }
2385}
2386
2387impl MySerialize for StmtPacket {
2388 fn serialize(&self, buf: &mut Vec<u8>) {
2389 self.status.serialize(&mut *buf);
2390 self.statement_id.serialize(&mut *buf);
2391 self.num_columns.serialize(&mut *buf);
2392 self.num_params.serialize(&mut *buf);
2393 self.__skip.serialize(&mut *buf);
2394 self.warning_count.serialize(&mut *buf);
2395 }
2396}
2397
2398impl StmtPacket {
2399 pub fn statement_id(&self) -> u32 {
2401 *self.statement_id
2402 }
2403
2404 pub fn num_columns(&self) -> u16 {
2406 *self.num_columns
2407 }
2408
2409 pub fn num_params(&self) -> u16 {
2411 *self.num_params
2412 }
2413
2414 pub fn warning_count(&self) -> u16 {
2416 *self.warning_count
2417 }
2418}
2419
2420#[derive(Debug, Clone, Eq, PartialEq)]
2424pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2425
2426impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2427 const SIZE: Option<usize> = None;
2428 type Ctx = usize;
2429
2430 fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2431 let bitmap_len = Self::bitmap_len(num_columns);
2432 let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2433 Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2434 }
2435}
2436
2437impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2438 pub fn new(num_columns: usize) -> Self {
2440 Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2441 }
2442
2443 pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2445 let bitmap_len = Self::bitmap_len(num_columns);
2446 assert!(input.len() >= bitmap_len);
2447
2448 let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2449 *input = &input[bitmap_len..];
2450
2451 bitmap
2452 }
2453}
2454
2455impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2456 pub fn bitmap_len(num_columns: usize) -> usize {
2457 (num_columns + 7 + T::BIT_OFFSET) / 8
2458 }
2459
2460 fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2461 let offset = column_index + T::BIT_OFFSET;
2462 let byte = offset / 8;
2463 let bit = 1 << (offset % 8) as u8;
2464
2465 assert!(byte < self.0.as_ref().len());
2466
2467 (byte, bit)
2468 }
2469
2470 pub fn from_bytes(bytes: U) -> Self {
2472 Self(bytes, PhantomData)
2473 }
2474
2475 pub fn is_null(&self, column_index: usize) -> bool {
2477 let (byte, bit) = self.byte_and_bit(column_index);
2478 self.0.as_ref()[byte] & bit > 0
2479 }
2480}
2481
2482impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2483 pub fn set(&mut self, column_index: usize, is_null: bool) {
2485 let (byte, bit) = self.byte_and_bit(column_index);
2486 if is_null {
2487 self.0.as_mut()[byte] |= bit
2488 } else {
2489 self.0.as_mut()[byte] &= !bit
2490 }
2491 }
2492}
2493
2494impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2495 fn as_ref(&self) -> &[u8] {
2496 self.0.as_ref()
2497 }
2498}
2499
2500#[derive(Debug, Clone, PartialEq)]
2501pub struct ComStmtExecuteRequestBuilder {
2502 pub stmt_id: u32,
2503}
2504
2505impl ComStmtExecuteRequestBuilder {
2506 pub const NULL_BITMAP_OFFSET: usize = 10;
2507
2508 pub fn new(stmt_id: u32) -> Self {
2509 Self { stmt_id }
2510 }
2511}
2512
2513impl ComStmtExecuteRequestBuilder {
2514 pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2515 let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2516
2517 let mut bitmap_bytes = vec![0; bitmap_len];
2518 let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2519 let params = params.iter().collect::<Vec<_>>();
2520
2521 let meta_len = params.len() * 2;
2522
2523 let mut data_len = 0;
2524 for (i, param) in params.iter().enumerate() {
2525 match param.bin_len() as usize {
2526 0 => bitmap.set(i, true),
2527 x => data_len += x,
2528 }
2529 }
2530
2531 let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2532
2533 let as_long_data = total_len > MAX_PAYLOAD_LEN;
2534
2535 (
2536 ComStmtExecuteRequest {
2537 com_stmt_execute: ConstU8::new(),
2538 stmt_id: RawInt::new(self.stmt_id),
2539 flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2540 iteration_count: ConstU32::new(),
2541 params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2542 bitmap: RawBytes::new(bitmap_bytes),
2543 params,
2544 as_long_data,
2545 },
2546 as_long_data,
2547 )
2548 }
2549}
2550
2551define_header!(
2552 ComStmtExecuteHeader,
2553 COM_STMT_EXECUTE,
2554 InvalidComStmtExecuteHeader
2555);
2556
2557define_const!(
2558 ConstU32,
2559 IterationCount,
2560 InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2561 1
2562);
2563
2564#[derive(Debug, Clone, PartialEq)]
2565pub struct ComStmtExecuteRequest<'a> {
2566 com_stmt_execute: ComStmtExecuteHeader,
2567 stmt_id: RawInt<LeU32>,
2568 flags: Const<CursorType, u8>,
2569 iteration_count: IterationCount,
2570 bitmap: RawBytes<'a, BareBytes<8192>>,
2572 params_flags: Const<StmtExecuteParamsFlags, u8>,
2573 params: Vec<&'a Value>,
2574 as_long_data: bool,
2575}
2576
2577impl<'a> ComStmtExecuteRequest<'a> {
2578 pub fn stmt_id(&self) -> u32 {
2579 self.stmt_id.0
2580 }
2581
2582 pub fn flags(&self) -> CursorType {
2583 self.flags.0
2584 }
2585
2586 pub fn bitmap(&self) -> &[u8] {
2587 self.bitmap.as_bytes()
2588 }
2589
2590 pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2591 self.params_flags.0
2592 }
2593
2594 pub fn params(&self) -> &[&'a Value] {
2595 self.params.as_ref()
2596 }
2597
2598 pub fn as_long_data(&self) -> bool {
2599 self.as_long_data
2600 }
2601}
2602
2603impl MySerialize for ComStmtExecuteRequest<'_> {
2604 fn serialize(&self, buf: &mut Vec<u8>) {
2605 self.com_stmt_execute.serialize(&mut *buf);
2606 self.stmt_id.serialize(&mut *buf);
2607 self.flags.serialize(&mut *buf);
2608 self.iteration_count.serialize(&mut *buf);
2609
2610 if !self.params.is_empty() {
2611 self.bitmap.serialize(&mut *buf);
2612 self.params_flags.serialize(&mut *buf);
2613 }
2614
2615 for param in &self.params {
2616 let (column_type, flags) = match param {
2617 Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2618 Value::Bytes(_) => (
2619 ColumnType::MYSQL_TYPE_VAR_STRING,
2620 StmtExecuteParamFlags::empty(),
2621 ),
2622 Value::Int(_) => (
2623 ColumnType::MYSQL_TYPE_LONGLONG,
2624 StmtExecuteParamFlags::empty(),
2625 ),
2626 Value::UInt(_) => (
2627 ColumnType::MYSQL_TYPE_LONGLONG,
2628 StmtExecuteParamFlags::UNSIGNED,
2629 ),
2630 Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2631 Value::Double(_) => (
2632 ColumnType::MYSQL_TYPE_DOUBLE,
2633 StmtExecuteParamFlags::empty(),
2634 ),
2635 Value::Date(..) => (
2636 ColumnType::MYSQL_TYPE_DATETIME,
2637 StmtExecuteParamFlags::empty(),
2638 ),
2639 Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2640 };
2641
2642 buf.put_slice(&[column_type as u8, flags.bits()]);
2643 }
2644
2645 for param in &self.params {
2646 match **param {
2647 Value::Int(_)
2648 | Value::UInt(_)
2649 | Value::Float(_)
2650 | Value::Double(_)
2651 | Value::Date(..)
2652 | Value::Time(..) => {
2653 param.serialize(buf);
2654 }
2655 Value::Bytes(_) if !self.as_long_data => {
2656 param.serialize(buf);
2657 }
2658 Value::Bytes(_) | Value::NULL => {}
2659 }
2660 }
2661 }
2662}
2663
2664define_header!(
2665 ComStmtSendLongDataHeader,
2666 COM_STMT_SEND_LONG_DATA,
2667 InvalidComStmtSendLongDataHeader
2668);
2669
2670#[derive(Debug, Clone, Eq, PartialEq)]
2671pub struct ComStmtSendLongData<'a> {
2672 __header: ComStmtSendLongDataHeader,
2673 stmt_id: RawInt<LeU32>,
2674 param_index: RawInt<LeU16>,
2675 data: RawBytes<'a, EofBytes>,
2676}
2677
2678impl<'a> ComStmtSendLongData<'a> {
2679 pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2680 Self {
2681 __header: ComStmtSendLongDataHeader::new(),
2682 stmt_id: RawInt::new(stmt_id),
2683 param_index: RawInt::new(param_index),
2684 data: RawBytes::new(data),
2685 }
2686 }
2687
2688 pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2689 ComStmtSendLongData {
2690 __header: self.__header,
2691 stmt_id: self.stmt_id,
2692 param_index: self.param_index,
2693 data: self.data.into_owned(),
2694 }
2695 }
2696}
2697
2698impl MySerialize for ComStmtSendLongData<'_> {
2699 fn serialize(&self, buf: &mut Vec<u8>) {
2700 self.__header.serialize(&mut *buf);
2701 self.stmt_id.serialize(&mut *buf);
2702 self.param_index.serialize(&mut *buf);
2703 self.data.serialize(&mut *buf);
2704 }
2705}
2706
2707#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2708pub struct ComStmtClose {
2709 pub stmt_id: u32,
2710}
2711
2712impl ComStmtClose {
2713 pub fn new(stmt_id: u32) -> Self {
2714 Self { stmt_id }
2715 }
2716}
2717
2718impl MySerialize for ComStmtClose {
2719 fn serialize(&self, buf: &mut Vec<u8>) {
2720 buf.put_u8(Command::COM_STMT_CLOSE as u8);
2721 buf.put_u32_le(self.stmt_id);
2722 }
2723}
2724
2725define_header!(
2726 ComRegisterSlaveHeader,
2727 COM_REGISTER_SLAVE,
2728 InvalidComRegisterSlaveHeader
2729);
2730
2731#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2734pub struct ComRegisterSlave<'a> {
2735 header: ComRegisterSlaveHeader,
2736 server_id: RawInt<LeU32>,
2738 hostname: RawBytes<'a, U8Bytes>,
2741 user: RawBytes<'a, U8Bytes>,
2748 password: RawBytes<'a, U8Bytes>,
2755 port: RawInt<LeU16>,
2762 replication_rank: RawInt<LeU32>,
2764 master_id: RawInt<LeU32>,
2767}
2768
2769impl<'a> ComRegisterSlave<'a> {
2770 pub fn new(server_id: u32) -> Self {
2772 Self {
2773 header: Default::default(),
2774 server_id: RawInt::new(server_id),
2775 hostname: Default::default(),
2776 user: Default::default(),
2777 password: Default::default(),
2778 port: Default::default(),
2779 replication_rank: Default::default(),
2780 master_id: Default::default(),
2781 }
2782 }
2783
2784 pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2786 self.hostname = RawBytes::new(hostname);
2787 self
2788 }
2789
2790 pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2792 self.user = RawBytes::new(user);
2793 self
2794 }
2795
2796 pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2798 self.password = RawBytes::new(password);
2799 self
2800 }
2801
2802 pub fn with_port(mut self, port: u16) -> Self {
2804 self.port = RawInt::new(port);
2805 self
2806 }
2807
2808 pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2810 self.replication_rank = RawInt::new(replication_rank);
2811 self
2812 }
2813
2814 pub fn with_master_id(mut self, master_id: u32) -> Self {
2816 self.master_id = RawInt::new(master_id);
2817 self
2818 }
2819
2820 pub fn server_id(&self) -> u32 {
2822 self.server_id.0
2823 }
2824
2825 pub fn hostname_raw(&self) -> &[u8] {
2827 self.hostname.as_bytes()
2828 }
2829
2830 pub fn hostname(&'a self) -> Cow<'a, str> {
2832 self.hostname.as_str()
2833 }
2834
2835 pub fn user_raw(&self) -> &[u8] {
2837 self.user.as_bytes()
2838 }
2839
2840 pub fn user(&'a self) -> Cow<'a, str> {
2842 self.user.as_str()
2843 }
2844
2845 pub fn password_raw(&self) -> &[u8] {
2847 self.password.as_bytes()
2848 }
2849
2850 pub fn password(&'a self) -> Cow<'a, str> {
2852 self.password.as_str()
2853 }
2854
2855 pub fn port(&self) -> u16 {
2857 self.port.0
2858 }
2859
2860 pub fn replication_rank(&self) -> u32 {
2862 self.replication_rank.0
2863 }
2864
2865 pub fn master_id(&self) -> u32 {
2867 self.master_id.0
2868 }
2869}
2870
2871impl MySerialize for ComRegisterSlave<'_> {
2872 fn serialize(&self, buf: &mut Vec<u8>) {
2873 self.header.serialize(&mut *buf);
2874 self.server_id.serialize(&mut *buf);
2875 self.hostname.serialize(&mut *buf);
2876 self.user.serialize(&mut *buf);
2877 self.password.serialize(&mut *buf);
2878 self.port.serialize(&mut *buf);
2879 self.replication_rank.serialize(&mut *buf);
2880 self.master_id.serialize(&mut *buf);
2881 }
2882}
2883
2884impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2885 const SIZE: Option<usize> = None;
2886 type Ctx = ();
2887
2888 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2889 let mut sbuf: ParseBuf<'_> = buf.parse(5)?;
2890 let header = sbuf.parse_unchecked(())?;
2891 let server_id = sbuf.parse_unchecked(())?;
2892
2893 let hostname = buf.parse(())?;
2894 let user = buf.parse(())?;
2895 let password = buf.parse(())?;
2896
2897 let mut sbuf: ParseBuf<'_> = buf.parse(10)?;
2898 let port = sbuf.parse_unchecked(())?;
2899 let replication_rank = sbuf.parse_unchecked(())?;
2900 let master_id = sbuf.parse_unchecked(())?;
2901
2902 Ok(Self {
2903 header,
2904 server_id,
2905 hostname,
2906 user,
2907 password,
2908 port,
2909 replication_rank,
2910 master_id,
2911 })
2912 }
2913}
2914
2915define_header!(
2916 ComTableDumpHeader,
2917 COM_TABLE_DUMP,
2918 InvalidComTableDumpHeader
2919);
2920
2921#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2923pub struct ComTableDump<'a> {
2924 header: ComTableDumpHeader,
2925 database: RawBytes<'a, U8Bytes>,
2931 table: RawBytes<'a, U8Bytes>,
2937}
2938
2939impl<'a> ComTableDump<'a> {
2940 pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2942 Self {
2943 header: Default::default(),
2944 database: RawBytes::new(database),
2945 table: RawBytes::new(table),
2946 }
2947 }
2948
2949 pub fn database_raw(&self) -> &[u8] {
2951 self.database.as_bytes()
2952 }
2953
2954 pub fn database(&self) -> Cow<'_, str> {
2956 self.database.as_str()
2957 }
2958
2959 pub fn table_raw(&self) -> &[u8] {
2961 self.table.as_bytes()
2962 }
2963
2964 pub fn table(&self) -> Cow<'_, str> {
2966 self.table.as_str()
2967 }
2968}
2969
2970impl MySerialize for ComTableDump<'_> {
2971 fn serialize(&self, buf: &mut Vec<u8>) {
2972 self.header.serialize(&mut *buf);
2973 self.database.serialize(&mut *buf);
2974 self.table.serialize(&mut *buf);
2975 }
2976}
2977
2978impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
2979 const SIZE: Option<usize> = None;
2980 type Ctx = ();
2981
2982 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2983 Ok(Self {
2984 header: buf.parse(())?,
2985 database: buf.parse(())?,
2986 table: buf.parse(())?,
2987 })
2988 }
2989}
2990
2991my_bitflags! {
2992 BinlogDumpFlags,
2993 #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={0:b})")]
2994 UnknownBinlogDumpFlags,
2995 u16,
2996
2997 #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
2999 pub struct BinlogDumpFlags: u16 {
3000 const BINLOG_DUMP_NON_BLOCK = 0x01;
3002 const BINLOG_THROUGH_POSITION = 0x02;
3003 const BINLOG_THROUGH_GTID = 0x04;
3004 }
3005}
3006
3007define_header!(
3008 ComBinlogDumpHeader,
3009 COM_BINLOG_DUMP,
3010 InvalidComBinlogDumpHeader
3011);
3012
3013#[derive(Clone, Debug, Eq, PartialEq, Hash)]
3015pub struct ComBinlogDump<'a> {
3016 header: ComBinlogDumpHeader,
3017 pos: RawInt<LeU32>,
3019 flags: Const<BinlogDumpFlags, LeU16>,
3023 server_id: RawInt<LeU32>,
3025 filename: RawBytes<'a, EofBytes>,
3030}
3031
3032impl<'a> ComBinlogDump<'a> {
3033 pub fn new(server_id: u32) -> Self {
3035 Self {
3036 header: Default::default(),
3037 pos: Default::default(),
3038 flags: Default::default(),
3039 server_id: RawInt::new(server_id),
3040 filename: Default::default(),
3041 }
3042 }
3043
3044 pub fn with_pos(mut self, pos: u32) -> Self {
3046 self.pos = RawInt::new(pos);
3047 self
3048 }
3049
3050 pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3052 self.flags = Const::new(flags);
3053 self
3054 }
3055
3056 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3058 self.filename = RawBytes::new(filename);
3059 self
3060 }
3061
3062 pub fn pos(&self) -> u32 {
3064 *self.pos
3065 }
3066
3067 pub fn flags(&self) -> BinlogDumpFlags {
3069 *self.flags
3070 }
3071
3072 pub fn server_id(&self) -> u32 {
3074 *self.server_id
3075 }
3076
3077 pub fn filename_raw(&self) -> &[u8] {
3079 self.filename.as_bytes()
3080 }
3081
3082 pub fn filename(&self) -> Cow<'_, str> {
3084 self.filename.as_str()
3085 }
3086}
3087
3088impl MySerialize for ComBinlogDump<'_> {
3089 fn serialize(&self, buf: &mut Vec<u8>) {
3090 self.header.serialize(&mut *buf);
3091 self.pos.serialize(&mut *buf);
3092 self.flags.serialize(&mut *buf);
3093 self.server_id.serialize(&mut *buf);
3094 self.filename.serialize(&mut *buf);
3095 }
3096}
3097
3098impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3099 const SIZE: Option<usize> = None;
3100 type Ctx = ();
3101
3102 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3103 let mut sbuf: ParseBuf<'_> = buf.parse(11)?;
3104 Ok(Self {
3105 header: sbuf.parse_unchecked(())?,
3106 pos: sbuf.parse_unchecked(())?,
3107 flags: sbuf.parse_unchecked(())?,
3108 server_id: sbuf.parse_unchecked(())?,
3109 filename: buf.parse(())?,
3110 })
3111 }
3112}
3113
3114#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3116pub struct GnoInterval {
3117 start: RawInt<LeU64>,
3118 end: RawInt<LeU64>,
3119}
3120
3121impl GnoInterval {
3122 pub fn new(start: u64, end: u64) -> Self {
3124 Self {
3125 start: RawInt::new(start),
3126 end: RawInt::new(end),
3127 }
3128 }
3129 pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3131 if start >= end {
3132 return Err(io::Error::new(
3133 io::ErrorKind::InvalidData,
3134 format!("start({}) >= end({}) in GnoInterval", start, end),
3135 ));
3136 }
3137 if start == 0 || end == 0 {
3138 return Err(io::Error::new(
3139 io::ErrorKind::InvalidData,
3140 "Gno can't be zero",
3141 ));
3142 }
3143 Ok(Self::new(start, end))
3144 }
3145}
3146
3147impl MySerialize for GnoInterval {
3148 fn serialize(&self, buf: &mut Vec<u8>) {
3149 self.start.serialize(&mut *buf);
3150 self.end.serialize(&mut *buf);
3151 }
3152}
3153
3154impl<'de> MyDeserialize<'de> for GnoInterval {
3155 const SIZE: Option<usize> = Some(16);
3156 type Ctx = ();
3157
3158 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3159 Ok(Self {
3160 start: buf.parse_unchecked(())?,
3161 end: buf.parse_unchecked(())?,
3162 })
3163 }
3164}
3165
3166pub const UUID_LEN: usize = 16;
3168
3169#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3172pub struct Sid<'a> {
3173 uuid: [u8; UUID_LEN],
3174 intervals: Seq<'a, GnoInterval, LeU64>,
3175}
3176
3177impl Sid<'_> {
3178 pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3180 Self {
3181 uuid,
3182 intervals: Default::default(),
3183 }
3184 }
3185
3186 pub fn uuid(&self) -> [u8; UUID_LEN] {
3188 self.uuid
3189 }
3190
3191 pub fn intervals(&self) -> &[GnoInterval] {
3193 &self.intervals[..]
3194 }
3195
3196 pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3198 let mut intervals = self.intervals.0.into_owned();
3199 intervals.push(interval);
3200 self.intervals = Seq::new(intervals);
3201 self
3202 }
3203
3204 pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3206 self.intervals = Seq::new(intervals);
3207 self
3208 }
3209
3210 fn len(&self) -> u64 {
3211 use saturating::Saturating as S;
3212 let mut len = S(UUID_LEN as u64); len += S(8); len += S((self.intervals.len() * 16) as u64);
3215 len.0
3216 }
3217}
3218
3219impl MySerialize for Sid<'_> {
3220 fn serialize(&self, buf: &mut Vec<u8>) {
3221 self.uuid.serialize(&mut *buf);
3222 self.intervals.serialize(buf);
3223 }
3224}
3225
3226impl<'de> MyDeserialize<'de> for Sid<'de> {
3227 const SIZE: Option<usize> = None;
3228 type Ctx = ();
3229
3230 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3231 Ok(Self {
3232 uuid: buf.parse(())?,
3233 intervals: buf.parse(())?,
3234 })
3235 }
3236}
3237
3238impl Sid<'_> {
3239 fn wrap_err(msg: String) -> io::Error {
3240 io::Error::new(io::ErrorKind::InvalidInput, msg)
3241 }
3242
3243 fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3244 let n: u64 = to_parse.parse().map_err(|e| {
3245 Sid::wrap_err(format!(
3246 "invalid GnoInterval format: {}, error: {}",
3247 full, e
3248 ))
3249 })?;
3250 Ok(n)
3251 }
3252}
3253
3254impl FromStr for Sid<'_> {
3255 type Err = io::Error;
3256
3257 fn from_str(s: &str) -> Result<Self, Self::Err> {
3258 let (uuid, intervals) = s
3259 .split_once(':')
3260 .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3261 let uuid = Uuid::parse_str(uuid)
3262 .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3263 let intervals = intervals
3264 .split(':')
3265 .map(|interval| {
3266 let nums = interval.split('-').collect::<Vec<_>>();
3267 if nums.len() != 1 && nums.len() != 2 {
3268 return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3269 }
3270 if nums.len() == 1 {
3271 let start = Sid::parse_interval_num(nums[0], s)?;
3272 let interval = GnoInterval::check_and_new(start, start + 1)?;
3273 Ok(interval)
3274 } else {
3275 let start = Sid::parse_interval_num(nums[0], s)?;
3276 let end = Sid::parse_interval_num(nums[1], s)?;
3277 let interval = GnoInterval::check_and_new(start, end + 1)?;
3278 Ok(interval)
3279 }
3280 })
3281 .collect::<Result<Vec<_>, _>>()?;
3282 Ok(Self {
3283 uuid: *uuid.as_bytes(),
3284 intervals: Seq::new(intervals),
3285 })
3286 }
3287}
3288
3289define_header!(
3290 ComBinlogDumpGtidHeader,
3291 COM_BINLOG_DUMP_GTID,
3292 InvalidComBinlogDumpGtidHeader
3293);
3294
3295#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3297pub struct ComBinlogDumpGtid<'a> {
3298 header: ComBinlogDumpGtidHeader,
3299 flags: Const<BinlogDumpFlags, LeU16>,
3301 server_id: RawInt<LeU32>,
3303 filename: RawBytes<'a, U32Bytes>,
3312 pos: RawInt<LeU64>,
3314 sid_block: Seq<'a, Sid<'a>, LeU64>,
3316}
3317
3318impl<'a> ComBinlogDumpGtid<'a> {
3319 pub fn new(server_id: u32) -> Self {
3321 Self {
3322 header: Default::default(),
3323 pos: Default::default(),
3324 flags: Default::default(),
3325 server_id: RawInt::new(server_id),
3326 filename: Default::default(),
3327 sid_block: Default::default(),
3328 }
3329 }
3330
3331 pub fn server_id(&self) -> u32 {
3333 self.server_id.0
3334 }
3335
3336 pub fn flags(&self) -> BinlogDumpFlags {
3338 self.flags.0
3339 }
3340
3341 pub fn filename_raw(&self) -> &[u8] {
3343 self.filename.as_bytes()
3344 }
3345
3346 pub fn filename(&self) -> Cow<'_, str> {
3348 self.filename.as_str()
3349 }
3350
3351 pub fn pos(&self) -> u64 {
3353 self.pos.0
3354 }
3355
3356 pub fn sids(&self) -> &[Sid<'a>] {
3358 &self.sid_block
3359 }
3360
3361 pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3363 Self {
3364 header: self.header,
3365 flags: self.flags,
3366 server_id: self.server_id,
3367 filename: RawBytes::new(filename),
3368 pos: self.pos,
3369 sid_block: self.sid_block,
3370 }
3371 }
3372
3373 pub fn with_server_id(mut self, server_id: u32) -> Self {
3375 self.server_id.0 = server_id;
3376 self
3377 }
3378
3379 pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3381 if self.sid_block.is_empty() {
3382 flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3383 } else {
3384 flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3385 }
3386 self.flags.0 = flags;
3387 self
3388 }
3389
3390 pub fn with_pos(mut self, pos: u64) -> Self {
3392 self.pos.0 = pos;
3393 self
3394 }
3395
3396 pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3398 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3399 self.sid_block.push(sid);
3400 self
3401 }
3402
3403 pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3405 self.sid_block = Seq::new(sids);
3406 if self.sid_block.is_empty() {
3407 self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3408 } else {
3409 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3410 }
3411 self
3412 }
3413
3414 fn sid_block_len(&self) -> u32 {
3415 use saturating::Saturating as S;
3416 let mut len = S(8); for sid in self.sid_block.iter() {
3418 len += S(sid.len() as u32);
3419 }
3420 len.0
3421 }
3422}
3423
3424impl MySerialize for ComBinlogDumpGtid<'_> {
3425 fn serialize(&self, buf: &mut Vec<u8>) {
3426 self.header.serialize(&mut *buf);
3427 self.flags.serialize(&mut *buf);
3428 self.server_id.serialize(&mut *buf);
3429 self.filename.serialize(&mut *buf);
3430 self.pos.serialize(&mut *buf);
3431 buf.put_u32_le(self.sid_block_len());
3432 self.sid_block.serialize(&mut *buf);
3433 }
3434}
3435
3436impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3437 const SIZE: Option<usize> = None;
3438 type Ctx = ();
3439
3440 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3441 let mut sbuf: ParseBuf<'_> = buf.parse(7)?;
3442 let header = sbuf.parse_unchecked(())?;
3443 let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3444 let server_id = sbuf.parse_unchecked(())?;
3445
3446 let filename = buf.parse(())?;
3447 let pos = buf.parse(())?;
3448
3449 let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3451 let mut buf: ParseBuf<'_> = buf.parse(sid_data_len.0 as usize)?;
3452 let sid_block = buf.parse(())?;
3453
3454 Ok(Self {
3455 header,
3456 flags,
3457 server_id,
3458 filename,
3459 pos,
3460 sid_block,
3461 })
3462 }
3463}
3464
3465define_header!(
3466 SemiSyncAckPacketPacketHeader,
3467 InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3468 0xEF
3469);
3470
3471pub struct SemiSyncAckPacket<'a> {
3474 header: SemiSyncAckPacketPacketHeader,
3475 position: RawInt<LeU64>,
3476 filename: RawBytes<'a, EofBytes>,
3477}
3478
3479impl<'a> SemiSyncAckPacket<'a> {
3480 pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3481 Self {
3482 header: Default::default(),
3483 position: RawInt::new(position),
3484 filename: RawBytes::new(filename),
3485 }
3486 }
3487
3488 pub fn with_position(mut self, position: u64) -> Self {
3490 self.position.0 = position;
3491 self
3492 }
3493
3494 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3496 self.filename = RawBytes::new(filename);
3497 self
3498 }
3499
3500 pub fn position(&self) -> u64 {
3502 self.position.0
3503 }
3504
3505 pub fn filename_raw(&self) -> &[u8] {
3507 self.filename.as_bytes()
3508 }
3509
3510 pub fn filename(&self) -> Cow<'_, str> {
3512 self.filename.as_str()
3513 }
3514}
3515
3516impl MySerialize for SemiSyncAckPacket<'_> {
3517 fn serialize(&self, buf: &mut Vec<u8>) {
3518 self.header.serialize(&mut *buf);
3519 self.position.serialize(&mut *buf);
3520 self.filename.serialize(&mut *buf);
3521 }
3522}
3523
3524impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3525 const SIZE: Option<usize> = None;
3526 type Ctx = ();
3527
3528 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3529 let mut sbuf: ParseBuf<'_> = buf.parse(9)?;
3530 Ok(Self {
3531 header: sbuf.parse_unchecked(())?,
3532 position: sbuf.parse_unchecked(())?,
3533 filename: buf.parse(())?,
3534 })
3535 }
3536}
3537
3538#[cfg(test)]
3539mod test {
3540 use super::*;
3541 use crate::{
3542 constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3543 proto::{MyDeserialize, MySerialize},
3544 };
3545
3546 proptest::proptest! {
3547 #[test]
3548 fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3549 let cmd = ComTableDump::new(database, table);
3550
3551 let mut output = Vec::new();
3552 cmd.serialize(&mut output);
3553
3554 assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3555 }
3556
3557 #[test]
3558 fn com_binlog_dump_roundtrip(
3559 server_id: u32,
3560 filename: Vec<u8>,
3561 pos: u32,
3562 flags: u16,
3563 ) {
3564 let cmd = ComBinlogDump::new(server_id)
3565 .with_filename(filename)
3566 .with_pos(pos)
3567 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3568
3569 let mut output = Vec::new();
3570 cmd.serialize(&mut output);
3571
3572 assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3573 }
3574
3575 #[test]
3576 fn com_register_slave_roundtrip(
3577 server_id: u32,
3578 hostname in r"\w{0,256}",
3579 user in r"\w{0,256}",
3580 password in r"\w{0,256}",
3581 port: u16,
3582 replication_rank: u32,
3583 master_id: u32,
3584 ) {
3585 let cmd = ComRegisterSlave::new(server_id)
3586 .with_hostname(hostname.as_bytes())
3587 .with_user(user.as_bytes())
3588 .with_password(password.as_bytes())
3589 .with_port(port)
3590 .with_replication_rank(replication_rank)
3591 .with_master_id(master_id);
3592
3593 let mut output = Vec::new();
3594 cmd.serialize(&mut output);
3595 let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3596
3597 if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3598 assert_ne!(cmd, parsed);
3599 } else {
3600 assert_eq!(cmd, parsed);
3601 }
3602 }
3603
3604 #[test]
3605 fn com_binlog_dump_gtid_roundtrip(
3606 flags: u16,
3607 server_id: u32,
3608 filename: Vec<u8>,
3609 pos: u64,
3610 n_sid_blocks in 0_u64..1024,
3611 ) {
3612 let mut cmd = ComBinlogDumpGtid::new(server_id)
3613 .with_filename(filename)
3614 .with_pos(pos)
3615 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3616
3617 let mut sids = Vec::new();
3618 for i in 0..n_sid_blocks {
3619 let mut block = Sid::new([i as u8; 16]);
3620 for j in 0..i {
3621 block = block.with_interval(GnoInterval::new(i, j));
3622 }
3623 sids.push(block);
3624 }
3625
3626 cmd = cmd.with_sids(sids);
3627
3628 let mut output = Vec::new();
3629 cmd.serialize(&mut output);
3630
3631 assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3632 }
3633 }
3634
3635 #[test]
3636 fn should_parse_local_infile_packet() {
3637 const LIP: &[u8] = b"\xfbfile_name";
3638
3639 let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3640 assert_eq!(lip.file_name_str(), "file_name");
3641 }
3642
3643 #[test]
3644 fn should_parse_stmt_packet() {
3645 const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3646 const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3647
3648 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3649 assert_eq!(sp.statement_id(), 0x01);
3650 assert_eq!(sp.num_columns(), 0x01);
3651 assert_eq!(sp.num_params(), 0x02);
3652 assert_eq!(sp.warning_count(), 0x00);
3653
3654 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3655 assert_eq!(sp.statement_id(), 0x01);
3656 assert_eq!(sp.num_columns(), 0x00);
3657 assert_eq!(sp.num_params(), 0x00);
3658 assert_eq!(sp.warning_count(), 0x00);
3659 }
3660
3661 #[test]
3662 fn should_parse_handshake_packet() {
3663 const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3664 \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3665 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3666 \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3667
3668 const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3669 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3670 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3671 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3672 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3673 \x6f\x72\x64\x00";
3674
3675 const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3676 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3677 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3678 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3679 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3680 \x6f\x72\x64\x00";
3681
3682 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3683 assert_eq!(hsp.protocol_version(), 0x0a);
3684 assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3685 assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3686 assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3687 assert_eq!(hsp.connection_id(), 0x0b);
3688 assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3689 assert_eq!(
3690 hsp.capabilities(),
3691 CapabilityFlags::from_bits_truncate(0xf7ff)
3692 );
3693 assert_eq!(hsp.default_collation(), 0x08);
3694 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3695 assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3696 assert_eq!(hsp.auth_plugin_name_ref(), None);
3697
3698 let mut output = Vec::new();
3699 hsp.serialize(&mut output);
3700 assert_eq!(&output, HSP);
3701
3702 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3703 assert_eq!(hsp.protocol_version(), 0x0a);
3704 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3705 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3706 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3707 assert_eq!(hsp.connection_id(), 0x0a56);
3708 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3709 assert_eq!(
3710 hsp.capabilities(),
3711 CapabilityFlags::from_bits_truncate(0xc00fffff)
3712 );
3713 assert_eq!(hsp.default_collation(), 0x08);
3714 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3715 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3716 assert_eq!(
3717 hsp.auth_plugin_name_ref(),
3718 Some(&b"mysql_native_password"[..])
3719 );
3720
3721 let mut output = Vec::new();
3722 hsp.serialize(&mut output);
3723 assert_eq!(&output, HSP_2);
3724
3725 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3726 assert_eq!(hsp.protocol_version(), 0x0a);
3727 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3728 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3729 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3730 assert_eq!(hsp.connection_id(), 0x0a56);
3731 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3732 assert_eq!(
3733 hsp.capabilities(),
3734 CapabilityFlags::from_bits_truncate(0xc00fffff)
3735 );
3736 assert_eq!(hsp.default_collation(), 0x08);
3737 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3738 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3739 assert_eq!(
3740 hsp.auth_plugin_name_ref(),
3741 Some(&b"mysql_native_password"[..])
3742 );
3743
3744 let mut output = Vec::new();
3745 hsp.serialize(&mut output);
3746 assert_eq!(&output, HSP_3);
3747 }
3748
3749 #[test]
3750 fn should_parse_err_packet() {
3751 const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3752 \x6c\x65\x73\x20\x75\x73\x65\x64";
3753 const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3754 \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3755 const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3756
3757 let err_packet = ErrPacket::deserialize(
3758 CapabilityFlags::CLIENT_PROTOCOL_41,
3759 &mut ParseBuf(ERR_PACKET),
3760 )
3761 .unwrap();
3762 let err_packet = err_packet.server_error();
3763 assert_eq!(err_packet.error_code(), 1096);
3764 assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3765 assert_eq!(err_packet.message_str(), "No tables used");
3766
3767 let err_packet =
3768 ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3769 .unwrap();
3770 let server_error = err_packet.server_error();
3771 assert_eq!(server_error.error_code(), 1040);
3772 assert_eq!(server_error.sql_state_ref(), None);
3773 assert_eq!(server_error.message_str(), "Too many connections");
3774
3775 let err_packet = ErrPacket::deserialize(
3776 CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3777 &mut ParseBuf(PROGRESS_PACKET),
3778 )
3779 .unwrap();
3780 assert!(err_packet.is_progress_report());
3781 let progress_report = err_packet.progress_report();
3782 assert_eq!(progress_report.stage(), 1);
3783 assert_eq!(progress_report.max_stage(), 10);
3784 assert_eq!(progress_report.progress(), 23500);
3785 assert_eq!(progress_report.stage_info_str(), "stage name");
3786 }
3787
3788 #[test]
3789 fn should_parse_column_packet() {
3790 const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3791 \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3792 let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3793 assert_eq!(column.schema_str(), "schema");
3794 assert_eq!(column.table_str(), "table");
3795 assert_eq!(column.org_table_str(), "org_table");
3796 assert_eq!(column.name_str(), "name");
3797 assert_eq!(column.org_name_str(), "org_name");
3798 assert_eq!(
3799 column.character_set(),
3800 CollationId::UTF8MB3_GENERAL_CI as u16
3801 );
3802 assert_eq!(column.column_length(), 15);
3803 assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3804 assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3805 assert_eq!(column.decimals(), 8);
3806 }
3807
3808 #[test]
3809 fn should_parse_auth_switch_request() {
3810 const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3811 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3812 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3813 let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3814 assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3815 assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3816 }
3817
3818 #[test]
3819 fn should_parse_auth_more_data() {
3820 const PAYLOAD: &[u8] = b"\x01\x04";
3821 let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3822 assert_eq!(packet.data(), b"\x04",);
3823 }
3824
3825 #[test]
3826 fn should_parse_ok_packet() {
3827 const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3828 const RESULT_SET_TERMINATOR: &[u8] = &[
3829 0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3830 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3831 0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3832 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3833 0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3834 0x73, 0x65, 0x63, 0x2e,
3835 ];
3836 const SESS_STATE_SYS_VAR_OK: &[u8] =
3837 b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3838 \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3839 const SESS_STATE_SCHEMA_OK: &[u8] =
3840 b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3841 const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3842 const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3843
3844 OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3846 CapabilityFlags::empty(),
3847 &mut ParseBuf(PLAIN_OK),
3848 )
3849 .unwrap_err();
3850
3851 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3852 CapabilityFlags::empty(),
3853 &mut ParseBuf(PLAIN_OK),
3854 )
3855 .unwrap()
3856 .into();
3857 assert_eq!(ok_packet.affected_rows(), 1);
3858 assert_eq!(ok_packet.last_insert_id(), None);
3859 assert_eq!(
3860 ok_packet.status_flags(),
3861 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3862 );
3863 assert_eq!(ok_packet.warnings(), 0);
3864 assert_eq!(ok_packet.info_ref(), None);
3865 assert_eq!(ok_packet.session_state_info_ref(), None);
3866
3867 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3868 CapabilityFlags::CLIENT_SESSION_TRACK,
3869 &mut ParseBuf(PLAIN_OK),
3870 )
3871 .unwrap()
3872 .into();
3873 assert_eq!(ok_packet.affected_rows(), 1);
3874 assert_eq!(ok_packet.last_insert_id(), None);
3875 assert_eq!(
3876 ok_packet.status_flags(),
3877 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3878 );
3879 assert_eq!(ok_packet.warnings(), 0);
3880 assert_eq!(ok_packet.info_ref(), None);
3881 assert_eq!(ok_packet.session_state_info_ref(), None);
3882
3883 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3884 CapabilityFlags::CLIENT_SESSION_TRACK,
3885 &mut ParseBuf(RESULT_SET_TERMINATOR),
3886 )
3887 .unwrap()
3888 .into();
3889 assert_eq!(ok_packet.affected_rows(), 0);
3890 assert_eq!(ok_packet.last_insert_id(), None);
3891 assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3892 assert_eq!(ok_packet.warnings(), 0);
3893 assert_eq!(
3894 ok_packet.info_str(),
3895 Some(Cow::Borrowed(
3896 "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3897 ))
3898 );
3899 assert_eq!(ok_packet.session_state_info_ref(), None);
3900
3901 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3902 CapabilityFlags::CLIENT_SESSION_TRACK,
3903 &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3904 )
3905 .unwrap()
3906 .into();
3907 assert_eq!(ok_packet.affected_rows(), 0);
3908 assert_eq!(ok_packet.last_insert_id(), None);
3909 assert_eq!(
3910 ok_packet.status_flags(),
3911 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3912 );
3913 assert_eq!(ok_packet.warnings(), 0);
3914 assert_eq!(ok_packet.info_ref(), None);
3915 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3916
3917 match sess_state_info.decode().unwrap() {
3918 SessionStateChange::SystemVariables(mut vals) => {
3919 let val = vals.pop().unwrap();
3920 assert_eq!(val.name_bytes(), b"autocommit");
3921 assert_eq!(val.value_bytes(), b"OFF");
3922 assert!(vals.is_empty());
3923 }
3924 _ => panic!(),
3925 }
3926
3927 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3928 CapabilityFlags::CLIENT_SESSION_TRACK,
3929 &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3930 )
3931 .unwrap()
3932 .into();
3933 assert_eq!(ok_packet.affected_rows(), 0);
3934 assert_eq!(ok_packet.last_insert_id(), None);
3935 assert_eq!(
3936 ok_packet.status_flags(),
3937 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3938 );
3939 assert_eq!(ok_packet.warnings(), 0);
3940 assert_eq!(ok_packet.info_ref(), None);
3941 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3942 match sess_state_info.decode().unwrap() {
3943 SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3944 _ => panic!(),
3945 }
3946
3947 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3948 CapabilityFlags::CLIENT_SESSION_TRACK,
3949 &mut ParseBuf(SESS_STATE_TRACK_OK),
3950 )
3951 .unwrap()
3952 .into();
3953 assert_eq!(ok_packet.affected_rows(), 0);
3954 assert_eq!(ok_packet.last_insert_id(), None);
3955 assert_eq!(
3956 ok_packet.status_flags(),
3957 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3958 );
3959 assert_eq!(ok_packet.warnings(), 0);
3960 assert_eq!(ok_packet.info_ref(), None);
3961 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3962 assert_eq!(
3963 sess_state_info.decode().unwrap(),
3964 SessionStateChange::IsTracked(true),
3965 );
3966
3967 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<OldEofPacket>::deserialize(
3968 CapabilityFlags::CLIENT_SESSION_TRACK,
3969 &mut ParseBuf(EOF),
3970 )
3971 .unwrap()
3972 .into();
3973 assert_eq!(ok_packet.affected_rows(), 0);
3974 assert_eq!(ok_packet.last_insert_id(), None);
3975 assert_eq!(
3976 ok_packet.status_flags(),
3977 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3978 );
3979 assert_eq!(ok_packet.warnings(), 0);
3980 assert_eq!(ok_packet.info_ref(), None);
3981 assert_eq!(ok_packet.session_state_info_ref(), None);
3982 }
3983
3984 #[test]
3985 fn should_build_handshake_response() {
3986 let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
3987 let response = HandshakeResponse::new(
3988 Some(&[][..]),
3989 (5u16, 5, 5),
3990 Some(&b"root"[..]),
3991 None::<&'static [u8]>,
3992 Some(AuthPlugin::MysqlNativePassword),
3993 flags_without_db_name,
3994 None,
3995 1_u32.to_be(),
3996 );
3997 let mut actual = Vec::new();
3998 response.serialize(&mut actual);
3999
4000 let expected: Vec<u8> = [
4001 0x05, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4005 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4009 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4011 .to_vec();
4012
4013 assert_eq!(expected, actual);
4014
4015 let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
4016 let response = HandshakeResponse::new(
4017 Some(&[][..]),
4018 (5u16, 5, 5),
4019 Some(&b"root"[..]),
4020 Some(&b"mydb"[..]),
4021 Some(AuthPlugin::MysqlNativePassword),
4022 flags_with_db_name,
4023 None,
4024 1_u32.to_be(),
4025 );
4026 let mut actual = Vec::new();
4027 response.serialize(&mut actual);
4028
4029 let expected: Vec<u8> = [
4030 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4034 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x64, 0x62, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4039 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4041 .to_vec();
4042
4043 assert_eq!(expected, actual);
4044
4045 let response = HandshakeResponse::new(
4046 Some(&[][..]),
4047 (5u16, 5, 5),
4048 Some(&b"root"[..]),
4049 Some(&b"mydb"[..]),
4050 Some(AuthPlugin::MysqlNativePassword),
4051 flags_without_db_name,
4052 None,
4053 1_u32.to_be(),
4054 );
4055 let mut actual = Vec::new();
4056 response.serialize(&mut actual);
4057 assert_eq!(expected, actual);
4058
4059 let response = HandshakeResponse::new(
4060 Some(&[][..]),
4061 (5u16, 5, 5),
4062 Some(&b"root"[..]),
4063 Some(&[][..]),
4064 Some(AuthPlugin::MysqlNativePassword),
4065 flags_with_db_name,
4066 None,
4067 1_u32.to_be(),
4068 );
4069 let mut actual = Vec::new();
4070 response.serialize(&mut actual);
4071
4072 let expected: Vec<u8> = [
4073 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4077 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4082 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4084 .to_vec();
4085 assert_eq!(expected, actual);
4086 }
4087
4088 #[test]
4089 fn parse_str_to_sid() {
4090 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4091 let sid = input.parse::<Sid<'_>>().unwrap();
4092 let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4093 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4094 assert_eq!(sid.intervals.len(), 1);
4095 assert_eq!(sid.intervals[0].start.0, 23);
4096 assert_eq!(sid.intervals[0].end.0, 24);
4097
4098 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4099 let sid = input.parse::<Sid<'_>>().unwrap();
4100 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4101 assert_eq!(sid.intervals.len(), 2);
4102 assert_eq!(sid.intervals[0].start.0, 1);
4103 assert_eq!(sid.intervals[0].end.0, 6);
4104 assert_eq!(sid.intervals[1].start.0, 10);
4105 assert_eq!(sid.intervals[1].end.0, 16);
4106
4107 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4108 let e = input.parse::<Sid<'_>>().unwrap_err();
4109 assert_eq!(
4110 e.to_string(),
4111 "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4112 );
4113
4114 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4115 let e = input.parse::<Sid<'_>>().unwrap_err();
4116 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4117
4118 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4119 let e = input.parse::<Sid<'_>>().unwrap_err();
4120 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4121
4122 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4123 let e = input.parse::<Sid<'_>>().unwrap_err();
4124 assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4125
4126 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4127 let e = input.parse::<Sid<'_>>().unwrap_err();
4128 assert_eq!(
4129 e.to_string(),
4130 "start(4) >= end(4) in GnoInterval".to_string()
4131 );
4132 }
4133
4134 #[test]
4135 fn should_parse_rsa_public_key_response_packet() {
4136 const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4137
4138 let rsa_public_key_response =
4139 PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4140
4141 assert!(rsa_public_key_response.is_ok());
4142 assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4143 }
4144
4145 #[test]
4146 fn should_build_rsa_public_key_response_packet() {
4147 let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4148
4149 let mut actual = Vec::new();
4150 rsa_public_key_response.serialize(&mut actual);
4151
4152 let expected = b"\x01test".to_vec();
4153
4154 assert_eq!(expected, actual);
4155 }
4156}