1use btoi::btoi;
10use bytes::BufMut;
11use regex::bytes::Regex;
12use uuid::Uuid;
13
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::{
17 borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
18};
19
20use crate::collations::CollationId;
21use crate::scramble::create_response_for_ed25519;
22use crate::{
23 constants::{
24 CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
25 MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags,
26 StmtExecuteParamsFlags,
27 },
28 io::{BufMutExt, ParseBuf},
29 misc::{
30 lenenc_str_len,
31 raw::{
32 Const, Either, RawBytes, RawConst, RawInt, Skip,
33 bytes::{
34 BareBytes, ConstBytes, ConstBytesValue, EofBytes, LenEnc, NullBytes, U8Bytes,
35 U32Bytes,
36 },
37 int::{ConstU8, ConstU32, LeU16, LeU24, LeU32, LeU32LowerHalf, LeU32UpperHalf, LeU64},
38 seq::Seq,
39 },
40 unexpected_buf_eof,
41 },
42 proto::{MyDeserialize, MySerialize},
43 value::{ClientSide, SerializationSide, Value},
44};
45
46use self::session_state_change::SessionStateChange;
47
48static MARIADB_VERSION_RE: LazyLock<Regex> =
49 LazyLock::new(|| Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap());
50static VERSION_RE: LazyLock<Regex> =
51 LazyLock::new(|| Regex::new(r"^(\d{1,2})\.(\d{1,2})\.(\d{1,3})(.*)").unwrap());
52
53macro_rules! define_header {
54 ($name:ident, $err:ident($msg:literal), $val:literal) => {
55 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
56 #[error($msg)]
57 pub struct $err;
58 pub type $name = crate::misc::raw::int::ConstU8<$err, $val>;
59 };
60 ($name:ident, $cmd:ident, $err:ident) => {
61 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
62 #[error("Invalid header for {}", stringify!($cmd))]
63 pub struct $err;
64 pub type $name = crate::misc::raw::int::ConstU8<$err, { Command::$cmd as u8 }>;
65 };
66}
67
68macro_rules! define_const {
69 ($kind:ident, $name:ident, $err:ident($msg:literal), $val:literal) => {
70 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
71 #[error($msg)]
72 pub struct $err;
73 pub type $name = $kind<$err, $val>;
74 };
75}
76
77macro_rules! define_const_bytes {
78 ($vname:ident, $name:ident, $err:ident($msg:literal), $val:expr_2021, $len:literal) => {
79 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
80 #[error($msg)]
81 pub struct $err;
82
83 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
84 pub struct $vname;
85
86 impl ConstBytesValue<$len> for $vname {
87 const VALUE: [u8; $len] = $val;
88 type Error = $err;
89 }
90
91 pub type $name = ConstBytes<$vname, $len>;
92 };
93}
94
95pub mod binlog_request;
96pub mod caching_sha2_password;
97pub mod session_state_change;
98
99define_const_bytes!(
100 Catalog,
101 ColumnDefinitionCatalog,
102 InvalidCatalog("Invalid catalog value in the column definition"),
103 *b"\x03def",
104 4
105);
106
107define_const!(
108 ConstU8,
109 FixedLengthFieldsLen,
110 InvalidFixedLengthFieldsLen("Invalid fixed length field length in the column definition"),
111 0x0c
112);
113
114#[derive(Debug, Default, Clone, Eq, PartialEq)]
116struct ColumnMeta<'a> {
117 schema: RawBytes<'a, LenEnc>,
118 table: RawBytes<'a, LenEnc>,
119 org_table: RawBytes<'a, LenEnc>,
120 name: RawBytes<'a, LenEnc>,
121 org_name: RawBytes<'a, LenEnc>,
122}
123
124impl ColumnMeta<'_> {
125 pub fn into_owned(self) -> ColumnMeta<'static> {
126 ColumnMeta {
127 schema: self.schema.into_owned(),
128 table: self.table.into_owned(),
129 org_table: self.org_table.into_owned(),
130 name: self.name.into_owned(),
131 org_name: self.org_name.into_owned(),
132 }
133 }
134
135 pub fn schema_ref(&self) -> &[u8] {
137 self.schema.as_bytes()
138 }
139
140 pub fn schema_str(&self) -> Cow<'_, str> {
142 String::from_utf8_lossy(self.schema_ref())
143 }
144
145 pub fn table_ref(&self) -> &[u8] {
147 self.table.as_bytes()
148 }
149
150 pub fn table_str(&self) -> Cow<'_, str> {
152 String::from_utf8_lossy(self.table_ref())
153 }
154
155 pub fn org_table_ref(&self) -> &[u8] {
159 self.org_table.as_bytes()
160 }
161
162 pub fn org_table_str(&self) -> Cow<'_, str> {
164 String::from_utf8_lossy(self.org_table_ref())
165 }
166
167 pub fn name_ref(&self) -> &[u8] {
169 self.name.as_bytes()
170 }
171
172 pub fn name_str(&self) -> Cow<'_, str> {
174 String::from_utf8_lossy(self.name_ref())
175 }
176
177 pub fn org_name_ref(&self) -> &[u8] {
181 self.org_name.as_bytes()
182 }
183
184 pub fn org_name_str(&self) -> Cow<'_, str> {
186 String::from_utf8_lossy(self.org_name_ref())
187 }
188}
189
190impl<'de> MyDeserialize<'de> for ColumnMeta<'de> {
191 const SIZE: Option<usize> = None;
192 type Ctx = ();
193
194 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
195 Ok(Self {
196 schema: buf.parse_unchecked(())?,
197 table: buf.parse_unchecked(())?,
198 org_table: buf.parse_unchecked(())?,
199 name: buf.parse_unchecked(())?,
200 org_name: buf.parse_unchecked(())?,
201 })
202 }
203}
204
205impl MySerialize for ColumnMeta<'_> {
206 fn serialize(&self, buf: &mut Vec<u8>) {
207 self.schema.serialize(&mut *buf);
208 self.table.serialize(&mut *buf);
209 self.org_table.serialize(&mut *buf);
210 self.name.serialize(&mut *buf);
211 self.org_name.serialize(&mut *buf);
212 }
213}
214
215#[derive(Debug, Clone, Eq, PartialEq)]
217pub struct Column {
218 catalog: ColumnDefinitionCatalog,
219 meta: Arc<ColumnMeta<'static>>,
220 fixed_length_fields_len: FixedLengthFieldsLen,
221 column_length: RawInt<LeU32>,
222 character_set: RawInt<LeU16>,
223 column_type: Const<ColumnType, u8>,
224 flags: Const<ColumnFlags, LeU16>,
225 decimals: RawInt<u8>,
226 __filler: Skip<2>,
227 }
229
230impl<'de> MyDeserialize<'de> for Column {
231 const SIZE: Option<usize> = None;
232 type Ctx = ();
233
234 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
235 let catalog = buf.parse(())?;
236 let meta = Arc::new(buf.parse::<ColumnMeta<'_>>(())?.into_owned());
237 let mut buf: ParseBuf<'_> = buf.parse(13)?;
238
239 Ok(Column {
240 catalog,
241 meta,
242 fixed_length_fields_len: buf.parse_unchecked(())?,
243 character_set: buf.parse_unchecked(())?,
244 column_length: buf.parse_unchecked(())?,
245 column_type: buf.parse_unchecked(())?,
246 flags: buf.parse_unchecked(())?,
247 decimals: buf.parse_unchecked(())?,
248 __filler: buf.parse_unchecked(())?,
249 })
250 }
251}
252
253impl MySerialize for Column {
254 fn serialize(&self, buf: &mut Vec<u8>) {
255 self.catalog.serialize(&mut *buf);
256 self.meta.serialize(&mut *buf);
257 self.fixed_length_fields_len.serialize(&mut *buf);
258 self.column_length.serialize(&mut *buf);
259 self.character_set.serialize(&mut *buf);
260 self.column_type.serialize(&mut *buf);
261 self.flags.serialize(&mut *buf);
262 self.decimals.serialize(&mut *buf);
263 self.__filler.serialize(&mut *buf);
264 }
265}
266
267impl Column {
268 pub fn new(column_type: ColumnType) -> Self {
269 Self {
270 catalog: Default::default(),
271 meta: Default::default(),
272 fixed_length_fields_len: Default::default(),
273 column_length: Default::default(),
274 character_set: Default::default(),
275 flags: Default::default(),
276 column_type: Const::new(column_type),
277 decimals: Default::default(),
278 __filler: Skip,
279 }
280 }
281
282 pub fn with_schema(mut self, schema: &[u8]) -> Self {
283 Arc::make_mut(&mut self.meta).schema = RawBytes::new(schema).into_owned();
284 self
285 }
286
287 pub fn with_table(mut self, table: &[u8]) -> Self {
288 Arc::make_mut(&mut self.meta).table = RawBytes::new(table).into_owned();
289 self
290 }
291
292 pub fn with_org_table(mut self, org_table: &[u8]) -> Self {
293 Arc::make_mut(&mut self.meta).org_table = RawBytes::new(org_table).into_owned();
294 self
295 }
296
297 pub fn with_name(mut self, name: &[u8]) -> Self {
298 Arc::make_mut(&mut self.meta).name = RawBytes::new(name).into_owned();
299 self
300 }
301
302 pub fn with_org_name(mut self, org_name: &[u8]) -> Self {
303 Arc::make_mut(&mut self.meta).org_name = RawBytes::new(org_name).into_owned();
304 self
305 }
306
307 pub fn with_flags(mut self, flags: ColumnFlags) -> Self {
308 self.flags = Const::new(flags);
309 self
310 }
311
312 pub fn with_column_length(mut self, column_length: u32) -> Self {
313 self.column_length = RawInt::new(column_length);
314 self
315 }
316
317 pub fn with_character_set(mut self, character_set: u16) -> Self {
318 self.character_set = RawInt::new(character_set);
319 self
320 }
321
322 pub fn with_decimals(mut self, decimals: u8) -> Self {
323 self.decimals = RawInt::new(decimals);
324 self
325 }
326
327 pub fn column_length(&self) -> u32 {
331 *self.column_length
332 }
333
334 pub fn column_type(&self) -> ColumnType {
336 *self.column_type
337 }
338
339 pub fn character_set(&self) -> u16 {
341 *self.character_set
342 }
343
344 pub fn flags(&self) -> ColumnFlags {
346 *self.flags
347 }
348
349 pub fn decimals(&self) -> u8 {
357 *self.decimals
358 }
359
360 #[inline(always)]
362 pub fn schema_ref(&self) -> &[u8] {
363 self.meta.schema_ref()
364 }
365
366 #[inline(always)]
368 pub fn schema_str(&self) -> Cow<'_, str> {
369 self.meta.schema_str()
370 }
371
372 #[inline(always)]
374 pub fn table_ref(&self) -> &[u8] {
375 self.meta.table_ref()
376 }
377
378 #[inline(always)]
380 pub fn table_str(&self) -> Cow<'_, str> {
381 self.meta.table_str()
382 }
383
384 #[inline(always)]
388 pub fn org_table_ref(&self) -> &[u8] {
389 self.meta.org_table_ref()
390 }
391
392 #[inline(always)]
394 pub fn org_table_str(&self) -> Cow<'_, str> {
395 self.meta.org_table_str()
396 }
397
398 #[inline(always)]
400 pub fn name_ref(&self) -> &[u8] {
401 self.meta.name_ref()
402 }
403
404 #[inline(always)]
406 pub fn name_str(&self) -> Cow<'_, str> {
407 self.meta.name_str()
408 }
409
410 #[inline(always)]
414 pub fn org_name_ref(&self) -> &[u8] {
415 self.meta.org_name_ref()
416 }
417
418 #[inline(always)]
420 pub fn org_name_str(&self) -> Cow<'_, str> {
421 self.meta.org_name_str()
422 }
423}
424
425#[derive(Debug, Clone, Eq, PartialEq)]
427pub struct SessionStateInfo<'a> {
428 data_type: Const<SessionStateType, u8>,
429 data: RawBytes<'a, LenEnc>,
430}
431
432impl SessionStateInfo<'_> {
433 pub fn into_owned(self) -> SessionStateInfo<'static> {
434 let SessionStateInfo { data_type, data } = self;
435 SessionStateInfo {
436 data_type,
437 data: data.into_owned(),
438 }
439 }
440
441 pub fn data_type(&self) -> SessionStateType {
442 *self.data_type
443 }
444
445 pub fn data_ref(&self) -> &[u8] {
447 self.data.as_bytes()
448 }
449
450 pub fn decode(&self) -> io::Result<SessionStateChange<'_>> {
452 ParseBuf(self.data.as_bytes()).parse_unchecked(*self.data_type)
453 }
454}
455
456impl<'de> MyDeserialize<'de> for SessionStateInfo<'de> {
457 const SIZE: Option<usize> = None;
458 type Ctx = ();
459
460 fn deserialize(_ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
461 Ok(SessionStateInfo {
462 data_type: buf.parse(())?,
463 data: buf.parse(())?,
464 })
465 }
466}
467
468impl MySerialize for SessionStateInfo<'_> {
469 fn serialize(&self, buf: &mut Vec<u8>) {
470 self.data_type.serialize(&mut *buf);
471 self.data.serialize(buf);
472 }
473}
474
475#[derive(Debug, Clone, Eq, PartialEq)]
477pub struct OkPacketBody<'a> {
478 affected_rows: RawInt<LenEnc>,
479 last_insert_id: RawInt<LenEnc>,
480 status_flags: Const<StatusFlags, LeU16>,
481 warnings: RawInt<LeU16>,
482 info: RawBytes<'a, LenEnc>,
483 session_state_info: RawBytes<'a, LenEnc>,
484}
485
486pub trait OkPacketKind {
490 const HEADER: u8;
491
492 fn parse_body<'de>(
493 capabilities: CapabilityFlags,
494 buf: &mut ParseBuf<'de>,
495 ) -> io::Result<OkPacketBody<'de>>;
496}
497
498#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
500pub struct ResultSetTerminator;
501
502impl OkPacketKind for ResultSetTerminator {
503 const HEADER: u8 = 0xFE;
504
505 fn parse_body<'de>(
506 capabilities: CapabilityFlags,
507 buf: &mut ParseBuf<'de>,
508 ) -> io::Result<OkPacketBody<'de>> {
509 buf.parse::<RawInt<LenEnc>>(())?;
514 buf.parse::<RawInt<LenEnc>>(())?;
515
516 let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
518 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
519 let warnings = sbuf.parse_unchecked(())?;
520
521 let (info, session_state_info) =
522 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
523 let info = buf.parse(())?;
524 let session_state_info =
525 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
526 buf.parse(())?
527 } else {
528 RawBytes::default()
529 };
530 (info, session_state_info)
531 } else if !buf.is_empty() && buf.0[0] > 0 {
532 let info = buf.parse(())?;
536 (info, RawBytes::default())
537 } else {
538 (RawBytes::default(), RawBytes::default())
539 };
540
541 Ok(OkPacketBody {
542 affected_rows: RawInt::new(0),
543 last_insert_id: RawInt::new(0),
544 status_flags,
545 warnings,
546 info,
547 session_state_info,
548 })
549 }
550}
551
552#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
554pub struct OldEofPacket;
555
556impl OkPacketKind for OldEofPacket {
557 const HEADER: u8 = 0xFE;
558
559 fn parse_body<'de>(
560 _: CapabilityFlags,
561 buf: &mut ParseBuf<'de>,
562 ) -> io::Result<OkPacketBody<'de>> {
563 let mut buf: ParseBuf<'_> = buf.parse(4)?;
565 let warnings = buf.parse_unchecked(())?;
566 let status_flags = buf.parse_unchecked(())?;
567
568 Ok(OkPacketBody {
569 affected_rows: RawInt::new(0),
570 last_insert_id: RawInt::new(0),
571 status_flags,
572 warnings,
573 info: RawBytes::new(&[][..]),
574 session_state_info: RawBytes::new(&[][..]),
575 })
576 }
577}
578
579#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
581pub struct NetworkStreamTerminator;
582
583impl OkPacketKind for NetworkStreamTerminator {
584 const HEADER: u8 = 0xFE;
585
586 fn parse_body<'de>(
587 flags: CapabilityFlags,
588 buf: &mut ParseBuf<'de>,
589 ) -> io::Result<OkPacketBody<'de>> {
590 OldEofPacket::parse_body(flags, buf)
591 }
592}
593
594#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
596pub struct CommonOkPacket;
597
598impl OkPacketKind for CommonOkPacket {
599 const HEADER: u8 = 0x00;
600
601 fn parse_body<'de>(
602 capabilities: CapabilityFlags,
603 buf: &mut ParseBuf<'de>,
604 ) -> io::Result<OkPacketBody<'de>> {
605 let affected_rows = buf.parse(())?;
606 let last_insert_id = buf.parse(())?;
607
608 let mut sbuf: ParseBuf<'_> = buf.parse(4)?;
610 let status_flags: Const<StatusFlags, LeU16> = sbuf.parse_unchecked(())?;
611 let warnings = sbuf.parse_unchecked(())?;
612
613 let (info, session_state_info) =
614 if capabilities.contains(CapabilityFlags::CLIENT_SESSION_TRACK) && !buf.is_empty() {
615 let info = buf.parse(())?;
616 let session_state_info =
617 if status_flags.contains(StatusFlags::SERVER_SESSION_STATE_CHANGED) {
618 buf.parse(())?
619 } else {
620 RawBytes::default()
621 };
622 (info, session_state_info)
623 } else if !buf.is_empty() && buf.0[0] > 0 {
624 let info = buf.parse(())?;
628 (info, RawBytes::default())
629 } else {
630 (RawBytes::default(), RawBytes::default())
631 };
632
633 Ok(OkPacketBody {
634 affected_rows,
635 last_insert_id,
636 status_flags,
637 warnings,
638 info,
639 session_state_info,
640 })
641 }
642}
643
644impl<'a> TryFrom<OkPacketBody<'a>> for OkPacket<'a> {
645 type Error = io::Error;
646
647 fn try_from(body: OkPacketBody<'a>) -> io::Result<Self> {
648 Ok(OkPacket {
649 affected_rows: *body.affected_rows,
650 last_insert_id: if *body.last_insert_id == 0 {
651 None
652 } else {
653 Some(*body.last_insert_id)
654 },
655 status_flags: *body.status_flags,
656 warnings: *body.warnings,
657 info: if !body.info.is_empty() {
658 Some(body.info)
659 } else {
660 None
661 },
662 session_state_info: if !body.session_state_info.is_empty() {
663 Some(body.session_state_info)
664 } else {
665 None
666 },
667 })
668 }
669}
670
671#[derive(Debug, Clone, Eq, PartialEq)]
673pub struct OkPacket<'a> {
674 affected_rows: u64,
675 last_insert_id: Option<u64>,
676 status_flags: StatusFlags,
677 warnings: u16,
678 info: Option<RawBytes<'a, LenEnc>>,
679 session_state_info: Option<RawBytes<'a, LenEnc>>,
680}
681
682impl OkPacket<'_> {
683 pub fn into_owned(self) -> OkPacket<'static> {
684 OkPacket {
685 affected_rows: self.affected_rows,
686 last_insert_id: self.last_insert_id,
687 status_flags: self.status_flags,
688 warnings: self.warnings,
689 info: self.info.map(|x| x.into_owned()),
690 session_state_info: self.session_state_info.map(|x| x.into_owned()),
691 }
692 }
693
694 pub fn affected_rows(&self) -> u64 {
696 self.affected_rows
697 }
698
699 pub fn last_insert_id(&self) -> Option<u64> {
701 self.last_insert_id
702 }
703
704 pub fn status_flags(&self) -> StatusFlags {
706 self.status_flags
707 }
708
709 pub fn warnings(&self) -> u16 {
711 self.warnings
712 }
713
714 pub fn info_ref(&self) -> Option<&[u8]> {
716 self.info.as_ref().map(|x| x.as_bytes())
717 }
718
719 pub fn info_str(&self) -> Option<Cow<'_, str>> {
721 self.info.as_ref().map(|x| x.as_str())
722 }
723
724 pub fn session_state_info_ref(&self) -> Option<&[u8]> {
726 self.session_state_info.as_ref().map(|x| x.as_bytes())
727 }
728
729 pub fn session_state_info(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
731 self.session_state_info_ref()
732 .map(|data| {
733 let mut data = ParseBuf(data);
734 let mut entries = Vec::new();
735 while !data.is_empty() {
736 entries.push(data.parse(())?);
737 }
738 Ok(entries)
739 })
740 .transpose()
741 .map(|x| x.unwrap_or_default())
742 }
743}
744
745#[derive(Debug, Clone, PartialEq, Eq)]
746pub struct OkPacketDeserializer<'de, T>(OkPacket<'de>, PhantomData<T>);
747
748impl<'de, T> OkPacketDeserializer<'de, T> {
749 pub fn into_inner(self) -> OkPacket<'de> {
750 self.0
751 }
752}
753
754impl<'de, T> From<OkPacketDeserializer<'de, T>> for OkPacket<'de> {
755 fn from(x: OkPacketDeserializer<'de, T>) -> Self {
756 x.0
757 }
758}
759
760#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
761#[error("Invalid OK packet header")]
762pub struct InvalidOkPacketHeader;
763
764impl<'de, T: OkPacketKind> MyDeserialize<'de> for OkPacketDeserializer<'de, T> {
765 const SIZE: Option<usize> = None;
766 type Ctx = CapabilityFlags;
767
768 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
769 if *buf.parse::<RawInt<u8>>(())? == T::HEADER {
770 let body = T::parse_body(capabilities, buf)?;
771 let ok = OkPacket::try_from(body)?;
772 Ok(Self(ok, PhantomData))
773 } else {
774 Err(io::Error::new(
775 io::ErrorKind::InvalidData,
776 InvalidOkPacketHeader,
777 ))
778 }
779 }
780}
781
782#[derive(Debug, Clone, Eq, PartialEq)]
784pub struct ProgressReport<'a> {
785 stage: RawInt<u8>,
786 max_stage: RawInt<u8>,
787 progress: RawInt<LeU24>,
788 stage_info: RawBytes<'a, LenEnc>,
789}
790
791impl<'a> ProgressReport<'a> {
792 pub fn new(
793 stage: u8,
794 max_stage: u8,
795 progress: u32,
796 stage_info: impl Into<Cow<'a, [u8]>>,
797 ) -> ProgressReport<'a> {
798 ProgressReport {
799 stage: RawInt::new(stage),
800 max_stage: RawInt::new(max_stage),
801 progress: RawInt::new(progress),
802 stage_info: RawBytes::new(stage_info),
803 }
804 }
805
806 pub fn stage(&self) -> u8 {
808 *self.stage
809 }
810
811 pub fn max_stage(&self) -> u8 {
812 *self.max_stage
813 }
814
815 pub fn progress(&self) -> u32 {
817 *self.progress
818 }
819
820 pub fn stage_info_ref(&self) -> &[u8] {
822 self.stage_info.as_bytes()
823 }
824
825 pub fn stage_info_str(&self) -> Cow<'_, str> {
827 self.stage_info.as_str()
828 }
829
830 pub fn into_owned(self) -> ProgressReport<'static> {
831 ProgressReport {
832 stage: self.stage,
833 max_stage: self.max_stage,
834 progress: self.progress,
835 stage_info: self.stage_info.into_owned(),
836 }
837 }
838}
839
840impl<'de> MyDeserialize<'de> for ProgressReport<'de> {
841 const SIZE: Option<usize> = None;
842 type Ctx = ();
843
844 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
845 let mut sbuf: ParseBuf<'_> = buf.parse(6)?;
846
847 sbuf.skip(1); Ok(ProgressReport {
850 stage: sbuf.parse_unchecked(())?,
851 max_stage: sbuf.parse_unchecked(())?,
852 progress: sbuf.parse_unchecked(())?,
853 stage_info: buf.parse(())?,
854 })
855 }
856}
857
858impl MySerialize for ProgressReport<'_> {
859 fn serialize(&self, buf: &mut Vec<u8>) {
860 buf.put_u8(1);
861 self.stage.serialize(&mut *buf);
862 self.max_stage.serialize(&mut *buf);
863 self.progress.serialize(&mut *buf);
864 self.stage_info.serialize(buf);
865 }
866}
867
868impl fmt::Display for ProgressReport<'_> {
869 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
870 write!(
871 f,
872 "Stage: {} of {} '{}' {:.2}% of stage done",
873 self.stage(),
874 self.max_stage(),
875 self.progress(),
876 self.stage_info_str()
877 )
878 }
879}
880
881define_header!(
882 ErrPacketHeader,
883 InvalidErrPacketHeader("Invalid error packet header"),
884 0xFF
885);
886
887#[derive(Debug, Clone, PartialEq)]
891pub enum ErrPacket<'a> {
892 Error(ServerError<'a>),
893 Progress(ProgressReport<'a>),
894}
895
896impl ErrPacket<'_> {
897 pub fn is_error(&self) -> bool {
899 matches!(self, ErrPacket::Error { .. })
900 }
901
902 pub fn is_progress_report(&self) -> bool {
904 !self.is_error()
905 }
906
907 pub fn progress_report(&self) -> &ProgressReport<'_> {
909 match *self {
910 ErrPacket::Progress(ref progress_report) => progress_report,
911 _ => panic!("This ErrPacket does not contains progress report"),
912 }
913 }
914
915 pub fn server_error(&self) -> &ServerError<'_> {
917 match self {
918 ErrPacket::Error(error) => error,
919 ErrPacket::Progress(_) => panic!("This ErrPacket does not contain a ServerError"),
920 }
921 }
922}
923
924impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
925 const SIZE: Option<usize> = None;
926 type Ctx = CapabilityFlags;
927
928 fn deserialize(capabilities: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
929 let mut sbuf: ParseBuf<'_> = buf.parse(3)?;
930 sbuf.parse_unchecked::<ErrPacketHeader>(())?;
931 let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;
932
933 if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
934 buf.parse(()).map(ErrPacket::Progress)
935 } else {
936 buf.parse((
937 *code,
938 capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
939 ))
940 .map(ErrPacket::Error)
941 }
942 }
943}
944
945impl MySerialize for ErrPacket<'_> {
946 fn serialize(&self, buf: &mut Vec<u8>) {
947 ErrPacketHeader::new().serialize(&mut *buf);
948 match self {
949 ErrPacket::Error(server_error) => {
950 server_error.code.serialize(&mut *buf);
951 server_error.serialize(buf);
952 }
953 ErrPacket::Progress(progress_report) => {
954 RawInt::<LeU16>::new(0xFFFF).serialize(&mut *buf);
955 progress_report.serialize(buf);
956 }
957 }
958 }
959}
960
961impl fmt::Display for ErrPacket<'_> {
962 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
963 match self {
964 ErrPacket::Error(server_error) => write!(f, "{}", server_error),
965 ErrPacket::Progress(progress_report) => write!(f, "{}", progress_report),
966 }
967 }
968}
969
970define_header!(
971 SqlStateMarker,
972 InvalidSqlStateMarker("Invalid SqlStateMarker value"),
973 b'#'
974);
975
976#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
978pub struct SqlState {
979 __state_marker: SqlStateMarker,
980 state: [u8; 5],
981}
982
983impl SqlState {
984 pub fn new(state: [u8; 5]) -> Self {
986 Self {
987 __state_marker: SqlStateMarker::new(),
988 state,
989 }
990 }
991
992 pub fn as_bytes(&self) -> [u8; 5] {
994 self.state
995 }
996
997 pub fn as_str(&self) -> Cow<'_, str> {
999 String::from_utf8_lossy(&self.state)
1000 }
1001}
1002
1003impl<'de> MyDeserialize<'de> for SqlState {
1004 const SIZE: Option<usize> = Some(6);
1005 type Ctx = ();
1006
1007 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1008 Ok(Self {
1009 __state_marker: buf.parse(())?,
1010 state: buf.parse(())?,
1011 })
1012 }
1013}
1014
1015impl MySerialize for SqlState {
1016 fn serialize(&self, buf: &mut Vec<u8>) {
1017 self.__state_marker.serialize(buf);
1018 self.state.serialize(buf);
1019 }
1020}
1021
1022#[derive(Debug, Clone, PartialEq)]
1026pub struct ServerError<'a> {
1027 code: RawInt<LeU16>,
1028 state: Option<SqlState>,
1029 message: RawBytes<'a, EofBytes>,
1030}
1031
1032impl<'a> ServerError<'a> {
1033 pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
1034 Self {
1035 code: RawInt::new(code),
1036 state,
1037 message: RawBytes::new(msg),
1038 }
1039 }
1040
1041 pub fn error_code(&self) -> u16 {
1043 *self.code
1044 }
1045
1046 pub fn sql_state_ref(&self) -> Option<&SqlState> {
1048 self.state.as_ref()
1049 }
1050
1051 pub fn message_ref(&self) -> &[u8] {
1053 self.message.as_bytes()
1054 }
1055
1056 pub fn message_str(&self) -> Cow<'_, str> {
1058 self.message.as_str()
1059 }
1060
1061 pub fn into_owned(self) -> ServerError<'static> {
1062 ServerError {
1063 code: self.code,
1064 state: self.state,
1065 message: self.message.into_owned(),
1066 }
1067 }
1068}
1069
1070impl<'de> MyDeserialize<'de> for ServerError<'de> {
1071 const SIZE: Option<usize> = None;
1072 type Ctx = (u16, bool);
1074
1075 fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1076 let server_error = if protocol_41 {
1077 ServerError {
1078 code: RawInt::new(code),
1079 state: Some(buf.parse(())?),
1080 message: buf.parse(())?,
1081 }
1082 } else {
1083 ServerError {
1084 code: RawInt::new(code),
1085 state: None,
1086 message: buf.parse(())?,
1087 }
1088 };
1089 Ok(server_error)
1090 }
1091}
1092
1093impl MySerialize for ServerError<'_> {
1094 fn serialize(&self, buf: &mut Vec<u8>) {
1095 if let Some(state) = &self.state {
1096 state.serialize(buf);
1097 }
1098 self.message.serialize(buf);
1099 }
1100}
1101
1102impl fmt::Display for ServerError<'_> {
1103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1104 let sql_state_str = self
1105 .sql_state_ref()
1106 .map(|s| format!(" ({})", s.as_str()))
1107 .unwrap_or_default();
1108
1109 write!(
1110 f,
1111 "ERROR {}{}: {}",
1112 self.error_code(),
1113 sql_state_str,
1114 self.message_str()
1115 )
1116 }
1117}
1118
1119define_header!(
1120 LocalInfileHeader,
1121 InvalidLocalInfileHeader("Invalid LOCAL_INFILE header"),
1122 0xFB
1123);
1124
1125#[derive(Debug, Clone, Eq, PartialEq)]
1127pub struct LocalInfilePacket<'a> {
1128 __header: LocalInfileHeader,
1129 file_name: RawBytes<'a, EofBytes>,
1130}
1131
1132impl<'a> LocalInfilePacket<'a> {
1133 pub fn new(file_name: impl Into<Cow<'a, [u8]>>) -> Self {
1134 Self {
1135 __header: LocalInfileHeader::new(),
1136 file_name: RawBytes::new(file_name),
1137 }
1138 }
1139
1140 pub fn file_name_ref(&self) -> &[u8] {
1142 self.file_name.as_bytes()
1143 }
1144
1145 pub fn file_name_str(&self) -> Cow<'_, str> {
1147 self.file_name.as_str()
1148 }
1149
1150 pub fn into_owned(self) -> LocalInfilePacket<'static> {
1151 LocalInfilePacket {
1152 __header: self.__header,
1153 file_name: self.file_name.into_owned(),
1154 }
1155 }
1156}
1157
1158impl<'de> MyDeserialize<'de> for LocalInfilePacket<'de> {
1159 const SIZE: Option<usize> = None;
1160 type Ctx = ();
1161
1162 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1163 Ok(LocalInfilePacket {
1164 __header: buf.parse(())?,
1165 file_name: buf.parse(())?,
1166 })
1167 }
1168}
1169
1170impl MySerialize for LocalInfilePacket<'_> {
1171 fn serialize(&self, buf: &mut Vec<u8>) {
1172 self.__header.serialize(buf);
1173 self.file_name.serialize(buf);
1174 }
1175}
1176
1177const MYSQL_OLD_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_old_password";
1178const MYSQL_NATIVE_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_native_password";
1179const CACHING_SHA2_PASSWORD_PLUGIN_NAME: &[u8] = b"caching_sha2_password";
1180const MYSQL_CLEAR_PASSWORD_PLUGIN_NAME: &[u8] = b"mysql_clear_password";
1181const ED25519_PLUGIN_NAME: &[u8] = b"client_ed25519";
1182
1183#[derive(Debug, Clone, PartialEq, Eq)]
1184pub enum AuthPluginData<'a> {
1185 Old([u8; 8]),
1187 Native([u8; 20]),
1189 Sha2([u8; 32]),
1191 Clear(Cow<'a, [u8]>),
1193 Ed25519([u8; 64]),
1198}
1199
1200impl AuthPluginData<'_> {
1201 pub fn into_owned(self) -> AuthPluginData<'static> {
1202 match self {
1203 AuthPluginData::Old(x) => AuthPluginData::Old(x),
1204 AuthPluginData::Native(x) => AuthPluginData::Native(x),
1205 AuthPluginData::Sha2(x) => AuthPluginData::Sha2(x),
1206 AuthPluginData::Clear(x) => AuthPluginData::Clear(Cow::Owned(x.into_owned())),
1207 AuthPluginData::Ed25519(x) => AuthPluginData::Ed25519(x),
1208 }
1209 }
1210}
1211
1212impl std::ops::Deref for AuthPluginData<'_> {
1213 type Target = [u8];
1214
1215 fn deref(&self) -> &Self::Target {
1216 match self {
1217 Self::Sha2(x) => &x[..],
1218 Self::Native(x) => &x[..],
1219 Self::Old(x) => &x[..],
1220 Self::Clear(x) => &x[..],
1221 Self::Ed25519(x) => &x[..],
1222 }
1223 }
1224}
1225
1226impl MySerialize for AuthPluginData<'_> {
1227 fn serialize(&self, buf: &mut Vec<u8>) {
1228 match self {
1229 Self::Sha2(x) => buf.put_slice(&x[..]),
1230 Self::Native(x) => buf.put_slice(&x[..]),
1231 Self::Old(x) => {
1232 buf.put_slice(&x[..]);
1233 buf.push(0);
1234 }
1235 Self::Clear(x) => {
1236 buf.put_slice(x);
1237 buf.push(0);
1238 }
1239 Self::Ed25519(x) => buf.put_slice(&x[..]),
1240 }
1241 }
1242}
1243
1244#[derive(Debug, Clone, Eq, PartialEq, Hash)]
1246pub enum AuthPlugin<'a> {
1247 MysqlOldPassword,
1249 MysqlClearPassword,
1251 MysqlNativePassword,
1253 CachingSha2Password,
1255 Ed25519,
1260 Other(Cow<'a, [u8]>),
1261}
1262
1263impl<'de> MyDeserialize<'de> for AuthPlugin<'de> {
1264 const SIZE: Option<usize> = None;
1265 type Ctx = ();
1266
1267 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1268 Ok(Self::from_bytes(buf.eat_all()))
1269 }
1270}
1271
1272impl MySerialize for AuthPlugin<'_> {
1273 fn serialize(&self, buf: &mut Vec<u8>) {
1274 buf.put_slice(self.as_bytes());
1275 buf.put_u8(0);
1276 }
1277}
1278
1279impl<'a> AuthPlugin<'a> {
1280 pub fn from_bytes(name: &'a [u8]) -> AuthPlugin<'a> {
1281 let name = if let [name @ .., 0] = name {
1282 name
1283 } else {
1284 name
1285 };
1286 match name {
1287 CACHING_SHA2_PASSWORD_PLUGIN_NAME => AuthPlugin::CachingSha2Password,
1288 MYSQL_NATIVE_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlNativePassword,
1289 MYSQL_OLD_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlOldPassword,
1290 MYSQL_CLEAR_PASSWORD_PLUGIN_NAME => AuthPlugin::MysqlClearPassword,
1291 ED25519_PLUGIN_NAME => AuthPlugin::Ed25519,
1292 name => AuthPlugin::Other(Cow::Borrowed(name)),
1293 }
1294 }
1295
1296 pub fn as_bytes(&self) -> &[u8] {
1297 match self {
1298 AuthPlugin::CachingSha2Password => CACHING_SHA2_PASSWORD_PLUGIN_NAME,
1299 AuthPlugin::MysqlNativePassword => MYSQL_NATIVE_PASSWORD_PLUGIN_NAME,
1300 AuthPlugin::MysqlOldPassword => MYSQL_OLD_PASSWORD_PLUGIN_NAME,
1301 AuthPlugin::MysqlClearPassword => MYSQL_CLEAR_PASSWORD_PLUGIN_NAME,
1302 AuthPlugin::Ed25519 => ED25519_PLUGIN_NAME,
1303 AuthPlugin::Other(name) => name,
1304 }
1305 }
1306
1307 pub fn into_owned(self) -> AuthPlugin<'static> {
1308 match self {
1309 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1310 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1311 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1312 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1313 AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1314 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Owned(name.into_owned())),
1315 }
1316 }
1317
1318 pub fn borrow(&self) -> AuthPlugin<'_> {
1319 match self {
1320 AuthPlugin::CachingSha2Password => AuthPlugin::CachingSha2Password,
1321 AuthPlugin::MysqlNativePassword => AuthPlugin::MysqlNativePassword,
1322 AuthPlugin::MysqlOldPassword => AuthPlugin::MysqlOldPassword,
1323 AuthPlugin::MysqlClearPassword => AuthPlugin::MysqlClearPassword,
1324 AuthPlugin::Ed25519 => AuthPlugin::Ed25519,
1325 AuthPlugin::Other(name) => AuthPlugin::Other(Cow::Borrowed(name.as_ref())),
1326 }
1327 }
1328
1329 pub fn gen_data<'b>(&self, pass: Option<&'b str>, nonce: &[u8]) -> Option<AuthPluginData<'b>> {
1339 use super::scramble::{scramble_323, scramble_native, scramble_sha256};
1340
1341 match pass {
1342 Some(pass) if !pass.is_empty() => match self {
1343 AuthPlugin::CachingSha2Password => {
1344 scramble_sha256(nonce, pass.as_bytes()).map(AuthPluginData::Sha2)
1345 }
1346 AuthPlugin::MysqlNativePassword => {
1347 scramble_native(nonce, pass.as_bytes()).map(AuthPluginData::Native)
1348 }
1349 AuthPlugin::MysqlOldPassword => {
1350 scramble_323(nonce.chunks(8).next().unwrap(), pass.as_bytes())
1351 .map(AuthPluginData::Old)
1352 }
1353 AuthPlugin::MysqlClearPassword => {
1354 Some(AuthPluginData::Clear(Cow::Borrowed(pass.as_bytes())))
1355 }
1356 AuthPlugin::Ed25519 => Some(AuthPluginData::Ed25519(create_response_for_ed25519(
1357 pass.as_bytes(),
1358 nonce,
1359 ))),
1360 AuthPlugin::Other(_) => None,
1361 },
1362 _ => None,
1363 }
1364 }
1365}
1366
1367define_header!(
1368 AuthMoreDataHeader,
1369 InvalidAuthMoreDataHeader("Invalid AuthMoreData header"),
1370 0x01
1371);
1372
1373#[derive(Debug, Clone, Eq, PartialEq)]
1375pub struct AuthMoreData<'a> {
1376 __header: AuthMoreDataHeader,
1377 data: RawBytes<'a, EofBytes>,
1378}
1379
1380impl<'a> AuthMoreData<'a> {
1381 pub fn new(data: impl Into<Cow<'a, [u8]>>) -> Self {
1382 Self {
1383 __header: AuthMoreDataHeader::new(),
1384 data: RawBytes::new(data),
1385 }
1386 }
1387
1388 pub fn data(&self) -> &[u8] {
1389 self.data.as_bytes()
1390 }
1391
1392 pub fn into_owned(self) -> AuthMoreData<'static> {
1393 AuthMoreData {
1394 __header: self.__header,
1395 data: self.data.into_owned(),
1396 }
1397 }
1398}
1399
1400impl<'de> MyDeserialize<'de> for AuthMoreData<'de> {
1401 const SIZE: Option<usize> = None;
1402 type Ctx = ();
1403
1404 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1405 Ok(Self {
1406 __header: buf.parse(())?,
1407 data: buf.parse(())?,
1408 })
1409 }
1410}
1411
1412impl MySerialize for AuthMoreData<'_> {
1413 fn serialize(&self, buf: &mut Vec<u8>) {
1414 self.__header.serialize(&mut *buf);
1415 self.data.serialize(buf);
1416 }
1417}
1418
1419define_header!(
1420 PublicKeyResponseHeader,
1421 InvalidPublicKeyResponse("Invalid PublicKeyResponse header"),
1422 0x01
1423);
1424
1425#[derive(Debug, Clone, Eq, PartialEq)]
1429pub struct PublicKeyResponse<'a> {
1430 __header: PublicKeyResponseHeader,
1431 rsa_key: RawBytes<'a, EofBytes>,
1432}
1433
1434impl<'a> PublicKeyResponse<'a> {
1435 pub fn new(rsa_key: impl Into<Cow<'a, [u8]>>) -> Self {
1436 Self {
1437 __header: PublicKeyResponseHeader::new(),
1438 rsa_key: RawBytes::new(rsa_key),
1439 }
1440 }
1441
1442 pub fn rsa_key(&self) -> Cow<'_, str> {
1444 self.rsa_key.as_str()
1445 }
1446}
1447
1448impl<'de> MyDeserialize<'de> for PublicKeyResponse<'de> {
1449 const SIZE: Option<usize> = None;
1450 type Ctx = ();
1451
1452 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1453 Ok(Self {
1454 __header: buf.parse(())?,
1455 rsa_key: buf.parse(())?,
1456 })
1457 }
1458}
1459
1460impl MySerialize for PublicKeyResponse<'_> {
1461 fn serialize(&self, buf: &mut Vec<u8>) {
1462 self.__header.serialize(&mut *buf);
1463 self.rsa_key.serialize(buf);
1464 }
1465}
1466
1467define_header!(
1468 AuthSwitchRequestHeader,
1469 InvalidAuthSwithRequestHeader("Invalid auth switch request header"),
1470 0xFE
1471);
1472
1473#[derive(Debug, Clone, Eq, PartialEq)]
1478pub struct OldAuthSwitchRequest {
1479 __header: AuthSwitchRequestHeader,
1480}
1481
1482impl OldAuthSwitchRequest {
1483 pub fn new() -> Self {
1484 Self {
1485 __header: AuthSwitchRequestHeader::new(),
1486 }
1487 }
1488
1489 pub const fn auth_plugin(&self) -> AuthPlugin<'static> {
1490 AuthPlugin::MysqlOldPassword
1491 }
1492}
1493
1494impl Default for OldAuthSwitchRequest {
1495 fn default() -> Self {
1496 Self::new()
1497 }
1498}
1499
1500impl<'de> MyDeserialize<'de> for OldAuthSwitchRequest {
1501 const SIZE: Option<usize> = Some(1);
1502 type Ctx = ();
1503
1504 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1505 Ok(Self {
1506 __header: buf.parse(())?,
1507 })
1508 }
1509}
1510
1511impl MySerialize for OldAuthSwitchRequest {
1512 fn serialize(&self, buf: &mut Vec<u8>) {
1513 self.__header.serialize(&mut *buf);
1514 }
1515}
1516
1517#[derive(Debug, Clone, Eq, PartialEq)]
1522pub struct AuthSwitchRequest<'a> {
1523 __header: AuthSwitchRequestHeader,
1524 auth_plugin: RawBytes<'a, NullBytes>,
1525 plugin_data: RawBytes<'a, EofBytes>,
1526}
1527
1528impl<'a> AuthSwitchRequest<'a> {
1529 pub fn new(
1530 auth_plugin: impl Into<Cow<'a, [u8]>>,
1531 plugin_data: impl Into<Cow<'a, [u8]>>,
1532 ) -> Self {
1533 Self {
1534 __header: AuthSwitchRequestHeader::new(),
1535 auth_plugin: RawBytes::new(auth_plugin),
1536 plugin_data: RawBytes::new(plugin_data),
1537 }
1538 }
1539
1540 pub fn auth_plugin(&self) -> AuthPlugin<'_> {
1541 ParseBuf(self.auth_plugin.as_bytes())
1542 .parse(())
1543 .expect("infallible")
1544 }
1545
1546 pub fn plugin_data(&self) -> &[u8] {
1547 match self.plugin_data.as_bytes() {
1548 [head @ .., 0] => head,
1549 all => all,
1550 }
1551 }
1552
1553 pub fn into_owned(self) -> AuthSwitchRequest<'static> {
1554 AuthSwitchRequest {
1555 __header: self.__header,
1556 auth_plugin: self.auth_plugin.into_owned(),
1557 plugin_data: self.plugin_data.into_owned(),
1558 }
1559 }
1560}
1561
1562impl<'de> MyDeserialize<'de> for AuthSwitchRequest<'de> {
1563 const SIZE: Option<usize> = None;
1564 type Ctx = ();
1565
1566 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1567 Ok(Self {
1568 __header: buf.parse(())?,
1569 auth_plugin: buf.parse(())?,
1570 plugin_data: buf.parse(())?,
1571 })
1572 }
1573}
1574
1575impl MySerialize for AuthSwitchRequest<'_> {
1576 fn serialize(&self, buf: &mut Vec<u8>) {
1577 self.__header.serialize(&mut *buf);
1578 self.auth_plugin.serialize(&mut *buf);
1579 self.plugin_data.serialize(buf);
1580 }
1581}
1582
1583#[derive(Debug, Clone, Eq, PartialEq)]
1585pub struct HandshakePacket<'a> {
1586 protocol_version: RawInt<u8>,
1587 server_version: RawBytes<'a, NullBytes>,
1588 connection_id: RawInt<LeU32>,
1589 scramble_1: [u8; 8],
1590 __filler: Skip<1>,
1591 capabilities_1: Const<CapabilityFlags, LeU32LowerHalf>,
1593 default_collation: RawInt<u8>,
1594 status_flags: Const<StatusFlags, LeU16>,
1595 capabilities_2: Const<CapabilityFlags, LeU32UpperHalf>,
1597 auth_plugin_data_len: RawInt<u8>,
1598 __reserved: Skip<6>,
1599 mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
1601 scramble_2: Option<RawBytes<'a, BareBytes<{ (u8::MAX as usize) - 8 }>>>,
1602 auth_plugin_name: Option<RawBytes<'a, NullBytes>>,
1603}
1604
1605impl<'de> MyDeserialize<'de> for HandshakePacket<'de> {
1606 const SIZE: Option<usize> = None;
1607 type Ctx = ();
1608
1609 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1610 let protocol_version = buf.parse(())?;
1611 let server_version = buf.parse(())?;
1612
1613 let mut sbuf: ParseBuf<'_> = buf.parse(31)?;
1615 let connection_id = sbuf.parse_unchecked(())?;
1616 let scramble_1 = sbuf.parse_unchecked(())?;
1617 let __filler = sbuf.parse_unchecked(())?;
1618 let capabilities_1: RawConst<LeU32LowerHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1619 let default_collation = sbuf.parse_unchecked(())?;
1620 let status_flags = sbuf.parse_unchecked(())?;
1621 let capabilities_2: RawConst<LeU32UpperHalf, CapabilityFlags> = sbuf.parse_unchecked(())?;
1622 let auth_plugin_data_len: RawInt<u8> = sbuf.parse_unchecked(())?;
1623 let __reserved = sbuf.parse_unchecked(())?;
1624 let mariadb_capabiities: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
1627 let mut scramble_2 = None;
1628 if capabilities_1.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
1629 let len = max(13, auth_plugin_data_len.0 as i8 - 8) as usize;
1630 scramble_2 = buf.parse(len).map(Some)?;
1631 }
1632 let mut auth_plugin_name = None;
1633 if capabilities_2.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
1634 auth_plugin_name = match buf.eat_all() {
1635 [head @ .., 0] => Some(RawBytes::new(head)),
1636 all => Some(RawBytes::new(all)),
1638 }
1639 }
1640
1641 Ok(Self {
1642 protocol_version,
1643 server_version,
1644 connection_id,
1645 scramble_1,
1646 __filler,
1647 capabilities_1: Const::new(CapabilityFlags::from_bits_truncate(capabilities_1.0)),
1648 default_collation,
1649 status_flags,
1650 capabilities_2: Const::new(CapabilityFlags::from_bits_truncate(capabilities_2.0)),
1651 auth_plugin_data_len,
1652 __reserved,
1653 mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
1654 mariadb_capabiities.0,
1655 )),
1656 scramble_2,
1657 auth_plugin_name,
1658 })
1659 }
1660}
1661
1662impl MySerialize for HandshakePacket<'_> {
1663 fn serialize(&self, buf: &mut Vec<u8>) {
1664 self.protocol_version.serialize(&mut *buf);
1665 self.server_version.serialize(&mut *buf);
1666 self.connection_id.serialize(&mut *buf);
1667 self.scramble_1.serialize(&mut *buf);
1668 buf.put_u8(0x00);
1669 self.capabilities_1.serialize(&mut *buf);
1670 self.default_collation.serialize(&mut *buf);
1671 self.status_flags.serialize(&mut *buf);
1672 self.capabilities_2.serialize(&mut *buf);
1673
1674 if self
1675 .capabilities_2
1676 .contains(CapabilityFlags::CLIENT_PLUGIN_AUTH)
1677 {
1678 buf.put_u8(
1679 self.scramble_2
1680 .as_ref()
1681 .map(|x| (x.len() + 8) as u8)
1682 .unwrap_or_default(),
1683 );
1684 } else {
1685 buf.put_u8(0);
1686 }
1687
1688 self.__reserved.serialize(&mut *buf);
1689 self.mariadb_ext_capabilities.serialize(&mut *buf);
1690
1691 if let Some(scramble_2) = &self.scramble_2 {
1694 scramble_2.serialize(&mut *buf);
1695 }
1696
1697 if let Some(client_plugin_auth) = &self.auth_plugin_name {
1700 client_plugin_auth.serialize(buf);
1701 }
1702 }
1703}
1704
1705impl<'a> HandshakePacket<'a> {
1706 #[allow(clippy::too_many_arguments)]
1707 pub fn new(
1708 protocol_version: u8,
1709 server_version: impl Into<Cow<'a, [u8]>>,
1710 connection_id: u32,
1711 scramble_1: [u8; 8],
1712 scramble_2: Option<impl Into<Cow<'a, [u8]>>>,
1713 capabilities: CapabilityFlags,
1714 default_collation: u8,
1715 status_flags: StatusFlags,
1716 auth_plugin_name: Option<impl Into<Cow<'a, [u8]>>>,
1717 ) -> Self {
1718 let (capabilities_1, capabilities_2) = (
1722 CapabilityFlags::from_bits_retain(capabilities.bits() & 0x0000_FFFF),
1723 CapabilityFlags::from_bits_retain(capabilities.bits() & 0xFFFF_0000),
1724 );
1725
1726 let scramble_2 = scramble_2.map(RawBytes::new);
1727
1728 HandshakePacket {
1729 protocol_version: RawInt::new(protocol_version),
1730 server_version: RawBytes::new(server_version),
1731 connection_id: RawInt::new(connection_id),
1732 scramble_1,
1733 __filler: Skip,
1734 capabilities_1: Const::new(capabilities_1),
1735 default_collation: RawInt::new(default_collation),
1736 status_flags: Const::new(status_flags),
1737 capabilities_2: Const::new(capabilities_2),
1738 auth_plugin_data_len: RawInt::new(
1739 scramble_2
1740 .as_ref()
1741 .map(|x| x.len() as u8)
1742 .unwrap_or_default(),
1743 ),
1744 __reserved: Skip,
1745 mariadb_ext_capabilities: Const::new(MariadbCapabilities::empty()),
1746 scramble_2,
1747 auth_plugin_name: auth_plugin_name.map(RawBytes::new),
1748 }
1749 }
1750
1751 pub fn with_mariadb_ext_capabilities(mut self, flags: MariadbCapabilities) -> Self {
1752 self.mariadb_ext_capabilities = Const::new(flags);
1753 self
1754 }
1755
1756 pub fn into_owned(self) -> HandshakePacket<'static> {
1757 HandshakePacket {
1758 protocol_version: self.protocol_version,
1759 server_version: self.server_version.into_owned(),
1760 connection_id: self.connection_id,
1761 scramble_1: self.scramble_1,
1762 __filler: self.__filler,
1763 capabilities_1: self.capabilities_1,
1764 default_collation: self.default_collation,
1765 status_flags: self.status_flags,
1766 capabilities_2: self.capabilities_2,
1767 auth_plugin_data_len: self.auth_plugin_data_len,
1768 __reserved: self.__reserved,
1769 mariadb_ext_capabilities: self.mariadb_ext_capabilities,
1770 scramble_2: self.scramble_2.map(|x| x.into_owned()),
1771 auth_plugin_name: self.auth_plugin_name.map(RawBytes::into_owned),
1772 }
1773 }
1774
1775 pub fn protocol_version(&self) -> u8 {
1777 self.protocol_version.0
1778 }
1779
1780 pub fn server_version_ref(&self) -> &[u8] {
1782 self.server_version.as_bytes()
1783 }
1784
1785 pub fn server_version_str(&self) -> Cow<'_, str> {
1788 self.server_version.as_str()
1789 }
1790
1791 pub fn server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1795 VERSION_RE
1796 .captures(self.server_version_ref())
1797 .map(|captures| {
1798 (
1800 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1801 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1802 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1803 )
1804 })
1805 }
1806
1807 pub fn maria_db_server_version_parsed(&self) -> Option<(u16, u16, u16)> {
1809 MARIADB_VERSION_RE
1810 .captures(self.server_version_ref())
1811 .map(|captures| {
1812 (
1814 btoi::<u16>(captures.get(1).unwrap().as_bytes()).unwrap(),
1815 btoi::<u16>(captures.get(2).unwrap().as_bytes()).unwrap(),
1816 btoi::<u16>(captures.get(3).unwrap().as_bytes()).unwrap(),
1817 )
1818 })
1819 }
1820
1821 pub fn connection_id(&self) -> u32 {
1823 self.connection_id.0
1824 }
1825
1826 pub fn scramble_1_ref(&self) -> &[u8] {
1828 self.scramble_1.as_ref()
1829 }
1830
1831 pub fn scramble_2_ref(&self) -> Option<&[u8]> {
1835 self.scramble_2.as_ref().map(|x| x.as_bytes())
1836 }
1837
1838 pub fn nonce(&self) -> Vec<u8> {
1840 let mut out = Vec::from(self.scramble_1_ref());
1841 out.extend_from_slice(self.scramble_2_ref().unwrap_or(&[][..]));
1842
1843 out.resize(20, 0);
1846 out
1847 }
1848
1849 pub fn capabilities(&self) -> CapabilityFlags {
1851 self.capabilities_1.0 | self.capabilities_2.0
1852 }
1853
1854 pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
1856 self.mariadb_ext_capabilities.0
1857 }
1858 pub fn default_collation(&self) -> u8 {
1860 self.default_collation.0
1861 }
1862
1863 pub fn status_flags(&self) -> StatusFlags {
1865 self.status_flags.0
1866 }
1867
1868 pub fn auth_plugin_name_ref(&self) -> Option<&[u8]> {
1870 self.auth_plugin_name.as_ref().map(|x| x.as_bytes())
1871 }
1872
1873 pub fn auth_plugin_name_str(&self) -> Option<Cow<'_, str>> {
1876 self.auth_plugin_name.as_ref().map(|x| x.as_str())
1877 }
1878
1879 pub fn auth_plugin(&self) -> Option<AuthPlugin<'_>> {
1881 self.auth_plugin_name.as_ref().map(|x| match x.as_bytes() {
1882 [name @ .., 0] => ParseBuf(name).parse_unchecked(()).expect("infallible"),
1883 all => ParseBuf(all).parse_unchecked(()).expect("infallible"),
1884 })
1885 }
1886}
1887
1888define_header!(
1889 ComChangeUserHeader,
1890 InvalidComChangeUserHeader("Invalid COM_CHANGE_USER header"),
1891 0x11
1892);
1893
1894#[derive(Debug, Clone, PartialEq, Eq)]
1895pub struct ComChangeUser<'a> {
1896 __header: ComChangeUserHeader,
1897 user: RawBytes<'a, NullBytes>,
1898 auth_plugin_data: RawBytes<'a, U8Bytes>,
1900 database: RawBytes<'a, NullBytes>,
1901 more_data: Option<ComChangeUserMoreData<'a>>,
1902}
1903
1904impl<'a> ComChangeUser<'a> {
1905 pub fn new() -> Self {
1906 Self {
1907 __header: ComChangeUserHeader::new(),
1908 user: Default::default(),
1909 auth_plugin_data: Default::default(),
1910 database: Default::default(),
1911 more_data: None,
1912 }
1913 }
1914
1915 pub fn with_user(mut self, user: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1916 self.user = user.map(RawBytes::new).unwrap_or_default();
1917 self
1918 }
1919
1920 pub fn with_database(mut self, database: Option<impl Into<Cow<'a, [u8]>>>) -> Self {
1921 self.database = database.map(RawBytes::new).unwrap_or_default();
1922 self
1923 }
1924
1925 pub fn with_auth_plugin_data(
1926 mut self,
1927 auth_plugin_data: Option<impl Into<Cow<'a, [u8]>>>,
1928 ) -> Self {
1929 self.auth_plugin_data = auth_plugin_data.map(RawBytes::new).unwrap_or_default();
1930 self
1931 }
1932
1933 pub fn with_more_data(mut self, more_data: Option<ComChangeUserMoreData<'a>>) -> Self {
1934 self.more_data = more_data;
1935 self
1936 }
1937
1938 pub fn into_owned(self) -> ComChangeUser<'static> {
1939 ComChangeUser {
1940 __header: self.__header,
1941 user: self.user.into_owned(),
1942 auth_plugin_data: self.auth_plugin_data.into_owned(),
1943 database: self.database.into_owned(),
1944 more_data: self.more_data.map(|x| x.into_owned()),
1945 }
1946 }
1947}
1948
1949impl Default for ComChangeUser<'_> {
1950 fn default() -> Self {
1951 Self::new()
1952 }
1953}
1954
1955impl<'de> MyDeserialize<'de> for ComChangeUser<'de> {
1956 const SIZE: Option<usize> = None;
1957
1958 type Ctx = CapabilityFlags;
1959
1960 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
1961 Ok(Self {
1962 __header: buf.parse(())?,
1963 user: buf.parse(())?,
1964 auth_plugin_data: buf.parse(())?,
1965 database: buf.parse(())?,
1966 more_data: if !buf.is_empty() {
1967 Some(buf.parse(flags)?)
1968 } else {
1969 None
1970 },
1971 })
1972 }
1973}
1974
1975impl MySerialize for ComChangeUser<'_> {
1976 fn serialize(&self, buf: &mut Vec<u8>) {
1977 self.__header.serialize(&mut *buf);
1978 self.user.serialize(&mut *buf);
1979 self.auth_plugin_data.serialize(&mut *buf);
1980 self.database.serialize(&mut *buf);
1981 if let Some(ref more_data) = self.more_data {
1982 more_data.serialize(&mut *buf);
1983 }
1984 }
1985}
1986
1987#[derive(Debug, Clone, PartialEq, Eq)]
1988pub struct ComChangeUserMoreData<'a> {
1989 character_set: RawInt<LeU16>,
1990 auth_plugin: Option<AuthPlugin<'a>>,
1991 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
1992}
1993
1994impl<'a> ComChangeUserMoreData<'a> {
1995 pub fn new(character_set: u16) -> Self {
1996 Self {
1997 character_set: RawInt::new(character_set),
1998 auth_plugin: None,
1999 connect_attributes: None,
2000 }
2001 }
2002
2003 pub fn with_auth_plugin(mut self, auth_plugin: Option<AuthPlugin<'a>>) -> Self {
2004 self.auth_plugin = auth_plugin;
2005 self
2006 }
2007
2008 pub fn with_connect_attributes(
2009 mut self,
2010 connect_attributes: Option<HashMap<String, String>>,
2011 ) -> Self {
2012 self.connect_attributes = connect_attributes.map(|attrs| {
2013 attrs
2014 .into_iter()
2015 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2016 .collect()
2017 });
2018 self
2019 }
2020
2021 pub fn into_owned(self) -> ComChangeUserMoreData<'static> {
2022 ComChangeUserMoreData {
2023 character_set: self.character_set,
2024 auth_plugin: self.auth_plugin.map(|x| x.into_owned()),
2025 connect_attributes: self.connect_attributes.map(|x| {
2026 x.into_iter()
2027 .map(|(k, v)| (k.into_owned(), v.into_owned()))
2028 .collect()
2029 }),
2030 }
2031 }
2032}
2033
2034fn deserialize_connect_attrs<'de>(
2036 buf: &mut ParseBuf<'de>,
2037) -> io::Result<HashMap<RawBytes<'de, LenEnc>, RawBytes<'de, LenEnc>>> {
2038 let data_len = buf.parse::<RawInt<LenEnc>>(())?;
2039 let mut data: ParseBuf<'_> = buf.parse(data_len.0 as usize)?;
2040 let mut attrs = HashMap::new();
2041 while !data.is_empty() {
2042 let key = data.parse::<RawBytes<'_, LenEnc>>(())?;
2043 let value = data.parse::<RawBytes<'_, LenEnc>>(())?;
2044 attrs.insert(key, value);
2045 }
2046 Ok(attrs)
2047}
2048
2049fn serialize_connect_attrs<'a>(
2051 connect_attributes: &HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>,
2052 buf: &mut Vec<u8>,
2053) {
2054 let len = connect_attributes
2055 .iter()
2056 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2057 .sum::<u64>();
2058 buf.put_lenenc_int(len);
2059
2060 for (name, value) in connect_attributes {
2061 name.serialize(&mut *buf);
2062 value.serialize(&mut *buf);
2063 }
2064}
2065
2066impl<'de> MyDeserialize<'de> for ComChangeUserMoreData<'de> {
2067 const SIZE: Option<usize> = None;
2068 type Ctx = CapabilityFlags;
2069
2070 fn deserialize(flags: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2071 let character_set = buf.parse(())?;
2073 let mut auth_plugin = None;
2074 let mut connect_attributes = None;
2075
2076 if flags.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
2077 match buf.parse::<RawBytes<'_, NullBytes>>(())?.0 {
2079 Cow::Borrowed(bytes) => {
2080 let mut auth_plugin_buf = ParseBuf(bytes);
2081 auth_plugin = Some(auth_plugin_buf.parse(())?);
2082 }
2083 _ => unreachable!(),
2084 }
2085 };
2086
2087 if flags.contains(CapabilityFlags::CLIENT_CONNECT_ATTRS) {
2088 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2089 };
2090
2091 Ok(Self {
2092 character_set,
2093 auth_plugin,
2094 connect_attributes,
2095 })
2096 }
2097}
2098
2099impl MySerialize for ComChangeUserMoreData<'_> {
2100 fn serialize(&self, buf: &mut Vec<u8>) {
2101 self.character_set.serialize(&mut *buf);
2102 if let Some(ref auth_plugin) = self.auth_plugin {
2103 auth_plugin.serialize(&mut *buf);
2104 }
2105 if let Some(ref connect_attributes) = self.connect_attributes {
2106 serialize_connect_attrs(connect_attributes, buf);
2107 } else {
2108 serialize_connect_attrs(&Default::default(), buf);
2111 }
2112 }
2113}
2114
2115type ScrambleBuf<'a> =
2117 Either<RawBytes<'a, LenEnc>, Either<RawBytes<'a, U8Bytes>, RawBytes<'a, NullBytes>>>;
2118
2119#[derive(Debug, Clone, PartialEq, Eq)]
2120pub struct HandshakeResponse<'a> {
2121 capabilities: Const<CapabilityFlags, LeU32>,
2122 max_packet_size: RawInt<LeU32>,
2123 collation: RawInt<u8>,
2124 scramble_buf: ScrambleBuf<'a>,
2125 user: RawBytes<'a, NullBytes>,
2126 db_name: Option<RawBytes<'a, NullBytes>>,
2127 auth_plugin: Option<AuthPlugin<'a>>,
2128 connect_attributes: Option<HashMap<RawBytes<'a, LenEnc>, RawBytes<'a, LenEnc>>>,
2129 mariadb_ext_capabilities: Const<MariadbCapabilities, LeU32>,
2130}
2131
2132impl<'a> HandshakeResponse<'a> {
2133 #[allow(clippy::too_many_arguments)]
2134 pub fn new(
2135 scramble_buf: Option<impl Into<Cow<'a, [u8]>>>,
2136 server_version: (u16, u16, u16),
2137 user: Option<impl Into<Cow<'a, [u8]>>>,
2138 db_name: Option<impl Into<Cow<'a, [u8]>>>,
2139 auth_plugin: Option<AuthPlugin<'a>>,
2140 mut capabilities: CapabilityFlags,
2141 connect_attributes: Option<HashMap<String, String>>,
2142 max_packet_size: u32,
2143 ) -> Self {
2144 let scramble_buf =
2145 if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
2146 Either::Left(RawBytes::new(
2147 scramble_buf.map(Into::into).unwrap_or_default(),
2148 ))
2149 } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
2150 Either::Right(Either::Left(RawBytes::new(
2151 scramble_buf.map(Into::into).unwrap_or_default(),
2152 )))
2153 } else {
2154 Either::Right(Either::Right(RawBytes::new(
2155 scramble_buf.map(Into::into).unwrap_or_default(),
2156 )))
2157 };
2158
2159 if db_name.is_some() {
2160 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2161 } else {
2162 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
2163 }
2164
2165 if auth_plugin.is_some() {
2166 capabilities.insert(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2167 } else {
2168 capabilities.remove(CapabilityFlags::CLIENT_PLUGIN_AUTH);
2169 }
2170
2171 if connect_attributes.is_some() {
2172 capabilities.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2173 } else {
2174 capabilities.remove(CapabilityFlags::CLIENT_CONNECT_ATTRS);
2175 }
2176
2177 Self {
2178 scramble_buf,
2179 collation: if server_version >= (5, 5, 3) {
2180 RawInt::new(CollationId::UTF8MB4_GENERAL_CI as u8)
2181 } else {
2182 RawInt::new(CollationId::UTF8MB3_GENERAL_CI as u8)
2183 },
2184 user: user.map(RawBytes::new).unwrap_or_default(),
2185 db_name: db_name.map(RawBytes::new),
2186 auth_plugin,
2187 capabilities: Const::new(capabilities),
2188 connect_attributes: connect_attributes.map(|attrs| {
2189 attrs
2190 .into_iter()
2191 .map(|(k, v)| (RawBytes::new(k.into_bytes()), RawBytes::new(v.into_bytes())))
2192 .collect()
2193 }),
2194 max_packet_size: RawInt::new(max_packet_size),
2195 mariadb_ext_capabilities: Const::new(MariadbCapabilities::empty()),
2196 }
2197 }
2198
2199 pub fn with_mariadb_ext_capabilities(
2200 mut self,
2201 mariadb_ext_capabilities: MariadbCapabilities,
2202 ) -> Self {
2203 self.mariadb_ext_capabilities = Const::new(mariadb_ext_capabilities);
2204 self
2205 }
2206
2207 pub fn capabilities(&self) -> CapabilityFlags {
2208 self.capabilities.0
2209 }
2210
2211 pub fn mariadb_ext_capabilities(&self) -> MariadbCapabilities {
2212 self.mariadb_ext_capabilities.0
2213 }
2214
2215 pub fn collation(&self) -> u8 {
2216 self.collation.0
2217 }
2218
2219 pub fn scramble_buf(&self) -> &[u8] {
2220 match &self.scramble_buf {
2221 Either::Left(x) => x.as_bytes(),
2222 Either::Right(x) => match x {
2223 Either::Left(x) => x.as_bytes(),
2224 Either::Right(x) => x.as_bytes(),
2225 },
2226 }
2227 }
2228
2229 pub fn user(&self) -> &[u8] {
2230 self.user.as_bytes()
2231 }
2232
2233 pub fn db_name(&self) -> Option<&[u8]> {
2234 self.db_name.as_ref().map(|x| x.as_bytes())
2235 }
2236
2237 pub fn auth_plugin(&self) -> Option<&AuthPlugin<'a>> {
2238 self.auth_plugin.as_ref()
2239 }
2240
2241 #[must_use = "entails computation"]
2242 pub fn connect_attributes(&self) -> Option<HashMap<String, String>> {
2243 self.connect_attributes.as_ref().map(|attrs| {
2244 attrs
2245 .iter()
2246 .map(|(k, v)| (k.as_str().into_owned(), v.as_str().into_owned()))
2247 .collect()
2248 })
2249 }
2250}
2251
2252impl<'de> MyDeserialize<'de> for HandshakeResponse<'de> {
2253 const SIZE: Option<usize> = None;
2254 type Ctx = ();
2255
2256 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2257 let mut sbuf: ParseBuf<'_> = buf.parse(4 + 4 + 1 + 23)?;
2258 let client_flags: RawConst<LeU32, CapabilityFlags> = sbuf.parse_unchecked(())?;
2259 let max_packet_size: RawInt<LeU32> = sbuf.parse_unchecked(())?;
2260 let collation = sbuf.parse_unchecked(())?;
2261 sbuf.parse_unchecked::<Skip<19>>(())?;
2262 let mariadb_flags: RawConst<LeU32, MariadbCapabilities> = sbuf.parse_unchecked(())?;
2263
2264 let user = buf.parse(())?;
2265 let scramble_buf =
2266 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.bits() > 0 {
2267 Either::Left(buf.parse(())?)
2268 } else if client_flags.0 & CapabilityFlags::CLIENT_SECURE_CONNECTION.bits() > 0 {
2269 Either::Right(Either::Left(buf.parse(())?))
2270 } else {
2271 Either::Right(Either::Right(buf.parse(())?))
2272 };
2273
2274 let mut db_name = None;
2275 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_WITH_DB.bits() > 0 {
2276 db_name = buf.parse(()).map(Some)?;
2277 }
2278
2279 let mut auth_plugin = None;
2280 if client_flags.0 & CapabilityFlags::CLIENT_PLUGIN_AUTH.bits() > 0 {
2281 let auth_plugin_name = buf.eat_null_str();
2282 auth_plugin = Some(AuthPlugin::from_bytes(auth_plugin_name));
2283 }
2284
2285 let mut connect_attributes = None;
2286 if client_flags.0 & CapabilityFlags::CLIENT_CONNECT_ATTRS.bits() > 0 {
2287 connect_attributes = Some(deserialize_connect_attrs(&mut *buf)?);
2288 }
2289
2290 Ok(Self {
2291 capabilities: Const::new(CapabilityFlags::from_bits_truncate(client_flags.0)),
2292 max_packet_size,
2293 collation,
2294 scramble_buf,
2295 user,
2296 db_name,
2297 auth_plugin,
2298 connect_attributes,
2299 mariadb_ext_capabilities: Const::new(MariadbCapabilities::from_bits_truncate(
2300 mariadb_flags.0,
2301 )),
2302 })
2303 }
2304}
2305
2306impl MySerialize for HandshakeResponse<'_> {
2307 fn serialize(&self, buf: &mut Vec<u8>) {
2308 self.capabilities.serialize(&mut *buf);
2309 self.max_packet_size.serialize(&mut *buf);
2310 self.collation.serialize(&mut *buf);
2311 buf.put_slice(&[0; 19]);
2312 self.mariadb_ext_capabilities.serialize(&mut *buf);
2313 self.user.serialize(&mut *buf);
2314 self.scramble_buf.serialize(&mut *buf);
2315
2316 if let Some(db_name) = &self.db_name {
2317 db_name.serialize(&mut *buf);
2318 }
2319
2320 if let Some(auth_plugin) = &self.auth_plugin {
2321 auth_plugin.serialize(&mut *buf);
2322 }
2323
2324 if let Some(attrs) = &self.connect_attributes {
2325 let len = attrs
2326 .iter()
2327 .map(|(k, v)| lenenc_str_len(k.as_bytes()) + lenenc_str_len(v.as_bytes()))
2328 .sum::<u64>();
2329 buf.put_lenenc_int(len);
2330
2331 for (name, value) in attrs {
2332 name.serialize(&mut *buf);
2333 value.serialize(&mut *buf);
2334 }
2335 }
2336 }
2337}
2338
2339#[derive(Debug, Clone, Eq, PartialEq)]
2340pub struct SslRequest {
2341 capabilities: Const<CapabilityFlags, LeU32>,
2342 max_packet_size: RawInt<LeU32>,
2343 character_set: RawInt<u8>,
2344 __skip: Skip<23>,
2345}
2346
2347impl SslRequest {
2348 pub fn new(capabilities: CapabilityFlags, max_packet_size: u32, character_set: u8) -> Self {
2349 Self {
2350 capabilities: Const::new(capabilities),
2351 max_packet_size: RawInt::new(max_packet_size),
2352 character_set: RawInt::new(character_set),
2353 __skip: Skip,
2354 }
2355 }
2356
2357 pub fn capabilities(&self) -> CapabilityFlags {
2358 self.capabilities.0
2359 }
2360
2361 pub fn max_packet_size(&self) -> u32 {
2362 self.max_packet_size.0
2363 }
2364
2365 pub fn character_set(&self) -> u8 {
2366 self.character_set.0
2367 }
2368}
2369
2370impl<'de> MyDeserialize<'de> for SslRequest {
2371 const SIZE: Option<usize> = Some(4 + 4 + 1 + 23);
2372 type Ctx = ();
2373
2374 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2375 let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2376 let raw_capabilities = buf.parse_unchecked::<RawConst<LeU32, CapabilityFlags>>(())?;
2377 Ok(Self {
2378 capabilities: Const::new(CapabilityFlags::from_bits_truncate(raw_capabilities.0)),
2379 max_packet_size: buf.parse_unchecked(())?,
2380 character_set: buf.parse_unchecked(())?,
2381 __skip: buf.parse_unchecked(())?,
2382 })
2383 }
2384}
2385
2386impl MySerialize for SslRequest {
2387 fn serialize(&self, buf: &mut Vec<u8>) {
2388 self.capabilities.serialize(&mut *buf);
2389 self.max_packet_size.serialize(&mut *buf);
2390 self.character_set.serialize(&mut *buf);
2391 self.__skip.serialize(&mut *buf);
2392 }
2393}
2394
2395#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
2396#[error("Invalid statement packet status")]
2397pub struct InvalidStmtPacketStatus;
2398
2399#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2401pub struct StmtPacket {
2402 status: ConstU8<InvalidStmtPacketStatus, 0x00>,
2403 statement_id: RawInt<LeU32>,
2404 num_columns: RawInt<LeU16>,
2405 num_params: RawInt<LeU16>,
2406 __skip: Skip<1>,
2407 warning_count: RawInt<LeU16>,
2408}
2409
2410impl<'de> MyDeserialize<'de> for StmtPacket {
2411 const SIZE: Option<usize> = Some(12);
2412 type Ctx = ();
2413
2414 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2415 let mut buf: ParseBuf<'_> = buf.parse(Self::SIZE.unwrap())?;
2416 Ok(StmtPacket {
2417 status: buf.parse_unchecked(())?,
2418 statement_id: buf.parse_unchecked(())?,
2419 num_columns: buf.parse_unchecked(())?,
2420 num_params: buf.parse_unchecked(())?,
2421 __skip: buf.parse_unchecked(())?,
2422 warning_count: buf.parse_unchecked(())?,
2423 })
2424 }
2425}
2426
2427impl MySerialize for StmtPacket {
2428 fn serialize(&self, buf: &mut Vec<u8>) {
2429 self.status.serialize(&mut *buf);
2430 self.statement_id.serialize(&mut *buf);
2431 self.num_columns.serialize(&mut *buf);
2432 self.num_params.serialize(&mut *buf);
2433 self.__skip.serialize(&mut *buf);
2434 self.warning_count.serialize(&mut *buf);
2435 }
2436}
2437
2438impl StmtPacket {
2439 pub fn statement_id(&self) -> u32 {
2441 *self.statement_id
2442 }
2443
2444 pub fn num_columns(&self) -> u16 {
2446 *self.num_columns
2447 }
2448
2449 pub fn num_params(&self) -> u16 {
2451 *self.num_params
2452 }
2453
2454 pub fn warning_count(&self) -> u16 {
2456 *self.warning_count
2457 }
2458}
2459
2460#[derive(Debug, Clone, Eq, PartialEq)]
2464pub struct NullBitmap<T, U: AsRef<[u8]> = Vec<u8>>(U, PhantomData<T>);
2465
2466impl<'de, T: SerializationSide> MyDeserialize<'de> for NullBitmap<T, Cow<'de, [u8]>> {
2467 const SIZE: Option<usize> = None;
2468 type Ctx = usize;
2469
2470 fn deserialize(num_columns: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2471 let bitmap_len = Self::bitmap_len(num_columns);
2472 let bytes = buf.checked_eat(bitmap_len).ok_or_else(unexpected_buf_eof)?;
2473 Ok(Self::from_bytes(Cow::Borrowed(bytes)))
2474 }
2475}
2476
2477impl<T: SerializationSide> NullBitmap<T, Vec<u8>> {
2478 pub fn new(num_columns: usize) -> Self {
2480 Self::from_bytes(vec![0; Self::bitmap_len(num_columns)])
2481 }
2482
2483 pub fn read(input: &mut &[u8], num_columns: usize) -> Self {
2485 let bitmap_len = Self::bitmap_len(num_columns);
2486 assert!(input.len() >= bitmap_len);
2487
2488 let bitmap = Self::from_bytes(input[..bitmap_len].to_vec());
2489 *input = &input[bitmap_len..];
2490
2491 bitmap
2492 }
2493}
2494
2495impl<T: SerializationSide, U: AsRef<[u8]>> NullBitmap<T, U> {
2496 pub fn bitmap_len(num_columns: usize) -> usize {
2497 (num_columns + 7 + T::BIT_OFFSET) / 8
2498 }
2499
2500 fn byte_and_bit(&self, column_index: usize) -> (usize, u8) {
2501 let offset = column_index + T::BIT_OFFSET;
2502 let byte = offset / 8;
2503 let bit = 1 << (offset % 8) as u8;
2504
2505 assert!(byte < self.0.as_ref().len());
2506
2507 (byte, bit)
2508 }
2509
2510 pub fn from_bytes(bytes: U) -> Self {
2512 Self(bytes, PhantomData)
2513 }
2514
2515 pub fn is_null(&self, column_index: usize) -> bool {
2517 let (byte, bit) = self.byte_and_bit(column_index);
2518 self.0.as_ref()[byte] & bit > 0
2519 }
2520}
2521
2522impl<T: SerializationSide, U: AsRef<[u8]> + AsMut<[u8]>> NullBitmap<T, U> {
2523 pub fn set(&mut self, column_index: usize, is_null: bool) {
2525 let (byte, bit) = self.byte_and_bit(column_index);
2526 if is_null {
2527 self.0.as_mut()[byte] |= bit
2528 } else {
2529 self.0.as_mut()[byte] &= !bit
2530 }
2531 }
2532}
2533
2534impl<T, U: AsRef<[u8]>> AsRef<[u8]> for NullBitmap<T, U> {
2535 fn as_ref(&self) -> &[u8] {
2536 self.0.as_ref()
2537 }
2538}
2539
2540#[derive(Debug, Clone, PartialEq)]
2541pub struct ComStmtExecuteRequestBuilder {
2542 pub stmt_id: u32,
2543}
2544
2545impl ComStmtExecuteRequestBuilder {
2546 pub const NULL_BITMAP_OFFSET: usize = 10;
2547
2548 pub fn new(stmt_id: u32) -> Self {
2549 Self { stmt_id }
2550 }
2551}
2552
2553impl ComStmtExecuteRequestBuilder {
2554 pub fn build(self, params: &[Value]) -> (ComStmtExecuteRequest<'_>, bool) {
2555 let bitmap_len = NullBitmap::<ClientSide>::bitmap_len(params.len());
2556
2557 let mut bitmap_bytes = vec![0; bitmap_len];
2558 let mut bitmap = NullBitmap::<ClientSide, _>::from_bytes(&mut bitmap_bytes);
2559 let params = params.iter().collect::<Vec<_>>();
2560
2561 let meta_len = params.len() * 2;
2562
2563 let mut data_len = 0;
2564 for (i, param) in params.iter().enumerate() {
2565 match param.bin_len() as usize {
2566 0 => bitmap.set(i, true),
2567 x => data_len += x,
2568 }
2569 }
2570
2571 let total_len = 10 + bitmap_len + 1 + meta_len + data_len;
2572
2573 let as_long_data = total_len > MAX_PAYLOAD_LEN;
2574
2575 (
2576 ComStmtExecuteRequest {
2577 com_stmt_execute: ConstU8::new(),
2578 stmt_id: RawInt::new(self.stmt_id),
2579 flags: Const::new(CursorType::CURSOR_TYPE_NO_CURSOR),
2580 iteration_count: ConstU32::new(),
2581 params_flags: Const::new(StmtExecuteParamsFlags::NEW_PARAMS_BOUND),
2582 bitmap: RawBytes::new(bitmap_bytes),
2583 params,
2584 as_long_data,
2585 },
2586 as_long_data,
2587 )
2588 }
2589}
2590
2591define_header!(
2592 ComStmtExecuteHeader,
2593 COM_STMT_EXECUTE,
2594 InvalidComStmtExecuteHeader
2595);
2596
2597define_const!(
2598 ConstU32,
2599 IterationCount,
2600 InvalidIterationCount("Invalid iteration count for COM_STMT_EXECUTE"),
2601 1
2602);
2603
2604#[derive(Debug, Clone, PartialEq)]
2605pub struct ComStmtExecuteRequest<'a> {
2606 com_stmt_execute: ComStmtExecuteHeader,
2607 stmt_id: RawInt<LeU32>,
2608 flags: Const<CursorType, u8>,
2609 iteration_count: IterationCount,
2610 bitmap: RawBytes<'a, BareBytes<8192>>,
2612 params_flags: Const<StmtExecuteParamsFlags, u8>,
2613 params: Vec<&'a Value>,
2614 as_long_data: bool,
2615}
2616
2617impl<'a> ComStmtExecuteRequest<'a> {
2618 pub fn stmt_id(&self) -> u32 {
2619 self.stmt_id.0
2620 }
2621
2622 pub fn flags(&self) -> CursorType {
2623 self.flags.0
2624 }
2625
2626 pub fn bitmap(&self) -> &[u8] {
2627 self.bitmap.as_bytes()
2628 }
2629
2630 pub fn params_flags(&self) -> StmtExecuteParamsFlags {
2631 self.params_flags.0
2632 }
2633
2634 pub fn params(&self) -> &[&'a Value] {
2635 self.params.as_ref()
2636 }
2637
2638 pub fn as_long_data(&self) -> bool {
2639 self.as_long_data
2640 }
2641}
2642
2643impl MySerialize for ComStmtExecuteRequest<'_> {
2644 fn serialize(&self, buf: &mut Vec<u8>) {
2645 self.com_stmt_execute.serialize(&mut *buf);
2646 self.stmt_id.serialize(&mut *buf);
2647 self.flags.serialize(&mut *buf);
2648 self.iteration_count.serialize(&mut *buf);
2649
2650 if !self.params.is_empty() {
2651 self.bitmap.serialize(&mut *buf);
2652 self.params_flags.serialize(&mut *buf);
2653 }
2654
2655 for param in &self.params {
2656 let (column_type, flags) = match param {
2657 Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
2658 Value::Bytes(_) => (
2659 ColumnType::MYSQL_TYPE_VAR_STRING,
2660 StmtExecuteParamFlags::empty(),
2661 ),
2662 Value::Int(_) => (
2663 ColumnType::MYSQL_TYPE_LONGLONG,
2664 StmtExecuteParamFlags::empty(),
2665 ),
2666 Value::UInt(_) => (
2667 ColumnType::MYSQL_TYPE_LONGLONG,
2668 StmtExecuteParamFlags::UNSIGNED,
2669 ),
2670 Value::Float(_) => (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()),
2671 Value::Double(_) => (
2672 ColumnType::MYSQL_TYPE_DOUBLE,
2673 StmtExecuteParamFlags::empty(),
2674 ),
2675 Value::Date(..) => (
2676 ColumnType::MYSQL_TYPE_DATETIME,
2677 StmtExecuteParamFlags::empty(),
2678 ),
2679 Value::Time(..) => (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()),
2680 };
2681
2682 buf.put_slice(&[column_type as u8, flags.bits()]);
2683 }
2684
2685 for param in &self.params {
2686 match **param {
2687 Value::Int(_)
2688 | Value::UInt(_)
2689 | Value::Float(_)
2690 | Value::Double(_)
2691 | Value::Date(..)
2692 | Value::Time(..) => {
2693 param.serialize(buf);
2694 }
2695 Value::Bytes(_) if !self.as_long_data => {
2696 param.serialize(buf);
2697 }
2698 Value::Bytes(_) | Value::NULL => {}
2699 }
2700 }
2701 }
2702}
2703
2704define_header!(
2705 ComStmtSendLongDataHeader,
2706 COM_STMT_SEND_LONG_DATA,
2707 InvalidComStmtSendLongDataHeader
2708);
2709
2710#[derive(Debug, Clone, Eq, PartialEq)]
2711pub struct ComStmtSendLongData<'a> {
2712 __header: ComStmtSendLongDataHeader,
2713 stmt_id: RawInt<LeU32>,
2714 param_index: RawInt<LeU16>,
2715 data: RawBytes<'a, EofBytes>,
2716}
2717
2718impl<'a> ComStmtSendLongData<'a> {
2719 pub fn new(stmt_id: u32, param_index: u16, data: impl Into<Cow<'a, [u8]>>) -> Self {
2720 Self {
2721 __header: ComStmtSendLongDataHeader::new(),
2722 stmt_id: RawInt::new(stmt_id),
2723 param_index: RawInt::new(param_index),
2724 data: RawBytes::new(data),
2725 }
2726 }
2727
2728 pub fn into_owned(self) -> ComStmtSendLongData<'static> {
2729 ComStmtSendLongData {
2730 __header: self.__header,
2731 stmt_id: self.stmt_id,
2732 param_index: self.param_index,
2733 data: self.data.into_owned(),
2734 }
2735 }
2736}
2737
2738impl MySerialize for ComStmtSendLongData<'_> {
2739 fn serialize(&self, buf: &mut Vec<u8>) {
2740 self.__header.serialize(&mut *buf);
2741 self.stmt_id.serialize(&mut *buf);
2742 self.param_index.serialize(&mut *buf);
2743 self.data.serialize(&mut *buf);
2744 }
2745}
2746
2747#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2748pub struct ComStmtClose {
2749 pub stmt_id: u32,
2750}
2751
2752impl ComStmtClose {
2753 pub fn new(stmt_id: u32) -> Self {
2754 Self { stmt_id }
2755 }
2756}
2757
2758impl MySerialize for ComStmtClose {
2759 fn serialize(&self, buf: &mut Vec<u8>) {
2760 buf.put_u8(Command::COM_STMT_CLOSE as u8);
2761 buf.put_u32_le(self.stmt_id);
2762 }
2763}
2764
2765define_header!(
2766 ComRegisterSlaveHeader,
2767 COM_REGISTER_SLAVE,
2768 InvalidComRegisterSlaveHeader
2769);
2770
2771#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2774pub struct ComRegisterSlave<'a> {
2775 header: ComRegisterSlaveHeader,
2776 server_id: RawInt<LeU32>,
2778 hostname: RawBytes<'a, U8Bytes>,
2781 user: RawBytes<'a, U8Bytes>,
2788 password: RawBytes<'a, U8Bytes>,
2795 port: RawInt<LeU16>,
2802 replication_rank: RawInt<LeU32>,
2804 master_id: RawInt<LeU32>,
2807}
2808
2809impl<'a> ComRegisterSlave<'a> {
2810 pub fn new(server_id: u32) -> Self {
2812 Self {
2813 header: Default::default(),
2814 server_id: RawInt::new(server_id),
2815 hostname: Default::default(),
2816 user: Default::default(),
2817 password: Default::default(),
2818 port: Default::default(),
2819 replication_rank: Default::default(),
2820 master_id: Default::default(),
2821 }
2822 }
2823
2824 pub fn with_hostname(mut self, hostname: impl Into<Cow<'a, [u8]>>) -> Self {
2826 self.hostname = RawBytes::new(hostname);
2827 self
2828 }
2829
2830 pub fn with_user(mut self, user: impl Into<Cow<'a, [u8]>>) -> Self {
2832 self.user = RawBytes::new(user);
2833 self
2834 }
2835
2836 pub fn with_password(mut self, password: impl Into<Cow<'a, [u8]>>) -> Self {
2838 self.password = RawBytes::new(password);
2839 self
2840 }
2841
2842 pub fn with_port(mut self, port: u16) -> Self {
2844 self.port = RawInt::new(port);
2845 self
2846 }
2847
2848 pub fn with_replication_rank(mut self, replication_rank: u32) -> Self {
2850 self.replication_rank = RawInt::new(replication_rank);
2851 self
2852 }
2853
2854 pub fn with_master_id(mut self, master_id: u32) -> Self {
2856 self.master_id = RawInt::new(master_id);
2857 self
2858 }
2859
2860 pub fn server_id(&self) -> u32 {
2862 self.server_id.0
2863 }
2864
2865 pub fn hostname_raw(&self) -> &[u8] {
2867 self.hostname.as_bytes()
2868 }
2869
2870 pub fn hostname(&'a self) -> Cow<'a, str> {
2872 self.hostname.as_str()
2873 }
2874
2875 pub fn user_raw(&self) -> &[u8] {
2877 self.user.as_bytes()
2878 }
2879
2880 pub fn user(&'a self) -> Cow<'a, str> {
2882 self.user.as_str()
2883 }
2884
2885 pub fn password_raw(&self) -> &[u8] {
2887 self.password.as_bytes()
2888 }
2889
2890 pub fn password(&'a self) -> Cow<'a, str> {
2892 self.password.as_str()
2893 }
2894
2895 pub fn port(&self) -> u16 {
2897 self.port.0
2898 }
2899
2900 pub fn replication_rank(&self) -> u32 {
2902 self.replication_rank.0
2903 }
2904
2905 pub fn master_id(&self) -> u32 {
2907 self.master_id.0
2908 }
2909}
2910
2911impl MySerialize for ComRegisterSlave<'_> {
2912 fn serialize(&self, buf: &mut Vec<u8>) {
2913 self.header.serialize(&mut *buf);
2914 self.server_id.serialize(&mut *buf);
2915 self.hostname.serialize(&mut *buf);
2916 self.user.serialize(&mut *buf);
2917 self.password.serialize(&mut *buf);
2918 self.port.serialize(&mut *buf);
2919 self.replication_rank.serialize(&mut *buf);
2920 self.master_id.serialize(&mut *buf);
2921 }
2922}
2923
2924impl<'de> MyDeserialize<'de> for ComRegisterSlave<'de> {
2925 const SIZE: Option<usize> = None;
2926 type Ctx = ();
2927
2928 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
2929 let mut sbuf: ParseBuf<'_> = buf.parse(5)?;
2930 let header = sbuf.parse_unchecked(())?;
2931 let server_id = sbuf.parse_unchecked(())?;
2932
2933 let hostname = buf.parse(())?;
2934 let user = buf.parse(())?;
2935 let password = buf.parse(())?;
2936
2937 let mut sbuf: ParseBuf<'_> = buf.parse(10)?;
2938 let port = sbuf.parse_unchecked(())?;
2939 let replication_rank = sbuf.parse_unchecked(())?;
2940 let master_id = sbuf.parse_unchecked(())?;
2941
2942 Ok(Self {
2943 header,
2944 server_id,
2945 hostname,
2946 user,
2947 password,
2948 port,
2949 replication_rank,
2950 master_id,
2951 })
2952 }
2953}
2954
2955define_header!(
2956 ComTableDumpHeader,
2957 COM_TABLE_DUMP,
2958 InvalidComTableDumpHeader
2959);
2960
2961#[derive(Debug, Clone, Eq, PartialEq, Hash)]
2963pub struct ComTableDump<'a> {
2964 header: ComTableDumpHeader,
2965 database: RawBytes<'a, U8Bytes>,
2971 table: RawBytes<'a, U8Bytes>,
2977}
2978
2979impl<'a> ComTableDump<'a> {
2980 pub fn new(database: impl Into<Cow<'a, [u8]>>, table: impl Into<Cow<'a, [u8]>>) -> Self {
2982 Self {
2983 header: Default::default(),
2984 database: RawBytes::new(database),
2985 table: RawBytes::new(table),
2986 }
2987 }
2988
2989 pub fn database_raw(&self) -> &[u8] {
2991 self.database.as_bytes()
2992 }
2993
2994 pub fn database(&self) -> Cow<'_, str> {
2996 self.database.as_str()
2997 }
2998
2999 pub fn table_raw(&self) -> &[u8] {
3001 self.table.as_bytes()
3002 }
3003
3004 pub fn table(&self) -> Cow<'_, str> {
3006 self.table.as_str()
3007 }
3008}
3009
3010impl MySerialize for ComTableDump<'_> {
3011 fn serialize(&self, buf: &mut Vec<u8>) {
3012 self.header.serialize(&mut *buf);
3013 self.database.serialize(&mut *buf);
3014 self.table.serialize(&mut *buf);
3015 }
3016}
3017
3018impl<'de> MyDeserialize<'de> for ComTableDump<'de> {
3019 const SIZE: Option<usize> = None;
3020 type Ctx = ();
3021
3022 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3023 Ok(Self {
3024 header: buf.parse(())?,
3025 database: buf.parse(())?,
3026 table: buf.parse(())?,
3027 })
3028 }
3029}
3030
3031my_bitflags! {
3032 BinlogDumpFlags,
3033 #[error("Unknown flags in the raw value of BinlogDumpFlags (raw={0:b})")]
3034 UnknownBinlogDumpFlags,
3035 u16,
3036
3037 #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
3039 pub struct BinlogDumpFlags: u16 {
3040 const BINLOG_DUMP_NON_BLOCK = 0x01;
3042 const BINLOG_THROUGH_POSITION = 0x02;
3043 const BINLOG_THROUGH_GTID = 0x04;
3044 }
3045}
3046
3047define_header!(
3048 ComBinlogDumpHeader,
3049 COM_BINLOG_DUMP,
3050 InvalidComBinlogDumpHeader
3051);
3052
3053#[derive(Clone, Debug, Eq, PartialEq, Hash)]
3055pub struct ComBinlogDump<'a> {
3056 header: ComBinlogDumpHeader,
3057 pos: RawInt<LeU32>,
3059 flags: Const<BinlogDumpFlags, LeU16>,
3063 server_id: RawInt<LeU32>,
3065 filename: RawBytes<'a, EofBytes>,
3070}
3071
3072impl<'a> ComBinlogDump<'a> {
3073 pub fn new(server_id: u32) -> Self {
3075 Self {
3076 header: Default::default(),
3077 pos: Default::default(),
3078 flags: Default::default(),
3079 server_id: RawInt::new(server_id),
3080 filename: Default::default(),
3081 }
3082 }
3083
3084 pub fn with_pos(mut self, pos: u32) -> Self {
3086 self.pos = RawInt::new(pos);
3087 self
3088 }
3089
3090 pub fn with_flags(mut self, flags: BinlogDumpFlags) -> Self {
3092 self.flags = Const::new(flags);
3093 self
3094 }
3095
3096 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3098 self.filename = RawBytes::new(filename);
3099 self
3100 }
3101
3102 pub fn pos(&self) -> u32 {
3104 *self.pos
3105 }
3106
3107 pub fn flags(&self) -> BinlogDumpFlags {
3109 *self.flags
3110 }
3111
3112 pub fn server_id(&self) -> u32 {
3114 *self.server_id
3115 }
3116
3117 pub fn filename_raw(&self) -> &[u8] {
3119 self.filename.as_bytes()
3120 }
3121
3122 pub fn filename(&self) -> Cow<'_, str> {
3124 self.filename.as_str()
3125 }
3126}
3127
3128impl MySerialize for ComBinlogDump<'_> {
3129 fn serialize(&self, buf: &mut Vec<u8>) {
3130 self.header.serialize(&mut *buf);
3131 self.pos.serialize(&mut *buf);
3132 self.flags.serialize(&mut *buf);
3133 self.server_id.serialize(&mut *buf);
3134 self.filename.serialize(&mut *buf);
3135 }
3136}
3137
3138impl<'de> MyDeserialize<'de> for ComBinlogDump<'de> {
3139 const SIZE: Option<usize> = None;
3140 type Ctx = ();
3141
3142 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3143 let mut sbuf: ParseBuf<'_> = buf.parse(11)?;
3144 Ok(Self {
3145 header: sbuf.parse_unchecked(())?,
3146 pos: sbuf.parse_unchecked(())?,
3147 flags: sbuf.parse_unchecked(())?,
3148 server_id: sbuf.parse_unchecked(())?,
3149 filename: buf.parse(())?,
3150 })
3151 }
3152}
3153
3154#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
3156pub struct GnoInterval {
3157 start: RawInt<LeU64>,
3158 end: RawInt<LeU64>,
3159}
3160
3161impl GnoInterval {
3162 pub fn new(start: u64, end: u64) -> Self {
3164 Self {
3165 start: RawInt::new(start),
3166 end: RawInt::new(end),
3167 }
3168 }
3169 pub fn check_and_new(start: u64, end: u64) -> io::Result<Self> {
3171 if start >= end {
3172 return Err(io::Error::new(
3173 io::ErrorKind::InvalidData,
3174 format!("start({}) >= end({}) in GnoInterval", start, end),
3175 ));
3176 }
3177 if start == 0 || end == 0 {
3178 return Err(io::Error::new(
3179 io::ErrorKind::InvalidData,
3180 "Gno can't be zero",
3181 ));
3182 }
3183 Ok(Self::new(start, end))
3184 }
3185}
3186
3187impl MySerialize for GnoInterval {
3188 fn serialize(&self, buf: &mut Vec<u8>) {
3189 self.start.serialize(&mut *buf);
3190 self.end.serialize(&mut *buf);
3191 }
3192}
3193
3194impl<'de> MyDeserialize<'de> for GnoInterval {
3195 const SIZE: Option<usize> = Some(16);
3196 type Ctx = ();
3197
3198 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3199 Ok(Self {
3200 start: buf.parse_unchecked(())?,
3201 end: buf.parse_unchecked(())?,
3202 })
3203 }
3204}
3205
3206pub const UUID_LEN: usize = 16;
3208
3209#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3212pub struct Sid<'a> {
3213 uuid: [u8; UUID_LEN],
3214 intervals: Seq<'a, GnoInterval, LeU64>,
3215}
3216
3217impl Sid<'_> {
3218 pub fn new(uuid: [u8; UUID_LEN]) -> Self {
3220 Self {
3221 uuid,
3222 intervals: Default::default(),
3223 }
3224 }
3225
3226 pub fn uuid(&self) -> [u8; UUID_LEN] {
3228 self.uuid
3229 }
3230
3231 pub fn intervals(&self) -> &[GnoInterval] {
3233 &self.intervals[..]
3234 }
3235
3236 pub fn with_interval(mut self, interval: GnoInterval) -> Self {
3238 let mut intervals = self.intervals.0.into_owned();
3239 intervals.push(interval);
3240 self.intervals = Seq::new(intervals);
3241 self
3242 }
3243
3244 pub fn with_intervals(mut self, intervals: Vec<GnoInterval>) -> Self {
3246 self.intervals = Seq::new(intervals);
3247 self
3248 }
3249
3250 fn len(&self) -> u64 {
3251 use saturating::Saturating as S;
3252 let mut len = S(UUID_LEN as u64); len += S(8); len += S((self.intervals.len() * 16) as u64);
3255 len.0
3256 }
3257}
3258
3259impl MySerialize for Sid<'_> {
3260 fn serialize(&self, buf: &mut Vec<u8>) {
3261 self.uuid.serialize(&mut *buf);
3262 self.intervals.serialize(buf);
3263 }
3264}
3265
3266impl<'de> MyDeserialize<'de> for Sid<'de> {
3267 const SIZE: Option<usize> = None;
3268 type Ctx = ();
3269
3270 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3271 Ok(Self {
3272 uuid: buf.parse(())?,
3273 intervals: buf.parse(())?,
3274 })
3275 }
3276}
3277
3278impl Sid<'_> {
3279 fn wrap_err(msg: String) -> io::Error {
3280 io::Error::new(io::ErrorKind::InvalidInput, msg)
3281 }
3282
3283 fn parse_interval_num(to_parse: &str, full: &str) -> Result<u64, io::Error> {
3284 let n: u64 = to_parse.parse().map_err(|e| {
3285 Sid::wrap_err(format!(
3286 "invalid GnoInterval format: {}, error: {}",
3287 full, e
3288 ))
3289 })?;
3290 Ok(n)
3291 }
3292}
3293
3294impl FromStr for Sid<'_> {
3295 type Err = io::Error;
3296
3297 fn from_str(s: &str) -> Result<Self, Self::Err> {
3298 let (uuid, intervals) = s
3299 .split_once(':')
3300 .ok_or_else(|| Sid::wrap_err(format!("invalid sid format: {}", s)))?;
3301 let uuid = Uuid::parse_str(uuid)
3302 .map_err(|e| Sid::wrap_err(format!("invalid uuid format: {}, error: {}", s, e)))?;
3303 let intervals = intervals
3304 .split(':')
3305 .map(|interval| {
3306 let nums = interval.split('-').collect::<Vec<_>>();
3307 if nums.len() != 1 && nums.len() != 2 {
3308 return Err(Sid::wrap_err(format!("invalid GnoInterval format: {}", s)));
3309 }
3310 if nums.len() == 1 {
3311 let start = Sid::parse_interval_num(nums[0], s)?;
3312 let interval = GnoInterval::check_and_new(start, start + 1)?;
3313 Ok(interval)
3314 } else {
3315 let start = Sid::parse_interval_num(nums[0], s)?;
3316 let end = Sid::parse_interval_num(nums[1], s)?;
3317 let interval = GnoInterval::check_and_new(start, end + 1)?;
3318 Ok(interval)
3319 }
3320 })
3321 .collect::<Result<Vec<_>, _>>()?;
3322 Ok(Self {
3323 uuid: *uuid.as_bytes(),
3324 intervals: Seq::new(intervals),
3325 })
3326 }
3327}
3328
3329define_header!(
3330 ComBinlogDumpGtidHeader,
3331 COM_BINLOG_DUMP_GTID,
3332 InvalidComBinlogDumpGtidHeader
3333);
3334
3335#[derive(Debug, Clone, Eq, PartialEq, Hash)]
3337pub struct ComBinlogDumpGtid<'a> {
3338 header: ComBinlogDumpGtidHeader,
3339 flags: Const<BinlogDumpFlags, LeU16>,
3341 server_id: RawInt<LeU32>,
3343 filename: RawBytes<'a, U32Bytes>,
3352 pos: RawInt<LeU64>,
3354 sid_block: Seq<'a, Sid<'a>, LeU64>,
3356}
3357
3358impl<'a> ComBinlogDumpGtid<'a> {
3359 pub fn new(server_id: u32) -> Self {
3361 Self {
3362 header: Default::default(),
3363 pos: Default::default(),
3364 flags: Default::default(),
3365 server_id: RawInt::new(server_id),
3366 filename: Default::default(),
3367 sid_block: Default::default(),
3368 }
3369 }
3370
3371 pub fn server_id(&self) -> u32 {
3373 self.server_id.0
3374 }
3375
3376 pub fn flags(&self) -> BinlogDumpFlags {
3378 self.flags.0
3379 }
3380
3381 pub fn filename_raw(&self) -> &[u8] {
3383 self.filename.as_bytes()
3384 }
3385
3386 pub fn filename(&self) -> Cow<'_, str> {
3388 self.filename.as_str()
3389 }
3390
3391 pub fn pos(&self) -> u64 {
3393 self.pos.0
3394 }
3395
3396 pub fn sids(&self) -> &[Sid<'a>] {
3398 &self.sid_block
3399 }
3400
3401 pub fn with_filename(self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3403 Self {
3404 header: self.header,
3405 flags: self.flags,
3406 server_id: self.server_id,
3407 filename: RawBytes::new(filename),
3408 pos: self.pos,
3409 sid_block: self.sid_block,
3410 }
3411 }
3412
3413 pub fn with_server_id(mut self, server_id: u32) -> Self {
3415 self.server_id.0 = server_id;
3416 self
3417 }
3418
3419 pub fn with_flags(mut self, mut flags: BinlogDumpFlags) -> Self {
3421 if self.sid_block.is_empty() {
3422 flags.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3423 } else {
3424 flags.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3425 }
3426 self.flags.0 = flags;
3427 self
3428 }
3429
3430 pub fn with_pos(mut self, pos: u64) -> Self {
3432 self.pos.0 = pos;
3433 self
3434 }
3435
3436 pub fn with_sid(mut self, sid: Sid<'a>) -> Self {
3438 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3439 self.sid_block.push(sid);
3440 self
3441 }
3442
3443 pub fn with_sids(mut self, sids: impl Into<Cow<'a, [Sid<'a>]>>) -> Self {
3445 self.sid_block = Seq::new(sids);
3446 if self.sid_block.is_empty() {
3447 self.flags.0.remove(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3448 } else {
3449 self.flags.0.insert(BinlogDumpFlags::BINLOG_THROUGH_GTID);
3450 }
3451 self
3452 }
3453
3454 fn sid_block_len(&self) -> u32 {
3455 use saturating::Saturating as S;
3456 let mut len = S(8); for sid in self.sid_block.iter() {
3458 len += S(sid.len() as u32);
3459 }
3460 len.0
3461 }
3462}
3463
3464impl MySerialize for ComBinlogDumpGtid<'_> {
3465 fn serialize(&self, buf: &mut Vec<u8>) {
3466 self.header.serialize(&mut *buf);
3467 self.flags.serialize(&mut *buf);
3468 self.server_id.serialize(&mut *buf);
3469 self.filename.serialize(&mut *buf);
3470 self.pos.serialize(&mut *buf);
3471 buf.put_u32_le(self.sid_block_len());
3472 self.sid_block.serialize(&mut *buf);
3473 }
3474}
3475
3476impl<'de> MyDeserialize<'de> for ComBinlogDumpGtid<'de> {
3477 const SIZE: Option<usize> = None;
3478 type Ctx = ();
3479
3480 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3481 let mut sbuf: ParseBuf<'_> = buf.parse(7)?;
3482 let header = sbuf.parse_unchecked(())?;
3483 let flags: Const<BinlogDumpFlags, LeU16> = sbuf.parse_unchecked(())?;
3484 let server_id = sbuf.parse_unchecked(())?;
3485
3486 let filename = buf.parse(())?;
3487 let pos = buf.parse(())?;
3488
3489 let sid_data_len: RawInt<LeU32> = buf.parse(())?;
3491 let mut buf: ParseBuf<'_> = buf.parse(sid_data_len.0 as usize)?;
3492 let sid_block = buf.parse(())?;
3493
3494 Ok(Self {
3495 header,
3496 flags,
3497 server_id,
3498 filename,
3499 pos,
3500 sid_block,
3501 })
3502 }
3503}
3504
3505define_header!(
3506 SemiSyncAckPacketPacketHeader,
3507 InvalidSemiSyncAckPacketPacketHeader("Invalid semi-sync ack packet header"),
3508 0xEF
3509);
3510
3511pub struct SemiSyncAckPacket<'a> {
3514 header: SemiSyncAckPacketPacketHeader,
3515 position: RawInt<LeU64>,
3516 filename: RawBytes<'a, EofBytes>,
3517}
3518
3519impl<'a> SemiSyncAckPacket<'a> {
3520 pub fn new(position: u64, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3521 Self {
3522 header: Default::default(),
3523 position: RawInt::new(position),
3524 filename: RawBytes::new(filename),
3525 }
3526 }
3527
3528 pub fn with_position(mut self, position: u64) -> Self {
3530 self.position.0 = position;
3531 self
3532 }
3533
3534 pub fn with_filename(mut self, filename: impl Into<Cow<'a, [u8]>>) -> Self {
3536 self.filename = RawBytes::new(filename);
3537 self
3538 }
3539
3540 pub fn position(&self) -> u64 {
3542 self.position.0
3543 }
3544
3545 pub fn filename_raw(&self) -> &[u8] {
3547 self.filename.as_bytes()
3548 }
3549
3550 pub fn filename(&self) -> Cow<'_, str> {
3552 self.filename.as_str()
3553 }
3554}
3555
3556impl MySerialize for SemiSyncAckPacket<'_> {
3557 fn serialize(&self, buf: &mut Vec<u8>) {
3558 self.header.serialize(&mut *buf);
3559 self.position.serialize(&mut *buf);
3560 self.filename.serialize(&mut *buf);
3561 }
3562}
3563
3564impl<'de> MyDeserialize<'de> for SemiSyncAckPacket<'de> {
3565 const SIZE: Option<usize> = None;
3566 type Ctx = ();
3567
3568 fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
3569 let mut sbuf: ParseBuf<'_> = buf.parse(9)?;
3570 Ok(Self {
3571 header: sbuf.parse_unchecked(())?,
3572 position: sbuf.parse_unchecked(())?,
3573 filename: buf.parse(())?,
3574 })
3575 }
3576}
3577
3578#[cfg(test)]
3579mod test {
3580 use super::*;
3581 use crate::{
3582 constants::{CapabilityFlags, ColumnFlags, ColumnType, StatusFlags},
3583 proto::{MyDeserialize, MySerialize},
3584 };
3585
3586 proptest::proptest! {
3587 #[test]
3588 fn com_table_dump_roundtrip(database: Vec<u8>, table: Vec<u8>) {
3589 let cmd = ComTableDump::new(database, table);
3590
3591 let mut output = Vec::new();
3592 cmd.serialize(&mut output);
3593
3594 assert_eq!(cmd, ComTableDump::deserialize((), &mut ParseBuf(&output[..]))?);
3595 }
3596
3597 #[test]
3598 fn com_binlog_dump_roundtrip(
3599 server_id: u32,
3600 filename: Vec<u8>,
3601 pos: u32,
3602 flags: u16,
3603 ) {
3604 let cmd = ComBinlogDump::new(server_id)
3605 .with_filename(filename)
3606 .with_pos(pos)
3607 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3608
3609 let mut output = Vec::new();
3610 cmd.serialize(&mut output);
3611
3612 assert_eq!(cmd, ComBinlogDump::deserialize((), &mut ParseBuf(&output[..]))?);
3613 }
3614
3615 #[test]
3616 fn com_register_slave_roundtrip(
3617 server_id: u32,
3618 hostname in r"\w{0,256}",
3619 user in r"\w{0,256}",
3620 password in r"\w{0,256}",
3621 port: u16,
3622 replication_rank: u32,
3623 master_id: u32,
3624 ) {
3625 let cmd = ComRegisterSlave::new(server_id)
3626 .with_hostname(hostname.as_bytes())
3627 .with_user(user.as_bytes())
3628 .with_password(password.as_bytes())
3629 .with_port(port)
3630 .with_replication_rank(replication_rank)
3631 .with_master_id(master_id);
3632
3633 let mut output = Vec::new();
3634 cmd.serialize(&mut output);
3635 let parsed = ComRegisterSlave::deserialize((), &mut ParseBuf(&output[..]))?;
3636
3637 if hostname.len() > 255 || user.len() > 255 || password.len() > 255 {
3638 assert_ne!(cmd, parsed);
3639 } else {
3640 assert_eq!(cmd, parsed);
3641 }
3642 }
3643
3644 #[test]
3645 fn com_binlog_dump_gtid_roundtrip(
3646 flags: u16,
3647 server_id: u32,
3648 filename: Vec<u8>,
3649 pos: u64,
3650 n_sid_blocks in 0_u64..1024,
3651 ) {
3652 let mut cmd = ComBinlogDumpGtid::new(server_id)
3653 .with_filename(filename)
3654 .with_pos(pos)
3655 .with_flags(crate::packets::BinlogDumpFlags::from_bits_truncate(flags));
3656
3657 let mut sids = Vec::new();
3658 for i in 0..n_sid_blocks {
3659 let mut block = Sid::new([i as u8; 16]);
3660 for j in 0..i {
3661 block = block.with_interval(GnoInterval::new(i, j));
3662 }
3663 sids.push(block);
3664 }
3665
3666 cmd = cmd.with_sids(sids);
3667
3668 let mut output = Vec::new();
3669 cmd.serialize(&mut output);
3670
3671 assert_eq!(cmd, ComBinlogDumpGtid::deserialize((), &mut ParseBuf(&output[..]))?);
3672 }
3673 }
3674
3675 #[test]
3676 fn should_parse_local_infile_packet() {
3677 const LIP: &[u8] = b"\xfbfile_name";
3678
3679 let lip = LocalInfilePacket::deserialize((), &mut ParseBuf(LIP)).unwrap();
3680 assert_eq!(lip.file_name_str(), "file_name");
3681 }
3682
3683 #[test]
3684 fn should_parse_stmt_packet() {
3685 const SP: &[u8] = b"\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00";
3686 const SP_2: &[u8] = b"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
3687
3688 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP)).unwrap();
3689 assert_eq!(sp.statement_id(), 0x01);
3690 assert_eq!(sp.num_columns(), 0x01);
3691 assert_eq!(sp.num_params(), 0x02);
3692 assert_eq!(sp.warning_count(), 0x00);
3693
3694 let sp = StmtPacket::deserialize((), &mut ParseBuf(SP_2)).unwrap();
3695 assert_eq!(sp.statement_id(), 0x01);
3696 assert_eq!(sp.num_columns(), 0x00);
3697 assert_eq!(sp.num_params(), 0x00);
3698 assert_eq!(sp.warning_count(), 0x00);
3699 }
3700
3701 #[test]
3702 fn should_parse_handshake_packet() {
3703 const HSP: &[u8] = b"\x0a5.5.5-10.0.17-MariaDB-log\x00\x0b\x00\
3704 \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
3705 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x2a\x34\x64\
3706 \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
3707
3708 const HSP_2: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3709 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3710 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3711 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3712 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3713 \x6f\x72\x64\x00";
3714
3715 const HSP_3: &[u8] = b"\x0a\x35\x2e\x36\x2e\x34\x2d\x6d\x37\x2d\x6c\x6f\
3716 \x67\x00\x56\x0a\x00\x00\x52\x42\x33\x76\x7a\x26\x47\x72\x00\xff\
3717 \xff\x08\x02\x00\x0f\xc0\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\
3718 \x00\x2b\x79\x44\x26\x2f\x5a\x5a\x33\x30\x35\x5a\x47\x00\x6d\x79\
3719 \x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\x73\x73\x77\
3720 \x6f\x72\x64\x00";
3721
3722 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
3723 assert_eq!(hsp.protocol_version(), 0x0a);
3724 assert_eq!(hsp.server_version_str(), "5.5.5-10.0.17-MariaDB-log");
3725 assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
3726 assert_eq!(hsp.maria_db_server_version_parsed(), Some((10, 0, 17)));
3727 assert_eq!(hsp.connection_id(), 0x0b);
3728 assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
3729 assert_eq!(
3730 hsp.capabilities(),
3731 CapabilityFlags::from_bits_truncate(0xf7ff)
3732 );
3733 assert_eq!(hsp.default_collation(), 0x08);
3734 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3735 assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
3736 assert_eq!(hsp.auth_plugin_name_ref(), None);
3737
3738 let mut output = Vec::new();
3739 hsp.serialize(&mut output);
3740 assert_eq!(&output, HSP);
3741
3742 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_2)).unwrap();
3743 assert_eq!(hsp.protocol_version(), 0x0a);
3744 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3745 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3746 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3747 assert_eq!(hsp.connection_id(), 0x0a56);
3748 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3749 assert_eq!(
3750 hsp.capabilities(),
3751 CapabilityFlags::from_bits_truncate(0xc00fffff)
3752 );
3753 assert_eq!(hsp.default_collation(), 0x08);
3754 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3755 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3756 assert_eq!(
3757 hsp.auth_plugin_name_ref(),
3758 Some(&b"mysql_native_password"[..])
3759 );
3760
3761 let mut output = Vec::new();
3762 hsp.serialize(&mut output);
3763 assert_eq!(&output, HSP_2);
3764
3765 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP_3)).unwrap();
3766 assert_eq!(hsp.protocol_version(), 0x0a);
3767 assert_eq!(hsp.server_version_str(), "5.6.4-m7-log");
3768 assert_eq!(hsp.server_version_parsed(), Some((5, 6, 4)));
3769 assert_eq!(hsp.maria_db_server_version_parsed(), None);
3770 assert_eq!(hsp.connection_id(), 0x0a56);
3771 assert_eq!(hsp.scramble_1_ref(), b"RB3vz&Gr");
3772 assert_eq!(
3773 hsp.capabilities(),
3774 CapabilityFlags::from_bits_truncate(0xc00fffff)
3775 );
3776 assert_eq!(hsp.default_collation(), 0x08);
3777 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
3778 assert_eq!(hsp.scramble_2_ref(), Some(&b"+yD&/ZZ305ZG\0"[..]));
3779 assert_eq!(
3780 hsp.auth_plugin_name_ref(),
3781 Some(&b"mysql_native_password"[..])
3782 );
3783
3784 let mut output = Vec::new();
3785 hsp.serialize(&mut output);
3786 assert_eq!(&output, HSP_3);
3787 }
3788
3789 #[test]
3790 fn should_parse_err_packet() {
3791 const ERR_PACKET: &[u8] = b"\xff\x48\x04\x23\x48\x59\x30\x30\x30\x4e\x6f\x20\x74\x61\x62\
3792 \x6c\x65\x73\x20\x75\x73\x65\x64";
3793 const ERR_PACKET_NO_STATE: &[u8] = b"\xff\x10\x04\x54\x6f\x6f\x20\x6d\x61\x6e\x79\x20\x63\
3794 \x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
3795 const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";
3796
3797 let err_packet = ErrPacket::deserialize(
3798 CapabilityFlags::CLIENT_PROTOCOL_41,
3799 &mut ParseBuf(ERR_PACKET),
3800 )
3801 .unwrap();
3802 let err_packet = err_packet.server_error();
3803 assert_eq!(err_packet.error_code(), 1096);
3804 assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
3805 assert_eq!(err_packet.message_str(), "No tables used");
3806
3807 let err_packet =
3808 ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
3809 .unwrap();
3810 let server_error = err_packet.server_error();
3811 assert_eq!(server_error.error_code(), 1040);
3812 assert_eq!(server_error.sql_state_ref(), None);
3813 assert_eq!(server_error.message_str(), "Too many connections");
3814
3815 let err_packet = ErrPacket::deserialize(
3816 CapabilityFlags::CLIENT_PROGRESS_OBSOLETE,
3817 &mut ParseBuf(PROGRESS_PACKET),
3818 )
3819 .unwrap();
3820 assert!(err_packet.is_progress_report());
3821 let progress_report = err_packet.progress_report();
3822 assert_eq!(progress_report.stage(), 1);
3823 assert_eq!(progress_report.max_stage(), 10);
3824 assert_eq!(progress_report.progress(), 23500);
3825 assert_eq!(progress_report.stage_info_str(), "stage name");
3826 }
3827
3828 #[test]
3829 fn should_parse_column_packet() {
3830 const COLUMN_PACKET: &[u8] = b"\x03def\x06schema\x05table\x09org_table\x04name\
3831 \x08org_name\x0c\x21\x00\x0F\x00\x00\x00\x00\x01\x00\x08\x00\x00";
3832 let column = Column::deserialize((), &mut ParseBuf(COLUMN_PACKET)).unwrap();
3833 assert_eq!(column.schema_str(), "schema");
3834 assert_eq!(column.table_str(), "table");
3835 assert_eq!(column.org_table_str(), "org_table");
3836 assert_eq!(column.name_str(), "name");
3837 assert_eq!(column.org_name_str(), "org_name");
3838 assert_eq!(
3839 column.character_set(),
3840 CollationId::UTF8MB3_GENERAL_CI as u16
3841 );
3842 assert_eq!(column.column_length(), 15);
3843 assert_eq!(column.column_type(), ColumnType::MYSQL_TYPE_DECIMAL);
3844 assert_eq!(column.flags(), ColumnFlags::NOT_NULL_FLAG);
3845 assert_eq!(column.decimals(), 8);
3846 }
3847
3848 #[test]
3849 fn should_parse_auth_switch_request() {
3850 const PAYLOAD: &[u8] = b"\xfe\x6d\x79\x73\x71\x6c\x5f\x6e\x61\x74\x69\x76\x65\x5f\x70\x61\
3851 \x73\x73\x77\x6f\x72\x64\x00\x7a\x51\x67\x34\x69\x36\x6f\x4e\x79\
3852 \x36\x3d\x72\x48\x4e\x2f\x3e\x2d\x62\x29\x41\x00";
3853 let packet = AuthSwitchRequest::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3854 assert_eq!(packet.auth_plugin().as_bytes(), b"mysql_native_password",);
3855 assert_eq!(packet.plugin_data(), b"zQg4i6oNy6=rHN/>-b)A",)
3856 }
3857
3858 #[test]
3859 fn should_parse_auth_more_data() {
3860 const PAYLOAD: &[u8] = b"\x01\x04";
3861 let packet = AuthMoreData::deserialize((), &mut ParseBuf(PAYLOAD)).unwrap();
3862 assert_eq!(packet.data(), b"\x04",);
3863 }
3864
3865 #[test]
3866 fn should_parse_ok_packet() {
3867 const PLAIN_OK: &[u8] = b"\x00\x01\x00\x02\x00\x00\x00";
3868 const RESULT_SET_TERMINATOR: &[u8] = &[
3869 0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x42, 0x52, 0x65, 0x61, 0x64, 0x20, 0x31,
3870 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2c, 0x20, 0x31, 0x2e, 0x30, 0x30, 0x20, 0x42, 0x20,
3871 0x69, 0x6e, 0x20, 0x30, 0x2e, 0x30, 0x30, 0x32, 0x20, 0x73, 0x65, 0x63, 0x2e, 0x2c,
3872 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x72, 0x6f, 0x77, 0x73, 0x2f, 0x73,
3873 0x65, 0x63, 0x2e, 0x2c, 0x20, 0x36, 0x31, 0x31, 0x2e, 0x33, 0x34, 0x20, 0x42, 0x2f,
3874 0x73, 0x65, 0x63, 0x2e,
3875 ];
3876 const SESS_STATE_SYS_VAR_OK: &[u8] =
3877 b"\x00\x00\x00\x02\x40\x00\x00\x00\x11\x00\x0f\x0a\x61\
3878 \x75\x74\x6f\x63\x6f\x6d\x6d\x69\x74\x03\x4f\x46\x46";
3879 const SESS_STATE_SCHEMA_OK: &[u8] =
3880 b"\x00\x00\x00\x02\x40\x00\x00\x00\x07\x01\x05\x04\x74\x65\x73\x74";
3881 const SESS_STATE_TRACK_OK: &[u8] = b"\x00\x00\x00\x02\x40\x00\x00\x00\x04\x02\x02\x01\x31";
3882 const EOF: &[u8] = b"\xfe\x00\x00\x02\x00";
3883
3884 OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3886 CapabilityFlags::empty(),
3887 &mut ParseBuf(PLAIN_OK),
3888 )
3889 .unwrap_err();
3890
3891 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3892 CapabilityFlags::empty(),
3893 &mut ParseBuf(PLAIN_OK),
3894 )
3895 .unwrap()
3896 .into();
3897 assert_eq!(ok_packet.affected_rows(), 1);
3898 assert_eq!(ok_packet.last_insert_id(), None);
3899 assert_eq!(
3900 ok_packet.status_flags(),
3901 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3902 );
3903 assert_eq!(ok_packet.warnings(), 0);
3904 assert_eq!(ok_packet.info_ref(), None);
3905 assert_eq!(ok_packet.session_state_info_ref(), None);
3906
3907 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3908 CapabilityFlags::CLIENT_SESSION_TRACK,
3909 &mut ParseBuf(PLAIN_OK),
3910 )
3911 .unwrap()
3912 .into();
3913 assert_eq!(ok_packet.affected_rows(), 1);
3914 assert_eq!(ok_packet.last_insert_id(), None);
3915 assert_eq!(
3916 ok_packet.status_flags(),
3917 StatusFlags::SERVER_STATUS_AUTOCOMMIT
3918 );
3919 assert_eq!(ok_packet.warnings(), 0);
3920 assert_eq!(ok_packet.info_ref(), None);
3921 assert_eq!(ok_packet.session_state_info_ref(), None);
3922
3923 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<ResultSetTerminator>::deserialize(
3924 CapabilityFlags::CLIENT_SESSION_TRACK,
3925 &mut ParseBuf(RESULT_SET_TERMINATOR),
3926 )
3927 .unwrap()
3928 .into();
3929 assert_eq!(ok_packet.affected_rows(), 0);
3930 assert_eq!(ok_packet.last_insert_id(), None);
3931 assert_eq!(ok_packet.status_flags(), StatusFlags::empty());
3932 assert_eq!(ok_packet.warnings(), 0);
3933 assert_eq!(
3934 ok_packet.info_str(),
3935 Some(Cow::Borrowed(
3936 "Read 1 rows, 1.00 B in 0.002 sec., 611.34 rows/sec., 611.34 B/sec."
3937 ))
3938 );
3939 assert_eq!(ok_packet.session_state_info_ref(), None);
3940
3941 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3942 CapabilityFlags::CLIENT_SESSION_TRACK,
3943 &mut ParseBuf(SESS_STATE_SYS_VAR_OK),
3944 )
3945 .unwrap()
3946 .into();
3947 assert_eq!(ok_packet.affected_rows(), 0);
3948 assert_eq!(ok_packet.last_insert_id(), None);
3949 assert_eq!(
3950 ok_packet.status_flags(),
3951 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3952 );
3953 assert_eq!(ok_packet.warnings(), 0);
3954 assert_eq!(ok_packet.info_ref(), None);
3955 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3956
3957 match sess_state_info.decode().unwrap() {
3958 SessionStateChange::SystemVariables(mut vals) => {
3959 let val = vals.pop().unwrap();
3960 assert_eq!(val.name_bytes(), b"autocommit");
3961 assert_eq!(val.value_bytes(), b"OFF");
3962 assert!(vals.is_empty());
3963 }
3964 _ => panic!(),
3965 }
3966
3967 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3968 CapabilityFlags::CLIENT_SESSION_TRACK,
3969 &mut ParseBuf(SESS_STATE_SCHEMA_OK),
3970 )
3971 .unwrap()
3972 .into();
3973 assert_eq!(ok_packet.affected_rows(), 0);
3974 assert_eq!(ok_packet.last_insert_id(), None);
3975 assert_eq!(
3976 ok_packet.status_flags(),
3977 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3978 );
3979 assert_eq!(ok_packet.warnings(), 0);
3980 assert_eq!(ok_packet.info_ref(), None);
3981 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
3982 match sess_state_info.decode().unwrap() {
3983 SessionStateChange::Schema(schema) => assert_eq!(schema.as_bytes(), b"test"),
3984 _ => panic!(),
3985 }
3986
3987 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<CommonOkPacket>::deserialize(
3988 CapabilityFlags::CLIENT_SESSION_TRACK,
3989 &mut ParseBuf(SESS_STATE_TRACK_OK),
3990 )
3991 .unwrap()
3992 .into();
3993 assert_eq!(ok_packet.affected_rows(), 0);
3994 assert_eq!(ok_packet.last_insert_id(), None);
3995 assert_eq!(
3996 ok_packet.status_flags(),
3997 StatusFlags::SERVER_STATUS_AUTOCOMMIT | StatusFlags::SERVER_SESSION_STATE_CHANGED
3998 );
3999 assert_eq!(ok_packet.warnings(), 0);
4000 assert_eq!(ok_packet.info_ref(), None);
4001 let sess_state_info = ok_packet.session_state_info().unwrap().pop().unwrap();
4002 assert_eq!(
4003 sess_state_info.decode().unwrap(),
4004 SessionStateChange::IsTracked(true),
4005 );
4006
4007 let ok_packet: OkPacket<'_> = OkPacketDeserializer::<OldEofPacket>::deserialize(
4008 CapabilityFlags::CLIENT_SESSION_TRACK,
4009 &mut ParseBuf(EOF),
4010 )
4011 .unwrap()
4012 .into();
4013 assert_eq!(ok_packet.affected_rows(), 0);
4014 assert_eq!(ok_packet.last_insert_id(), None);
4015 assert_eq!(
4016 ok_packet.status_flags(),
4017 StatusFlags::SERVER_STATUS_AUTOCOMMIT
4018 );
4019 assert_eq!(ok_packet.warnings(), 0);
4020 assert_eq!(ok_packet.info_ref(), None);
4021 assert_eq!(ok_packet.session_state_info_ref(), None);
4022 }
4023
4024 #[test]
4025 fn should_build_handshake_response() {
4026 let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
4027 let response = HandshakeResponse::new(
4028 Some(&[][..]),
4029 (5u16, 5, 5),
4030 Some(&b"root"[..]),
4031 None::<&'static [u8]>,
4032 Some(AuthPlugin::MysqlNativePassword),
4033 flags_without_db_name,
4034 None,
4035 1_u32.to_be(),
4036 );
4037 let mut actual = Vec::new();
4038 response.serialize(&mut actual);
4039
4040 let expected: Vec<u8> = [
4041 0x05, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4045 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4049 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4051 .to_vec();
4052
4053 assert_eq!(expected, actual);
4054
4055 let flags_with_db_name = flags_without_db_name | CapabilityFlags::CLIENT_CONNECT_WITH_DB;
4056 let response = HandshakeResponse::new(
4057 Some(&[][..]),
4058 (5u16, 5, 5),
4059 Some(&b"root"[..]),
4060 Some(&b"mydb"[..]),
4061 Some(AuthPlugin::MysqlNativePassword),
4062 flags_with_db_name,
4063 None,
4064 1_u32.to_be(),
4065 );
4066 let mut actual = Vec::new();
4067 response.serialize(&mut actual);
4068
4069 let expected: Vec<u8> = [
4070 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4074 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x64, 0x62, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4079 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4081 .to_vec();
4082
4083 assert_eq!(expected, actual);
4084
4085 let response = HandshakeResponse::new(
4086 Some(&[][..]),
4087 (5u16, 5, 5),
4088 Some(&b"root"[..]),
4089 Some(&b"mydb"[..]),
4090 Some(AuthPlugin::MysqlNativePassword),
4091 flags_without_db_name,
4092 None,
4093 1_u32.to_be(),
4094 );
4095 let mut actual = Vec::new();
4096 response.serialize(&mut actual);
4097 assert_eq!(expected, actual);
4098
4099 let response = HandshakeResponse::new(
4100 Some(&[][..]),
4101 (5u16, 5, 5),
4102 Some(&b"root"[..]),
4103 Some(&[][..]),
4104 Some(AuthPlugin::MysqlNativePassword),
4105 flags_with_db_name,
4106 None,
4107 1_u32.to_be(),
4108 );
4109 let mut actual = Vec::new();
4110 response.serialize(&mut actual);
4111
4112 let expected: Vec<u8> = [
4113 0x0d, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4117 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4122 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4124 .to_vec();
4125 assert_eq!(expected, actual);
4126 }
4127
4128 #[test]
4129 fn should_parse_handshake_packet_with_mariadb_ext_capabilities() {
4130 const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\
4131 \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
4132 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\
4133 \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";
4134
4135 let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
4136 assert_eq!(hsp.protocol_version(), 0x0a);
4137 assert_eq!(hsp.server_version_str(), "5.5.5-11.4.7-MariaDB-log");
4138 assert_eq!(hsp.server_version_parsed(), Some((5, 5, 5)));
4139 assert_eq!(hsp.maria_db_server_version_parsed(), Some((11, 4, 7)));
4140 assert_eq!(hsp.connection_id(), 0x0b);
4141 assert_eq!(hsp.scramble_1_ref(), b"dvH@I-CJ");
4142 assert_eq!(
4143 hsp.capabilities(),
4144 CapabilityFlags::from_bits_truncate(0xf7ff)
4145 );
4146 assert_eq!(hsp.default_collation(), 0x08);
4147 assert_eq!(hsp.status_flags(), StatusFlags::from_bits_truncate(0x0002));
4148 assert_eq!(hsp.scramble_2_ref(), Some(&b"*4d|cZwk4^]:\x00"[..]));
4149 assert_eq!(hsp.auth_plugin_name_ref(), None);
4150 assert_eq!(
4151 hsp.mariadb_ext_capabilities(),
4152 MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
4153 );
4154 let mut output = Vec::new();
4155 hsp.serialize(&mut output);
4156 assert_eq!(&output, HSP);
4157 }
4158
4159 #[test]
4160 fn should_build_handshake_response_with_mariadb_capabilities() {
4161 let flags_without_db_name = CapabilityFlags::from_bits_truncate(0x81aea205);
4162 let response = HandshakeResponse::new(
4163 Some(&[][..]),
4164 (5u16, 5, 5),
4165 Some(&b"root"[..]),
4166 None::<&'static [u8]>,
4167 Some(AuthPlugin::MysqlNativePassword),
4168 flags_without_db_name,
4169 None,
4170 1_u32.to_be(),
4171 )
4172 .with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA);
4173 let mut actual = Vec::new();
4174 response.serialize(&mut actual);
4175
4176 let expected: Vec<u8> = [
4177 0x05, 0xa2, 0xae, 0x81, 0x00, 0x00, 0x00, 0x01, 0x2d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
4181 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
4186 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, ]
4188 .to_vec();
4189
4190 assert_eq!(expected, actual);
4191 }
4192
4193 #[test]
4194 fn parse_str_to_sid() {
4195 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:23";
4196 let sid = input.parse::<Sid<'_>>().unwrap();
4197 let expected_sid = Uuid::parse_str("3E11FA47-71CA-11E1-9E33-C80AA9429562").unwrap();
4198 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4199 assert_eq!(sid.intervals.len(), 1);
4200 assert_eq!(sid.intervals[0].start.0, 23);
4201 assert_eq!(sid.intervals[0].end.0, 24);
4202
4203 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15";
4204 let sid = input.parse::<Sid<'_>>().unwrap();
4205 assert_eq!(sid.uuid, *expected_sid.as_bytes());
4206 assert_eq!(sid.intervals.len(), 2);
4207 assert_eq!(sid.intervals[0].start.0, 1);
4208 assert_eq!(sid.intervals[0].end.0, 6);
4209 assert_eq!(sid.intervals[1].start.0, 10);
4210 assert_eq!(sid.intervals[1].end.0, 16);
4211
4212 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562";
4213 let e = input.parse::<Sid<'_>>().unwrap_err();
4214 assert_eq!(
4215 e.to_string(),
4216 "invalid sid format: 3E11FA47-71CA-11E1-9E33-C80AA9429562".to_string()
4217 );
4218
4219 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-";
4220 let e = input.parse::<Sid<'_>>().unwrap_err();
4221 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:10-15:20-, error: cannot parse integer from empty string".to_string());
4222
4223 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa";
4224 let e = input.parse::<Sid<'_>>().unwrap_err();
4225 assert_eq!(e.to_string(), "invalid GnoInterval format: 3E11FA47-71CA-11E1-9E33-C80AA9429562:1-5:1aaa, error: invalid digit found in string".to_string());
4226
4227 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:0-3";
4228 let e = input.parse::<Sid<'_>>().unwrap_err();
4229 assert_eq!(e.to_string(), "Gno can't be zero".to_string());
4230
4231 let input = "3E11FA47-71CA-11E1-9E33-C80AA9429562:4-3";
4232 let e = input.parse::<Sid<'_>>().unwrap_err();
4233 assert_eq!(
4234 e.to_string(),
4235 "start(4) >= end(4) in GnoInterval".to_string()
4236 );
4237 }
4238
4239 #[test]
4240 fn should_parse_rsa_public_key_response_packet() {
4241 const PUBLIC_RSA_KEY_RESPONSE: &[u8] = b"\x01test";
4242
4243 let rsa_public_key_response =
4244 PublicKeyResponse::deserialize((), &mut ParseBuf(PUBLIC_RSA_KEY_RESPONSE));
4245
4246 assert!(rsa_public_key_response.is_ok());
4247 assert_eq!(rsa_public_key_response.unwrap().rsa_key(), "test");
4248 }
4249
4250 #[test]
4251 fn should_build_rsa_public_key_response_packet() {
4252 let rsa_public_key_response = PublicKeyResponse::new("test".as_bytes());
4253
4254 let mut actual = Vec::new();
4255 rsa_public_key_response.serialize(&mut actual);
4256
4257 let expected = b"\x01test".to_vec();
4258
4259 assert_eq!(expected, actual);
4260 }
4261}