Skip to main content

mz_sql_server_util/
desc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Metadata about tables, columns, and other objects from SQL Server.
11//!
12//! ### Tables
13//!
14//! When creating a SQL Server source we will query system tables from the
15//! upstream instance to get a [`SqlServerTableRaw`]. From this raw information
16//! we create a [`SqlServerTableDesc`] which describes how the upstream table
17//! will get represented in Materialize.
18//!
19//! ### Rows
20//!
21//! With a [`SqlServerTableDesc`] and an [`mz_repr::RelationDesc`] we can
22//! create a [`SqlServerRowDecoder`] which will be used when running a source
23//! to efficiently decode [`tiberius::Row`]s into [`mz_repr::Row`]s.
24
25use base64::Engine;
26use chrono::{NaiveDateTime, SubsecRound};
27use dec::OrderedDecimal;
28use mz_ore::cast::CastFrom;
29use mz_proto::{IntoRustIfSome, ProtoType, RustType};
30use mz_repr::adt::char::CharLength;
31use mz_repr::adt::numeric::{Numeric, NumericMaxScale};
32use mz_repr::adt::timestamp::{CheckedTimestamp, TimestampPrecision};
33use mz_repr::adt::varchar::VarCharMaxLength;
34use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlColumnType, SqlScalarType};
35#[cfg(any(test, feature = "proptest"))]
36use proptest_derive::Arbitrary;
37use serde::{Deserialize, Serialize};
38
39use std::collections::BTreeSet;
40use std::sync::Arc;
41
42use crate::desc::proto_sql_server_table_constraint::ConstraintType;
43use crate::{SqlServerDecodeError, SqlServerError};
44
45include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
46
47/// Materialize compatible description of a table in Microsoft SQL Server.
48///
49/// See [`SqlServerTableRaw`] for the raw information we read from the upstream
50/// system.
51///
52/// Note: We map a [`SqlServerTableDesc`] to a Materialize [`RelationDesc`] as
53/// part of purification. Specifically we use this description to generate a
54/// SQL statement for subsource and it's the _parsing of that statement_ which
55/// actually generates a [`RelationDesc`].
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
58pub struct SqlServerTableDesc {
59    /// Name of the schema that the table belongs to.
60    pub schema_name: Arc<str>,
61    /// Name of the table.
62    pub name: Arc<str>,
63    /// Columns for the table.
64    pub columns: Box<[SqlServerColumnDesc]>,
65    /// Constraints for the table.
66    pub constraints: Vec<SqlServerTableConstraint>,
67}
68
69impl SqlServerTableDesc {
70    /// Creating a [`SqlServerTableDesc`] from a [`SqlServerTableRaw`] description.
71    ///
72    /// Note: Not all columns from SQL Server can be ingested into Materialize. To determine if a
73    /// column is supported see [`SqlServerColumnDesc::decode_type`].
74    pub fn new(
75        raw: SqlServerTableRaw,
76        raw_constraints: Vec<SqlServerTableConstraintRaw>,
77    ) -> Result<Self, SqlServerError> {
78        let columns: Box<[_]> = raw
79            .columns
80            .into_iter()
81            .map(SqlServerColumnDesc::new)
82            .collect();
83        let constraints = raw_constraints
84            .into_iter()
85            .map(SqlServerTableConstraint::try_from)
86            .collect::<Result<Vec<_>, _>>()?;
87        Ok(SqlServerTableDesc {
88            schema_name: raw.schema_name,
89            name: raw.name,
90            columns,
91            constraints,
92        })
93    }
94
95    /// Returns the [`SqlServerQualifiedTableName`] for this [`SqlServerTableDesc`].
96    pub fn qualified_name(&self) -> SqlServerQualifiedTableName {
97        SqlServerQualifiedTableName {
98            schema_name: Arc::clone(&self.schema_name),
99            table_name: Arc::clone(&self.name),
100        }
101    }
102
103    /// Update this [`SqlServerTableDesc`] to represent the specified columns
104    /// as text in Materialize.
105    pub fn apply_text_columns(&mut self, text_columns: &BTreeSet<&str>) {
106        for column in &mut self.columns {
107            if text_columns.contains(column.name.as_ref()) {
108                column.represent_as_text();
109            }
110        }
111    }
112
113    /// Update this [`SqlServerTableDesc`] to exclude the specified columns from being
114    /// replicated into Materialize.
115    pub fn apply_excl_columns(&mut self, excl_columns: &BTreeSet<&str>) {
116        for column in &mut self.columns {
117            if excl_columns.contains(column.name.as_ref()) {
118                column.exclude();
119            }
120        }
121    }
122
123    /// Returns a [`SqlServerRowDecoder`] which can be used to decode [`tiberius::Row`]s into
124    /// [`mz_repr::Row`]s that match the shape of the provided [`RelationDesc`].
125    pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
126        let decoder = SqlServerRowDecoder::try_new(self, desc)?;
127        Ok(decoder)
128    }
129}
130
131impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
132    fn into_proto(&self) -> ProtoSqlServerTableDesc {
133        ProtoSqlServerTableDesc {
134            name: self.name.to_string(),
135            schema_name: self.schema_name.to_string(),
136            columns: self.columns.iter().map(|c| c.into_proto()).collect(),
137            constraints: self.constraints.iter().map(|c| c.into_proto()).collect(),
138        }
139    }
140
141    fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
142        let columns = proto
143            .columns
144            .into_iter()
145            .map(|c| c.into_rust())
146            .collect::<Result<_, _>>()?;
147        let constraints = proto
148            .constraints
149            .into_iter()
150            .map(|c| c.into_rust())
151            .collect::<Result<_, _>>()?;
152        Ok(SqlServerTableDesc {
153            schema_name: proto.schema_name.into(),
154            name: proto.name.into(),
155            columns,
156            constraints,
157        })
158    }
159}
160
161/// SQL Server table constraint type (e.g. PRIMARY KEY, UNIQUE, etc.)
162/// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-information-schema-views/table-constraints-transact-sql?view=sql-server-ver17>
163#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
164#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
165pub enum SqlServerTableConstraintType {
166    PrimaryKey,
167    Unique,
168}
169
170impl TryFrom<String> for SqlServerTableConstraintType {
171    type Error = SqlServerError;
172
173    fn try_from(value: String) -> Result<Self, Self::Error> {
174        match value.as_str() {
175            "PRIMARY KEY" => Ok(Self::PrimaryKey),
176            "UNIQUE" => Ok(Self::Unique),
177            name => Err(SqlServerError::InvalidData {
178                column_name: "constraint_type".into(),
179                error: format!("Unknown constraint type: {name}"),
180            }),
181        }
182    }
183}
184
185impl RustType<proto_sql_server_table_constraint::ConstraintType> for SqlServerTableConstraintType {
186    fn into_proto(&self) -> proto_sql_server_table_constraint::ConstraintType {
187        match self {
188            SqlServerTableConstraintType::PrimaryKey => ConstraintType::PrimaryKey(()),
189            SqlServerTableConstraintType::Unique => ConstraintType::Unique(()),
190        }
191    }
192
193    fn from_proto(
194        proto: proto_sql_server_table_constraint::ConstraintType,
195    ) -> Result<Self, mz_proto::TryFromProtoError> {
196        Ok(match proto {
197            ConstraintType::PrimaryKey(_) => SqlServerTableConstraintType::PrimaryKey,
198            ConstraintType::Unique(_) => SqlServerTableConstraintType::Unique,
199        })
200    }
201}
202
203/// SQL Server table constraint.
204#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
205#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
206pub struct SqlServerTableConstraint {
207    pub constraint_name: String,
208    pub constraint_type: SqlServerTableConstraintType,
209    pub column_names: Vec<String>,
210}
211
212impl TryFrom<SqlServerTableConstraintRaw> for SqlServerTableConstraint {
213    type Error = SqlServerError;
214
215    fn try_from(value: SqlServerTableConstraintRaw) -> Result<Self, Self::Error> {
216        Ok(SqlServerTableConstraint {
217            constraint_name: value.constraint_name,
218            constraint_type: value.constraint_type.try_into()?,
219            column_names: value.columns,
220        })
221    }
222}
223
224impl RustType<ProtoSqlServerTableConstraint> for SqlServerTableConstraint {
225    fn into_proto(&self) -> ProtoSqlServerTableConstraint {
226        ProtoSqlServerTableConstraint {
227            constraint_name: self.constraint_name.clone(),
228            constraint_type: Some(self.constraint_type.into_proto()),
229            column_names: self.column_names.clone(),
230        }
231    }
232
233    fn from_proto(
234        proto: ProtoSqlServerTableConstraint,
235    ) -> Result<Self, mz_proto::TryFromProtoError> {
236        Ok(SqlServerTableConstraint {
237            constraint_name: proto.constraint_name,
238            constraint_type: proto
239                .constraint_type
240                .into_rust_if_some("ProtoSqlServerTableConstraint::constraint_type")?,
241            column_names: proto.column_names,
242        })
243    }
244}
245
246/// Partially qualified name of a table from Microsoft SQL Server.
247///
248/// TODO(sql_server3): Change this to use a &str.
249#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
250pub struct SqlServerQualifiedTableName {
251    pub schema_name: Arc<str>,
252    pub table_name: Arc<str>,
253}
254
255impl ToString for SqlServerQualifiedTableName {
256    fn to_string(&self) -> String {
257        format!(
258            "{}.{}",
259            crate::quote_identifier(&self.schema_name),
260            crate::quote_identifier(&self.table_name)
261        )
262    }
263}
264
265/// Raw metadata for a table from Microsoft SQL Server.
266///
267/// See [`SqlServerTableDesc`] for a refined description that is compatible
268/// with Materialize.
269#[derive(Debug, Clone)]
270pub struct SqlServerTableRaw {
271    /// Name of the schema the table belongs to.
272    pub schema_name: Arc<str>,
273    /// Name of the table.
274    pub name: Arc<str>,
275    /// The capture instance replicating changes.
276    pub capture_instance: Arc<SqlServerCaptureInstanceRaw>,
277    /// Columns for the table.
278    pub columns: Arc<[SqlServerColumnRaw]>,
279}
280
281/// Raw capture instance metadata.
282#[derive(Debug, Clone)]
283pub struct SqlServerCaptureInstanceRaw {
284    /// The capture instance replicating changes.
285    pub name: Arc<str>,
286    /// The creation date of the capture instance.
287    pub create_date: Arc<NaiveDateTime>,
288}
289
290/// Description of a column from a table in Microsoft SQL Server.
291#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
293pub struct SqlServerColumnDesc {
294    /// Name of the column.
295    pub name: Arc<str>,
296    /// The intended data type of the this column in Materialize. `None` indicates this
297    /// column should be excluded when replicating into Materialize.
298    ///
299    /// Note: This type might differ from the `decode_type`, e.g. a user can
300    /// specify `TEXT COLUMNS` to decode columns as text.
301    pub column_type: Option<SqlColumnType>,
302    /// This field is deprecated and will be removed in a future version.  This exists only for the
303    /// purpose of migrating from old representations.
304    pub primary_key_constraint: Option<Arc<str>>,
305    /// Rust type we should parse the data from a [`tiberius::Row`] as.
306    pub decode_type: SqlServerColumnDecodeType,
307    /// Raw type of the column as we read it from upstream.
308    ///
309    /// This is useful to keep around for debugging purposes.
310    pub raw_type: Arc<str>,
311}
312
313impl SqlServerColumnDesc {
314    /// Create a [`SqlServerColumnDesc`] from a [`SqlServerColumnRaw`] description.
315    pub fn new(raw: &SqlServerColumnRaw) -> Self {
316        let (column_type, decode_type) = match parse_data_type(raw) {
317            Ok((scalar_type, decode_type)) => {
318                let column_type = scalar_type.nullable(raw.is_nullable);
319                (Some(column_type), decode_type)
320            }
321            Err(err) => {
322                tracing::warn!(
323                    ?err,
324                    ?raw,
325                    "found an unsupported data type when parsing raw data"
326                );
327                (
328                    None,
329                    SqlServerColumnDecodeType::Unsupported {
330                        context: err.reason,
331                    },
332                )
333            }
334        };
335        SqlServerColumnDesc {
336            name: Arc::clone(&raw.name),
337            primary_key_constraint: None,
338            column_type,
339            decode_type,
340            raw_type: Arc::clone(&raw.data_type),
341        }
342    }
343
344    /// Change this [`SqlServerColumnDesc`] to be represented as text in Materialize.
345    pub fn represent_as_text(&mut self) {
346        self.column_type = self
347            .column_type
348            .as_ref()
349            .map(|ct| SqlScalarType::String.nullable(ct.nullable));
350    }
351
352    /// Exclude this [`SqlServerColumnDesc`] from being replicated into Materialize.
353    pub fn exclude(&mut self) {
354        self.column_type = None;
355    }
356
357    /// Check if this [`SqlServerColumnDesc`] is excluded from being replicated into Materialize.
358    pub fn is_excluded(&self) -> bool {
359        self.column_type.is_none()
360    }
361}
362
363impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
364    fn into_proto(&self) -> ProtoSqlServerColumnDesc {
365        ProtoSqlServerColumnDesc {
366            name: self.name.to_string(),
367            column_type: self.column_type.into_proto(),
368            primary_key_constraint: self.primary_key_constraint.as_ref().map(|v| v.to_string()),
369            decode_type: Some(self.decode_type.into_proto()),
370            raw_type: self.raw_type.to_string(),
371        }
372    }
373
374    fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
375        Ok(SqlServerColumnDesc {
376            name: proto.name.into(),
377            column_type: proto.column_type.into_rust()?,
378            primary_key_constraint: proto.primary_key_constraint.map(|v| v.into()),
379            decode_type: proto
380                .decode_type
381                .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
382            raw_type: proto.raw_type.into(),
383        })
384    }
385}
386
387/// The raw datatype from SQL Server is not supported in Materialize.
388#[derive(Debug)]
389#[allow(dead_code)]
390pub struct UnsupportedDataType {
391    column_name: String,
392    column_type: String,
393    reason: String,
394}
395
396/// Parse a raw data type from SQL Server into a Materialize [`SqlScalarType`].
397///
398/// Returns the [`SqlScalarType`] that we'll map this column to and the [`SqlServerColumnDecodeType`]
399/// that we use to decode the raw value.
400fn parse_data_type(
401    raw: &SqlServerColumnRaw,
402) -> Result<(SqlScalarType, SqlServerColumnDecodeType), UnsupportedDataType> {
403    // The value of a computed column, persisted or not, will be readable by the snapshot, but will
404    // always be NULL in the CDC stream.  This can lead to issues in MZ (e.g. decoding errors,
405    // negative accumulations, etc.).
406    if raw.is_computed {
407        return Err(UnsupportedDataType {
408            column_name: raw.name.to_string(),
409            column_type: format!("{} (computed)", raw.data_type.to_lowercase()),
410            reason: "column is computed".into(),
411        });
412    }
413
414    let scalar =
415        match raw.data_type.to_lowercase().as_str() {
416            "tinyint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::U8),
417            "smallint" => (SqlScalarType::Int16, SqlServerColumnDecodeType::I16),
418            "int" => (SqlScalarType::Int32, SqlServerColumnDecodeType::I32),
419            "bigint" => (SqlScalarType::Int64, SqlServerColumnDecodeType::I64),
420            "bit" => (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool),
421            "decimal" | "numeric" | "money" | "smallmoney" => {
422                // SQL Server supports a precision in the range of [1, 38] and then
423                // the scale is 0 <= scale <= precision.
424                //
425                // Materialize numerics are floating point with a fixed precision of 39.
426                //
427                // See: <https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments>
428                if raw.precision > 38 || raw.scale > raw.precision {
429                    tracing::warn!(
430                        "unexpected value from SQL Server, precision of {} and scale of {}",
431                        raw.precision,
432                        raw.scale,
433                    );
434                }
435                if raw.precision > 39 {
436                    let reason = format!(
437                        "precision of {} is greater than our maximum of 39",
438                        raw.precision
439                    );
440                    return Err(UnsupportedDataType {
441                        column_name: raw.name.to_string(),
442                        column_type: raw.data_type.to_string(),
443                        reason,
444                    });
445                }
446
447                let raw_scale = usize::cast_from(raw.scale);
448                let max_scale =
449                    NumericMaxScale::try_from(raw_scale).map_err(|_| UnsupportedDataType {
450                        column_type: raw.data_type.to_string(),
451                        column_name: raw.name.to_string(),
452                        reason: format!("scale of {} is too large", raw.scale),
453                    })?;
454                let column_type = SqlScalarType::Numeric {
455                    max_scale: Some(max_scale),
456                };
457
458                (column_type, SqlServerColumnDecodeType::Numeric)
459            }
460            // SQL Server has a few IEEE 754 floating point type names. The underlying type is float(n),
461            // where n is the number of bits used. SQL Server still ends up with only 2 distinct types
462            // as it treats 1 <= n <= 24 as n=24, and 25 <= n <= 53 as n=53.
463            //
464            // Additionally, `real` and `double precision` exist as synonyms of float(24) and float(53),
465            // respectively.  What doesn't appear to be documented is how these appear in `sys.types`.
466            // See <https://learn.microsoft.com/en-us/sql/t-sql/data-types/float-and-real-transact-sql?view=sql-server-ver17>
467            "real" | "float" | "double precision" => match raw.max_length {
468                // Decide the MZ type based on the number of bytes rather than the name, just in case
469                // there is inconsistency among versions.
470                4 => (SqlScalarType::Float32, SqlServerColumnDecodeType::F32),
471                8 => (SqlScalarType::Float64, SqlServerColumnDecodeType::F64),
472                _ => {
473                    return Err(UnsupportedDataType {
474                        column_name: raw.name.to_string(),
475                        column_type: raw.data_type.to_string(),
476                        reason: format!("unsupported length {}", raw.max_length),
477                    });
478                }
479            },
480            dt @ ("char" | "nchar" | "sysname") => {
481                // There isn't a char(max) or nchar(max), so it isn't clear if this condition
482                // is possible.
483                if raw.max_length == -1 {
484                    return Err(UnsupportedDataType {
485                        column_name: raw.name.to_string(),
486                        column_type: raw.data_type.to_string(),
487                        reason: "columns with unlimited size do not support CDC".to_string(),
488                    });
489                }
490
491                let column_type = match dt {
492                    "char" => {
493                        let length =
494                            if raw.max_length != -1 {
495                                let length = CharLength::try_from(i64::from(raw.max_length))
496                                    .map_err(|e| UnsupportedDataType {
497                                        column_name: raw.name.to_string(),
498                                        column_type: raw.data_type.to_string(),
499                                        reason: e.to_string(),
500                                    })?;
501                                Some(length)
502                            } else {
503                                None
504                            };
505                        SqlScalarType::Char { length }
506                    }
507                    // Determining the max character count for these types is difficult
508                    // because of different character encodings, so we fallback to just
509                    // representing them as "text".
510                    "nchar" | "sysname" => SqlScalarType::String,
511                    other => unreachable!("'{other}' checked above"),
512                };
513
514                (column_type, SqlServerColumnDecodeType::String)
515            }
516            "varchar" | "nvarchar" => {
517                // `max text repl size` is 64KB by default.  If a user attempts to insert a value
518                // that exceeds this limit, SQL Server will return an error and the insert fails
519                // with error `7139`.  This is also true for updates that increase the field length
520                // beyond the limit.
521                //
522                // See <https://learn.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors-7000-to-7999?view=sql-server-ver17>
523                //
524                // If the `max text repl size` changes, it does not affect events already written to
525                // the CDC table, nor does it change the behavior of what CDC captures for updates
526                // to non-LOD columns (based on testing).
527                let max_length =
528                    if raw.max_length != -1 {
529                        let length = VarCharMaxLength::try_from(i64::from(raw.max_length))
530                            .map_err(|e| UnsupportedDataType {
531                                column_name: raw.name.to_string(),
532                                column_type: raw.data_type.to_string(),
533                                reason: e.to_string(),
534                            })?;
535                        Some(length)
536                    } else {
537                        None
538                    };
539                let column_type = SqlScalarType::VarChar { max_length };
540                (column_type, SqlServerColumnDecodeType::String)
541            }
542            "text" | "ntext" | "image" => {
543                // SQL Server docs indicate this should always be 16. There's no
544                // issue if it's not, but it's good to track.
545                mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
546
547                // TODO(sql_server3): Support UPSERT semantics for SQL Server.
548                return Err(UnsupportedDataType {
549                    column_name: raw.name.to_string(),
550                    column_type: raw.data_type.to_string(),
551                    reason: "columns with unlimited size do not support CDC".to_string(),
552                });
553            }
554            "xml" => {
555                // When the `max_length` is -1 SQL Server will not present us with the "before" value
556                // for updated columns.
557                //
558                // TODO(sql_server3): Support UPSERT semantics for SQL Server.
559                if raw.max_length == -1 {
560                    return Err(UnsupportedDataType {
561                        column_name: raw.name.to_string(),
562                        column_type: raw.data_type.to_string(),
563                        reason: "columns with unlimited size do not support CDC".to_string(),
564                    });
565                }
566                (SqlScalarType::String, SqlServerColumnDecodeType::Xml)
567            }
568            "binary" | "varbinary" => {
569                // [`SqlScalarType`] does not support tracking max_length for binary data. To ensure
570                // columns of type varbinary(max) (Large Object Data) are decoded properly, it is
571                // necessary to know that the length is `max`. `varchar` and `nvarchar` track this
572                // using [`SqlScalarType::VarChar`] max_length field.
573                if raw.max_length == -1 {
574                    return Err(UnsupportedDataType {
575                        column_name: raw.name.to_string(),
576                        column_type: raw.data_type.to_string(),
577                        reason: "columns with unlimited size do not support CDC".to_string(),
578                    });
579                }
580                (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
581            }
582            "json" => (SqlScalarType::Jsonb, SqlServerColumnDecodeType::String),
583            "date" => (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
584            // SQL Server supports a scale of (and defaults to) 7 digits (aka 100 nanoseconds)
585            // for time related types.
586            //
587            // Internally Materialize supports a scale of 9 (aka nanoseconds), but for Postgres
588            // compatibility we constraint ourselves to a scale of 6 (aka microseconds). By
589            // default we will round values we get from  SQL Server to fit in Materialize.
590            //
591            // TODO(sql_server3): Support a "strict" mode where we're fail the creation of the
592            // source if the scale is too large.
593            // TODO(sql_server3): Support specifying a precision for SqlScalarType::Time.
594            //
595            // See: <https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetime2-transact-sql?view=sql-server-ver16>.
596            "time" => (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
597            dt @ ("smalldatetime" | "datetime" | "datetime2" | "datetimeoffset") => {
598                if raw.scale > 7 {
599                    tracing::warn!("unexpected scale '{}' from SQL Server", raw.scale);
600                }
601                if raw.scale > mz_repr::adt::timestamp::MAX_PRECISION {
602                    tracing::warn!("truncating scale of '{}' for '{}'", raw.scale, dt);
603                }
604                let precision = std::cmp::min(raw.scale, mz_repr::adt::timestamp::MAX_PRECISION);
605                let precision =
606                    Some(TimestampPrecision::try_from(i64::from(precision)).expect("known to fit"));
607
608                match dt {
609                    "smalldatetime" | "datetime" | "datetime2" => (
610                        SqlScalarType::Timestamp { precision },
611                        SqlServerColumnDecodeType::NaiveDateTime,
612                    ),
613                    "datetimeoffset" => (
614                        SqlScalarType::TimestampTz { precision },
615                        SqlServerColumnDecodeType::DateTime,
616                    ),
617                    other => unreachable!("'{other}' checked above"),
618                }
619            }
620            "uniqueidentifier" => (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
621            // TODO(sql_server3): Support reading the following types, at least as text:
622            //
623            // * geography
624            // * geometry
625            // * json (preview)
626            // * vector (preview)
627            //
628            // None of these types are implemented in `tiberius`, the crate that
629            // provides our SQL Server client, so we'll need to implement support
630            // for decoding them.
631            //
632            // See <https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/355f7890-6e91-4978-ab76-2ded17ee09bc>.
633            other => {
634                return Err(UnsupportedDataType {
635                    column_type: other.to_string(),
636                    column_name: raw.name.to_string(),
637                    reason: format!("'{other}' is unimplemented"),
638                });
639            }
640        };
641    Ok(scalar)
642}
643
644/// Raw metadata for a column from a table in Microsoft SQL Server.
645///
646/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
647#[derive(Clone, Debug)]
648pub struct SqlServerColumnRaw {
649    /// Name of this column.
650    pub name: Arc<str>,
651    /// Name of the data type.
652    pub data_type: Arc<str>,
653    /// Whether or not the column is nullable.
654    pub is_nullable: bool,
655    /// Maximum length (in bytes) of the column.
656    ///
657    /// For `varchar(max)`, `nvarchar(max)`, `varbinary(max)`, or `xml` this will be `-1`. For
658    /// `text`, `ntext`, and `image` columns this will be 16.
659    ///
660    /// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
661    ///
662    /// TODO(sql_server2): Validate this value for `json` columns where were introduced
663    /// Azure SQL 2024.
664    pub max_length: i16,
665    /// Precision of the column, if numeric-based; otherwise 0.
666    pub precision: u8,
667    /// Scale of the columns, if numeric-based; otherwise 0.
668    pub scale: u8,
669    /// Whether the column is computed.
670    pub is_computed: bool,
671}
672
673/// Raw metadata for a table constraint.
674#[derive(Clone, Debug)]
675pub struct SqlServerTableConstraintRaw {
676    pub constraint_name: String,
677    pub constraint_type: String,
678    pub columns: Vec<String>,
679}
680
681/// Rust type that we should use when reading a column from SQL Server.
682#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
683#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
684pub enum SqlServerColumnDecodeType {
685    Bool,
686    U8,
687    I16,
688    I32,
689    I64,
690    F32,
691    F64,
692    String,
693    Bytes,
694    /// [`uuid::Uuid`].
695    Uuid,
696    /// [`tiberius::numeric::Numeric`].
697    Numeric,
698    /// [`tiberius::xml::XmlData`].
699    Xml,
700    /// [`chrono::NaiveDate`].
701    NaiveDate,
702    /// [`chrono::NaiveTime`].
703    NaiveTime,
704    /// [`chrono::DateTime`].
705    DateTime,
706    /// [`chrono::NaiveDateTime`].
707    NaiveDateTime,
708    /// Decoding this type isn't supported.
709    Unsupported {
710        /// Any additional context as to why this type isn't supported.
711        context: String,
712    },
713}
714
715impl SqlServerColumnDecodeType {
716    /// Decode the column with `name` out of the provided `data`.
717    pub fn decode<'a>(
718        &self,
719        data: &'a tiberius::Row,
720        name: &'a str,
721        column: &'a SqlColumnType,
722        arena: &'a RowArena,
723    ) -> Result<Datum<'a>, SqlServerDecodeError> {
724        let maybe_datum = match (&column.scalar_type, self) {
725            (SqlScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
726                .try_get(name)
727                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool"))?
728                .map(|val: bool| if val { Datum::True } else { Datum::False }),
729            (SqlScalarType::Int16, SqlServerColumnDecodeType::U8) => data
730                .try_get(name)
731                .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8"))?
732                .map(|val: u8| Datum::Int16(i16::cast_from(val))),
733            (SqlScalarType::Int16, SqlServerColumnDecodeType::I16) => data
734                .try_get(name)
735                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16"))?
736                .map(Datum::Int16),
737            (SqlScalarType::Int32, SqlServerColumnDecodeType::I32) => data
738                .try_get(name)
739                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32"))?
740                .map(Datum::Int32),
741            (SqlScalarType::Int64, SqlServerColumnDecodeType::I64) => data
742                .try_get(name)
743                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64"))?
744                .map(Datum::Int64),
745            (SqlScalarType::Float32, SqlServerColumnDecodeType::F32) => data
746                .try_get(name)
747                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32"))?
748                .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
749            (SqlScalarType::Float64, SqlServerColumnDecodeType::F64) => data
750                .try_get(name)
751                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64"))?
752                .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
753            (SqlScalarType::String, SqlServerColumnDecodeType::String) => data
754                .try_get(name)
755                .map_err(|_| SqlServerDecodeError::invalid_column(name, "string"))?
756                .map(Datum::String),
757            (SqlScalarType::Char { length }, SqlServerColumnDecodeType::String) => data
758                .try_get(name)
759                .map_err(|_| SqlServerDecodeError::invalid_column(name, "char"))?
760                .map(|val: &str| match length {
761                    Some(expected) => {
762                        let found_chars = val.chars().count();
763                        let expct_chars = usize::cast_from(expected.into_u32());
764                        if found_chars != expct_chars {
765                            Err(SqlServerDecodeError::invalid_char(
766                                name,
767                                expct_chars,
768                                found_chars,
769                            ))
770                        } else {
771                            Ok(Datum::String(val))
772                        }
773                    }
774                    None => Ok(Datum::String(val)),
775                })
776                .transpose()?,
777            (SqlScalarType::VarChar { max_length }, SqlServerColumnDecodeType::String) => data
778                .try_get(name)
779                .map_err(|_| SqlServerDecodeError::invalid_column(name, "varchar"))?
780                .map(|val: &str| match max_length {
781                    Some(max) => {
782                        let found_chars = val.chars().count();
783                        let max_chars = usize::cast_from(max.into_u32());
784                        if found_chars > max_chars {
785                            Err(SqlServerDecodeError::invalid_varchar(
786                                name,
787                                max_chars,
788                                found_chars,
789                            ))
790                        } else {
791                            Ok(Datum::String(val))
792                        }
793                    }
794                    None => Ok(Datum::String(val)),
795                })
796                .transpose()?,
797            (SqlScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => data
798                .try_get(name)
799                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes"))?
800                .map(Datum::Bytes),
801            (SqlScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => data
802                .try_get(name)
803                .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid"))?
804                .map(Datum::Uuid),
805            (SqlScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
806                .try_get(name)
807                .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric"))?
808                .map(|val: tiberius::numeric::Numeric| {
809                    let numeric = tiberius_numeric_to_mz_numeric(val);
810                    Datum::Numeric(OrderedDecimal(numeric))
811                }),
812            (SqlScalarType::String, SqlServerColumnDecodeType::Xml) => data
813                .try_get(name)
814                .map_err(|_| SqlServerDecodeError::invalid_column(name, "xml"))?
815                .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
816            (SqlScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
817                .try_get(name)
818                .map_err(|_| SqlServerDecodeError::invalid_column(name, "date"))?
819                .map(|val: chrono::NaiveDate| {
820                    let date = val
821                        .try_into()
822                        .map_err(|e| SqlServerDecodeError::invalid_date(name, e))?;
823                    Ok::<_, SqlServerDecodeError>(Datum::Date(date))
824                })
825                .transpose()?,
826            (SqlScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => data
827                .try_get(name)
828                .map_err(|_| SqlServerDecodeError::invalid_column(name, "time"))?
829                .map(|val: chrono::NaiveTime| {
830                    // Postgres' maximum precision is 6 (aka microseconds).
831                    //
832                    // While the Postgres spec supports specifying a precision
833                    // Materialize does not.
834                    let rounded = val.round_subsecs(6);
835                    // Overflowed.
836                    let val = if rounded < val {
837                        val.trunc_subsecs(6)
838                    } else {
839                        val
840                    };
841                    Datum::Time(val)
842                }),
843            (SqlScalarType::Timestamp { precision }, SqlServerColumnDecodeType::NaiveDateTime) => {
844                data.try_get(name)
845                    .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamp"))?
846                    .map(|val: chrono::NaiveDateTime| {
847                        let ts: CheckedTimestamp<chrono::NaiveDateTime> = val
848                            .try_into()
849                            .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
850                        let rounded = ts
851                            .round_to_precision(*precision)
852                            .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
853                        Ok::<_, SqlServerDecodeError>(Datum::Timestamp(rounded))
854                    })
855                    .transpose()?
856            }
857            (SqlScalarType::TimestampTz { precision }, SqlServerColumnDecodeType::DateTime) => data
858                .try_get(name)
859                .map_err(|_| SqlServerDecodeError::invalid_column(name, "timestamptz"))?
860                .map(|val: chrono::DateTime<chrono::Utc>| {
861                    let ts: CheckedTimestamp<chrono::DateTime<chrono::Utc>> = val
862                        .try_into()
863                        .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
864                    let rounded = ts
865                        .round_to_precision(*precision)
866                        .map_err(|e| SqlServerDecodeError::invalid_timestamp(name, e))?;
867                    Ok::<_, SqlServerDecodeError>(Datum::TimestampTz(rounded))
868                })
869                .transpose()?,
870            // We support mapping any type to a string.
871            (SqlScalarType::String, SqlServerColumnDecodeType::Bool) => data
872                .try_get(name)
873                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bool-text"))?
874                .map(|val: bool| {
875                    if val {
876                        Datum::String("true")
877                    } else {
878                        Datum::String("false")
879                    }
880                }),
881            (SqlScalarType::String, SqlServerColumnDecodeType::U8) => data
882                .try_get(name)
883                .map_err(|_| SqlServerDecodeError::invalid_column(name, "u8-text"))?
884                .map(|val: u8| {
885                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
886                }),
887            (SqlScalarType::String, SqlServerColumnDecodeType::I16) => data
888                .try_get(name)
889                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i16-text"))?
890                .map(|val: i16| {
891                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
892                }),
893            (SqlScalarType::String, SqlServerColumnDecodeType::I32) => data
894                .try_get(name)
895                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i32-text"))?
896                .map(|val: i32| {
897                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
898                }),
899            (SqlScalarType::String, SqlServerColumnDecodeType::I64) => data
900                .try_get(name)
901                .map_err(|_| SqlServerDecodeError::invalid_column(name, "i64-text"))?
902                .map(|val: i64| {
903                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
904                }),
905            (SqlScalarType::String, SqlServerColumnDecodeType::F32) => data
906                .try_get(name)
907                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f32-text"))?
908                .map(|val: f32| {
909                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
910                }),
911            (SqlScalarType::String, SqlServerColumnDecodeType::F64) => data
912                .try_get(name)
913                .map_err(|_| SqlServerDecodeError::invalid_column(name, "f64-text"))?
914                .map(|val: f64| {
915                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
916                }),
917            (SqlScalarType::String, SqlServerColumnDecodeType::Uuid) => data
918                .try_get(name)
919                .map_err(|_| SqlServerDecodeError::invalid_column(name, "uuid-text"))?
920                .map(|val: uuid::Uuid| {
921                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
922                }),
923            (SqlScalarType::String, SqlServerColumnDecodeType::Bytes) => data
924                .try_get(name)
925                .map_err(|_| SqlServerDecodeError::invalid_column(name, "bytes-text"))?
926                .map(|val: &[u8]| {
927                    let encoded = base64::engine::general_purpose::STANDARD.encode(val);
928                    arena.make_datum(|packer| packer.push(Datum::String(&encoded)))
929                }),
930            (SqlScalarType::String, SqlServerColumnDecodeType::Numeric) => data
931                .try_get(name)
932                .map_err(|_| SqlServerDecodeError::invalid_column(name, "numeric-text"))?
933                .map(|val: tiberius::numeric::Numeric| {
934                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
935                }),
936            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDate) => data
937                .try_get(name)
938                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedate-text"))?
939                .map(|val: chrono::NaiveDate| {
940                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
941                }),
942            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveTime) => data
943                .try_get(name)
944                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivetime-text"))?
945                .map(|val: chrono::NaiveTime| {
946                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
947                }),
948            (SqlScalarType::String, SqlServerColumnDecodeType::DateTime) => data
949                .try_get(name)
950                .map_err(|_| SqlServerDecodeError::invalid_column(name, "datetime-text"))?
951                .map(|val: chrono::DateTime<chrono::Utc>| {
952                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
953                }),
954            (SqlScalarType::String, SqlServerColumnDecodeType::NaiveDateTime) => data
955                .try_get(name)
956                .map_err(|_| SqlServerDecodeError::invalid_column(name, "naivedatetime-text"))?
957                .map(|val: chrono::NaiveDateTime| {
958                    arena.make_datum(|packer| packer.push(Datum::String(&val.to_string())))
959                }),
960            (column_type, decode_type) => {
961                return Err(SqlServerDecodeError::Unsupported {
962                    sql_server_type: decode_type.clone(),
963                    mz_type: column_type.clone(),
964                });
965            }
966        };
967
968        match (maybe_datum, column.nullable) {
969            (Some(datum), _) => Ok(datum),
970            (None, true) => Ok(Datum::Null),
971            (None, false) => Err(SqlServerDecodeError::InvalidData {
972                column_name: name.to_string(),
973                // Note: This error string is durably recorded in Persist, do not change.
974                error: "found Null in non-nullable column".to_string(),
975            }),
976        }
977    }
978}
979
980impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
981    fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
982        match self {
983            SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
984            SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
985            SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
986            SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
987            SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
988            SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
989            SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
990            SqlServerColumnDecodeType::String => {
991                proto_sql_server_column_desc::DecodeType::String(())
992            }
993            SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
994            SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
995            SqlServerColumnDecodeType::Numeric => {
996                proto_sql_server_column_desc::DecodeType::Numeric(())
997            }
998            SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
999            SqlServerColumnDecodeType::NaiveDate => {
1000                proto_sql_server_column_desc::DecodeType::NaiveDate(())
1001            }
1002            SqlServerColumnDecodeType::NaiveTime => {
1003                proto_sql_server_column_desc::DecodeType::NaiveTime(())
1004            }
1005            SqlServerColumnDecodeType::DateTime => {
1006                proto_sql_server_column_desc::DecodeType::DateTime(())
1007            }
1008            SqlServerColumnDecodeType::NaiveDateTime => {
1009                proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
1010            }
1011            SqlServerColumnDecodeType::Unsupported { context } => {
1012                proto_sql_server_column_desc::DecodeType::Unsupported(context.clone())
1013            }
1014        }
1015    }
1016
1017    fn from_proto(
1018        proto: proto_sql_server_column_desc::DecodeType,
1019    ) -> Result<Self, mz_proto::TryFromProtoError> {
1020        let val = match proto {
1021            proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
1022            proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
1023            proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
1024            proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
1025            proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
1026            proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
1027            proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
1028            proto_sql_server_column_desc::DecodeType::String(()) => {
1029                SqlServerColumnDecodeType::String
1030            }
1031            proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
1032            proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
1033            proto_sql_server_column_desc::DecodeType::Numeric(()) => {
1034                SqlServerColumnDecodeType::Numeric
1035            }
1036            proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
1037            proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
1038                SqlServerColumnDecodeType::NaiveDate
1039            }
1040            proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
1041                SqlServerColumnDecodeType::NaiveTime
1042            }
1043            proto_sql_server_column_desc::DecodeType::DateTime(()) => {
1044                SqlServerColumnDecodeType::DateTime
1045            }
1046            proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
1047                SqlServerColumnDecodeType::NaiveDateTime
1048            }
1049            proto_sql_server_column_desc::DecodeType::Unsupported(context) => {
1050                SqlServerColumnDecodeType::Unsupported { context }
1051            }
1052        };
1053        Ok(val)
1054    }
1055}
1056
1057/// Numerics in SQL Server have a maximum precision of 38 digits, where [`Numeric`]s in
1058/// Materialize have a maximum precision of 39 digits, so this conversion is infallible.
1059fn tiberius_numeric_to_mz_numeric(val: tiberius::numeric::Numeric) -> Numeric {
1060    let mut numeric = mz_repr::adt::numeric::cx_datum().from_i128(val.value());
1061    // Use scaleb to adjust the exponent directly, avoiding precision loss from division
1062    // scaleb(x, -n) computes x * 10^(-n)
1063    mz_repr::adt::numeric::cx_datum().scaleb(&mut numeric, &Numeric::from(-i32::from(val.scale())));
1064    numeric
1065}
1066
1067/// The update mask of a CDC event row, returned by `cdc.fn_cdc_get_all_changes_<capture_instance>`
1068/// as `__$update_mask`.
1069///
1070/// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver17>
1071#[derive(Debug)]
1072pub struct UpdateMask {
1073    mask: Vec<u8>,
1074}
1075
1076impl TryFrom<&tiberius::Row> for UpdateMask {
1077    type Error = SqlServerDecodeError;
1078
1079    fn try_from(row: &tiberius::Row) -> Result<Self, Self::Error> {
1080        static UPDATE_MASK: &str = "__$update_mask";
1081
1082        let mask: Vec<u8> = row
1083            .try_get::<&[u8], _>(UPDATE_MASK)
1084            .inspect_err(|e| tracing::warn!("Failed extracting update mask: {e:?}"))
1085            .map_err(|_| SqlServerDecodeError::InvalidColumn {
1086                column_name: UPDATE_MASK.to_string(),
1087                as_type: "bytes",
1088            })?
1089            .ok_or_else(|| SqlServerDecodeError::InvalidData {
1090                column_name: UPDATE_MASK.to_string(),
1091                error: "column cannot be null".to_string(),
1092            })?
1093            .into();
1094        Ok(UpdateMask { mask })
1095    }
1096}
1097
1098impl UpdateMask {
1099    /// Returns true if the data column was updated, false otherwise.
1100    ///
1101    /// This function panics if `col_index` exceeds the mask.
1102    ///
1103    /// The [`tiberius::Row`] returned by `cdc.fn_cdc_get_all_changes_<capture_instance>` contains
1104    /// 4 metadata columns used by CDC:
1105    /// - `__$start_lsn`
1106    /// - `__$seqval`
1107    /// - `__$operation`
1108    /// - `__$update_mask`
1109    ///
1110    /// This function will always return false for the first 4 columns.
1111    pub fn data_col_updated(&self, col_index: usize) -> bool {
1112        const CDC_METADATA_COL_COUNT: usize = 4;
1113
1114        if col_index < CDC_METADATA_COL_COUNT {
1115            return false;
1116        }
1117        let adj_col_index = col_index - CDC_METADATA_COL_COUNT;
1118        let byte_offset = adj_col_index / usize::cast_from(u8::BITS);
1119        assert!(
1120            byte_offset < self.mask.len(),
1121            "byte_offset = {byte_offset} mask_len = {}",
1122            self.mask.len()
1123        );
1124        let bit_offset = adj_col_index % usize::cast_from(u8::BITS);
1125        (self.mask[self.mask.len() - byte_offset - 1] >> bit_offset) & 1 == 1
1126    }
1127}
1128
1129/// A decoder from [`tiberius::Row`] to [`mz_repr::Row`].
1130///
1131/// The goal of this type is to perform any expensive "downcasts" so in the hot
1132/// path of decoding rows we do the minimal amount of work.
1133#[derive(Debug)]
1134pub struct SqlServerRowDecoder {
1135    decoders: Vec<(Arc<str>, SqlColumnType, SqlServerColumnDecodeType)>,
1136}
1137
1138impl SqlServerRowDecoder {
1139    /// Try to create a [`SqlServerRowDecoder`] that will decode [`tiberius::Row`]s that match
1140    /// the shape of the provided [`SqlServerTableDesc`], to [`mz_repr::Row`]s that match the
1141    /// shape of the provided [`RelationDesc`].
1142    pub fn try_new(
1143        table: &SqlServerTableDesc,
1144        desc: &RelationDesc,
1145    ) -> Result<Self, SqlServerError> {
1146        let decoders = desc
1147            .iter()
1148            .map(|(col_name, col_type)| {
1149                let sql_server_col = table
1150                    .columns
1151                    .iter()
1152                    .find(|col| col.name.as_ref() == col_name.as_str())
1153                    .ok_or_else(|| {
1154                        // TODO(sql_server2): Structured Error.
1155                        anyhow::anyhow!("no SQL Server column with name {col_name} found")
1156                    })?;
1157                let Some(sql_server_col_typ) = sql_server_col.column_type.as_ref() else {
1158                    return Err(SqlServerError::ProgrammingError(format!(
1159                        "programming error, {col_name} should have been exluded",
1160                    )));
1161                };
1162
1163                // This shouldn't be true, but be defensive.
1164                //
1165                // TODO(sql_server2): Maybe allow the Materialize column type to be
1166                // more nullable than our decoding type?
1167                //
1168                // Sad. Our timestamp types don't roundtrip their precision through
1169                // parsing so we ignore the mismatch here.
1170                let matches = match (&sql_server_col_typ.scalar_type, &col_type.scalar_type) {
1171                    (SqlScalarType::Timestamp { .. }, SqlScalarType::Timestamp { .. })
1172                    | (SqlScalarType::TimestampTz { .. }, SqlScalarType::TimestampTz { .. }) => {
1173                        // Types match so check nullability.
1174                        sql_server_col_typ.nullable == col_type.nullable
1175                    }
1176                    (_, _) => sql_server_col_typ == col_type,
1177                };
1178                if !matches {
1179                    return Err(SqlServerError::ProgrammingError(format!(
1180                        "programming error, {col_name} has mismatched type {:?} vs {:?}",
1181                        sql_server_col.column_type, col_type
1182                    )));
1183                }
1184
1185                let name = Arc::clone(&sql_server_col.name);
1186                let decoder = sql_server_col.decode_type.clone();
1187                // Note: We specifically use the `SqlColumnType` from the SqlServerTableDesc
1188                // because it retains precision.
1189                //
1190                // See: <https://github.com/MaterializeInc/database-issues/issues/3179>.
1191                let col_typ = sql_server_col_typ.clone();
1192
1193                Ok::<_, SqlServerError>((name, col_typ, decoder))
1194            })
1195            .collect::<Result<_, _>>()?;
1196
1197        Ok(SqlServerRowDecoder { decoders })
1198    }
1199
1200    /// Decode data from the provided [`tiberius::Row`] into the provided [`Row`].
1201    ///
1202    /// For updates, the new row data is provided in the event the data contains Large Object Data
1203    /// (e.g. varchar(max)). [`SqlServerRowDecoder::decode()`] will decode the [`UpdateMask`] from
1204    /// the new row and retrieve LOD values from the new row for any LOD column that was not updated
1205    /// in the old row.
1206    pub fn decode(
1207        &self,
1208        data: &tiberius::Row,
1209        row: &mut Row,
1210        arena: &RowArena,
1211        new_data: Option<&tiberius::Row>,
1212    ) -> Result<(), SqlServerDecodeError> {
1213        let mut packer = row.packer();
1214
1215        for (col_name, col_type, decoder) in &self.decoders {
1216            let datum = decoder.decode(data, col_name, col_type, arena)?;
1217
1218            let datum = if let Some(new_data) = new_data
1219                && matches!(
1220                    col_type.scalar_type,
1221                    SqlScalarType::VarChar { max_length: None }
1222                )
1223                && matches!(datum, Datum::Null)
1224            {
1225                let update_mask = UpdateMask::try_from(new_data)?;
1226                let col_index = new_data
1227                    .columns()
1228                    .iter()
1229                    .position(|c| c.name() == col_name.as_ref())
1230                    .expect("column exists");
1231                // The only time it is valid to pull the LOD column value from the new row
1232                // is if the LOD column was *not* updated. The mask check is necessary to
1233                // differentiate between updating a non-LOD column (the LOD column in old row
1234                // is NULL) and updating a LOD column where the old value is NULL.
1235                if !update_mask.data_col_updated(col_index) {
1236                    decoder.decode(new_data, col_name, col_type, arena)?
1237                } else {
1238                    datum
1239                }
1240            } else {
1241                datum
1242            };
1243
1244            packer.push(datum);
1245        }
1246        Ok(())
1247    }
1248
1249    pub fn included_column_names(&self) -> Vec<Arc<str>> {
1250        self.decoders
1251            .iter()
1252            .map(|decoder| Arc::clone(&decoder.0))
1253            .collect()
1254    }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259    use std::collections::BTreeSet;
1260    use std::sync::Arc;
1261
1262    use chrono::NaiveDateTime;
1263    use itertools::Itertools;
1264    use mz_ore::assert_contains;
1265    use mz_ore::collections::CollectionExt;
1266    use mz_repr::adt::numeric::NumericMaxScale;
1267    use mz_repr::adt::varchar::VarCharMaxLength;
1268    use mz_repr::{Datum, RelationDesc, Row, RowArena, SqlScalarType};
1269    use tiberius::RowTestExt;
1270
1271    use crate::desc::{
1272        SqlServerCaptureInstanceRaw, SqlServerColumnDecodeType, SqlServerColumnDesc,
1273        SqlServerTableDesc, SqlServerTableRaw, tiberius_numeric_to_mz_numeric,
1274    };
1275
1276    use super::SqlServerColumnRaw;
1277
1278    impl SqlServerColumnRaw {
1279        /// Create a new [`SqlServerColumnRaw`]. The specified `data_type` is
1280        /// _not_ checked for validity.
1281        fn new(name: &str, data_type: &str) -> Self {
1282            SqlServerColumnRaw {
1283                name: name.into(),
1284                data_type: data_type.into(),
1285                is_nullable: false,
1286                max_length: 0,
1287                precision: 0,
1288                scale: 0,
1289                is_computed: false,
1290            }
1291        }
1292
1293        fn nullable(mut self, nullable: bool) -> Self {
1294            self.is_nullable = nullable;
1295            self
1296        }
1297
1298        fn max_length(mut self, max_length: i16) -> Self {
1299            self.max_length = max_length;
1300            self
1301        }
1302
1303        fn precision(mut self, precision: u8) -> Self {
1304            self.precision = precision;
1305            self
1306        }
1307
1308        fn scale(mut self, scale: u8) -> Self {
1309            self.scale = scale;
1310            self
1311        }
1312    }
1313
1314    #[mz_ore::test]
1315    fn smoketest_column_raw() {
1316        let raw = SqlServerColumnRaw::new("foo", "bit");
1317        let col = SqlServerColumnDesc::new(&raw);
1318
1319        assert_eq!(&*col.name, "foo");
1320        assert_eq!(col.column_type, Some(SqlScalarType::Bool.nullable(false)));
1321        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
1322
1323        let raw = SqlServerColumnRaw::new("foo", "decimal")
1324            .precision(20)
1325            .scale(10);
1326        let col = SqlServerColumnDesc::new(&raw);
1327
1328        let col_type = SqlScalarType::Numeric {
1329            max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
1330        }
1331        .nullable(false);
1332        assert_eq!(col.column_type, Some(col_type));
1333        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
1334    }
1335
1336    #[mz_ore::test]
1337    fn smoketest_column_raw_invalid() {
1338        let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
1339        let desc = SqlServerColumnDesc::new(&raw);
1340        let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1341            panic!("unexpected decode type {desc:?}");
1342        };
1343        assert_contains!(context, "'bad_data_type' is unimplemented");
1344
1345        let raw = SqlServerColumnRaw::new("foo", "decimal")
1346            .precision(100)
1347            .scale(10);
1348        let desc = SqlServerColumnDesc::new(&raw);
1349        assert!(matches!(
1350            desc.decode_type,
1351            SqlServerColumnDecodeType::Unsupported { .. }
1352        ));
1353
1354        let raw = SqlServerColumnRaw::new("foo", "varbinary").max_length(-1);
1355        let desc = SqlServerColumnDesc::new(&raw);
1356        let SqlServerColumnDecodeType::Unsupported { context } = desc.decode_type else {
1357            panic!("unexpected decode type {desc:?}");
1358        };
1359        assert_contains!(context, "columns with unlimited size do not support CDC");
1360    }
1361
1362    #[mz_ore::test]
1363    fn smoketest_decoder() {
1364        let sql_server_columns = [
1365            SqlServerColumnRaw::new("a", "varchar").max_length(16),
1366            SqlServerColumnRaw::new("b", "int").nullable(true),
1367            SqlServerColumnRaw::new("c", "bit"),
1368        ];
1369        let sql_server_desc = SqlServerTableRaw {
1370            schema_name: "my_schema".into(),
1371            name: "my_table".into(),
1372            capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1373                name: "my_table_CT".into(),
1374                create_date: NaiveDateTime::parse_from_str(
1375                    "2024-01-01 00:00:00",
1376                    "%Y-%m-%d %H:%M:%S",
1377                )
1378                .unwrap()
1379                .into(),
1380            }),
1381            columns: sql_server_columns.into(),
1382        };
1383        let sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1384
1385        let max_length = Some(VarCharMaxLength::try_from(16).unwrap());
1386        let relation_desc = RelationDesc::builder()
1387            .with_column("a", SqlScalarType::VarChar { max_length }.nullable(false))
1388            // Note: In the upstream table 'c' is ordered after 'b'.
1389            .with_column("c", SqlScalarType::Bool.nullable(false))
1390            .with_column("b", SqlScalarType::Int32.nullable(true))
1391            .finish();
1392
1393        // This decoder should shape the SQL Server Rows into Rows compatible with the RelationDesc.
1394        let decoder = sql_server_desc
1395            .decoder(&relation_desc)
1396            .expect("known valid");
1397
1398        let sql_server_columns = [
1399            tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
1400            tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
1401            tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
1402        ];
1403
1404        let data_a = [
1405            tiberius::ColumnData::String(Some("hello world".into())),
1406            tiberius::ColumnData::I32(Some(42)),
1407            tiberius::ColumnData::Bit(Some(true)),
1408        ];
1409        let sql_server_row_a = tiberius::Row::build(
1410            sql_server_columns
1411                .iter()
1412                .cloned()
1413                .zip_eq(data_a.into_iter()),
1414        );
1415
1416        let data_b = [
1417            tiberius::ColumnData::String(Some("foo bar".into())),
1418            tiberius::ColumnData::I32(None),
1419            tiberius::ColumnData::Bit(Some(false)),
1420        ];
1421        let sql_server_row_b =
1422            tiberius::Row::build(sql_server_columns.into_iter().zip_eq(data_b.into_iter()));
1423
1424        let mut rnd_row = Row::default();
1425        let arena = RowArena::default();
1426
1427        decoder
1428            .decode(&sql_server_row_a, &mut rnd_row, &arena, None)
1429            .unwrap();
1430        assert_eq!(
1431            &rnd_row,
1432            &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
1433        );
1434
1435        decoder
1436            .decode(&sql_server_row_b, &mut rnd_row, &arena, None)
1437            .unwrap();
1438        assert_eq!(
1439            &rnd_row,
1440            &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
1441        );
1442    }
1443
1444    #[mz_ore::test]
1445    fn smoketest_decode_to_string() {
1446        #[track_caller]
1447        fn testcase(
1448            data_type: &'static str,
1449            col_type: tiberius::ColumnType,
1450            col_data: tiberius::ColumnData<'static>,
1451        ) {
1452            let columns = [SqlServerColumnRaw::new("a", data_type)];
1453            let sql_server_desc = SqlServerTableRaw {
1454                schema_name: "my_schema".into(),
1455                name: "my_table".into(),
1456                capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
1457                    name: "my_table_CT".into(),
1458                    create_date: NaiveDateTime::parse_from_str(
1459                        "2024-01-01 00:00:00",
1460                        "%Y-%m-%d %H:%M:%S",
1461                    )
1462                    .unwrap()
1463                    .into(),
1464                }),
1465                columns: columns.into(),
1466            };
1467            let mut sql_server_desc = SqlServerTableDesc::new(sql_server_desc, vec![]).unwrap();
1468            sql_server_desc.apply_text_columns(&BTreeSet::from(["a"]));
1469
1470            // We should support decoding every datatype to a string.
1471            let relation_desc = RelationDesc::builder()
1472                .with_column("a", SqlScalarType::String.nullable(false))
1473                .finish();
1474
1475            // This decoder should shape the SQL Server Rows into Rows compatible with the RelationDesc.
1476            let decoder = sql_server_desc
1477                .decoder(&relation_desc)
1478                .expect("known valid");
1479
1480            let sql_server_row = tiberius::Row::build([(
1481                tiberius::Column::new("a".to_string(), col_type),
1482                col_data,
1483            )]);
1484            let mut mz_row = Row::default();
1485            let arena = RowArena::new();
1486            decoder
1487                .decode(&sql_server_row, &mut mz_row, &arena, None)
1488                .unwrap();
1489
1490            let str_datum = mz_row.into_element();
1491            assert!(matches!(str_datum, Datum::String(_)));
1492        }
1493
1494        use tiberius::ColumnData;
1495
1496        testcase(
1497            "bit",
1498            tiberius::ColumnType::Bit,
1499            ColumnData::Bit(Some(true)),
1500        );
1501        testcase(
1502            "bit",
1503            tiberius::ColumnType::Bit,
1504            ColumnData::Bit(Some(false)),
1505        );
1506        testcase(
1507            "tinyint",
1508            tiberius::ColumnType::Int1,
1509            ColumnData::U8(Some(33)),
1510        );
1511        testcase(
1512            "smallint",
1513            tiberius::ColumnType::Int2,
1514            ColumnData::I16(Some(101)),
1515        );
1516        testcase(
1517            "int",
1518            tiberius::ColumnType::Int4,
1519            ColumnData::I32(Some(-42)),
1520        );
1521        {
1522            let datetime = tiberius::time::DateTime::new(10, 300);
1523            testcase(
1524                "datetime",
1525                tiberius::ColumnType::Datetime,
1526                ColumnData::DateTime(Some(datetime)),
1527            );
1528        }
1529    }
1530
1531    #[mz_ore::test]
1532    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `decNumberFromInt32` on OS `linux`
1533    fn smoketest_numeric_conversion() {
1534        let a = tiberius::numeric::Numeric::new_with_scale(12345, 2);
1535        let rnd = tiberius_numeric_to_mz_numeric(a);
1536        let og = mz_repr::adt::numeric::cx_datum().parse("123.45").unwrap();
1537        assert_eq!(og, rnd);
1538
1539        let a = tiberius::numeric::Numeric::new_with_scale(-99999, 5);
1540        let rnd = tiberius_numeric_to_mz_numeric(a);
1541        let og = mz_repr::adt::numeric::cx_datum().parse("-.99999").unwrap();
1542        assert_eq!(og, rnd);
1543
1544        let a = tiberius::numeric::Numeric::new_with_scale(1, 29);
1545        let rnd = tiberius_numeric_to_mz_numeric(a);
1546        let og = mz_repr::adt::numeric::cx_datum()
1547            .parse("0.00000000000000000000000000001")
1548            .unwrap();
1549        assert_eq!(og, rnd);
1550
1551        let a = tiberius::numeric::Numeric::new_with_scale(-111111111111111111, 0);
1552        let rnd = tiberius_numeric_to_mz_numeric(a);
1553        let og = mz_repr::adt::numeric::cx_datum()
1554            .parse("-111111111111111111")
1555            .unwrap();
1556        assert_eq!(og, rnd);
1557    }
1558
1559    // TODO(sql_server2): Proptest the decoder.
1560}