1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35#[cfg(any(test, feature = "proptest"))]
36use proptest_derive::Arbitrary;
37use serde::{Deserialize, Serialize};
38
39use std::collections::BTreeSet;
40use std::sync::Arc;
41
42use crate::desc::proto_sql_server_table_constraint::ConstraintType;
43use crate::{SqlServerDecodeError, SqlServerError};
44
45include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
58pub struct SqlServerTableDesc {
59 pub schema_name: Arc<str>,
61 pub name: Arc<str>,
63 pub columns: Box<[SqlServerColumnDesc]>,
65 pub constraints: Vec<SqlServerTableConstraint>,
67}
68
69impl SqlServerTableDesc {
70 pub fn new(
75 raw: SqlServerTableRaw,
76 raw_constraints: Vec<SqlServerTableConstraintRaw>,
77 ) -> Result<Self, SqlServerError> {
78 let columns: Box<[_]> = raw
79 .columns
80 .into_iter()
81 .map(SqlServerColumnDesc::new)
82 .collect();
83 let constraints = raw_constraints
84 .into_iter()
85 .map(SqlServerTableConstraint::try_from)
86 .collect::<Result<Vec<_>, _>>()?;
87 Ok(SqlServerTableDesc {
88 schema_name: raw.schema_name,
89 name: raw.name,
90 columns,
91 constraints,
92 })
93 }
94
95 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
97 SqlServerQualifiedTableName {
98 schema_name: Arc::clone(&self.schema_name),
99 table_name: Arc::clone(&self.name),
100 }
101 }
102
103 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
106 for column in &mut self.columns {
107 if text_columns.contains(column.name.as_ref()) {
108 column.represent_as_text();
109 }
110 }
111 }
112
113 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
116 for column in &mut self.columns {
117 if excl_columns.contains(column.name.as_ref()) {
118 column.exclude();
119 }
120 }
121 }
122
123 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
126 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
127 Ok(decoder)
128 }
129}
130
131impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
132 fn into_proto(&self) -> ProtoSqlServerTableDesc {
133 ProtoSqlServerTableDesc {
134 name: self.name.to_string(),
135 schema_name: self.schema_name.to_string(),
136 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
137 constraints: self.constraints.iter().map(|c| c.into_proto()).collect(),
138 }
139 }
140
141 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
142 let columns = proto
143 .columns
144 .into_iter()
145 .map(|c| c.into_rust())
146 .collect::<Result<_, _>>()?;
147 let constraints = proto
148 .constraints
149 .into_iter()
150 .map(|c| c.into_rust())
151 .collect::<Result<_, _>>()?;
152 Ok(SqlServerTableDesc {
153 schema_name: proto.schema_name.into(),
154 name: proto.name.into(),
155 columns,
156 constraints,
157 })
158 }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
164#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
165pub enum SqlServerTableConstraintType {
166 PrimaryKey,
167 Unique,
168}
169
170impl TryFrom<String> for SqlServerTableConstraintType {
171 type Error = SqlServerError;
172
173 fn try_from(value: String) -> Result<Self, Self::Error> {
174 match value.as_str() {
175 "PRIMARY KEY" => Ok(Self::PrimaryKey),
176 "UNIQUE" => Ok(Self::Unique),
177 name => Err(SqlServerError::InvalidData {
178 column_name: "constraint_type".into(),
179 error: format!("Unknown constraint type: {name}"),
180 }),
181 }
182 }
183}
184
185impl RustType<proto_sql_server_table_constraint::ConstraintType> for SqlServerTableConstraintType {
186 fn into_proto(&self) -> proto_sql_server_table_constraint::ConstraintType {
187 match self {
188 SqlServerTableConstraintType::PrimaryKey => ConstraintType::PrimaryKey(()),
189 SqlServerTableConstraintType::Unique => ConstraintType::Unique(()),
190 }
191 }
192
193 fn from_proto(
194 proto: proto_sql_server_table_constraint::ConstraintType,
195 ) -> Result<Self, mz_proto::TryFromProtoError> {
196 Ok(match proto {
197 ConstraintType::PrimaryKey(_) => SqlServerTableConstraintType::PrimaryKey,
198 ConstraintType::Unique(_) => SqlServerTableConstraintType::Unique,
199 })
200 }
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
205#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
206pub struct SqlServerTableConstraint {
207 pub constraint_name: String,
208 pub constraint_type: SqlServerTableConstraintType,
209 pub column_names: Vec<String>,
210}
211
212impl TryFrom<SqlServerTableConstraintRaw> for SqlServerTableConstraint {
213 type Error = SqlServerError;
214
215 fn try_from(value: SqlServerTableConstraintRaw) -> Result<Self, Self::Error> {
216 Ok(SqlServerTableConstraint {
217 constraint_name: value.constraint_name,
218 constraint_type: value.constraint_type.try_into()?,
219 column_names: value.columns,
220 })
221 }
222}
223
224impl RustType<ProtoSqlServerTableConstraint> for SqlServerTableConstraint {
225 fn into_proto(&self) -> ProtoSqlServerTableConstraint {
226 ProtoSqlServerTableConstraint {
227 constraint_name: self.constraint_name.clone(),
228 constraint_type: Some(self.constraint_type.into_proto()),
229 column_names: self.column_names.clone(),
230 }
231 }
232
233 fn from_proto(
234 proto: ProtoSqlServerTableConstraint,
235 ) -> Result<Self, mz_proto::TryFromProtoError> {
236 Ok(SqlServerTableConstraint {
237 constraint_name: proto.constraint_name,
238 constraint_type: proto
239 .constraint_type
240 .into_rust_if_some("ProtoSqlServerTableConstraint::constraint_type")?,
241 column_names: proto.column_names,
242 })
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
250pub struct SqlServerQualifiedTableName {
251 pub schema_name: Arc<str>,
252 pub table_name: Arc<str>,
253}
254
255impl ToString for SqlServerQualifiedTableName {
256 fn to_string(&self) -> String {
257 format!(
258 "{}.{}",
259 crate::quote_identifier(&self.schema_name),
260 crate::quote_identifier(&self.table_name)
261 )
262 }
263}
264
265#[derive(Debug, Clone)]
270pub struct SqlServerTableRaw {
271 pub schema_name: Arc<str>,
273 pub name: Arc<str>,
275 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
277 pub columns: Arc<[SqlServerColumnRaw]>,
279}
280
281#[derive(Debug, Clone)]
283pub struct SqlServerCaptureInstanceRaw {
284 pub name: Arc<str>,
286 pub create_date: Arc<NaiveDateTime>,
288}
289
290#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
293pub struct SqlServerColumnDesc {
294 pub name: Arc<str>,
296 pub column_type: Option<SqlColumnType>,
302 pub primary_key_constraint: Option<Arc<str>>,
305 pub decode_type: SqlServerColumnDecodeType,
307 pub raw_type: Arc<str>,
311}
312
313impl SqlServerColumnDesc {
314 pub fn new(raw: &SqlServerColumnRaw) -> Self {
316 let (column_type, decode_type) = match parse_data_type(raw) {
317 Ok((scalar_type, decode_type)) => {
318 let column_type = scalar_type.nullable(raw.is_nullable);
319 (Some(column_type), decode_type)
320 }
321 Err(err) => {
322 tracing::warn!(
323 ?err,
324 ?raw,
325 "found an unsupported data type when parsing raw data"
326 );
327 (
328 None,
329 SqlServerColumnDecodeType::Unsupported {
330 context: err.reason,
331 },
332 )
333 }
334 };
335 SqlServerColumnDesc {
336 name: Arc::clone(&raw.name),
337 primary_key_constraint: None,
338 column_type,
339 decode_type,
340 raw_type: Arc::clone(&raw.data_type),
341 }
342 }
343
344 pub fn represent_as_text(&mut self) {
346 self.column_type = self
347 .column_type
348 .as_ref()
349 .map(|ct| SqlScalarType::String.nullable(ct.nullable));
350 }
351
352 pub fn exclude(&mut self) {
354 self.column_type = None;
355 }
356
357 pub fn is_excluded(&self) -> bool {
359 self.column_type.is_none()
360 }
361}
362
363impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
364 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
365 ProtoSqlServerColumnDesc {
366 name: self.name.to_string(),
367 column_type: self.column_type.into_proto(),
368 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
369 decode_type: Some(self.decode_type.into_proto()),
370 raw_type: self.raw_type.to_string(),
371 }
372 }
373
374 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
375 Ok(SqlServerColumnDesc {
376 name: proto.name.into(),
377 column_type: proto.column_type.into_rust()?,
378 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
379 decode_type: proto
380 .decode_type
381 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
382 raw_type: proto.raw_type.into(),
383 })
384 }
385}
386
387#[derive(Debug)]
389#[allow(dead_code)]
390pub struct UnsupportedDataType {
391 column_name: String,
392 column_type: String,
393 reason: String,
394}
395
396fn parse_data_type(
401 raw: &SqlServerColumnRaw,
402) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
403 if raw.is_computed {
407 return Err(UnsupportedDataType {
408 column_name: raw.name.to_string(),
409 column_type: format!("{} (computed)", raw.data_type.to_lowercase()),
410 reason: "column is computed".into(),
411 });
412 }
413
414 let scalar =
415 match raw.data_type.to_lowercase().as_str() {
416 "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
417 "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
418 "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
419 "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
420 "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
421 "decimal" | "numeric" | "money" | "smallmoney" => {
422 if raw.precision > 38 || raw.scale > raw.precision {
429 tracing::warn!(
430 "unexpected value from SQL Server, precision of {} and scale of {}",
431 raw.precision,
432 raw.scale,
433 );
434 }
435 if raw.precision > 39 {
436 let reason = format!(
437 "precision of {} is greater than our maximum of 39",
438 raw.precision
439 );
440 return Err(UnsupportedDataType {
441 column_name: raw.name.to_string(),
442 column_type: raw.data_type.to_string(),
443 reason,
444 });
445 }
446
447 let raw_scale = usize::cast_from(raw.scale);
448 let max_scale =
449 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
450 column_type: raw.data_type.to_string(),
451 column_name: raw.name.to_string(),
452 reason: format!("scale of {} is too large", raw.scale),
453 })?;
454 let column_type = SqlScalarType::Numeric {
455 max_scale: Some(max_scale),
456 };
457
458 (column_type, SqlServerColumnDecodeType::Numeric)
459 }
460 "real" | "float" | "double precision" => match raw.max_length {
468 4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
471 8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
472 _ => {
473 return Err(UnsupportedDataType {
474 column_name: raw.name.to_string(),
475 column_type: raw.data_type.to_string(),
476 reason: format!("unsupported length {}", raw.max_length),
477 });
478 }
479 },
480 dt @ ("char" | "nchar" | "sysname") => {
481 if raw.max_length == -1 {
484 return Err(UnsupportedDataType {
485 column_name: raw.name.to_string(),
486 column_type: raw.data_type.to_string(),
487 reason: "columns with unlimited size do not support CDC".to_string(),
488 });
489 }
490
491 let column_type = match dt {
492 "char" => {
493 let length =
494 if raw.max_length != -1 {
495 let length = CharLength::try_from(i64::from(raw.max_length))
496 .map_err(|e| UnsupportedDataType {
497 column_name: raw.name.to_string(),
498 column_type: raw.data_type.to_string(),
499 reason: e.to_string(),
500 })?;
501 Some(length)
502 } else {
503 None
504 };
505 SqlScalarType::Char { length }
506 }
507 "nchar" | "sysname" => SqlScalarType::String,
511 other => unreachable!("'{other}' checked above"),
512 };
513
514 (column_type, SqlServerColumnDecodeType::String)
515 }
516 "varchar" | "nvarchar" => {
517 let max_length =
528 if raw.max_length != -1 {
529 let length = VarCharMaxLength::try_from(i64::from(raw.max_length))
530 .map_err(|e| UnsupportedDataType {
531 column_name: raw.name.to_string(),
532 column_type: raw.data_type.to_string(),
533 reason: e.to_string(),
534 })?;
535 Some(length)
536 } else {
537 None
538 };
539 let column_type = SqlScalarType::VarChar { max_length };
540 (column_type, SqlServerColumnDecodeType::String)
541 }
542 "text" | "ntext" | "image" => {
543 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
546
547 return Err(UnsupportedDataType {
549 column_name: raw.name.to_string(),
550 column_type: raw.data_type.to_string(),
551 reason: "columns with unlimited size do not support CDC".to_string(),
552 });
553 }
554 "xml" => {
555 if raw.max_length == -1 {
560 return Err(UnsupportedDataType {
561 column_name: raw.name.to_string(),
562 column_type: raw.data_type.to_string(),
563 reason: "columns with unlimited size do not support CDC".to_string(),
564 });
565 }
566 (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
567 }
568 "binary" | "varbinary" => {
569 if raw.max_length == -1 {
574 return Err(UnsupportedDataType {
575 column_name: raw.name.to_string(),
576 column_type: raw.data_type.to_string(),
577 reason: "columns with unlimited size do not support CDC".to_string(),
578 });
579 }
580 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
581 }
582 "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
583 "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
584 "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
597 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
598 if raw.scale > 7 {
599 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
600 }
601 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
602 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
603 }
604 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
605 let precision =
606 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
607
608 match dt {
609 "smalldatetime" | "datetime" | "datetime2" => (
610 SqlScalarType::Timestamp { precision },
611 SqlServerColumnDecodeType::NaiveDateTime,
612 ),
613 "datetimeoffset" => (
614 SqlScalarType::TimestampTz { precision },
615 SqlServerColumnDecodeType::DateTime,
616 ),
617 other => unreachable!("'{other}' checked above"),
618 }
619 }
620 "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
621 other => {
634 return Err(UnsupportedDataType {
635 column_type: other.to_string(),
636 column_name: raw.name.to_string(),
637 reason: format!("'{other}' is unimplemented"),
638 });
639 }
640 };
641 Ok(scalar)
642}
643
644#[derive(Clone, Debug)]
648pub struct SqlServerColumnRaw {
649 pub name: Arc<str>,
651 pub data_type: Arc<str>,
653 pub is_nullable: bool,
655 pub max_length: i16,
665 pub precision: u8,
667 pub scale: u8,
669 pub is_computed: bool,
671}
672
673#[derive(Clone, Debug)]
675pub struct SqlServerTableConstraintRaw {
676 pub constraint_name: String,
677 pub constraint_type: String,
678 pub columns: Vec<String>,
679}
680
681#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
683#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
684pub enum SqlServerColumnDecodeType {
685 Bool,
686 U8,
687 I16,
688 I32,
689 I64,
690 F32,
691 F64,
692 String,
693 Bytes,
694 Uuid,
696 Numeric,
698 Xml,
700 NaiveDate,
702 NaiveTime,
704 DateTime,
706 NaiveDateTime,
708 Unsupported {
710 context: String,
712 },
713}
714
715impl SqlServerColumnDecodeType {
716 pub fn decode<'a>(
718 &self,
719 data: &'a tiberius::Row,
720 name: &'a str,
721 column: &'a SqlColumnType,
722 arena: &'a RowArena,
723 ) -> Result<Datum<'a>, SqlServerDecodeError> {
724 let maybe_datum = match (&column.scalar_type, self) {
725 (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
726 .try_get(name)
727 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
728 .map(|val: bool| if val { Datum::True } else { Datum::False }),
729 (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
730 .try_get(name)
731 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
732 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
733 (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
734 .try_get(name)
735 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
736 .map(Datum::Int16),
737 (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
738 .try_get(name)
739 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
740 .map(Datum::Int32),
741 (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
742 .try_get(name)
743 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
744 .map(Datum::Int64),
745 (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
746 .try_get(name)
747 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
748 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
749 (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
750 .try_get(name)
751 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
752 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
753 (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
754 .try_get(name)
755 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
756 .map(Datum::String),
757 (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
758 .try_get(name)
759 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
760 .map(|val: &str| match length {
761 Some(expected) => {
762 let found_chars = val.chars().count();
763 let expct_chars = usize::cast_from(expected.into_u32());
764 if found_chars != expct_chars {
765 Err(SqlServerDecodeError::invalid_char(
766 name,
767 expct_chars,
768 found_chars,
769 ))
770 } else {
771 Ok(Datum::String(val))
772 }
773 }
774 None => Ok(Datum::String(val)),
775 })
776 .transpose()?,
777 (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
778 .try_get(name)
779 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
780 .map(|val: &str| match max_length {
781 Some(max) => {
782 let found_chars = val.chars().count();
783 let max_chars = usize::cast_from(max.into_u32());
784 if found_chars > max_chars {
785 Err(SqlServerDecodeError::invalid_varchar(
786 name,
787 max_chars,
788 found_chars,
789 ))
790 } else {
791 Ok(Datum::String(val))
792 }
793 }
794 None => Ok(Datum::String(val)),
795 })
796 .transpose()?,
797 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
798 .try_get(name)
799 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
800 .map(Datum::Bytes),
801 (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
802 .try_get(name)
803 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
804 .map(Datum::Uuid),
805 (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
806 .try_get(name)
807 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
808 .map(|val: tiberius::numeric::Numeric| {
809 let numeric = tiberius_numeric_to_mz_numeric(val);
810 Datum::Numeric(OrderedDecimal(numeric))
811 }),
812 (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
813 .try_get(name)
814 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
815 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
816 (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
817 .try_get(name)
818 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
819 .map(|val: chrono::NaiveDate| {
820 let date = val
821 .try_into()
822 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
823 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
824 })
825 .transpose()?,
826 (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
827 .try_get(name)
828 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
829 .map(|val: chrono::NaiveTime| {
830 let rounded = val.round_subsecs(6);
835 let val = if rounded < val {
837 val.trunc_subsecs(6)
838 } else {
839 val
840 };
841 Datum::Time(val)
842 }),
843 (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
844 data.try_get(name)
845 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
846 .map(|val: chrono::NaiveDateTime| {
847 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
848 .try_into()
849 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
850 let rounded = ts
851 .round_to_precision(*precision)
852 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
853 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
854 })
855 .transpose()?
856 }
857 (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
858 .try_get(name)
859 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
860 .map(|val: chrono::DateTime<chrono::Utc>| {
861 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
862 .try_into()
863 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
864 let rounded = ts
865 .round_to_precision(*precision)
866 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
867 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
868 })
869 .transpose()?,
870 (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
872 .try_get(name)
873 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
874 .map(|val: bool| {
875 if val {
876 Datum::String("true")
877 } else {
878 Datum::String("false")
879 }
880 }),
881 (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
882 .try_get(name)
883 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
884 .map(|val: u8| {
885 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
886 }),
887 (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
888 .try_get(name)
889 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
890 .map(|val: i16| {
891 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
892 }),
893 (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
894 .try_get(name)
895 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
896 .map(|val: i32| {
897 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
898 }),
899 (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
900 .try_get(name)
901 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
902 .map(|val: i64| {
903 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
904 }),
905 (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
906 .try_get(name)
907 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
908 .map(|val: f32| {
909 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
910 }),
911 (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
912 .try_get(name)
913 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
914 .map(|val: f64| {
915 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
916 }),
917 (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
918 .try_get(name)
919 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
920 .map(|val: uuid::Uuid| {
921 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
922 }),
923 (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
924 .try_get(name)
925 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
926 .map(|val: &[u8]| {
927 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
928 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
929 }),
930 (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
931 .try_get(name)
932 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
933 .map(|val: tiberius::numeric::Numeric| {
934 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
935 }),
936 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
937 .try_get(name)
938 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
939 .map(|val: chrono::NaiveDate| {
940 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
941 }),
942 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
943 .try_get(name)
944 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
945 .map(|val: chrono::NaiveTime| {
946 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
947 }),
948 (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
949 .try_get(name)
950 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
951 .map(|val: chrono::DateTime<chrono::Utc>| {
952 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
953 }),
954 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
955 .try_get(name)
956 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
957 .map(|val: chrono::NaiveDateTime| {
958 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
959 }),
960 (column_type, decode_type) => {
961 return Err(SqlServerDecodeError::Unsupported {
962 sql_server_type: decode_type.clone(),
963 mz_type: column_type.clone(),
964 });
965 }
966 };
967
968 match (maybe_datum, column.nullable) {
969 (Some(datum), _) => Ok(datum),
970 (None, true) => Ok(Datum::Null),
971 (None, false) => Err(SqlServerDecodeError::InvalidData {
972 column_name: name.to_string(),
973 error: "found Null in non-nullable column".to_string(),
975 }),
976 }
977 }
978}
979
980impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
981 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
982 match self {
983 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
984 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
985 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
986 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
987 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
988 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
989 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
990 SqlServerColumnDecodeType::String => {
991 proto_sql_server_column_desc::DecodeType::String(())
992 }
993 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
994 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
995 SqlServerColumnDecodeType::Numeric => {
996 proto_sql_server_column_desc::DecodeType::Numeric(())
997 }
998 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
999 SqlServerColumnDecodeType::NaiveDate => {
1000 proto_sql_server_column_desc::DecodeType::NaiveDate(())
1001 }
1002 SqlServerColumnDecodeType::NaiveTime => {
1003 proto_sql_server_column_desc::DecodeType::NaiveTime(())
1004 }
1005 SqlServerColumnDecodeType::DateTime => {
1006 proto_sql_server_column_desc::DecodeType::DateTime(())
1007 }
1008 SqlServerColumnDecodeType::NaiveDateTime => {
1009 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
1010 }
1011 SqlServerColumnDecodeType::Unsupported { context } => {
1012 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
1013 }
1014 }
1015 }
1016
1017 fn from_proto(
1018 proto: proto_sql_server_column_desc::DecodeType,
1019 ) -> Result<Self, mz_proto::TryFromProtoError> {
1020 let val = match proto {
1021 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
1022 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
1023 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
1024 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
1025 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
1026 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
1027 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
1028 proto_sql_server_column_desc::DecodeType::String(()) => {
1029 SqlServerColumnDecodeType::String
1030 }
1031 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
1032 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
1033 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
1034 SqlServerColumnDecodeType::Numeric
1035 }
1036 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
1037 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
1038 SqlServerColumnDecodeType::NaiveDate
1039 }
1040 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
1041 SqlServerColumnDecodeType::NaiveTime
1042 }
1043 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
1044 SqlServerColumnDecodeType::DateTime
1045 }
1046 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
1047 SqlServerColumnDecodeType::NaiveDateTime
1048 }
1049 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
1050 SqlServerColumnDecodeType::Unsupported { context }
1051 }
1052 };
1053 Ok(val)
1054 }
1055}
1056
1057fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
1060 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
1061 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
1064 numeric
1065}
1066
1067#[derive(Debug)]
1072pub struct UpdateMask {
1073 mask: Vec<u8>,
1074}
1075
1076impl TryFrom<&tiberius::Row> for UpdateMask {
1077 type Error = SqlServerDecodeError;
1078
1079 fn try_from(row: &tiberius::Row) -> Result<Self, Self::Error> {
1080 static UPDATE_MASK: &str = "__$update_mask";
1081
1082 let mask: Vec<u8> = row
1083 .try_get::<&[u8], _>(UPDATE_MASK)
1084 .inspect_err(|e| tracing::warn!("Failed extracting update mask: {e:?}"))
1085 .map_err(|_| SqlServerDecodeError::InvalidColumn {
1086 column_name: UPDATE_MASK.to_string(),
1087 as_type: "bytes",
1088 })?
1089 .ok_or_else(|| SqlServerDecodeError::InvalidData {
1090 column_name: UPDATE_MASK.to_string(),
1091 error: "column cannot be null".to_string(),
1092 })?
1093 .into();
1094 Ok(UpdateMask { mask })
1095 }
1096}
1097
1098impl UpdateMask {
1099 pub fn data_col_updated(&self, col_index: usize) -> bool {
1112 const CDC_METADATA_COL_COUNT: usize = 4;
1113
1114 if col_index < CDC_METADATA_COL_COUNT {
1115 return false;
1116 }
1117 let adj_col_index = col_index - CDC_METADATA_COL_COUNT;
1118 let byte_offset = adj_col_index / usize::cast_from(u8::BITS);
1119 assert!(
1120 byte_offset < self.mask.len(),
1121 "byte_offset = {byte_offset} mask_len = {}",
1122 self.mask.len()
1123 );
1124 let bit_offset = adj_col_index % usize::cast_from(u8::BITS);
1125 (self.mask[self.mask.len() - byte_offset - 1] >> bit_offset) & 1 == 1
1126 }
1127}
1128
1129#[derive(Debug)]
1134pub struct SqlServerRowDecoder {
1135 decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
1136}
1137
1138impl SqlServerRowDecoder {
1139 pub fn try_new(
1143 table: &SqlServerTableDesc,
1144 desc: &RelationDesc,
1145 ) -> Result<Self, SqlServerError> {
1146 let decoders = desc
1147 .iter()
1148 .map(|(col_name, col_type)| {
1149 let sql_server_col = table
1150 .columns
1151 .iter()
1152 .find(|col| col.name.as_ref() == col_name.as_str())
1153 .ok_or_else(|| {
1154 anyhow::anyhow!("no SQL Server column with name {col_name} found")
1156 })?;
1157 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
1158 return Err(SqlServerError::ProgrammingError(format!(
1159 "programming error, {col_name} should have been exluded",
1160 )));
1161 };
1162
1163 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
1171 (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
1172 | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
1173 sql_server_col_typ.nullable == col_type.nullable
1175 }
1176 (_, _) => sql_server_col_typ == col_type,
1177 };
1178 if !matches {
1179 return Err(SqlServerError::ProgrammingError(format!(
1180 "programming error, {col_name} has mismatched type {:?} vs {:?}",
1181 sql_server_col.column_type, col_type
1182 )));
1183 }
1184
1185 let name = Arc::clone(&sql_server_col.name);
1186 let decoder = sql_server_col.decode_type.clone();
1187 let col_typ = sql_server_col_typ.clone();
1192
1193 Ok::<_, SqlServerError>((name, col_typ, decoder))
1194 })
1195 .collect::<Result<_, _>>()?;
1196
1197 Ok(SqlServerRowDecoder { decoders })
1198 }
1199
1200 pub fn decode(
1207 &self,
1208 data: &tiberius::Row,
1209 row: &mut Row,
1210 arena: &RowArena,
1211 new_data: Option<&tiberius::Row>,
1212 ) -> Result<(), SqlServerDecodeError> {
1213 let mut packer = row.packer();
1214
1215 for (col_name, col_type, decoder) in &self.decoders {
1216 let datum = decoder.decode(data, col_name, col_type, arena)?;
1217
1218 let datum = if let Some(new_data) = new_data
1219 && matches!(
1220 col_type.scalar_type,
1221 SqlScalarType::VarChar { max_length: None }
1222 )
1223 && matches!(datum, Datum::Null)
1224 {
1225 let update_mask = UpdateMask::try_from(new_data)?;
1226 let col_index = new_data
1227 .columns()
1228 .iter()
1229 .position(|c| c.name() == col_name.as_ref())
1230 .expect("column exists");
1231 if !update_mask.data_col_updated(col_index) {
1236 decoder.decode(new_data, col_name, col_type, arena)?
1237 } else {
1238 datum
1239 }
1240 } else {
1241 datum
1242 };
1243
1244 packer.push(datum);
1245 }
1246 Ok(())
1247 }
1248
1249 pub fn included_column_names(&self) -> Vec<Arc<str>> {
1250 self.decoders
1251 .iter()
1252 .map(|decoder| Arc::clone(&decoder.0))
1253 .collect()
1254 }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259 use std::collections::BTreeSet;
1260 use std::sync::Arc;
1261
1262 use chrono::NaiveDateTime;
1263 use itertools::Itertools;
1264 use mz_ore::assert_contains;
1265 use mz_ore::collections::CollectionExt;
1266 use mz_repr::adt::numeric::NumericMaxScale;
1267 use mz_repr::adt::varchar::VarCharMaxLength;
1268 use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1269 use tiberius::RowTestExt;
1270
1271 use crate::desc::{
1272 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1273 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1274 };
1275
1276 use super::SqlServerColumnRaw;
1277
1278 impl SqlServerColumnRaw {
1279 fn new(name: &str, data_type: &str) -> Self {
1282 SqlServerColumnRaw {
1283 name: name.into(),
1284 data_type: data_type.into(),
1285 is_nullable: false,
1286 max_length: 0,
1287 precision: 0,
1288 scale: 0,
1289 is_computed: false,
1290 }
1291 }
1292
1293 fn nullable(mut self, nullable: bool) -> Self {
1294 self.is_nullable = nullable;
1295 self
1296 }
1297
1298 fn max_length(mut self, max_length: i16) -> Self {
1299 self.max_length = max_length;
1300 self
1301 }
1302
1303 fn precision(mut self, precision: u8) -> Self {
1304 self.precision = precision;
1305 self
1306 }
1307
1308 fn scale(mut self, scale: u8) -> Self {
1309 self.scale = scale;
1310 self
1311 }
1312 }
1313
1314 #[mz_ore::test]
1315 fn smoketest_column_raw() {
1316 let raw = SqlServerColumnRaw::new("foo", "bit");
1317 let col = SqlServerColumnDesc::new(&raw);
1318
1319 assert_eq!(&*col.name, "foo");
1320 assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1321 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1322
1323 let raw = SqlServerColumnRaw::new("foo", "decimal")
1324 .precision(20)
1325 .scale(10);
1326 let col = SqlServerColumnDesc::new(&raw);
1327
1328 let col_type = SqlScalarType::Numeric {
1329 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1330 }
1331 .nullable(false);
1332 assert_eq!(col.column_type, Some(col_type));
1333 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1334 }
1335
1336 #[mz_ore::test]
1337 fn smoketest_column_raw_invalid() {
1338 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1339 let desc = SqlServerColumnDesc::new(&raw);
1340 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1341 panic!("unexpected decode type {desc:?}");
1342 };
1343 assert_contains!(context, "'bad_data_type' is unimplemented");
1344
1345 let raw = SqlServerColumnRaw::new("foo", "decimal")
1346 .precision(100)
1347 .scale(10);
1348 let desc = SqlServerColumnDesc::new(&raw);
1349 assert!(matches!(
1350 desc.decode_type,
1351 SqlServerColumnDecodeType::Unsupported { .. }
1352 ));
1353
1354 let raw = SqlServerColumnRaw::new("foo", "varbinary").max_length(-1);
1355 let desc = SqlServerColumnDesc::new(&raw);
1356 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1357 panic!("unexpected decode type {desc:?}");
1358 };
1359 assert_contains!(context, "columns with unlimited size do not support CDC");
1360 }
1361
1362 #[mz_ore::test]
1363 fn smoketest_decoder() {
1364 let sql_server_columns = [
1365 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1366 SqlServerColumnRaw::new("b", "int").nullable(true),
1367 SqlServerColumnRaw::new("c", "bit"),
1368 ];
1369 let sql_server_desc = SqlServerTableRaw {
1370 schema_name: "my_schema".into(),
1371 name: "my_table".into(),
1372 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1373 name: "my_table_CT".into(),
1374 create_date: NaiveDateTime::parse_from_str(
1375 "2024-01-01 00:00:00",
1376 "%Y-%m-%d %H:%M:%S",
1377 )
1378 .unwrap()
1379 .into(),
1380 }),
1381 columns: sql_server_columns.into(),
1382 };
1383 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1384
1385 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1386 let relation_desc = RelationDesc::builder()
1387 .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1388 .with_column("c", SqlScalarType::Bool.nullable(false))
1390 .with_column("b", SqlScalarType::Int32.nullable(true))
1391 .finish();
1392
1393 let decoder = sql_server_desc
1395 .decoder(&relation_desc)
1396 .expect("known valid");
1397
1398 let sql_server_columns = [
1399 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1400 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1401 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1402 ];
1403
1404 let data_a = [
1405 tiberius::ColumnData::String(Some("hello world".into())),
1406 tiberius::ColumnData::I32(Some(42)),
1407 tiberius::ColumnData::Bit(Some(true)),
1408 ];
1409 let sql_server_row_a = tiberius::Row::build(
1410 sql_server_columns
1411 .iter()
1412 .cloned()
1413 .zip_eq(data_a.into_iter()),
1414 );
1415
1416 let data_b = [
1417 tiberius::ColumnData::String(Some("foo bar".into())),
1418 tiberius::ColumnData::I32(None),
1419 tiberius::ColumnData::Bit(Some(false)),
1420 ];
1421 let sql_server_row_b =
1422 tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1423
1424 let mut rnd_row = Row::default();
1425 let arena = RowArena::default();
1426
1427 decoder
1428 .decode(&sql_server_row_a, &mut rnd_row, &arena, None)
1429 .unwrap();
1430 assert_eq!(
1431 &rnd_row,
1432 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1433 );
1434
1435 decoder
1436 .decode(&sql_server_row_b, &mut rnd_row, &arena, None)
1437 .unwrap();
1438 assert_eq!(
1439 &rnd_row,
1440 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1441 );
1442 }
1443
1444 #[mz_ore::test]
1445 fn smoketest_decode_to_string() {
1446 #[track_caller]
1447 fn testcase(
1448 data_type: &'static str,
1449 col_type: tiberius::ColumnType,
1450 col_data: tiberius::ColumnData<'static>,
1451 ) {
1452 let columns = [SqlServerColumnRaw::new("a", data_type)];
1453 let sql_server_desc = SqlServerTableRaw {
1454 schema_name: "my_schema".into(),
1455 name: "my_table".into(),
1456 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1457 name: "my_table_CT".into(),
1458 create_date: NaiveDateTime::parse_from_str(
1459 "2024-01-01 00:00:00",
1460 "%Y-%m-%d %H:%M:%S",
1461 )
1462 .unwrap()
1463 .into(),
1464 }),
1465 columns: columns.into(),
1466 };
1467 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1468 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1469
1470 let relation_desc = RelationDesc::builder()
1472 .with_column("a", SqlScalarType::String.nullable(false))
1473 .finish();
1474
1475 let decoder = sql_server_desc
1477 .decoder(&relation_desc)
1478 .expect("known valid");
1479
1480 let sql_server_row = tiberius::Row::build([(
1481 tiberius::Column::new("a".to_string(), col_type),
1482 col_data,
1483 )]);
1484 let mut mz_row = Row::default();
1485 let arena = RowArena::new();
1486 decoder
1487 .decode(&sql_server_row, &mut mz_row, &arena, None)
1488 .unwrap();
1489
1490 let str_datum = mz_row.into_element();
1491 assert!(matches!(str_datum, Datum::String(_)));
1492 }
1493
1494 use tiberius::ColumnData;
1495
1496 testcase(
1497 "bit",
1498 tiberius::ColumnType::Bit,
1499 ColumnData::Bit(Some(true)),
1500 );
1501 testcase(
1502 "bit",
1503 tiberius::ColumnType::Bit,
1504 ColumnData::Bit(Some(false)),
1505 );
1506 testcase(
1507 "tinyint",
1508 tiberius::ColumnType::Int1,
1509 ColumnData::U8(Some(33)),
1510 );
1511 testcase(
1512 "smallint",
1513 tiberius::ColumnType::Int2,
1514 ColumnData::I16(Some(101)),
1515 );
1516 testcase(
1517 "int",
1518 tiberius::ColumnType::Int4,
1519 ColumnData::I32(Some(-42)),
1520 );
1521 {
1522 let datetime = tiberius::time::DateTime::new(10, 300);
1523 testcase(
1524 "datetime",
1525 tiberius::ColumnType::Datetime,
1526 ColumnData::DateTime(Some(datetime)),
1527 );
1528 }
1529 }
1530
1531 #[mz_ore::test]
1532 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1534 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1535 let rnd = tiberius_numeric_to_mz_numeric(a);
1536 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1537 assert_eq!(og, rnd);
1538
1539 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1540 let rnd = tiberius_numeric_to_mz_numeric(a);
1541 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1542 assert_eq!(og, rnd);
1543
1544 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1545 let rnd = tiberius_numeric_to_mz_numeric(a);
1546 let og = mz_repr::adt::numeric::cx_datum()
1547 .parse("0.00000000000000000000000000001")
1548 .unwrap();
1549 assert_eq!(og, rnd);
1550
1551 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1552 let rnd = tiberius_numeric_to_mz_numeric(a);
1553 let og = mz_repr::adt::numeric::cx_datum()
1554 .parse("-111111111111111111")
1555 .unwrap();
1556 assert_eq!(og, rnd);
1557 }
1558
1559 }