mz_sql_server_util/
inspect.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Useful queries to inspect the state of a SQL Server instance.
11
12use itertools::Itertools;
13use smallvec::SmallVec;
14use std::sync::Arc;
15
16use crate::cdc::{Lsn, RowFilterOption};
17use crate::{Client, SqlServerError};
18
19/// Returns the minimum log sequence number for the specified `capture_instance`.
20///
21/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-min-lsn-transact-sql?view=sql-server-ver16>
22pub async fn get_min_lsn(
23    client: &mut Client,
24    capture_instance: &str,
25) -> Result<Lsn, SqlServerError> {
26    static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
27    let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
28
29    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
30    parse_lsn(&result[..1])
31}
32
33/// Returns the maximum log sequence number for the entire database.
34///
35/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-get-max-lsn-transact-sql?view=sql-server-ver16>
36pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
37    static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
38    let result = client.simple_query(MAX_LSN_QUERY).await?;
39
40    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
41    parse_lsn(&result[..1])
42}
43
44/// Increments the log sequence number.
45///
46/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16>
47pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
48    static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
49    let result = client
50        .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes()])
51        .await?;
52
53    mz_ore::soft_assert_eq_or_log!(result.len(), 1);
54    parse_lsn(&result[..1])
55}
56
57/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
58///
59/// Returns an error if the provided slice doesn't have exactly one row.
60fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
61    match result {
62        [row] => {
63            let val = row
64                .try_get::<&[u8], _>(0)?
65                .ok_or_else(|| SqlServerError::InvalidData {
66                    column_name: "lsn".to_string(),
67                    error: "expected LSN at column 0, but found Null".to_string(),
68                })?;
69            let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
70                column_name: "lsn".to_string(),
71                error: msg,
72            })?;
73
74            Ok(lsn)
75        }
76        other => Err(SqlServerError::InvalidData {
77            column_name: "lsn".to_string(),
78            error: format!("expected 1 column, got {other:?}"),
79        }),
80    }
81}
82
83/// Queries the specified capture instance and returns all changes from `start_lsn` to `end_lsn`.
84///
85/// TODO(sql_server1): This presents an opportunity for SQL injection. We should create a stored
86/// procedure using `QUOTENAME` to sanitize the input for the capture instance provided by the
87/// user.
88pub async fn get_changes(
89    client: &mut Client,
90    capture_instance: &str,
91    start_lsn: Lsn,
92    end_lsn: Lsn,
93    filter: RowFilterOption,
94) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
95    let query = format!(
96        "SELECT * FROM cdc.fn_cdc_get_all_changes_{capture_instance}(@P1, @P2, N'{filter}');"
97    );
98    let results = client
99        .query(&query, &[&start_lsn.as_bytes(), &end_lsn.as_bytes()])
100        .await?;
101
102    Ok(results)
103}
104
105/// Returns the `(capture_instance, schema_name, table_name)` for the tables
106/// that are tracked by the specified `capture_instance`s.
107pub async fn get_tables_for_capture_instance<'a>(
108    client: &mut Client,
109    capture_instances: impl IntoIterator<Item = &str>,
110) -> Result<Vec<(Arc<str>, Arc<str>, Arc<str>)>, SqlServerError> {
111    // SQL Server does not have support for array types, so we need to manually construct
112    // the parameterized query.
113    let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
114    // TODO(sql_server3): Remove this redundant collection.
115    #[allow(clippy::as_conversions)]
116    let params_dyn: SmallVec<[_; 1]> = params
117        .iter()
118        .map(|instance| instance as &dyn tiberius::ToSql)
119        .collect();
120    let param_indexes = params
121        .iter()
122        .enumerate()
123        // Params are 1-based indexed.
124        .map(|(idx, _)| format!("@P{}", idx + 1))
125        .join(", ");
126
127    let table_for_capture_instance_query = format!(
128        "
129SELECT c.capture_instance, SCHEMA_NAME(o.schema_id) as schema_name, o.name as obj_name
130FROM sys.objects o
131JOIN cdc.change_tables c
132ON o.object_id = c.source_object_id
133WHERE c.capture_instance IN ({param_indexes});"
134    );
135
136    let result = client
137        .query(&table_for_capture_instance_query, &params_dyn[..])
138        .await?;
139    let tables = result
140        .into_iter()
141        .map(|row| {
142            let capture_instance: &str = row.try_get("capture_instance")?.ok_or_else(|| {
143                SqlServerError::ProgrammingError("missing column 'capture_instance'".to_string())
144            })?;
145            let schema_name: &str = row.try_get("schema_name")?.ok_or_else(|| {
146                SqlServerError::ProgrammingError("missing column 'schema_name'".to_string())
147            })?;
148            let table_name: &str = row.try_get("obj_name")?.ok_or_else(|| {
149                SqlServerError::ProgrammingError("missing column 'schema_name'".to_string())
150            })?;
151
152            Ok::<_, SqlServerError>((
153                capture_instance.into(),
154                schema_name.into(),
155                table_name.into(),
156            ))
157        })
158        .collect::<Result<_, _>>()?;
159
160    Ok(tables)
161}