use futures_core::ready;
use mysql_common::{
binlog::{
consts::{BinlogVersion::Version4, EventType},
events::{Event, TableMapEvent, TransactionPayloadEvent},
EventStreamReader,
},
io::ParseBuf,
packets::{ComRegisterSlave, ErrPacket, NetworkStreamTerminator, OkPacketDeserializer},
};
use std::{
future::Future,
io::{Cursor, ErrorKind},
pin::Pin,
task::{Context, Poll},
};
use crate::{connection_like::Connection, queryable::Queryable};
use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result};
use self::request::BinlogStreamRequest;
pub mod request;
impl super::Conn {
pub async fn get_binlog_stream(
mut self,
request: BinlogStreamRequest<'_>,
) -> Result<BinlogStream> {
self.request_binlog(request).await?;
Ok(BinlogStream::new(self))
}
async fn register_as_slave(
&mut self,
com_register_slave: ComRegisterSlave<'_>,
) -> crate::Result<()> {
self.query_drop("SET @master_binlog_checksum='ALL'").await?;
self.write_command(&com_register_slave).await?;
self.read_packet().await?;
Ok(())
}
async fn request_binlog(&mut self, request: BinlogStreamRequest<'_>) -> crate::Result<()> {
self.register_as_slave(request.register_slave).await?;
self.write_command(&request.binlog_request.as_cmd()).await?;
Ok(())
}
}
pub struct BinlogStream {
read_packet: ReadPacket<'static, 'static>,
esr: EventStreamReader,
tpe: Option<Cursor<Vec<u8>>>,
}
impl BinlogStream {
pub(super) fn new(conn: Conn) -> Self {
BinlogStream {
read_packet: ReadPacket::new(conn),
esr: EventStreamReader::new(Version4),
tpe: None,
}
}
pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> {
self.esr.get_tme(table_id)
}
pub async fn close(self) -> Result<()> {
match self.read_packet.0 {
Connection::Conn(conn) => {
if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await {
if error.kind() == ErrorKind::BrokenPipe {
return Ok(());
}
}
}
Connection::ConnMut(_) => {}
Connection::Tx(_) => {}
}
Ok(())
}
}
impl futures_core::stream::Stream for BinlogStream {
type Item = Result<Event>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
{
let Self {
ref mut tpe,
ref mut esr,
..
} = *self;
if let Some(tpe) = tpe.as_mut() {
match esr.read_decompressed(tpe) {
Ok(Some(event)) => return Poll::Ready(Some(Ok(event))),
Ok(None) => self.tpe = None,
Err(err) => return Poll::Ready(Some(Err(err.into()))),
}
}
}
let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) {
Ok(packet) => packet,
Err(err) => return Poll::Ready(Some(Err(err.into()))),
};
let first_byte = packet.first().copied();
if first_byte == Some(255) {
if let Ok(ErrPacket::Error(err)) =
ParseBuf(&packet).parse(self.read_packet.conn_ref().capabilities())
{
return Poll::Ready(Some(Err(From::from(err))));
}
}
if first_byte == Some(254)
&& packet.len() < 8
&& ParseBuf(&packet)
.parse::<OkPacketDeserializer<NetworkStreamTerminator>>(
self.read_packet.conn_ref().capabilities(),
)
.is_ok()
{
return Poll::Ready(None);
}
if first_byte == Some(0) {
let event_data = &packet[1..];
match self.esr.read(event_data) {
Ok(Some(event)) => {
if event.header().event_type_raw() == EventType::TRANSACTION_PAYLOAD_EVENT as u8
{
#[allow(clippy::single_match)]
match event.read_event::<TransactionPayloadEvent<'_>>() {
Ok(e) => self.tpe = Some(Cursor::new(e.danger_decompress())),
Err(_) => (),
}
}
Poll::Ready(Some(Ok(event)))
}
Ok(None) => Poll::Ready(None),
Err(err) => Poll::Ready(Some(Err(err.into()))),
}
} else {
Poll::Ready(Some(Err(DriverError::UnexpectedPacket {
payload: packet.to_vec(),
}
.into())))
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures_util::StreamExt;
use mysql_common::binlog::events::EventData;
use tokio::time::timeout;
use crate::prelude::*;
use crate::{test_misc::get_opts, *};
async fn gen_dummy_data(conn: &mut Conn) -> super::Result<()> {
"CREATE TABLE IF NOT EXISTS customers (customer_id int not null)"
.ignore(&mut *conn)
.await?;
let mut tx = conn.start_transaction(Default::default()).await?;
for i in 0_u8..100 {
"INSERT INTO customers(customer_id) VALUES (?)"
.with((i,))
.ignore(&mut tx)
.await?;
}
tx.commit().await?;
"DROP TABLE customers".ignore(conn).await?;
Ok(())
}
async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec<u8>, u64)> {
let mut conn = match pool {
None => Conn::new(get_opts()).await.unwrap(),
Some(pool) => pool.get_conn().await.unwrap(),
};
if conn.server_version() >= (8, 0, 31) && conn.server_version() < (9, 0, 0) {
let _ = "SET binlog_transaction_compression=ON"
.ignore(&mut conn)
.await;
}
if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE"
.first::<String, _>(&mut conn)
.await
{
if !gtid_mode.starts_with("ON") {
panic!(
"GTID_MODE is disabled \
(enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
);
}
}
let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap();
let filename = row.get(0).unwrap();
let position = row.get(1).unwrap();
gen_dummy_data(&mut conn).await.unwrap();
Ok((conn, filename, position))
}
#[tokio::test]
async fn should_read_binlog() -> super::Result<()> {
read_binlog_streams_and_close_their_connections(None, (12, 13, 14))
.await
.unwrap();
let pool = Pool::new(get_opts());
read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17))
.await
.unwrap();
timeout(Duration::from_secs(10), pool.disconnect())
.await
.unwrap()
.unwrap();
Ok(())
}
async fn read_binlog_streams_and_close_their_connections(
pool: Option<&Pool>,
binlog_server_ids: (u32, u32, u32),
) -> super::Result<()> {
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let is_mariadb = conn.inner.is_mariadb;
let mut binlog_stream = conn
.get_binlog_stream(
BinlogStreamRequest::new(binlog_server_ids.0)
.with_filename(&filename)
.with_pos(pos),
)
.await
.unwrap();
let mut events_num = 0;
while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await {
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
if let EventData::RowsEvent(re) = event.read_data()?.unwrap() {
let tme = binlog_stream.get_tme(re.table_id());
for row in re.rows(tme.unwrap()) {
row.unwrap();
}
}
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
if !is_mariadb {
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let mut binlog_stream = conn
.get_binlog_stream(
BinlogStreamRequest::new(binlog_server_ids.1)
.with_gtid()
.with_filename(&filename)
.with_pos(pos),
)
.await
.unwrap();
events_num = 0;
while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await
{
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
if let EventData::RowsEvent(re) = event.read_data()?.unwrap() {
let tme = binlog_stream.get_tme(re.table_id());
for row in re.rows(tme.unwrap()) {
row.unwrap();
}
}
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
}
let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap();
let mut binlog_stream = conn
.get_binlog_stream(
BinlogStreamRequest::new(binlog_server_ids.2)
.with_filename(&filename)
.with_pos(pos)
.with_non_blocking(),
)
.await
.unwrap();
events_num = 0;
while let Some(event) = binlog_stream.next().await {
let event = event.unwrap();
events_num += 1;
event.header().event_type().unwrap();
event.read_data().unwrap();
}
assert!(events_num > 0);
timeout(Duration::from_secs(10), binlog_stream.close())
.await
.unwrap()
.unwrap();
Ok(())
}
}