mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27    SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName, SqlServerTableRaw,
28};
29use crate::{Client, SqlServerError};
30
31/// Returns the minimum log sequence number for the specified `capture_instance`.
32///
33/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
34pub async fn get_min_lsn(
35    client: &mut Client,
36    capture_instance: &str,
37) -> Result<Lsn, SqlServerError> {
38    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
39    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
40
41    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
42    parse_lsn(&result[..1])
43}
44/// Returns the minimum log sequence number for the specified `capture_instance`, retrying
45/// if the log sequence number is not available.
46///
47/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
48pub async fn get_min_lsn_retry(
49    client: &mut Client,
50    capture_instance: &str,
51    max_retry_duration: Duration,
52) -> Result<Lsn, SqlServerError> {
53    let (_client, lsn_result) = mz_ore::retry::Retry::default()
54        .max_duration(max_retry_duration)
55        .retry_async_with_state(client, |_, client| async {
56            let result = crate::inspect::get_min_lsn(client, capture_instance).await;
57            (client, map_null_lsn_to_retry(result))
58        })
59        .await;
60    let Ok(lsn) = lsn_result else {
61        tracing::warn!("database did not report a minimum LSN in time");
62        return lsn_result;
63    };
64    Ok(lsn)
65}
66
67/// Returns the maximum log sequence number for the entire database.
68/// This implementation relies on CDC, which is asynchronous, so may
69/// return an LSN that is less than the maximum LSN of SQL server.
70///
71/// See:
72/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
73/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
74pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
75    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
76    let result = client.simple_query(MAX_LSN_QUERY).await?;
77
78    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
79    parse_lsn(&result[..1])
80}
81
82/// Retrieves the minumum [`Lsn`] (start_lsn field) from `cdc.change_tables`
83/// for the specified capture instances.
84///
85/// This is based on the `sys.fn_cdc_get_min_lsn` implementation, which has logic
86/// that we want to bypass. Specifically, `sys.fn_cdc_get_min_lsn` returns NULL
87/// if the `start_lsn` in `cdc.change_tables` is less than or equal to the LSN
88/// returned by `sys.fn_cdc_get_max_lsn`.
89///
90/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver16>
91pub async fn get_min_lsns(
92    client: &mut Client,
93    capture_instances: impl IntoIterator<Item = &str>,
94) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
95    let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
96    let values: Vec<_> = capture_instances
97        .iter()
98        .map(|ci| {
99            let ci: &dyn tiberius::ToSql = ci;
100            ci
101        })
102        .collect();
103    let args = (0..capture_instances.len())
104        .map(|i| format!("@P{}", i + 1))
105        .collect::<Vec<_>>()
106        .join(",");
107    let stmt = format!(
108        "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
109    );
110    let result = client.query(stmt, &values).await?;
111    let min_lsns = result
112        .into_iter()
113        .map(|row| {
114            let capture_instance: Arc<str> = row
115                .try_get::<&str, _>("capture_instance")?
116                .ok_or_else(|| {
117                    SqlServerError::ProgrammingError(
118                        "missing column 'capture_instance'".to_string(),
119                    )
120                })?
121                .into();
122            let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
123                SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
124            })?;
125            let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
126                column_name: "lsn".to_string(),
127                error: format!("Error parsing LSN for {capture_instance}: {msg}"),
128            })?;
129            Ok::<_, SqlServerError>((capture_instance, min_lsn))
130        })
131        .collect::<Result<_, _>>()?;
132
133    Ok(min_lsns)
134}
135
136/// Returns the maximum log sequence number for the entire database, retrying
137/// if the log sequence number is not available. This implementation relies on
138/// CDC, which is asynchronous, so may return an LSN that is less than the
139/// maximum LSN of SQL server.
140///
141/// See:
142/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
143/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
144pub async fn get_max_lsn_retry(
145    client: &mut Client,
146    max_retry_duration: Duration,
147) -> Result<Lsn, SqlServerError> {
148    let (_client, lsn_result) = mz_ore::retry::Retry::default()
149        .max_duration(max_retry_duration)
150        .retry_async_with_state(client, |_, client| async {
151            let result = crate::inspect::get_max_lsn(client).await;
152            (client, map_null_lsn_to_retry(result))
153        })
154        .await;
155
156    let Ok(lsn) = lsn_result else {
157        tracing::warn!("database did not report a maximum LSN in time");
158        return lsn_result;
159    };
160    Ok(lsn)
161}
162
163fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
164    match result {
165        Ok(val) => RetryResult::Ok(val),
166        Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
167        Err(other) => RetryResult::FatalErr(other),
168    }
169}
170
171/// Increments the log sequence number.
172///
173/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
174pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
175    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
176    let result = client
177        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
178        .await?;
179
180    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
181    parse_lsn(&result[..1])
182}
183
184/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
185///
186/// Returns an error if the provided slice doesn't have exactly one row.
187pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
188    match row {
189        [r] => {
190            let numeric_lsn = r
191                .try_get::<Numeric, _>(0)?
192                .ok_or_else(|| SqlServerError::NullLsn)?;
193            let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
194                column_name: "lsn".to_string(),
195                error: msg,
196            })?;
197            Ok(lsn)
198        }
199        other => Err(SqlServerError::InvalidData {
200            column_name: "lsn".to_string(),
201            error: format!("expected 1 column, got {other:?}"),
202        }),
203    }
204}
205
206/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
207///
208/// Returns an error if the provided slice doesn't have exactly one row.
209fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
210    match result {
211        [row] => {
212            let val = row
213                .try_get::<&[u8], _>(0)?
214                .ok_or_else(|| SqlServerError::NullLsn)?;
215            if val.is_empty() {
216                Err(SqlServerError::NullLsn)
217            } else {
218                let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
219                    column_name: "lsn".to_string(),
220                    error: msg,
221                })?;
222                Ok(lsn)
223            }
224        }
225        other => Err(SqlServerError::InvalidData {
226            column_name: "lsn".to_string(),
227            error: format!("expected 1 column, got {other:?}"),
228        }),
229    }
230}
231
232/// Queries the specified capture instance and returns all changes from
233/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
234///
235/// TODO(sql_server2): This presents an opportunity for SQL injection. We should create a stored
236/// procedure using `QUOTENAME` to sanitize the input for the capture instance provided by the
237/// user.
238pub fn get_changes_asc(
239    client: &mut Client,
240    capture_instance: &str,
241    start_lsn: Lsn,
242    end_lsn: Lsn,
243    filter: RowFilterOption,
244) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
245    const START_LSN_COLUMN: &str = "__$start_lsn";
246    let query = format!(
247        "SELECT * FROM cdc.fn_cdc_get_all_changes_{capture_instance}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;"
248    );
249    client.query_streaming(
250        query,
251        &[
252            &start_lsn.as_bytes().as_slice(),
253            &end_lsn.as_bytes().as_slice(),
254        ],
255    )
256}
257
258/// Cleans up the change table associated with the specified `capture_instance` by
259/// deleting `max_deletes` entries with a `start_lsn` less than `low_water_mark`.
260///
261/// Note: At the moment cleanup is kind of "best effort".  If this query succeeds
262/// then at most `max_delete` rows were deleted, but the number of actual rows
263/// deleted is not returned as part of the query. The number of rows _should_ be
264/// present in an informational message (i.e. a Notice) that is returned, but
265/// [`tiberius`] doesn't expose these to us.
266///
267/// TODO(sql_server2): Update [`tiberius`] to return informational messages so we
268/// can determine how many rows got deleted.
269///
270/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16>.
271pub async fn cleanup_change_table(
272    client: &mut Client,
273    capture_instance: &str,
274    low_water_mark: &Lsn,
275    max_deletes: u32,
276) -> Result<(), SqlServerError> {
277    static GET_LSN_QUERY: &str =
278        "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
279    static CLEANUP_QUERY: &str = "
280DECLARE @mz_cleanup_status_bit BIT;
281SET @mz_cleanup_status_bit = 0;
282EXEC sys.sp_cdc_cleanup_change_table
283    @capture_instance = @P1,
284    @low_water_mark = @P2,
285    @threshold = @P3,
286    @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
287SELECT @mz_cleanup_status_bit;
288    ";
289
290    let max_deletes = i64::cast_from(max_deletes);
291
292    // First we need to get a valid LSN as our low watermark. If we try to cleanup
293    // a change table with an LSN that doesn't exist in the `cdc.lsn_time_mapping`
294    // table we'll get an error code `22964`.
295    let result = client
296        .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
297        .await?;
298    let low_water_mark_to_use = match &result[..] {
299        [row] => row
300            .try_get::<&[u8], _>(0)?
301            .ok_or_else(|| SqlServerError::InvalidData {
302                column_name: "mz_cleanup_status_bit".to_string(),
303                error: "expected a bool, found NULL".to_string(),
304            })?,
305        other => Err(SqlServerError::ProgrammingError(format!(
306            "expected one row for low water mark, found {other:?}"
307        )))?,
308    };
309
310    // Once we get a valid LSN that is less than or equal to the provided watermark
311    // we can clean up the specified change table!
312    let result = client
313        .query(
314            CLEANUP_QUERY,
315            &[&capture_instance, &low_water_mark_to_use, &max_deletes],
316        )
317        .await;
318
319    let rows = match result {
320        Ok(rows) => rows,
321        Err(SqlServerError::SqlServer(e)) => {
322            // See these remarks from the SQL Server Documentation.
323            //
324            // <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16#remarks>.
325            let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
326
327            if already_cleaned_up {
328                return Ok(());
329            } else {
330                return Err(SqlServerError::SqlServer(e));
331            }
332        }
333        Err(other) => return Err(other),
334    };
335
336    match &rows[..] {
337        [row] => {
338            let failure =
339                row.try_get::<bool, _>(0)?
340                    .ok_or_else(|| SqlServerError::InvalidData {
341                        column_name: "mz_cleanup_status_bit".to_string(),
342                        error: "expected a bool, found NULL".to_string(),
343                    })?;
344
345            if failure {
346                Err(super::cdc::CdcError::CleanupFailed {
347                    capture_instance: capture_instance.to_string(),
348                    low_water_mark: *low_water_mark,
349                })?
350            } else {
351                Ok(())
352            }
353        }
354        other => Err(SqlServerError::ProgrammingError(format!(
355            "expected one status row, found {other:?}"
356        ))),
357    }
358}
359
360// Retrieves all columns in tables that have CDC (Change Data Capture) enabled.
361//
362// Returns metadata needed to create an instance of ['SqlServerTableRaw`].
363//
364// The query joins several system tables:
365// - sys.tables: Source tables in the database
366// - sys.schemas: Schema information for proper table identification
367// - sys.columns: Column definitions including nullability
368// - sys.types: Data type information for each column
369// - cdc.change_tables: CDC configuration linking capture instances to source tables
370// - information_schema views: To identify primary key constraints
371//
372// For each column, it returns:
373// - Table identification (schema_name, table_name, capture_instance)
374// - Column metadata (name, type, nullable, max_length, precision, scale)
375// - Primary key information (constraint name if the column is part of a PK)
376static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
377SELECT
378    s.name as schema_name,
379    t.name as table_name,
380    ch.capture_instance as capture_instance,
381    ch.create_date as capture_instance_create_date,
382    c.name as col_name,
383    ty.name as col_type,
384    c.is_nullable as col_nullable,
385    c.max_length as col_max_length,
386    c.precision as col_precision,
387    c.scale as col_scale,
388    tc.constraint_name AS col_primary_key_constraint
389FROM sys.tables t
390JOIN sys.schemas s ON t.schema_id = s.schema_id
391JOIN sys.columns c ON t.object_id = c.object_id
392JOIN sys.types ty ON c.user_type_id = ty.user_type_id
393JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
394LEFT JOIN information_schema.key_column_usage kc
395    ON kc.table_schema = s.name
396    AND kc.table_name = t.name
397    AND kc.column_name = c.name
398LEFT JOIN information_schema.table_constraints tc
399    ON tc.constraint_catalog = kc.constraint_catalog
400    AND tc.constraint_schema = kc.constraint_schema
401    AND tc.constraint_name = kc.constraint_name
402    AND tc.table_schema = kc.table_schema
403    AND tc.table_name = kc.table_name
404    AND tc.constraint_type = 'PRIMARY KEY'
405";
406
407/// Returns the table metadata for the tables that are tracked by the specified `capture_instance`s.
408pub async fn get_tables_for_capture_instance<'a>(
409    client: &mut Client,
410    capture_instances: impl IntoIterator<Item = &str>,
411) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
412    // SQL Server does not have support for array types, so we need to manually construct
413    // the parameterized query.
414    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
415    // If there are no tables to check for just return an empty list.
416    if params.is_empty() {
417        return Ok(Vec::default());
418    }
419
420    // TODO(sql_server3): Remove this redundant collection.
421    #[allow(clippy::as_conversions)]
422    let params_dyn: SmallVec<[_; 1]> = params
423        .iter()
424        .map(|instance| instance as &dyn tiberius::ToSql)
425        .collect();
426    let param_indexes = params
427        .iter()
428        .enumerate()
429        // Params are 1-based indexed.
430        .map(|(idx, _)| format!("@P{}", idx + 1))
431        .join(", ");
432
433    let table_for_capture_instance_query = format!(
434        "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
435    );
436
437    let result = client
438        .query(&table_for_capture_instance_query, &params_dyn[..])
439        .await?;
440
441    let tables = deserialize_table_columns_to_raw_tables(&result)?;
442
443    Ok(tables)
444}
445
446/// Retrieves column metdata from the CDC table maintained by the provided capture instance. The
447/// resulting column information collection is similar to the information collected for the
448/// upstream table, with the exclusion of nullability and primary key constraints, which contain
449/// static values for CDC columns. CDC table schema is automatically generated and does not attempt
450/// to enforce the same constraints on the data as the upstream table.
451pub async fn get_cdc_table_columns(
452    client: &mut Client,
453    capture_instance: &str,
454) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
455    static CDC_COLUMNS_QUERY: &str = "SELECT \
456        c.name AS col_name, \
457        t.name AS col_type, \
458        c.max_length AS col_max_length, \
459        c.precision AS col_precision, \
460        c.scale AS col_scale \
461    FROM \
462        sys.columns AS c \
463    JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
464    WHERE \
465        c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
466    ORDER BY c.column_id;";
467    let cdc_table_name = format!("cdc.{capture_instance}_CT");
468    let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
469    let mut columns = BTreeMap::new();
470    for row in result.iter() {
471        let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
472        // Reusing this struct even though some of the fields aren't needed because it simplifies
473        // comparison with the upstream table metadata
474        let column = SqlServerColumnRaw {
475            name: Arc::clone(&column_name),
476            data_type: get_value::<&str>(row, "col_type")?.into(),
477            is_nullable: true,
478            primary_key_constraint: None,
479            max_length: get_value(row, "col_max_length")?,
480            precision: get_value(row, "col_precision")?,
481            scale: get_value(row, "col_scale")?,
482        };
483        columns.insert(column_name, column);
484    }
485    Ok(columns)
486}
487
488/// Ensure change data capture (CDC) is enabled for the database the provided
489/// `client` is currently connected to.
490///
491/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/track-changes/enable-and-disable-change-data-capture-sql-server?view=sql-server-ver16>
492pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
493    static DATABASE_CDC_ENABLED_QUERY: &str =
494        "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
495    let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
496
497    check_system_result(&result, "database CDC".to_string(), true)?;
498    Ok(())
499}
500
501/// Retrieves the largest `restore_history_id` from SQL Server for the current database.  The
502/// `restore_history_id` column is of type `IDENTITY(1,1)` based on `EXEC sp_help restorehistory`.
503/// We expect it to start at 1 and be incremented by 1, with possible gaps in values.
504/// See:
505/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/restorehistory-transact-sql?view=sql-server-ver17>
506/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property?view=sql-server-ver17>
507pub async fn get_latest_restore_history_id(
508    client: &mut Client,
509) -> Result<Option<i32>, SqlServerError> {
510    static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
511        FROM msdb.dbo.restorehistory \
512        WHERE destination_database_name = DB_NAME() \
513        ORDER BY restore_history_id DESC;";
514    let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
515
516    match &result[..] {
517        [] => Ok(None),
518        [row] => Ok(row.try_get::<i32, _>(0)?),
519        other => Err(SqlServerError::InvariantViolated(format!(
520            "expected one row, got {other:?}"
521        ))),
522    }
523}
524
525/// A DDL event collected from the `cdc.ddl_history` table.
526#[derive(Debug)]
527pub struct DDLEvent {
528    pub lsn: Lsn,
529    pub ddl_command: Arc<str>,
530}
531
532impl DDLEvent {
533    /// Returns true if the DDL event is a compatible change, or false if it is not.
534    /// This performs a naive parsing of the DDL command looking for modification of columns
535    ///  1. ALTER TABLE .. ALTER COLUMN
536    ///  2. ALTER TABLE .. DROP COLUMN
537    ///
538    /// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/alter-table-transact-sql?view=sql-server-ver17>
539    pub fn is_compatible(&self) -> bool {
540        // TODO (maz): This is currently a basic check that doesn't take into account type changes.
541        // At some point, we will need to move this to SqlServerTableDesc and expand it.
542        let mut words = self.ddl_command.split_ascii_whitespace();
543        match (
544            words.next().map(str::to_ascii_lowercase).as_deref(),
545            words.next().map(str::to_ascii_lowercase).as_deref(),
546        ) {
547            (Some("alter"), Some("table")) => {
548                let mut peekable = words.peekable();
549                let mut compatible = true;
550                while compatible && let Some(token) = peekable.next() {
551                    compatible = match token.to_ascii_lowercase().as_str() {
552                        "alter" | "drop" => peekable
553                            .peek()
554                            .is_some_and(|next_tok| !next_tok.eq_ignore_ascii_case("column")),
555                        _ => true,
556                    }
557                }
558                compatible
559            }
560            _ => true,
561        }
562    }
563}
564
565/// Returns DDL changes made to the source table for the given capture instance.  This follows the
566/// same convention as `cdc.fn_cdc_get_all_changes_<capture_instance>`, in that the range is
567/// inclusive, i.e. `[from_lsn, to_lsn]`. The events are returned in ascending order of
568/// LSN.
569pub async fn get_ddl_history(
570    client: &mut Client,
571    capture_instance: &str,
572    from_lsn: &Lsn,
573    to_lsn: &Lsn,
574) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
575    // We query the ddl_history table instead of using the stored procedure as there doesn't
576    // appear to be a way to apply filters or projections against output of the stored procedure
577    // without an intermediate table.
578    static DDL_HISTORY_QUERY: &str = "SELECT \
579                s.name AS schema_name, \
580                t.name AS table_name, \
581                dh.ddl_lsn, \
582                dh.ddl_command
583            FROM \
584                cdc.change_tables ct \
585            JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
586            JOIN sys.tables t ON t.object_id = dh.source_object_id \
587            JOIN sys.schemas s ON s.schema_id = t.schema_id \
588            WHERE \
589                ct.capture_instance = @P1 \
590                AND dh.ddl_lsn >= @P2 \
591                AND dh.ddl_lsn <= @P3 \
592            ORDER BY ddl_lsn;";
593
594    let result = client
595        .query(
596            DDL_HISTORY_QUERY,
597            &[
598                &capture_instance,
599                &from_lsn.as_bytes().as_slice(),
600                &to_lsn.as_bytes().as_slice(),
601            ],
602        )
603        .await?;
604
605    // SQL server doesn't support array types, and using string_agg to collect LSN
606    // would require more parsing, so we opt for a BTreeMap to accumulate the results.
607    let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
608    for row in result.iter() {
609        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
610        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
611        let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
612        let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
613
614        let qualified_table_name = SqlServerQualifiedTableName {
615            schema_name,
616            table_name,
617        };
618        let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
619            column_name: "ddl_lsn".to_string(),
620            error: lsn_err,
621        })?;
622
623        collector
624            .entry(qualified_table_name)
625            .or_default()
626            .push(DDLEvent { lsn, ddl_command });
627    }
628
629    Ok(collector)
630}
631
632/// Ensure the `SNAPSHOT` transaction isolation level is enabled for the
633/// database the provided `client` is currently connected to.
634///
635/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql?view=sql-server-ver16>
636pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
637    static SNAPSHOT_ISOLATION_QUERY: &str =
638        "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
639    let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
640
641    check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
642    Ok(())
643}
644
645pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
646    let result = client
647        .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
648        .await?;
649
650    let tables = deserialize_table_columns_to_raw_tables(&result)?;
651
652    Ok(tables)
653}
654
655// Helper function to retrieve value from a row.
656fn get_value<'a, T: tiberius::FromSql<'a>>(
657    row: &'a tiberius::Row,
658    name: &'static str,
659) -> Result<T, SqlServerError> {
660    row.try_get(name)?
661        .ok_or(SqlServerError::MissingColumn(name))
662}
663
664fn deserialize_table_columns_to_raw_tables(
665    rows: &[tiberius::Row],
666) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
667    // Group our columns by (schema, name).
668    let mut tables = BTreeMap::default();
669    for row in rows {
670        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
671        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
672        let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
673        let capture_instance_create_date: NaiveDateTime =
674            get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
675        let primary_key_constraint: Option<Arc<str>> = row
676            .try_get::<&str, _>("col_primary_key_constraint")?
677            .map(|v| v.into());
678
679        let column_name = get_value::<&str>(row, "col_name")?.into();
680        let column = SqlServerColumnRaw {
681            name: Arc::clone(&column_name),
682            data_type: get_value::<&str>(row, "col_type")?.into(),
683            is_nullable: get_value(row, "col_nullable")?,
684            primary_key_constraint,
685            max_length: get_value(row, "col_max_length")?,
686            precision: get_value(row, "col_precision")?,
687            scale: get_value(row, "col_scale")?,
688        };
689
690        let columns: &mut Vec<_> = tables
691            .entry((
692                Arc::clone(&schema_name),
693                Arc::clone(&table_name),
694                Arc::clone(&capture_instance),
695                capture_instance_create_date,
696            ))
697            .or_default();
698        columns.push(column);
699    }
700
701    // Flatten into our raw Table description.
702    let raw_tables = tables
703        .into_iter()
704        .map(
705            |((schema, name, capture_instance, capture_instance_create_date), columns)| {
706                SqlServerTableRaw {
707                    schema_name: schema,
708                    name,
709                    capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
710                        name: capture_instance,
711                        create_date: capture_instance_create_date.into(),
712                    }),
713                    columns: columns.into(),
714                }
715            },
716        )
717        .collect::<Vec<SqlServerTableRaw>>();
718
719    Ok(raw_tables)
720}
721
722/// Return a [`Stream`] that is the entire snapshot of the specified table.
723pub fn snapshot(
724    client: &mut Client,
725    schema: &str,
726    table: &str,
727) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
728    let query = format!("SELECT * FROM {schema}.{table};");
729    client.query_streaming(query, &[])
730}
731
732/// Returns the total number of rows present in the specified table.
733pub async fn snapshot_size(
734    client: &mut Client,
735    schema: &str,
736    table: &str,
737) -> Result<usize, SqlServerError> {
738    let query = format!("SELECT COUNT(*) FROM {schema}.{table};");
739    let result = client.query(query, &[]).await?;
740
741    match &result[..] {
742        [row] => match row.try_get::<i32, _>(0)? {
743            Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
744            Some(negative) => Err(SqlServerError::InvalidData {
745                column_name: "count".to_string(),
746                error: format!("found negative count: {negative}"),
747            }),
748            None => Err(SqlServerError::InvalidData {
749                column_name: "count".to_string(),
750                error: "expected a value found NULL".to_string(),
751            }),
752        },
753        other => Err(SqlServerError::InvariantViolated(format!(
754            "expected one row, got {other:?}"
755        ))),
756    }
757}
758
759/// Helper function to parse an expected result from a "system" query.
760fn check_system_result<'a, T>(
761    result: &'a SmallVec<[tiberius::Row; 1]>,
762    name: String,
763    expected: T,
764) -> Result<(), SqlServerError>
765where
766    T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
767{
768    match &result[..] {
769        [row] => {
770            let result: Option<T> = row.try_get(0)?;
771            if result == Some(expected) {
772                Ok(())
773            } else {
774                Err(SqlServerError::InvalidSystemSetting {
775                    name,
776                    expected: expected.to_string(),
777                    actual: format!("{result:?}"),
778                })
779            }
780        }
781        other => Err(SqlServerError::InvariantViolated(format!(
782            "expected 1 row, got {other:?}"
783        ))),
784    }
785}
786
787/// Return a Result that is empty if all tables, columns, and capture instances
788/// have the necessary permissions to and an error if any table, column,
789/// or capture instance does not have the necessary permissions
790/// for tracking changes.
791pub async fn validate_source_privileges<'a>(
792    client: &mut Client,
793    capture_instances: impl IntoIterator<Item = &str>,
794) -> Result<(), SqlServerError> {
795    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
796
797    if params.is_empty() {
798        return Ok(());
799    }
800
801    let params_dyn: SmallVec<[_; 1]> = params
802        .iter()
803        .map(|instance| {
804            let instance: &dyn tiberius::ToSql = instance;
805            instance
806        })
807        .collect();
808
809    let param_indexes = (1..params.len() + 1)
810        .map(|idx| format!("@P{}", idx))
811        .join(", ");
812
813    // NB(ptravers): we rely on HAS_PERMS_BY_NAME to check both table and column permissions.
814    let capture_instance_query = format!(
815            "
816        SELECT
817            SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
818            ct.capture_instance AS capture_instance,
819            COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
820            COALESCE(HAS_PERMS_BY_NAME('cdc.' + ct.capture_instance + '_CT', 'OBJECT', 'SELECT'), 0) AS capture_table_select
821        FROM cdc.change_tables ct
822        JOIN sys.objects o ON o.object_id = ct.source_object_id
823        WHERE ct.capture_instance IN ({param_indexes});
824            "
825        );
826
827    let rows = client
828        .query(capture_instance_query, &params_dyn[..])
829        .await?;
830
831    let mut capture_instances_without_perms = vec![];
832    let mut tables_without_perms = vec![];
833
834    for row in rows {
835        let table: &str = row
836            .try_get("qualified_table_name")
837            .context("getting table column")?
838            .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
839
840        let capture_instance: &str = row
841            .try_get("capture_instance")
842            .context("getting capture_instance column")?
843            .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
844
845        let permitted_table: i32 = row
846            .try_get("table_select")
847            .context("getting table_select column")?
848            .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
849
850        let permitted_capture_instance: i32 = row
851            .try_get("capture_table_select")
852            .context("getting capture_table_select column")?
853            .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
854
855        if permitted_table == 0 {
856            tables_without_perms.push(table.to_string());
857        }
858
859        if permitted_capture_instance == 0 {
860            capture_instances_without_perms.push(capture_instance.to_string());
861        }
862    }
863
864    if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
865        return Err(SqlServerError::AuthorizationError {
866            tables: tables_without_perms.join(", "),
867            capture_instances: capture_instances_without_perms.join(", "),
868        });
869    }
870
871    Ok(())
872}