1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{ColumnType, Datum, RelationDesc, Row, RowArena, ScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::{SqlServerDecodeError, SqlServerError};
42
43include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
55pub struct SqlServerTableDesc {
56 pub schema_name: Arc<str>,
58 pub name: Arc<str>,
60 pub columns: Box<[SqlServerColumnDesc]>,
62}
63
64impl SqlServerTableDesc {
65 pub fn new(raw: SqlServerTableRaw) -> Self {
70 let columns: Box<[_]> = raw
71 .columns
72 .into_iter()
73 .map(SqlServerColumnDesc::new)
74 .collect();
75 SqlServerTableDesc {
76 schema_name: raw.schema_name,
77 name: raw.name,
78 columns,
79 }
80 }
81
82 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
84 SqlServerQualifiedTableName {
85 schema_name: Arc::clone(&self.schema_name),
86 table_name: Arc::clone(&self.name),
87 }
88 }
89
90 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
93 for column in &mut self.columns {
94 if text_columns.contains(column.name.as_ref()) {
95 column.represent_as_text();
96 }
97 }
98 }
99
100 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
103 for column in &mut self.columns {
104 if excl_columns.contains(column.name.as_ref()) {
105 column.exclude();
106 }
107 }
108 }
109
110 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
113 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
114 Ok(decoder)
115 }
116}
117
118impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
119 fn into_proto(&self) -> ProtoSqlServerTableDesc {
120 ProtoSqlServerTableDesc {
121 name: self.name.to_string(),
122 schema_name: self.schema_name.to_string(),
123 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
124 }
125 }
126
127 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
128 let columns = proto
129 .columns
130 .into_iter()
131 .map(|c| c.into_rust())
132 .collect::<Result<_, _>>()?;
133 Ok(SqlServerTableDesc {
134 schema_name: proto.schema_name.into(),
135 name: proto.name.into(),
136 columns,
137 })
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
145pub struct SqlServerQualifiedTableName {
146 pub schema_name: Arc<str>,
147 pub table_name: Arc<str>,
148}
149
150#[derive(Debug, Clone)]
155pub struct SqlServerTableRaw {
156 pub schema_name: Arc<str>,
158 pub name: Arc<str>,
160 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
162 pub columns: Arc<[SqlServerColumnRaw]>,
164}
165
166#[derive(Debug, Clone)]
168pub struct SqlServerCaptureInstanceRaw {
169 pub name: Arc<str>,
171 pub create_date: Arc<NaiveDateTime>,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
177pub struct SqlServerColumnDesc {
178 pub name: Arc<str>,
180 pub column_type: Option<ColumnType>,
186 pub primary_key_constraint: Option<Arc<str>>,
188 pub decode_type: SqlServerColumnDecodeType,
190 pub raw_type: Arc<str>,
194}
195
196impl SqlServerColumnDesc {
197 pub fn new(raw: &SqlServerColumnRaw) -> Self {
199 let (column_type, decode_type) = match parse_data_type(raw) {
200 Ok((scalar_type, decode_type)) => {
201 let column_type = scalar_type.nullable(raw.is_nullable);
202 (Some(column_type), decode_type)
203 }
204 Err(err) => {
205 tracing::warn!(
206 ?err,
207 ?raw,
208 "found an unsupported data type when parsing raw data"
209 );
210 (
211 None,
212 SqlServerColumnDecodeType::Unsupported {
213 context: err.reason,
214 },
215 )
216 }
217 };
218 SqlServerColumnDesc {
219 name: Arc::clone(&raw.name),
220 primary_key_constraint: raw.primary_key_constraint.clone(),
221 column_type,
222 decode_type,
223 raw_type: Arc::clone(&raw.data_type),
224 }
225 }
226
227 pub fn is_supported(&self) -> bool {
229 !matches!(
230 self.decode_type,
231 SqlServerColumnDecodeType::Unsupported { .. }
232 )
233 }
234
235 pub fn represent_as_text(&mut self) {
237 self.column_type = self
238 .column_type
239 .as_ref()
240 .map(|ct| ScalarType::String.nullable(ct.nullable));
241 }
242
243 pub fn exclude(&mut self) {
245 self.column_type = None;
246 }
247
248 pub fn is_excluded(&self) -> bool {
250 self.column_type.is_none()
251 }
252}
253
254impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
255 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
256 ProtoSqlServerColumnDesc {
257 name: self.name.to_string(),
258 column_type: self.column_type.into_proto(),
259 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
260 decode_type: Some(self.decode_type.into_proto()),
261 raw_type: self.raw_type.to_string(),
262 }
263 }
264
265 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
266 Ok(SqlServerColumnDesc {
267 name: proto.name.into(),
268 column_type: proto.column_type.into_rust()?,
269 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
270 decode_type: proto
271 .decode_type
272 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
273 raw_type: proto.raw_type.into(),
274 })
275 }
276}
277
278#[derive(Debug)]
280#[allow(dead_code)]
281pub struct UnsupportedDataType {
282 column_name: String,
283 column_type: String,
284 reason: String,
285}
286
287fn parse_data_type(
292 raw: &SqlServerColumnRaw,
293) -> Result<(ScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
294 let scalar = match raw.data_type.to_lowercase().as_str() {
295 "tinyint" => (ScalarType::Int16, SqlServerColumnDecodeType::U8),
296 "smallint" => (ScalarType::Int16, SqlServerColumnDecodeType::I16),
297 "int" => (ScalarType::Int32, SqlServerColumnDecodeType::I32),
298 "bigint" => (ScalarType::Int64, SqlServerColumnDecodeType::I64),
299 "bit" => (ScalarType::Bool, SqlServerColumnDecodeType::Bool),
300 "decimal" | "numeric" | "money" | "smallmoney" => {
301 if raw.precision > 38 || raw.scale > raw.precision {
308 tracing::warn!(
309 "unexpected value from SQL Server, precision of {} and scale of {}",
310 raw.precision,
311 raw.scale,
312 );
313 }
314 if raw.precision > 39 {
315 let reason = format!(
316 "precision of {} is greater than our maximum of 39",
317 raw.precision
318 );
319 return Err(UnsupportedDataType {
320 column_name: raw.name.to_string(),
321 column_type: raw.data_type.to_string(),
322 reason,
323 });
324 }
325
326 let raw_scale = usize::cast_from(raw.scale);
327 let max_scale =
328 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
329 column_type: raw.data_type.to_string(),
330 column_name: raw.name.to_string(),
331 reason: format!("scale of {} is too large", raw.scale),
332 })?;
333 let column_type = ScalarType::Numeric {
334 max_scale: Some(max_scale),
335 };
336
337 (column_type, SqlServerColumnDecodeType::Numeric)
338 }
339 "real" => (ScalarType::Float32, SqlServerColumnDecodeType::F32),
340 "double" => (ScalarType::Float64, SqlServerColumnDecodeType::F64),
341 dt @ ("char" | "nchar" | "varchar" | "nvarchar" | "sysname") => {
342 if raw.max_length == -1 {
347 return Err(UnsupportedDataType {
348 column_name: raw.name.to_string(),
349 column_type: raw.data_type.to_string(),
350 reason: "columns with unlimited size do not support CDC".to_string(),
351 });
352 }
353
354 let column_type = match dt {
355 "char" => {
356 let length = CharLength::try_from(i64::from(raw.max_length)).map_err(|e| {
357 UnsupportedDataType {
358 column_name: raw.name.to_string(),
359 column_type: raw.data_type.to_string(),
360 reason: e.to_string(),
361 }
362 })?;
363 ScalarType::Char {
364 length: Some(length),
365 }
366 }
367 "varchar" => {
368 let length =
369 VarCharMaxLength::try_from(i64::from(raw.max_length)).map_err(|e| {
370 UnsupportedDataType {
371 column_name: raw.name.to_string(),
372 column_type: raw.data_type.to_string(),
373 reason: e.to_string(),
374 }
375 })?;
376 ScalarType::VarChar {
377 max_length: Some(length),
378 }
379 }
380 "nchar" | "nvarchar" | "sysname" => ScalarType::String,
384 other => unreachable!("'{other}' checked above"),
385 };
386
387 (column_type, SqlServerColumnDecodeType::String)
388 }
389 "text" | "ntext" | "image" => {
390 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
393
394 return Err(UnsupportedDataType {
396 column_name: raw.name.to_string(),
397 column_type: raw.data_type.to_string(),
398 reason: "columns with unlimited size do not support CDC".to_string(),
399 });
400 }
401 "xml" => {
402 if raw.max_length == -1 {
407 return Err(UnsupportedDataType {
408 column_name: raw.name.to_string(),
409 column_type: raw.data_type.to_string(),
410 reason: "columns with unlimited size do not support CDC".to_string(),
411 });
412 }
413 (ScalarType::String, SqlServerColumnDecodeType::Xml)
414 }
415 "binary" | "varbinary" => {
416 if raw.max_length == -1 {
422 return Err(UnsupportedDataType {
423 column_name: raw.name.to_string(),
424 column_type: raw.data_type.to_string(),
425 reason: "columns with unlimited size do not support CDC".to_string(),
426 });
427 }
428
429 (ScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
430 }
431 "json" => (ScalarType::Jsonb, SqlServerColumnDecodeType::String),
432 "date" => (ScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
433 "time" => (ScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
446 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
447 if raw.scale > 7 {
448 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
449 }
450 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
451 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
452 }
453 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
454 let precision =
455 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
456
457 match dt {
458 "smalldatetime" | "datetime" | "datetime2" => (
459 ScalarType::Timestamp { precision },
460 SqlServerColumnDecodeType::NaiveDateTime,
461 ),
462 "datetimeoffset" => (
463 ScalarType::TimestampTz { precision },
464 SqlServerColumnDecodeType::DateTime,
465 ),
466 other => unreachable!("'{other}' checked above"),
467 }
468 }
469 "uniqueidentifier" => (ScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
470 other => {
483 return Err(UnsupportedDataType {
484 column_type: other.to_string(),
485 column_name: raw.name.to_string(),
486 reason: format!("'{other}' is unimplemented"),
487 });
488 }
489 };
490 Ok(scalar)
491}
492
493#[derive(Clone, Debug)]
497pub struct SqlServerColumnRaw {
498 pub name: Arc<str>,
500 pub data_type: Arc<str>,
502 pub is_nullable: bool,
504 pub primary_key_constraint: Option<Arc<str>>,
506 pub max_length: i16,
516 pub precision: u8,
518 pub scale: u8,
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
524pub enum SqlServerColumnDecodeType {
525 Bool,
526 U8,
527 I16,
528 I32,
529 I64,
530 F32,
531 F64,
532 String,
533 Bytes,
534 Uuid,
536 Numeric,
538 Xml,
540 NaiveDate,
542 NaiveTime,
544 DateTime,
546 NaiveDateTime,
548 Unsupported {
550 context: String,
552 },
553}
554
555impl SqlServerColumnDecodeType {
556 pub fn decode<'a>(
558 &self,
559 data: &'a tiberius::Row,
560 name: &'a str,
561 column: &'a ColumnType,
562 arena: &'a RowArena,
563 ) -> Result<Datum<'a>, SqlServerDecodeError> {
564 let maybe_datum = match (&column.scalar_type, self) {
565 (ScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
566 .try_get(name)
567 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
568 .map(|val: bool| if val { Datum::True } else { Datum::False }),
569 (ScalarType::Int16, SqlServerColumnDecodeType::U8) => data
570 .try_get(name)
571 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
572 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
573 (ScalarType::Int16, SqlServerColumnDecodeType::I16) => data
574 .try_get(name)
575 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
576 .map(Datum::Int16),
577 (ScalarType::Int32, SqlServerColumnDecodeType::I32) => data
578 .try_get(name)
579 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
580 .map(Datum::Int32),
581 (ScalarType::Int64, SqlServerColumnDecodeType::I64) => data
582 .try_get(name)
583 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
584 .map(Datum::Int64),
585 (ScalarType::Float32, SqlServerColumnDecodeType::F32) => data
586 .try_get(name)
587 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
588 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
589 (ScalarType::Float64, SqlServerColumnDecodeType::F64) => data
590 .try_get(name)
591 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
592 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
593 (ScalarType::String, SqlServerColumnDecodeType::String) => data
594 .try_get(name)
595 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
596 .map(Datum::String),
597 (ScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
598 .try_get(name)
599 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
600 .map(|val: &str| match length {
601 Some(expected) => {
602 let found_chars = val.chars().count();
603 let expct_chars = usize::cast_from(expected.into_u32());
604 if found_chars != expct_chars {
605 Err(SqlServerDecodeError::invalid_char(
606 name,
607 expct_chars,
608 found_chars,
609 ))
610 } else {
611 Ok(Datum::String(val))
612 }
613 }
614 None => Ok(Datum::String(val)),
615 })
616 .transpose()?,
617 (ScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
618 .try_get(name)
619 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
620 .map(|val: &str| match max_length {
621 Some(max) => {
622 let found_chars = val.chars().count();
623 let max_chars = usize::cast_from(max.into_u32());
624 if found_chars > max_chars {
625 Err(SqlServerDecodeError::invalid_varchar(
626 name,
627 max_chars,
628 found_chars,
629 ))
630 } else {
631 Ok(Datum::String(val))
632 }
633 }
634 None => Ok(Datum::String(val)),
635 })
636 .transpose()?,
637 (ScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
638 .try_get(name)
639 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
640 .map(Datum::Bytes),
641 (ScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
642 .try_get(name)
643 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
644 .map(Datum::Uuid),
645 (ScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
646 .try_get(name)
647 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
648 .map(|val: tiberius::numeric::Numeric| {
649 let numeric = tiberius_numeric_to_mz_numeric(val);
650 Datum::Numeric(OrderedDecimal(numeric))
651 }),
652 (ScalarType::String, SqlServerColumnDecodeType::Xml) => data
653 .try_get(name)
654 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
655 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
656 (ScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
657 .try_get(name)
658 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
659 .map(|val: chrono::NaiveDate| {
660 let date = val
661 .try_into()
662 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
663 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
664 })
665 .transpose()?,
666 (ScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
667 .try_get(name)
668 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
669 .map(|val: chrono::NaiveTime| {
670 let rounded = val.round_subsecs(6);
675 let val = if rounded < val {
677 val.trunc_subsecs(6)
678 } else {
679 val
680 };
681 Datum::Time(val)
682 }),
683 (ScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => data
684 .try_get(name)
685 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
686 .map(|val: chrono::NaiveDateTime| {
687 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
688 .try_into()
689 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
690 let rounded = ts
691 .round_to_precision(*precision)
692 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
693 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
694 })
695 .transpose()?,
696 (ScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
697 .try_get(name)
698 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
699 .map(|val: chrono::DateTime<chrono::Utc>| {
700 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
701 .try_into()
702 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
703 let rounded = ts
704 .round_to_precision(*precision)
705 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
706 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
707 })
708 .transpose()?,
709 (ScalarType::String, SqlServerColumnDecodeType::Bool) => data
711 .try_get(name)
712 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
713 .map(|val: bool| {
714 if val {
715 Datum::String("true")
716 } else {
717 Datum::String("false")
718 }
719 }),
720 (ScalarType::String, SqlServerColumnDecodeType::U8) => data
721 .try_get(name)
722 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
723 .map(|val: u8| {
724 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
725 }),
726 (ScalarType::String, SqlServerColumnDecodeType::I16) => data
727 .try_get(name)
728 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
729 .map(|val: i16| {
730 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
731 }),
732 (ScalarType::String, SqlServerColumnDecodeType::I32) => data
733 .try_get(name)
734 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
735 .map(|val: i32| {
736 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
737 }),
738 (ScalarType::String, SqlServerColumnDecodeType::I64) => data
739 .try_get(name)
740 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
741 .map(|val: i64| {
742 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
743 }),
744 (ScalarType::String, SqlServerColumnDecodeType::F32) => data
745 .try_get(name)
746 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
747 .map(|val: f32| {
748 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
749 }),
750 (ScalarType::String, SqlServerColumnDecodeType::F64) => data
751 .try_get(name)
752 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
753 .map(|val: f64| {
754 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
755 }),
756 (ScalarType::String, SqlServerColumnDecodeType::Uuid) => data
757 .try_get(name)
758 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
759 .map(|val: uuid::Uuid| {
760 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
761 }),
762 (ScalarType::String, SqlServerColumnDecodeType::Bytes) => data
763 .try_get(name)
764 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
765 .map(|val: &[u8]| {
766 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
767 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
768 }),
769 (ScalarType::String, SqlServerColumnDecodeType::Numeric) => data
770 .try_get(name)
771 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
772 .map(|val: tiberius::numeric::Numeric| {
773 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
774 }),
775 (ScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
776 .try_get(name)
777 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
778 .map(|val: chrono::NaiveDate| {
779 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
780 }),
781 (ScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
782 .try_get(name)
783 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
784 .map(|val: chrono::NaiveTime| {
785 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
786 }),
787 (ScalarType::String, SqlServerColumnDecodeType::DateTime) => data
788 .try_get(name)
789 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
790 .map(|val: chrono::DateTime<chrono::Utc>| {
791 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
792 }),
793 (ScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
794 .try_get(name)
795 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
796 .map(|val: chrono::NaiveDateTime| {
797 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
798 }),
799 (column_type, decode_type) => {
800 return Err(SqlServerDecodeError::Unsupported {
801 sql_server_type: decode_type.clone(),
802 mz_type: column_type.clone(),
803 });
804 }
805 };
806
807 match (maybe_datum, column.nullable) {
808 (Some(datum), _) => Ok(datum),
809 (None, true) => Ok(Datum::Null),
810 (None, false) => Err(SqlServerDecodeError::InvalidData {
811 column_name: name.to_string(),
812 error: "found Null in non-nullable column".to_string(),
814 }),
815 }
816 }
817}
818
819impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
820 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
821 match self {
822 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
823 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
824 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
825 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
826 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
827 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
828 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
829 SqlServerColumnDecodeType::String => {
830 proto_sql_server_column_desc::DecodeType::String(())
831 }
832 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
833 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
834 SqlServerColumnDecodeType::Numeric => {
835 proto_sql_server_column_desc::DecodeType::Numeric(())
836 }
837 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
838 SqlServerColumnDecodeType::NaiveDate => {
839 proto_sql_server_column_desc::DecodeType::NaiveDate(())
840 }
841 SqlServerColumnDecodeType::NaiveTime => {
842 proto_sql_server_column_desc::DecodeType::NaiveTime(())
843 }
844 SqlServerColumnDecodeType::DateTime => {
845 proto_sql_server_column_desc::DecodeType::DateTime(())
846 }
847 SqlServerColumnDecodeType::NaiveDateTime => {
848 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
849 }
850 SqlServerColumnDecodeType::Unsupported { context } => {
851 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
852 }
853 }
854 }
855
856 fn from_proto(
857 proto: proto_sql_server_column_desc::DecodeType,
858 ) -> Result<Self, mz_proto::TryFromProtoError> {
859 let val = match proto {
860 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
861 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
862 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
863 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
864 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
865 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
866 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
867 proto_sql_server_column_desc::DecodeType::String(()) => {
868 SqlServerColumnDecodeType::String
869 }
870 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
871 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
872 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
873 SqlServerColumnDecodeType::Numeric
874 }
875 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
876 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
877 SqlServerColumnDecodeType::NaiveDate
878 }
879 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
880 SqlServerColumnDecodeType::NaiveTime
881 }
882 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
883 SqlServerColumnDecodeType::DateTime
884 }
885 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
886 SqlServerColumnDecodeType::NaiveDateTime
887 }
888 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
889 SqlServerColumnDecodeType::Unsupported { context }
890 }
891 };
892 Ok(val)
893 }
894}
895
896fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
899 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
900 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
903 numeric
904}
905
906#[derive(Debug)]
911pub struct SqlServerRowDecoder {
912 decoders: Vec<(Arc<str>, ColumnType, SqlServerColumnDecodeType)>,
913}
914
915impl SqlServerRowDecoder {
916 pub fn try_new(
920 table: &SqlServerTableDesc,
921 desc: &RelationDesc,
922 ) -> Result<Self, SqlServerError> {
923 let decoders = desc
924 .iter()
925 .map(|(col_name, col_type)| {
926 let sql_server_col = table
927 .columns
928 .iter()
929 .find(|col| col.name.as_ref() == col_name.as_str())
930 .ok_or_else(|| {
931 anyhow::anyhow!("no SQL Server column with name {col_name} found")
933 })?;
934 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
935 return Err(SqlServerError::ProgrammingError(format!(
936 "programming error, {col_name} should have been exluded",
937 )));
938 };
939
940 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
948 (ScalarType::Timestamp { .. }, ScalarType::Timestamp { .. })
949 | (ScalarType::TimestampTz { .. }, ScalarType::TimestampTz { .. }) => {
950 sql_server_col_typ.nullable == col_type.nullable
952 }
953 (_, _) => sql_server_col_typ == col_type,
954 };
955 if !matches {
956 return Err(SqlServerError::ProgrammingError(format!(
957 "programming error, {col_name} has mismatched type {:?} vs {:?}",
958 sql_server_col.column_type, col_type
959 )));
960 }
961
962 let name = Arc::clone(&sql_server_col.name);
963 let decoder = sql_server_col.decode_type.clone();
964 let col_typ = sql_server_col_typ.clone();
969
970 Ok::<_, SqlServerError>((name, col_typ, decoder))
971 })
972 .collect::<Result<_, _>>()?;
973
974 Ok(SqlServerRowDecoder { decoders })
975 }
976
977 pub fn decode(
979 &self,
980 data: &tiberius::Row,
981 row: &mut Row,
982 arena: &RowArena,
983 ) -> Result<(), SqlServerDecodeError> {
984 let mut packer = row.packer();
985 for (col_name, col_type, decoder) in &self.decoders {
986 let datum = decoder.decode(data, col_name, col_type, arena)?;
987 packer.push(datum);
988 }
989 Ok(())
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use std::collections::BTreeSet;
996 use std::sync::Arc;
997
998 use chrono::NaiveDateTime;
999 use itertools::Itertools;
1000 use mz_ore::assert_contains;
1001 use mz_ore::collections::CollectionExt;
1002 use mz_repr::adt::numeric::NumericMaxScale;
1003 use mz_repr::adt::varchar::VarCharMaxLength;
1004 use mz_repr::{Datum, RelationDesc, Row, RowArena, ScalarType};
1005 use tiberius::RowTestExt;
1006
1007 use crate::desc::{
1008 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1009 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1010 };
1011
1012 use super::SqlServerColumnRaw;
1013
1014 impl SqlServerColumnRaw {
1015 fn new(name: &str, data_type: &str) -> Self {
1018 SqlServerColumnRaw {
1019 name: name.into(),
1020 data_type: data_type.into(),
1021 is_nullable: false,
1022 primary_key_constraint: None,
1023 max_length: 0,
1024 precision: 0,
1025 scale: 0,
1026 }
1027 }
1028
1029 fn nullable(mut self, nullable: bool) -> Self {
1030 self.is_nullable = nullable;
1031 self
1032 }
1033
1034 fn max_length(mut self, max_length: i16) -> Self {
1035 self.max_length = max_length;
1036 self
1037 }
1038
1039 fn precision(mut self, precision: u8) -> Self {
1040 self.precision = precision;
1041 self
1042 }
1043
1044 fn scale(mut self, scale: u8) -> Self {
1045 self.scale = scale;
1046 self
1047 }
1048 }
1049
1050 #[mz_ore::test]
1051 fn smoketest_column_raw() {
1052 let raw = SqlServerColumnRaw::new("foo", "bit");
1053 let col = SqlServerColumnDesc::new(&raw);
1054
1055 assert_eq!(&*col.name, "foo");
1056 assert_eq!(col.column_type, Some(ScalarType::Bool.nullable(false)));
1057 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1058
1059 let raw = SqlServerColumnRaw::new("foo", "decimal")
1060 .precision(20)
1061 .scale(10);
1062 let col = SqlServerColumnDesc::new(&raw);
1063
1064 let col_type = ScalarType::Numeric {
1065 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1066 }
1067 .nullable(false);
1068 assert_eq!(col.column_type, Some(col_type));
1069 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1070 }
1071
1072 #[mz_ore::test]
1073 fn smoketest_column_raw_invalid() {
1074 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1075 let desc = SqlServerColumnDesc::new(&raw);
1076 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1077 panic!("unexpected decode type {desc:?}");
1078 };
1079 assert_contains!(context, "'bad_data_type' is unimplemented");
1080
1081 let raw = SqlServerColumnRaw::new("foo", "decimal")
1082 .precision(100)
1083 .scale(10);
1084 let desc = SqlServerColumnDesc::new(&raw);
1085 assert!(!desc.is_supported());
1086
1087 let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
1088 let desc = SqlServerColumnDesc::new(&raw);
1089 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1090 panic!("unexpected decode type {desc:?}");
1091 };
1092 assert_contains!(context, "columns with unlimited size do not support CDC");
1093 }
1094
1095 #[mz_ore::test]
1096 fn smoketest_decoder() {
1097 let sql_server_columns = [
1098 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1099 SqlServerColumnRaw::new("b", "int").nullable(true),
1100 SqlServerColumnRaw::new("c", "bit"),
1101 ];
1102 let sql_server_desc = SqlServerTableRaw {
1103 schema_name: "my_schema".into(),
1104 name: "my_table".into(),
1105 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1106 name: "my_table_CT".into(),
1107 create_date: NaiveDateTime::parse_from_str(
1108 "2024-01-01 00:00:00",
1109 "%Y-%m-%d %H:%M:%S",
1110 )
1111 .unwrap()
1112 .into(),
1113 }),
1114 columns: sql_server_columns.into(),
1115 };
1116 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1117
1118 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1119 let relation_desc = RelationDesc::builder()
1120 .with_column("a", ScalarType::VarChar { max_length }.nullable(false))
1121 .with_column("c", ScalarType::Bool.nullable(false))
1123 .with_column("b", ScalarType::Int32.nullable(true))
1124 .finish();
1125
1126 let decoder = sql_server_desc
1128 .decoder(&relation_desc)
1129 .expect("known valid");
1130
1131 let sql_server_columns = [
1132 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1133 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1134 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1135 ];
1136
1137 let data_a = [
1138 tiberius::ColumnData::String(Some("hello world".into())),
1139 tiberius::ColumnData::I32(Some(42)),
1140 tiberius::ColumnData::Bit(Some(true)),
1141 ];
1142 let sql_server_row_a = tiberius::Row::build(
1143 sql_server_columns
1144 .iter()
1145 .cloned()
1146 .zip_eq(data_a.into_iter()),
1147 );
1148
1149 let data_b = [
1150 tiberius::ColumnData::String(Some("foo bar".into())),
1151 tiberius::ColumnData::I32(None),
1152 tiberius::ColumnData::Bit(Some(false)),
1153 ];
1154 let sql_server_row_b =
1155 tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1156
1157 let mut rnd_row = Row::default();
1158 let arena = RowArena::default();
1159
1160 decoder
1161 .decode(&sql_server_row_a, &mut rnd_row, &arena)
1162 .unwrap();
1163 assert_eq!(
1164 &rnd_row,
1165 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1166 );
1167
1168 decoder
1169 .decode(&sql_server_row_b, &mut rnd_row, &arena)
1170 .unwrap();
1171 assert_eq!(
1172 &rnd_row,
1173 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1174 );
1175 }
1176
1177 #[mz_ore::test]
1178 fn smoketest_decode_to_string() {
1179 #[track_caller]
1180 fn testcase(
1181 data_type: &'static str,
1182 col_type: tiberius::ColumnType,
1183 col_data: tiberius::ColumnData<'static>,
1184 ) {
1185 let columns = [SqlServerColumnRaw::new("a", data_type)];
1186 let sql_server_desc = SqlServerTableRaw {
1187 schema_name: "my_schema".into(),
1188 name: "my_table".into(),
1189 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1190 name: "my_table_CT".into(),
1191 create_date: NaiveDateTime::parse_from_str(
1192 "2024-01-01 00:00:00",
1193 "%Y-%m-%d %H:%M:%S",
1194 )
1195 .unwrap()
1196 .into(),
1197 }),
1198 columns: columns.into(),
1199 };
1200 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1201 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1202
1203 let relation_desc = RelationDesc::builder()
1205 .with_column("a", ScalarType::String.nullable(false))
1206 .finish();
1207
1208 let decoder = sql_server_desc
1210 .decoder(&relation_desc)
1211 .expect("known valid");
1212
1213 let sql_server_row = tiberius::Row::build([(
1214 tiberius::Column::new("a".to_string(), col_type),
1215 col_data,
1216 )]);
1217 let mut mz_row = Row::default();
1218 let arena = RowArena::new();
1219 decoder
1220 .decode(&sql_server_row, &mut mz_row, &arena)
1221 .unwrap();
1222
1223 let str_datum = mz_row.into_element();
1224 assert!(matches!(str_datum, Datum::String(_)));
1225 }
1226
1227 use tiberius::{ColumnData, ColumnType};
1228
1229 testcase("bit", ColumnType::Bit, ColumnData::Bit(Some(true)));
1230 testcase("bit", ColumnType::Bit, ColumnData::Bit(Some(false)));
1231 testcase("tinyint", ColumnType::Int1, ColumnData::U8(Some(33)));
1232 testcase("smallint", ColumnType::Int2, ColumnData::I16(Some(101)));
1233 testcase("int", ColumnType::Int4, ColumnData::I32(Some(-42)));
1234 {
1235 let datetime = tiberius::time::DateTime::new(10, 300);
1236 testcase(
1237 "datetime",
1238 ColumnType::Datetime,
1239 ColumnData::DateTime(Some(datetime)),
1240 );
1241 }
1242 }
1243
1244 #[mz_ore::test]
1245 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1247 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1248 let rnd = tiberius_numeric_to_mz_numeric(a);
1249 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1250 assert_eq!(og, rnd);
1251
1252 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1253 let rnd = tiberius_numeric_to_mz_numeric(a);
1254 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1255 assert_eq!(og, rnd);
1256
1257 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1258 let rnd = tiberius_numeric_to_mz_numeric(a);
1259 let og = mz_repr::adt::numeric::cx_datum()
1260 .parse("0.00000000000000000000000000001")
1261 .unwrap();
1262 assert_eq!(og, rnd);
1263
1264 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1265 let rnd = tiberius_numeric_to_mz_numeric(a);
1266 let og = mz_repr::adt::numeric::cx_datum()
1267 .parse("-111111111111111111")
1268 .unwrap();
1269 assert_eq!(og, rnd);
1270 }
1271
1272 }