use anyhow::Context;
use dec::OrderedDecimal;
use mz_ore::cast::CastFrom;
use mz_proto::{IntoRustIfSome, ProtoType, RustType};
use mz_repr::adt::numeric::{Dec, Numeric, NumericMaxScale};
use mz_repr::{ColumnType, Datum, RelationDesc, Row, ScalarType};
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::SqlServerError;
include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
pub struct SqlServerTableDesc {
pub schema_name: Arc<str>,
pub name: Arc<str>,
pub columns: Arc<[SqlServerColumnDesc]>,
pub is_cdc_enabled: bool,
}
impl SqlServerTableDesc {
pub fn try_new(raw: SqlServerTableRaw) -> Result<Self, SqlServerError> {
let columns: Arc<[_]> = raw
.columns
.into_iter()
.map(SqlServerColumnDesc::try_new)
.collect::<Result<_, _>>()?;
Ok(SqlServerTableDesc {
schema_name: raw.schema_name,
name: raw.name,
columns,
is_cdc_enabled: raw.is_cdc_enabled,
})
}
pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
let decoder = SqlServerRowDecoder::try_new(self, desc)?;
Ok(decoder)
}
}
impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
fn into_proto(&self) -> ProtoSqlServerTableDesc {
ProtoSqlServerTableDesc {
name: self.name.to_string(),
schema_name: self.schema_name.to_string(),
columns: self.columns.iter().map(|c| c.into_proto()).collect(),
is_cdc_enabled: self.is_cdc_enabled,
}
}
fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
let columns = proto
.columns
.into_iter()
.map(|c| c.into_rust())
.collect::<Result<_, _>>()?;
Ok(SqlServerTableDesc {
schema_name: proto.schema_name.into(),
name: proto.name.into(),
columns,
is_cdc_enabled: proto.is_cdc_enabled,
})
}
}
#[derive(Debug, Clone)]
pub struct SqlServerTableRaw {
pub schema_name: Arc<str>,
pub name: Arc<str>,
pub columns: Arc<[SqlServerColumnRaw]>,
pub is_cdc_enabled: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
pub struct SqlServerColumnDesc {
pub name: Arc<str>,
pub column_type: ColumnType,
pub decode_type: SqlServerColumnDecodeType,
}
impl SqlServerColumnDesc {
pub fn try_new(raw: &SqlServerColumnRaw) -> Result<Self, SqlServerError> {
let (scalar_type, decode_type) = parse_data_type(raw)?;
Ok(SqlServerColumnDesc {
name: Arc::clone(&raw.name),
column_type: scalar_type.nullable(raw.is_nullable),
decode_type,
})
}
}
impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
fn into_proto(&self) -> ProtoSqlServerColumnDesc {
ProtoSqlServerColumnDesc {
name: self.name.to_string(),
column_type: Some(self.column_type.into_proto()),
decode_type: Some(self.decode_type.into_proto()),
}
}
fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
Ok(SqlServerColumnDesc {
name: proto.name.into(),
column_type: proto
.column_type
.into_rust_if_some("ProtoSqlServerColumnDesc::column_type")?,
decode_type: proto
.decode_type
.into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
})
}
}
fn parse_data_type(
raw: &SqlServerColumnRaw,
) -> Result<(ScalarType, SqlServerColumnDecodeType), SqlServerError> {
let scalar = match raw.data_type.to_lowercase().as_str() {
"tinyint" => (ScalarType::Int16, SqlServerColumnDecodeType::U8),
"smallint" => (ScalarType::Int16, SqlServerColumnDecodeType::I16),
"int" => (ScalarType::Int32, SqlServerColumnDecodeType::I32),
"bigint" => (ScalarType::Int64, SqlServerColumnDecodeType::I64),
"bit" => (ScalarType::Bool, SqlServerColumnDecodeType::Bool),
"decimal" | "numeric" => {
if raw.precision > 38 || raw.scale > raw.precision {
tracing::warn!(
"unexpected value from SQL Server, precision of {} and scale of {}",
raw.precision,
raw.scale,
);
}
if raw.precision > 39 {
let reason = format!(
"precision of {} is greater than our maximum of 39",
raw.precision
);
return Err(SqlServerError::UnsupportedDataType {
column_name: raw.name.to_string(),
column_type: raw.data_type.to_string(),
reason,
});
}
let raw_scale = usize::cast_from(raw.scale);
let max_scale = NumericMaxScale::try_from(raw_scale).map_err(|_| {
SqlServerError::UnsupportedDataType {
column_type: raw.data_type.to_string(),
column_name: raw.name.to_string(),
reason: format!("scale of {} is too large", raw.scale),
}
})?;
let column_type = ScalarType::Numeric {
max_scale: Some(max_scale),
};
(column_type, SqlServerColumnDecodeType::Numeric)
}
"real" => (ScalarType::Float32, SqlServerColumnDecodeType::F32),
"double" => (ScalarType::Float64, SqlServerColumnDecodeType::F64),
"char" | "nchar" | "varchar" | "nvarchar" | "sysname" => {
if raw.max_length == -1 {
return Err(SqlServerError::UnsupportedDataType {
column_name: raw.name.to_string(),
column_type: raw.data_type.to_string(),
reason: "columns with unlimited size do not support CDC".to_string(),
});
}
(ScalarType::String, SqlServerColumnDecodeType::String)
}
"text" | "ntext" | "image" => {
mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
return Err(SqlServerError::UnsupportedDataType {
column_name: raw.name.to_string(),
column_type: raw.data_type.to_string(),
reason: "columns with unlimited size do not support CDC".to_string(),
});
}
"xml" => {
if raw.max_length == -1 {
return Err(SqlServerError::UnsupportedDataType {
column_name: raw.name.to_string(),
column_type: raw.data_type.to_string(),
reason: "columns with unlimited size do not support CDC".to_string(),
});
}
(ScalarType::String, SqlServerColumnDecodeType::Xml)
}
"binary" | "varbinary" => {
if raw.max_length == -1 {
return Err(SqlServerError::UnsupportedDataType {
column_name: raw.name.to_string(),
column_type: raw.data_type.to_string(),
reason: "columns with unlimited size do not support CDC".to_string(),
});
}
(ScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
}
"json" => (ScalarType::Jsonb, SqlServerColumnDecodeType::String),
"date" => (ScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
"time" => (ScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
"smalldatetime" | "datetime" | "datetime2" => (
ScalarType::Timestamp { precision: None },
SqlServerColumnDecodeType::NaiveDateTime,
),
"datetimeoffset" => (
ScalarType::TimestampTz { precision: None },
SqlServerColumnDecodeType::DateTime,
),
"uniqueidentifier" => (ScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
other => {
return Err(SqlServerError::UnsupportedDataType {
column_type: other.to_string(),
column_name: raw.name.to_string(),
reason: "unimplemented".to_string(),
})
}
};
Ok(scalar)
}
#[derive(Clone, Debug)]
pub struct SqlServerColumnRaw {
pub name: Arc<str>,
pub data_type: Arc<str>,
pub is_nullable: bool,
pub max_length: i16,
pub precision: u8,
pub scale: u8,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
pub enum SqlServerColumnDecodeType {
Bool,
U8,
I16,
I32,
I64,
F32,
F64,
String,
Bytes,
Uuid,
Numeric,
Xml,
NaiveDate,
NaiveTime,
DateTime,
NaiveDateTime,
}
impl SqlServerColumnDecodeType {
pub fn decode<'a>(
self,
data: &'a tiberius::Row,
name: &'a str,
column: &'a ColumnType,
) -> Result<Datum<'a>, SqlServerError> {
let maybe_datum = match (&column.scalar_type, self) {
(ScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
.try_get(name)
.context("bool")?
.map(|val: bool| if val { Datum::True } else { Datum::False }),
(ScalarType::Int16, SqlServerColumnDecodeType::U8) => data
.try_get(name)
.context("u8")?
.map(|val: u8| Datum::Int16(i16::cast_from(val))),
(ScalarType::Int16, SqlServerColumnDecodeType::I16) => {
data.try_get(name).context("i16")?.map(Datum::Int16)
}
(ScalarType::Int32, SqlServerColumnDecodeType::I32) => {
data.try_get(name).context("i32")?.map(Datum::Int32)
}
(ScalarType::Int64, SqlServerColumnDecodeType::I64) => {
data.try_get(name).context("i64")?.map(Datum::Int64)
}
(ScalarType::Float32, SqlServerColumnDecodeType::F32) => data
.try_get(name)
.context("f32")?
.map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
(ScalarType::Float64, SqlServerColumnDecodeType::F64) => data
.try_get(name)
.context("f64")?
.map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
(ScalarType::String, SqlServerColumnDecodeType::String) => {
data.try_get(name).context("string")?.map(Datum::String)
}
(ScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => {
data.try_get(name).context("bytes")?.map(Datum::Bytes)
}
(ScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => {
data.try_get(name).context("uuid")?.map(Datum::Uuid)
}
(ScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
.try_get(name)
.context("numeric")?
.map(|val: tiberius::numeric::Numeric| {
let numeric = Numeric::context()
.parse(val.to_string())
.context("parsing")?;
Ok::<_, SqlServerError>(Datum::Numeric(OrderedDecimal(numeric)))
})
.transpose()?,
(ScalarType::String, SqlServerColumnDecodeType::Xml) => data
.try_get(name)
.context("xml")?
.map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
(ScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
.try_get(name)
.context("date")?
.map(|val: chrono::NaiveDate| {
let date = val.try_into().context("parse date")?;
Ok::<_, SqlServerError>(Datum::Date(date))
})
.transpose()?,
(ScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => {
data.try_get(name).context("time")?.map(Datum::Time)
}
(ScalarType::Timestamp { .. }, SqlServerColumnDecodeType::NaiveDateTime) => data
.try_get(name)
.context("timestamp")?
.map(|val: chrono::NaiveDateTime| {
let ts = val.try_into().context("parse timestamp")?;
Ok::<_, SqlServerError>(Datum::Timestamp(ts))
})
.transpose()?,
(ScalarType::TimestampTz { .. }, SqlServerColumnDecodeType::DateTime) => data
.try_get(name)
.context("timestamptz")?
.map(|val: chrono::DateTime<chrono::Utc>| {
let ts = val.try_into().context("parse timestamptz")?;
Ok::<_, SqlServerError>(Datum::TimestampTz(ts))
})
.transpose()?,
(column_type, decode_type) => {
let msg = format!("don't know how to parse {decode_type:?} as {column_type:?}");
return Err(SqlServerError::ProgrammingError(msg));
}
};
match (maybe_datum, column.nullable) {
(Some(datum), _) => Ok(datum),
(None, true) => Ok(Datum::Null),
(None, false) => Err(SqlServerError::InvalidData {
column_name: name.to_string(),
error: "found Null in non-nullable column".to_string(),
}),
}
}
}
impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
match self {
SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
SqlServerColumnDecodeType::String => {
proto_sql_server_column_desc::DecodeType::String(())
}
SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
SqlServerColumnDecodeType::Numeric => {
proto_sql_server_column_desc::DecodeType::Numeric(())
}
SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
SqlServerColumnDecodeType::NaiveDate => {
proto_sql_server_column_desc::DecodeType::NaiveDate(())
}
SqlServerColumnDecodeType::NaiveTime => {
proto_sql_server_column_desc::DecodeType::NaiveTime(())
}
SqlServerColumnDecodeType::DateTime => {
proto_sql_server_column_desc::DecodeType::DateTime(())
}
SqlServerColumnDecodeType::NaiveDateTime => {
proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
}
}
}
fn from_proto(
proto: proto_sql_server_column_desc::DecodeType,
) -> Result<Self, mz_proto::TryFromProtoError> {
let val = match proto {
proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
proto_sql_server_column_desc::DecodeType::String(()) => {
SqlServerColumnDecodeType::String
}
proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
proto_sql_server_column_desc::DecodeType::Numeric(()) => {
SqlServerColumnDecodeType::Numeric
}
proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
SqlServerColumnDecodeType::NaiveDate
}
proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
SqlServerColumnDecodeType::NaiveTime
}
proto_sql_server_column_desc::DecodeType::DateTime(()) => {
SqlServerColumnDecodeType::DateTime
}
proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
SqlServerColumnDecodeType::NaiveDateTime
}
};
Ok(val)
}
}
pub struct SqlServerRowDecoder {
decoders: Vec<(Arc<str>, ColumnType, SqlServerColumnDecodeType)>,
}
impl SqlServerRowDecoder {
pub fn try_new(
table: &SqlServerTableDesc,
desc: &RelationDesc,
) -> Result<Self, SqlServerError> {
let decoders = desc
.iter()
.map(|(col_name, col_type)| {
let sql_server_col = table
.columns
.iter()
.find(|col| col.name.as_ref() == col_name.as_str())
.ok_or_else(|| {
anyhow::anyhow!("no SQL Server column with name {col_name} found")
})?;
if &sql_server_col.column_type != col_type {
return Err(SqlServerError::ProgrammingError(format!(
"programming error, {col_name} has mismatched type {:?} vs {:?}",
sql_server_col.column_type, col_type
)));
}
let name = Arc::clone(&sql_server_col.name);
let decoder = sql_server_col.decode_type;
Ok::<_, SqlServerError>((name, col_type.clone(), decoder))
})
.collect::<Result<_, _>>()?;
Ok(SqlServerRowDecoder { decoders })
}
pub fn decode(&self, data: &tiberius::Row, row: &mut Row) -> Result<(), SqlServerError> {
let mut packer = row.packer();
for (col_name, col_type, decoder) in &self.decoders {
let datum = decoder.decode(data, col_name, col_type)?;
packer.push(datum);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::desc::{
SqlServerColumnDecodeType, SqlServerColumnDesc, SqlServerTableDesc, SqlServerTableRaw,
};
use super::SqlServerColumnRaw;
use mz_ore::assert_contains;
use mz_repr::adt::numeric::NumericMaxScale;
use mz_repr::{Datum, RelationDesc, Row, ScalarType};
use tiberius::RowTestExt;
impl SqlServerColumnRaw {
fn new(name: &str, data_type: &str) -> Self {
SqlServerColumnRaw {
name: name.into(),
data_type: data_type.into(),
is_nullable: false,
max_length: 0,
precision: 0,
scale: 0,
}
}
fn nullable(mut self, nullable: bool) -> Self {
self.is_nullable = nullable;
self
}
fn max_length(mut self, max_length: i16) -> Self {
self.max_length = max_length;
self
}
fn precision(mut self, precision: u8) -> Self {
self.precision = precision;
self
}
fn scale(mut self, scale: u8) -> Self {
self.scale = scale;
self
}
}
#[mz_ore::test]
fn smoketest_column_raw() {
let raw = SqlServerColumnRaw::new("foo", "bit");
let col = SqlServerColumnDesc::try_new(&raw).unwrap();
assert_eq!(&*col.name, "foo");
assert_eq!(col.column_type, ScalarType::Bool.nullable(false));
assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
let raw = SqlServerColumnRaw::new("foo", "decimal")
.precision(20)
.scale(10);
let col = SqlServerColumnDesc::try_new(&raw).unwrap();
let col_type = ScalarType::Numeric {
max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
}
.nullable(false);
assert_eq!(col.column_type, col_type);
assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
}
#[mz_ore::test]
fn smoketest_column_raw_invalid() {
let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
assert_contains!(err.to_string(), "'bad_data_type' from column 'foo'");
let raw = SqlServerColumnRaw::new("foo", "decimal")
.precision(100)
.scale(10);
let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
assert_contains!(
err.to_string(),
"precision of 100 is greater than our maximum of 39"
);
let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
assert_contains!(
err.to_string(),
"columns with unlimited size do not support CDC"
);
}
#[mz_ore::test]
fn smoketest_decoder() {
let sql_server_columns = [
SqlServerColumnRaw::new("a", "varchar"),
SqlServerColumnRaw::new("b", "int").nullable(true),
SqlServerColumnRaw::new("c", "bit"),
];
let sql_server_desc = SqlServerTableRaw {
schema_name: "my_schema".into(),
name: "my_table".into(),
columns: sql_server_columns.into(),
is_cdc_enabled: true,
};
let sql_server_desc = SqlServerTableDesc::try_new(sql_server_desc).expect("known valid");
let relation_desc = RelationDesc::builder()
.with_column("a", ScalarType::String.nullable(false))
.with_column("c", ScalarType::Bool.nullable(false))
.with_column("b", ScalarType::Int32.nullable(true))
.finish();
let decoder = sql_server_desc
.decoder(&relation_desc)
.expect("known valid");
let sql_server_columns = [
tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
];
let data_a = [
tiberius::ColumnData::String(Some("hello world".into())),
tiberius::ColumnData::I32(Some(42)),
tiberius::ColumnData::Bit(Some(true)),
];
let sql_server_row_a =
tiberius::Row::build(sql_server_columns.iter().cloned().zip(data_a.into_iter()));
let data_b = [
tiberius::ColumnData::String(Some("foo bar".into())),
tiberius::ColumnData::I32(None),
tiberius::ColumnData::Bit(Some(false)),
];
let sql_server_row_b =
tiberius::Row::build(sql_server_columns.into_iter().zip(data_b.into_iter()));
let mut rnd_row = Row::default();
decoder.decode(&sql_server_row_a, &mut rnd_row).unwrap();
assert_eq!(
&rnd_row,
&Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
);
decoder.decode(&sql_server_row_b, &mut rnd_row).unwrap();
assert_eq!(
&rnd_row,
&Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
);
}
}