Skip to main content

mz_sql_server_util/
desc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Metadata about tables, columns, and other objects from SQL Server.
11//!
12//! ### Tables
13//!
14//! When creating a SQL Server source we will query system tables from the
15//! upstream instance to get a [`SqlServerTableRaw`]. From this raw information
16//! we create a [`SqlServerTableDesc`] which describes how the upstream table
17//! will get represented in Materialize.
18//!
19//! ### Rows
20//!
21//! With a [`SqlServerTableDesc`] and an [`mz_repr::RelationDesc`] we can
22//! create a [`SqlServerRowDecoder`] which will be used when running a source
23//! to efficiently decode [`tiberius::Row`]s into [`mz_repr::Row`]s.
24
25use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35use proptest_derive::Arbitrary;
36use serde::{Deserialize, Serialize};
37
38use std::collections::BTreeSet;
39use std::sync::Arc;
40
41use crate::desc::proto_sql_server_table_constraint::ConstraintType;
42use crate::{SqlServerDecodeError, SqlServerError};
43
44include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
45
46/// Materialize compatible description of a table in Microsoft SQL Server.
47///
48/// See [`SqlServerTableRaw`] for the raw information we read from the upstream
49/// system.
50///
51/// Note: We map a [`SqlServerTableDesc`] to a Materialize [`RelationDesc`] as
52/// part of purification. Specifically we use this description to generate a
53/// SQL statement for subsource and it's the _parsing of that statement_ which
54/// actually generates a [`RelationDesc`].
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
56pub struct SqlServerTableDesc {
57    /// Name of the schema that the table belongs to.
58    pub schema_name: Arc<str>,
59    /// Name of the table.
60    pub name: Arc<str>,
61    /// Columns for the table.
62    pub columns: Box<[SqlServerColumnDesc]>,
63    /// Constraints for the table.
64    pub constraints: Vec<SqlServerTableConstraint>,
65}
66
67impl SqlServerTableDesc {
68    /// Creating a [`SqlServerTableDesc`] from a [`SqlServerTableRaw`] description.
69    ///
70    /// Note: Not all columns from SQL Server can be ingested into Materialize. To determine if a
71    /// column is supported see [`SqlServerColumnDesc::decode_type`].
72    pub fn new(
73        raw: SqlServerTableRaw,
74        raw_constraints: Vec<SqlServerTableConstraintRaw>,
75    ) -> Result<Self, SqlServerError> {
76        let columns: Box<[_]> = raw
77            .columns
78            .into_iter()
79            .map(SqlServerColumnDesc::new)
80            .collect();
81        let constraints = raw_constraints
82            .into_iter()
83            .map(SqlServerTableConstraint::try_from)
84            .collect::<Result<Vec<_>, _>>()?;
85        Ok(SqlServerTableDesc {
86            schema_name: raw.schema_name,
87            name: raw.name,
88            columns,
89            constraints,
90        })
91    }
92
93    /// Returns the [`SqlServerQualifiedTableName`] for this [`SqlServerTableDesc`].
94    pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
95        SqlServerQualifiedTableName {
96            schema_name: Arc::clone(&self.schema_name),
97            table_name: Arc::clone(&self.name),
98        }
99    }
100
101    /// Update this [`SqlServerTableDesc`] to represent the specified columns
102    /// as text in Materialize.
103    pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
104        for column in &mut self.columns {
105            if text_columns.contains(column.name.as_ref()) {
106                column.represent_as_text();
107            }
108        }
109    }
110
111    /// Update this [`SqlServerTableDesc`] to exclude the specified columns from being
112    /// replicated into Materialize.
113    pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
114        for column in &mut self.columns {
115            if excl_columns.contains(column.name.as_ref()) {
116                column.exclude();
117            }
118        }
119    }
120
121    /// Returns a [`SqlServerRowDecoder`] which can be used to decode [`tiberius::Row`]s into
122    /// [`mz_repr::Row`]s that match the shape of the provided [`RelationDesc`].
123    pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
124        let decoder = SqlServerRowDecoder::try_new(self, desc)?;
125        Ok(decoder)
126    }
127}
128
129impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
130    fn into_proto(&self) -> ProtoSqlServerTableDesc {
131        ProtoSqlServerTableDesc {
132            name: self.name.to_string(),
133            schema_name: self.schema_name.to_string(),
134            columns: self.columns.iter().map(|c| c.into_proto()).collect(),
135            constraints: self.constraints.iter().map(|c| c.into_proto()).collect(),
136        }
137    }
138
139    fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
140        let columns = proto
141            .columns
142            .into_iter()
143            .map(|c| c.into_rust())
144            .collect::<Result<_, _>>()?;
145        let constraints = proto
146            .constraints
147            .into_iter()
148            .map(|c| c.into_rust())
149            .collect::<Result<_, _>>()?;
150        Ok(SqlServerTableDesc {
151            schema_name: proto.schema_name.into(),
152            name: proto.name.into(),
153            columns,
154            constraints,
155        })
156    }
157}
158
159/// SQL Server table constraint type (e.g. PRIMARY KEY, UNIQUE, etc.)
160/// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-information-schema-views/table-constraints-transact-sql?view=sql-server-ver17>
161#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Arbitrary)]
162pub enum SqlServerTableConstraintType {
163    PrimaryKey,
164    Unique,
165}
166
167impl TryFrom<String> for SqlServerTableConstraintType {
168    type Error = SqlServerError;
169
170    fn try_from(value: String) -> Result<Self, Self::Error> {
171        match value.as_str() {
172            "PRIMARY KEY" => Ok(Self::PrimaryKey),
173            "UNIQUE" => Ok(Self::Unique),
174            name => Err(SqlServerError::InvalidData {
175                column_name: "constraint_type".into(),
176                error: format!("Unknown constraint type: {name}"),
177            }),
178        }
179    }
180}
181
182impl RustType<proto_sql_server_table_constraint::ConstraintType> for SqlServerTableConstraintType {
183    fn into_proto(&self) -> proto_sql_server_table_constraint::ConstraintType {
184        match self {
185            SqlServerTableConstraintType::PrimaryKey => ConstraintType::PrimaryKey(()),
186            SqlServerTableConstraintType::Unique => ConstraintType::Unique(()),
187        }
188    }
189
190    fn from_proto(
191        proto: proto_sql_server_table_constraint::ConstraintType,
192    ) -> Result<Self, mz_proto::TryFromProtoError> {
193        Ok(match proto {
194            ConstraintType::PrimaryKey(_) => SqlServerTableConstraintType::PrimaryKey,
195            ConstraintType::Unique(_) => SqlServerTableConstraintType::Unique,
196        })
197    }
198}
199
200/// SQL Server table constraint.
201#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Arbitrary)]
202pub struct SqlServerTableConstraint {
203    pub constraint_name: String,
204    pub constraint_type: SqlServerTableConstraintType,
205    pub column_names: Vec<String>,
206}
207
208impl TryFrom<SqlServerTableConstraintRaw> for SqlServerTableConstraint {
209    type Error = SqlServerError;
210
211    fn try_from(value: SqlServerTableConstraintRaw) -> Result<Self, Self::Error> {
212        Ok(SqlServerTableConstraint {
213            constraint_name: value.constraint_name,
214            constraint_type: value.constraint_type.try_into()?,
215            column_names: value.columns,
216        })
217    }
218}
219
220impl RustType<ProtoSqlServerTableConstraint> for SqlServerTableConstraint {
221    fn into_proto(&self) -> ProtoSqlServerTableConstraint {
222        ProtoSqlServerTableConstraint {
223            constraint_name: self.constraint_name.clone(),
224            constraint_type: Some(self.constraint_type.into_proto()),
225            column_names: self.column_names.clone(),
226        }
227    }
228
229    fn from_proto(
230        proto: ProtoSqlServerTableConstraint,
231    ) -> Result<Self, mz_proto::TryFromProtoError> {
232        Ok(SqlServerTableConstraint {
233            constraint_name: proto.constraint_name,
234            constraint_type: proto
235                .constraint_type
236                .into_rust_if_some("ProtoSqlServerTableConstraint::constraint_type")?,
237            column_names: proto.column_names,
238        })
239    }
240}
241
242/// Partially qualified name of a table from Microsoft SQL Server.
243///
244/// TODO(sql_server3): Change this to use a &str.
245#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
246pub struct SqlServerQualifiedTableName {
247    pub schema_name: Arc<str>,
248    pub table_name: Arc<str>,
249}
250
251impl ToString for SqlServerQualifiedTableName {
252    fn to_string(&self) -> String {
253        format!(
254            "{}.{}",
255            crate::quote_identifier(&self.schema_name),
256            crate::quote_identifier(&self.table_name)
257        )
258    }
259}
260
261/// Raw metadata for a table from Microsoft SQL Server.
262///
263/// See [`SqlServerTableDesc`] for a refined description that is compatible
264/// with Materialize.
265#[derive(Debug, Clone)]
266pub struct SqlServerTableRaw {
267    /// Name of the schema the table belongs to.
268    pub schema_name: Arc<str>,
269    /// Name of the table.
270    pub name: Arc<str>,
271    /// The capture instance replicating changes.
272    pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
273    /// Columns for the table.
274    pub columns: Arc<[SqlServerColumnRaw]>,
275}
276
277/// Raw capture instance metadata.
278#[derive(Debug, Clone)]
279pub struct SqlServerCaptureInstanceRaw {
280    /// The capture instance replicating changes.
281    pub name: Arc<str>,
282    /// The creation date of the capture instance.
283    pub create_date: Arc<NaiveDateTime>,
284}
285
286/// Description of a column from a table in Microsoft SQL Server.
287#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
288pub struct SqlServerColumnDesc {
289    /// Name of the column.
290    pub name: Arc<str>,
291    /// The intended data type of the this column in Materialize. `None` indicates this
292    /// column should be excluded when replicating into Materialize.
293    ///
294    /// Note: This type might differ from the `decode_type`, e.g. a user can
295    /// specify `TEXT COLUMNS` to decode columns as text.
296    pub column_type: Option<SqlColumnType>,
297    /// This field is deprecated and will be removed in a future version.  This exists only for the
298    /// purpose of migrating from old representations.
299    pub primary_key_constraint: Option<Arc<str>>,
300    /// Rust type we should parse the data from a [`tiberius::Row`] as.
301    pub decode_type: SqlServerColumnDecodeType,
302    /// Raw type of the column as we read it from upstream.
303    ///
304    /// This is useful to keep around for debugging purposes.
305    pub raw_type: Arc<str>,
306}
307
308impl SqlServerColumnDesc {
309    /// Create a [`SqlServerColumnDesc`] from a [`SqlServerColumnRaw`] description.
310    pub fn new(raw: &SqlServerColumnRaw) -> Self {
311        let (column_type, decode_type) = match parse_data_type(raw) {
312            Ok((scalar_type, decode_type)) => {
313                let column_type = scalar_type.nullable(raw.is_nullable);
314                (Some(column_type), decode_type)
315            }
316            Err(err) => {
317                tracing::warn!(
318                    ?err,
319                    ?raw,
320                    "found an unsupported data type when parsing raw data"
321                );
322                (
323                    None,
324                    SqlServerColumnDecodeType::Unsupported {
325                        context: err.reason,
326                    },
327                )
328            }
329        };
330        SqlServerColumnDesc {
331            name: Arc::clone(&raw.name),
332            primary_key_constraint: None,
333            column_type,
334            decode_type,
335            raw_type: Arc::clone(&raw.data_type),
336        }
337    }
338
339    /// Change this [`SqlServerColumnDesc`] to be represented as text in Materialize.
340    pub fn represent_as_text(&mut self) {
341        self.column_type = self
342            .column_type
343            .as_ref()
344            .map(|ct| SqlScalarType::String.nullable(ct.nullable));
345    }
346
347    /// Exclude this [`SqlServerColumnDesc`] from being replicated into Materialize.
348    pub fn exclude(&mut self) {
349        self.column_type = None;
350    }
351
352    /// Check if this [`SqlServerColumnDesc`] is excluded from being replicated into Materialize.
353    pub fn is_excluded(&self) -> bool {
354        self.column_type.is_none()
355    }
356}
357
358impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
359    fn into_proto(&self) -> ProtoSqlServerColumnDesc {
360        ProtoSqlServerColumnDesc {
361            name: self.name.to_string(),
362            column_type: self.column_type.into_proto(),
363            primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
364            decode_type: Some(self.decode_type.into_proto()),
365            raw_type: self.raw_type.to_string(),
366        }
367    }
368
369    fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
370        Ok(SqlServerColumnDesc {
371            name: proto.name.into(),
372            column_type: proto.column_type.into_rust()?,
373            primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
374            decode_type: proto
375                .decode_type
376                .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
377            raw_type: proto.raw_type.into(),
378        })
379    }
380}
381
382/// The raw datatype from SQL Server is not supported in Materialize.
383#[derive(Debug)]
384#[allow(dead_code)]
385pub struct UnsupportedDataType {
386    column_name: String,
387    column_type: String,
388    reason: String,
389}
390
391/// Parse a raw data type from SQL Server into a Materialize [`SqlScalarType`].
392///
393/// Returns the [`SqlScalarType`] that we'll map this column to and the [`SqlServerColumnDecodeType`]
394/// that we use to decode the raw value.
395fn parse_data_type(
396    raw: &SqlServerColumnRaw,
397) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
398    // The value of a computed column, persisted or not, will be readable by the snapshot, but will
399    // always be NULL in the CDC stream.  This can lead to issues in MZ (e.g. decoding errors,
400    // negative accumulations, etc.).
401    if raw.is_computed {
402        return Err(UnsupportedDataType {
403            column_name: raw.name.to_string(),
404            column_type: format!("{} (computed)", raw.data_type.to_lowercase()),
405            reason: "column is computed".into(),
406        });
407    }
408
409    let scalar =
410        match raw.data_type.to_lowercase().as_str() {
411            "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
412            "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
413            "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
414            "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
415            "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
416            "decimal" | "numeric" | "money" | "smallmoney" => {
417                // SQL Server supports a precision in the range of [1, 38] and then
418                // the scale is 0 <= scale <= precision.
419                //
420                // Materialize numerics are floating point with a fixed precision of 39.
421                //
422                // See: <https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments>
423                if raw.precision > 38 || raw.scale > raw.precision {
424                    tracing::warn!(
425                        "unexpected value from SQL Server, precision of {} and scale of {}",
426                        raw.precision,
427                        raw.scale,
428                    );
429                }
430                if raw.precision > 39 {
431                    let reason = format!(
432                        "precision of {} is greater than our maximum of 39",
433                        raw.precision
434                    );
435                    return Err(UnsupportedDataType {
436                        column_name: raw.name.to_string(),
437                        column_type: raw.data_type.to_string(),
438                        reason,
439                    });
440                }
441
442                let raw_scale = usize::cast_from(raw.scale);
443                let max_scale =
444                    NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
445                        column_type: raw.data_type.to_string(),
446                        column_name: raw.name.to_string(),
447                        reason: format!("scale of {} is too large", raw.scale),
448                    })?;
449                let column_type = SqlScalarType::Numeric {
450                    max_scale: Some(max_scale),
451                };
452
453                (column_type, SqlServerColumnDecodeType::Numeric)
454            }
455            // SQL Server has a few IEEE 754 floating point type names. The underlying type is float(n),
456            // where n is the number of bits used. SQL Server still ends up with only 2 distinct types
457            // as it treats 1 <= n <= 24 as n=24, and 25 <= n <= 53 as n=53.
458            //
459            // Additionally, `real` and `double precision` exist as synonyms of float(24) and float(53),
460            // respectively.  What doesn't appear to be documented is how these appear in `sys.types`.
461            // See <https://learn.microsoft.com/en-us/sql/t-sql/data-types/float-and-real-transact-sql?view=sql-server-ver17>
462            "real" | "float" | "double precision" => match raw.max_length {
463                // Decide the MZ type based on the number of bytes rather than the name, just in case
464                // there is inconsistency among versions.
465                4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
466                8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
467                _ => {
468                    return Err(UnsupportedDataType {
469                        column_name: raw.name.to_string(),
470                        column_type: raw.data_type.to_string(),
471                        reason: format!("unsupported length {}", raw.max_length),
472                    });
473                }
474            },
475            dt @ ("char" | "nchar" | "sysname") => {
476                // There isn't a char(max) or nchar(max), so it isn't clear if this condition
477                // is possible.
478                if raw.max_length == -1 {
479                    return Err(UnsupportedDataType {
480                        column_name: raw.name.to_string(),
481                        column_type: raw.data_type.to_string(),
482                        reason: "columns with unlimited size do not support CDC".to_string(),
483                    });
484                }
485
486                let column_type = match dt {
487                    "char" => {
488                        let length =
489                            if raw.max_length != -1 {
490                                let length = CharLength::try_from(i64::from(raw.max_length))
491                                    .map_err(|e| UnsupportedDataType {
492                                        column_name: raw.name.to_string(),
493                                        column_type: raw.data_type.to_string(),
494                                        reason: e.to_string(),
495                                    })?;
496                                Some(length)
497                            } else {
498                                None
499                            };
500                        SqlScalarType::Char { length }
501                    }
502                    // Determining the max character count for these types is difficult
503                    // because of different character encodings, so we fallback to just
504                    // representing them as "text".
505                    "nchar" | "sysname" => SqlScalarType::String,
506                    other => unreachable!("'{other}' checked above"),
507                };
508
509                (column_type, SqlServerColumnDecodeType::String)
510            }
511            "varchar" | "nvarchar" => {
512                // `max text repl size` is 64KB by default.  If a user attempts to insert a value
513                // that exceeds this limit, SQL Server will return an error and the insert fails
514                // with error `7139`.  This is also true for updates that increase the field length
515                // beyond the limit.
516                //
517                // See <https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors-7000-to-7999?view=sql-server-ver17>
518                //
519                // If the `max text repl size` changes, it does not affect events already written to
520                // the CDC table, nor does it change the behavior of what CDC captures for updates
521                // to non-LOD columns (based on testing).
522                let max_length =
523                    if raw.max_length != -1 {
524                        let length = VarCharMaxLength::try_from(i64::from(raw.max_length))
525                            .map_err(|e| UnsupportedDataType {
526                                column_name: raw.name.to_string(),
527                                column_type: raw.data_type.to_string(),
528                                reason: e.to_string(),
529                            })?;
530                        Some(length)
531                    } else {
532                        None
533                    };
534                let column_type = SqlScalarType::VarChar { max_length };
535                (column_type, SqlServerColumnDecodeType::String)
536            }
537            "text" | "ntext" | "image" => {
538                // SQL Server docs indicate this should always be 16. There's no
539                // issue if it's not, but it's good to track.
540                mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
541
542                // TODO(sql_server3): Support UPSERT semantics for SQL Server.
543                return Err(UnsupportedDataType {
544                    column_name: raw.name.to_string(),
545                    column_type: raw.data_type.to_string(),
546                    reason: "columns with unlimited size do not support CDC".to_string(),
547                });
548            }
549            "xml" => {
550                // When the `max_length` is -1 SQL Server will not present us with the "before" value
551                // for updated columns.
552                //
553                // TODO(sql_server3): Support UPSERT semantics for SQL Server.
554                if raw.max_length == -1 {
555                    return Err(UnsupportedDataType {
556                        column_name: raw.name.to_string(),
557                        column_type: raw.data_type.to_string(),
558                        reason: "columns with unlimited size do not support CDC".to_string(),
559                    });
560                }
561                (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
562            }
563            "binary" | "varbinary" => {
564                // [`SqlScalarType`] does not support tracking max_length for binary data. To ensure
565                // columns of type varbinary(max) (Large Object Data) are decoded properly, it is
566                // necessary to know that the length is `max`. `varchar` and `nvarchar` track this
567                // using [`SqlScalarType::VarChar`] max_length field.
568                if raw.max_length == -1 {
569                    return Err(UnsupportedDataType {
570                        column_name: raw.name.to_string(),
571                        column_type: raw.data_type.to_string(),
572                        reason: "columns with unlimited size do not support CDC".to_string(),
573                    });
574                }
575                (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
576            }
577            "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
578            "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
579            // SQL Server supports a scale of (and defaults to) 7 digits (aka 100 nanoseconds)
580            // for time related types.
581            //
582            // Internally Materialize supports a scale of 9 (aka nanoseconds), but for Postgres
583            // compatibility we constraint ourselves to a scale of 6 (aka microseconds). By
584            // default we will round values we get from  SQL Server to fit in Materialize.
585            //
586            // TODO(sql_server3): Support a "strict" mode where we're fail the creation of the
587            // source if the scale is too large.
588            // TODO(sql_server3): Support specifying a precision for SqlScalarType::Time.
589            //
590            // See: <https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetime2-transact-sql?view=sql-server-ver16>.
591            "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
592            dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
593                if raw.scale > 7 {
594                    tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
595                }
596                if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
597                    tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
598                }
599                let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
600                let precision =
601                    Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
602
603                match dt {
604                    "smalldatetime" | "datetime" | "datetime2" => (
605                        SqlScalarType::Timestamp { precision },
606                        SqlServerColumnDecodeType::NaiveDateTime,
607                    ),
608                    "datetimeoffset" => (
609                        SqlScalarType::TimestampTz { precision },
610                        SqlServerColumnDecodeType::DateTime,
611                    ),
612                    other => unreachable!("'{other}' checked above"),
613                }
614            }
615            "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
616            // TODO(sql_server3): Support reading the following types, at least as text:
617            //
618            // * geography
619            // * geometry
620            // * json (preview)
621            // * vector (preview)
622            //
623            // None of these types are implemented in `tiberius`, the crate that
624            // provides our SQL Server client, so we'll need to implement support
625            // for decoding them.
626            //
627            // See <https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/355f7890-6e91-4978-ab76-2ded17ee09bc>.
628            other => {
629                return Err(UnsupportedDataType {
630                    column_type: other.to_string(),
631                    column_name: raw.name.to_string(),
632                    reason: format!("'{other}' is unimplemented"),
633                });
634            }
635        };
636    Ok(scalar)
637}
638
639/// Raw metadata for a column from a table in Microsoft SQL Server.
640///
641/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
642#[derive(Clone, Debug)]
643pub struct SqlServerColumnRaw {
644    /// Name of this column.
645    pub name: Arc<str>,
646    /// Name of the data type.
647    pub data_type: Arc<str>,
648    /// Whether or not the column is nullable.
649    pub is_nullable: bool,
650    /// Maximum length (in bytes) of the column.
651    ///
652    /// For `varchar(max)`, `nvarchar(max)`, `varbinary(max)`, or `xml` this will be `-1`. For
653    /// `text`, `ntext`, and `image` columns this will be 16.
654    ///
655    /// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
656    ///
657    /// TODO(sql_server2): Validate this value for `json` columns where were introduced
658    /// Azure SQL 2024.
659    pub max_length: i16,
660    /// Precision of the column, if numeric-based; otherwise 0.
661    pub precision: u8,
662    /// Scale of the columns, if numeric-based; otherwise 0.
663    pub scale: u8,
664    /// Whether the column is computed.
665    pub is_computed: bool,
666}
667
668/// Raw metadata for a table constraint.
669#[derive(Clone, Debug)]
670pub struct SqlServerTableConstraintRaw {
671    pub constraint_name: String,
672    pub constraint_type: String,
673    pub columns: Vec<String>,
674}
675
676/// Rust type that we should use when reading a column from SQL Server.
677#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
678pub enum SqlServerColumnDecodeType {
679    Bool,
680    U8,
681    I16,
682    I32,
683    I64,
684    F32,
685    F64,
686    String,
687    Bytes,
688    /// [`uuid::Uuid`].
689    Uuid,
690    /// [`tiberius::numeric::Numeric`].
691    Numeric,
692    /// [`tiberius::xml::XmlData`].
693    Xml,
694    /// [`chrono::NaiveDate`].
695    NaiveDate,
696    /// [`chrono::NaiveTime`].
697    NaiveTime,
698    /// [`chrono::DateTime`].
699    DateTime,
700    /// [`chrono::NaiveDateTime`].
701    NaiveDateTime,
702    /// Decoding this type isn't supported.
703    Unsupported {
704        /// Any additional context as to why this type isn't supported.
705        context: String,
706    },
707}
708
709impl SqlServerColumnDecodeType {
710    /// Decode the column with `name` out of the provided `data`.
711    pub fn decode<'a>(
712        &self,
713        data: &'a tiberius::Row,
714        name: &'a str,
715        column: &'a SqlColumnType,
716        arena: &'a RowArena,
717    ) -> Result<Datum<'a>, SqlServerDecodeError> {
718        let maybe_datum = match (&column.scalar_type, self) {
719            (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
720                .try_get(name)
721                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
722                .map(|val: bool| if val { Datum::True } else { Datum::False }),
723            (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
724                .try_get(name)
725                .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
726                .map(|val: u8| Datum::Int16(i16::cast_from(val))),
727            (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
728                .try_get(name)
729                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
730                .map(Datum::Int16),
731            (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
732                .try_get(name)
733                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
734                .map(Datum::Int32),
735            (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
736                .try_get(name)
737                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
738                .map(Datum::Int64),
739            (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
740                .try_get(name)
741                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
742                .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
743            (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
744                .try_get(name)
745                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
746                .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
747            (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
748                .try_get(name)
749                .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
750                .map(Datum::String),
751            (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
752                .try_get(name)
753                .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
754                .map(|val: &str| match length {
755                    Some(expected) => {
756                        let found_chars = val.chars().count();
757                        let expct_chars = usize::cast_from(expected.into_u32());
758                        if found_chars != expct_chars {
759                            Err(SqlServerDecodeError::invalid_char(
760                                name,
761                                expct_chars,
762                                found_chars,
763                            ))
764                        } else {
765                            Ok(Datum::String(val))
766                        }
767                    }
768                    None => Ok(Datum::String(val)),
769                })
770                .transpose()?,
771            (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
772                .try_get(name)
773                .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
774                .map(|val: &str| match max_length {
775                    Some(max) => {
776                        let found_chars = val.chars().count();
777                        let max_chars = usize::cast_from(max.into_u32());
778                        if found_chars > max_chars {
779                            Err(SqlServerDecodeError::invalid_varchar(
780                                name,
781                                max_chars,
782                                found_chars,
783                            ))
784                        } else {
785                            Ok(Datum::String(val))
786                        }
787                    }
788                    None => Ok(Datum::String(val)),
789                })
790                .transpose()?,
791            (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
792                .try_get(name)
793                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
794                .map(Datum::Bytes),
795            (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
796                .try_get(name)
797                .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
798                .map(Datum::Uuid),
799            (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
800                .try_get(name)
801                .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
802                .map(|val: tiberius::numeric::Numeric| {
803                    let numeric = tiberius_numeric_to_mz_numeric(val);
804                    Datum::Numeric(OrderedDecimal(numeric))
805                }),
806            (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
807                .try_get(name)
808                .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
809                .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
810            (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
811                .try_get(name)
812                .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
813                .map(|val: chrono::NaiveDate| {
814                    let date = val
815                        .try_into()
816                        .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
817                    Ok::<_, SqlServerDecodeError>(Datum::Date(date))
818                })
819                .transpose()?,
820            (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
821                .try_get(name)
822                .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
823                .map(|val: chrono::NaiveTime| {
824                    // Postgres' maximum precision is 6 (aka microseconds).
825                    //
826                    // While the Postgres spec supports specifying a precision
827                    // Materialize does not.
828                    let rounded = val.round_subsecs(6);
829                    // Overflowed.
830                    let val = if rounded < val {
831                        val.trunc_subsecs(6)
832                    } else {
833                        val
834                    };
835                    Datum::Time(val)
836                }),
837            (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
838                data.try_get(name)
839                    .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
840                    .map(|val: chrono::NaiveDateTime| {
841                        let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
842                            .try_into()
843                            .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
844                        let rounded = ts
845                            .round_to_precision(*precision)
846                            .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
847                        Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
848                    })
849                    .transpose()?
850            }
851            (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
852                .try_get(name)
853                .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
854                .map(|val: chrono::DateTime<chrono::Utc>| {
855                    let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
856                        .try_into()
857                        .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
858                    let rounded = ts
859                        .round_to_precision(*precision)
860                        .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
861                    Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
862                })
863                .transpose()?,
864            // We support mapping any type to a string.
865            (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
866                .try_get(name)
867                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
868                .map(|val: bool| {
869                    if val {
870                        Datum::String("true")
871                    } else {
872                        Datum::String("false")
873                    }
874                }),
875            (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
876                .try_get(name)
877                .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
878                .map(|val: u8| {
879                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
880                }),
881            (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
882                .try_get(name)
883                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
884                .map(|val: i16| {
885                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
886                }),
887            (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
888                .try_get(name)
889                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
890                .map(|val: i32| {
891                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
892                }),
893            (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
894                .try_get(name)
895                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
896                .map(|val: i64| {
897                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
898                }),
899            (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
900                .try_get(name)
901                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
902                .map(|val: f32| {
903                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
904                }),
905            (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
906                .try_get(name)
907                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
908                .map(|val: f64| {
909                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
910                }),
911            (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
912                .try_get(name)
913                .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
914                .map(|val: uuid::Uuid| {
915                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
916                }),
917            (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
918                .try_get(name)
919                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
920                .map(|val: &[u8]| {
921                    let encoded = base64::engine::general_purpose::STANDARD.encode(val);
922                    arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
923                }),
924            (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
925                .try_get(name)
926                .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
927                .map(|val: tiberius::numeric::Numeric| {
928                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
929                }),
930            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
931                .try_get(name)
932                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
933                .map(|val: chrono::NaiveDate| {
934                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
935                }),
936            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
937                .try_get(name)
938                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
939                .map(|val: chrono::NaiveTime| {
940                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
941                }),
942            (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
943                .try_get(name)
944                .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
945                .map(|val: chrono::DateTime<chrono::Utc>| {
946                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
947                }),
948            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
949                .try_get(name)
950                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
951                .map(|val: chrono::NaiveDateTime| {
952                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
953                }),
954            (column_type, decode_type) => {
955                return Err(SqlServerDecodeError::Unsupported {
956                    sql_server_type: decode_type.clone(),
957                    mz_type: column_type.clone(),
958                });
959            }
960        };
961
962        match (maybe_datum, column.nullable) {
963            (Some(datum), _) => Ok(datum),
964            (None, true) => Ok(Datum::Null),
965            (None, false) => Err(SqlServerDecodeError::InvalidData {
966                column_name: name.to_string(),
967                // Note: This error string is durably recorded in Persist, do not change.
968                error: "found Null in non-nullable column".to_string(),
969            }),
970        }
971    }
972}
973
974impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
975    fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
976        match self {
977            SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
978            SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
979            SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
980            SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
981            SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
982            SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
983            SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
984            SqlServerColumnDecodeType::String => {
985                proto_sql_server_column_desc::DecodeType::String(())
986            }
987            SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
988            SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
989            SqlServerColumnDecodeType::Numeric => {
990                proto_sql_server_column_desc::DecodeType::Numeric(())
991            }
992            SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
993            SqlServerColumnDecodeType::NaiveDate => {
994                proto_sql_server_column_desc::DecodeType::NaiveDate(())
995            }
996            SqlServerColumnDecodeType::NaiveTime => {
997                proto_sql_server_column_desc::DecodeType::NaiveTime(())
998            }
999            SqlServerColumnDecodeType::DateTime => {
1000                proto_sql_server_column_desc::DecodeType::DateTime(())
1001            }
1002            SqlServerColumnDecodeType::NaiveDateTime => {
1003                proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
1004            }
1005            SqlServerColumnDecodeType::Unsupported { context } => {
1006                proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
1007            }
1008        }
1009    }
1010
1011    fn from_proto(
1012        proto: proto_sql_server_column_desc::DecodeType,
1013    ) -> Result<Self, mz_proto::TryFromProtoError> {
1014        let val = match proto {
1015            proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
1016            proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
1017            proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
1018            proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
1019            proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
1020            proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
1021            proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
1022            proto_sql_server_column_desc::DecodeType::String(()) => {
1023                SqlServerColumnDecodeType::String
1024            }
1025            proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
1026            proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
1027            proto_sql_server_column_desc::DecodeType::Numeric(()) => {
1028                SqlServerColumnDecodeType::Numeric
1029            }
1030            proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
1031            proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
1032                SqlServerColumnDecodeType::NaiveDate
1033            }
1034            proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
1035                SqlServerColumnDecodeType::NaiveTime
1036            }
1037            proto_sql_server_column_desc::DecodeType::DateTime(()) => {
1038                SqlServerColumnDecodeType::DateTime
1039            }
1040            proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
1041                SqlServerColumnDecodeType::NaiveDateTime
1042            }
1043            proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
1044                SqlServerColumnDecodeType::Unsupported { context }
1045            }
1046        };
1047        Ok(val)
1048    }
1049}
1050
1051/// Numerics in SQL Server have a maximum precision of 38 digits, where [`Numeric`]s in
1052/// Materialize have a maximum precision of 39 digits, so this conversion is infallible.
1053fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
1054    let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
1055    // Use scaleb to adjust the exponent directly, avoiding precision loss from division
1056    // scaleb(x, -n) computes x * 10^(-n)
1057    mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
1058    numeric
1059}
1060
1061/// The update mask of a CDC event row, returned by `cdc.fn_cdc_get_all_changes_<capture_instance>`
1062/// as `__$update_mask`.
1063///
1064/// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver17>
1065#[derive(Debug)]
1066pub struct UpdateMask {
1067    mask: Vec<u8>,
1068}
1069
1070impl TryFrom<&tiberius::Row> for UpdateMask {
1071    type Error = SqlServerDecodeError;
1072
1073    fn try_from(row: &tiberius::Row) -> Result<Self, Self::Error> {
1074        static UPDATE_MASK: &str = "__$update_mask";
1075
1076        let mask: Vec<u8> = row
1077            .try_get::<&[u8], _>(UPDATE_MASK)
1078            .inspect_err(|e| tracing::warn!("Failed extracting update mask: {e:?}"))
1079            .map_err(|_| SqlServerDecodeError::InvalidColumn {
1080                column_name: UPDATE_MASK.to_string(),
1081                as_type: "bytes",
1082            })?
1083            .ok_or_else(|| SqlServerDecodeError::InvalidData {
1084                column_name: UPDATE_MASK.to_string(),
1085                error: "column cannot be null".to_string(),
1086            })?
1087            .into();
1088        Ok(UpdateMask { mask })
1089    }
1090}
1091
1092impl UpdateMask {
1093    /// Returns true if the data column was updated, false otherwise.
1094    ///
1095    /// This function panics if `col_index` exceeds the mask.
1096    ///
1097    /// The [`tiberius::Row`] returned by `cdc.fn_cdc_get_all_changes_<capture_instance>` contains
1098    /// 4 metadata columns used by CDC:
1099    /// - `__$start_lsn`
1100    /// - `__$seqval`
1101    /// - `__$operation`
1102    /// - `__$update_mask`
1103    ///
1104    /// This function will always return false for the first 4 columns.
1105    pub fn data_col_updated(&self, col_index: usize) -> bool {
1106        const CDC_METADATA_COL_COUNT: usize = 4;
1107
1108        if col_index < CDC_METADATA_COL_COUNT {
1109            return false;
1110        }
1111        let adj_col_index = col_index - CDC_METADATA_COL_COUNT;
1112        let byte_offset = adj_col_index / usize::cast_from(u8::BITS);
1113        assert!(
1114            byte_offset < self.mask.len(),
1115            "byte_offset = {byte_offset} mask_len = {}",
1116            self.mask.len()
1117        );
1118        let bit_offset = adj_col_index % usize::cast_from(u8::BITS);
1119        (self.mask[self.mask.len() - byte_offset - 1] >> bit_offset) & 1 == 1
1120    }
1121}
1122
1123/// A decoder from [`tiberius::Row`] to [`mz_repr::Row`].
1124///
1125/// The goal of this type is to perform any expensive "downcasts" so in the hot
1126/// path of decoding rows we do the minimal amount of work.
1127#[derive(Debug)]
1128pub struct SqlServerRowDecoder {
1129    decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
1130}
1131
1132impl SqlServerRowDecoder {
1133    /// Try to create a [`SqlServerRowDecoder`] that will decode [`tiberius::Row`]s that match
1134    /// the shape of the provided [`SqlServerTableDesc`], to [`mz_repr::Row`]s that match the
1135    /// shape of the provided [`RelationDesc`].
1136    pub fn try_new(
1137        table: &SqlServerTableDesc,
1138        desc: &RelationDesc,
1139    ) -> Result<Self, SqlServerError> {
1140        let decoders = desc
1141            .iter()
1142            .map(|(col_name, col_type)| {
1143                let sql_server_col = table
1144                    .columns
1145                    .iter()
1146                    .find(|col| col.name.as_ref() == col_name.as_str())
1147                    .ok_or_else(|| {
1148                        // TODO(sql_server2): Structured Error.
1149                        anyhow::anyhow!("no SQL Server column with name {col_name} found")
1150                    })?;
1151                let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
1152                    return Err(SqlServerError::ProgrammingError(format!(
1153                        "programming error, {col_name} should have been exluded",
1154                    )));
1155                };
1156
1157                // This shouldn't be true, but be defensive.
1158                //
1159                // TODO(sql_server2): Maybe allow the Materialize column type to be
1160                // more nullable than our decoding type?
1161                //
1162                // Sad. Our timestamp types don't roundtrip their precision through
1163                // parsing so we ignore the mismatch here.
1164                let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
1165                    (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
1166                    | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
1167                        // Types match so check nullability.
1168                        sql_server_col_typ.nullable == col_type.nullable
1169                    }
1170                    (_, _) => sql_server_col_typ == col_type,
1171                };
1172                if !matches {
1173                    return Err(SqlServerError::ProgrammingError(format!(
1174                        "programming error, {col_name} has mismatched type {:?} vs {:?}",
1175                        sql_server_col.column_type, col_type
1176                    )));
1177                }
1178
1179                let name = Arc::clone(&sql_server_col.name);
1180                let decoder = sql_server_col.decode_type.clone();
1181                // Note: We specifically use the `SqlColumnType` from the SqlServerTableDesc
1182                // because it retains precision.
1183                //
1184                // See: <https://github.com/MaterializeInc/database-issues/issues/3179>.
1185                let col_typ = sql_server_col_typ.clone();
1186
1187                Ok::<_, SqlServerError>((name, col_typ, decoder))
1188            })
1189            .collect::<Result<_, _>>()?;
1190
1191        Ok(SqlServerRowDecoder { decoders })
1192    }
1193
1194    /// Decode data from the provided [`tiberius::Row`] into the provided [`Row`].
1195    ///
1196    /// For updates, the new row data is provided in the event the data contains Large Object Data
1197    /// (e.g. varchar(max)). [`SqlServerRowDecoder::decode()`] will decode the [`UpdateMask`] from
1198    /// the new row and retrieve LOD values from the new row for any LOD column that was not updated
1199    /// in the old row.
1200    pub fn decode(
1201        &self,
1202        data: &tiberius::Row,
1203        row: &mut Row,
1204        arena: &RowArena,
1205        new_data: Option<&tiberius::Row>,
1206    ) -> Result<(), SqlServerDecodeError> {
1207        let mut packer = row.packer();
1208
1209        for (col_name, col_type, decoder) in &self.decoders {
1210            let datum = decoder.decode(data, col_name, col_type, arena)?;
1211
1212            let datum = if let Some(new_data) = new_data
1213                && matches!(
1214                    col_type.scalar_type,
1215                    SqlScalarType::VarChar { max_length: None }
1216                )
1217                && matches!(datum, Datum::Null)
1218            {
1219                let update_mask = UpdateMask::try_from(new_data)?;
1220                let col_index = new_data
1221                    .columns()
1222                    .iter()
1223                    .position(|c| c.name() == col_name.as_ref())
1224                    .expect("column exists");
1225                // The only time it is valid to pull the LOD column value from the new row
1226                // is if the LOD column was *not* updated. The mask check is necessary to
1227                // differentiate between updating a non-LOD column (the LOD column in old row
1228                // is NULL) and updating a LOD column where the old value is NULL.
1229                if !update_mask.data_col_updated(col_index) {
1230                    decoder.decode(new_data, col_name, col_type, arena)?
1231                } else {
1232                    datum
1233                }
1234            } else {
1235                datum
1236            };
1237
1238            packer.push(datum);
1239        }
1240        Ok(())
1241    }
1242
1243    pub fn included_column_names(&self) -> Vec<Arc<str>> {
1244        self.decoders
1245            .iter()
1246            .map(|decoder| Arc::clone(&decoder.0))
1247            .collect()
1248    }
1249}
1250
1251#[cfg(test)]
1252mod tests {
1253    use std::collections::BTreeSet;
1254    use std::sync::Arc;
1255
1256    use chrono::NaiveDateTime;
1257    use itertools::Itertools;
1258    use mz_ore::assert_contains;
1259    use mz_ore::collections::CollectionExt;
1260    use mz_repr::adt::numeric::NumericMaxScale;
1261    use mz_repr::adt::varchar::VarCharMaxLength;
1262    use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1263    use tiberius::RowTestExt;
1264
1265    use crate::desc::{
1266        SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1267        SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1268    };
1269
1270    use super::SqlServerColumnRaw;
1271
1272    impl SqlServerColumnRaw {
1273        /// Create a new [`SqlServerColumnRaw`]. The specified `data_type` is
1274        /// _not_ checked for validity.
1275        fn new(name: &str, data_type: &str) -> Self {
1276            SqlServerColumnRaw {
1277                name: name.into(),
1278                data_type: data_type.into(),
1279                is_nullable: false,
1280                max_length: 0,
1281                precision: 0,
1282                scale: 0,
1283                is_computed: false,
1284            }
1285        }
1286
1287        fn nullable(mut self, nullable: bool) -> Self {
1288            self.is_nullable = nullable;
1289            self
1290        }
1291
1292        fn max_length(mut self, max_length: i16) -> Self {
1293            self.max_length = max_length;
1294            self
1295        }
1296
1297        fn precision(mut self, precision: u8) -> Self {
1298            self.precision = precision;
1299            self
1300        }
1301
1302        fn scale(mut self, scale: u8) -> Self {
1303            self.scale = scale;
1304            self
1305        }
1306    }
1307
1308    #[mz_ore::test]
1309    fn smoketest_column_raw() {
1310        let raw = SqlServerColumnRaw::new("foo", "bit");
1311        let col = SqlServerColumnDesc::new(&raw);
1312
1313        assert_eq!(&*col.name, "foo");
1314        assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1315        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1316
1317        let raw = SqlServerColumnRaw::new("foo", "decimal")
1318            .precision(20)
1319            .scale(10);
1320        let col = SqlServerColumnDesc::new(&raw);
1321
1322        let col_type = SqlScalarType::Numeric {
1323            max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1324        }
1325        .nullable(false);
1326        assert_eq!(col.column_type, Some(col_type));
1327        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1328    }
1329
1330    #[mz_ore::test]
1331    fn smoketest_column_raw_invalid() {
1332        let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1333        let desc = SqlServerColumnDesc::new(&raw);
1334        let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1335            panic!("unexpected decode type {desc:?}");
1336        };
1337        assert_contains!(context, "'bad_data_type' is unimplemented");
1338
1339        let raw = SqlServerColumnRaw::new("foo", "decimal")
1340            .precision(100)
1341            .scale(10);
1342        let desc = SqlServerColumnDesc::new(&raw);
1343        assert!(matches!(
1344            desc.decode_type,
1345            SqlServerColumnDecodeType::Unsupported { .. }
1346        ));
1347
1348        let raw = SqlServerColumnRaw::new("foo", "varbinary").max_length(-1);
1349        let desc = SqlServerColumnDesc::new(&raw);
1350        let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1351            panic!("unexpected decode type {desc:?}");
1352        };
1353        assert_contains!(context, "columns with unlimited size do not support CDC");
1354    }
1355
1356    #[mz_ore::test]
1357    fn smoketest_decoder() {
1358        let sql_server_columns = [
1359            SqlServerColumnRaw::new("a", "varchar").max_length(16),
1360            SqlServerColumnRaw::new("b", "int").nullable(true),
1361            SqlServerColumnRaw::new("c", "bit"),
1362        ];
1363        let sql_server_desc = SqlServerTableRaw {
1364            schema_name: "my_schema".into(),
1365            name: "my_table".into(),
1366            capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1367                name: "my_table_CT".into(),
1368                create_date: NaiveDateTime::parse_from_str(
1369                    "2024-01-01 00:00:00",
1370                    "%Y-%m-%d %H:%M:%S",
1371                )
1372                .unwrap()
1373                .into(),
1374            }),
1375            columns: sql_server_columns.into(),
1376        };
1377        let sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1378
1379        let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1380        let relation_desc = RelationDesc::builder()
1381            .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1382            // Note: In the upstream table 'c' is ordered after 'b'.
1383            .with_column("c", SqlScalarType::Bool.nullable(false))
1384            .with_column("b", SqlScalarType::Int32.nullable(true))
1385            .finish();
1386
1387        // This decoder should shape the SQL Server Rows into Rows compatible with the RelationDesc.
1388        let decoder = sql_server_desc
1389            .decoder(&relation_desc)
1390            .expect("known valid");
1391
1392        let sql_server_columns = [
1393            tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1394            tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1395            tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1396        ];
1397
1398        let data_a = [
1399            tiberius::ColumnData::String(Some("hello world".into())),
1400            tiberius::ColumnData::I32(Some(42)),
1401            tiberius::ColumnData::Bit(Some(true)),
1402        ];
1403        let sql_server_row_a = tiberius::Row::build(
1404            sql_server_columns
1405                .iter()
1406                .cloned()
1407                .zip_eq(data_a.into_iter()),
1408        );
1409
1410        let data_b = [
1411            tiberius::ColumnData::String(Some("foo bar".into())),
1412            tiberius::ColumnData::I32(None),
1413            tiberius::ColumnData::Bit(Some(false)),
1414        ];
1415        let sql_server_row_b =
1416            tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1417
1418        let mut rnd_row = Row::default();
1419        let arena = RowArena::default();
1420
1421        decoder
1422            .decode(&sql_server_row_a, &mut rnd_row, &arena, None)
1423            .unwrap();
1424        assert_eq!(
1425            &rnd_row,
1426            &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1427        );
1428
1429        decoder
1430            .decode(&sql_server_row_b, &mut rnd_row, &arena, None)
1431            .unwrap();
1432        assert_eq!(
1433            &rnd_row,
1434            &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1435        );
1436    }
1437
1438    #[mz_ore::test]
1439    fn smoketest_decode_to_string() {
1440        #[track_caller]
1441        fn testcase(
1442            data_type: &'static str,
1443            col_type: tiberius::ColumnType,
1444            col_data: tiberius::ColumnData<'static>,
1445        ) {
1446            let columns = [SqlServerColumnRaw::new("a", data_type)];
1447            let sql_server_desc = SqlServerTableRaw {
1448                schema_name: "my_schema".into(),
1449                name: "my_table".into(),
1450                capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1451                    name: "my_table_CT".into(),
1452                    create_date: NaiveDateTime::parse_from_str(
1453                        "2024-01-01 00:00:00",
1454                        "%Y-%m-%d %H:%M:%S",
1455                    )
1456                    .unwrap()
1457                    .into(),
1458                }),
1459                columns: columns.into(),
1460            };
1461            let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1462            sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1463
1464            // We should support decoding every datatype to a string.
1465            let relation_desc = RelationDesc::builder()
1466                .with_column("a", SqlScalarType::String.nullable(false))
1467                .finish();
1468
1469            // This decoder should shape the SQL Server Rows into Rows compatible with the RelationDesc.
1470            let decoder = sql_server_desc
1471                .decoder(&relation_desc)
1472                .expect("known valid");
1473
1474            let sql_server_row = tiberius::Row::build([(
1475                tiberius::Column::new("a".to_string(), col_type),
1476                col_data,
1477            )]);
1478            let mut mz_row = Row::default();
1479            let arena = RowArena::new();
1480            decoder
1481                .decode(&sql_server_row, &mut mz_row, &arena, None)
1482                .unwrap();
1483
1484            let str_datum = mz_row.into_element();
1485            assert!(matches!(str_datum, Datum::String(_)));
1486        }
1487
1488        use tiberius::ColumnData;
1489
1490        testcase(
1491            "bit",
1492            tiberius::ColumnType::Bit,
1493            ColumnData::Bit(Some(true)),
1494        );
1495        testcase(
1496            "bit",
1497            tiberius::ColumnType::Bit,
1498            ColumnData::Bit(Some(false)),
1499        );
1500        testcase(
1501            "tinyint",
1502            tiberius::ColumnType::Int1,
1503            ColumnData::U8(Some(33)),
1504        );
1505        testcase(
1506            "smallint",
1507            tiberius::ColumnType::Int2,
1508            ColumnData::I16(Some(101)),
1509        );
1510        testcase(
1511            "int",
1512            tiberius::ColumnType::Int4,
1513            ColumnData::I32(Some(-42)),
1514        );
1515        {
1516            let datetime = tiberius::time::DateTime::new(10, 300);
1517            testcase(
1518                "datetime",
1519                tiberius::ColumnType::Datetime,
1520                ColumnData::DateTime(Some(datetime)),
1521            );
1522        }
1523    }
1524
1525    #[mz_ore::test]
1526    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
1527    fn smoketest_numeric_conversion() {
1528        let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1529        let rnd = tiberius_numeric_to_mz_numeric(a);
1530        let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1531        assert_eq!(og, rnd);
1532
1533        let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1534        let rnd = tiberius_numeric_to_mz_numeric(a);
1535        let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1536        assert_eq!(og, rnd);
1537
1538        let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1539        let rnd = tiberius_numeric_to_mz_numeric(a);
1540        let og = mz_repr::adt::numeric::cx_datum()
1541            .parse("0.00000000000000000000000000001")
1542            .unwrap();
1543        assert_eq!(og, rnd);
1544
1545        let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1546        let rnd = tiberius_numeric_to_mz_numeric(a);
1547        let og = mz_repr::adt::numeric::cx_datum()
1548            .parse("-111111111111111111")
1549            .unwrap();
1550        assert_eq!(og, rnd);
1551    }
1552
1553    // TODO(sql_server2): Proptest the decoder.
1554}