mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerTableRaw};
27use crate::{Client, SqlServerError};
28
29/// Returns the minimum log sequence number for the specified `capture_instance`.
30///
31/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
32pub async fn get_min_lsn(
33    client: &mut Client,
34    capture_instance: &str,
35) -> Result<Lsn, SqlServerError> {
36    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
37    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
38
39    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
40    parse_lsn(&result[..1])
41}
42/// Returns the minimum log sequence number for the specified `capture_instance`, retrying
43/// if the log sequence number is not available.
44///
45/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
46pub async fn get_min_lsn_retry(
47    client: &mut Client,
48    capture_instance: &str,
49    max_retry_duration: Duration,
50) -> Result<Lsn, SqlServerError> {
51    let (_client, lsn_result) = mz_ore::retry::Retry::default()
52        .max_duration(max_retry_duration)
53        .retry_async_with_state(client, |_, client| async {
54            let result = crate::inspect::get_min_lsn(client, capture_instance).await;
55            (client, map_null_lsn_to_retry(result))
56        })
57        .await;
58    let Ok(lsn) = lsn_result else {
59        tracing::warn!("database did not report a minimum LSN in time");
60        return lsn_result;
61    };
62    Ok(lsn)
63}
64
65/// Returns the maximum log sequence number for the entire database.
66/// This implementation relies on CDC, which is asynchronous, so may
67/// return an LSN that is less than the maximum LSN of SQL server.
68///
69/// See:
70/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
71/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
72pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
73    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
74    let result = client.simple_query(MAX_LSN_QUERY).await?;
75
76    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
77    parse_lsn(&result[..1])
78}
79
80/// Retrieves the minumum [`Lsn`] (start_lsn field) from `cdc.change_tables`
81/// for the specified capture instances.
82///
83/// This is based on the `sys.fn_cdc_get_min_lsn` implementation, which has logic
84/// that we want to bypass. Specifically, `sys.fn_cdc_get_min_lsn` returns NULL
85/// if the `start_lsn` in `cdc.change_tables` is less than or equal to the LSN
86/// returned by `sys.fn_cdc_get_max_lsn`.
87///
88/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver16>
89pub async fn get_min_lsns(
90    client: &mut Client,
91    capture_instances: impl IntoIterator<Item = &str>,
92) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
93    let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
94    let values: Vec<_> = capture_instances
95        .iter()
96        .map(|ci| {
97            let ci: &dyn tiberius::ToSql = ci;
98            ci
99        })
100        .collect();
101    let args = (0..capture_instances.len())
102        .map(|i| format!("@P{}", i + 1))
103        .collect::<Vec<_>>()
104        .join(",");
105    let stmt = format!(
106        "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
107    );
108    let result = client.query(stmt, &values).await?;
109    let min_lsns = result
110        .into_iter()
111        .map(|row| {
112            let capture_instance: Arc<str> = row
113                .try_get::<&str, _>("capture_instance")?
114                .ok_or_else(|| {
115                    SqlServerError::ProgrammingError(
116                        "missing column 'capture_instance'".to_string(),
117                    )
118                })?
119                .into();
120            let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
121                SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
122            })?;
123            let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
124                column_name: "lsn".to_string(),
125                error: format!("Error parsing LSN for {capture_instance}: {msg}"),
126            })?;
127            Ok::<_, SqlServerError>((capture_instance, min_lsn))
128        })
129        .collect::<Result<_, _>>()?;
130
131    Ok(min_lsns)
132}
133
134/// Returns the maximum log sequence number for the entire database, retrying
135/// if the log sequence number is not available. This implementation relies on
136/// CDC, which is asynchronous, so may return an LSN that is less than the
137/// maximum LSN of SQL server.
138///
139/// See:
140/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
141/// - <https://groups.google.com/g/debezium/c/47Yg2r166KM/m/lHqtRF2xAQAJ?pli=1>
142pub async fn get_max_lsn_retry(
143    client: &mut Client,
144    max_retry_duration: Duration,
145) -> Result<Lsn, SqlServerError> {
146    let (_client, lsn_result) = mz_ore::retry::Retry::default()
147        .max_duration(max_retry_duration)
148        .retry_async_with_state(client, |_, client| async {
149            let result = crate::inspect::get_max_lsn(client).await;
150            (client, map_null_lsn_to_retry(result))
151        })
152        .await;
153
154    let Ok(lsn) = lsn_result else {
155        tracing::warn!("database did not report a maximum LSN in time");
156        return lsn_result;
157    };
158    Ok(lsn)
159}
160
161fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
162    match result {
163        Ok(val) => RetryResult::Ok(val),
164        Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
165        Err(other) => RetryResult::FatalErr(other),
166    }
167}
168
169/// Increments the log sequence number.
170///
171/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
172pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
173    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
174    let result = client
175        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
176        .await?;
177
178    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
179    parse_lsn(&result[..1])
180}
181
182/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
183///
184/// Returns an error if the provided slice doesn't have exactly one row.
185pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
186    match row {
187        [r] => {
188            let numeric_lsn = r
189                .try_get::<Numeric, _>(0)?
190                .ok_or_else(|| SqlServerError::NullLsn)?;
191            let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
192                column_name: "lsn".to_string(),
193                error: msg,
194            })?;
195            Ok(lsn)
196        }
197        other => Err(SqlServerError::InvalidData {
198            column_name: "lsn".to_string(),
199            error: format!("expected 1 column, got {other:?}"),
200        }),
201    }
202}
203
204/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
205///
206/// Returns an error if the provided slice doesn't have exactly one row.
207fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
208    match result {
209        [row] => {
210            let val = row
211                .try_get::<&[u8], _>(0)?
212                .ok_or_else(|| SqlServerError::NullLsn)?;
213            if val.is_empty() {
214                Err(SqlServerError::NullLsn)
215            } else {
216                let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
217                    column_name: "lsn".to_string(),
218                    error: msg,
219                })?;
220                Ok(lsn)
221            }
222        }
223        other => Err(SqlServerError::InvalidData {
224            column_name: "lsn".to_string(),
225            error: format!("expected 1 column, got {other:?}"),
226        }),
227    }
228}
229
230/// Queries the specified capture instance and returns all changes from
231/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
232///
233/// TODO(sql_server2): This presents an opportunity for SQL injection. We should create a stored
234/// procedure using `QUOTENAME` to sanitize the input for the capture instance provided by the
235/// user.
236pub fn get_changes_asc(
237    client: &mut Client,
238    capture_instance: &str,
239    start_lsn: Lsn,
240    end_lsn: Lsn,
241    filter: RowFilterOption,
242) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
243    const START_LSN_COLUMN: &str = "__$start_lsn";
244    let query = format!(
245        "SELECT * FROM cdc.fn_cdc_get_all_changes_{capture_instance}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;"
246    );
247    client.query_streaming(
248        query,
249        &[
250            &start_lsn.as_bytes().as_slice(),
251            &end_lsn.as_bytes().as_slice(),
252        ],
253    )
254}
255
256/// Cleans up the change table associated with the specified `capture_instance` by
257/// deleting `max_deletes` entries with a `start_lsn` less than `low_water_mark`.
258///
259/// Note: At the moment cleanup is kind of "best effort".  If this query succeeds
260/// then at most `max_delete` rows were deleted, but the number of actual rows
261/// deleted is not returned as part of the query. The number of rows _should_ be
262/// present in an informational message (i.e. a Notice) that is returned, but
263/// [`tiberius`] doesn't expose these to us.
264///
265/// TODO(sql_server2): Update [`tiberius`] to return informational messages so we
266/// can determine how many rows got deleted.
267///
268/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16>.
269pub async fn cleanup_change_table(
270    client: &mut Client,
271    capture_instance: &str,
272    low_water_mark: &Lsn,
273    max_deletes: u32,
274) -> Result<(), SqlServerError> {
275    static GET_LSN_QUERY: &str =
276        "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
277    static CLEANUP_QUERY: &str = "
278DECLARE @mz_cleanup_status_bit BIT;
279SET @mz_cleanup_status_bit = 0;
280EXEC sys.sp_cdc_cleanup_change_table
281    @capture_instance = @P1,
282    @low_water_mark = @P2,
283    @threshold = @P3,
284    @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
285SELECT @mz_cleanup_status_bit;
286    ";
287
288    let max_deletes = i64::cast_from(max_deletes);
289
290    // First we need to get a valid LSN as our low watermark. If we try to cleanup
291    // a change table with an LSN that doesn't exist in the `cdc.lsn_time_mapping`
292    // table we'll get an error code `22964`.
293    let result = client
294        .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
295        .await?;
296    let low_water_mark_to_use = match &result[..] {
297        [row] => row
298            .try_get::<&[u8], _>(0)?
299            .ok_or_else(|| SqlServerError::InvalidData {
300                column_name: "mz_cleanup_status_bit".to_string(),
301                error: "expected a bool, found NULL".to_string(),
302            })?,
303        other => Err(SqlServerError::ProgrammingError(format!(
304            "expected one row for low water mark, found {other:?}"
305        )))?,
306    };
307
308    // Once we get a valid LSN that is less than or equal to the provided watermark
309    // we can clean up the specified change table!
310    let result = client
311        .query(
312            CLEANUP_QUERY,
313            &[&capture_instance, &low_water_mark_to_use, &max_deletes],
314        )
315        .await;
316
317    let rows = match result {
318        Ok(rows) => rows,
319        Err(SqlServerError::SqlServer(e)) => {
320            // See these remarks from the SQL Server Documentation.
321            //
322            // <https://learn.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sys-sp-cdc-cleanup-change-table-transact-sql?view=sql-server-ver16#remarks>.
323            let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
324
325            if already_cleaned_up {
326                return Ok(());
327            } else {
328                return Err(SqlServerError::SqlServer(e));
329            }
330        }
331        Err(other) => return Err(other),
332    };
333
334    match &rows[..] {
335        [row] => {
336            let failure =
337                row.try_get::<bool, _>(0)?
338                    .ok_or_else(|| SqlServerError::InvalidData {
339                        column_name: "mz_cleanup_status_bit".to_string(),
340                        error: "expected a bool, found NULL".to_string(),
341                    })?;
342
343            if failure {
344                Err(super::cdc::CdcError::CleanupFailed {
345                    capture_instance: capture_instance.to_string(),
346                    low_water_mark: *low_water_mark,
347                })?
348            } else {
349                Ok(())
350            }
351        }
352        other => Err(SqlServerError::ProgrammingError(format!(
353            "expected one status row, found {other:?}"
354        ))),
355    }
356}
357
358// Retrieves all columns in tables that have CDC (Change Data Capture) enabled.
359//
360// Returns metadata needed to create an instance of ['SqlServerTableRaw`].
361//
362// The query joins several system tables:
363// - sys.tables: Source tables in the database
364// - sys.schemas: Schema information for proper table identification
365// - sys.columns: Column definitions including nullability
366// - sys.types: Data type information for each column
367// - cdc.change_tables: CDC configuration linking capture instances to source tables
368// - information_schema views: To identify primary key constraints
369//
370// For each column, it returns:
371// - Table identification (schema_name, table_name, capture_instance)
372// - Column metadata (name, type, nullable, max_length, precision, scale)
373// - Primary key information (constraint name if the column is part of a PK)
374static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
375SELECT
376    s.name as schema_name,
377    t.name as table_name,
378    ch.capture_instance as capture_instance,
379    ch.create_date as capture_instance_create_date,
380    c.name as col_name,
381    ty.name as col_type,
382    c.is_nullable as col_nullable,
383    c.max_length as col_max_length,
384    c.precision as col_precision,
385    c.scale as col_scale,
386    tc.constraint_name AS col_primary_key_constraint
387FROM sys.tables t
388JOIN sys.schemas s ON t.schema_id = s.schema_id
389JOIN sys.columns c ON t.object_id = c.object_id
390JOIN sys.types ty ON c.user_type_id = ty.user_type_id
391JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
392LEFT JOIN information_schema.key_column_usage kc
393    ON kc.table_schema = s.name
394    AND kc.table_name = t.name
395    AND kc.column_name = c.name
396LEFT JOIN information_schema.table_constraints tc
397    ON tc.constraint_catalog = kc.constraint_catalog
398    AND tc.constraint_schema = kc.constraint_schema
399    AND tc.constraint_name = kc.constraint_name
400    AND tc.table_schema = kc.table_schema
401    AND tc.table_name = kc.table_name
402    AND tc.constraint_type = 'PRIMARY KEY'
403";
404
405/// Returns the table metadata for the tables that are tracked by the specified `capture_instance`s.
406pub async fn get_tables_for_capture_instance<'a>(
407    client: &mut Client,
408    capture_instances: impl IntoIterator<Item = &str>,
409) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
410    // SQL Server does not have support for array types, so we need to manually construct
411    // the parameterized query.
412    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
413    // If there are no tables to check for just return an empty list.
414    if params.is_empty() {
415        return Ok(Vec::default());
416    }
417
418    // TODO(sql_server3): Remove this redundant collection.
419    #[allow(clippy::as_conversions)]
420    let params_dyn: SmallVec<[_; 1]> = params
421        .iter()
422        .map(|instance| instance as &dyn tiberius::ToSql)
423        .collect();
424    let param_indexes = params
425        .iter()
426        .enumerate()
427        // Params are 1-based indexed.
428        .map(|(idx, _)| format!("@P{}", idx + 1))
429        .join(", ");
430
431    let table_for_capture_instance_query = format!(
432        "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
433    );
434
435    let result = client
436        .query(&table_for_capture_instance_query, &params_dyn[..])
437        .await?;
438
439    let tables = deserialize_table_columns_to_raw_tables(&result)?;
440
441    Ok(tables)
442}
443
444/// Ensure change data capture (CDC) is enabled for the database the provided
445/// `client` is currently connected to.
446///
447/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/track-changes/enable-and-disable-change-data-capture-sql-server?view=sql-server-ver16>
448pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
449    static DATABASE_CDC_ENABLED_QUERY: &str =
450        "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
451    let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
452
453    check_system_result(&result, "database CDC".to_string(), true)?;
454    Ok(())
455}
456
457/// Retrieves the largest `restore_history_id` from SQL Server for the current database.  The
458/// `restore_history_id` column is of type `IDENTITY(1,1)` based on `EXEC sp_help restorehistory`.
459/// We expect it to start at 1 and be incremented by 1, with possible gaps in values.
460/// See:
461/// - <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/restorehistory-transact-sql?view=sql-server-ver17>
462/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql-identity-property?view=sql-server-ver17>
463pub async fn get_latest_restore_history_id(
464    client: &mut Client,
465) -> Result<Option<i32>, SqlServerError> {
466    static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
467        FROM msdb.dbo.restorehistory \
468        WHERE destination_database_name = DB_NAME() \
469        ORDER BY restore_history_id DESC;";
470    let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
471
472    match &result[..] {
473        [] => Ok(None),
474        [row] => Ok(row.try_get::<i32, _>(0)?),
475        other => Err(SqlServerError::InvariantViolated(format!(
476            "expected one row, got {other:?}"
477        ))),
478    }
479}
480
481/// Ensure the `SNAPSHOT` transaction isolation level is enabled for the
482/// database the provided `client` is currently connected to.
483///
484/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql?view=sql-server-ver16>
485pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
486    static SNAPSHOT_ISOLATION_QUERY: &str =
487        "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
488    let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
489
490    check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
491    Ok(())
492}
493
494pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
495    let result = client
496        .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
497        .await?;
498
499    let tables = deserialize_table_columns_to_raw_tables(&result)?;
500
501    Ok(tables)
502}
503
504fn deserialize_table_columns_to_raw_tables(
505    rows: &[tiberius::Row],
506) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
507    fn get_value<'a, T: tiberius::FromSql<'a>>(
508        row: &'a tiberius::Row,
509        name: &'static str,
510    ) -> Result<T, SqlServerError> {
511        row.try_get(name)?
512            .ok_or(SqlServerError::MissingColumn(name))
513    }
514
515    // Group our columns by (schema, name).
516    let mut tables = BTreeMap::default();
517    for row in rows {
518        let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
519        let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
520        let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
521        let capture_instance_create_date: NaiveDateTime =
522            get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
523        let primary_key_constraint: Option<Arc<str>> = row
524            .try_get::<&str, _>("col_primary_key_constraint")?
525            .map(|v| v.into());
526
527        let column_name = get_value::<&str>(row, "col_name")?.into();
528        let column = SqlServerColumnRaw {
529            name: Arc::clone(&column_name),
530            data_type: get_value::<&str>(row, "col_type")?.into(),
531            is_nullable: get_value(row, "col_nullable")?,
532            primary_key_constraint,
533            max_length: get_value(row, "col_max_length")?,
534            precision: get_value(row, "col_precision")?,
535            scale: get_value(row, "col_scale")?,
536        };
537
538        let columns: &mut Vec<_> = tables
539            .entry((
540                Arc::clone(&schema_name),
541                Arc::clone(&table_name),
542                Arc::clone(&capture_instance),
543                capture_instance_create_date,
544            ))
545            .or_default();
546        columns.push(column);
547    }
548
549    // Flatten into our raw Table description.
550    let raw_tables = tables
551        .into_iter()
552        .map(
553            |((schema, name, capture_instance, capture_instance_create_date), columns)| {
554                SqlServerTableRaw {
555                    schema_name: schema,
556                    name,
557                    capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
558                        name: capture_instance,
559                        create_date: capture_instance_create_date.into(),
560                    }),
561                    columns: columns.into(),
562                }
563            },
564        )
565        .collect::<Vec<SqlServerTableRaw>>();
566
567    Ok(raw_tables)
568}
569
570/// Return a [`Stream`] that is the entire snapshot of the specified table.
571pub fn snapshot(
572    client: &mut Client,
573    schema: &str,
574    table: &str,
575) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
576    let query = format!("SELECT * FROM {schema}.{table};");
577    client.query_streaming(query, &[])
578}
579
580/// Returns the total number of rows present in the specified table.
581pub async fn snapshot_size(
582    client: &mut Client,
583    schema: &str,
584    table: &str,
585) -> Result<usize, SqlServerError> {
586    let query = format!("SELECT COUNT(*) FROM {schema}.{table};");
587    let result = client.query(query, &[]).await?;
588
589    match &result[..] {
590        [row] => match row.try_get::<i32, _>(0)? {
591            Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
592            Some(negative) => Err(SqlServerError::InvalidData {
593                column_name: "count".to_string(),
594                error: format!("found negative count: {negative}"),
595            }),
596            None => Err(SqlServerError::InvalidData {
597                column_name: "count".to_string(),
598                error: "expected a value found NULL".to_string(),
599            }),
600        },
601        other => Err(SqlServerError::InvariantViolated(format!(
602            "expected one row, got {other:?}"
603        ))),
604    }
605}
606
607/// Helper function to parse an expected result from a "system" query.
608fn check_system_result<'a, T>(
609    result: &'a SmallVec<[tiberius::Row; 1]>,
610    name: String,
611    expected: T,
612) -> Result<(), SqlServerError>
613where
614    T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
615{
616    match &result[..] {
617        [row] => {
618            let result: Option<T> = row.try_get(0)?;
619            if result == Some(expected) {
620                Ok(())
621            } else {
622                Err(SqlServerError::InvalidSystemSetting {
623                    name,
624                    expected: expected.to_string(),
625                    actual: format!("{result:?}"),
626                })
627            }
628        }
629        other => Err(SqlServerError::InvariantViolated(format!(
630            "expected 1 row, got {other:?}"
631        ))),
632    }
633}
634
635/// Return a Result that is empty if all tables, columns, and capture instances
636/// have the necessary permissions to and an error if any table, column,
637/// or capture instance does not have the necessary permissions
638/// for tracking changes.
639pub async fn validate_source_privileges<'a>(
640    client: &mut Client,
641    capture_instances: impl IntoIterator<Item = &str>,
642) -> Result<(), SqlServerError> {
643    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
644
645    if params.is_empty() {
646        return Ok(());
647    }
648
649    let params_dyn: SmallVec<[_; 1]> = params
650        .iter()
651        .map(|instance| {
652            let instance: &dyn tiberius::ToSql = instance;
653            instance
654        })
655        .collect();
656
657    let param_indexes = (1..params.len() + 1)
658        .map(|idx| format!("@P{}", idx))
659        .join(", ");
660
661    // NB(ptravers): we rely on HAS_PERMS_BY_NAME to check both table and column permissions.
662    let capture_instance_query = format!(
663            "
664        SELECT
665            SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
666            ct.capture_instance AS capture_instance,
667            COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
668            COALESCE(HAS_PERMS_BY_NAME('cdc.' + ct.capture_instance + '_CT', 'OBJECT', 'SELECT'), 0) AS capture_table_select
669        FROM cdc.change_tables ct
670        JOIN sys.objects o ON o.object_id = ct.source_object_id
671        WHERE ct.capture_instance IN ({param_indexes});
672            "
673        );
674
675    let rows = client
676        .query(capture_instance_query, &params_dyn[..])
677        .await?;
678
679    let mut capture_instances_without_perms = vec![];
680    let mut tables_without_perms = vec![];
681
682    for row in rows {
683        let table: &str = row
684            .try_get("qualified_table_name")
685            .context("getting table column")?
686            .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
687
688        let capture_instance: &str = row
689            .try_get("capture_instance")
690            .context("getting capture_instance column")?
691            .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
692
693        let permitted_table: i32 = row
694            .try_get("table_select")
695            .context("getting table_select column")?
696            .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
697
698        let permitted_capture_instance: i32 = row
699            .try_get("capture_table_select")
700            .context("getting capture_table_select column")?
701            .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
702
703        if permitted_table == 0 {
704            tables_without_perms.push(table.to_string());
705        }
706
707        if permitted_capture_instance == 0 {
708            capture_instances_without_perms.push(capture_instance.to_string());
709        }
710    }
711
712    if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
713        return Err(SqlServerError::AuthorizationError {
714            tables: tables_without_perms.join(", "),
715            capture_instances: capture_instances_without_perms.join(", "),
716        });
717    }
718
719    Ok(())
720}