1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::{SqlServerDecodeError, SqlServerError};
42
43include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
55pub struct SqlServerTableDesc {
56 pub schema_name: Arc<str>,
58 pub name: Arc<str>,
60 pub columns: Box<[SqlServerColumnDesc]>,
62}
63
64impl SqlServerTableDesc {
65 pub fn new(raw: SqlServerTableRaw) -> Self {
70 let columns: Box<[_]> = raw
71 .columns
72 .into_iter()
73 .map(SqlServerColumnDesc::new)
74 .collect();
75 SqlServerTableDesc {
76 schema_name: raw.schema_name,
77 name: raw.name,
78 columns,
79 }
80 }
81
82 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
84 SqlServerQualifiedTableName {
85 schema_name: Arc::clone(&self.schema_name),
86 table_name: Arc::clone(&self.name),
87 }
88 }
89
90 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
93 for column in &mut self.columns {
94 if text_columns.contains(column.name.as_ref()) {
95 column.represent_as_text();
96 }
97 }
98 }
99
100 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
103 for column in &mut self.columns {
104 if excl_columns.contains(column.name.as_ref()) {
105 column.exclude();
106 }
107 }
108 }
109
110 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
113 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
114 Ok(decoder)
115 }
116}
117
118impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
119 fn into_proto(&self) -> ProtoSqlServerTableDesc {
120 ProtoSqlServerTableDesc {
121 name: self.name.to_string(),
122 schema_name: self.schema_name.to_string(),
123 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
124 }
125 }
126
127 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
128 let columns = proto
129 .columns
130 .into_iter()
131 .map(|c| c.into_rust())
132 .collect::<Result<_, _>>()?;
133 Ok(SqlServerTableDesc {
134 schema_name: proto.schema_name.into(),
135 name: proto.name.into(),
136 columns,
137 })
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
145pub struct SqlServerQualifiedTableName {
146 pub schema_name: Arc<str>,
147 pub table_name: Arc<str>,
148}
149
150impl ToString for SqlServerQualifiedTableName {
151 fn to_string(&self) -> String {
152 format!("[{}].[{}]", self.schema_name, self.table_name)
153 }
154}
155
156#[derive(Debug, Clone)]
161pub struct SqlServerTableRaw {
162 pub schema_name: Arc<str>,
164 pub name: Arc<str>,
166 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
168 pub columns: Arc<[SqlServerColumnRaw]>,
170}
171
172#[derive(Debug, Clone)]
174pub struct SqlServerCaptureInstanceRaw {
175 pub name: Arc<str>,
177 pub create_date: Arc<NaiveDateTime>,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
183pub struct SqlServerColumnDesc {
184 pub name: Arc<str>,
186 pub column_type: Option<SqlColumnType>,
192 pub primary_key_constraint: Option<Arc<str>>,
194 pub decode_type: SqlServerColumnDecodeType,
196 pub raw_type: Arc<str>,
200}
201
202impl SqlServerColumnDesc {
203 pub fn new(raw: &SqlServerColumnRaw) -> Self {
205 let (column_type, decode_type) = match parse_data_type(raw) {
206 Ok((scalar_type, decode_type)) => {
207 let column_type = scalar_type.nullable(raw.is_nullable);
208 (Some(column_type), decode_type)
209 }
210 Err(err) => {
211 tracing::warn!(
212 ?err,
213 ?raw,
214 "found an unsupported data type when parsing raw data"
215 );
216 (
217 None,
218 SqlServerColumnDecodeType::Unsupported {
219 context: err.reason,
220 },
221 )
222 }
223 };
224 SqlServerColumnDesc {
225 name: Arc::clone(&raw.name),
226 primary_key_constraint: raw.primary_key_constraint.clone(),
227 column_type,
228 decode_type,
229 raw_type: Arc::clone(&raw.data_type),
230 }
231 }
232
233 pub fn represent_as_text(&mut self) {
235 self.column_type = self
236 .column_type
237 .as_ref()
238 .map(|ct| SqlScalarType::String.nullable(ct.nullable));
239 }
240
241 pub fn exclude(&mut self) {
243 self.column_type = None;
244 }
245
246 pub fn is_excluded(&self) -> bool {
248 self.column_type.is_none()
249 }
250}
251
252impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
253 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
254 ProtoSqlServerColumnDesc {
255 name: self.name.to_string(),
256 column_type: self.column_type.into_proto(),
257 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
258 decode_type: Some(self.decode_type.into_proto()),
259 raw_type: self.raw_type.to_string(),
260 }
261 }
262
263 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
264 Ok(SqlServerColumnDesc {
265 name: proto.name.into(),
266 column_type: proto.column_type.into_rust()?,
267 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
268 decode_type: proto
269 .decode_type
270 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
271 raw_type: proto.raw_type.into(),
272 })
273 }
274}
275
276#[derive(Debug)]
278#[allow(dead_code)]
279pub struct UnsupportedDataType {
280 column_name: String,
281 column_type: String,
282 reason: String,
283}
284
285fn parse_data_type(
290 raw: &SqlServerColumnRaw,
291) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
292 if raw.is_computed {
296 return Err(UnsupportedDataType {
297 column_name: raw.name.to_string(),
298 column_type: format!("{} (computed)", raw.data_type.to_lowercase()),
299 reason: "column is computed".into(),
300 });
301 }
302
303 let scalar = match raw.data_type.to_lowercase().as_str() {
304 "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
305 "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
306 "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
307 "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
308 "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
309 "decimal" | "numeric" | "money" | "smallmoney" => {
310 if raw.precision > 38 || raw.scale > raw.precision {
317 tracing::warn!(
318 "unexpected value from SQL Server, precision of {} and scale of {}",
319 raw.precision,
320 raw.scale,
321 );
322 }
323 if raw.precision > 39 {
324 let reason = format!(
325 "precision of {} is greater than our maximum of 39",
326 raw.precision
327 );
328 return Err(UnsupportedDataType {
329 column_name: raw.name.to_string(),
330 column_type: raw.data_type.to_string(),
331 reason,
332 });
333 }
334
335 let raw_scale = usize::cast_from(raw.scale);
336 let max_scale =
337 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
338 column_type: raw.data_type.to_string(),
339 column_name: raw.name.to_string(),
340 reason: format!("scale of {} is too large", raw.scale),
341 })?;
342 let column_type = SqlScalarType::Numeric {
343 max_scale: Some(max_scale),
344 };
345
346 (column_type, SqlServerColumnDecodeType::Numeric)
347 }
348 "real" | "float" | "double precision" => match raw.max_length {
356 4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
359 8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
360 _ => {
361 return Err(UnsupportedDataType {
362 column_name: raw.name.to_string(),
363 column_type: raw.data_type.to_string(),
364 reason: format!("unsupported length {}", raw.max_length),
365 });
366 }
367 },
368 dt @ ("char" | "nchar" | "varchar" | "nvarchar" | "sysname") => {
369 if raw.max_length == -1 {
374 return Err(UnsupportedDataType {
375 column_name: raw.name.to_string(),
376 column_type: raw.data_type.to_string(),
377 reason: "columns with unlimited size do not support CDC".to_string(),
378 });
379 }
380
381 let column_type = match dt {
382 "char" => {
383 let length = CharLength::try_from(i64::from(raw.max_length)).map_err(|e| {
384 UnsupportedDataType {
385 column_name: raw.name.to_string(),
386 column_type: raw.data_type.to_string(),
387 reason: e.to_string(),
388 }
389 })?;
390 SqlScalarType::Char {
391 length: Some(length),
392 }
393 }
394 "varchar" => {
395 let length =
396 VarCharMaxLength::try_from(i64::from(raw.max_length)).map_err(|e| {
397 UnsupportedDataType {
398 column_name: raw.name.to_string(),
399 column_type: raw.data_type.to_string(),
400 reason: e.to_string(),
401 }
402 })?;
403 SqlScalarType::VarChar {
404 max_length: Some(length),
405 }
406 }
407 "nchar" | "nvarchar" | "sysname" => SqlScalarType::String,
411 other => unreachable!("'{other}' checked above"),
412 };
413
414 (column_type, SqlServerColumnDecodeType::String)
415 }
416 "text" | "ntext" | "image" => {
417 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
420
421 return Err(UnsupportedDataType {
423 column_name: raw.name.to_string(),
424 column_type: raw.data_type.to_string(),
425 reason: "columns with unlimited size do not support CDC".to_string(),
426 });
427 }
428 "xml" => {
429 if raw.max_length == -1 {
434 return Err(UnsupportedDataType {
435 column_name: raw.name.to_string(),
436 column_type: raw.data_type.to_string(),
437 reason: "columns with unlimited size do not support CDC".to_string(),
438 });
439 }
440 (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
441 }
442 "binary" | "varbinary" => {
443 if raw.max_length == -1 {
449 return Err(UnsupportedDataType {
450 column_name: raw.name.to_string(),
451 column_type: raw.data_type.to_string(),
452 reason: "columns with unlimited size do not support CDC".to_string(),
453 });
454 }
455
456 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
457 }
458 "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
459 "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
460 "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
473 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
474 if raw.scale > 7 {
475 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
476 }
477 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
478 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
479 }
480 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
481 let precision =
482 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
483
484 match dt {
485 "smalldatetime" | "datetime" | "datetime2" => (
486 SqlScalarType::Timestamp { precision },
487 SqlServerColumnDecodeType::NaiveDateTime,
488 ),
489 "datetimeoffset" => (
490 SqlScalarType::TimestampTz { precision },
491 SqlServerColumnDecodeType::DateTime,
492 ),
493 other => unreachable!("'{other}' checked above"),
494 }
495 }
496 "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
497 other => {
510 return Err(UnsupportedDataType {
511 column_type: other.to_string(),
512 column_name: raw.name.to_string(),
513 reason: format!("'{other}' is unimplemented"),
514 });
515 }
516 };
517 Ok(scalar)
518}
519
520#[derive(Clone, Debug)]
524pub struct SqlServerColumnRaw {
525 pub name: Arc<str>,
527 pub data_type: Arc<str>,
529 pub is_nullable: bool,
531 pub primary_key_constraint: Option<Arc<str>>,
533 pub max_length: i16,
543 pub precision: u8,
545 pub scale: u8,
547 pub is_computed: bool,
549}
550
551#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
553pub enum SqlServerColumnDecodeType {
554 Bool,
555 U8,
556 I16,
557 I32,
558 I64,
559 F32,
560 F64,
561 String,
562 Bytes,
563 Uuid,
565 Numeric,
567 Xml,
569 NaiveDate,
571 NaiveTime,
573 DateTime,
575 NaiveDateTime,
577 Unsupported {
579 context: String,
581 },
582}
583
584impl SqlServerColumnDecodeType {
585 pub fn decode<'a>(
587 &self,
588 data: &'a tiberius::Row,
589 name: &'a str,
590 column: &'a SqlColumnType,
591 arena: &'a RowArena,
592 ) -> Result<Datum<'a>, SqlServerDecodeError> {
593 let maybe_datum = match (&column.scalar_type, self) {
594 (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
595 .try_get(name)
596 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
597 .map(|val: bool| if val { Datum::True } else { Datum::False }),
598 (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
599 .try_get(name)
600 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
601 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
602 (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
603 .try_get(name)
604 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
605 .map(Datum::Int16),
606 (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
607 .try_get(name)
608 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
609 .map(Datum::Int32),
610 (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
611 .try_get(name)
612 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
613 .map(Datum::Int64),
614 (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
615 .try_get(name)
616 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
617 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
618 (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
619 .try_get(name)
620 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
621 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
622 (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
623 .try_get(name)
624 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
625 .map(Datum::String),
626 (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
627 .try_get(name)
628 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
629 .map(|val: &str| match length {
630 Some(expected) => {
631 let found_chars = val.chars().count();
632 let expct_chars = usize::cast_from(expected.into_u32());
633 if found_chars != expct_chars {
634 Err(SqlServerDecodeError::invalid_char(
635 name,
636 expct_chars,
637 found_chars,
638 ))
639 } else {
640 Ok(Datum::String(val))
641 }
642 }
643 None => Ok(Datum::String(val)),
644 })
645 .transpose()?,
646 (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
647 .try_get(name)
648 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
649 .map(|val: &str| match max_length {
650 Some(max) => {
651 let found_chars = val.chars().count();
652 let max_chars = usize::cast_from(max.into_u32());
653 if found_chars > max_chars {
654 Err(SqlServerDecodeError::invalid_varchar(
655 name,
656 max_chars,
657 found_chars,
658 ))
659 } else {
660 Ok(Datum::String(val))
661 }
662 }
663 None => Ok(Datum::String(val)),
664 })
665 .transpose()?,
666 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
667 .try_get(name)
668 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
669 .map(Datum::Bytes),
670 (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
671 .try_get(name)
672 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
673 .map(Datum::Uuid),
674 (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
675 .try_get(name)
676 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
677 .map(|val: tiberius::numeric::Numeric| {
678 let numeric = tiberius_numeric_to_mz_numeric(val);
679 Datum::Numeric(OrderedDecimal(numeric))
680 }),
681 (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
682 .try_get(name)
683 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
684 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
685 (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
686 .try_get(name)
687 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
688 .map(|val: chrono::NaiveDate| {
689 let date = val
690 .try_into()
691 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
692 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
693 })
694 .transpose()?,
695 (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
696 .try_get(name)
697 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
698 .map(|val: chrono::NaiveTime| {
699 let rounded = val.round_subsecs(6);
704 let val = if rounded < val {
706 val.trunc_subsecs(6)
707 } else {
708 val
709 };
710 Datum::Time(val)
711 }),
712 (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
713 data.try_get(name)
714 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
715 .map(|val: chrono::NaiveDateTime| {
716 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
717 .try_into()
718 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
719 let rounded = ts
720 .round_to_precision(*precision)
721 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
722 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
723 })
724 .transpose()?
725 }
726 (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
727 .try_get(name)
728 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
729 .map(|val: chrono::DateTime<chrono::Utc>| {
730 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
731 .try_into()
732 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
733 let rounded = ts
734 .round_to_precision(*precision)
735 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
736 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
737 })
738 .transpose()?,
739 (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
741 .try_get(name)
742 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
743 .map(|val: bool| {
744 if val {
745 Datum::String("true")
746 } else {
747 Datum::String("false")
748 }
749 }),
750 (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
751 .try_get(name)
752 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
753 .map(|val: u8| {
754 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
755 }),
756 (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
757 .try_get(name)
758 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
759 .map(|val: i16| {
760 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
761 }),
762 (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
763 .try_get(name)
764 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
765 .map(|val: i32| {
766 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
767 }),
768 (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
769 .try_get(name)
770 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
771 .map(|val: i64| {
772 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
773 }),
774 (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
775 .try_get(name)
776 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
777 .map(|val: f32| {
778 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
779 }),
780 (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
781 .try_get(name)
782 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
783 .map(|val: f64| {
784 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
785 }),
786 (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
787 .try_get(name)
788 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
789 .map(|val: uuid::Uuid| {
790 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
791 }),
792 (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
793 .try_get(name)
794 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
795 .map(|val: &[u8]| {
796 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
797 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
798 }),
799 (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
800 .try_get(name)
801 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
802 .map(|val: tiberius::numeric::Numeric| {
803 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
804 }),
805 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
806 .try_get(name)
807 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
808 .map(|val: chrono::NaiveDate| {
809 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
810 }),
811 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
812 .try_get(name)
813 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
814 .map(|val: chrono::NaiveTime| {
815 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
816 }),
817 (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
818 .try_get(name)
819 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
820 .map(|val: chrono::DateTime<chrono::Utc>| {
821 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
822 }),
823 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
824 .try_get(name)
825 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
826 .map(|val: chrono::NaiveDateTime| {
827 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
828 }),
829 (column_type, decode_type) => {
830 return Err(SqlServerDecodeError::Unsupported {
831 sql_server_type: decode_type.clone(),
832 mz_type: column_type.clone(),
833 });
834 }
835 };
836
837 match (maybe_datum, column.nullable) {
838 (Some(datum), _) => Ok(datum),
839 (None, true) => Ok(Datum::Null),
840 (None, false) => Err(SqlServerDecodeError::InvalidData {
841 column_name: name.to_string(),
842 error: "found Null in non-nullable column".to_string(),
844 }),
845 }
846 }
847}
848
849impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
850 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
851 match self {
852 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
853 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
854 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
855 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
856 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
857 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
858 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
859 SqlServerColumnDecodeType::String => {
860 proto_sql_server_column_desc::DecodeType::String(())
861 }
862 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
863 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
864 SqlServerColumnDecodeType::Numeric => {
865 proto_sql_server_column_desc::DecodeType::Numeric(())
866 }
867 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
868 SqlServerColumnDecodeType::NaiveDate => {
869 proto_sql_server_column_desc::DecodeType::NaiveDate(())
870 }
871 SqlServerColumnDecodeType::NaiveTime => {
872 proto_sql_server_column_desc::DecodeType::NaiveTime(())
873 }
874 SqlServerColumnDecodeType::DateTime => {
875 proto_sql_server_column_desc::DecodeType::DateTime(())
876 }
877 SqlServerColumnDecodeType::NaiveDateTime => {
878 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
879 }
880 SqlServerColumnDecodeType::Unsupported { context } => {
881 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
882 }
883 }
884 }
885
886 fn from_proto(
887 proto: proto_sql_server_column_desc::DecodeType,
888 ) -> Result<Self, mz_proto::TryFromProtoError> {
889 let val = match proto {
890 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
891 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
892 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
893 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
894 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
895 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
896 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
897 proto_sql_server_column_desc::DecodeType::String(()) => {
898 SqlServerColumnDecodeType::String
899 }
900 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
901 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
902 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
903 SqlServerColumnDecodeType::Numeric
904 }
905 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
906 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
907 SqlServerColumnDecodeType::NaiveDate
908 }
909 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
910 SqlServerColumnDecodeType::NaiveTime
911 }
912 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
913 SqlServerColumnDecodeType::DateTime
914 }
915 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
916 SqlServerColumnDecodeType::NaiveDateTime
917 }
918 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
919 SqlServerColumnDecodeType::Unsupported { context }
920 }
921 };
922 Ok(val)
923 }
924}
925
926fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
929 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
930 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
933 numeric
934}
935
936#[derive(Debug)]
941pub struct SqlServerRowDecoder {
942 decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
943}
944
945impl SqlServerRowDecoder {
946 pub fn try_new(
950 table: &SqlServerTableDesc,
951 desc: &RelationDesc,
952 ) -> Result<Self, SqlServerError> {
953 let decoders = desc
954 .iter()
955 .map(|(col_name, col_type)| {
956 let sql_server_col = table
957 .columns
958 .iter()
959 .find(|col| col.name.as_ref() == col_name.as_str())
960 .ok_or_else(|| {
961 anyhow::anyhow!("no SQL Server column with name {col_name} found")
963 })?;
964 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
965 return Err(SqlServerError::ProgrammingError(format!(
966 "programming error, {col_name} should have been exluded",
967 )));
968 };
969
970 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
978 (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
979 | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
980 sql_server_col_typ.nullable == col_type.nullable
982 }
983 (_, _) => sql_server_col_typ == col_type,
984 };
985 if !matches {
986 return Err(SqlServerError::ProgrammingError(format!(
987 "programming error, {col_name} has mismatched type {:?} vs {:?}",
988 sql_server_col.column_type, col_type
989 )));
990 }
991
992 let name = Arc::clone(&sql_server_col.name);
993 let decoder = sql_server_col.decode_type.clone();
994 let col_typ = sql_server_col_typ.clone();
999
1000 Ok::<_, SqlServerError>((name, col_typ, decoder))
1001 })
1002 .collect::<Result<_, _>>()?;
1003
1004 Ok(SqlServerRowDecoder { decoders })
1005 }
1006
1007 pub fn decode(
1009 &self,
1010 data: &tiberius::Row,
1011 row: &mut Row,
1012 arena: &RowArena,
1013 ) -> Result<(), SqlServerDecodeError> {
1014 let mut packer = row.packer();
1015 for (col_name, col_type, decoder) in &self.decoders {
1016 let datum = decoder.decode(data, col_name, col_type, arena)?;
1017 packer.push(datum);
1018 }
1019 Ok(())
1020 }
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025 use std::collections::BTreeSet;
1026 use std::sync::Arc;
1027
1028 use chrono::NaiveDateTime;
1029 use itertools::Itertools;
1030 use mz_ore::assert_contains;
1031 use mz_ore::collections::CollectionExt;
1032 use mz_repr::adt::numeric::NumericMaxScale;
1033 use mz_repr::adt::varchar::VarCharMaxLength;
1034 use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1035 use tiberius::RowTestExt;
1036
1037 use crate::desc::{
1038 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1039 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1040 };
1041
1042 use super::SqlServerColumnRaw;
1043
1044 impl SqlServerColumnRaw {
1045 fn new(name: &str, data_type: &str) -> Self {
1048 SqlServerColumnRaw {
1049 name: name.into(),
1050 data_type: data_type.into(),
1051 is_nullable: false,
1052 primary_key_constraint: None,
1053 max_length: 0,
1054 precision: 0,
1055 scale: 0,
1056 is_computed: false,
1057 }
1058 }
1059
1060 fn nullable(mut self, nullable: bool) -> Self {
1061 self.is_nullable = nullable;
1062 self
1063 }
1064
1065 fn max_length(mut self, max_length: i16) -> Self {
1066 self.max_length = max_length;
1067 self
1068 }
1069
1070 fn precision(mut self, precision: u8) -> Self {
1071 self.precision = precision;
1072 self
1073 }
1074
1075 fn scale(mut self, scale: u8) -> Self {
1076 self.scale = scale;
1077 self
1078 }
1079 }
1080
1081 #[mz_ore::test]
1082 fn smoketest_column_raw() {
1083 let raw = SqlServerColumnRaw::new("foo", "bit");
1084 let col = SqlServerColumnDesc::new(&raw);
1085
1086 assert_eq!(&*col.name, "foo");
1087 assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1088 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1089
1090 let raw = SqlServerColumnRaw::new("foo", "decimal")
1091 .precision(20)
1092 .scale(10);
1093 let col = SqlServerColumnDesc::new(&raw);
1094
1095 let col_type = SqlScalarType::Numeric {
1096 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1097 }
1098 .nullable(false);
1099 assert_eq!(col.column_type, Some(col_type));
1100 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1101 }
1102
1103 #[mz_ore::test]
1104 fn smoketest_column_raw_invalid() {
1105 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1106 let desc = SqlServerColumnDesc::new(&raw);
1107 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1108 panic!("unexpected decode type {desc:?}");
1109 };
1110 assert_contains!(context, "'bad_data_type' is unimplemented");
1111
1112 let raw = SqlServerColumnRaw::new("foo", "decimal")
1113 .precision(100)
1114 .scale(10);
1115 let desc = SqlServerColumnDesc::new(&raw);
1116 assert!(matches!(
1117 desc.decode_type,
1118 SqlServerColumnDecodeType::Unsupported { .. }
1119 ));
1120
1121 let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
1122 let desc = SqlServerColumnDesc::new(&raw);
1123 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1124 panic!("unexpected decode type {desc:?}");
1125 };
1126 assert_contains!(context, "columns with unlimited size do not support CDC");
1127 }
1128
1129 #[mz_ore::test]
1130 fn smoketest_decoder() {
1131 let sql_server_columns = [
1132 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1133 SqlServerColumnRaw::new("b", "int").nullable(true),
1134 SqlServerColumnRaw::new("c", "bit"),
1135 ];
1136 let sql_server_desc = SqlServerTableRaw {
1137 schema_name: "my_schema".into(),
1138 name: "my_table".into(),
1139 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1140 name: "my_table_CT".into(),
1141 create_date: NaiveDateTime::parse_from_str(
1142 "2024-01-01 00:00:00",
1143 "%Y-%m-%d %H:%M:%S",
1144 )
1145 .unwrap()
1146 .into(),
1147 }),
1148 columns: sql_server_columns.into(),
1149 };
1150 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1151
1152 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1153 let relation_desc = RelationDesc::builder()
1154 .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1155 .with_column("c", SqlScalarType::Bool.nullable(false))
1157 .with_column("b", SqlScalarType::Int32.nullable(true))
1158 .finish();
1159
1160 let decoder = sql_server_desc
1162 .decoder(&relation_desc)
1163 .expect("known valid");
1164
1165 let sql_server_columns = [
1166 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1167 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1168 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1169 ];
1170
1171 let data_a = [
1172 tiberius::ColumnData::String(Some("hello world".into())),
1173 tiberius::ColumnData::I32(Some(42)),
1174 tiberius::ColumnData::Bit(Some(true)),
1175 ];
1176 let sql_server_row_a = tiberius::Row::build(
1177 sql_server_columns
1178 .iter()
1179 .cloned()
1180 .zip_eq(data_a.into_iter()),
1181 );
1182
1183 let data_b = [
1184 tiberius::ColumnData::String(Some("foo bar".into())),
1185 tiberius::ColumnData::I32(None),
1186 tiberius::ColumnData::Bit(Some(false)),
1187 ];
1188 let sql_server_row_b =
1189 tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1190
1191 let mut rnd_row = Row::default();
1192 let arena = RowArena::default();
1193
1194 decoder
1195 .decode(&sql_server_row_a, &mut rnd_row, &arena)
1196 .unwrap();
1197 assert_eq!(
1198 &rnd_row,
1199 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1200 );
1201
1202 decoder
1203 .decode(&sql_server_row_b, &mut rnd_row, &arena)
1204 .unwrap();
1205 assert_eq!(
1206 &rnd_row,
1207 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1208 );
1209 }
1210
1211 #[mz_ore::test]
1212 fn smoketest_decode_to_string() {
1213 #[track_caller]
1214 fn testcase(
1215 data_type: &'static str,
1216 col_type: tiberius::ColumnType,
1217 col_data: tiberius::ColumnData<'static>,
1218 ) {
1219 let columns = [SqlServerColumnRaw::new("a", data_type)];
1220 let sql_server_desc = SqlServerTableRaw {
1221 schema_name: "my_schema".into(),
1222 name: "my_table".into(),
1223 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1224 name: "my_table_CT".into(),
1225 create_date: NaiveDateTime::parse_from_str(
1226 "2024-01-01 00:00:00",
1227 "%Y-%m-%d %H:%M:%S",
1228 )
1229 .unwrap()
1230 .into(),
1231 }),
1232 columns: columns.into(),
1233 };
1234 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1235 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1236
1237 let relation_desc = RelationDesc::builder()
1239 .with_column("a", SqlScalarType::String.nullable(false))
1240 .finish();
1241
1242 let decoder = sql_server_desc
1244 .decoder(&relation_desc)
1245 .expect("known valid");
1246
1247 let sql_server_row = tiberius::Row::build([(
1248 tiberius::Column::new("a".to_string(), col_type),
1249 col_data,
1250 )]);
1251 let mut mz_row = Row::default();
1252 let arena = RowArena::new();
1253 decoder
1254 .decode(&sql_server_row, &mut mz_row, &arena)
1255 .unwrap();
1256
1257 let str_datum = mz_row.into_element();
1258 assert!(matches!(str_datum, Datum::String(_)));
1259 }
1260
1261 use tiberius::ColumnData;
1262
1263 testcase(
1264 "bit",
1265 tiberius::ColumnType::Bit,
1266 ColumnData::Bit(Some(true)),
1267 );
1268 testcase(
1269 "bit",
1270 tiberius::ColumnType::Bit,
1271 ColumnData::Bit(Some(false)),
1272 );
1273 testcase(
1274 "tinyint",
1275 tiberius::ColumnType::Int1,
1276 ColumnData::U8(Some(33)),
1277 );
1278 testcase(
1279 "smallint",
1280 tiberius::ColumnType::Int2,
1281 ColumnData::I16(Some(101)),
1282 );
1283 testcase(
1284 "int",
1285 tiberius::ColumnType::Int4,
1286 ColumnData::I32(Some(-42)),
1287 );
1288 {
1289 let datetime = tiberius::time::DateTime::new(10, 300);
1290 testcase(
1291 "datetime",
1292 tiberius::ColumnType::Datetime,
1293 ColumnData::DateTime(Some(datetime)),
1294 );
1295 }
1296 }
1297
1298 #[mz_ore::test]
1299 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1301 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1302 let rnd = tiberius_numeric_to_mz_numeric(a);
1303 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1304 assert_eq!(og, rnd);
1305
1306 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1307 let rnd = tiberius_numeric_to_mz_numeric(a);
1308 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1309 assert_eq!(og, rnd);
1310
1311 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1312 let rnd = tiberius_numeric_to_mz_numeric(a);
1313 let og = mz_repr::adt::numeric::cx_datum()
1314 .parse("0.00000000000000000000000000001")
1315 .unwrap();
1316 assert_eq!(og, rnd);
1317
1318 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1319 let rnd = tiberius_numeric_to_mz_numeric(a);
1320 let og = mz_repr::adt::numeric::cx_datum()
1321 .parse("-111111111111111111")
1322 .unwrap();
1323 assert_eq!(og, rnd);
1324 }
1325
1326 }