1use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::{SqlServerDecodeError, SqlServerError};
42
43include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
55pub struct SqlServerTableDesc {
56 pub schema_name: Arc<str>,
58 pub name: Arc<str>,
60 pub columns: Box<[SqlServerColumnDesc]>,
62}
63
64impl SqlServerTableDesc {
65 pub fn new(raw: SqlServerTableRaw) -> Self {
70 let columns: Box<[_]> = raw
71 .columns
72 .into_iter()
73 .map(SqlServerColumnDesc::new)
74 .collect();
75 SqlServerTableDesc {
76 schema_name: raw.schema_name,
77 name: raw.name,
78 columns,
79 }
80 }
81
82 pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
84 SqlServerQualifiedTableName {
85 schema_name: Arc::clone(&self.schema_name),
86 table_name: Arc::clone(&self.name),
87 }
88 }
89
90 pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
93 for column in &mut self.columns {
94 if text_columns.contains(column.name.as_ref()) {
95 column.represent_as_text();
96 }
97 }
98 }
99
100 pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
103 for column in &mut self.columns {
104 if excl_columns.contains(column.name.as_ref()) {
105 column.exclude();
106 }
107 }
108 }
109
110 pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
113 let decoder = SqlServerRowDecoder::try_new(self, desc)?;
114 Ok(decoder)
115 }
116}
117
118impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
119 fn into_proto(&self) -> ProtoSqlServerTableDesc {
120 ProtoSqlServerTableDesc {
121 name: self.name.to_string(),
122 schema_name: self.schema_name.to_string(),
123 columns: self.columns.iter().map(|c| c.into_proto()).collect(),
124 }
125 }
126
127 fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
128 let columns = proto
129 .columns
130 .into_iter()
131 .map(|c| c.into_rust())
132 .collect::<Result<_, _>>()?;
133 Ok(SqlServerTableDesc {
134 schema_name: proto.schema_name.into(),
135 name: proto.name.into(),
136 columns,
137 })
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
145pub struct SqlServerQualifiedTableName {
146 pub schema_name: Arc<str>,
147 pub table_name: Arc<str>,
148}
149
150#[derive(Debug, Clone)]
155pub struct SqlServerTableRaw {
156 pub schema_name: Arc<str>,
158 pub name: Arc<str>,
160 pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
162 pub columns: Arc<[SqlServerColumnRaw]>,
164}
165
166#[derive(Debug, Clone)]
168pub struct SqlServerCaptureInstanceRaw {
169 pub name: Arc<str>,
171 pub create_date: Arc<NaiveDateTime>,
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
177pub struct SqlServerColumnDesc {
178 pub name: Arc<str>,
180 pub column_type: Option<SqlColumnType>,
186 pub primary_key_constraint: Option<Arc<str>>,
188 pub decode_type: SqlServerColumnDecodeType,
190 pub raw_type: Arc<str>,
194}
195
196impl SqlServerColumnDesc {
197 pub fn new(raw: &SqlServerColumnRaw) -> Self {
199 let (column_type, decode_type) = match parse_data_type(raw) {
200 Ok((scalar_type, decode_type)) => {
201 let column_type = scalar_type.nullable(raw.is_nullable);
202 (Some(column_type), decode_type)
203 }
204 Err(err) => {
205 tracing::warn!(
206 ?err,
207 ?raw,
208 "found an unsupported data type when parsing raw data"
209 );
210 (
211 None,
212 SqlServerColumnDecodeType::Unsupported {
213 context: err.reason,
214 },
215 )
216 }
217 };
218 SqlServerColumnDesc {
219 name: Arc::clone(&raw.name),
220 primary_key_constraint: raw.primary_key_constraint.clone(),
221 column_type,
222 decode_type,
223 raw_type: Arc::clone(&raw.data_type),
224 }
225 }
226
227 pub fn is_supported(&self) -> bool {
229 !matches!(
230 self.decode_type,
231 SqlServerColumnDecodeType::Unsupported { .. }
232 )
233 }
234
235 pub fn represent_as_text(&mut self) {
237 self.column_type = self
238 .column_type
239 .as_ref()
240 .map(|ct| SqlScalarType::String.nullable(ct.nullable));
241 }
242
243 pub fn exclude(&mut self) {
245 self.column_type = None;
246 }
247
248 pub fn is_excluded(&self) -> bool {
250 self.column_type.is_none()
251 }
252}
253
254impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
255 fn into_proto(&self) -> ProtoSqlServerColumnDesc {
256 ProtoSqlServerColumnDesc {
257 name: self.name.to_string(),
258 column_type: self.column_type.into_proto(),
259 primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
260 decode_type: Some(self.decode_type.into_proto()),
261 raw_type: self.raw_type.to_string(),
262 }
263 }
264
265 fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
266 Ok(SqlServerColumnDesc {
267 name: proto.name.into(),
268 column_type: proto.column_type.into_rust()?,
269 primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
270 decode_type: proto
271 .decode_type
272 .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
273 raw_type: proto.raw_type.into(),
274 })
275 }
276}
277
278#[derive(Debug)]
280#[allow(dead_code)]
281pub struct UnsupportedDataType {
282 column_name: String,
283 column_type: String,
284 reason: String,
285}
286
287fn parse_data_type(
292 raw: &SqlServerColumnRaw,
293) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
294 let scalar = match raw.data_type.to_lowercase().as_str() {
295 "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
296 "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
297 "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
298 "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
299 "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
300 "decimal" | "numeric" | "money" | "smallmoney" => {
301 if raw.precision > 38 || raw.scale > raw.precision {
308 tracing::warn!(
309 "unexpected value from SQL Server, precision of {} and scale of {}",
310 raw.precision,
311 raw.scale,
312 );
313 }
314 if raw.precision > 39 {
315 let reason = format!(
316 "precision of {} is greater than our maximum of 39",
317 raw.precision
318 );
319 return Err(UnsupportedDataType {
320 column_name: raw.name.to_string(),
321 column_type: raw.data_type.to_string(),
322 reason,
323 });
324 }
325
326 let raw_scale = usize::cast_from(raw.scale);
327 let max_scale =
328 NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
329 column_type: raw.data_type.to_string(),
330 column_name: raw.name.to_string(),
331 reason: format!("scale of {} is too large", raw.scale),
332 })?;
333 let column_type = SqlScalarType::Numeric {
334 max_scale: Some(max_scale),
335 };
336
337 (column_type, SqlServerColumnDecodeType::Numeric)
338 }
339 "real" => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
340 "double" => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
341 dt @ ("char" | "nchar" | "varchar" | "nvarchar" | "sysname") => {
342 if raw.max_length == -1 {
347 return Err(UnsupportedDataType {
348 column_name: raw.name.to_string(),
349 column_type: raw.data_type.to_string(),
350 reason: "columns with unlimited size do not support CDC".to_string(),
351 });
352 }
353
354 let column_type = match dt {
355 "char" => {
356 let length = CharLength::try_from(i64::from(raw.max_length)).map_err(|e| {
357 UnsupportedDataType {
358 column_name: raw.name.to_string(),
359 column_type: raw.data_type.to_string(),
360 reason: e.to_string(),
361 }
362 })?;
363 SqlScalarType::Char {
364 length: Some(length),
365 }
366 }
367 "varchar" => {
368 let length =
369 VarCharMaxLength::try_from(i64::from(raw.max_length)).map_err(|e| {
370 UnsupportedDataType {
371 column_name: raw.name.to_string(),
372 column_type: raw.data_type.to_string(),
373 reason: e.to_string(),
374 }
375 })?;
376 SqlScalarType::VarChar {
377 max_length: Some(length),
378 }
379 }
380 "nchar" | "nvarchar" | "sysname" => SqlScalarType::String,
384 other => unreachable!("'{other}' checked above"),
385 };
386
387 (column_type, SqlServerColumnDecodeType::String)
388 }
389 "text" | "ntext" | "image" => {
390 mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
393
394 return Err(UnsupportedDataType {
396 column_name: raw.name.to_string(),
397 column_type: raw.data_type.to_string(),
398 reason: "columns with unlimited size do not support CDC".to_string(),
399 });
400 }
401 "xml" => {
402 if raw.max_length == -1 {
407 return Err(UnsupportedDataType {
408 column_name: raw.name.to_string(),
409 column_type: raw.data_type.to_string(),
410 reason: "columns with unlimited size do not support CDC".to_string(),
411 });
412 }
413 (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
414 }
415 "binary" | "varbinary" => {
416 if raw.max_length == -1 {
422 return Err(UnsupportedDataType {
423 column_name: raw.name.to_string(),
424 column_type: raw.data_type.to_string(),
425 reason: "columns with unlimited size do not support CDC".to_string(),
426 });
427 }
428
429 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
430 }
431 "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
432 "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
433 "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
446 dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
447 if raw.scale > 7 {
448 tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
449 }
450 if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
451 tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
452 }
453 let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
454 let precision =
455 Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
456
457 match dt {
458 "smalldatetime" | "datetime" | "datetime2" => (
459 SqlScalarType::Timestamp { precision },
460 SqlServerColumnDecodeType::NaiveDateTime,
461 ),
462 "datetimeoffset" => (
463 SqlScalarType::TimestampTz { precision },
464 SqlServerColumnDecodeType::DateTime,
465 ),
466 other => unreachable!("'{other}' checked above"),
467 }
468 }
469 "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
470 other => {
483 return Err(UnsupportedDataType {
484 column_type: other.to_string(),
485 column_name: raw.name.to_string(),
486 reason: format!("'{other}' is unimplemented"),
487 });
488 }
489 };
490 Ok(scalar)
491}
492
493#[derive(Clone, Debug)]
497pub struct SqlServerColumnRaw {
498 pub name: Arc<str>,
500 pub data_type: Arc<str>,
502 pub is_nullable: bool,
504 pub primary_key_constraint: Option<Arc<str>>,
506 pub max_length: i16,
516 pub precision: u8,
518 pub scale: u8,
520}
521
522#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
524pub enum SqlServerColumnDecodeType {
525 Bool,
526 U8,
527 I16,
528 I32,
529 I64,
530 F32,
531 F64,
532 String,
533 Bytes,
534 Uuid,
536 Numeric,
538 Xml,
540 NaiveDate,
542 NaiveTime,
544 DateTime,
546 NaiveDateTime,
548 Unsupported {
550 context: String,
552 },
553}
554
555impl SqlServerColumnDecodeType {
556 pub fn decode<'a>(
558 &self,
559 data: &'a tiberius::Row,
560 name: &'a str,
561 column: &'a SqlColumnType,
562 arena: &'a RowArena,
563 ) -> Result<Datum<'a>, SqlServerDecodeError> {
564 let maybe_datum = match (&column.scalar_type, self) {
565 (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
566 .try_get(name)
567 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
568 .map(|val: bool| if val { Datum::True } else { Datum::False }),
569 (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
570 .try_get(name)
571 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
572 .map(|val: u8| Datum::Int16(i16::cast_from(val))),
573 (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
574 .try_get(name)
575 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
576 .map(Datum::Int16),
577 (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
578 .try_get(name)
579 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
580 .map(Datum::Int32),
581 (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
582 .try_get(name)
583 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
584 .map(Datum::Int64),
585 (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
586 .try_get(name)
587 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
588 .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
589 (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
590 .try_get(name)
591 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
592 .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
593 (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
594 .try_get(name)
595 .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
596 .map(Datum::String),
597 (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
598 .try_get(name)
599 .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
600 .map(|val: &str| match length {
601 Some(expected) => {
602 let found_chars = val.chars().count();
603 let expct_chars = usize::cast_from(expected.into_u32());
604 if found_chars != expct_chars {
605 Err(SqlServerDecodeError::invalid_char(
606 name,
607 expct_chars,
608 found_chars,
609 ))
610 } else {
611 Ok(Datum::String(val))
612 }
613 }
614 None => Ok(Datum::String(val)),
615 })
616 .transpose()?,
617 (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
618 .try_get(name)
619 .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
620 .map(|val: &str| match max_length {
621 Some(max) => {
622 let found_chars = val.chars().count();
623 let max_chars = usize::cast_from(max.into_u32());
624 if found_chars > max_chars {
625 Err(SqlServerDecodeError::invalid_varchar(
626 name,
627 max_chars,
628 found_chars,
629 ))
630 } else {
631 Ok(Datum::String(val))
632 }
633 }
634 None => Ok(Datum::String(val)),
635 })
636 .transpose()?,
637 (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
638 .try_get(name)
639 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
640 .map(Datum::Bytes),
641 (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
642 .try_get(name)
643 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
644 .map(Datum::Uuid),
645 (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
646 .try_get(name)
647 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
648 .map(|val: tiberius::numeric::Numeric| {
649 let numeric = tiberius_numeric_to_mz_numeric(val);
650 Datum::Numeric(OrderedDecimal(numeric))
651 }),
652 (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
653 .try_get(name)
654 .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
655 .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
656 (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
657 .try_get(name)
658 .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
659 .map(|val: chrono::NaiveDate| {
660 let date = val
661 .try_into()
662 .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
663 Ok::<_, SqlServerDecodeError>(Datum::Date(date))
664 })
665 .transpose()?,
666 (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
667 .try_get(name)
668 .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
669 .map(|val: chrono::NaiveTime| {
670 let rounded = val.round_subsecs(6);
675 let val = if rounded < val {
677 val.trunc_subsecs(6)
678 } else {
679 val
680 };
681 Datum::Time(val)
682 }),
683 (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
684 data.try_get(name)
685 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
686 .map(|val: chrono::NaiveDateTime| {
687 let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
688 .try_into()
689 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
690 let rounded = ts
691 .round_to_precision(*precision)
692 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
693 Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
694 })
695 .transpose()?
696 }
697 (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
698 .try_get(name)
699 .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
700 .map(|val: chrono::DateTime<chrono::Utc>| {
701 let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
702 .try_into()
703 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
704 let rounded = ts
705 .round_to_precision(*precision)
706 .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
707 Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
708 })
709 .transpose()?,
710 (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
712 .try_get(name)
713 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
714 .map(|val: bool| {
715 if val {
716 Datum::String("true")
717 } else {
718 Datum::String("false")
719 }
720 }),
721 (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
722 .try_get(name)
723 .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
724 .map(|val: u8| {
725 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
726 }),
727 (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
728 .try_get(name)
729 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
730 .map(|val: i16| {
731 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
732 }),
733 (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
734 .try_get(name)
735 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
736 .map(|val: i32| {
737 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
738 }),
739 (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
740 .try_get(name)
741 .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
742 .map(|val: i64| {
743 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
744 }),
745 (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
746 .try_get(name)
747 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
748 .map(|val: f32| {
749 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
750 }),
751 (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
752 .try_get(name)
753 .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
754 .map(|val: f64| {
755 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
756 }),
757 (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
758 .try_get(name)
759 .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
760 .map(|val: uuid::Uuid| {
761 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
762 }),
763 (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
764 .try_get(name)
765 .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
766 .map(|val: &[u8]| {
767 let encoded = base64::engine::general_purpose::STANDARD.encode(val);
768 arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
769 }),
770 (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
771 .try_get(name)
772 .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
773 .map(|val: tiberius::numeric::Numeric| {
774 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
775 }),
776 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
777 .try_get(name)
778 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
779 .map(|val: chrono::NaiveDate| {
780 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
781 }),
782 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
783 .try_get(name)
784 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
785 .map(|val: chrono::NaiveTime| {
786 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
787 }),
788 (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
789 .try_get(name)
790 .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
791 .map(|val: chrono::DateTime<chrono::Utc>| {
792 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
793 }),
794 (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
795 .try_get(name)
796 .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
797 .map(|val: chrono::NaiveDateTime| {
798 arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
799 }),
800 (column_type, decode_type) => {
801 return Err(SqlServerDecodeError::Unsupported {
802 sql_server_type: decode_type.clone(),
803 mz_type: column_type.clone(),
804 });
805 }
806 };
807
808 match (maybe_datum, column.nullable) {
809 (Some(datum), _) => Ok(datum),
810 (None, true) => Ok(Datum::Null),
811 (None, false) => Err(SqlServerDecodeError::InvalidData {
812 column_name: name.to_string(),
813 error: "found Null in non-nullable column".to_string(),
815 }),
816 }
817 }
818}
819
820impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
821 fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
822 match self {
823 SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
824 SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
825 SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
826 SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
827 SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
828 SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
829 SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
830 SqlServerColumnDecodeType::String => {
831 proto_sql_server_column_desc::DecodeType::String(())
832 }
833 SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
834 SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
835 SqlServerColumnDecodeType::Numeric => {
836 proto_sql_server_column_desc::DecodeType::Numeric(())
837 }
838 SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
839 SqlServerColumnDecodeType::NaiveDate => {
840 proto_sql_server_column_desc::DecodeType::NaiveDate(())
841 }
842 SqlServerColumnDecodeType::NaiveTime => {
843 proto_sql_server_column_desc::DecodeType::NaiveTime(())
844 }
845 SqlServerColumnDecodeType::DateTime => {
846 proto_sql_server_column_desc::DecodeType::DateTime(())
847 }
848 SqlServerColumnDecodeType::NaiveDateTime => {
849 proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
850 }
851 SqlServerColumnDecodeType::Unsupported { context } => {
852 proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
853 }
854 }
855 }
856
857 fn from_proto(
858 proto: proto_sql_server_column_desc::DecodeType,
859 ) -> Result<Self, mz_proto::TryFromProtoError> {
860 let val = match proto {
861 proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
862 proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
863 proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
864 proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
865 proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
866 proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
867 proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
868 proto_sql_server_column_desc::DecodeType::String(()) => {
869 SqlServerColumnDecodeType::String
870 }
871 proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
872 proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
873 proto_sql_server_column_desc::DecodeType::Numeric(()) => {
874 SqlServerColumnDecodeType::Numeric
875 }
876 proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
877 proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
878 SqlServerColumnDecodeType::NaiveDate
879 }
880 proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
881 SqlServerColumnDecodeType::NaiveTime
882 }
883 proto_sql_server_column_desc::DecodeType::DateTime(()) => {
884 SqlServerColumnDecodeType::DateTime
885 }
886 proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
887 SqlServerColumnDecodeType::NaiveDateTime
888 }
889 proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
890 SqlServerColumnDecodeType::Unsupported { context }
891 }
892 };
893 Ok(val)
894 }
895}
896
897fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
900 let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
901 mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
904 numeric
905}
906
907#[derive(Debug)]
912pub struct SqlServerRowDecoder {
913 decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
914}
915
916impl SqlServerRowDecoder {
917 pub fn try_new(
921 table: &SqlServerTableDesc,
922 desc: &RelationDesc,
923 ) -> Result<Self, SqlServerError> {
924 let decoders = desc
925 .iter()
926 .map(|(col_name, col_type)| {
927 let sql_server_col = table
928 .columns
929 .iter()
930 .find(|col| col.name.as_ref() == col_name.as_str())
931 .ok_or_else(|| {
932 anyhow::anyhow!("no SQL Server column with name {col_name} found")
934 })?;
935 let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
936 return Err(SqlServerError::ProgrammingError(format!(
937 "programming error, {col_name} should have been exluded",
938 )));
939 };
940
941 let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
949 (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
950 | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
951 sql_server_col_typ.nullable == col_type.nullable
953 }
954 (_, _) => sql_server_col_typ == col_type,
955 };
956 if !matches {
957 return Err(SqlServerError::ProgrammingError(format!(
958 "programming error, {col_name} has mismatched type {:?} vs {:?}",
959 sql_server_col.column_type, col_type
960 )));
961 }
962
963 let name = Arc::clone(&sql_server_col.name);
964 let decoder = sql_server_col.decode_type.clone();
965 let col_typ = sql_server_col_typ.clone();
970
971 Ok::<_, SqlServerError>((name, col_typ, decoder))
972 })
973 .collect::<Result<_, _>>()?;
974
975 Ok(SqlServerRowDecoder { decoders })
976 }
977
978 pub fn decode(
980 &self,
981 data: &tiberius::Row,
982 row: &mut Row,
983 arena: &RowArena,
984 ) -> Result<(), SqlServerDecodeError> {
985 let mut packer = row.packer();
986 for (col_name, col_type, decoder) in &self.decoders {
987 let datum = decoder.decode(data, col_name, col_type, arena)?;
988 packer.push(datum);
989 }
990 Ok(())
991 }
992}
993
994#[cfg(test)]
995mod tests {
996 use chrono::NaiveDateTime;
997 use std::collections::BTreeSet;
998 use std::sync::Arc;
999
1000 use crate::desc::{
1001 SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1002 SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1003 };
1004
1005 use super::SqlServerColumnRaw;
1006 use mz_ore::assert_contains;
1007 use mz_ore::collections::CollectionExt;
1008 use mz_repr::adt::numeric::NumericMaxScale;
1009 use mz_repr::adt::varchar::VarCharMaxLength;
1010 use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1011 use tiberius::RowTestExt;
1012
1013 impl SqlServerColumnRaw {
1014 fn new(name: &str, data_type: &str) -> Self {
1017 SqlServerColumnRaw {
1018 name: name.into(),
1019 data_type: data_type.into(),
1020 is_nullable: false,
1021 primary_key_constraint: None,
1022 max_length: 0,
1023 precision: 0,
1024 scale: 0,
1025 }
1026 }
1027
1028 fn nullable(mut self, nullable: bool) -> Self {
1029 self.is_nullable = nullable;
1030 self
1031 }
1032
1033 fn max_length(mut self, max_length: i16) -> Self {
1034 self.max_length = max_length;
1035 self
1036 }
1037
1038 fn precision(mut self, precision: u8) -> Self {
1039 self.precision = precision;
1040 self
1041 }
1042
1043 fn scale(mut self, scale: u8) -> Self {
1044 self.scale = scale;
1045 self
1046 }
1047 }
1048
1049 #[mz_ore::test]
1050 fn smoketest_column_raw() {
1051 let raw = SqlServerColumnRaw::new("foo", "bit");
1052 let col = SqlServerColumnDesc::new(&raw);
1053
1054 assert_eq!(&*col.name, "foo");
1055 assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1056 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1057
1058 let raw = SqlServerColumnRaw::new("foo", "decimal")
1059 .precision(20)
1060 .scale(10);
1061 let col = SqlServerColumnDesc::new(&raw);
1062
1063 let col_type = SqlScalarType::Numeric {
1064 max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1065 }
1066 .nullable(false);
1067 assert_eq!(col.column_type, Some(col_type));
1068 assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1069 }
1070
1071 #[mz_ore::test]
1072 fn smoketest_column_raw_invalid() {
1073 let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1074 let desc = SqlServerColumnDesc::new(&raw);
1075 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1076 panic!("unexpected decode type {desc:?}");
1077 };
1078 assert_contains!(context, "'bad_data_type' is unimplemented");
1079
1080 let raw = SqlServerColumnRaw::new("foo", "decimal")
1081 .precision(100)
1082 .scale(10);
1083 let desc = SqlServerColumnDesc::new(&raw);
1084 assert!(!desc.is_supported());
1085
1086 let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
1087 let desc = SqlServerColumnDesc::new(&raw);
1088 let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1089 panic!("unexpected decode type {desc:?}");
1090 };
1091 assert_contains!(context, "columns with unlimited size do not support CDC");
1092 }
1093
1094 #[mz_ore::test]
1095 fn smoketest_decoder() {
1096 let sql_server_columns = [
1097 SqlServerColumnRaw::new("a", "varchar").max_length(16),
1098 SqlServerColumnRaw::new("b", "int").nullable(true),
1099 SqlServerColumnRaw::new("c", "bit"),
1100 ];
1101 let sql_server_desc = SqlServerTableRaw {
1102 schema_name: "my_schema".into(),
1103 name: "my_table".into(),
1104 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1105 name: "my_table_CT".into(),
1106 create_date: NaiveDateTime::parse_from_str(
1107 "2024-01-01 00:00:00",
1108 "%Y-%m-%d %H:%M:%S",
1109 )
1110 .unwrap()
1111 .into(),
1112 }),
1113 columns: sql_server_columns.into(),
1114 };
1115 let sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1116
1117 let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1118 let relation_desc = RelationDesc::builder()
1119 .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1120 .with_column("c", SqlScalarType::Bool.nullable(false))
1122 .with_column("b", SqlScalarType::Int32.nullable(true))
1123 .finish();
1124
1125 let decoder = sql_server_desc
1127 .decoder(&relation_desc)
1128 .expect("known valid");
1129
1130 let sql_server_columns = [
1131 tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1132 tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1133 tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1134 ];
1135
1136 let data_a = [
1137 tiberius::ColumnData::String(Some("hello world".into())),
1138 tiberius::ColumnData::I32(Some(42)),
1139 tiberius::ColumnData::Bit(Some(true)),
1140 ];
1141 let sql_server_row_a =
1142 tiberius::Row::build(sql_server_columns.iter().cloned().zip(data_a.into_iter()));
1143
1144 let data_b = [
1145 tiberius::ColumnData::String(Some("foo bar".into())),
1146 tiberius::ColumnData::I32(None),
1147 tiberius::ColumnData::Bit(Some(false)),
1148 ];
1149 let sql_server_row_b =
1150 tiberius::Row::build(sql_server_columns.into_iter().zip(data_b.into_iter()));
1151
1152 let mut rnd_row = Row::default();
1153 let arena = RowArena::default();
1154
1155 decoder
1156 .decode(&sql_server_row_a, &mut rnd_row, &arena)
1157 .unwrap();
1158 assert_eq!(
1159 &rnd_row,
1160 &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1161 );
1162
1163 decoder
1164 .decode(&sql_server_row_b, &mut rnd_row, &arena)
1165 .unwrap();
1166 assert_eq!(
1167 &rnd_row,
1168 &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1169 );
1170 }
1171
1172 #[mz_ore::test]
1173 fn smoketest_decode_to_string() {
1174 #[track_caller]
1175 fn testcase(
1176 data_type: &'static str,
1177 col_type: tiberius::ColumnType,
1178 col_data: tiberius::ColumnData<'static>,
1179 ) {
1180 let columns = [SqlServerColumnRaw::new("a", data_type)];
1181 let sql_server_desc = SqlServerTableRaw {
1182 schema_name: "my_schema".into(),
1183 name: "my_table".into(),
1184 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1185 name: "my_table_CT".into(),
1186 create_date: NaiveDateTime::parse_from_str(
1187 "2024-01-01 00:00:00",
1188 "%Y-%m-%d %H:%M:%S",
1189 )
1190 .unwrap()
1191 .into(),
1192 }),
1193 columns: columns.into(),
1194 };
1195 let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc);
1196 sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1197
1198 let relation_desc = RelationDesc::builder()
1200 .with_column("a", SqlScalarType::String.nullable(false))
1201 .finish();
1202
1203 let decoder = sql_server_desc
1205 .decoder(&relation_desc)
1206 .expect("known valid");
1207
1208 let sql_server_row = tiberius::Row::build([(
1209 tiberius::Column::new("a".to_string(), col_type),
1210 col_data,
1211 )]);
1212 let mut mz_row = Row::default();
1213 let arena = RowArena::new();
1214 decoder
1215 .decode(&sql_server_row, &mut mz_row, &arena)
1216 .unwrap();
1217
1218 let str_datum = mz_row.into_element();
1219 assert!(matches!(str_datum, Datum::String(_)));
1220 }
1221
1222 use tiberius::{ColumnData, ColumnType};
1223
1224 testcase("bit", ColumnType::Bit, ColumnData::Bit(Some(true)));
1225 testcase("bit", ColumnType::Bit, ColumnData::Bit(Some(false)));
1226 testcase("tinyint", ColumnType::Int1, ColumnData::U8(Some(33)));
1227 testcase("smallint", ColumnType::Int2, ColumnData::I16(Some(101)));
1228 testcase("int", ColumnType::Int4, ColumnData::I32(Some(-42)));
1229 {
1230 let datetime = tiberius::time::DateTime::new(10, 300);
1231 testcase(
1232 "datetime",
1233 ColumnType::Datetime,
1234 ColumnData::DateTime(Some(datetime)),
1235 );
1236 }
1237 }
1238
1239 #[mz_ore::test]
1240 #[cfg_attr(miri, ignore)] fn smoketest_numeric_conversion() {
1242 let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1243 let rnd = tiberius_numeric_to_mz_numeric(a);
1244 let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1245 assert_eq!(og, rnd);
1246
1247 let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1248 let rnd = tiberius_numeric_to_mz_numeric(a);
1249 let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1250 assert_eq!(og, rnd);
1251
1252 let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1253 let rnd = tiberius_numeric_to_mz_numeric(a);
1254 let og = mz_repr::adt::numeric::cx_datum()
1255 .parse("0.00000000000000000000000000001")
1256 .unwrap();
1257 assert_eq!(og, rnd);
1258
1259 let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1260 let rnd = tiberius_numeric_to_mz_numeric(a);
1261 let og = mz_repr::adt::numeric::cx_datum()
1262 .parse("-111111111111111111")
1263 .unwrap();
1264 assert_eq!(og, rnd);
1265 }
1266
1267 }