Skip to main content

mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27    SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName,
28    SqlServerTableConstraintRaw, SqlServerTableRaw,
29};
30use crate::{Client, SqlServerError, quote_identifier};
31
32/// Returns the minimum log sequence number for the specified `capture_instance`.
33///
34/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
35pub async fn get_min_lsn(
36    client: &mut Client,
37    capture_instance: &str,
38) -> Result<Lsn, SqlServerError> {
39    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
40    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
41
42    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
43    parse_lsn(&result[..1])
44}
45/// Returns the minimum log sequence number for the specified `capture_instance`, retrying
46/// if the log sequence number is not available.
47///
48/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
49pub async fn get_min_lsn_retry(
50    client: &mut Client,
51    capture_instance: &str,
52    max_retry_duration: Duration,
53) -> Result<Lsn, SqlServerError> {
54    let (_client, lsn_result) = mz_ore::retry::Retry::default()
55        .max_duration(max_retry_duration)
56        .retry_async_with_state(client, |_, client| async {
57            let result = crate::inspect::get_min_lsn(client, capture_instance).await;
58            (client, map_null_lsn_to_retry(result))
59        })
60        .await;
61    let Ok(lsn) = lsn_result else {
62        tracing::warn!("database did not report a minimum LSN in time");
63        return lsn_result;
64    };
65    Ok(lsn)
66}
67
68/// Returns the maximum log sequence number for the entire database.
69/// This implementation relies on CDC, which is asynchronous, so may
70/// return an LSN that is less than the maximum LSN of SQL server.
71///
72/// See:
73/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
74/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
75pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
76    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
77    let result = client.simple_query(MAX_LSN_QUERY).await?;
78
79    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
80    parse_lsn(&result[..1])
81}
82
83/// Retrieves the minumum [`Lsn`] (start_lsn field) from `cdc.change_tables`
84/// for the specified capture instances.
85///
86/// This is based on the `sys.fn_cdc_get_min_lsn` implementation, which has logic
87/// that we want to bypass. Specifically, `sys.fn_cdc_get_min_lsn` returns NULL
88/// if the `start_lsn` in `cdc.change_tables` is less than or equal to the LSN
89/// returned by `sys.fn_cdc_get_max_lsn`.
90///
91/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver16>
92pub async fn get_min_lsns(
93    client: &mut Client,
94    capture_instances: impl IntoIterator<Item = &str>,
95) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
96    let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
97    let values: Vec<_> = capture_instances
98        .iter()
99        .map(|ci| {
100            let ci: &dyn tiberius::ToSql = ci;
101            ci
102        })
103        .collect();
104    let args = (0..capture_instances.len())
105        .map(|i| format!("@P{}", i + 1))
106        .collect::<Vec<_>>()
107        .join(",");
108    let stmt = format!(
109        "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
110    );
111    let result = client.query(stmt, &values).await?;
112    let min_lsns = result
113        .into_iter()
114        .map(|row| {
115            let capture_instance: Arc<str> = row
116                .try_get::<&str, _>("capture_instance")?
117                .ok_or_else(|| {
118                    SqlServerError::ProgrammingError(
119                        "missing column 'capture_instance'".to_string(),
120                    )
121                })?
122                .into();
123            let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
124                SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
125            })?;
126            let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
127                column_name: "lsn".to_string(),
128                error: format!("Error parsing LSN for {capture_instance}: {msg}"),
129            })?;
130            Ok::<_, SqlServerError>((capture_instance, min_lsn))
131        })
132        .collect::<Result<_, _>>()?;
133
134    Ok(min_lsns)
135}
136
137/// Returns the maximum log sequence number for the entire database, retrying
138/// if the log sequence number is not available. This implementation relies on
139/// CDC, which is asynchronous, so may return an LSN that is less than the
140/// maximum LSN of SQL server.
141///
142/// See:
143/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
144/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
145pub async fn get_max_lsn_retry(
146    client: &mut Client,
147    max_retry_duration: Duration,
148) -> Result<Lsn, SqlServerError> {
149    let (_client, lsn_result) = mz_ore::retry::Retry::default()
150        .max_duration(max_retry_duration)
151        .retry_async_with_state(client, |_, client| async {
152            let result = crate::inspect::get_max_lsn(client).await;
153            (client, map_null_lsn_to_retry(result))
154        })
155        .await;
156
157    let Ok(lsn) = lsn_result else {
158        tracing::warn!("database did not report a maximum LSN in time");
159        return lsn_result;
160    };
161    Ok(lsn)
162}
163
164fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
165    match result {
166        Ok(val) => RetryResult::Ok(val),
167        Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
168        Err(other) => RetryResult::FatalErr(other),
169    }
170}
171
172/// Increments the log sequence number.
173///
174/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
175pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
176    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
177    let result = client
178        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
179        .await?;
180
181    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
182    parse_lsn(&result[..1])
183}
184
185/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
186///
187/// Returns an error if the provided slice doesn't have exactly one row.
188pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
189    match row {
190        [r] => {
191            let numeric_lsn = r
192                .try_get::<Numeric, _>(0)?
193                .ok_or_else(|| SqlServerError::NullLsn)?;
194            let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
195                column_name: "lsn".to_string(),
196                error: msg,
197            })?;
198            Ok(lsn)
199        }
200        other => Err(SqlServerError::InvalidData {
201            column_name: "lsn".to_string(),
202            error: format!("expected 1 column, got {other:?}"),
203        }),
204    }
205}
206
207/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
208///
209/// Returns an error if the provided slice doesn't have exactly one row.
210fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
211    match result {
212        [row] => {
213            let val = row
214                .try_get::<&[u8], _>(0)?
215                .ok_or_else(|| SqlServerError::NullLsn)?;
216            if val.is_empty() {
217                Err(SqlServerError::NullLsn)
218            } else {
219                let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
220                    column_name: "lsn".to_string(),
221                    error: msg,
222                })?;
223                Ok(lsn)
224            }
225        }
226        other => Err(SqlServerError::InvalidData {
227            column_name: "lsn".to_string(),
228            error: format!("expected 1 column, got {other:?}"),
229        }),
230    }
231}
232
233/// Queries the specified capture instance and returns all changes from
234/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
235pub fn get_changes_asc(
236    client: &mut Client,
237    capture_instance: &str,
238    start_lsn: Lsn,
239    end_lsn: Lsn,
240    filter: RowFilterOption,
241) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
242    const START_LSN_COLUMN: &str = "__$start_lsn";
243    let query = format!(
244        "SELECT * FROM cdc.{function}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;",
245        function = quote_identifier(&format!("fn_cdc_get_all_changes_{capture_instance}"))
246    );
247    client.query_streaming(
248        query,
249        &[
250            &start_lsn.as_bytes().as_slice(),
251            &end_lsn.as_bytes().as_slice(),
252        ],
253    )
254}
255
256/// Cleans up the change table associated with the specified `capture_instance` by
257/// deleting `max_deletes` entries with a `start_lsn` less than `low_water_mark`.
258///
259/// Note: At the moment cleanup is kind of "best effort".  If this query succeeds
260/// then at most `max_delete` rows were deleted, but the number of actual rows
261/// deleted is not returned as part of the query. The number of rows _should_ be
262/// present in an informational message (i.e. a Notice) that is returned, but
263/// [`tiberius`] doesn't expose these to us.
264///
265/// TODO(sql_server2): Update [`tiberius`] to return informational messages so we
266/// can determine how many rows got deleted.
267///
268/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16>.
269pub async fn cleanup_change_table(
270    client: &mut Client,
271    capture_instance: &str,
272    low_water_mark: &Lsn,
273    max_deletes: u32,
274) -> Result<(), SqlServerError> {
275    static GET_LSN_QUERY: &str =
276        "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
277    static CLEANUP_QUERY: &str = "
278DECLARE @mz_cleanup_status_bit BIT;
279SET @mz_cleanup_status_bit = 0;
280EXEC sys.sp_cdc_cleanup_change_table
281    @capture_instance = @P1,
282    @low_water_mark = @P2,
283    @threshold = @P3,
284    @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
285SELECT @mz_cleanup_status_bit;
286    ";
287
288    let max_deletes = i64::cast_from(max_deletes);
289
290    // First we need to get a valid LSN as our low watermark. If we try to cleanup
291    // a change table with an LSN that doesn't exist in the `cdc.lsn_time_mapping`
292    // table we'll get an error code `22964`.
293    let result = client
294        .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
295        .await?;
296    let low_water_mark_to_use = match &result[..] {
297        [row] => row
298            .try_get::<&[u8], _>(0)?
299            .ok_or_else(|| SqlServerError::InvalidData {
300                column_name: "mz_cleanup_status_bit".to_string(),
301                error: "expected a bool, found NULL".to_string(),
302            })?,
303        other => Err(SqlServerError::ProgrammingError(format!(
304            "expected one row for low water mark, found {other:?}"
305        )))?,
306    };
307
308    // Once we get a valid LSN that is less than or equal to the provided watermark
309    // we can clean up the specified change table!
310    let result = client
311        .query(
312            CLEANUP_QUERY,
313            &[&capture_instance, &low_water_mark_to_use, &max_deletes],
314        )
315        .await;
316
317    let rows = match result {
318        Ok(rows) => rows,
319        Err(SqlServerError::SqlServer(e)) => {
320            // See these remarks from the SQL Server Documentation.
321            //
322            // <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16#remarks>.
323            let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
324
325            if already_cleaned_up {
326                return Ok(());
327            } else {
328                return Err(SqlServerError::SqlServer(e));
329            }
330        }
331        Err(other) => return Err(other),
332    };
333
334    match &rows[..] {
335        [row] => {
336            let failure =
337                row.try_get::<bool, _>(0)?
338                    .ok_or_else(|| SqlServerError::InvalidData {
339                        column_name: "mz_cleanup_status_bit".to_string(),
340                        error: "expected a bool, found NULL".to_string(),
341                    })?;
342
343            if failure {
344                Err(super::cdc::CdcError::CleanupFailed {
345                    capture_instance: capture_instance.to_string(),
346                    low_water_mark: *low_water_mark,
347                })?
348            } else {
349                Ok(())
350            }
351        }
352        other => Err(SqlServerError::ProgrammingError(format!(
353            "expected one status row, found {other:?}"
354        ))),
355    }
356}
357
358// Retrieves all columns in tables that have CDC (Change Data Capture) enabled.
359//
360// Returns metadata needed to create an instance of ['SqlServerTableRaw`].
361//
362// The query joins several system tables:
363// - sys.tables: Source tables in the database
364// - sys.schemas: Schema information for proper table identification
365// - sys.columns: Column definitions including nullability
366// - sys.types: Data type information for each column
367// - cdc.change_tables: CDC configuration linking capture instances to source tables
368//
369// For each column, it returns:
370// - Table identification (schema_name, table_name, capture_instance)
371// - Column metadata (name, type, nullable, max_length, precision, scale)
372static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
373SELECT
374    s.name as schema_name,
375    t.name as table_name,
376    ch.capture_instance as capture_instance,
377    ch.create_date as capture_instance_create_date,
378    c.name as col_name,
379    ty.name as col_type,
380    c.is_nullable as col_nullable,
381    c.max_length as col_max_length,
382    c.precision as col_precision,
383    c.scale as col_scale,
384    c.is_computed as col_is_computed
385FROM sys.tables t
386JOIN sys.schemas s ON t.schema_id = s.schema_id
387JOIN sys.columns c ON t.object_id = c.object_id
388JOIN sys.types ty ON c.user_type_id = ty.user_type_id
389JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
390";
391
392/// Returns the table metadata for the tables that are tracked by the specified `capture_instance`s.
393pub async fn get_tables_for_capture_instance(
394    client: &mut Client,
395    capture_instances: impl IntoIterator<Item = &str>,
396) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
397    // SQL Server does not have support for array types, so we need to manually construct
398    // the parameterized query.
399    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
400    // If there are no tables to check for just return an empty list.
401    if params.is_empty() {
402        return Ok(Vec::default());
403    }
404
405    // TODO(sql_server3): Remove this redundant collection.
406    #[allow(clippy::as_conversions)]
407    let params_dyn: SmallVec<[_; 1]> = params
408        .iter()
409        .map(|instance| instance as &dyn tiberius::ToSql)
410        .collect();
411    let param_indexes = params
412        .iter()
413        .enumerate()
414        // Params are 1-based indexed.
415        .map(|(idx, _)| format!("@P{}", idx + 1))
416        .join(", ");
417
418    let table_for_capture_instance_query = format!(
419        "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
420    );
421
422    let result = client
423        .query(&table_for_capture_instance_query, &params_dyn[..])
424        .await?;
425
426    let tables = deserialize_table_columns_to_raw_tables(&result)?;
427
428    Ok(tables)
429}
430
431/// Retrieves column metdata from the CDC table maintained by the provided capture instance. The
432/// resulting column information collection is similar to the information collected for the
433/// upstream table, with the exclusion of nullability and primary key constraints, which contain
434/// static values for CDC columns. CDC table schema is automatically generated and does not attempt
435/// to enforce the same constraints on the data as the upstream table.
436pub async fn get_cdc_table_columns(
437    client: &mut Client,
438    capture_instance: &str,
439) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
440    static CDC_COLUMNS_QUERY: &str = "SELECT \
441        c.name AS col_name, \
442        t.name AS col_type, \
443        c.max_length AS col_max_length, \
444        c.precision AS col_precision, \
445        c.scale AS col_scale, \
446        c.is_computed as col_is_computed \
447    FROM \
448        sys.columns AS c \
449    JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
450    WHERE \
451        c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
452    ORDER BY c.column_id;";
453    // Strings passed into OBJECT_ID must be escaped
454    let cdc_table_name = format!(
455        "cdc.{table_name}",
456        table_name = quote_identifier(&format!("{capture_instance}_CT"))
457    );
458    let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
459    let mut columns = BTreeMap::new();
460    for row in result.iter() {
461        let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
462        // Reusing this struct even though some of the fields aren't needed because it simplifies
463        // comparison with the upstream table metadata
464        let column = SqlServerColumnRaw {
465            name: Arc::clone(&column_name),
466            data_type: get_value::<&str>(row, "col_type")?.into(),
467            is_nullable: true,
468            max_length: get_value(row, "col_max_length")?,
469            precision: get_value(row, "col_precision")?,
470            scale: get_value(row, "col_scale")?,
471            is_computed: get_value(row, "col_is_computed")?,
472        };
473        columns.insert(column_name, column);
474    }
475    Ok(columns)
476}
477
478/// Retrieve primary key and unique constraints for the given tables.  Tables should be provided as
479/// an interator of tuples, where each tuple contains is `(schema_name, table_name)`.
480pub async fn get_constraints_for_tables(
481    client: &mut Client,
482    schema_table_list: impl Iterator<Item = &(Arc<str>, Arc<str>)>,
483) -> Result<BTreeMap<(Arc<str>, Arc<str>), Vec<SqlServerTableConstraintRaw>>, SqlServerError> {
484    let qualified_table_names: Vec<_> = schema_table_list
485        .map(|(schema, table)| {
486            format!(
487                "{quoted_schema}.{quoted_table}",
488                quoted_schema = quote_identifier(schema),
489                quoted_table = quote_identifier(table)
490            )
491        })
492        .collect();
493
494    if qualified_table_names.is_empty() {
495        return Ok(Default::default());
496    }
497
498    let params = (1..qualified_table_names.len() + 1)
499        .map(|idx| format!("@P{}", idx))
500        .join(", ");
501
502    // KEY_COLUMN_USAGE (not CONSTRAINT_COLUMN_USAGE) because it exposes
503    // ORDINAL_POSITION, letting us preserve composite-key column order.
504    let query = format!(
505        "SELECT \
506        tc.table_schema, \
507        tc.table_name, \
508        kcu.column_name, \
509        tc.constraint_name, \
510        tc.constraint_type \
511    FROM information_schema.table_constraints tc \
512    JOIN information_schema.key_column_usage kcu \
513        ON kcu.constraint_schema = tc.constraint_schema \
514        AND kcu.constraint_name = tc.constraint_name \
515        AND kcu.table_schema = tc.table_schema \
516        AND kcu.table_name = tc.table_name \
517    WHERE
518        QUOTENAME(tc.table_schema) + '.' + QUOTENAME(tc.table_name) IN ({params})
519        AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
520    ORDER BY tc.table_schema, tc.table_name, tc.constraint_name, kcu.ordinal_position"
521    );
522
523    let query_params: Vec<_> = qualified_table_names
524        .iter()
525        .map(|qualified_name| {
526            let name: &dyn tiberius::ToSql = qualified_name;
527            name
528        })
529        .collect();
530
531    tracing::debug!("query = {query} params = {qualified_table_names:?}");
532    let result = client.query(query, &query_params).await?;
533
534    let mut contraints_by_table: BTreeMap<_, BTreeMap<_, Vec<_>>> = BTreeMap::new();
535    for row in result {
536        let schema_name: Arc<str> = get_value::<&str>(&row, "table_schema")?.into();
537        let table_name: Arc<str> = get_value::<&str>(&row, "table_name")?.into();
538        let column_name = get_value::<&str>(&row, "column_name")?.into();
539        let constraint_name = get_value::<&str>(&row, "constraint_name")?.into();
540        let constraint_type = get_value::<&str>(&row, "constraint_type")?.into();
541
542        contraints_by_table
543            .entry((Arc::clone(&schema_name), Arc::clone(&table_name)))
544            .or_default()
545            .entry((constraint_name, constraint_type))
546            .or_default()
547            .push(column_name);
548    }
549    Ok(contraints_by_table
550        .into_iter()
551        .inspect(|((schema_name, table_name), constraints)| {
552            tracing::debug!("table {schema_name}.{table_name} constraints: {constraints:?}")
553        })
554        .map(|(qualified_name, constraints)| {
555            (
556                qualified_name,
557                constraints
558                    .into_iter()
559                    .map(|((constraint_name, constraint_type), columns)| {
560                        SqlServerTableConstraintRaw {
561                            constraint_name,
562                            constraint_type,
563                            columns,
564                        }
565                    })
566                    .collect(),
567            )
568        })
569        .collect())
570}
571
572/// Ensure change data capture (CDC) is enabled for the database the provided
573/// `client` is currently connected to.
574///
575/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/track-changes/enable-and-disable-change-data-capture-sql-server?view=sql-server-ver16>
576pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
577    static DATABASE_CDC_ENABLED_QUERY: &str =
578        "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
579    let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
580
581    check_system_result(&result, "database CDC".to_string(), true)?;
582    Ok(())
583}
584
585/// Retrieves the largest `restore_history_id` from SQL Server for the current database.  The
586/// `restore_history_id` column is of type `IDENTITY(1,1)` based on `EXEC sp_help restorehistory`.
587/// We expect it to start at 1 and be incremented by 1, with possible gaps in values.
588/// See:
589/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/restorehistory-transact-sql?view=sql-server-ver17>
590/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property?view=sql-server-ver17>
591pub async fn get_latest_restore_history_id(
592    client: &mut Client,
593) -> Result<Option<i32>, SqlServerError> {
594    static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
595        FROM msdb.dbo.restorehistory \
596        WHERE destination_database_name = DB_NAME() \
597        ORDER BY restore_history_id DESC;";
598    let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
599
600    match &result[..] {
601        [] => Ok(None),
602        [row] => Ok(row.try_get::<i32, _>(0)?),
603        other => Err(SqlServerError::InvariantViolated(format!(
604            "expected one row, got {other:?}"
605        ))),
606    }
607}
608
609/// A DDL event collected from the `cdc.ddl_history` table.
610#[derive(Debug)]
611pub struct DDLEvent {
612    pub lsn: Lsn,
613    pub ddl_command: Arc<str>,
614}
615
616impl DDLEvent {
617    /// Returns true if the DDL event is a compatible change, or false if it is not.
618    /// This performs a naive parsing of the DDL command looking for modification of columns
619    ///  1. ALTER TABLE .. ALTER COLUMN
620    ///  2. ALTER TABLE .. DROP COLUMN
621    ///
622    /// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/alter-table-transact-sql?view=sql-server-ver17>
623    pub fn is_compatible(&self, included_columns: &[Arc<str>]) -> bool {
624        // TODO (maz): This is currently a basic check that doesn't take into account type changes.
625        // At some point, we will need to move this to SqlServerTableDesc and expand it.
626        let mut words = self.ddl_command.split_ascii_whitespace();
627        match (
628            words.next().map(str::to_ascii_lowercase).as_deref(),
629            words.next().map(str::to_ascii_lowercase).as_deref(),
630        ) {
631            (Some("alter"), Some("table")) => {
632                let mut peekable = words.peekable();
633                let mut compatible = true;
634                while compatible && let Some(token) = peekable.next() {
635                    compatible = match token.to_ascii_lowercase().as_str() {
636                        "alter" | "drop" => {
637                            let target = peekable.next();
638                            match target {
639                                // Targeting a column
640                                Some(t) if t.eq_ignore_ascii_case("column") => {
641                                    let mut all_excluded = true;
642                                    while let Some(tok) = peekable.next() {
643                                        // The column name(s) can be preceeded by the pair of keywords "IF EXISTS", so we want to skip those.
644                                        match tok.to_ascii_lowercase().as_str() {
645                                            "if" | "exists" | "," | "column" => continue,
646                                            col_str => {
647                                                // If any column is in the included list, then it is not okay to alter/drop it
648                                                // The col_str token may be a comma-separated list of columns as whitespace is not required
649                                                // between column names in SQL Server DDL.
650                                                if !col_str.trim_matches(',').split(',').all(
651                                                    |col_name| {
652                                                        !included_columns.iter().any(|included| {
653                                                            included.eq_ignore_ascii_case(
654                                                                col_name.trim_matches(
655                                                                    ['[', ']', '"'].as_ref(),
656                                                                ),
657                                                            )
658                                                        })
659                                                    },
660                                                ) {
661                                                    all_excluded = false;
662                                                    break;
663                                                }
664                                                // If this is the only/last column, then we can break out of the while loop.
665                                                // Check if this string has no trailing comma, and if not, peek to see if the next token
666                                                // contains a leading comma.
667                                                if !col_str.ends_with(",") {
668                                                    match peekable.peek() {
669                                                        Some(x) if x.starts_with(",") => continue,
670                                                        _ => break,
671                                                    }
672                                                }
673                                            }
674                                        };
675                                    }
676                                    all_excluded
677                                }
678                                // No target token after "alter" or "drop"
679                                None => false,
680                                // Other targets are considered compatible
681                                _ => true,
682                            }
683                        }
684                        _ => true,
685                    }
686                }
687                compatible
688            }
689            _ => true,
690        }
691    }
692}
693
694/// Returns DDL changes made to the source table for the given capture instance.  This follows the
695/// same convention as `cdc.fn_cdc_get_all_changes_<capture_instance>`, in that the range is
696/// inclusive, i.e. `[from_lsn, to_lsn]`. The events are returned in ascending order of
697/// LSN.
698pub async fn get_ddl_history(
699    client: &mut Client,
700    capture_instance: &str,
701    from_lsn: &Lsn,
702    to_lsn: &Lsn,
703) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
704    // We query the ddl_history table instead of using the stored procedure as there doesn't
705    // appear to be a way to apply filters or projections against output of the stored procedure
706    // without an intermediate table.
707    static DDL_HISTORY_QUERY: &str = "SELECT \
708                s.name AS schema_name, \
709                t.name AS table_name, \
710                dh.ddl_lsn, \
711                dh.ddl_command
712            FROM \
713                cdc.change_tables ct \
714            JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
715            JOIN sys.tables t ON t.object_id = dh.source_object_id \
716            JOIN sys.schemas s ON s.schema_id = t.schema_id \
717            WHERE \
718                ct.capture_instance = @P1 \
719                AND dh.ddl_lsn >= @P2 \
720                AND dh.ddl_lsn <= @P3 \
721            ORDER BY ddl_lsn;";
722
723    let result = client
724        .query(
725            DDL_HISTORY_QUERY,
726            &[
727                &capture_instance,
728                &from_lsn.as_bytes().as_slice(),
729                &to_lsn.as_bytes().as_slice(),
730            ],
731        )
732        .await?;
733
734    // SQL server doesn't support array types, and using string_agg to collect LSN
735    // would require more parsing, so we opt for a BTreeMap to accumulate the results.
736    let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
737    for row in result.iter() {
738        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
739        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
740        let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
741        let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
742
743        let qualified_table_name = SqlServerQualifiedTableName {
744            schema_name,
745            table_name,
746        };
747        let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
748            column_name: "ddl_lsn".to_string(),
749            error: lsn_err,
750        })?;
751
752        collector
753            .entry(qualified_table_name)
754            .or_default()
755            .push(DDLEvent { lsn, ddl_command });
756    }
757
758    Ok(collector)
759}
760
761/// Ensure the `SNAPSHOT` transaction isolation level is enabled for the
762/// database the provided `client` is currently connected to.
763///
764/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql?view=sql-server-ver16>
765pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
766    static SNAPSHOT_ISOLATION_QUERY: &str =
767        "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
768    let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
769
770    check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
771    Ok(())
772}
773
774/// Ensure the SQL Server Agent is running.
775///
776/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-dynamic-management-views/sys-dm-server-services-transact-sql?view=azuresqldb-current&viewFallbackFrom=sql-server-ver17>
777pub async fn ensure_sql_server_agent_running(client: &mut Client) -> Result<(), SqlServerError> {
778    static AGENT_STATUS_QUERY: &str = "SELECT status_desc FROM sys.dm_server_services WHERE servicename LIKE 'SQL Server Agent%';";
779    let result = client.simple_query(AGENT_STATUS_QUERY).await?;
780
781    check_system_result(&result, "SQL Server Agent status".to_string(), "Running")?;
782    Ok(())
783}
784
785pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
786    let result = client
787        .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
788        .await?;
789
790    let tables = deserialize_table_columns_to_raw_tables(&result)?;
791
792    Ok(tables)
793}
794
795// Helper function to retrieve value from a row.
796fn get_value<'a, T: tiberius::FromSql<'a>>(
797    row: &'a tiberius::Row,
798    name: &'static str,
799) -> Result<T, SqlServerError> {
800    row.try_get(name)?
801        .ok_or(SqlServerError::MissingColumn(name))
802}
803fn deserialize_table_columns_to_raw_tables(
804    rows: &[tiberius::Row],
805) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
806    // Group our columns by (schema, name).
807    let mut tables = BTreeMap::default();
808    for row in rows {
809        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
810        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
811        let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
812        let capture_instance_create_date: NaiveDateTime =
813            get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
814
815        let column_name = get_value::<&str>(row, "col_name")?.into();
816        let column = SqlServerColumnRaw {
817            name: Arc::clone(&column_name),
818            data_type: get_value::<&str>(row, "col_type")?.into(),
819            is_nullable: get_value(row, "col_nullable")?,
820            max_length: get_value(row, "col_max_length")?,
821            precision: get_value(row, "col_precision")?,
822            scale: get_value(row, "col_scale")?,
823            is_computed: get_value(row, "col_is_computed")?,
824        };
825
826        let columns: &mut Vec<_> = tables
827            .entry((
828                Arc::clone(&schema_name),
829                Arc::clone(&table_name),
830                Arc::clone(&capture_instance),
831                capture_instance_create_date,
832            ))
833            .or_default();
834        columns.push(column);
835    }
836
837    let raw_tables = tables
838        .into_iter()
839        .map(
840            |((schema, name, capture_instance, capture_instance_create_date), columns)| {
841                SqlServerTableRaw {
842                    schema_name: schema,
843                    name,
844                    capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
845                        name: capture_instance,
846                        create_date: capture_instance_create_date.into(),
847                    }),
848                    columns: columns.into(),
849                }
850            },
851        )
852        .collect();
853    Ok(raw_tables)
854}
855
856/// Return a [`Stream`] that is the entire snapshot of the specified table.
857pub fn snapshot(
858    client: &mut Client,
859    table: &SqlServerTableRaw,
860) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
861    let cols = table
862        .columns
863        .iter()
864        .map(|SqlServerColumnRaw { name, .. }| quote_identifier(name))
865        .join(",");
866    let query = format!(
867        "SELECT {cols} FROM {schema_name}.{table_name};",
868        schema_name = quote_identifier(&table.schema_name),
869        table_name = quote_identifier(&table.name)
870    );
871    client.query_streaming(query, &[])
872}
873
874/// Returns the total number of rows present in the specified table.
875pub async fn snapshot_size(
876    client: &mut Client,
877    schema: &str,
878    table: &str,
879) -> Result<usize, SqlServerError> {
880    let query = format!(
881        "SELECT COUNT(*) FROM {schema_name}.{table_name};",
882        schema_name = quote_identifier(schema),
883        table_name = quote_identifier(table)
884    );
885    let result = client.query(query, &[]).await?;
886
887    match &result[..] {
888        [row] => match row.try_get::<i32, _>(0)? {
889            Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
890            Some(negative) => Err(SqlServerError::InvalidData {
891                column_name: "count".to_string(),
892                error: format!("found negative count: {negative}"),
893            }),
894            None => Err(SqlServerError::InvalidData {
895                column_name: "count".to_string(),
896                error: "expected a value found NULL".to_string(),
897            }),
898        },
899        other => Err(SqlServerError::InvariantViolated(format!(
900            "expected one row, got {other:?}"
901        ))),
902    }
903}
904
905/// Helper function to parse an expected result from a "system" query.
906fn check_system_result<'a, T>(
907    result: &'a SmallVec<[tiberius::Row; 1]>,
908    name: String,
909    expected: T,
910) -> Result<(), SqlServerError>
911where
912    T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
913{
914    match &result[..] {
915        [row] => {
916            let result: Option<T> = row.try_get(0)?;
917            if result == Some(expected) {
918                Ok(())
919            } else {
920                Err(SqlServerError::InvalidSystemSetting {
921                    name,
922                    expected: expected.to_string(),
923                    actual: format!("{result:?}"),
924                })
925            }
926        }
927        other => Err(SqlServerError::InvariantViolated(format!(
928            "expected 1 row, got {other:?}"
929        ))),
930    }
931}
932
933/// Return a Result that is empty if all tables, columns, and capture instances
934/// have the necessary permissions to and an error if any table, column,
935/// or capture instance does not have the necessary permissions
936/// for tracking changes.
937pub async fn validate_source_privileges(
938    client: &mut Client,
939    capture_instances: impl IntoIterator<Item = &str>,
940) -> Result<(), SqlServerError> {
941    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
942
943    if params.is_empty() {
944        return Ok(());
945    }
946
947    let params_dyn: SmallVec<[_; 1]> = params
948        .iter()
949        .map(|instance| {
950            let instance: &dyn tiberius::ToSql = instance;
951            instance
952        })
953        .collect();
954
955    let param_indexes = (1..params.len() + 1)
956        .map(|idx| format!("@P{}", idx))
957        .join(", ");
958
959    // NB(ptravers): we rely on HAS_PERMS_BY_NAME to check both table and column permissions.
960    let capture_instance_query = format!(
961            "
962        SELECT
963            SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
964            ct.capture_instance AS capture_instance,
965            COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
966            COALESCE(HAS_PERMS_BY_NAME('cdc.' + QUOTENAME(ct.capture_instance + '_CT') , 'OBJECT', 'SELECT'), 0) AS capture_table_select
967        FROM cdc.change_tables ct
968        JOIN sys.objects o ON o.object_id = ct.source_object_id
969        WHERE ct.capture_instance IN ({param_indexes});
970            "
971        );
972
973    let rows = client
974        .query(capture_instance_query, &params_dyn[..])
975        .await?;
976
977    let mut capture_instances_without_perms = vec![];
978    let mut tables_without_perms = vec![];
979
980    for row in rows {
981        let table: &str = row
982            .try_get("qualified_table_name")
983            .context("getting table column")?
984            .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
985
986        let capture_instance: &str = row
987            .try_get("capture_instance")
988            .context("getting capture_instance column")?
989            .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
990
991        let permitted_table: i32 = row
992            .try_get("table_select")
993            .context("getting table_select column")?
994            .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
995
996        let permitted_capture_instance: i32 = row
997            .try_get("capture_table_select")
998            .context("getting capture_table_select column")?
999            .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
1000
1001        if permitted_table == 0 {
1002            tables_without_perms.push(table.to_string());
1003        }
1004
1005        if permitted_capture_instance == 0 {
1006            capture_instances_without_perms.push(capture_instance.to_string());
1007        }
1008    }
1009
1010    if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
1011        return Err(SqlServerError::AuthorizationError {
1012            tables: tables_without_perms.join(", "),
1013            capture_instances: capture_instances_without_perms.join(", "),
1014        });
1015    }
1016
1017    Ok(())
1018}
1019
1020#[cfg(test)]
1021mod tests {
1022    use super::DDLEvent;
1023    use std::sync::Arc;
1024
1025    #[mz_ore::test]
1026    fn test_ddl_event_is_compatible() {
1027        fn test_case(ddl_command: &str, included_columns: &[Arc<str>], expected: bool) {
1028            let ddl_event = DDLEvent {
1029                lsn: Default::default(),
1030                ddl_command: ddl_command.into(),
1031            };
1032            let result = ddl_event.is_compatible(included_columns);
1033            assert_eq!(
1034                result, expected,
1035                "DDL command '{}' with included_columns {:?} expected to be {}, got {}",
1036                ddl_command, included_columns, expected, result
1037            );
1038        }
1039
1040        let included_columns = vec![Arc::from("col3"), Arc::from("col4"), Arc::from("col4")];
1041
1042        test_case(
1043            "ALTER TABLE my_table ALTER COLUMN col1 INT",
1044            &included_columns,
1045            true,
1046        );
1047        test_case(
1048            "ALTER TABLE my_table DROP COLUMN col2",
1049            &included_columns,
1050            true,
1051        );
1052        test_case(
1053            "ALTER TABLE my_table ALTER COLUMN col3 INT",
1054            &included_columns,
1055            false,
1056        );
1057        test_case(
1058            "ALTER TABLE my_table DROP COLUMN col4 INT",
1059            &included_columns,
1060            false,
1061        );
1062        test_case(
1063            "CREATE INDEX idx_my_index ON my_table(col1)",
1064            &included_columns,
1065            true,
1066        );
1067        test_case(
1068            "DROP INDEX idx_my_index ON my_table",
1069            &included_columns,
1070            true,
1071        );
1072        test_case(
1073            "ALTER TABLE my_table ADD COLUMN col5 INT",
1074            &included_columns,
1075            true,
1076        );
1077        test_case(
1078            "ALTER TABLE my_table DROP COLUMN col1, col2",
1079            &included_columns,
1080            true,
1081        );
1082        test_case(
1083            "ALTER TABLE my_table DROP COLUMN col3, col2",
1084            &included_columns,
1085            false,
1086        );
1087        test_case(
1088            "ALTER TABLE my_table DROP COLUMN col3, col4",
1089            &included_columns,
1090            false,
1091        );
1092        test_case(
1093            "ALTER TABLE my_table DROP COLUMN IF EXISTS col1, col2",
1094            &included_columns,
1095            true,
1096        );
1097        test_case(
1098            "ALTER TABLE my_table DROP CONSTRAINT constraint_name",
1099            &included_columns,
1100            true,
1101        );
1102        test_case(
1103            "ALTER TABLE my_table DROP COLUMN col1,col3",
1104            &included_columns,
1105            false,
1106        );
1107        test_case(
1108            "ALTER TABLE my_table DROP COLUMN col1,col2",
1109            &included_columns,
1110            true,
1111        );
1112        test_case(
1113            "ALTER TABLE my_table DROP COLUMN col1 ,col2",
1114            &included_columns,
1115            true,
1116        );
1117        test_case(
1118            "ALTER TABLE my_table DROP COLUMN col1 , col2",
1119            &included_columns,
1120            true,
1121        );
1122        test_case(
1123            "ALTER TABLE my_table DROP COLUMN col1 , col3",
1124            &included_columns,
1125            false,
1126        );
1127        test_case(
1128            "ALTER TABLE my_table DROP COLUMN col1 , COLUMN col3",
1129            &included_columns,
1130            false,
1131        );
1132    }
1133}