use std::{
borrow::Cow,
cmp::min,
convert::TryFrom,
fmt,
io::{self, Read},
};
use byteorder::{LittleEndian, ReadBytesExt};
use saturating::Saturating as S;
use crate::{
binlog::{
consts::{BinlogVersion, EventType, StatusVarKey},
BinlogCtx, BinlogEvent, BinlogStruct,
},
constants::{Flags2, SqlMode},
io::ParseBuf,
misc::{
raw::{
bytes::{BareU16Bytes, BareU8Bytes, EofBytes, NullBytes, U8Bytes},
int::*,
RawBytes, RawFlags, Skip,
},
unexpected_buf_eof,
},
proto::{MyDeserialize, MySerialize},
};
use super::BinlogEventHeader;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct QueryEvent<'a> {
thread_id: RawInt<LeU32>,
execution_time: RawInt<LeU32>,
schema_len: RawInt<u8>,
error_code: RawInt<LeU16>,
status_vars_len: RawInt<LeU16>,
status_vars: StatusVars<'a>,
schema: RawBytes<'a, BareU8Bytes>,
__skip: Skip<1>,
query: RawBytes<'a, EofBytes>,
}
impl<'a> QueryEvent<'a> {
pub fn new(status_vars: impl Into<Cow<'a, [u8]>>, schema: impl Into<Cow<'a, [u8]>>) -> Self {
let status_vars = StatusVars(RawBytes::new(status_vars));
let schema = RawBytes::new(schema);
Self {
thread_id: Default::default(),
execution_time: Default::default(),
schema_len: RawInt::new(schema.len() as u8),
error_code: Default::default(),
status_vars_len: RawInt::new(status_vars.0.len() as u16),
status_vars,
schema,
__skip: Default::default(),
query: Default::default(),
}
}
pub fn with_thread_id(mut self, thread_id: u32) -> Self {
self.thread_id = RawInt::new(thread_id);
self
}
pub fn with_execution_time(mut self, execution_time: u32) -> Self {
self.execution_time = RawInt::new(execution_time);
self
}
pub fn with_error_code(mut self, error_code: u16) -> Self {
self.error_code = RawInt::new(error_code);
self
}
pub fn with_status_vars(mut self, status_vars: impl Into<Cow<'a, [u8]>>) -> Self {
self.status_vars = StatusVars(RawBytes::new(status_vars));
self.status_vars_len.0 = self.status_vars.0.len() as u16;
self
}
pub fn with_schema(mut self, schema: impl Into<Cow<'a, [u8]>>) -> Self {
self.schema = RawBytes::new(schema);
self.schema_len.0 = self.schema.len() as u8;
self
}
pub fn with_query(mut self, query: impl Into<Cow<'a, [u8]>>) -> Self {
self.query = RawBytes::new(query);
self
}
pub fn thread_id(&self) -> u32 {
self.thread_id.0
}
pub fn execution_time(&self) -> u32 {
self.execution_time.0
}
pub fn error_code(&self) -> u16 {
self.error_code.0
}
pub fn status_vars_raw(&'a self) -> &'a [u8] {
self.status_vars.0.as_bytes()
}
pub fn status_vars(&'a self) -> &'a StatusVars<'a> {
&self.status_vars
}
pub fn schema_raw(&'a self) -> &'a [u8] {
self.schema.as_bytes()
}
pub fn schema(&'a self) -> Cow<'a, str> {
self.schema.as_str()
}
pub fn query_raw(&'a self) -> &'a [u8] {
self.query.as_bytes()
}
pub fn query(&'a self) -> Cow<'a, str> {
self.query.as_str()
}
pub fn into_owned(self) -> QueryEvent<'static> {
QueryEvent {
thread_id: self.thread_id,
execution_time: self.execution_time,
schema_len: self.schema_len,
error_code: self.error_code,
status_vars_len: self.status_vars_len,
status_vars: self.status_vars.into_owned(),
schema: self.schema.into_owned(),
__skip: self.__skip,
query: self.query.into_owned(),
}
}
}
impl<'de> MyDeserialize<'de> for QueryEvent<'de> {
const SIZE: Option<usize> = None;
type Ctx = BinlogCtx<'de>;
fn deserialize(ctx: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
let mut sbuf: ParseBuf = buf.parse(13)?;
let thread_id = sbuf.parse_unchecked(())?;
let execution_time = sbuf.parse_unchecked(())?;
let schema_len: RawInt<u8> = sbuf.parse_unchecked(())?;
let error_code = sbuf.parse_unchecked(())?;
let status_vars_len: RawInt<LeU16> = sbuf.parse_unchecked(())?;
let post_header_len = ctx.fde.get_event_type_header_length(Self::EVENT_TYPE);
if !buf.checked_skip(post_header_len.saturating_sub(13) as usize) {
return Err(unexpected_buf_eof());
}
let status_vars = buf.parse(*status_vars_len)?;
let schema = buf.parse(*schema_len as usize)?;
let __skip = buf.parse(())?;
let query = buf.parse(())?;
Ok(Self {
thread_id,
execution_time,
schema_len,
error_code,
status_vars_len,
status_vars,
schema,
__skip,
query,
})
}
}
impl MySerialize for QueryEvent<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
self.thread_id.serialize(&mut *buf);
self.execution_time.serialize(&mut *buf);
self.schema_len.serialize(&mut *buf);
self.error_code.serialize(&mut *buf);
self.status_vars_len.serialize(&mut *buf);
self.status_vars.serialize(&mut *buf);
self.schema.serialize(&mut *buf);
self.__skip.serialize(&mut *buf);
self.query.serialize(&mut *buf);
}
}
impl<'a> BinlogEvent<'a> for QueryEvent<'a> {
const EVENT_TYPE: EventType = EventType::QUERY_EVENT;
}
impl<'a> BinlogStruct<'a> for QueryEvent<'a> {
fn len(&self, _version: BinlogVersion) -> usize {
let mut len = S(0);
len += S(4);
len += S(4);
len += S(1);
len += S(2);
len += S(2);
len += S(min(self.status_vars.0.len(), u16::MAX as usize));
len += S(min(self.schema.0.len(), u8::MAX as usize));
len += S(1);
len += S(self.query.0.len());
min(len.0, u32::MAX as usize - BinlogEventHeader::LEN)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum StatusVarVal<'a> {
Flags2(RawFlags<Flags2, LeU32>),
SqlMode(RawFlags<SqlMode, LeU64>),
Catalog(&'a [u8]),
AutoIncrement {
increment: u16,
offset: u16,
},
Charset {
charset_client: u16,
collation_connection: u16,
collation_server: u16,
},
TimeZone(RawBytes<'a, U8Bytes>),
CatalogNz(RawBytes<'a, U8Bytes>),
LcTimeNames(u16),
CharsetDatabase(u16),
TableMapForUpdate(u64),
MasterDataWritten([u8; 4]),
Invoker {
username: RawBytes<'a, U8Bytes>,
hostname: RawBytes<'a, U8Bytes>,
},
UpdatedDbNames(Vec<RawBytes<'a, NullBytes>>),
Microseconds(u32),
CommitTs(&'a [u8]),
CommitTs2(&'a [u8]),
ExplicitDefaultsForTimestamp(bool),
DdlLoggedWithXid(u64),
DefaultCollationForUtf8mb4(u16),
SqlRequirePrimaryKey(u8),
DefaultTableEncryption(u8),
}
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVar<'a> {
key: StatusVarKey,
value: &'a [u8],
}
impl StatusVar<'_> {
pub fn get_value(&self) -> Result<StatusVarVal, &[u8]> {
match self.key {
StatusVarKey::Flags2 => {
let mut read = self.value;
read.read_u32::<LittleEndian>()
.map(RawFlags::new)
.map(StatusVarVal::Flags2)
.map_err(|_| self.value)
}
StatusVarKey::SqlMode => {
let mut read = self.value;
read.read_u64::<LittleEndian>()
.map(RawFlags::new)
.map(StatusVarVal::SqlMode)
.map_err(|_| self.value)
}
StatusVarKey::Catalog => Ok(StatusVarVal::Catalog(self.value)),
StatusVarKey::AutoIncrement => {
let mut read = self.value;
let increment = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
let offset = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::AutoIncrement { increment, offset })
}
StatusVarKey::Charset => {
let mut read = self.value;
let charset_client = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
let collation_connection =
read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
let collation_server = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::Charset {
charset_client,
collation_connection,
collation_server,
})
}
StatusVarKey::TimeZone => {
let mut read = self.value;
let len = read.read_u8().map_err(|_| self.value)? as usize;
let text = read.get(..len).ok_or(self.value)?;
Ok(StatusVarVal::TimeZone(RawBytes::new(text)))
}
StatusVarKey::CatalogNz => {
let mut read = self.value;
let len = read.read_u8().map_err(|_| self.value)? as usize;
let text = read.get(..len).ok_or(self.value)?;
Ok(StatusVarVal::CatalogNz(RawBytes::new(text)))
}
StatusVarKey::LcTimeNames => {
let mut read = self.value;
let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::LcTimeNames(val))
}
StatusVarKey::CharsetDatabase => {
let mut read = self.value;
let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::CharsetDatabase(val))
}
StatusVarKey::TableMapForUpdate => {
let mut read = self.value;
let val = read.read_u64::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::TableMapForUpdate(val))
}
StatusVarKey::MasterDataWritten => {
let mut read = self.value;
let mut val = [0u8; 4];
read.read_exact(&mut val).map_err(|_| self.value)?;
Ok(StatusVarVal::MasterDataWritten(val))
}
StatusVarKey::Invoker => {
let mut read = self.value;
let len = read.read_u8().map_err(|_| self.value)? as usize;
let username = read.get(..len).ok_or(self.value)?;
read = &read[len..];
let len = read.read_u8().map_err(|_| self.value)? as usize;
let hostname = read.get(..len).ok_or(self.value)?;
Ok(StatusVarVal::Invoker {
username: RawBytes::new(username),
hostname: RawBytes::new(hostname),
})
}
StatusVarKey::UpdatedDbNames => {
let mut read = self.value;
let count = read.read_u8().map_err(|_| self.value)? as usize;
let mut names = Vec::with_capacity(count);
for _ in 0..count {
let index = read.iter().position(|x| *x == 0).ok_or(self.value)?;
names.push(RawBytes::new(&read[..index]));
read = &read[index..];
}
Ok(StatusVarVal::UpdatedDbNames(names))
}
StatusVarKey::Microseconds => {
let mut read = self.value;
let val = read.read_u32::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::Microseconds(val))
}
StatusVarKey::CommitTs => Ok(StatusVarVal::CommitTs(self.value)),
StatusVarKey::CommitTs2 => Ok(StatusVarVal::CommitTs2(self.value)),
StatusVarKey::ExplicitDefaultsForTimestamp => {
let mut read = self.value;
let val = read.read_u8().map_err(|_| self.value)?;
Ok(StatusVarVal::ExplicitDefaultsForTimestamp(val != 0))
}
StatusVarKey::DdlLoggedWithXid => {
let mut read = self.value;
let val = read.read_u64::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::DdlLoggedWithXid(val))
}
StatusVarKey::DefaultCollationForUtf8mb4 => {
let mut read = self.value;
let val = read.read_u16::<LittleEndian>().map_err(|_| self.value)?;
Ok(StatusVarVal::DefaultCollationForUtf8mb4(val))
}
StatusVarKey::SqlRequirePrimaryKey => {
let mut read = self.value;
let val = read.read_u8().map_err(|_| self.value)?;
Ok(StatusVarVal::SqlRequirePrimaryKey(val))
}
StatusVarKey::DefaultTableEncryption => {
let mut read = self.value;
let val = read.read_u8().map_err(|_| self.value)?;
Ok(StatusVarVal::DefaultTableEncryption(val))
}
}
}
}
impl fmt::Debug for StatusVar<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StatusVar")
.field("key", &self.key)
.field("value", &self.get_value())
.finish()
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVars<'a>(pub RawBytes<'a, BareU16Bytes>);
impl<'a> StatusVars<'a> {
pub fn iter(&'a self) -> StatusVarsIterator<'a> {
StatusVarsIterator::new(self.0.as_bytes())
}
pub fn get_status_var(&'a self, needle: StatusVarKey) -> Option<StatusVar<'a>> {
self.iter()
.find_map(|var| if var.key == needle { Some(var) } else { None })
}
pub fn into_owned(self) -> StatusVars<'static> {
StatusVars(self.0.into_owned())
}
}
impl<'de> MyDeserialize<'de> for StatusVars<'de> {
const SIZE: Option<usize> = None;
type Ctx = u16;
fn deserialize(len: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
Ok(Self(buf.parse(len as usize)?))
}
}
impl MySerialize for StatusVars<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
self.0.serialize(buf);
}
}
impl fmt::Debug for StatusVars<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.iter().fmt(f)
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct StatusVarsIterator<'a> {
pos: usize,
status_vars: &'a [u8],
}
impl<'a> StatusVarsIterator<'a> {
pub fn new(status_vars: &'a [u8]) -> StatusVarsIterator<'a> {
Self {
pos: 0,
status_vars,
}
}
}
impl fmt::Debug for StatusVarsIterator<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl<'a> Iterator for StatusVarsIterator<'a> {
type Item = StatusVar<'a>;
fn next(&mut self) -> Option<Self::Item> {
let key = *self.status_vars.get(self.pos)?;
let key = StatusVarKey::try_from(key).ok()?;
self.pos += 1;
macro_rules! get_fixed {
($len:expr) => {{
self.pos += $len;
self.status_vars.get((self.pos - $len)..self.pos)?
}};
}
macro_rules! get_var {
($suffix_len:expr) => {{
let len = *self.status_vars.get(self.pos)? as usize;
get_fixed!(1 + len + $suffix_len)
}};
}
let value = match key {
StatusVarKey::Flags2 => get_fixed!(4),
StatusVarKey::SqlMode => get_fixed!(8),
StatusVarKey::Catalog => get_var!(1),
StatusVarKey::AutoIncrement => get_fixed!(4),
StatusVarKey::Charset => get_fixed!(6),
StatusVarKey::TimeZone => get_var!(0),
StatusVarKey::CatalogNz => get_var!(0),
StatusVarKey::LcTimeNames => get_fixed!(2),
StatusVarKey::CharsetDatabase => get_fixed!(2),
StatusVarKey::TableMapForUpdate => get_fixed!(8),
StatusVarKey::MasterDataWritten => get_fixed!(4),
StatusVarKey::Invoker => {
let user_len = *self.status_vars.get(self.pos)? as usize;
let host_len = *self.status_vars.get(self.pos + 1 + user_len)? as usize;
get_fixed!(1 + user_len + 1 + host_len)
}
StatusVarKey::UpdatedDbNames => {
let mut total = 1;
let count = *self.status_vars.get(self.pos)? as usize;
for _ in 0..count {
while *self.status_vars.get(self.pos + total)? != 0x00 {
total += 1;
}
total += 1;
}
get_fixed!(total)
}
StatusVarKey::Microseconds => get_fixed!(3),
StatusVarKey::CommitTs => get_fixed!(0),
StatusVarKey::CommitTs2 => get_fixed!(0),
StatusVarKey::ExplicitDefaultsForTimestamp => get_fixed!(1),
StatusVarKey::DdlLoggedWithXid => get_fixed!(8),
StatusVarKey::DefaultCollationForUtf8mb4 => get_fixed!(2),
StatusVarKey::SqlRequirePrimaryKey => get_fixed!(1),
StatusVarKey::DefaultTableEncryption => get_fixed!(1),
};
Some(StatusVar { key, value })
}
}