1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::desc::proto_sql_server_table_constraint::ConstraintType;
42use crate::{SqlServerDecodeError, SqlServerError};
43
44include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
56pub struct SqlServerTableDesc {
57 pub schema_name: Arc<str>,
59 pub name: Arc<str>,
61 pub columns: Box<[SqlServerColumnDesc]>,
63 pub constraints: Vec<SqlServerTableConstraint>,
65}
66
67impl SqlServerTableDesc {
68 pub fn new(
73 raw: SqlServerTableRaw,
74 raw_constraints: Vec<SqlServerTableConstraintRaw>,
75 ) -> Result<Self, SqlServerError> {
76 let columns: Box<[_]> = raw
77 .columns
78 .into_iter()
79 .map(SqlServerColumnDesc::new)
80 .collect();
81 let constraints = raw_constraints
82 .into_iter()
83 .map(SqlServerTableConstraint::try_from)
84 .collect::<Result<Vec<_>, _>>()?;
85 Ok(SqlServerTableDesc {
86 schema_name: raw.schema_name,
87 name: raw.name,
88 columns,
89 constraints,
90 })
91 }
92
93 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
95 SqlServerQualifiedTableName {
96 schema_name: Arc::clone(&self.schema_name),
97 table_name: Arc::clone(&self.name),
98 }
99 }
100
101 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
104 for column in &mut self.columns {
105 if text_columns.contains(column.name.as_ref()) {
106 column.represent_as_text();
107 }
108 }
109 }
110
111 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
114 for column in &mut self.columns {
115 if excl_columns.contains(column.name.as_ref()) {
116 column.exclude();
117 }
118 }
119 }
120
121 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
124 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
125 Ok(decoder)
126 }
127}
128
129impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
130 fn into_proto(&self) -> ProtoSqlServerTableDesc {
131 ProtoSqlServerTableDesc {
132 name: self.name.to_string(),
133 schema_name: self.schema_name.to_string(),
134 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
135 constraints: self.constraints.iter().map(|c| c.into_proto()).collect(),
136 }
137 }
138
139 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
140 let columns = proto
141 .columns
142 .into_iter()
143 .map(|c| c.into_rust())
144 .collect::<Result<_, _>>()?;
145 let constraints = proto
146 .constraints
147 .into_iter()
148 .map(|c| c.into_rust())
149 .collect::<Result<_, _>>()?;
150 Ok(SqlServerTableDesc {
151 schema_name: proto.schema_name.into(),
152 name: proto.name.into(),
153 columns,
154 constraints,
155 })
156 }
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Arbitrary)]
162pub enum SqlServerTableConstraintType {
163 PrimaryKey,
164 Unique,
165}
166
167impl TryFrom<String> for SqlServerTableConstraintType {
168 type Error = SqlServerError;
169
170 fn try_from(value: String) -> Result<Self, Self::Error> {
171 match value.as_str() {
172 "PRIMARY KEY" => Ok(Self::PrimaryKey),
173 "UNIQUE" => Ok(Self::Unique),
174 name => Err(SqlServerError::InvalidData {
175 column_name: "constraint_type".into(),
176 error: format!("Unknown constraint type: {name}"),
177 }),
178 }
179 }
180}
181
182impl RustType<proto_sql_server_table_constraint::ConstraintType> for SqlServerTableConstraintType {
183 fn into_proto(&self) -> proto_sql_server_table_constraint::ConstraintType {
184 match self {
185 SqlServerTableConstraintType::PrimaryKey => ConstraintType::PrimaryKey(()),
186 SqlServerTableConstraintType::Unique => ConstraintType::Unique(()),
187 }
188 }
189
190 fn from_proto(
191 proto: proto_sql_server_table_constraint::ConstraintType,
192 ) -> Result<Self, mz_proto::TryFromProtoError> {
193 Ok(match proto {
194 ConstraintType::PrimaryKey(_) => SqlServerTableConstraintType::PrimaryKey,
195 ConstraintType::Unique(_) => SqlServerTableConstraintType::Unique,
196 })
197 }
198}
199
200#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Arbitrary)]
202pub struct SqlServerTableConstraint {
203 pub constraint_name: String,
204 pub constraint_type: SqlServerTableConstraintType,
205 pub column_names: Vec<String>,
206}
207
208impl TryFrom<SqlServerTableConstraintRaw> for SqlServerTableConstraint {
209 type Error = SqlServerError;
210
211 fn try_from(value: SqlServerTableConstraintRaw) -> Result<Self, Self::Error> {
212 Ok(SqlServerTableConstraint {
213 constraint_name: value.constraint_name,
214 constraint_type: value.constraint_type.try_into()?,
215 column_names: value.columns,
216 })
217 }
218}
219
220impl RustType<ProtoSqlServerTableConstraint> for SqlServerTableConstraint {
221 fn into_proto(&self) -> ProtoSqlServerTableConstraint {
222 ProtoSqlServerTableConstraint {
223 constraint_name: self.constraint_name.clone(),
224 constraint_type: Some(self.constraint_type.into_proto()),
225 column_names: self.column_names.clone(),
226 }
227 }
228
229 fn from_proto(
230 proto: ProtoSqlServerTableConstraint,
231 ) -> Result<Self, mz_proto::TryFromProtoError> {
232 Ok(SqlServerTableConstraint {
233 constraint_name: proto.constraint_name,
234 constraint_type: proto
235 .constraint_type
236 .into_rust_if_some("ProtoSqlServerTableConstraint::constraint_type")?,
237 column_names: proto.column_names,
238 })
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
246pub struct SqlServerQualifiedTableName {
247 pub schema_name: Arc<str>,
248 pub table_name: Arc<str>,
249}
250
251impl ToString for SqlServerQualifiedTableName {
252 fn to_string(&self) -> String {
253 format!(
254 "{}.{}",
255 crate::quote_identifier(&self.schema_name),
256 crate::quote_identifier(&self.table_name)
257 )
258 }
259}
260
261#[derive(Debug, Clone)]
266pub struct SqlServerTableRaw {
267 pub schema_name: Arc<str>,
269 pub name: Arc<str>,
271 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
273 pub columns: Arc<[SqlServerColumnRaw]>,
275}
276
277#[derive(Debug, Clone)]
279pub struct SqlServerCaptureInstanceRaw {
280 pub name: Arc<str>,
282 pub create_date: Arc<NaiveDateTime>,
284}
285
286#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
288pub struct SqlServerColumnDesc {
289 pub name: Arc<str>,
291 pub column_type: Option<SqlColumnType>,
297 pub primary_key_constraint: Option<Arc<str>>,
300 pub decode_type: SqlServerColumnDecodeType,
302 pub raw_type: Arc<str>,
306}
307
308impl SqlServerColumnDesc {
309 pub fn new(raw: &SqlServerColumnRaw) -> Self {
311 let (column_type, decode_type) = match parse_data_type(raw) {
312 Ok((scalar_type, decode_type)) => {
313 let column_type = scalar_type.nullable(raw.is_nullable);
314 (Some(column_type), decode_type)
315 }
316 Err(err) => {
317 tracing::warn!(
318 ?err,
319 ?raw,
320 "found an unsupported data type when parsing raw data"
321 );
322 (
323 None,
324 SqlServerColumnDecodeType::Unsupported {
325 context: err.reason,
326 },
327 )
328 }
329 };
330 SqlServerColumnDesc {
331 name: Arc::clone(&raw.name),
332 primary_key_constraint: None,
333 column_type,
334 decode_type,
335 raw_type: Arc::clone(&raw.data_type),
336 }
337 }
338
339 pub fn represent_as_text(&mut self) {
341 self.column_type = self
342 .column_type
343 .as_ref()
344 .map(|ct| SqlScalarType::String.nullable(ct.nullable));
345 }
346
347 pub fn exclude(&mut self) {
349 self.column_type = None;
350 }
351
352 pub fn is_excluded(&self) -> bool {
354 self.column_type.is_none()
355 }
356}
357
358impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
359 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
360 ProtoSqlServerColumnDesc {
361 name: self.name.to_string(),
362 column_type: self.column_type.into_proto(),
363 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
364 decode_type: Some(self.decode_type.into_proto()),
365 raw_type: self.raw_type.to_string(),
366 }
367 }
368
369 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
370 Ok(SqlServerColumnDesc {
371 name: proto.name.into(),
372 column_type: proto.column_type.into_rust()?,
373 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
374 decode_type: proto
375 .decode_type
376 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
377 raw_type: proto.raw_type.into(),
378 })
379 }
380}
381
382#[derive(Debug)]
384#[allow(dead_code)]
385pub struct UnsupportedDataType {
386 column_name: String,
387 column_type: String,
388 reason: String,
389}
390
391fn parse_data_type(
396 raw: &SqlServerColumnRaw,
397) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
398 if raw.is_computed {
402 return Err(UnsupportedDataType {
403 column_name: raw.name.to_string(),
404 column_type: format!("{} (computed)", raw.data_type.to_lowercase()),
405 reason: "column is computed".into(),
406 });
407 }
408
409 let scalar =
410 match raw.data_type.to_lowercase().as_str() {
411 "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
412 "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
413 "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
414 "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
415 "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
416 "decimal" | "numeric" | "money" | "smallmoney" => {
417 if raw.precision > 38 || raw.scale > raw.precision {
424 tracing::warn!(
425 "unexpected value from SQL Server, precision of {} and scale of {}",
426 raw.precision,
427 raw.scale,
428 );
429 }
430 if raw.precision > 39 {
431 let reason = format!(
432 "precision of {} is greater than our maximum of 39",
433 raw.precision
434 );
435 return Err(UnsupportedDataType {
436 column_name: raw.name.to_string(),
437 column_type: raw.data_type.to_string(),
438 reason,
439 });
440 }
441
442 let raw_scale = usize::cast_from(raw.scale);
443 let max_scale =
444 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
445 column_type: raw.data_type.to_string(),
446 column_name: raw.name.to_string(),
447 reason: format!("scale of {} is too large", raw.scale),
448 })?;
449 let column_type = SqlScalarType::Numeric {
450 max_scale: Some(max_scale),
451 };
452
453 (column_type, SqlServerColumnDecodeType::Numeric)
454 }
455 "real" | "float" | "double precision" => match raw.max_length {
463 4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
466 8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
467 _ => {
468 return Err(UnsupportedDataType {
469 column_name: raw.name.to_string(),
470 column_type: raw.data_type.to_string(),
471 reason: format!("unsupported length {}", raw.max_length),
472 });
473 }
474 },
475 dt @ ("char" | "nchar" | "sysname") => {
476 if raw.max_length == -1 {
479 return Err(UnsupportedDataType {
480 column_name: raw.name.to_string(),
481 column_type: raw.data_type.to_string(),
482 reason: "columns with unlimited size do not support CDC".to_string(),
483 });
484 }
485
486 let column_type = match dt {
487 "char" => {
488 let length =
489 if raw.max_length != -1 {
490 let length = CharLength::try_from(i64::from(raw.max_length))
491 .map_err(|e| UnsupportedDataType {
492 column_name: raw.name.to_string(),
493 column_type: raw.data_type.to_string(),
494 reason: e.to_string(),
495 })?;
496 Some(length)
497 } else {
498 None
499 };
500 SqlScalarType::Char { length }
501 }
502 "nchar" | "sysname" => SqlScalarType::String,
506 other => unreachable!("'{other}' checked above"),
507 };
508
509 (column_type, SqlServerColumnDecodeType::String)
510 }
511 "varchar" | "nvarchar" => {
512 let max_length =
523 if raw.max_length != -1 {
524 let length = VarCharMaxLength::try_from(i64::from(raw.max_length))
525 .map_err(|e| UnsupportedDataType {
526 column_name: raw.name.to_string(),
527 column_type: raw.data_type.to_string(),
528 reason: e.to_string(),
529 })?;
530 Some(length)
531 } else {
532 None
533 };
534 let column_type = SqlScalarType::VarChar { max_length };
535 (column_type, SqlServerColumnDecodeType::String)
536 }
537 "text" | "ntext" | "image" => {
538 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
541
542 return Err(UnsupportedDataType {
544 column_name: raw.name.to_string(),
545 column_type: raw.data_type.to_string(),
546 reason: "columns with unlimited size do not support CDC".to_string(),
547 });
548 }
549 "xml" => {
550 if raw.max_length == -1 {
555 return Err(UnsupportedDataType {
556 column_name: raw.name.to_string(),
557 column_type: raw.data_type.to_string(),
558 reason: "columns with unlimited size do not support CDC".to_string(),
559 });
560 }
561 (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
562 }
563 "binary" | "varbinary" => {
564 if raw.max_length == -1 {
569 return Err(UnsupportedDataType {
570 column_name: raw.name.to_string(),
571 column_type: raw.data_type.to_string(),
572 reason: "columns with unlimited size do not support CDC".to_string(),
573 });
574 }
575 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
576 }
577 "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
578 "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
579 "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
592 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
593 if raw.scale > 7 {
594 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
595 }
596 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
597 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
598 }
599 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
600 let precision =
601 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
602
603 match dt {
604 "smalldatetime" | "datetime" | "datetime2" => (
605 SqlScalarType::Timestamp { precision },
606 SqlServerColumnDecodeType::NaiveDateTime,
607 ),
608 "datetimeoffset" => (
609 SqlScalarType::TimestampTz { precision },
610 SqlServerColumnDecodeType::DateTime,
611 ),
612 other => unreachable!("'{other}' checked above"),
613 }
614 }
615 "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
616 other => {
629 return Err(UnsupportedDataType {
630 column_type: other.to_string(),
631 column_name: raw.name.to_string(),
632 reason: format!("'{other}' is unimplemented"),
633 });
634 }
635 };
636 Ok(scalar)
637}
638
639#[derive(Clone, Debug)]
643pub struct SqlServerColumnRaw {
644 pub name: Arc<str>,
646 pub data_type: Arc<str>,
648 pub is_nullable: bool,
650 pub max_length: i16,
660 pub precision: u8,
662 pub scale: u8,
664 pub is_computed: bool,
666}
667
668#[derive(Clone, Debug)]
670pub struct SqlServerTableConstraintRaw {
671 pub constraint_name: String,
672 pub constraint_type: String,
673 pub columns: Vec<String>,
674}
675
676#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
678pub enum SqlServerColumnDecodeType {
679 Bool,
680 U8,
681 I16,
682 I32,
683 I64,
684 F32,
685 F64,
686 String,
687 Bytes,
688 Uuid,
690 Numeric,
692 Xml,
694 NaiveDate,
696 NaiveTime,
698 DateTime,
700 NaiveDateTime,
702 Unsupported {
704 context: String,
706 },
707}
708
709impl SqlServerColumnDecodeType {
710 pub fn decode<'a>(
712 &self,
713 data: &'a tiberius::Row,
714 name: &'a str,
715 column: &'a SqlColumnType,
716 arena: &'a RowArena,
717 ) -> Result<Datum<'a>, SqlServerDecodeError> {
718 let maybe_datum = match (&column.scalar_type, self) {
719 (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
720 .try_get(name)
721 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
722 .map(|val: bool| if val { Datum::True } else { Datum::False }),
723 (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
724 .try_get(name)
725 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
726 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
727 (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
728 .try_get(name)
729 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
730 .map(Datum::Int16),
731 (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
732 .try_get(name)
733 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
734 .map(Datum::Int32),
735 (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
736 .try_get(name)
737 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
738 .map(Datum::Int64),
739 (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
740 .try_get(name)
741 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
742 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
743 (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
744 .try_get(name)
745 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
746 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
747 (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
748 .try_get(name)
749 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
750 .map(Datum::String),
751 (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
752 .try_get(name)
753 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
754 .map(|val: &str| match length {
755 Some(expected) => {
756 let found_chars = val.chars().count();
757 let expct_chars = usize::cast_from(expected.into_u32());
758 if found_chars != expct_chars {
759 Err(SqlServerDecodeError::invalid_char(
760 name,
761 expct_chars,
762 found_chars,
763 ))
764 } else {
765 Ok(Datum::String(val))
766 }
767 }
768 None => Ok(Datum::String(val)),
769 })
770 .transpose()?,
771 (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
772 .try_get(name)
773 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
774 .map(|val: &str| match max_length {
775 Some(max) => {
776 let found_chars = val.chars().count();
777 let max_chars = usize::cast_from(max.into_u32());
778 if found_chars > max_chars {
779 Err(SqlServerDecodeError::invalid_varchar(
780 name,
781 max_chars,
782 found_chars,
783 ))
784 } else {
785 Ok(Datum::String(val))
786 }
787 }
788 None => Ok(Datum::String(val)),
789 })
790 .transpose()?,
791 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
792 .try_get(name)
793 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
794 .map(Datum::Bytes),
795 (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
796 .try_get(name)
797 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
798 .map(Datum::Uuid),
799 (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
800 .try_get(name)
801 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
802 .map(|val: tiberius::numeric::Numeric| {
803 let numeric = tiberius_numeric_to_mz_numeric(val);
804 Datum::Numeric(OrderedDecimal(numeric))
805 }),
806 (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
807 .try_get(name)
808 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
809 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
810 (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
811 .try_get(name)
812 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
813 .map(|val: chrono::NaiveDate| {
814 let date = val
815 .try_into()
816 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
817 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
818 })
819 .transpose()?,
820 (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
821 .try_get(name)
822 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
823 .map(|val: chrono::NaiveTime| {
824 let rounded = val.round_subsecs(6);
829 let val = if rounded < val {
831 val.trunc_subsecs(6)
832 } else {
833 val
834 };
835 Datum::Time(val)
836 }),
837 (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
838 data.try_get(name)
839 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
840 .map(|val: chrono::NaiveDateTime| {
841 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
842 .try_into()
843 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
844 let rounded = ts
845 .round_to_precision(*precision)
846 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
847 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
848 })
849 .transpose()?
850 }
851 (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
852 .try_get(name)
853 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
854 .map(|val: chrono::DateTime<chrono::Utc>| {
855 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
856 .try_into()
857 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
858 let rounded = ts
859 .round_to_precision(*precision)
860 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
861 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
862 })
863 .transpose()?,
864 (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
866 .try_get(name)
867 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
868 .map(|val: bool| {
869 if val {
870 Datum::String("true")
871 } else {
872 Datum::String("false")
873 }
874 }),
875 (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
876 .try_get(name)
877 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
878 .map(|val: u8| {
879 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
880 }),
881 (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
882 .try_get(name)
883 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
884 .map(|val: i16| {
885 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
886 }),
887 (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
888 .try_get(name)
889 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
890 .map(|val: i32| {
891 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
892 }),
893 (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
894 .try_get(name)
895 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
896 .map(|val: i64| {
897 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
898 }),
899 (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
900 .try_get(name)
901 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
902 .map(|val: f32| {
903 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
904 }),
905 (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
906 .try_get(name)
907 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
908 .map(|val: f64| {
909 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
910 }),
911 (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
912 .try_get(name)
913 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
914 .map(|val: uuid::Uuid| {
915 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
916 }),
917 (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
918 .try_get(name)
919 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
920 .map(|val: &[u8]| {
921 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
922 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
923 }),
924 (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
925 .try_get(name)
926 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
927 .map(|val: tiberius::numeric::Numeric| {
928 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
929 }),
930 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
931 .try_get(name)
932 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
933 .map(|val: chrono::NaiveDate| {
934 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
935 }),
936 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
937 .try_get(name)
938 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
939 .map(|val: chrono::NaiveTime| {
940 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
941 }),
942 (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
943 .try_get(name)
944 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
945 .map(|val: chrono::DateTime<chrono::Utc>| {
946 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
947 }),
948 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
949 .try_get(name)
950 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
951 .map(|val: chrono::NaiveDateTime| {
952 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
953 }),
954 (column_type, decode_type) => {
955 return Err(SqlServerDecodeError::Unsupported {
956 sql_server_type: decode_type.clone(),
957 mz_type: column_type.clone(),
958 });
959 }
960 };
961
962 match (maybe_datum, column.nullable) {
963 (Some(datum), _) => Ok(datum),
964 (None, true) => Ok(Datum::Null),
965 (None, false) => Err(SqlServerDecodeError::InvalidData {
966 column_name: name.to_string(),
967 error: "found Null in non-nullable column".to_string(),
969 }),
970 }
971 }
972}
973
974impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
975 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
976 match self {
977 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
978 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
979 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
980 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
981 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
982 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
983 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
984 SqlServerColumnDecodeType::String => {
985 proto_sql_server_column_desc::DecodeType::String(())
986 }
987 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
988 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
989 SqlServerColumnDecodeType::Numeric => {
990 proto_sql_server_column_desc::DecodeType::Numeric(())
991 }
992 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
993 SqlServerColumnDecodeType::NaiveDate => {
994 proto_sql_server_column_desc::DecodeType::NaiveDate(())
995 }
996 SqlServerColumnDecodeType::NaiveTime => {
997 proto_sql_server_column_desc::DecodeType::NaiveTime(())
998 }
999 SqlServerColumnDecodeType::DateTime => {
1000 proto_sql_server_column_desc::DecodeType::DateTime(())
1001 }
1002 SqlServerColumnDecodeType::NaiveDateTime => {
1003 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
1004 }
1005 SqlServerColumnDecodeType::Unsupported { context } => {
1006 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
1007 }
1008 }
1009 }
1010
1011 fn from_proto(
1012 proto: proto_sql_server_column_desc::DecodeType,
1013 ) -> Result<Self, mz_proto::TryFromProtoError> {
1014 let val = match proto {
1015 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
1016 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
1017 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
1018 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
1019 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
1020 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
1021 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
1022 proto_sql_server_column_desc::DecodeType::String(()) => {
1023 SqlServerColumnDecodeType::String
1024 }
1025 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
1026 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
1027 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
1028 SqlServerColumnDecodeType::Numeric
1029 }
1030 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
1031 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
1032 SqlServerColumnDecodeType::NaiveDate
1033 }
1034 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
1035 SqlServerColumnDecodeType::NaiveTime
1036 }
1037 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
1038 SqlServerColumnDecodeType::DateTime
1039 }
1040 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
1041 SqlServerColumnDecodeType::NaiveDateTime
1042 }
1043 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
1044 SqlServerColumnDecodeType::Unsupported { context }
1045 }
1046 };
1047 Ok(val)
1048 }
1049}
1050
1051fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
1054 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
1055 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
1058 numeric
1059}
1060
1061#[derive(Debug)]
1066pub struct UpdateMask {
1067 mask: Vec<u8>,
1068}
1069
1070impl TryFrom<&tiberius::Row> for UpdateMask {
1071 type Error = SqlServerDecodeError;
1072
1073 fn try_from(row: &tiberius::Row) -> Result<Self, Self::Error> {
1074 static UPDATE_MASK: &str = "__$update_mask";
1075
1076 let mask: Vec<u8> = row
1077 .try_get::<&[u8], _>(UPDATE_MASK)
1078 .inspect_err(|e| tracing::warn!("Failed extracting update mask: {e:?}"))
1079 .map_err(|_| SqlServerDecodeError::InvalidColumn {
1080 column_name: UPDATE_MASK.to_string(),
1081 as_type: "bytes",
1082 })?
1083 .ok_or_else(|| SqlServerDecodeError::InvalidData {
1084 column_name: UPDATE_MASK.to_string(),
1085 error: "column cannot be null".to_string(),
1086 })?
1087 .into();
1088 Ok(UpdateMask { mask })
1089 }
1090}
1091
1092impl UpdateMask {
1093 pub fn data_col_updated(&self, col_index: usize) -> bool {
1106 const CDC_METADATA_COL_COUNT: usize = 4;
1107
1108 if col_index < CDC_METADATA_COL_COUNT {
1109 return false;
1110 }
1111 let adj_col_index = col_index - CDC_METADATA_COL_COUNT;
1112 let byte_offset = adj_col_index / usize::cast_from(u8::BITS);
1113 assert!(
1114 byte_offset < self.mask.len(),
1115 "byte_offset = {byte_offset} mask_len = {}",
1116 self.mask.len()
1117 );
1118 let bit_offset = adj_col_index % usize::cast_from(u8::BITS);
1119 (self.mask[self.mask.len() - byte_offset - 1] >> bit_offset) & 1 == 1
1120 }
1121}
1122
1123#[derive(Debug)]
1128pub struct SqlServerRowDecoder {
1129 decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
1130}
1131
1132impl SqlServerRowDecoder {
1133 pub fn try_new(
1137 table: &SqlServerTableDesc,
1138 desc: &RelationDesc,
1139 ) -> Result<Self, SqlServerError> {
1140 let decoders = desc
1141 .iter()
1142 .map(|(col_name, col_type)| {
1143 let sql_server_col = table
1144 .columns
1145 .iter()
1146 .find(|col| col.name.as_ref() == col_name.as_str())
1147 .ok_or_else(|| {
1148 anyhow::anyhow!("no SQL Server column with name {col_name} found")
1150 })?;
1151 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
1152 return Err(SqlServerError::ProgrammingError(format!(
1153 "programming error, {col_name} should have been exluded",
1154 )));
1155 };
1156
1157 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
1165 (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
1166 | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
1167 sql_server_col_typ.nullable == col_type.nullable
1169 }
1170 (_, _) => sql_server_col_typ == col_type,
1171 };
1172 if !matches {
1173 return Err(SqlServerError::ProgrammingError(format!(
1174 "programming error, {col_name} has mismatched type {:?} vs {:?}",
1175 sql_server_col.column_type, col_type
1176 )));
1177 }
1178
1179 let name = Arc::clone(&sql_server_col.name);
1180 let decoder = sql_server_col.decode_type.clone();
1181 let col_typ = sql_server_col_typ.clone();
1186
1187 Ok::<_, SqlServerError>((name, col_typ, decoder))
1188 })
1189 .collect::<Result<_, _>>()?;
1190
1191 Ok(SqlServerRowDecoder { decoders })
1192 }
1193
1194 pub fn decode(
1201 &self,
1202 data: &tiberius::Row,
1203 row: &mut Row,
1204 arena: &RowArena,
1205 new_data: Option<&tiberius::Row>,
1206 ) -> Result<(), SqlServerDecodeError> {
1207 let mut packer = row.packer();
1208
1209 for (col_name, col_type, decoder) in &self.decoders {
1210 let datum = decoder.decode(data, col_name, col_type, arena)?;
1211
1212 let datum = if let Some(new_data) = new_data
1213 && matches!(
1214 col_type.scalar_type,
1215 SqlScalarType::VarChar { max_length: None }
1216 )
1217 && matches!(datum, Datum::Null)
1218 {
1219 let update_mask = UpdateMask::try_from(new_data)?;
1220 let col_index = new_data
1221 .columns()
1222 .iter()
1223 .position(|c| c.name() == col_name.as_ref())
1224 .expect("column exists");
1225 if !update_mask.data_col_updated(col_index) {
1230 decoder.decode(new_data, col_name, col_type, arena)?
1231 } else {
1232 datum
1233 }
1234 } else {
1235 datum
1236 };
1237
1238 packer.push(datum);
1239 }
1240 Ok(())
1241 }
1242
1243 pub fn included_column_names(&self) -> Vec<Arc<str>> {
1244 self.decoders
1245 .iter()
1246 .map(|decoder| Arc::clone(&decoder.0))
1247 .collect()
1248 }
1249}
1250
1251#[cfg(test)]
1252mod tests {
1253 use std::collections::BTreeSet;
1254 use std::sync::Arc;
1255
1256 use chrono::NaiveDateTime;
1257 use itertools::Itertools;
1258 use mz_ore::assert_contains;
1259 use mz_ore::collections::CollectionExt;
1260 use mz_repr::adt::numeric::NumericMaxScale;
1261 use mz_repr::adt::varchar::VarCharMaxLength;
1262 use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1263 use tiberius::RowTestExt;
1264
1265 use crate::desc::{
1266 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1267 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1268 };
1269
1270 use super::SqlServerColumnRaw;
1271
1272 impl SqlServerColumnRaw {
1273 fn new(name: &str, data_type: &str) -> Self {
1276 SqlServerColumnRaw {
1277 name: name.into(),
1278 data_type: data_type.into(),
1279 is_nullable: false,
1280 max_length: 0,
1281 precision: 0,
1282 scale: 0,
1283 is_computed: false,
1284 }
1285 }
1286
1287 fn nullable(mut self, nullable: bool) -> Self {
1288 self.is_nullable = nullable;
1289 self
1290 }
1291
1292 fn max_length(mut self, max_length: i16) -> Self {
1293 self.max_length = max_length;
1294 self
1295 }
1296
1297 fn precision(mut self, precision: u8) -> Self {
1298 self.precision = precision;
1299 self
1300 }
1301
1302 fn scale(mut self, scale: u8) -> Self {
1303 self.scale = scale;
1304 self
1305 }
1306 }
1307
1308 #[mz_ore::test]
1309 fn smoketest_column_raw() {
1310 let raw = SqlServerColumnRaw::new("foo", "bit");
1311 let col = SqlServerColumnDesc::new(&raw);
1312
1313 assert_eq!(&*col.name, "foo");
1314 assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1315 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1316
1317 let raw = SqlServerColumnRaw::new("foo", "decimal")
1318 .precision(20)
1319 .scale(10);
1320 let col = SqlServerColumnDesc::new(&raw);
1321
1322 let col_type = SqlScalarType::Numeric {
1323 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1324 }
1325 .nullable(false);
1326 assert_eq!(col.column_type, Some(col_type));
1327 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1328 }
1329
1330 #[mz_ore::test]
1331 fn smoketest_column_raw_invalid() {
1332 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1333 let desc = SqlServerColumnDesc::new(&raw);
1334 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1335 panic!("unexpected decode type {desc:?}");
1336 };
1337 assert_contains!(context, "'bad_data_type' is unimplemented");
1338
1339 let raw = SqlServerColumnRaw::new("foo", "decimal")
1340 .precision(100)
1341 .scale(10);
1342 let desc = SqlServerColumnDesc::new(&raw);
1343 assert!(matches!(
1344 desc.decode_type,
1345 SqlServerColumnDecodeType::Unsupported { .. }
1346 ));
1347
1348 let raw = SqlServerColumnRaw::new("foo", "varbinary").max_length(-1);
1349 let desc = SqlServerColumnDesc::new(&raw);
1350 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1351 panic!("unexpected decode type {desc:?}");
1352 };
1353 assert_contains!(context, "columns with unlimited size do not support CDC");
1354 }
1355
1356 #[mz_ore::test]
1357 fn smoketest_decoder() {
1358 let sql_server_columns = [
1359 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1360 SqlServerColumnRaw::new("b", "int").nullable(true),
1361 SqlServerColumnRaw::new("c", "bit"),
1362 ];
1363 let sql_server_desc = SqlServerTableRaw {
1364 schema_name: "my_schema".into(),
1365 name: "my_table".into(),
1366 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1367 name: "my_table_CT".into(),
1368 create_date: NaiveDateTime::parse_from_str(
1369 "2024-01-01 00:00:00",
1370 "%Y-%m-%d %H:%M:%S",
1371 )
1372 .unwrap()
1373 .into(),
1374 }),
1375 columns: sql_server_columns.into(),
1376 };
1377 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1378
1379 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1380 let relation_desc = RelationDesc::builder()
1381 .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1382 .with_column("c", SqlScalarType::Bool.nullable(false))
1384 .with_column("b", SqlScalarType::Int32.nullable(true))
1385 .finish();
1386
1387 let decoder = sql_server_desc
1389 .decoder(&relation_desc)
1390 .expect("known valid");
1391
1392 let sql_server_columns = [
1393 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1394 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1395 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1396 ];
1397
1398 let data_a = [
1399 tiberius::ColumnData::String(Some("hello world".into())),
1400 tiberius::ColumnData::I32(Some(42)),
1401 tiberius::ColumnData::Bit(Some(true)),
1402 ];
1403 let sql_server_row_a = tiberius::Row::build(
1404 sql_server_columns
1405 .iter()
1406 .cloned()
1407 .zip_eq(data_a.into_iter()),
1408 );
1409
1410 let data_b = [
1411 tiberius::ColumnData::String(Some("foo bar".into())),
1412 tiberius::ColumnData::I32(None),
1413 tiberius::ColumnData::Bit(Some(false)),
1414 ];
1415 let sql_server_row_b =
1416 tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1417
1418 let mut rnd_row = Row::default();
1419 let arena = RowArena::default();
1420
1421 decoder
1422 .decode(&sql_server_row_a, &mut rnd_row, &arena, None)
1423 .unwrap();
1424 assert_eq!(
1425 &rnd_row,
1426 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1427 );
1428
1429 decoder
1430 .decode(&sql_server_row_b, &mut rnd_row, &arena, None)
1431 .unwrap();
1432 assert_eq!(
1433 &rnd_row,
1434 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1435 );
1436 }
1437
1438 #[mz_ore::test]
1439 fn smoketest_decode_to_string() {
1440 #[track_caller]
1441 fn testcase(
1442 data_type: &'static str,
1443 col_type: tiberius::ColumnType,
1444 col_data: tiberius::ColumnData<'static>,
1445 ) {
1446 let columns = [SqlServerColumnRaw::new("a", data_type)];
1447 let sql_server_desc = SqlServerTableRaw {
1448 schema_name: "my_schema".into(),
1449 name: "my_table".into(),
1450 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1451 name: "my_table_CT".into(),
1452 create_date: NaiveDateTime::parse_from_str(
1453 "2024-01-01 00:00:00",
1454 "%Y-%m-%d %H:%M:%S",
1455 )
1456 .unwrap()
1457 .into(),
1458 }),
1459 columns: columns.into(),
1460 };
1461 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1462 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1463
1464 let relation_desc = RelationDesc::builder()
1466 .with_column("a", SqlScalarType::String.nullable(false))
1467 .finish();
1468
1469 let decoder = sql_server_desc
1471 .decoder(&relation_desc)
1472 .expect("known valid");
1473
1474 let sql_server_row = tiberius::Row::build([(
1475 tiberius::Column::new("a".to_string(), col_type),
1476 col_data,
1477 )]);
1478 let mut mz_row = Row::default();
1479 let arena = RowArena::new();
1480 decoder
1481 .decode(&sql_server_row, &mut mz_row, &arena, None)
1482 .unwrap();
1483
1484 let str_datum = mz_row.into_element();
1485 assert!(matches!(str_datum, Datum::String(_)));
1486 }
1487
1488 use tiberius::ColumnData;
1489
1490 testcase(
1491 "bit",
1492 tiberius::ColumnType::Bit,
1493 ColumnData::Bit(Some(true)),
1494 );
1495 testcase(
1496 "bit",
1497 tiberius::ColumnType::Bit,
1498 ColumnData::Bit(Some(false)),
1499 );
1500 testcase(
1501 "tinyint",
1502 tiberius::ColumnType::Int1,
1503 ColumnData::U8(Some(33)),
1504 );
1505 testcase(
1506 "smallint",
1507 tiberius::ColumnType::Int2,
1508 ColumnData::I16(Some(101)),
1509 );
1510 testcase(
1511 "int",
1512 tiberius::ColumnType::Int4,
1513 ColumnData::I32(Some(-42)),
1514 );
1515 {
1516 let datetime = tiberius::time::DateTime::new(10, 300);
1517 testcase(
1518 "datetime",
1519 tiberius::ColumnType::Datetime,
1520 ColumnData::DateTime(Some(datetime)),
1521 );
1522 }
1523 }
1524
1525 #[mz_ore::test]
1526 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1528 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1529 let rnd = tiberius_numeric_to_mz_numeric(a);
1530 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1531 assert_eq!(og, rnd);
1532
1533 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1534 let rnd = tiberius_numeric_to_mz_numeric(a);
1535 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1536 assert_eq!(og, rnd);
1537
1538 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1539 let rnd = tiberius_numeric_to_mz_numeric(a);
1540 let og = mz_repr::adt::numeric::cx_datum()
1541 .parse("0.00000000000000000000000000001")
1542 .unwrap();
1543 assert_eq!(og, rnd);
1544
1545 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1546 let rnd = tiberius_numeric_to_mz_numeric(a);
1547 let og = mz_repr::adt::numeric::cx_datum()
1548 .parse("-111111111111111111")
1549 .unwrap();
1550 assert_eq!(og, rnd);
1551 }
1552
1553 }