mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27    SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName,
28    SqlServerTableConstraintRaw, SqlServerTableRaw,
29};
30use crate::{Client, SqlServerError, quote_identifier};
31
32/// Returns the minimum log sequence number for the specified `capture_instance`.
33///
34/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
35pub async fn get_min_lsn(
36    client: &mut Client,
37    capture_instance: &str,
38) -> Result<Lsn, SqlServerError> {
39    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
40    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
41
42    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
43    parse_lsn(&result[..1])
44}
45/// Returns the minimum log sequence number for the specified `capture_instance`, retrying
46/// if the log sequence number is not available.
47///
48/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
49pub async fn get_min_lsn_retry(
50    client: &mut Client,
51    capture_instance: &str,
52    max_retry_duration: Duration,
53) -> Result<Lsn, SqlServerError> {
54    let (_client, lsn_result) = mz_ore::retry::Retry::default()
55        .max_duration(max_retry_duration)
56        .retry_async_with_state(client, |_, client| async {
57            let result = crate::inspect::get_min_lsn(client, capture_instance).await;
58            (client, map_null_lsn_to_retry(result))
59        })
60        .await;
61    let Ok(lsn) = lsn_result else {
62        tracing::warn!("database did not report a minimum LSN in time");
63        return lsn_result;
64    };
65    Ok(lsn)
66}
67
68/// Returns the maximum log sequence number for the entire database.
69/// This implementation relies on CDC, which is asynchronous, so may
70/// return an LSN that is less than the maximum LSN of SQL server.
71///
72/// See:
73/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
74/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
75pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
76    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
77    let result = client.simple_query(MAX_LSN_QUERY).await?;
78
79    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
80    parse_lsn(&result[..1])
81}
82
83/// Retrieves the minumum [`Lsn`] (start_lsn field) from `cdc.change_tables`
84/// for the specified capture instances.
85///
86/// This is based on the `sys.fn_cdc_get_min_lsn` implementation, which has logic
87/// that we want to bypass. Specifically, `sys.fn_cdc_get_min_lsn` returns NULL
88/// if the `start_lsn` in `cdc.change_tables` is less than or equal to the LSN
89/// returned by `sys.fn_cdc_get_max_lsn`.
90///
91/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver16>
92pub async fn get_min_lsns(
93    client: &mut Client,
94    capture_instances: impl IntoIterator<Item = &str>,
95) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
96    let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
97    let values: Vec<_> = capture_instances
98        .iter()
99        .map(|ci| {
100            let ci: &dyn tiberius::ToSql = ci;
101            ci
102        })
103        .collect();
104    let args = (0..capture_instances.len())
105        .map(|i| format!("@P{}", i + 1))
106        .collect::<Vec<_>>()
107        .join(",");
108    let stmt = format!(
109        "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
110    );
111    let result = client.query(stmt, &values).await?;
112    let min_lsns = result
113        .into_iter()
114        .map(|row| {
115            let capture_instance: Arc<str> = row
116                .try_get::<&str, _>("capture_instance")?
117                .ok_or_else(|| {
118                    SqlServerError::ProgrammingError(
119                        "missing column 'capture_instance'".to_string(),
120                    )
121                })?
122                .into();
123            let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
124                SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
125            })?;
126            let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
127                column_name: "lsn".to_string(),
128                error: format!("Error parsing LSN for {capture_instance}: {msg}"),
129            })?;
130            Ok::<_, SqlServerError>((capture_instance, min_lsn))
131        })
132        .collect::<Result<_, _>>()?;
133
134    Ok(min_lsns)
135}
136
137/// Returns the maximum log sequence number for the entire database, retrying
138/// if the log sequence number is not available. This implementation relies on
139/// CDC, which is asynchronous, so may return an LSN that is less than the
140/// maximum LSN of SQL server.
141///
142/// See:
143/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
144/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
145pub async fn get_max_lsn_retry(
146    client: &mut Client,
147    max_retry_duration: Duration,
148) -> Result<Lsn, SqlServerError> {
149    let (_client, lsn_result) = mz_ore::retry::Retry::default()
150        .max_duration(max_retry_duration)
151        .retry_async_with_state(client, |_, client| async {
152            let result = crate::inspect::get_max_lsn(client).await;
153            (client, map_null_lsn_to_retry(result))
154        })
155        .await;
156
157    let Ok(lsn) = lsn_result else {
158        tracing::warn!("database did not report a maximum LSN in time");
159        return lsn_result;
160    };
161    Ok(lsn)
162}
163
164fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
165    match result {
166        Ok(val) => RetryResult::Ok(val),
167        Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
168        Err(other) => RetryResult::FatalErr(other),
169    }
170}
171
172/// Increments the log sequence number.
173///
174/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
175pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
176    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
177    let result = client
178        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
179        .await?;
180
181    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
182    parse_lsn(&result[..1])
183}
184
185/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
186///
187/// Returns an error if the provided slice doesn't have exactly one row.
188pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
189    match row {
190        [r] => {
191            let numeric_lsn = r
192                .try_get::<Numeric, _>(0)?
193                .ok_or_else(|| SqlServerError::NullLsn)?;
194            let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
195                column_name: "lsn".to_string(),
196                error: msg,
197            })?;
198            Ok(lsn)
199        }
200        other => Err(SqlServerError::InvalidData {
201            column_name: "lsn".to_string(),
202            error: format!("expected 1 column, got {other:?}"),
203        }),
204    }
205}
206
207/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
208///
209/// Returns an error if the provided slice doesn't have exactly one row.
210fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
211    match result {
212        [row] => {
213            let val = row
214                .try_get::<&[u8], _>(0)?
215                .ok_or_else(|| SqlServerError::NullLsn)?;
216            if val.is_empty() {
217                Err(SqlServerError::NullLsn)
218            } else {
219                let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
220                    column_name: "lsn".to_string(),
221                    error: msg,
222                })?;
223                Ok(lsn)
224            }
225        }
226        other => Err(SqlServerError::InvalidData {
227            column_name: "lsn".to_string(),
228            error: format!("expected 1 column, got {other:?}"),
229        }),
230    }
231}
232
233/// Queries the specified capture instance and returns all changes from
234/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
235pub fn get_changes_asc(
236    client: &mut Client,
237    capture_instance: &str,
238    start_lsn: Lsn,
239    end_lsn: Lsn,
240    filter: RowFilterOption,
241) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
242    const START_LSN_COLUMN: &str = "__$start_lsn";
243    let query = format!(
244        "SELECT * FROM cdc.{function}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;",
245        function = quote_identifier(&format!("fn_cdc_get_all_changes_{capture_instance}"))
246    );
247    client.query_streaming(
248        query,
249        &[
250            &start_lsn.as_bytes().as_slice(),
251            &end_lsn.as_bytes().as_slice(),
252        ],
253    )
254}
255
256/// Cleans up the change table associated with the specified `capture_instance` by
257/// deleting `max_deletes` entries with a `start_lsn` less than `low_water_mark`.
258///
259/// Note: At the moment cleanup is kind of "best effort".  If this query succeeds
260/// then at most `max_delete` rows were deleted, but the number of actual rows
261/// deleted is not returned as part of the query. The number of rows _should_ be
262/// present in an informational message (i.e. a Notice) that is returned, but
263/// [`tiberius`] doesn't expose these to us.
264///
265/// TODO(sql_server2): Update [`tiberius`] to return informational messages so we
266/// can determine how many rows got deleted.
267///
268/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16>.
269pub async fn cleanup_change_table(
270    client: &mut Client,
271    capture_instance: &str,
272    low_water_mark: &Lsn,
273    max_deletes: u32,
274) -> Result<(), SqlServerError> {
275    static GET_LSN_QUERY: &str =
276        "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
277    static CLEANUP_QUERY: &str = "
278DECLARE @mz_cleanup_status_bit BIT;
279SET @mz_cleanup_status_bit = 0;
280EXEC sys.sp_cdc_cleanup_change_table
281    @capture_instance = @P1,
282    @low_water_mark = @P2,
283    @threshold = @P3,
284    @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
285SELECT @mz_cleanup_status_bit;
286    ";
287
288    let max_deletes = i64::cast_from(max_deletes);
289
290    // First we need to get a valid LSN as our low watermark. If we try to cleanup
291    // a change table with an LSN that doesn't exist in the `cdc.lsn_time_mapping`
292    // table we'll get an error code `22964`.
293    let result = client
294        .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
295        .await?;
296    let low_water_mark_to_use = match &result[..] {
297        [row] => row
298            .try_get::<&[u8], _>(0)?
299            .ok_or_else(|| SqlServerError::InvalidData {
300                column_name: "mz_cleanup_status_bit".to_string(),
301                error: "expected a bool, found NULL".to_string(),
302            })?,
303        other => Err(SqlServerError::ProgrammingError(format!(
304            "expected one row for low water mark, found {other:?}"
305        )))?,
306    };
307
308    // Once we get a valid LSN that is less than or equal to the provided watermark
309    // we can clean up the specified change table!
310    let result = client
311        .query(
312            CLEANUP_QUERY,
313            &[&capture_instance, &low_water_mark_to_use, &max_deletes],
314        )
315        .await;
316
317    let rows = match result {
318        Ok(rows) => rows,
319        Err(SqlServerError::SqlServer(e)) => {
320            // See these remarks from the SQL Server Documentation.
321            //
322            // <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16#remarks>.
323            let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
324
325            if already_cleaned_up {
326                return Ok(());
327            } else {
328                return Err(SqlServerError::SqlServer(e));
329            }
330        }
331        Err(other) => return Err(other),
332    };
333
334    match &rows[..] {
335        [row] => {
336            let failure =
337                row.try_get::<bool, _>(0)?
338                    .ok_or_else(|| SqlServerError::InvalidData {
339                        column_name: "mz_cleanup_status_bit".to_string(),
340                        error: "expected a bool, found NULL".to_string(),
341                    })?;
342
343            if failure {
344                Err(super::cdc::CdcError::CleanupFailed {
345                    capture_instance: capture_instance.to_string(),
346                    low_water_mark: *low_water_mark,
347                })?
348            } else {
349                Ok(())
350            }
351        }
352        other => Err(SqlServerError::ProgrammingError(format!(
353            "expected one status row, found {other:?}"
354        ))),
355    }
356}
357
358// Retrieves all columns in tables that have CDC (Change Data Capture) enabled.
359//
360// Returns metadata needed to create an instance of ['SqlServerTableRaw`].
361//
362// The query joins several system tables:
363// - sys.tables: Source tables in the database
364// - sys.schemas: Schema information for proper table identification
365// - sys.columns: Column definitions including nullability
366// - sys.types: Data type information for each column
367// - cdc.change_tables: CDC configuration linking capture instances to source tables
368//
369// For each column, it returns:
370// - Table identification (schema_name, table_name, capture_instance)
371// - Column metadata (name, type, nullable, max_length, precision, scale)
372static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
373SELECT
374    s.name as schema_name,
375    t.name as table_name,
376    ch.capture_instance as capture_instance,
377    ch.create_date as capture_instance_create_date,
378    c.name as col_name,
379    ty.name as col_type,
380    c.is_nullable as col_nullable,
381    c.max_length as col_max_length,
382    c.precision as col_precision,
383    c.scale as col_scale,
384    c.is_computed as col_is_computed
385FROM sys.tables t
386JOIN sys.schemas s ON t.schema_id = s.schema_id
387JOIN sys.columns c ON t.object_id = c.object_id
388JOIN sys.types ty ON c.user_type_id = ty.user_type_id
389JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
390";
391
392/// Returns the table metadata for the tables that are tracked by the specified `capture_instance`s.
393pub async fn get_tables_for_capture_instance<'a>(
394    client: &mut Client,
395    capture_instances: impl IntoIterator<Item = &str>,
396) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
397    // SQL Server does not have support for array types, so we need to manually construct
398    // the parameterized query.
399    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
400    // If there are no tables to check for just return an empty list.
401    if params.is_empty() {
402        return Ok(Vec::default());
403    }
404
405    // TODO(sql_server3): Remove this redundant collection.
406    #[allow(clippy::as_conversions)]
407    let params_dyn: SmallVec<[_; 1]> = params
408        .iter()
409        .map(|instance| instance as &dyn tiberius::ToSql)
410        .collect();
411    let param_indexes = params
412        .iter()
413        .enumerate()
414        // Params are 1-based indexed.
415        .map(|(idx, _)| format!("@P{}", idx + 1))
416        .join(", ");
417
418    let table_for_capture_instance_query = format!(
419        "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
420    );
421
422    let result = client
423        .query(&table_for_capture_instance_query, &params_dyn[..])
424        .await?;
425
426    let tables = deserialize_table_columns_to_raw_tables(&result)?;
427
428    Ok(tables)
429}
430
431/// Retrieves column metdata from the CDC table maintained by the provided capture instance. The
432/// resulting column information collection is similar to the information collected for the
433/// upstream table, with the exclusion of nullability and primary key constraints, which contain
434/// static values for CDC columns. CDC table schema is automatically generated and does not attempt
435/// to enforce the same constraints on the data as the upstream table.
436pub async fn get_cdc_table_columns(
437    client: &mut Client,
438    capture_instance: &str,
439) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
440    static CDC_COLUMNS_QUERY: &str = "SELECT \
441        c.name AS col_name, \
442        t.name AS col_type, \
443        c.max_length AS col_max_length, \
444        c.precision AS col_precision, \
445        c.scale AS col_scale, \
446        c.is_computed as col_is_computed \
447    FROM \
448        sys.columns AS c \
449    JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
450    WHERE \
451        c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
452    ORDER BY c.column_id;";
453    // Strings passed into OBJECT_ID must be escaped
454    let cdc_table_name = format!(
455        "cdc.{table_name}",
456        table_name = quote_identifier(&format!("{capture_instance}_CT"))
457    );
458    let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
459    let mut columns = BTreeMap::new();
460    for row in result.iter() {
461        let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
462        // Reusing this struct even though some of the fields aren't needed because it simplifies
463        // comparison with the upstream table metadata
464        let column = SqlServerColumnRaw {
465            name: Arc::clone(&column_name),
466            data_type: get_value::<&str>(row, "col_type")?.into(),
467            is_nullable: true,
468            max_length: get_value(row, "col_max_length")?,
469            precision: get_value(row, "col_precision")?,
470            scale: get_value(row, "col_scale")?,
471            is_computed: get_value(row, "col_is_computed")?,
472        };
473        columns.insert(column_name, column);
474    }
475    Ok(columns)
476}
477
478/// Retrieve primary key and unique constraints for the given tables.  Tables should be provided as
479/// an interator of tuples, where each tuple contains is `(schema_name, table_name)`.
480pub async fn get_constraints_for_tables(
481    client: &mut Client,
482    schema_table_list: impl Iterator<Item = &(Arc<str>, Arc<str>)>,
483) -> Result<BTreeMap<(Arc<str>, Arc<str>), Vec<SqlServerTableConstraintRaw>>, SqlServerError> {
484    let qualified_table_names: Vec<_> = schema_table_list
485        .map(|(schema, table)| {
486            format!(
487                "{quoted_schema}.{quoted_table}",
488                quoted_schema = quote_identifier(schema),
489                quoted_table = quote_identifier(table)
490            )
491        })
492        .collect();
493
494    if qualified_table_names.is_empty() {
495        return Ok(Default::default());
496    }
497
498    let params = (1..qualified_table_names.len() + 1)
499        .map(|idx| format!("@P{}", idx))
500        .join(", ");
501
502    // Because we don't have an object idenfifier for the table(s), this query concatenates the
503    // schema and table name to create a single identifier for the query rather than compose a
504    // complex set of OR conditions for each schema + set of tables in the schema.
505    //
506    // This query may perform poorly due to the condition relying on a concatenated value. We may get
507    // better performance by adding a query constraint on the table names, but it isn't clear at
508    // this time if that is needed.
509    let query = format!(
510        "SELECT \
511        tc.table_schema, \
512        tc.table_name, \
513        ccu.column_name,  \
514        tc.constraint_name, \
515        tc.constraint_type \
516    FROM information_schema.table_constraints tc \
517    JOIN information_schema.constraint_column_usage ccu \
518        ON ccu.constraint_schema = tc.constraint_schema \
519        AND ccu.constraint_name = tc.constraint_name \
520    WHERE
521        QUOTENAME(tc.table_schema) + '.' + QUOTENAME(tc.table_name) IN ({params})
522        AND tc.constraint_type in ('PRIMARY KEY', 'UNIQUE')"
523    );
524
525    let query_params: Vec<_> = qualified_table_names
526        .iter()
527        .map(|qualified_name| {
528            let name: &dyn tiberius::ToSql = qualified_name;
529            name
530        })
531        .collect();
532
533    tracing::debug!("query = {query} params = {qualified_table_names:?}");
534    let result = client.query(query, &query_params).await?;
535
536    let mut contraints_by_table: BTreeMap<_, BTreeMap<_, Vec<_>>> = BTreeMap::new();
537    for row in result {
538        let schema_name: Arc<str> = get_value::<&str>(&row, "table_schema")?.into();
539        let table_name: Arc<str> = get_value::<&str>(&row, "table_name")?.into();
540        let column_name = get_value::<&str>(&row, "column_name")?.into();
541        let constraint_name = get_value::<&str>(&row, "constraint_name")?.into();
542        let constraint_type = get_value::<&str>(&row, "constraint_type")?.into();
543
544        contraints_by_table
545            .entry((Arc::clone(&schema_name), Arc::clone(&table_name)))
546            .or_default()
547            .entry((constraint_name, constraint_type))
548            .or_default()
549            .push(column_name);
550    }
551    Ok(contraints_by_table
552        .into_iter()
553        .inspect(|((schema_name, table_name), constraints)| {
554            tracing::debug!("table {schema_name}.{table_name} constraints: {constraints:?}")
555        })
556        .map(|(qualified_name, constraints)| {
557            (
558                qualified_name,
559                constraints
560                    .into_iter()
561                    .map(|((constraint_name, constraint_type), columns)| {
562                        SqlServerTableConstraintRaw {
563                            constraint_name,
564                            constraint_type,
565                            columns,
566                        }
567                    })
568                    .collect(),
569            )
570        })
571        .collect())
572}
573
574/// Ensure change data capture (CDC) is enabled for the database the provided
575/// `client` is currently connected to.
576///
577/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/track-changes/enable-and-disable-change-data-capture-sql-server?view=sql-server-ver16>
578pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
579    static DATABASE_CDC_ENABLED_QUERY: &str =
580        "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
581    let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
582
583    check_system_result(&result, "database CDC".to_string(), true)?;
584    Ok(())
585}
586
587/// Retrieves the largest `restore_history_id` from SQL Server for the current database.  The
588/// `restore_history_id` column is of type `IDENTITY(1,1)` based on `EXEC sp_help restorehistory`.
589/// We expect it to start at 1 and be incremented by 1, with possible gaps in values.
590/// See:
591/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/restorehistory-transact-sql?view=sql-server-ver17>
592/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property?view=sql-server-ver17>
593pub async fn get_latest_restore_history_id(
594    client: &mut Client,
595) -> Result<Option<i32>, SqlServerError> {
596    static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
597        FROM msdb.dbo.restorehistory \
598        WHERE destination_database_name = DB_NAME() \
599        ORDER BY restore_history_id DESC;";
600    let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
601
602    match &result[..] {
603        [] => Ok(None),
604        [row] => Ok(row.try_get::<i32, _>(0)?),
605        other => Err(SqlServerError::InvariantViolated(format!(
606            "expected one row, got {other:?}"
607        ))),
608    }
609}
610
611/// A DDL event collected from the `cdc.ddl_history` table.
612#[derive(Debug)]
613pub struct DDLEvent {
614    pub lsn: Lsn,
615    pub ddl_command: Arc<str>,
616}
617
618impl DDLEvent {
619    /// Returns true if the DDL event is a compatible change, or false if it is not.
620    /// This performs a naive parsing of the DDL command looking for modification of columns
621    ///  1. ALTER TABLE .. ALTER COLUMN
622    ///  2. ALTER TABLE .. DROP COLUMN
623    ///
624    /// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/alter-table-transact-sql?view=sql-server-ver17>
625    pub fn is_compatible(&self) -> bool {
626        // TODO (maz): This is currently a basic check that doesn't take into account type changes.
627        // At some point, we will need to move this to SqlServerTableDesc and expand it.
628        let mut words = self.ddl_command.split_ascii_whitespace();
629        match (
630            words.next().map(str::to_ascii_lowercase).as_deref(),
631            words.next().map(str::to_ascii_lowercase).as_deref(),
632        ) {
633            (Some("alter"), Some("table")) => {
634                let mut peekable = words.peekable();
635                let mut compatible = true;
636                while compatible && let Some(token) = peekable.next() {
637                    compatible = match token.to_ascii_lowercase().as_str() {
638                        "alter" | "drop" => peekable
639                            .peek()
640                            .is_some_and(|next_tok| !next_tok.eq_ignore_ascii_case("column")),
641                        _ => true,
642                    }
643                }
644                compatible
645            }
646            _ => true,
647        }
648    }
649}
650
651/// Returns DDL changes made to the source table for the given capture instance.  This follows the
652/// same convention as `cdc.fn_cdc_get_all_changes_<capture_instance>`, in that the range is
653/// inclusive, i.e. `[from_lsn, to_lsn]`. The events are returned in ascending order of
654/// LSN.
655pub async fn get_ddl_history(
656    client: &mut Client,
657    capture_instance: &str,
658    from_lsn: &Lsn,
659    to_lsn: &Lsn,
660) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
661    // We query the ddl_history table instead of using the stored procedure as there doesn't
662    // appear to be a way to apply filters or projections against output of the stored procedure
663    // without an intermediate table.
664    static DDL_HISTORY_QUERY: &str = "SELECT \
665                s.name AS schema_name, \
666                t.name AS table_name, \
667                dh.ddl_lsn, \
668                dh.ddl_command
669            FROM \
670                cdc.change_tables ct \
671            JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
672            JOIN sys.tables t ON t.object_id = dh.source_object_id \
673            JOIN sys.schemas s ON s.schema_id = t.schema_id \
674            WHERE \
675                ct.capture_instance = @P1 \
676                AND dh.ddl_lsn >= @P2 \
677                AND dh.ddl_lsn <= @P3 \
678            ORDER BY ddl_lsn;";
679
680    let result = client
681        .query(
682            DDL_HISTORY_QUERY,
683            &[
684                &capture_instance,
685                &from_lsn.as_bytes().as_slice(),
686                &to_lsn.as_bytes().as_slice(),
687            ],
688        )
689        .await?;
690
691    // SQL server doesn't support array types, and using string_agg to collect LSN
692    // would require more parsing, so we opt for a BTreeMap to accumulate the results.
693    let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
694    for row in result.iter() {
695        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
696        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
697        let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
698        let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
699
700        let qualified_table_name = SqlServerQualifiedTableName {
701            schema_name,
702            table_name,
703        };
704        let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
705            column_name: "ddl_lsn".to_string(),
706            error: lsn_err,
707        })?;
708
709        collector
710            .entry(qualified_table_name)
711            .or_default()
712            .push(DDLEvent { lsn, ddl_command });
713    }
714
715    Ok(collector)
716}
717
718/// Ensure the `SNAPSHOT` transaction isolation level is enabled for the
719/// database the provided `client` is currently connected to.
720///
721/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql?view=sql-server-ver16>
722pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
723    static SNAPSHOT_ISOLATION_QUERY: &str =
724        "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
725    let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
726
727    check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
728    Ok(())
729}
730
731/// Ensure the SQL Server Agent is running.
732///
733/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-dynamic-management-views/sys-dm-server-services-transact-sql?view=azuresqldb-current&viewFallbackFrom=sql-server-ver17>
734pub async fn ensure_sql_server_agent_running(client: &mut Client) -> Result<(), SqlServerError> {
735    static AGENT_STATUS_QUERY: &str = "SELECT status_desc FROM sys.dm_server_services WHERE servicename LIKE 'SQL Server Agent%';";
736    let result = client.simple_query(AGENT_STATUS_QUERY).await?;
737
738    check_system_result(&result, "SQL Server Agent status".to_string(), "Running")?;
739    Ok(())
740}
741
742pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
743    let result = client
744        .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
745        .await?;
746
747    let tables = deserialize_table_columns_to_raw_tables(&result)?;
748
749    Ok(tables)
750}
751
752// Helper function to retrieve value from a row.
753fn get_value<'a, T: tiberius::FromSql<'a>>(
754    row: &'a tiberius::Row,
755    name: &'static str,
756) -> Result<T, SqlServerError> {
757    row.try_get(name)?
758        .ok_or(SqlServerError::MissingColumn(name))
759}
760fn deserialize_table_columns_to_raw_tables(
761    rows: &[tiberius::Row],
762) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
763    // Group our columns by (schema, name).
764    let mut tables = BTreeMap::default();
765    for row in rows {
766        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
767        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
768        let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
769        let capture_instance_create_date: NaiveDateTime =
770            get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
771
772        let column_name = get_value::<&str>(row, "col_name")?.into();
773        let column = SqlServerColumnRaw {
774            name: Arc::clone(&column_name),
775            data_type: get_value::<&str>(row, "col_type")?.into(),
776            is_nullable: get_value(row, "col_nullable")?,
777            max_length: get_value(row, "col_max_length")?,
778            precision: get_value(row, "col_precision")?,
779            scale: get_value(row, "col_scale")?,
780            is_computed: get_value(row, "col_is_computed")?,
781        };
782
783        let columns: &mut Vec<_> = tables
784            .entry((
785                Arc::clone(&schema_name),
786                Arc::clone(&table_name),
787                Arc::clone(&capture_instance),
788                capture_instance_create_date,
789            ))
790            .or_default();
791        columns.push(column);
792    }
793
794    let raw_tables = tables
795        .into_iter()
796        .map(
797            |((schema, name, capture_instance, capture_instance_create_date), columns)| {
798                SqlServerTableRaw {
799                    schema_name: schema,
800                    name,
801                    capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
802                        name: capture_instance,
803                        create_date: capture_instance_create_date.into(),
804                    }),
805                    columns: columns.into(),
806                }
807            },
808        )
809        .collect();
810    Ok(raw_tables)
811}
812
813/// Return a [`Stream`] that is the entire snapshot of the specified table.
814pub fn snapshot(
815    client: &mut Client,
816    table: &SqlServerTableRaw,
817) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
818    let cols = table
819        .columns
820        .iter()
821        .map(|SqlServerColumnRaw { name, .. }| quote_identifier(name))
822        .join(",");
823    let query = format!(
824        "SELECT {cols} FROM {schema_name}.{table_name};",
825        schema_name = quote_identifier(&table.schema_name),
826        table_name = quote_identifier(&table.name)
827    );
828    client.query_streaming(query, &[])
829}
830
831/// Returns the total number of rows present in the specified table.
832pub async fn snapshot_size(
833    client: &mut Client,
834    schema: &str,
835    table: &str,
836) -> Result<usize, SqlServerError> {
837    let query = format!(
838        "SELECT COUNT(*) FROM {schema_name}.{table_name};",
839        schema_name = quote_identifier(schema),
840        table_name = quote_identifier(table)
841    );
842    let result = client.query(query, &[]).await?;
843
844    match &result[..] {
845        [row] => match row.try_get::<i32, _>(0)? {
846            Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
847            Some(negative) => Err(SqlServerError::InvalidData {
848                column_name: "count".to_string(),
849                error: format!("found negative count: {negative}"),
850            }),
851            None => Err(SqlServerError::InvalidData {
852                column_name: "count".to_string(),
853                error: "expected a value found NULL".to_string(),
854            }),
855        },
856        other => Err(SqlServerError::InvariantViolated(format!(
857            "expected one row, got {other:?}"
858        ))),
859    }
860}
861
862/// Helper function to parse an expected result from a "system" query.
863fn check_system_result<'a, T>(
864    result: &'a SmallVec<[tiberius::Row; 1]>,
865    name: String,
866    expected: T,
867) -> Result<(), SqlServerError>
868where
869    T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
870{
871    match &result[..] {
872        [row] => {
873            let result: Option<T> = row.try_get(0)?;
874            if result == Some(expected) {
875                Ok(())
876            } else {
877                Err(SqlServerError::InvalidSystemSetting {
878                    name,
879                    expected: expected.to_string(),
880                    actual: format!("{result:?}"),
881                })
882            }
883        }
884        other => Err(SqlServerError::InvariantViolated(format!(
885            "expected 1 row, got {other:?}"
886        ))),
887    }
888}
889
890/// Return a Result that is empty if all tables, columns, and capture instances
891/// have the necessary permissions to and an error if any table, column,
892/// or capture instance does not have the necessary permissions
893/// for tracking changes.
894pub async fn validate_source_privileges<'a>(
895    client: &mut Client,
896    capture_instances: impl IntoIterator<Item = &str>,
897) -> Result<(), SqlServerError> {
898    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
899
900    if params.is_empty() {
901        return Ok(());
902    }
903
904    let params_dyn: SmallVec<[_; 1]> = params
905        .iter()
906        .map(|instance| {
907            let instance: &dyn tiberius::ToSql = instance;
908            instance
909        })
910        .collect();
911
912    let param_indexes = (1..params.len() + 1)
913        .map(|idx| format!("@P{}", idx))
914        .join(", ");
915
916    // NB(ptravers): we rely on HAS_PERMS_BY_NAME to check both table and column permissions.
917    let capture_instance_query = format!(
918            "
919        SELECT
920            SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
921            ct.capture_instance AS capture_instance,
922            COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
923            COALESCE(HAS_PERMS_BY_NAME('cdc.' + QUOTENAME(ct.capture_instance + '_CT') , 'OBJECT', 'SELECT'), 0) AS capture_table_select
924        FROM cdc.change_tables ct
925        JOIN sys.objects o ON o.object_id = ct.source_object_id
926        WHERE ct.capture_instance IN ({param_indexes});
927            "
928        );
929
930    let rows = client
931        .query(capture_instance_query, &params_dyn[..])
932        .await?;
933
934    let mut capture_instances_without_perms = vec![];
935    let mut tables_without_perms = vec![];
936
937    for row in rows {
938        let table: &str = row
939            .try_get("qualified_table_name")
940            .context("getting table column")?
941            .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
942
943        let capture_instance: &str = row
944            .try_get("capture_instance")
945            .context("getting capture_instance column")?
946            .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
947
948        let permitted_table: i32 = row
949            .try_get("table_select")
950            .context("getting table_select column")?
951            .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
952
953        let permitted_capture_instance: i32 = row
954            .try_get("capture_table_select")
955            .context("getting capture_table_select column")?
956            .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
957
958        if permitted_table == 0 {
959            tables_without_perms.push(table.to_string());
960        }
961
962        if permitted_capture_instance == 0 {
963            capture_instances_without_perms.push(capture_instance.to_string());
964        }
965    }
966
967    if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
968        return Err(SqlServerError::AuthorizationError {
969            tables: tables_without_perms.join(", "),
970            capture_instances: capture_instances_without_perms.join(", "),
971        });
972    }
973
974    Ok(())
975}