use std::borrow::Cow;
use std::future::IntoFuture;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::Context;
use derivative::Derivative;
use futures::future::BoxFuture;
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use mz_ore::result::ResultExt;
use smallvec::{smallvec, SmallVec};
use tiberius::ToSql;
use tokio::net::TcpStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
pub mod cdc;
pub mod desc;
pub mod inspect;
pub use tiberius::Config;
#[derive(Debug)]
pub struct Client {
tx: UnboundedSender<Request>,
}
static_assertions::assert_not_impl_all!(Client: Clone);
impl Client {
pub async fn connect(config: tiberius::Config) -> Result<(Self, Connection), SqlServerError> {
let tcp = TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
Self::connect_raw(config, tcp).await
}
pub async fn connect_raw(
config: tiberius::Config,
tcp: tokio::net::TcpStream,
) -> Result<(Self, Connection), SqlServerError> {
let client = tiberius::Client::connect(config, tcp.compat_write())
.await
.context("connecting to SQL Server")?;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
Ok((Client { tx }, Connection { rx, client }))
}
pub async fn execute<'a>(
&mut self,
query: impl Into<Cow<'a, str>>,
params: &[&dyn ToSql],
) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
let (tx, rx) = tokio::sync::oneshot::channel();
let params = params
.iter()
.map(|p| OwnedColumnData::from(p.to_sql()))
.collect();
let kind = RequestKind::Execute {
query: query.into().to_string(),
params,
};
self.tx
.send(Request { tx, kind })
.context("sending request")?;
let response = rx.await.context("channel")?.context("execute")?;
match response {
Response::Execute { rows_affected } => Ok(rows_affected),
other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
Err(SqlServerError::ProgrammingError(format!(
"expected Response::Execute, got {other:?}"
)))
}
}
}
pub async fn query<'a>(
&mut self,
query: impl Into<Cow<'a, str>>,
params: &[&dyn tiberius::ToSql],
) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
let (tx, rx) = tokio::sync::oneshot::channel();
let params = params
.iter()
.map(|p| OwnedColumnData::from(p.to_sql()))
.collect();
let kind = RequestKind::Query {
query: query.into().to_string(),
params,
};
self.tx
.send(Request { tx, kind })
.context("sending request")?;
let response = rx.await.context("channel")?.context("query")?;
match response {
Response::Rows(rows) => Ok(rows),
other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
),
}
}
pub fn query_streaming<'a>(
&mut self,
query: impl Into<Cow<'a, str>>,
params: &[&dyn tiberius::ToSql],
) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + '_ {
let (tx, rx) = tokio::sync::oneshot::channel();
let params = params
.iter()
.map(|p| OwnedColumnData::from(p.to_sql()))
.collect();
let kind = RequestKind::QueryStreamed {
query: query.into().to_string(),
params,
};
let request_future = async move {
self.tx
.send(Request { tx, kind })
.context("sending request")?;
let response = rx.await.context("channel")??;
match response {
Response::RowStream { stream } => {
Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
}
other @ Response::Execute { .. } | other @ Response::Rows(_) => {
Err(SqlServerError::ProgrammingError(format!(
"expected Response::Rows, got {other:?}"
)))
}
}
};
futures::stream::once(request_future).try_flatten()
}
pub async fn simple_query<'a>(
&mut self,
query: impl Into<Cow<'a, str>>,
) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
let (tx, rx) = tokio::sync::oneshot::channel();
let kind = RequestKind::SimpleQuery {
query: query.into().to_string(),
};
self.tx
.send(Request { tx, kind })
.context("sending request")?;
let response = rx.await.context("channel")?.context("simple_query")?;
match response {
Response::Rows(rows) => Ok(rows),
other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
),
}
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
Transaction::new(self).await
}
pub async fn set_transaction_isolation(
&mut self,
level: TransactionIsolationLevel,
) -> Result<(), SqlServerError> {
let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
self.simple_query(query).await?;
Ok(())
}
pub async fn get_transaction_isolation(
&mut self,
) -> Result<TransactionIsolationLevel, SqlServerError> {
const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
let rows = self.simple_query(QUERY).await?;
match &rows[..] {
[row] => {
let val: i16 = row
.try_get(0)
.context("getting 0th column")?
.ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
let level = TransactionIsolationLevel::try_from_sql_server(val)?;
Ok(level)
}
other => Err(SqlServerError::InvariantViolated(format!(
"expected one row, got {other:?}"
))),
}
}
pub fn cdc<I>(&mut self, capture_instances: I) -> crate::cdc::CdcStream<'_>
where
I: IntoIterator,
I::Item: Into<Arc<str>>,
{
let instances = capture_instances
.into_iter()
.map(|i| (i.into(), None))
.collect();
crate::cdc::CdcStream::new(self, instances)
}
}
pub type RowStream<'a> =
Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
#[derive(Debug)]
pub struct Transaction<'a> {
client: &'a mut Client,
closed: bool,
}
impl<'a> Transaction<'a> {
async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
let results = client
.simple_query("BEGIN TRANSACTION")
.await
.context("begin")?;
if !results.is_empty() {
Err(SqlServerError::InvariantViolated(format!(
"expected empty result from BEGIN TRANSACTION. Got: {results:?}"
)))
} else {
Ok(Transaction {
client,
closed: false,
})
}
}
pub async fn execute<'q>(
&mut self,
query: impl Into<Cow<'q, str>>,
params: &[&dyn ToSql],
) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
self.client.execute(query, params).await
}
pub async fn query<'q>(
&mut self,
query: impl Into<Cow<'q, str>>,
params: &[&dyn tiberius::ToSql],
) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
self.client.query(query, params).await
}
pub fn query_streaming<'q>(
&mut self,
query: impl Into<Cow<'q, str>>,
params: &[&dyn tiberius::ToSql],
) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + '_ {
self.client.query_streaming(query, params)
}
pub async fn simple_query<'q>(
&mut self,
query: impl Into<Cow<'q, str>>,
) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
self.client.simple_query(query).await
}
pub async fn rollback(mut self) -> Result<(), SqlServerError> {
static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
self.closed = true;
self.client.simple_query(ROLLBACK_QUERY).await?;
Ok(())
}
pub async fn commit(mut self) -> Result<(), SqlServerError> {
static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
self.closed = true;
self.client.simple_query(COMMIT_QUERY).await?;
Ok(())
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.closed {
let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum TransactionIsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Snapshot,
Serializable,
}
impl TransactionIsolationLevel {
fn as_str(&self) -> &'static str {
match self {
TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
TransactionIsolationLevel::Snapshot => "SNAPSHOT",
TransactionIsolationLevel::Serializable => "SERIALIZABLE",
}
}
fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
let level = match val {
1 => TransactionIsolationLevel::ReadUncommitted,
2 => TransactionIsolationLevel::ReadCommitted,
3 => TransactionIsolationLevel::RepeatableRead,
4 => TransactionIsolationLevel::Serializable,
5 => TransactionIsolationLevel::Snapshot,
x => anyhow::bail!("unknown level {x}"),
};
Ok(level)
}
}
#[derive(Derivative)]
#[derivative(Debug)]
enum Response {
Execute {
rows_affected: SmallVec<[u64; 1]>,
},
Rows(SmallVec<[tiberius::Row; 1]>),
RowStream {
#[derivative(Debug = "ignore")]
stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
},
}
#[derive(Debug)]
struct Request {
tx: oneshot::Sender<Result<Response, SqlServerError>>,
kind: RequestKind,
}
#[derive(Derivative)]
#[derivative(Debug)]
enum RequestKind {
Execute {
query: String,
#[derivative(Debug = "ignore")]
params: SmallVec<[OwnedColumnData; 4]>,
},
Query {
query: String,
#[derivative(Debug = "ignore")]
params: SmallVec<[OwnedColumnData; 4]>,
},
QueryStreamed {
query: String,
#[derivative(Debug = "ignore")]
params: SmallVec<[OwnedColumnData; 4]>,
},
SimpleQuery {
query: String,
},
}
pub struct Connection {
rx: UnboundedReceiver<Request>,
client: tiberius::Client<Compat<TcpStream>>,
}
impl Connection {
async fn run(mut self) {
while let Some(Request { tx, kind }) = self.rx.recv().await {
tracing::debug!(?kind, "processing SQL Server query");
let result = Connection::handle_request(&mut self.client, kind).await;
let (response, maybe_extra_work) = match result {
Ok((response, work)) => (Ok(response), work),
Err(err) => (Err(err), None),
};
let _ = tx.send(response);
if let Some(extra_work) = maybe_extra_work {
extra_work.await;
}
}
tracing::debug!("channel closed, SQL Server InnerClient shutting down");
}
async fn handle_request<'c>(
client: &'c mut tiberius::Client<Compat<TcpStream>>,
kind: RequestKind,
) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
match kind {
RequestKind::Execute { query, params } => {
#[allow(clippy::as_conversions)]
let params: SmallVec<[&dyn ToSql; 4]> =
params.iter().map(|x| x as &dyn ToSql).collect();
let result = client.execute(query, ¶ms[..]).await?;
match result.rows_affected() {
[] => Err(SqlServerError::InvariantViolated(
"got empty response".into(),
)),
rows_affected => {
let response = Response::Execute {
rows_affected: rows_affected.into(),
};
Ok((response, None))
}
}
}
RequestKind::Query { query, params } => {
#[allow(clippy::as_conversions)]
let params: SmallVec<[&dyn ToSql; 4]> =
params.iter().map(|x| x as &dyn ToSql).collect();
let result = client.query(query, params.as_slice()).await?;
let mut results = result.into_results().await.context("into results")?;
if results.is_empty() {
Err(SqlServerError::InvariantViolated(
"got empty response".into(),
))
} else if results.len() == 1 {
let rows = results.pop().expect("checked len").into();
Ok((Response::Rows(rows), None))
} else {
Err(SqlServerError::ProgrammingError(format!(
"Query only supports 1 statement, got {}",
results.len()
)))
}
}
RequestKind::QueryStreamed { query, params } => {
#[allow(clippy::as_conversions)]
let params: SmallVec<[&dyn ToSql; 4]> =
params.iter().map(|x| x as &dyn ToSql).collect();
let result = client.query(query, params.as_slice()).await?;
let (tx, rx) = tokio::sync::mpsc::channel(256);
let work = Box::pin(async move {
let mut stream = result.into_row_stream();
while let Some(result) = stream.next().await {
if let Err(err) = tx.send(result.err_into()).await {
tracing::warn!(?err, "SQL Server row stream receiver went away");
}
}
tracing::info!("SQL Server row stream complete");
});
Ok((Response::RowStream { stream: rx }, Some(work)))
}
RequestKind::SimpleQuery { query } => {
let result = client.simple_query(query).await?;
let mut results = result.into_results().await.context("into results")?;
if results.is_empty() {
Ok((Response::Rows(smallvec![]), None))
} else if results.len() == 1 {
let rows = results.pop().expect("checked len").into();
Ok((Response::Rows(rows), None))
} else {
Err(SqlServerError::ProgrammingError(format!(
"Simple query only supports 1 statement, got {}",
results.len()
)))
}
}
}
}
}
impl IntoFuture for Connection {
type Output = ();
type IntoFuture = BoxFuture<'static, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
self.run().boxed()
}
}
#[derive(Debug)]
enum OwnedColumnData {
U8(Option<u8>),
I16(Option<i16>),
I32(Option<i32>),
I64(Option<i64>),
F32(Option<f32>),
F64(Option<f64>),
Bit(Option<bool>),
String(Option<String>),
Guid(Option<uuid::Uuid>),
Binary(Option<Vec<u8>>),
Numeric(Option<tiberius::numeric::Numeric>),
Xml(Option<tiberius::xml::XmlData>),
DateTime(Option<tiberius::time::DateTime>),
SmallDateTime(Option<tiberius::time::SmallDateTime>),
Time(Option<tiberius::time::Time>),
Date(Option<tiberius::time::Date>),
DateTime2(Option<tiberius::time::DateTime2>),
DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
}
impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
fn from(value: tiberius::ColumnData<'a>) -> Self {
match value {
tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
tiberius::ColumnData::String(inner) => {
OwnedColumnData::String(inner.map(|s| s.to_string()))
}
tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
tiberius::ColumnData::Binary(inner) => {
OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
}
tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
}
}
}
impl tiberius::ToSql for OwnedColumnData {
fn to_sql(&self) -> tiberius::ColumnData<'_> {
match self {
OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
OwnedColumnData::String(inner) => {
tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
}
OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
OwnedColumnData::Binary(inner) => {
tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
}
OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
OwnedColumnData::Xml(inner) => {
tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
}
OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
}
}
}
impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
fn from(value: &'a T) -> Self {
OwnedColumnData::from(value.to_sql())
}
}
#[derive(Debug, thiserror::Error)]
pub enum SqlServerError {
#[error(transparent)]
SqlServer(#[from] tiberius::error::Error),
#[error(transparent)]
CdcError(#[from] crate::cdc::CdcError),
#[error("'{column_type}' from column '{column_name}' is not supported: {reason}")]
UnsupportedDataType {
column_name: String,
column_type: String,
reason: String,
},
#[error("sql server client encountered I/O error: {0}")]
IO(#[from] tokio::io::Error),
#[error("found invalid data in the column '{column_name}': {error}")]
InvalidData { column_name: String, error: String },
#[error("invariant was violated: {0}")]
InvariantViolated(String),
#[error(transparent)]
Generic(#[from] anyhow::Error),
#[error("programming error! {0}")]
ProgrammingError(String),
}