mz_sql_server_util/
desc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Metadata about tables, columns, and other objects from SQL Server.
11//!
12//! ### Tables
13//!
14//! When creating a SQL Server source we will query system tables from the
15//! upstream instance to get a [`SqlServerTableRaw`]. From this raw information
16//! we create a [`SqlServerTableDesc`] which describes how the upstream table
17//! will get represented in Materialize.
18//!
19//! ### Rows
20//!
21//! With a [`SqlServerTableDesc`] and an [`mz_repr::RelationDesc`] we can
22//! create a [`SqlServerRowDecoder`] which will be used when running a source
23//! to efficiently decode [`tiberius::Row`]s into [`mz_repr::Row`]s.
24
25use anyhow::Context;
26use dec::OrderedDecimal;
27use mz_ore::cast::CastFrom;
28use mz_proto::{IntoRustIfSome, ProtoType, RustType};
29use mz_repr::adt::numeric::{Dec, Numeric, NumericMaxScale};
30use mz_repr::{ColumnType, Datum, RelationDesc, Row, ScalarType};
31use proptest_derive::Arbitrary;
32use serde::{Deserialize, Serialize};
33
34use std::sync::Arc;
35
36use crate::SqlServerError;
37
38include!(concat!(env!("OUT_DIR"), "/mz_sql_server_util.rs"));
39
40/// Materialize compatible description of a table in Microsoft SQL Server.
41///
42/// See [`SqlServerTableRaw`] for the raw information we read from the upstream
43/// system.
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
45pub struct SqlServerTableDesc {
46    /// Name of the schema that the table belongs to.
47    pub schema_name: Arc<str>,
48    /// Name of the table.
49    pub name: Arc<str>,
50    /// Columns for the table.
51    pub columns: Arc<[SqlServerColumnDesc]>,
52    /// Is CDC enabled for this table, required to replicate into Materialize.
53    pub is_cdc_enabled: bool,
54}
55
56impl SqlServerTableDesc {
57    /// Try creating a [`SqlServerTableDesc`] from a [`SqlServerTableRaw`] description.
58    ///
59    /// Returns an error if the raw table description is not compatible with Materialize.
60    pub fn try_new(raw: SqlServerTableRaw) -> Result<Self, SqlServerError> {
61        let columns: Arc<[_]> = raw
62            .columns
63            .into_iter()
64            .map(SqlServerColumnDesc::try_new)
65            .collect::<Result<_, _>>()?;
66        Ok(SqlServerTableDesc {
67            schema_name: raw.schema_name,
68            name: raw.name,
69            columns,
70            is_cdc_enabled: raw.is_cdc_enabled,
71        })
72    }
73
74    /// Returns a [`SqlServerRowDecoder`] which can be used to decode [`tiberius::Row`]s into
75    /// [`mz_repr::Row`]s that match the shape of the provided [`RelationDesc`].
76    pub fn decoder(&self, desc: &RelationDesc) -> Result<SqlServerRowDecoder, SqlServerError> {
77        let decoder = SqlServerRowDecoder::try_new(self, desc)?;
78        Ok(decoder)
79    }
80}
81
82impl RustType<ProtoSqlServerTableDesc> for SqlServerTableDesc {
83    fn into_proto(&self) -> ProtoSqlServerTableDesc {
84        ProtoSqlServerTableDesc {
85            name: self.name.to_string(),
86            schema_name: self.schema_name.to_string(),
87            columns: self.columns.iter().map(|c| c.into_proto()).collect(),
88            is_cdc_enabled: self.is_cdc_enabled,
89        }
90    }
91
92    fn from_proto(proto: ProtoSqlServerTableDesc) -> Result<Self, mz_proto::TryFromProtoError> {
93        let columns = proto
94            .columns
95            .into_iter()
96            .map(|c| c.into_rust())
97            .collect::<Result<_, _>>()?;
98        Ok(SqlServerTableDesc {
99            schema_name: proto.schema_name.into(),
100            name: proto.name.into(),
101            columns,
102            is_cdc_enabled: proto.is_cdc_enabled,
103        })
104    }
105}
106
107/// Raw metadata for a table from Microsoft SQL Server.
108///
109/// See [`SqlServerTableDesc`] for a refined description that is compatible
110/// with Materialize.
111#[derive(Debug, Clone)]
112pub struct SqlServerTableRaw {
113    /// Name of the schema the table belongs to.
114    pub schema_name: Arc<str>,
115    /// Name of the table.
116    pub name: Arc<str>,
117    /// Columns for the table.
118    pub columns: Arc<[SqlServerColumnRaw]>,
119    /// Whether or not CDC is enabled for this table.
120    pub is_cdc_enabled: bool,
121}
122
123/// Description of a column from a table in Microsoft SQL Server.
124#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
125pub struct SqlServerColumnDesc {
126    /// Name of the column.
127    pub name: Arc<str>,
128    /// The intended data type of the this column in Materialize.
129    pub column_type: ColumnType,
130    /// Rust type we should parse the data from a [`tiberius::Row`] as.
131    pub decode_type: SqlServerColumnDecodeType,
132}
133
134impl SqlServerColumnDesc {
135    /// Try creating a [`SqlServerColumnDesc`] from a [`SqlServerColumnRaw`] description.
136    ///
137    /// Returns an error if the upstream column is not compatible with Materialize, e.g. the
138    /// data type doesn't support CDC.
139    pub fn try_new(raw: &SqlServerColumnRaw) -> Result<Self, SqlServerError> {
140        let (scalar_type, decode_type) = parse_data_type(raw)?;
141        Ok(SqlServerColumnDesc {
142            name: Arc::clone(&raw.name),
143            column_type: scalar_type.nullable(raw.is_nullable),
144            decode_type,
145        })
146    }
147}
148
149impl RustType<ProtoSqlServerColumnDesc> for SqlServerColumnDesc {
150    fn into_proto(&self) -> ProtoSqlServerColumnDesc {
151        ProtoSqlServerColumnDesc {
152            name: self.name.to_string(),
153            column_type: Some(self.column_type.into_proto()),
154            decode_type: Some(self.decode_type.into_proto()),
155        }
156    }
157
158    fn from_proto(proto: ProtoSqlServerColumnDesc) -> Result<Self, mz_proto::TryFromProtoError> {
159        Ok(SqlServerColumnDesc {
160            name: proto.name.into(),
161            column_type: proto
162                .column_type
163                .into_rust_if_some("ProtoSqlServerColumnDesc::column_type")?,
164            decode_type: proto
165                .decode_type
166                .into_rust_if_some("ProtoSqlServerColumnDesc::decode_type")?,
167        })
168    }
169}
170
171/// Parse a raw data type from SQL Server into a Materialize [`ScalarType`].
172///
173/// Returns the [`ScalarType`] that we'll map this column to and the [`SqlServerColumnDecodeType`]
174/// that we use to decode the raw value.
175fn parse_data_type(
176    raw: &SqlServerColumnRaw,
177) -> Result<(ScalarType, SqlServerColumnDecodeType), SqlServerError> {
178    let scalar = match raw.data_type.to_lowercase().as_str() {
179        "tinyint" => (ScalarType::Int16, SqlServerColumnDecodeType::U8),
180        "smallint" => (ScalarType::Int16, SqlServerColumnDecodeType::I16),
181        "int" => (ScalarType::Int32, SqlServerColumnDecodeType::I32),
182        "bigint" => (ScalarType::Int64, SqlServerColumnDecodeType::I64),
183        "bit" => (ScalarType::Bool, SqlServerColumnDecodeType::Bool),
184        "decimal" | "numeric" => {
185            // SQL Server supports a precision in the range of [1, 38] and then
186            // the scale is 0 <= scale <= precision.
187            //
188            // Materialize numerics are floating point with a fixed precision of 39.
189            //
190            // See: <https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments>
191            if raw.precision > 38 || raw.scale > raw.precision {
192                tracing::warn!(
193                    "unexpected value from SQL Server, precision of {} and scale of {}",
194                    raw.precision,
195                    raw.scale,
196                );
197            }
198            if raw.precision > 39 {
199                let reason = format!(
200                    "precision of {} is greater than our maximum of 39",
201                    raw.precision
202                );
203                return Err(SqlServerError::UnsupportedDataType {
204                    column_name: raw.name.to_string(),
205                    column_type: raw.data_type.to_string(),
206                    reason,
207                });
208            }
209
210            let raw_scale = usize::cast_from(raw.scale);
211            let max_scale = NumericMaxScale::try_from(raw_scale).map_err(|_| {
212                SqlServerError::UnsupportedDataType {
213                    column_type: raw.data_type.to_string(),
214                    column_name: raw.name.to_string(),
215                    reason: format!("scale of {} is too large", raw.scale),
216                }
217            })?;
218            let column_type = ScalarType::Numeric {
219                max_scale: Some(max_scale),
220            };
221
222            (column_type, SqlServerColumnDecodeType::Numeric)
223        }
224        "real" => (ScalarType::Float32, SqlServerColumnDecodeType::F32),
225        "double" => (ScalarType::Float64, SqlServerColumnDecodeType::F64),
226        "char" | "nchar" | "varchar" | "nvarchar" | "sysname" => {
227            // When the `max_length` is -1 SQL Server will not present us with the "before" value
228            // for updated columns.
229            //
230            // TODO(sql_server3): Support UPSERT semantics for SQL Server.
231            if raw.max_length == -1 {
232                return Err(SqlServerError::UnsupportedDataType {
233                    column_name: raw.name.to_string(),
234                    column_type: raw.data_type.to_string(),
235                    reason: "columns with unlimited size do not support CDC".to_string(),
236                });
237            }
238
239            (ScalarType::String, SqlServerColumnDecodeType::String)
240        }
241        "text" | "ntext" | "image" => {
242            // SQL Server docs indicate this should always be 16. There's no
243            // issue if it's not, but it's good to track.
244            mz_ore::soft_assert_eq_no_log!(raw.max_length, 16);
245
246            // TODO(sql_server3): Support UPSERT semantics for SQL Server.
247            return Err(SqlServerError::UnsupportedDataType {
248                column_name: raw.name.to_string(),
249                column_type: raw.data_type.to_string(),
250                reason: "columns with unlimited size do not support CDC".to_string(),
251            });
252        }
253        "xml" => {
254            // When the `max_length` is -1 SQL Server will not present us with the "before" value
255            // for updated columns.
256            //
257            // TODO(sql_server3): Support UPSERT semantics for SQL Server.
258            if raw.max_length == -1 {
259                return Err(SqlServerError::UnsupportedDataType {
260                    column_name: raw.name.to_string(),
261                    column_type: raw.data_type.to_string(),
262                    reason: "columns with unlimited size do not support CDC".to_string(),
263                });
264            }
265            (ScalarType::String, SqlServerColumnDecodeType::Xml)
266        }
267        "binary" | "varbinary" => {
268            // When the `max_length` is -1 if this column changes as part of an `UPDATE`
269            // or `DELETE` statement, SQL Server will not provide the "old" value for
270            // this column, but we need this value so we can emit a retraction.
271            //
272            // TODO(sql_server3): Support UPSERT semantics for SQL Server.
273            if raw.max_length == -1 {
274                return Err(SqlServerError::UnsupportedDataType {
275                    column_name: raw.name.to_string(),
276                    column_type: raw.data_type.to_string(),
277                    reason: "columns with unlimited size do not support CDC".to_string(),
278                });
279            }
280
281            (ScalarType::Bytes, SqlServerColumnDecodeType::Bytes)
282        }
283        "json" => (ScalarType::Jsonb, SqlServerColumnDecodeType::String),
284        "date" => (ScalarType::Date, SqlServerColumnDecodeType::NaiveDate),
285        "time" => (ScalarType::Time, SqlServerColumnDecodeType::NaiveTime),
286        // TODO(sql_server1): We should probably specify a precision here.
287        "smalldatetime" | "datetime" | "datetime2" => (
288            ScalarType::Timestamp { precision: None },
289            SqlServerColumnDecodeType::NaiveDateTime,
290        ),
291        // TODO(sql_server1): We should probably specify a precision here.
292        "datetimeoffset" => (
293            ScalarType::TimestampTz { precision: None },
294            SqlServerColumnDecodeType::DateTime,
295        ),
296        "uniqueidentifier" => (ScalarType::Uuid, SqlServerColumnDecodeType::Uuid),
297        // TODO(sql_server1): Support more data types.
298        other => {
299            return Err(SqlServerError::UnsupportedDataType {
300                column_type: other.to_string(),
301                column_name: raw.name.to_string(),
302                reason: "unimplemented".to_string(),
303            });
304        }
305    };
306    Ok(scalar)
307}
308
309/// Raw metadata for a column from a table in Microsoft SQL Server.
310///
311/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
312#[derive(Clone, Debug)]
313pub struct SqlServerColumnRaw {
314    /// Name of this column.
315    pub name: Arc<str>,
316    /// Name of the data type.
317    pub data_type: Arc<str>,
318    /// Whether or not the column is nullable.
319    pub is_nullable: bool,
320    /// Maximum length (in bytes) of the column.
321    ///
322    /// For `varchar(max)`, `nvarchar(max)`, `varbinary(max)`, or `xml` this will be `-1`. For
323    /// `text`, `ntext`, and `image` columns this will be 16.
324    ///
325    /// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-columns-transact-sql?view=sql-server-ver16>.
326    ///
327    /// TODO(sql_server2): Validate this value for `json` columns where were introduced
328    /// Azure SQL 2024.
329    pub max_length: i16,
330    /// Precision of the column, if numeric-based; otherwise 0.
331    pub precision: u8,
332    /// Scale of the columns, if numeric-based; otherwise 0.
333    pub scale: u8,
334}
335
336/// Rust type that we should use when reading a column from SQL Server.
337#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Arbitrary)]
338pub enum SqlServerColumnDecodeType {
339    Bool,
340    U8,
341    I16,
342    I32,
343    I64,
344    F32,
345    F64,
346    String,
347    Bytes,
348    /// [`uuid::Uuid`].
349    Uuid,
350    /// [`tiberius::numeric::Numeric`].
351    Numeric,
352    /// [`tiberius::xml::XmlData`].
353    Xml,
354    /// [`chrono::NaiveDate`].
355    NaiveDate,
356    /// [`chrono::NaiveTime`].
357    NaiveTime,
358    /// [`chrono::DateTime`].
359    DateTime,
360    /// [`chrono::NaiveDateTime`].
361    NaiveDateTime,
362}
363
364impl SqlServerColumnDecodeType {
365    /// Decode the column with `name` out of the provided `data`.
366    pub fn decode<'a>(
367        self,
368        data: &'a tiberius::Row,
369        name: &'a str,
370        column: &'a ColumnType,
371    ) -> Result<Datum<'a>, SqlServerError> {
372        let maybe_datum = match (&column.scalar_type, self) {
373            (ScalarType::Bool, SqlServerColumnDecodeType::Bool) => data
374                .try_get(name)
375                .context("bool")?
376                .map(|val: bool| if val { Datum::True } else { Datum::False }),
377            (ScalarType::Int16, SqlServerColumnDecodeType::U8) => data
378                .try_get(name)
379                .context("u8")?
380                .map(|val: u8| Datum::Int16(i16::cast_from(val))),
381            (ScalarType::Int16, SqlServerColumnDecodeType::I16) => {
382                data.try_get(name).context("i16")?.map(Datum::Int16)
383            }
384            (ScalarType::Int32, SqlServerColumnDecodeType::I32) => {
385                data.try_get(name).context("i32")?.map(Datum::Int32)
386            }
387            (ScalarType::Int64, SqlServerColumnDecodeType::I64) => {
388                data.try_get(name).context("i64")?.map(Datum::Int64)
389            }
390            (ScalarType::Float32, SqlServerColumnDecodeType::F32) => data
391                .try_get(name)
392                .context("f32")?
393                .map(|val: f32| Datum::Float32(ordered_float::OrderedFloat(val))),
394            (ScalarType::Float64, SqlServerColumnDecodeType::F64) => data
395                .try_get(name)
396                .context("f64")?
397                .map(|val: f64| Datum::Float64(ordered_float::OrderedFloat(val))),
398            (ScalarType::String, SqlServerColumnDecodeType::String) => {
399                data.try_get(name).context("string")?.map(Datum::String)
400            }
401            (ScalarType::Bytes, SqlServerColumnDecodeType::Bytes) => {
402                data.try_get(name).context("bytes")?.map(Datum::Bytes)
403            }
404            (ScalarType::Uuid, SqlServerColumnDecodeType::Uuid) => {
405                data.try_get(name).context("uuid")?.map(Datum::Uuid)
406            }
407            (ScalarType::Numeric { .. }, SqlServerColumnDecodeType::Numeric) => data
408                .try_get(name)
409                .context("numeric")?
410                .map(|val: tiberius::numeric::Numeric| {
411                    // TODO(sql_server3): Make decimal parsing more performant.
412                    let numeric = Numeric::context()
413                        .parse(val.to_string())
414                        .context("parsing")?;
415                    Ok::<_, SqlServerError>(Datum::Numeric(OrderedDecimal(numeric)))
416                })
417                .transpose()?,
418            (ScalarType::String, SqlServerColumnDecodeType::Xml) => data
419                .try_get(name)
420                .context("xml")?
421                .map(|val: &tiberius::xml::XmlData| Datum::String(val.as_ref())),
422            (ScalarType::Date, SqlServerColumnDecodeType::NaiveDate) => data
423                .try_get(name)
424                .context("date")?
425                .map(|val: chrono::NaiveDate| {
426                    let date = val.try_into().context("parse date")?;
427                    Ok::<_, SqlServerError>(Datum::Date(date))
428                })
429                .transpose()?,
430            // TODO(sql_server1): SQL Server's time related types support a resolution
431            // of 100 nanoseconds, while Postgres supports 1,000 nanoseconds (aka 1 microsecond).
432            //
433            // Internally we can support 1 nanosecond precision, but we should exercise
434            // this case and see what breaks.
435            (ScalarType::Time, SqlServerColumnDecodeType::NaiveTime) => {
436                data.try_get(name).context("time")?.map(Datum::Time)
437            }
438            (ScalarType::Timestamp { .. }, SqlServerColumnDecodeType::NaiveDateTime) => data
439                .try_get(name)
440                .context("timestamp")?
441                .map(|val: chrono::NaiveDateTime| {
442                    let ts = val.try_into().context("parse timestamp")?;
443                    Ok::<_, SqlServerError>(Datum::Timestamp(ts))
444                })
445                .transpose()?,
446            (ScalarType::TimestampTz { .. }, SqlServerColumnDecodeType::DateTime) => data
447                .try_get(name)
448                .context("timestamptz")?
449                .map(|val: chrono::DateTime<chrono::Utc>| {
450                    let ts = val.try_into().context("parse timestamptz")?;
451                    Ok::<_, SqlServerError>(Datum::TimestampTz(ts))
452                })
453                .transpose()?,
454            (column_type, decode_type) => {
455                let msg = format!("don't know how to parse {decode_type:?} as {column_type:?}");
456                return Err(SqlServerError::ProgrammingError(msg));
457            }
458        };
459
460        match (maybe_datum, column.nullable) {
461            (Some(datum), _) => Ok(datum),
462            (None, true) => Ok(Datum::Null),
463            (None, false) => Err(SqlServerError::InvalidData {
464                column_name: name.to_string(),
465                error: "found Null in non-nullable column".to_string(),
466            }),
467        }
468    }
469}
470
471impl RustType<proto_sql_server_column_desc::DecodeType> for SqlServerColumnDecodeType {
472    fn into_proto(&self) -> proto_sql_server_column_desc::DecodeType {
473        match self {
474            SqlServerColumnDecodeType::Bool => proto_sql_server_column_desc::DecodeType::Bool(()),
475            SqlServerColumnDecodeType::U8 => proto_sql_server_column_desc::DecodeType::U8(()),
476            SqlServerColumnDecodeType::I16 => proto_sql_server_column_desc::DecodeType::I16(()),
477            SqlServerColumnDecodeType::I32 => proto_sql_server_column_desc::DecodeType::I32(()),
478            SqlServerColumnDecodeType::I64 => proto_sql_server_column_desc::DecodeType::I64(()),
479            SqlServerColumnDecodeType::F32 => proto_sql_server_column_desc::DecodeType::F32(()),
480            SqlServerColumnDecodeType::F64 => proto_sql_server_column_desc::DecodeType::F64(()),
481            SqlServerColumnDecodeType::String => {
482                proto_sql_server_column_desc::DecodeType::String(())
483            }
484            SqlServerColumnDecodeType::Bytes => proto_sql_server_column_desc::DecodeType::Bytes(()),
485            SqlServerColumnDecodeType::Uuid => proto_sql_server_column_desc::DecodeType::Uuid(()),
486            SqlServerColumnDecodeType::Numeric => {
487                proto_sql_server_column_desc::DecodeType::Numeric(())
488            }
489            SqlServerColumnDecodeType::Xml => proto_sql_server_column_desc::DecodeType::Xml(()),
490            SqlServerColumnDecodeType::NaiveDate => {
491                proto_sql_server_column_desc::DecodeType::NaiveDate(())
492            }
493            SqlServerColumnDecodeType::NaiveTime => {
494                proto_sql_server_column_desc::DecodeType::NaiveTime(())
495            }
496            SqlServerColumnDecodeType::DateTime => {
497                proto_sql_server_column_desc::DecodeType::DateTime(())
498            }
499            SqlServerColumnDecodeType::NaiveDateTime => {
500                proto_sql_server_column_desc::DecodeType::NaiveDateTime(())
501            }
502        }
503    }
504
505    fn from_proto(
506        proto: proto_sql_server_column_desc::DecodeType,
507    ) -> Result<Self, mz_proto::TryFromProtoError> {
508        let val = match proto {
509            proto_sql_server_column_desc::DecodeType::Bool(()) => SqlServerColumnDecodeType::Bool,
510            proto_sql_server_column_desc::DecodeType::U8(()) => SqlServerColumnDecodeType::U8,
511            proto_sql_server_column_desc::DecodeType::I16(()) => SqlServerColumnDecodeType::I16,
512            proto_sql_server_column_desc::DecodeType::I32(()) => SqlServerColumnDecodeType::I32,
513            proto_sql_server_column_desc::DecodeType::I64(()) => SqlServerColumnDecodeType::I64,
514            proto_sql_server_column_desc::DecodeType::F32(()) => SqlServerColumnDecodeType::F32,
515            proto_sql_server_column_desc::DecodeType::F64(()) => SqlServerColumnDecodeType::F64,
516            proto_sql_server_column_desc::DecodeType::String(()) => {
517                SqlServerColumnDecodeType::String
518            }
519            proto_sql_server_column_desc::DecodeType::Bytes(()) => SqlServerColumnDecodeType::Bytes,
520            proto_sql_server_column_desc::DecodeType::Uuid(()) => SqlServerColumnDecodeType::Uuid,
521            proto_sql_server_column_desc::DecodeType::Numeric(()) => {
522                SqlServerColumnDecodeType::Numeric
523            }
524            proto_sql_server_column_desc::DecodeType::Xml(()) => SqlServerColumnDecodeType::Xml,
525            proto_sql_server_column_desc::DecodeType::NaiveDate(()) => {
526                SqlServerColumnDecodeType::NaiveDate
527            }
528            proto_sql_server_column_desc::DecodeType::NaiveTime(()) => {
529                SqlServerColumnDecodeType::NaiveTime
530            }
531            proto_sql_server_column_desc::DecodeType::DateTime(()) => {
532                SqlServerColumnDecodeType::DateTime
533            }
534            proto_sql_server_column_desc::DecodeType::NaiveDateTime(()) => {
535                SqlServerColumnDecodeType::NaiveDateTime
536            }
537        };
538        Ok(val)
539    }
540}
541
542/// A decoder from [`tiberius::Row`] to [`mz_repr::Row`].
543///
544/// The goal of this type is to perform any expensive "downcasts" so in the hot
545/// path of decoding rows we do the minimal amount of work.
546pub struct SqlServerRowDecoder {
547    decoders: Vec<(Arc<str>, ColumnType, SqlServerColumnDecodeType)>,
548}
549
550impl SqlServerRowDecoder {
551    /// Try to create a [`SqlServerRowDecoder`] that will decode [`tiberius::Row`]s that match
552    /// the shape of the provided [`SqlServerTableDesc`], to [`mz_repr::Row`]s that match the
553    /// shape of the provided [`RelationDesc`].
554    pub fn try_new(
555        table: &SqlServerTableDesc,
556        desc: &RelationDesc,
557    ) -> Result<Self, SqlServerError> {
558        let decoders = desc
559            .iter()
560            .map(|(col_name, col_type)| {
561                let sql_server_col = table
562                    .columns
563                    .iter()
564                    .find(|col| col.name.as_ref() == col_name.as_str())
565                    .ok_or_else(|| {
566                        // TODO(sql_server2): Structured Error.
567                        anyhow::anyhow!("no SQL Server column with name {col_name} found")
568                    })?;
569
570                // This shouldn't be true, but be defensive.
571                //
572                // TODO(sql_server2): Maybe allow the Materialize column type to be more nullable
573                // than our decoding type?
574                if &sql_server_col.column_type != col_type {
575                    return Err(SqlServerError::ProgrammingError(format!(
576                        "programming error, {col_name} has mismatched type {:?} vs {:?}",
577                        sql_server_col.column_type, col_type
578                    )));
579                }
580
581                let name = Arc::clone(&sql_server_col.name);
582                let decoder = sql_server_col.decode_type;
583
584                Ok::<_, SqlServerError>((name, col_type.clone(), decoder))
585            })
586            .collect::<Result<_, _>>()?;
587
588        Ok(SqlServerRowDecoder { decoders })
589    }
590
591    /// Decode data from the provided [`tiberius::Row`] into the provided [`Row`].
592    pub fn decode(&self, data: &tiberius::Row, row: &mut Row) -> Result<(), SqlServerError> {
593        let mut packer = row.packer();
594        for (col_name, col_type, decoder) in &self.decoders {
595            let datum = decoder.decode(data, col_name, col_type)?;
596            packer.push(datum);
597        }
598        Ok(())
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use crate::desc::{
605        SqlServerColumnDecodeType, SqlServerColumnDesc, SqlServerTableDesc, SqlServerTableRaw,
606    };
607
608    use super::SqlServerColumnRaw;
609    use mz_ore::assert_contains;
610    use mz_repr::adt::numeric::NumericMaxScale;
611    use mz_repr::{Datum, RelationDesc, Row, ScalarType};
612    use tiberius::RowTestExt;
613
614    impl SqlServerColumnRaw {
615        /// Create a new [`SqlServerColumnRaw`]. The specified `data_type` is
616        /// _not_ checked for validity.
617        fn new(name: &str, data_type: &str) -> Self {
618            SqlServerColumnRaw {
619                name: name.into(),
620                data_type: data_type.into(),
621                is_nullable: false,
622                max_length: 0,
623                precision: 0,
624                scale: 0,
625            }
626        }
627
628        fn nullable(mut self, nullable: bool) -> Self {
629            self.is_nullable = nullable;
630            self
631        }
632
633        fn max_length(mut self, max_length: i16) -> Self {
634            self.max_length = max_length;
635            self
636        }
637
638        fn precision(mut self, precision: u8) -> Self {
639            self.precision = precision;
640            self
641        }
642
643        fn scale(mut self, scale: u8) -> Self {
644            self.scale = scale;
645            self
646        }
647    }
648
649    #[mz_ore::test]
650    fn smoketest_column_raw() {
651        let raw = SqlServerColumnRaw::new("foo", "bit");
652        let col = SqlServerColumnDesc::try_new(&raw).unwrap();
653
654        assert_eq!(&*col.name, "foo");
655        assert_eq!(col.column_type, ScalarType::Bool.nullable(false));
656        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Bool);
657
658        let raw = SqlServerColumnRaw::new("foo", "decimal")
659            .precision(20)
660            .scale(10);
661        let col = SqlServerColumnDesc::try_new(&raw).unwrap();
662
663        let col_type = ScalarType::Numeric {
664            max_scale: Some(NumericMaxScale::try_from(10i64).expect("known valid")),
665        }
666        .nullable(false);
667        assert_eq!(col.column_type, col_type);
668        assert_eq!(col.decode_type, SqlServerColumnDecodeType::Numeric);
669    }
670
671    #[mz_ore::test]
672    fn smoketest_column_raw_invalid() {
673        let raw = SqlServerColumnRaw::new("foo", "bad_data_type");
674        let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
675        assert_contains!(err.to_string(), "'bad_data_type' from column 'foo'");
676
677        let raw = SqlServerColumnRaw::new("foo", "decimal")
678            .precision(100)
679            .scale(10);
680        let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
681        assert_contains!(
682            err.to_string(),
683            "precision of 100 is greater than our maximum of 39"
684        );
685
686        let raw = SqlServerColumnRaw::new("foo", "varchar").max_length(-1);
687        let err = SqlServerColumnDesc::try_new(&raw).unwrap_err();
688        assert_contains!(
689            err.to_string(),
690            "columns with unlimited size do not support CDC"
691        );
692    }
693
694    #[mz_ore::test]
695    fn smoketest_decoder() {
696        let sql_server_columns = [
697            SqlServerColumnRaw::new("a", "varchar"),
698            SqlServerColumnRaw::new("b", "int").nullable(true),
699            SqlServerColumnRaw::new("c", "bit"),
700        ];
701        let sql_server_desc = SqlServerTableRaw {
702            schema_name: "my_schema".into(),
703            name: "my_table".into(),
704            columns: sql_server_columns.into(),
705            is_cdc_enabled: true,
706        };
707        let sql_server_desc = SqlServerTableDesc::try_new(sql_server_desc).expect("known valid");
708
709        let relation_desc = RelationDesc::builder()
710            .with_column("a", ScalarType::String.nullable(false))
711            // Note: In the upstream table 'c' is ordered after 'b'.
712            .with_column("c", ScalarType::Bool.nullable(false))
713            .with_column("b", ScalarType::Int32.nullable(true))
714            .finish();
715
716        // This decoder should shape the SQL Server Rows into Rows compatible with the RelationDesc.
717        let decoder = sql_server_desc
718            .decoder(&relation_desc)
719            .expect("known valid");
720
721        let sql_server_columns = [
722            tiberius::Column::new("a".to_string(), tiberius::ColumnType::BigVarChar),
723            tiberius::Column::new("b".to_string(), tiberius::ColumnType::Int4),
724            tiberius::Column::new("c".to_string(), tiberius::ColumnType::Bit),
725        ];
726
727        let data_a = [
728            tiberius::ColumnData::String(Some("hello world".into())),
729            tiberius::ColumnData::I32(Some(42)),
730            tiberius::ColumnData::Bit(Some(true)),
731        ];
732        let sql_server_row_a =
733            tiberius::Row::build(sql_server_columns.iter().cloned().zip(data_a.into_iter()));
734
735        let data_b = [
736            tiberius::ColumnData::String(Some("foo bar".into())),
737            tiberius::ColumnData::I32(None),
738            tiberius::ColumnData::Bit(Some(false)),
739        ];
740        let sql_server_row_b =
741            tiberius::Row::build(sql_server_columns.into_iter().zip(data_b.into_iter()));
742
743        let mut rnd_row = Row::default();
744        decoder.decode(&sql_server_row_a, &mut rnd_row).unwrap();
745        assert_eq!(
746            &rnd_row,
747            &Row::pack_slice(&[Datum::String("hello world"), Datum::True, Datum::Int32(42)])
748        );
749
750        decoder.decode(&sql_server_row_b, &mut rnd_row).unwrap();
751        assert_eq!(
752            &rnd_row,
753            &Row::pack_slice(&[Datum::String("foo bar"), Datum::False, Datum::Null])
754        );
755    }
756
757    // TODO(sql_server2): Proptest the decoder.
758}