1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::{SqlServerDecodeError, SqlServerError};
42
43include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
55pub struct SqlServerTableDesc {
56 pub schema_name: Arc<str>,
58 pub name: Arc<str>,
60 pub columns: Box<[SqlServerColumnDesc]>,
62}
63
64impl SqlServerTableDesc {
65 pub fn new(raw: SqlServerTableRaw) -> Self {
70 let columns: Box<[_]> = raw
71 .columns
72 .into_iter()
73 .map(SqlServerColumnDesc::new)
74 .collect();
75 SqlServerTableDesc {
76 schema_name: raw.schema_name,
77 name: raw.name,
78 columns,
79 }
80 }
81
82 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
84 SqlServerQualifiedTableName {
85 schema_name: Arc::clone(&self.schema_name),
86 table_name: Arc::clone(&self.name),
87 }
88 }
89
90 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
93 for column in &mut self.columns {
94 if text_columns.contains(column.name.as_ref()) {
95 column.represent_as_text();
96 }
97 }
98 }
99
100 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
103 for column in &mut self.columns {
104 if excl_columns.contains(column.name.as_ref()) {
105 column.exclude();
106 }
107 }
108 }
109
110 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
113 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
114 Ok(decoder)
115 }
116}
117
118impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
119 fn into_proto(&self) -> ProtoSqlServerTableDesc {
120 ProtoSqlServerTableDesc {
121 name: self.name.to_string(),
122 schema_name: self.schema_name.to_string(),
123 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
124 }
125 }
126
127 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
128 let columns = proto
129 .columns
130 .into_iter()
131 .map(|c| c.into_rust())
132 .collect::<Result<_, _>>()?;
133 Ok(SqlServerTableDesc {
134 schema_name: proto.schema_name.into(),
135 name: proto.name.into(),
136 columns,
137 })
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
145pub struct SqlServerQualifiedTableName {
146 pub schema_name: Arc<str>,
147 pub table_name: Arc<str>,
148}
149
150impl ToString for SqlServerQualifiedTableName {
151 fn to_string(&self) -> String {
152 format!("[{}].[{}]", self.schema_name, self.table_name)
153 }
154}
155
156#[derive(Debug, Clone)]
161pub struct SqlServerTableRaw {
162 pub schema_name: Arc<str>,
164 pub name: Arc<str>,
166 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
168 pub columns: Arc<[SqlServerColumnRaw]>,
170}
171
172#[derive(Debug, Clone)]
174pub struct SqlServerCaptureInstanceRaw {
175 pub name: Arc<str>,
177 pub create_date: Arc<NaiveDateTime>,
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
183pub struct SqlServerColumnDesc {
184 pub name: Arc<str>,
186 pub column_type: Option<SqlColumnType>,
192 pub primary_key_constraint: Option<Arc<str>>,
194 pub decode_type: SqlServerColumnDecodeType,
196 pub raw_type: Arc<str>,
200}
201
202impl SqlServerColumnDesc {
203 pub fn new(raw: &SqlServerColumnRaw) -> Self {
205 let (column_type, decode_type) = match parse_data_type(raw) {
206 Ok((scalar_type, decode_type)) => {
207 let column_type = scalar_type.nullable(raw.is_nullable);
208 (Some(column_type), decode_type)
209 }
210 Err(err) => {
211 tracing::warn!(
212 ?err,
213 ?raw,
214 "found an unsupported data type when parsing raw data"
215 );
216 (
217 None,
218 SqlServerColumnDecodeType::Unsupported {
219 context: err.reason,
220 },
221 )
222 }
223 };
224 SqlServerColumnDesc {
225 name: Arc::clone(&raw.name),
226 primary_key_constraint: raw.primary_key_constraint.clone(),
227 column_type,
228 decode_type,
229 raw_type: Arc::clone(&raw.data_type),
230 }
231 }
232
233 pub fn is_supported(&self) -> bool {
235 !matches!(
236 self.decode_type,
237 SqlServerColumnDecodeType::Unsupported { .. }
238 )
239 }
240
241 pub fn represent_as_text(&mut self) {
243 self.column_type = self
244 .column_type
245 .as_ref()
246 .map(|ct| SqlScalarType::String.nullable(ct.nullable));
247 }
248
249 pub fn exclude(&mut self) {
251 self.column_type = None;
252 }
253
254 pub fn is_excluded(&self) -> bool {
256 self.column_type.is_none()
257 }
258}
259
260impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
261 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
262 ProtoSqlServerColumnDesc {
263 name: self.name.to_string(),
264 column_type: self.column_type.into_proto(),
265 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
266 decode_type: Some(self.decode_type.into_proto()),
267 raw_type: self.raw_type.to_string(),
268 }
269 }
270
271 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
272 Ok(SqlServerColumnDesc {
273 name: proto.name.into(),
274 column_type: proto.column_type.into_rust()?,
275 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
276 decode_type: proto
277 .decode_type
278 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
279 raw_type: proto.raw_type.into(),
280 })
281 }
282}
283
284#[derive(Debug)]
286#[allow(dead_code)]
287pub struct UnsupportedDataType {
288 column_name: String,
289 column_type: String,
290 reason: String,
291}
292
293fn parse_data_type(
298 raw: &SqlServerColumnRaw,
299) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
300 let scalar = match raw.data_type.to_lowercase().as_str() {
301 "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
302 "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
303 "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
304 "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
305 "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
306 "decimal" | "numeric" | "money" | "smallmoney" => {
307 if raw.precision > 38 || raw.scale > raw.precision {
314 tracing::warn!(
315 "unexpected value from SQL Server, precision of {} and scale of {}",
316 raw.precision,
317 raw.scale,
318 );
319 }
320 if raw.precision > 39 {
321 let reason = format!(
322 "precision of {} is greater than our maximum of 39",
323 raw.precision
324 );
325 return Err(UnsupportedDataType {
326 column_name: raw.name.to_string(),
327 column_type: raw.data_type.to_string(),
328 reason,
329 });
330 }
331
332 let raw_scale = usize::cast_from(raw.scale);
333 let max_scale =
334 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
335 column_type: raw.data_type.to_string(),
336 column_name: raw.name.to_string(),
337 reason: format!("scale of {} is too large", raw.scale),
338 })?;
339 let column_type = SqlScalarType::Numeric {
340 max_scale: Some(max_scale),
341 };
342
343 (column_type, SqlServerColumnDecodeType::Numeric)
344 }
345 "real" | "float" | "double precision" => match raw.max_length {
353 4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
356 8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
357 _ => {
358 return Err(UnsupportedDataType {
359 column_name: raw.name.to_string(),
360 column_type: raw.data_type.to_string(),
361 reason: format!("unsupported length {}", raw.max_length),
362 });
363 }
364 },
365 dt @ ("char" | "nchar" | "varchar" | "nvarchar" | "sysname") => {
366 if raw.max_length == -1 {
371 return Err(UnsupportedDataType {
372 column_name: raw.name.to_string(),
373 column_type: raw.data_type.to_string(),
374 reason: "columns with unlimited size do not support CDC".to_string(),
375 });
376 }
377
378 let column_type = match dt {
379 "char" => {
380 let length = CharLength::try_from(i64::from(raw.max_length)).map_err(|e| {
381 UnsupportedDataType {
382 column_name: raw.name.to_string(),
383 column_type: raw.data_type.to_string(),
384 reason: e.to_string(),
385 }
386 })?;
387 SqlScalarType::Char {
388 length: Some(length),
389 }
390 }
391 "varchar" => {
392 let length =
393 VarCharMaxLength::try_from(i64::from(raw.max_length)).map_err(|e| {
394 UnsupportedDataType {
395 column_name: raw.name.to_string(),
396 column_type: raw.data_type.to_string(),
397 reason: e.to_string(),
398 }
399 })?;
400 SqlScalarType::VarChar {
401 max_length: Some(length),
402 }
403 }
404 "nchar" | "nvarchar" | "sysname" => SqlScalarType::String,
408 other => unreachable!("'{other}' checked above"),
409 };
410
411 (column_type, SqlServerColumnDecodeType::String)
412 }
413 "text" | "ntext" | "image" => {
414 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
417
418 return Err(UnsupportedDataType {
420 column_name: raw.name.to_string(),
421 column_type: raw.data_type.to_string(),
422 reason: "columns with unlimited size do not support CDC".to_string(),
423 });
424 }
425 "xml" => {
426 if raw.max_length == -1 {
431 return Err(UnsupportedDataType {
432 column_name: raw.name.to_string(),
433 column_type: raw.data_type.to_string(),
434 reason: "columns with unlimited size do not support CDC".to_string(),
435 });
436 }
437 (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
438 }
439 "binary" | "varbinary" => {
440 if raw.max_length == -1 {
446 return Err(UnsupportedDataType {
447 column_name: raw.name.to_string(),
448 column_type: raw.data_type.to_string(),
449 reason: "columns with unlimited size do not support CDC".to_string(),
450 });
451 }
452
453 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
454 }
455 "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
456 "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
457 "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
470 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
471 if raw.scale > 7 {
472 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
473 }
474 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
475 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
476 }
477 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
478 let precision =
479 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
480
481 match dt {
482 "smalldatetime" | "datetime" | "datetime2" => (
483 SqlScalarType::Timestamp { precision },
484 SqlServerColumnDecodeType::NaiveDateTime,
485 ),
486 "datetimeoffset" => (
487 SqlScalarType::TimestampTz { precision },
488 SqlServerColumnDecodeType::DateTime,
489 ),
490 other => unreachable!("'{other}' checked above"),
491 }
492 }
493 "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
494 other => {
507 return Err(UnsupportedDataType {
508 column_type: other.to_string(),
509 column_name: raw.name.to_string(),
510 reason: format!("'{other}' is unimplemented"),
511 });
512 }
513 };
514 Ok(scalar)
515}
516
517#[derive(Clone, Debug)]
521pub struct SqlServerColumnRaw {
522 pub name: Arc<str>,
524 pub data_type: Arc<str>,
526 pub is_nullable: bool,
528 pub primary_key_constraint: Option<Arc<str>>,
530 pub max_length: i16,
540 pub precision: u8,
542 pub scale: u8,
544}
545
546#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
548pub enum SqlServerColumnDecodeType {
549 Bool,
550 U8,
551 I16,
552 I32,
553 I64,
554 F32,
555 F64,
556 String,
557 Bytes,
558 Uuid,
560 Numeric,
562 Xml,
564 NaiveDate,
566 NaiveTime,
568 DateTime,
570 NaiveDateTime,
572 Unsupported {
574 context: String,
576 },
577}
578
579impl SqlServerColumnDecodeType {
580 pub fn decode<'a>(
582 &self,
583 data: &'a tiberius::Row,
584 name: &'a str,
585 column: &'a SqlColumnType,
586 arena: &'a RowArena,
587 ) -> Result<Datum<'a>, SqlServerDecodeError> {
588 let maybe_datum = match (&column.scalar_type, self) {
589 (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
590 .try_get(name)
591 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
592 .map(|val: bool| if val { Datum::True } else { Datum::False }),
593 (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
594 .try_get(name)
595 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
596 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
597 (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
598 .try_get(name)
599 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
600 .map(Datum::Int16),
601 (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
602 .try_get(name)
603 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
604 .map(Datum::Int32),
605 (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
606 .try_get(name)
607 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
608 .map(Datum::Int64),
609 (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
610 .try_get(name)
611 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
612 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
613 (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
614 .try_get(name)
615 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
616 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
617 (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
618 .try_get(name)
619 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
620 .map(Datum::String),
621 (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
622 .try_get(name)
623 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
624 .map(|val: &str| match length {
625 Some(expected) => {
626 let found_chars = val.chars().count();
627 let expct_chars = usize::cast_from(expected.into_u32());
628 if found_chars != expct_chars {
629 Err(SqlServerDecodeError::invalid_char(
630 name,
631 expct_chars,
632 found_chars,
633 ))
634 } else {
635 Ok(Datum::String(val))
636 }
637 }
638 None => Ok(Datum::String(val)),
639 })
640 .transpose()?,
641 (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
642 .try_get(name)
643 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
644 .map(|val: &str| match max_length {
645 Some(max) => {
646 let found_chars = val.chars().count();
647 let max_chars = usize::cast_from(max.into_u32());
648 if found_chars > max_chars {
649 Err(SqlServerDecodeError::invalid_varchar(
650 name,
651 max_chars,
652 found_chars,
653 ))
654 } else {
655 Ok(Datum::String(val))
656 }
657 }
658 None => Ok(Datum::String(val)),
659 })
660 .transpose()?,
661 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
662 .try_get(name)
663 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
664 .map(Datum::Bytes),
665 (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
666 .try_get(name)
667 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
668 .map(Datum::Uuid),
669 (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
670 .try_get(name)
671 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
672 .map(|val: tiberius::numeric::Numeric| {
673 let numeric = tiberius_numeric_to_mz_numeric(val);
674 Datum::Numeric(OrderedDecimal(numeric))
675 }),
676 (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
677 .try_get(name)
678 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
679 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
680 (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
681 .try_get(name)
682 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
683 .map(|val: chrono::NaiveDate| {
684 let date = val
685 .try_into()
686 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
687 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
688 })
689 .transpose()?,
690 (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
691 .try_get(name)
692 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
693 .map(|val: chrono::NaiveTime| {
694 let rounded = val.round_subsecs(6);
699 let val = if rounded < val {
701 val.trunc_subsecs(6)
702 } else {
703 val
704 };
705 Datum::Time(val)
706 }),
707 (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
708 data.try_get(name)
709 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
710 .map(|val: chrono::NaiveDateTime| {
711 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
712 .try_into()
713 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
714 let rounded = ts
715 .round_to_precision(*precision)
716 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
717 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
718 })
719 .transpose()?
720 }
721 (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
722 .try_get(name)
723 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
724 .map(|val: chrono::DateTime<chrono::Utc>| {
725 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
726 .try_into()
727 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
728 let rounded = ts
729 .round_to_precision(*precision)
730 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
731 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
732 })
733 .transpose()?,
734 (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
736 .try_get(name)
737 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
738 .map(|val: bool| {
739 if val {
740 Datum::String("true")
741 } else {
742 Datum::String("false")
743 }
744 }),
745 (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
746 .try_get(name)
747 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
748 .map(|val: u8| {
749 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
750 }),
751 (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
752 .try_get(name)
753 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
754 .map(|val: i16| {
755 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
756 }),
757 (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
758 .try_get(name)
759 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
760 .map(|val: i32| {
761 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
762 }),
763 (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
764 .try_get(name)
765 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
766 .map(|val: i64| {
767 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
768 }),
769 (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
770 .try_get(name)
771 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
772 .map(|val: f32| {
773 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
774 }),
775 (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
776 .try_get(name)
777 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
778 .map(|val: f64| {
779 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
780 }),
781 (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
782 .try_get(name)
783 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
784 .map(|val: uuid::Uuid| {
785 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
786 }),
787 (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
788 .try_get(name)
789 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
790 .map(|val: &[u8]| {
791 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
792 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
793 }),
794 (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
795 .try_get(name)
796 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
797 .map(|val: tiberius::numeric::Numeric| {
798 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
799 }),
800 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
801 .try_get(name)
802 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
803 .map(|val: chrono::NaiveDate| {
804 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
805 }),
806 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
807 .try_get(name)
808 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
809 .map(|val: chrono::NaiveTime| {
810 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
811 }),
812 (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
813 .try_get(name)
814 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
815 .map(|val: chrono::DateTime<chrono::Utc>| {
816 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
817 }),
818 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
819 .try_get(name)
820 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
821 .map(|val: chrono::NaiveDateTime| {
822 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
823 }),
824 (column_type, decode_type) => {
825 return Err(SqlServerDecodeError::Unsupported {
826 sql_server_type: decode_type.clone(),
827 mz_type: column_type.clone(),
828 });
829 }
830 };
831
832 match (maybe_datum, column.nullable) {
833 (Some(datum), _) => Ok(datum),
834 (None, true) => Ok(Datum::Null),
835 (None, false) => Err(SqlServerDecodeError::InvalidData {
836 column_name: name.to_string(),
837 error: "found Null in non-nullable column".to_string(),
839 }),
840 }
841 }
842}
843
844impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
845 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
846 match self {
847 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
848 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
849 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
850 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
851 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
852 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
853 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
854 SqlServerColumnDecodeType::String => {
855 proto_sql_server_column_desc::DecodeType::String(())
856 }
857 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
858 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
859 SqlServerColumnDecodeType::Numeric => {
860 proto_sql_server_column_desc::DecodeType::Numeric(())
861 }
862 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
863 SqlServerColumnDecodeType::NaiveDate => {
864 proto_sql_server_column_desc::DecodeType::NaiveDate(())
865 }
866 SqlServerColumnDecodeType::NaiveTime => {
867 proto_sql_server_column_desc::DecodeType::NaiveTime(())
868 }
869 SqlServerColumnDecodeType::DateTime => {
870 proto_sql_server_column_desc::DecodeType::DateTime(())
871 }
872 SqlServerColumnDecodeType::NaiveDateTime => {
873 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
874 }
875 SqlServerColumnDecodeType::Unsupported { context } => {
876 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
877 }
878 }
879 }
880
881 fn from_proto(
882 proto: proto_sql_server_column_desc::DecodeType,
883 ) -> Result<Self, mz_proto::TryFromProtoError> {
884 let val = match proto {
885 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
886 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
887 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
888 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
889 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
890 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
891 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
892 proto_sql_server_column_desc::DecodeType::String(()) => {
893 SqlServerColumnDecodeType::String
894 }
895 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
896 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
897 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
898 SqlServerColumnDecodeType::Numeric
899 }
900 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
901 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
902 SqlServerColumnDecodeType::NaiveDate
903 }
904 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
905 SqlServerColumnDecodeType::NaiveTime
906 }
907 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
908 SqlServerColumnDecodeType::DateTime
909 }
910 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
911 SqlServerColumnDecodeType::NaiveDateTime
912 }
913 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
914 SqlServerColumnDecodeType::Unsupported { context }
915 }
916 };
917 Ok(val)
918 }
919}
920
921fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
924 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
925 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
928 numeric
929}
930
931#[derive(Debug)]
936pub struct SqlServerRowDecoder {
937 decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
938}
939
940impl SqlServerRowDecoder {
941 pub fn try_new(
945 table: &SqlServerTableDesc,
946 desc: &RelationDesc,
947 ) -> Result<Self, SqlServerError> {
948 let decoders = desc
949 .iter()
950 .map(|(col_name, col_type)| {
951 let sql_server_col = table
952 .columns
953 .iter()
954 .find(|col| col.name.as_ref() == col_name.as_str())
955 .ok_or_else(|| {
956 anyhow::anyhow!("no SQL Server column with name {col_name} found")
958 })?;
959 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
960 return Err(SqlServerError::ProgrammingError(format!(
961 "programming error, {col_name} should have been exluded",
962 )));
963 };
964
965 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
973 (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
974 | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
975 sql_server_col_typ.nullable == col_type.nullable
977 }
978 (_, _) => sql_server_col_typ == col_type,
979 };
980 if !matches {
981 return Err(SqlServerError::ProgrammingError(format!(
982 "programming error, {col_name} has mismatched type {:?} vs {:?}",
983 sql_server_col.column_type, col_type
984 )));
985 }
986
987 let name = Arc::clone(&sql_server_col.name);
988 let decoder = sql_server_col.decode_type.clone();
989 let col_typ = sql_server_col_typ.clone();
994
995 Ok::<_, SqlServerError>((name, col_typ, decoder))
996 })
997 .collect::<Result<_, _>>()?;
998
999 Ok(SqlServerRowDecoder { decoders })
1000 }
1001
1002 pub fn decode(
1004 &self,
1005 data: &tiberius::Row,
1006 row: &mut Row,
1007 arena: &RowArena,
1008 ) -> Result<(), SqlServerDecodeError> {
1009 let mut packer = row.packer();
1010 for (col_name, col_type, decoder) in &self.decoders {
1011 let datum = decoder.decode(data, col_name, col_type, arena)?;
1012 packer.push(datum);
1013 }
1014 Ok(())
1015 }
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020 use std::collections::BTreeSet;
1021 use std::sync::Arc;
1022
1023 use chrono::NaiveDateTime;
1024 use itertools::Itertools;
1025 use mz_ore::assert_contains;
1026 use mz_ore::collections::CollectionExt;
1027 use mz_repr::adt::numeric::NumericMaxScale;
1028 use mz_repr::adt::varchar::VarCharMaxLength;
1029 use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1030 use tiberius::RowTestExt;
1031
1032 use crate::desc::{
1033 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1034 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1035 };
1036
1037 use super::SqlServerColumnRaw;
1038
1039 impl SqlServerColumnRaw {
1040 fn new(name: &str, data_type: &str) -> Self {
1043 SqlServerColumnRaw {
1044 name: name.into(),
1045 data_type: data_type.into(),
1046 is_nullable: false,
1047 primary_key_constraint: None,
1048 max_length: 0,
1049 precision: 0,
1050 scale: 0,
1051 }
1052 }
1053
1054 fn nullable(mut self, nullable: bool) -> Self {
1055 self.is_nullable = nullable;
1056 self
1057 }
1058
1059 fn max_length(mut self, max_length: i16) -> Self {
1060 self.max_length = max_length;
1061 self
1062 }
1063
1064 fn precision(mut self, precision: u8) -> Self {
1065 self.precision = precision;
1066 self
1067 }
1068
1069 fn scale(mut self, scale: u8) -> Self {
1070 self.scale = scale;
1071 self
1072 }
1073 }
1074
1075 #[mz_ore::test]
1076 fn smoketest_column_raw() {
1077 let raw = SqlServerColumnRaw::new("foo", "bit");
1078 let col = SqlServerColumnDesc::new(&raw);
1079
1080 assert_eq!(&*col.name, "foo");
1081 assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1082 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1083
1084 let raw = SqlServerColumnRaw::new("foo", "decimal")
1085 .precision(20)
1086 .scale(10);
1087 let col = SqlServerColumnDesc::new(&raw);
1088
1089 let col_type = SqlScalarType::Numeric {
1090 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1091 }
1092 .nullable(false);
1093 assert_eq!(col.column_type, Some(col_type));
1094 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1095 }
1096
1097 #[mz_ore::test]
1098 fn smoketest_column_raw_invalid() {
1099 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1100 let desc = SqlServerColumnDesc::new(&raw);
1101 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1102 panic!("unexpected decode type {desc:?}");
1103 };
1104 assert_contains!(context, "'bad_data_type' is unimplemented");
1105
1106 let raw = SqlServerColumnRaw::new("foo", "decimal")
1107 .precision(100)
1108 .scale(10);
1109 let desc = SqlServerColumnDesc::new(&raw);
1110 assert!(!desc.is_supported());
1111
1112 let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
1113 let desc = SqlServerColumnDesc::new(&raw);
1114 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1115 panic!("unexpected decode type {desc:?}");
1116 };
1117 assert_contains!(context, "columns with unlimited size do not support CDC");
1118 }
1119
1120 #[mz_ore::test]
1121 fn smoketest_decoder() {
1122 let sql_server_columns = [
1123 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1124 SqlServerColumnRaw::new("b", "int").nullable(true),
1125 SqlServerColumnRaw::new("c", "bit"),
1126 ];
1127 let sql_server_desc = SqlServerTableRaw {
1128 schema_name: "my_schema".into(),
1129 name: "my_table".into(),
1130 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1131 name: "my_table_CT".into(),
1132 create_date: NaiveDateTime::parse_from_str(
1133 "2024-01-01 00:00:00",
1134 "%Y-%m-%d %H:%M:%S",
1135 )
1136 .unwrap()
1137 .into(),
1138 }),
1139 columns: sql_server_columns.into(),
1140 };
1141 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1142
1143 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1144 let relation_desc = RelationDesc::builder()
1145 .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1146 .with_column("c", SqlScalarType::Bool.nullable(false))
1148 .with_column("b", SqlScalarType::Int32.nullable(true))
1149 .finish();
1150
1151 let decoder = sql_server_desc
1153 .decoder(&relation_desc)
1154 .expect("known valid");
1155
1156 let sql_server_columns = [
1157 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1158 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1159 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1160 ];
1161
1162 let data_a = [
1163 tiberius::ColumnData::String(Some("hello world".into())),
1164 tiberius::ColumnData::I32(Some(42)),
1165 tiberius::ColumnData::Bit(Some(true)),
1166 ];
1167 let sql_server_row_a = tiberius::Row::build(
1168 sql_server_columns
1169 .iter()
1170 .cloned()
1171 .zip_eq(data_a.into_iter()),
1172 );
1173
1174 let data_b = [
1175 tiberius::ColumnData::String(Some("foo bar".into())),
1176 tiberius::ColumnData::I32(None),
1177 tiberius::ColumnData::Bit(Some(false)),
1178 ];
1179 let sql_server_row_b =
1180 tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1181
1182 let mut rnd_row = Row::default();
1183 let arena = RowArena::default();
1184
1185 decoder
1186 .decode(&sql_server_row_a, &mut rnd_row, &arena)
1187 .unwrap();
1188 assert_eq!(
1189 &rnd_row,
1190 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1191 );
1192
1193 decoder
1194 .decode(&sql_server_row_b, &mut rnd_row, &arena)
1195 .unwrap();
1196 assert_eq!(
1197 &rnd_row,
1198 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1199 );
1200 }
1201
1202 #[mz_ore::test]
1203 fn smoketest_decode_to_string() {
1204 #[track_caller]
1205 fn testcase(
1206 data_type: &'static str,
1207 col_type: tiberius::ColumnType,
1208 col_data: tiberius::ColumnData<'static>,
1209 ) {
1210 let columns = [SqlServerColumnRaw::new("a", data_type)];
1211 let sql_server_desc = SqlServerTableRaw {
1212 schema_name: "my_schema".into(),
1213 name: "my_table".into(),
1214 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1215 name: "my_table_CT".into(),
1216 create_date: NaiveDateTime::parse_from_str(
1217 "2024-01-01 00:00:00",
1218 "%Y-%m-%d %H:%M:%S",
1219 )
1220 .unwrap()
1221 .into(),
1222 }),
1223 columns: columns.into(),
1224 };
1225 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1226 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1227
1228 let relation_desc = RelationDesc::builder()
1230 .with_column("a", SqlScalarType::String.nullable(false))
1231 .finish();
1232
1233 let decoder = sql_server_desc
1235 .decoder(&relation_desc)
1236 .expect("known valid");
1237
1238 let sql_server_row = tiberius::Row::build([(
1239 tiberius::Column::new("a".to_string(), col_type),
1240 col_data,
1241 )]);
1242 let mut mz_row = Row::default();
1243 let arena = RowArena::new();
1244 decoder
1245 .decode(&sql_server_row, &mut mz_row, &arena)
1246 .unwrap();
1247
1248 let str_datum = mz_row.into_element();
1249 assert!(matches!(str_datum, Datum::String(_)));
1250 }
1251
1252 use tiberius::ColumnData;
1253
1254 testcase(
1255 "bit",
1256 tiberius::ColumnType::Bit,
1257 ColumnData::Bit(Some(true)),
1258 );
1259 testcase(
1260 "bit",
1261 tiberius::ColumnType::Bit,
1262 ColumnData::Bit(Some(false)),
1263 );
1264 testcase(
1265 "tinyint",
1266 tiberius::ColumnType::Int1,
1267 ColumnData::U8(Some(33)),
1268 );
1269 testcase(
1270 "smallint",
1271 tiberius::ColumnType::Int2,
1272 ColumnData::I16(Some(101)),
1273 );
1274 testcase(
1275 "int",
1276 tiberius::ColumnType::Int4,
1277 ColumnData::I32(Some(-42)),
1278 );
1279 {
1280 let datetime = tiberius::time::DateTime::new(10, 300);
1281 testcase(
1282 "datetime",
1283 tiberius::ColumnType::Datetime,
1284 ColumnData::DateTime(Some(datetime)),
1285 );
1286 }
1287 }
1288
1289 #[mz_ore::test]
1290 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1292 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1293 let rnd = tiberius_numeric_to_mz_numeric(a);
1294 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1295 assert_eq!(og, rnd);
1296
1297 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1298 let rnd = tiberius_numeric_to_mz_numeric(a);
1299 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1300 assert_eq!(og, rnd);
1301
1302 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1303 let rnd = tiberius_numeric_to_mz_numeric(a);
1304 let og = mz_repr::adt::numeric::cx_datum()
1305 .parse("0.00000000000000000000000000001")
1306 .unwrap();
1307 assert_eq!(og, rnd);
1308
1309 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1310 let rnd = tiberius_numeric_to_mz_numeric(a);
1311 let og = mz_repr::adt::numeric::cx_datum()
1312 .parse("-111111111111111111")
1313 .unwrap();
1314 assert_eq!(og, rnd);
1315 }
1316
1317 }