mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27    SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName, SqlServerTableRaw,
28};
29use crate::{Client, SqlServerError, quote_identifier};
30
31/// Returns the minimum log sequence number for the specified `capture_instance`.
32///
33/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
34pub async fn get_min_lsn(
35    client: &mut Client,
36    capture_instance: &str,
37) -> Result<Lsn, SqlServerError> {
38    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
39    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
40
41    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
42    parse_lsn(&result[..1])
43}
44/// Returns the minimum log sequence number for the specified `capture_instance`, retrying
45/// if the log sequence number is not available.
46///
47/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
48pub async fn get_min_lsn_retry(
49    client: &mut Client,
50    capture_instance: &str,
51    max_retry_duration: Duration,
52) -> Result<Lsn, SqlServerError> {
53    let (_client, lsn_result) = mz_ore::retry::Retry::default()
54        .max_duration(max_retry_duration)
55        .retry_async_with_state(client, |_, client| async {
56            let result = crate::inspect::get_min_lsn(client, capture_instance).await;
57            (client, map_null_lsn_to_retry(result))
58        })
59        .await;
60    let Ok(lsn) = lsn_result else {
61        tracing::warn!("database did not report a minimum LSN in time");
62        return lsn_result;
63    };
64    Ok(lsn)
65}
66
67/// Returns the maximum log sequence number for the entire database.
68/// This implementation relies on CDC, which is asynchronous, so may
69/// return an LSN that is less than the maximum LSN of SQL server.
70///
71/// See:
72/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
73/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
74pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
75    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
76    let result = client.simple_query(MAX_LSN_QUERY).await?;
77
78    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
79    parse_lsn(&result[..1])
80}
81
82/// Retrieves the minumum [`Lsn`] (start_lsn field) from `cdc.change_tables`
83/// for the specified capture instances.
84///
85/// This is based on the `sys.fn_cdc_get_min_lsn` implementation, which has logic
86/// that we want to bypass. Specifically, `sys.fn_cdc_get_min_lsn` returns NULL
87/// if the `start_lsn` in `cdc.change_tables` is less than or equal to the LSN
88/// returned by `sys.fn_cdc_get_max_lsn`.
89///
90/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver16>
91pub async fn get_min_lsns(
92    client: &mut Client,
93    capture_instances: impl IntoIterator<Item = &str>,
94) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
95    let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
96    let values: Vec<_> = capture_instances
97        .iter()
98        .map(|ci| {
99            let ci: &dyn tiberius::ToSql = ci;
100            ci
101        })
102        .collect();
103    let args = (0..capture_instances.len())
104        .map(|i| format!("@P{}", i + 1))
105        .collect::<Vec<_>>()
106        .join(",");
107    let stmt = format!(
108        "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
109    );
110    let result = client.query(stmt, &values).await?;
111    let min_lsns = result
112        .into_iter()
113        .map(|row| {
114            let capture_instance: Arc<str> = row
115                .try_get::<&str, _>("capture_instance")?
116                .ok_or_else(|| {
117                    SqlServerError::ProgrammingError(
118                        "missing column 'capture_instance'".to_string(),
119                    )
120                })?
121                .into();
122            let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
123                SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
124            })?;
125            let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
126                column_name: "lsn".to_string(),
127                error: format!("Error parsing LSN for {capture_instance}: {msg}"),
128            })?;
129            Ok::<_, SqlServerError>((capture_instance, min_lsn))
130        })
131        .collect::<Result<_, _>>()?;
132
133    Ok(min_lsns)
134}
135
136/// Returns the maximum log sequence number for the entire database, retrying
137/// if the log sequence number is not available. This implementation relies on
138/// CDC, which is asynchronous, so may return an LSN that is less than the
139/// maximum LSN of SQL server.
140///
141/// See:
142/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
143/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
144pub async fn get_max_lsn_retry(
145    client: &mut Client,
146    max_retry_duration: Duration,
147) -> Result<Lsn, SqlServerError> {
148    let (_client, lsn_result) = mz_ore::retry::Retry::default()
149        .max_duration(max_retry_duration)
150        .retry_async_with_state(client, |_, client| async {
151            let result = crate::inspect::get_max_lsn(client).await;
152            (client, map_null_lsn_to_retry(result))
153        })
154        .await;
155
156    let Ok(lsn) = lsn_result else {
157        tracing::warn!("database did not report a maximum LSN in time");
158        return lsn_result;
159    };
160    Ok(lsn)
161}
162
163fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
164    match result {
165        Ok(val) => RetryResult::Ok(val),
166        Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
167        Err(other) => RetryResult::FatalErr(other),
168    }
169}
170
171/// Increments the log sequence number.
172///
173/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
174pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
175    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
176    let result = client
177        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
178        .await?;
179
180    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
181    parse_lsn(&result[..1])
182}
183
184/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
185///
186/// Returns an error if the provided slice doesn't have exactly one row.
187pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
188    match row {
189        [r] => {
190            let numeric_lsn = r
191                .try_get::<Numeric, _>(0)?
192                .ok_or_else(|| SqlServerError::NullLsn)?;
193            let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
194                column_name: "lsn".to_string(),
195                error: msg,
196            })?;
197            Ok(lsn)
198        }
199        other => Err(SqlServerError::InvalidData {
200            column_name: "lsn".to_string(),
201            error: format!("expected 1 column, got {other:?}"),
202        }),
203    }
204}
205
206/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
207///
208/// Returns an error if the provided slice doesn't have exactly one row.
209fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
210    match result {
211        [row] => {
212            let val = row
213                .try_get::<&[u8], _>(0)?
214                .ok_or_else(|| SqlServerError::NullLsn)?;
215            if val.is_empty() {
216                Err(SqlServerError::NullLsn)
217            } else {
218                let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
219                    column_name: "lsn".to_string(),
220                    error: msg,
221                })?;
222                Ok(lsn)
223            }
224        }
225        other => Err(SqlServerError::InvalidData {
226            column_name: "lsn".to_string(),
227            error: format!("expected 1 column, got {other:?}"),
228        }),
229    }
230}
231
232/// Queries the specified capture instance and returns all changes from
233/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
234pub fn get_changes_asc(
235    client: &mut Client,
236    capture_instance: &str,
237    start_lsn: Lsn,
238    end_lsn: Lsn,
239    filter: RowFilterOption,
240) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
241    const START_LSN_COLUMN: &str = "__$start_lsn";
242    let query = format!(
243        "SELECT * FROM cdc.{function}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;",
244        function = quote_identifier(&format!("fn_cdc_get_all_changes_{capture_instance}"))
245    );
246    client.query_streaming(
247        query,
248        &[
249            &start_lsn.as_bytes().as_slice(),
250            &end_lsn.as_bytes().as_slice(),
251        ],
252    )
253}
254
255/// Cleans up the change table associated with the specified `capture_instance` by
256/// deleting `max_deletes` entries with a `start_lsn` less than `low_water_mark`.
257///
258/// Note: At the moment cleanup is kind of "best effort".  If this query succeeds
259/// then at most `max_delete` rows were deleted, but the number of actual rows
260/// deleted is not returned as part of the query. The number of rows _should_ be
261/// present in an informational message (i.e. a Notice) that is returned, but
262/// [`tiberius`] doesn't expose these to us.
263///
264/// TODO(sql_server2): Update [`tiberius`] to return informational messages so we
265/// can determine how many rows got deleted.
266///
267/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16>.
268pub async fn cleanup_change_table(
269    client: &mut Client,
270    capture_instance: &str,
271    low_water_mark: &Lsn,
272    max_deletes: u32,
273) -> Result<(), SqlServerError> {
274    static GET_LSN_QUERY: &str =
275        "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
276    static CLEANUP_QUERY: &str = "
277DECLARE @mz_cleanup_status_bit BIT;
278SET @mz_cleanup_status_bit = 0;
279EXEC sys.sp_cdc_cleanup_change_table
280    @capture_instance = @P1,
281    @low_water_mark = @P2,
282    @threshold = @P3,
283    @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
284SELECT @mz_cleanup_status_bit;
285    ";
286
287    let max_deletes = i64::cast_from(max_deletes);
288
289    // First we need to get a valid LSN as our low watermark. If we try to cleanup
290    // a change table with an LSN that doesn't exist in the `cdc.lsn_time_mapping`
291    // table we'll get an error code `22964`.
292    let result = client
293        .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
294        .await?;
295    let low_water_mark_to_use = match &result[..] {
296        [row] => row
297            .try_get::<&[u8], _>(0)?
298            .ok_or_else(|| SqlServerError::InvalidData {
299                column_name: "mz_cleanup_status_bit".to_string(),
300                error: "expected a bool, found NULL".to_string(),
301            })?,
302        other => Err(SqlServerError::ProgrammingError(format!(
303            "expected one row for low water mark, found {other:?}"
304        )))?,
305    };
306
307    // Once we get a valid LSN that is less than or equal to the provided watermark
308    // we can clean up the specified change table!
309    let result = client
310        .query(
311            CLEANUP_QUERY,
312            &[&capture_instance, &low_water_mark_to_use, &max_deletes],
313        )
314        .await;
315
316    let rows = match result {
317        Ok(rows) => rows,
318        Err(SqlServerError::SqlServer(e)) => {
319            // See these remarks from the SQL Server Documentation.
320            //
321            // <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16#remarks>.
322            let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
323
324            if already_cleaned_up {
325                return Ok(());
326            } else {
327                return Err(SqlServerError::SqlServer(e));
328            }
329        }
330        Err(other) => return Err(other),
331    };
332
333    match &rows[..] {
334        [row] => {
335            let failure =
336                row.try_get::<bool, _>(0)?
337                    .ok_or_else(|| SqlServerError::InvalidData {
338                        column_name: "mz_cleanup_status_bit".to_string(),
339                        error: "expected a bool, found NULL".to_string(),
340                    })?;
341
342            if failure {
343                Err(super::cdc::CdcError::CleanupFailed {
344                    capture_instance: capture_instance.to_string(),
345                    low_water_mark: *low_water_mark,
346                })?
347            } else {
348                Ok(())
349            }
350        }
351        other => Err(SqlServerError::ProgrammingError(format!(
352            "expected one status row, found {other:?}"
353        ))),
354    }
355}
356
357// Retrieves all columns in tables that have CDC (Change Data Capture) enabled.
358//
359// Returns metadata needed to create an instance of ['SqlServerTableRaw`].
360//
361// The query joins several system tables:
362// - sys.tables: Source tables in the database
363// - sys.schemas: Schema information for proper table identification
364// - sys.columns: Column definitions including nullability
365// - sys.types: Data type information for each column
366// - cdc.change_tables: CDC configuration linking capture instances to source tables
367// - information_schema views: To identify primary key constraints
368//
369// For each column, it returns:
370// - Table identification (schema_name, table_name, capture_instance)
371// - Column metadata (name, type, nullable, max_length, precision, scale)
372// - Primary key information (constraint name if the column is part of a PK)
373static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
374SELECT
375    s.name as schema_name,
376    t.name as table_name,
377    ch.capture_instance as capture_instance,
378    ch.create_date as capture_instance_create_date,
379    c.name as col_name,
380    ty.name as col_type,
381    c.is_nullable as col_nullable,
382    c.max_length as col_max_length,
383    c.precision as col_precision,
384    c.scale as col_scale,
385    c.is_computed as col_is_computed,
386    tc.constraint_name AS col_primary_key_constraint
387FROM sys.tables t
388JOIN sys.schemas s ON t.schema_id = s.schema_id
389JOIN sys.columns c ON t.object_id = c.object_id
390JOIN sys.types ty ON c.user_type_id = ty.user_type_id
391JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
392LEFT JOIN information_schema.key_column_usage kc
393    ON kc.table_schema = s.name
394    AND kc.table_name = t.name
395    AND kc.column_name = c.name
396LEFT JOIN information_schema.table_constraints tc
397    ON tc.constraint_catalog = kc.constraint_catalog
398    AND tc.constraint_schema = kc.constraint_schema
399    AND tc.constraint_name = kc.constraint_name
400    AND tc.table_schema = kc.table_schema
401    AND tc.table_name = kc.table_name
402    AND tc.constraint_type = 'PRIMARY KEY'
403";
404
405/// Returns the table metadata for the tables that are tracked by the specified `capture_instance`s.
406pub async fn get_tables_for_capture_instance<'a>(
407    client: &mut Client,
408    capture_instances: impl IntoIterator<Item = &str>,
409) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
410    // SQL Server does not have support for array types, so we need to manually construct
411    // the parameterized query.
412    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
413    // If there are no tables to check for just return an empty list.
414    if params.is_empty() {
415        return Ok(Vec::default());
416    }
417
418    // TODO(sql_server3): Remove this redundant collection.
419    #[allow(clippy::as_conversions)]
420    let params_dyn: SmallVec<[_; 1]> = params
421        .iter()
422        .map(|instance| instance as &dyn tiberius::ToSql)
423        .collect();
424    let param_indexes = params
425        .iter()
426        .enumerate()
427        // Params are 1-based indexed.
428        .map(|(idx, _)| format!("@P{}", idx + 1))
429        .join(", ");
430
431    let table_for_capture_instance_query = format!(
432        "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
433    );
434
435    let result = client
436        .query(&table_for_capture_instance_query, &params_dyn[..])
437        .await?;
438
439    let tables = deserialize_table_columns_to_raw_tables(&result)?;
440
441    Ok(tables)
442}
443
444/// Retrieves column metdata from the CDC table maintained by the provided capture instance. The
445/// resulting column information collection is similar to the information collected for the
446/// upstream table, with the exclusion of nullability and primary key constraints, which contain
447/// static values for CDC columns. CDC table schema is automatically generated and does not attempt
448/// to enforce the same constraints on the data as the upstream table.
449pub async fn get_cdc_table_columns(
450    client: &mut Client,
451    capture_instance: &str,
452) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
453    static CDC_COLUMNS_QUERY: &str = "SELECT \
454        c.name AS col_name, \
455        t.name AS col_type, \
456        c.max_length AS col_max_length, \
457        c.precision AS col_precision, \
458        c.scale AS col_scale, \
459        c.is_computed as col_is_computed \
460    FROM \
461        sys.columns AS c \
462    JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
463    WHERE \
464        c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
465    ORDER BY c.column_id;";
466    // Strings passed into OBJECT_ID must be escaped
467    let cdc_table_name = format!(
468        "cdc.{table_name}",
469        table_name = quote_identifier(&format!("{capture_instance}_CT"))
470    );
471    let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
472    let mut columns = BTreeMap::new();
473    for row in result.iter() {
474        let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
475        // Reusing this struct even though some of the fields aren't needed because it simplifies
476        // comparison with the upstream table metadata
477        let column = SqlServerColumnRaw {
478            name: Arc::clone(&column_name),
479            data_type: get_value::<&str>(row, "col_type")?.into(),
480            is_nullable: true,
481            primary_key_constraint: None,
482            max_length: get_value(row, "col_max_length")?,
483            precision: get_value(row, "col_precision")?,
484            scale: get_value(row, "col_scale")?,
485            is_computed: get_value(row, "col_is_computed")?,
486        };
487        columns.insert(column_name, column);
488    }
489    Ok(columns)
490}
491
492/// Ensure change data capture (CDC) is enabled for the database the provided
493/// `client` is currently connected to.
494///
495/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/track-changes/enable-and-disable-change-data-capture-sql-server?view=sql-server-ver16>
496pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
497    static DATABASE_CDC_ENABLED_QUERY: &str =
498        "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
499    let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
500
501    check_system_result(&result, "database CDC".to_string(), true)?;
502    Ok(())
503}
504
505/// Retrieves the largest `restore_history_id` from SQL Server for the current database.  The
506/// `restore_history_id` column is of type `IDENTITY(1,1)` based on `EXEC sp_help restorehistory`.
507/// We expect it to start at 1 and be incremented by 1, with possible gaps in values.
508/// See:
509/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/restorehistory-transact-sql?view=sql-server-ver17>
510/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property?view=sql-server-ver17>
511pub async fn get_latest_restore_history_id(
512    client: &mut Client,
513) -> Result<Option<i32>, SqlServerError> {
514    static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
515        FROM msdb.dbo.restorehistory \
516        WHERE destination_database_name = DB_NAME() \
517        ORDER BY restore_history_id DESC;";
518    let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
519
520    match &result[..] {
521        [] => Ok(None),
522        [row] => Ok(row.try_get::<i32, _>(0)?),
523        other => Err(SqlServerError::InvariantViolated(format!(
524            "expected one row, got {other:?}"
525        ))),
526    }
527}
528
529/// A DDL event collected from the `cdc.ddl_history` table.
530#[derive(Debug)]
531pub struct DDLEvent {
532    pub lsn: Lsn,
533    pub ddl_command: Arc<str>,
534}
535
536impl DDLEvent {
537    /// Returns true if the DDL event is a compatible change, or false if it is not.
538    /// This performs a naive parsing of the DDL command looking for modification of columns
539    ///  1. ALTER TABLE .. ALTER COLUMN
540    ///  2. ALTER TABLE .. DROP COLUMN
541    ///
542    /// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/alter-table-transact-sql?view=sql-server-ver17>
543    pub fn is_compatible(&self) -> bool {
544        // TODO (maz): This is currently a basic check that doesn't take into account type changes.
545        // At some point, we will need to move this to SqlServerTableDesc and expand it.
546        let mut words = self.ddl_command.split_ascii_whitespace();
547        match (
548            words.next().map(str::to_ascii_lowercase).as_deref(),
549            words.next().map(str::to_ascii_lowercase).as_deref(),
550        ) {
551            (Some("alter"), Some("table")) => {
552                let mut peekable = words.peekable();
553                let mut compatible = true;
554                while compatible && let Some(token) = peekable.next() {
555                    compatible = match token.to_ascii_lowercase().as_str() {
556                        "alter" | "drop" => peekable
557                            .peek()
558                            .is_some_and(|next_tok| !next_tok.eq_ignore_ascii_case("column")),
559                        _ => true,
560                    }
561                }
562                compatible
563            }
564            _ => true,
565        }
566    }
567}
568
569/// Returns DDL changes made to the source table for the given capture instance.  This follows the
570/// same convention as `cdc.fn_cdc_get_all_changes_<capture_instance>`, in that the range is
571/// inclusive, i.e. `[from_lsn, to_lsn]`. The events are returned in ascending order of
572/// LSN.
573pub async fn get_ddl_history(
574    client: &mut Client,
575    capture_instance: &str,
576    from_lsn: &Lsn,
577    to_lsn: &Lsn,
578) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
579    // We query the ddl_history table instead of using the stored procedure as there doesn't
580    // appear to be a way to apply filters or projections against output of the stored procedure
581    // without an intermediate table.
582    static DDL_HISTORY_QUERY: &str = "SELECT \
583                s.name AS schema_name, \
584                t.name AS table_name, \
585                dh.ddl_lsn, \
586                dh.ddl_command
587            FROM \
588                cdc.change_tables ct \
589            JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
590            JOIN sys.tables t ON t.object_id = dh.source_object_id \
591            JOIN sys.schemas s ON s.schema_id = t.schema_id \
592            WHERE \
593                ct.capture_instance = @P1 \
594                AND dh.ddl_lsn >= @P2 \
595                AND dh.ddl_lsn <= @P3 \
596            ORDER BY ddl_lsn;";
597
598    let result = client
599        .query(
600            DDL_HISTORY_QUERY,
601            &[
602                &capture_instance,
603                &from_lsn.as_bytes().as_slice(),
604                &to_lsn.as_bytes().as_slice(),
605            ],
606        )
607        .await?;
608
609    // SQL server doesn't support array types, and using string_agg to collect LSN
610    // would require more parsing, so we opt for a BTreeMap to accumulate the results.
611    let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
612    for row in result.iter() {
613        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
614        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
615        let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
616        let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
617
618        let qualified_table_name = SqlServerQualifiedTableName {
619            schema_name,
620            table_name,
621        };
622        let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
623            column_name: "ddl_lsn".to_string(),
624            error: lsn_err,
625        })?;
626
627        collector
628            .entry(qualified_table_name)
629            .or_default()
630            .push(DDLEvent { lsn, ddl_command });
631    }
632
633    Ok(collector)
634}
635
636/// Ensure the `SNAPSHOT` transaction isolation level is enabled for the
637/// database the provided `client` is currently connected to.
638///
639/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql?view=sql-server-ver16>
640pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
641    static SNAPSHOT_ISOLATION_QUERY: &str =
642        "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
643    let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
644
645    check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
646    Ok(())
647}
648
649/// Ensure the SQL Server Agent is running.
650///
651/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-dynamic-management-views/sys-dm-server-services-transact-sql?view=azuresqldb-current&viewFallbackFrom=sql-server-ver17>
652pub async fn ensure_sql_server_agent_running(client: &mut Client) -> Result<(), SqlServerError> {
653    static AGENT_STATUS_QUERY: &str = "SELECT status_desc FROM sys.dm_server_services WHERE servicename LIKE 'SQL Server Agent%';";
654    let result = client.simple_query(AGENT_STATUS_QUERY).await?;
655
656    check_system_result(&result, "SQL Server Agent status".to_string(), "Running")?;
657    Ok(())
658}
659
660pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
661    let result = client
662        .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
663        .await?;
664
665    let tables = deserialize_table_columns_to_raw_tables(&result)?;
666
667    Ok(tables)
668}
669
670// Helper function to retrieve value from a row.
671fn get_value<'a, T: tiberius::FromSql<'a>>(
672    row: &'a tiberius::Row,
673    name: &'static str,
674) -> Result<T, SqlServerError> {
675    row.try_get(name)?
676        .ok_or(SqlServerError::MissingColumn(name))
677}
678
679fn deserialize_table_columns_to_raw_tables(
680    rows: &[tiberius::Row],
681) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
682    // Group our columns by (schema, name).
683    let mut tables = BTreeMap::default();
684    for row in rows {
685        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
686        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
687        let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
688        let capture_instance_create_date: NaiveDateTime =
689            get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
690        let primary_key_constraint: Option<Arc<str>> = row
691            .try_get::<&str, _>("col_primary_key_constraint")?
692            .map(|v| v.into());
693
694        let column_name = get_value::<&str>(row, "col_name")?.into();
695        let column = SqlServerColumnRaw {
696            name: Arc::clone(&column_name),
697            data_type: get_value::<&str>(row, "col_type")?.into(),
698            is_nullable: get_value(row, "col_nullable")?,
699            primary_key_constraint,
700            max_length: get_value(row, "col_max_length")?,
701            precision: get_value(row, "col_precision")?,
702            scale: get_value(row, "col_scale")?,
703            is_computed: get_value(row, "col_is_computed")?,
704        };
705
706        let columns: &mut Vec<_> = tables
707            .entry((
708                Arc::clone(&schema_name),
709                Arc::clone(&table_name),
710                Arc::clone(&capture_instance),
711                capture_instance_create_date,
712            ))
713            .or_default();
714        columns.push(column);
715    }
716
717    // Flatten into our raw Table description.
718    let raw_tables = tables
719        .into_iter()
720        .map(
721            |((schema, name, capture_instance, capture_instance_create_date), columns)| {
722                SqlServerTableRaw {
723                    schema_name: schema,
724                    name,
725                    capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
726                        name: capture_instance,
727                        create_date: capture_instance_create_date.into(),
728                    }),
729                    columns: columns.into(),
730                }
731            },
732        )
733        .collect::<Vec<SqlServerTableRaw>>();
734
735    Ok(raw_tables)
736}
737
738/// Return a [`Stream`] that is the entire snapshot of the specified table.
739pub fn snapshot(
740    client: &mut Client,
741    table: &SqlServerTableRaw,
742) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
743    let cols = table
744        .columns
745        .iter()
746        .map(|SqlServerColumnRaw { name, .. }| quote_identifier(name))
747        .join(",");
748    let query = format!(
749        "SELECT {cols} FROM {schema_name}.{table_name};",
750        schema_name = quote_identifier(&table.schema_name),
751        table_name = quote_identifier(&table.name)
752    );
753    client.query_streaming(query, &[])
754}
755
756/// Returns the total number of rows present in the specified table.
757pub async fn snapshot_size(
758    client: &mut Client,
759    schema: &str,
760    table: &str,
761) -> Result<usize, SqlServerError> {
762    let query = format!(
763        "SELECT COUNT(*) FROM {schema_name}.{table_name};",
764        schema_name = quote_identifier(schema),
765        table_name = quote_identifier(table)
766    );
767    let result = client.query(query, &[]).await?;
768
769    match &result[..] {
770        [row] => match row.try_get::<i32, _>(0)? {
771            Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
772            Some(negative) => Err(SqlServerError::InvalidData {
773                column_name: "count".to_string(),
774                error: format!("found negative count: {negative}"),
775            }),
776            None => Err(SqlServerError::InvalidData {
777                column_name: "count".to_string(),
778                error: "expected a value found NULL".to_string(),
779            }),
780        },
781        other => Err(SqlServerError::InvariantViolated(format!(
782            "expected one row, got {other:?}"
783        ))),
784    }
785}
786
787/// Helper function to parse an expected result from a "system" query.
788fn check_system_result<'a, T>(
789    result: &'a SmallVec<[tiberius::Row; 1]>,
790    name: String,
791    expected: T,
792) -> Result<(), SqlServerError>
793where
794    T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
795{
796    match &result[..] {
797        [row] => {
798            let result: Option<T> = row.try_get(0)?;
799            if result == Some(expected) {
800                Ok(())
801            } else {
802                Err(SqlServerError::InvalidSystemSetting {
803                    name,
804                    expected: expected.to_string(),
805                    actual: format!("{result:?}"),
806                })
807            }
808        }
809        other => Err(SqlServerError::InvariantViolated(format!(
810            "expected 1 row, got {other:?}"
811        ))),
812    }
813}
814
815/// Return a Result that is empty if all tables, columns, and capture instances
816/// have the necessary permissions to and an error if any table, column,
817/// or capture instance does not have the necessary permissions
818/// for tracking changes.
819pub async fn validate_source_privileges<'a>(
820    client: &mut Client,
821    capture_instances: impl IntoIterator<Item = &str>,
822) -> Result<(), SqlServerError> {
823    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
824
825    if params.is_empty() {
826        return Ok(());
827    }
828
829    let params_dyn: SmallVec<[_; 1]> = params
830        .iter()
831        .map(|instance| {
832            let instance: &dyn tiberius::ToSql = instance;
833            instance
834        })
835        .collect();
836
837    let param_indexes = (1..params.len() + 1)
838        .map(|idx| format!("@P{}", idx))
839        .join(", ");
840
841    // NB(ptravers): we rely on HAS_PERMS_BY_NAME to check both table and column permissions.
842    let capture_instance_query = format!(
843            "
844        SELECT
845            SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
846            ct.capture_instance AS capture_instance,
847            COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
848            COALESCE(HAS_PERMS_BY_NAME('cdc.' + QUOTENAME(ct.capture_instance + '_CT') , 'OBJECT', 'SELECT'), 0) AS capture_table_select
849        FROM cdc.change_tables ct
850        JOIN sys.objects o ON o.object_id = ct.source_object_id
851        WHERE ct.capture_instance IN ({param_indexes});
852            "
853        );
854
855    let rows = client
856        .query(capture_instance_query, &params_dyn[..])
857        .await?;
858
859    let mut capture_instances_without_perms = vec![];
860    let mut tables_without_perms = vec![];
861
862    for row in rows {
863        let table: &str = row
864            .try_get("qualified_table_name")
865            .context("getting table column")?
866            .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
867
868        let capture_instance: &str = row
869            .try_get("capture_instance")
870            .context("getting capture_instance column")?
871            .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
872
873        let permitted_table: i32 = row
874            .try_get("table_select")
875            .context("getting table_select column")?
876            .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
877
878        let permitted_capture_instance: i32 = row
879            .try_get("capture_table_select")
880            .context("getting capture_table_select column")?
881            .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
882
883        if permitted_table == 0 {
884            tables_without_perms.push(table.to_string());
885        }
886
887        if permitted_capture_instance == 0 {
888            capture_instances_without_perms.push(capture_instance.to_string());
889        }
890    }
891
892    if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
893        return Err(SqlServerError::AuthorizationError {
894            tables: tables_without_perms.join(", "),
895            capture_instances: capture_instances_without_perms.join(", "),
896        });
897    }
898
899    Ok(())
900}