1use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27 SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName, SqlServerTableRaw,
28};
29use crate::{Client, SqlServerError, quote_identifier};
30
31pub async fn get_min_lsn(
35 client: &mut Client,
36 capture_instance: &str,
37) -> Result<Lsn, SqlServerError> {
38 static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
39 let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
40
41 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
42 parse_lsn(&result[..1])
43}
44pub async fn get_min_lsn_retry(
49 client: &mut Client,
50 capture_instance: &str,
51 max_retry_duration: Duration,
52) -> Result<Lsn, SqlServerError> {
53 let (_client, lsn_result) = mz_ore::retry::Retry::default()
54 .max_duration(max_retry_duration)
55 .retry_async_with_state(client, |_, client| async {
56 let result = crate::inspect::get_min_lsn(client, capture_instance).await;
57 (client, map_null_lsn_to_retry(result))
58 })
59 .await;
60 let Ok(lsn) = lsn_result else {
61 tracing::warn!("database did not report a minimum LSN in time");
62 return lsn_result;
63 };
64 Ok(lsn)
65}
66
67pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
75 static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
76 let result = client.simple_query(MAX_LSN_QUERY).await?;
77
78 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
79 parse_lsn(&result[..1])
80}
81
82pub async fn get_min_lsns(
92 client: &mut Client,
93 capture_instances: impl IntoIterator<Item = &str>,
94) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
95 let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
96 let values: Vec<_> = capture_instances
97 .iter()
98 .map(|ci| {
99 let ci: &dyn tiberius::ToSql = ci;
100 ci
101 })
102 .collect();
103 let args = (0..capture_instances.len())
104 .map(|i| format!("@P{}", i + 1))
105 .collect::<Vec<_>>()
106 .join(",");
107 let stmt = format!(
108 "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
109 );
110 let result = client.query(stmt, &values).await?;
111 let min_lsns = result
112 .into_iter()
113 .map(|row| {
114 let capture_instance: Arc<str> = row
115 .try_get::<&str, _>("capture_instance")?
116 .ok_or_else(|| {
117 SqlServerError::ProgrammingError(
118 "missing column 'capture_instance'".to_string(),
119 )
120 })?
121 .into();
122 let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
123 SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
124 })?;
125 let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
126 column_name: "lsn".to_string(),
127 error: format!("Error parsing LSN for {capture_instance}: {msg}"),
128 })?;
129 Ok::<_, SqlServerError>((capture_instance, min_lsn))
130 })
131 .collect::<Result<_, _>>()?;
132
133 Ok(min_lsns)
134}
135
136pub async fn get_max_lsn_retry(
145 client: &mut Client,
146 max_retry_duration: Duration,
147) -> Result<Lsn, SqlServerError> {
148 let (_client, lsn_result) = mz_ore::retry::Retry::default()
149 .max_duration(max_retry_duration)
150 .retry_async_with_state(client, |_, client| async {
151 let result = crate::inspect::get_max_lsn(client).await;
152 (client, map_null_lsn_to_retry(result))
153 })
154 .await;
155
156 let Ok(lsn) = lsn_result else {
157 tracing::warn!("database did not report a maximum LSN in time");
158 return lsn_result;
159 };
160 Ok(lsn)
161}
162
163fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
164 match result {
165 Ok(val) => RetryResult::Ok(val),
166 Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
167 Err(other) => RetryResult::FatalErr(other),
168 }
169}
170
171pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
175 static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
176 let result = client
177 .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
178 .await?;
179
180 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
181 parse_lsn(&result[..1])
182}
183
184pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
188 match row {
189 [r] => {
190 let numeric_lsn = r
191 .try_get::<Numeric, _>(0)?
192 .ok_or_else(|| SqlServerError::NullLsn)?;
193 let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
194 column_name: "lsn".to_string(),
195 error: msg,
196 })?;
197 Ok(lsn)
198 }
199 other => Err(SqlServerError::InvalidData {
200 column_name: "lsn".to_string(),
201 error: format!("expected 1 column, got {other:?}"),
202 }),
203 }
204}
205
206fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
210 match result {
211 [row] => {
212 let val = row
213 .try_get::<&[u8], _>(0)?
214 .ok_or_else(|| SqlServerError::NullLsn)?;
215 if val.is_empty() {
216 Err(SqlServerError::NullLsn)
217 } else {
218 let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
219 column_name: "lsn".to_string(),
220 error: msg,
221 })?;
222 Ok(lsn)
223 }
224 }
225 other => Err(SqlServerError::InvalidData {
226 column_name: "lsn".to_string(),
227 error: format!("expected 1 column, got {other:?}"),
228 }),
229 }
230}
231
232pub fn get_changes_asc(
235 client: &mut Client,
236 capture_instance: &str,
237 start_lsn: Lsn,
238 end_lsn: Lsn,
239 filter: RowFilterOption,
240) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
241 const START_LSN_COLUMN: &str = "__$start_lsn";
242 let query = format!(
243 "SELECT * FROM cdc.{function}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;",
244 function = quote_identifier(&format!("fn_cdc_get_all_changes_{capture_instance}"))
245 );
246 client.query_streaming(
247 query,
248 &[
249 &start_lsn.as_bytes().as_slice(),
250 &end_lsn.as_bytes().as_slice(),
251 ],
252 )
253}
254
255pub async fn cleanup_change_table(
269 client: &mut Client,
270 capture_instance: &str,
271 low_water_mark: &Lsn,
272 max_deletes: u32,
273) -> Result<(), SqlServerError> {
274 static GET_LSN_QUERY: &str =
275 "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
276 static CLEANUP_QUERY: &str = "
277DECLARE @mz_cleanup_status_bit BIT;
278SET @mz_cleanup_status_bit = 0;
279EXEC sys.sp_cdc_cleanup_change_table
280 @capture_instance = @P1,
281 @low_water_mark = @P2,
282 @threshold = @P3,
283 @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
284SELECT @mz_cleanup_status_bit;
285 ";
286
287 let max_deletes = i64::cast_from(max_deletes);
288
289 let result = client
293 .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
294 .await?;
295 let low_water_mark_to_use = match &result[..] {
296 [row] => row
297 .try_get::<&[u8], _>(0)?
298 .ok_or_else(|| SqlServerError::InvalidData {
299 column_name: "mz_cleanup_status_bit".to_string(),
300 error: "expected a bool, found NULL".to_string(),
301 })?,
302 other => Err(SqlServerError::ProgrammingError(format!(
303 "expected one row for low water mark, found {other:?}"
304 )))?,
305 };
306
307 let result = client
310 .query(
311 CLEANUP_QUERY,
312 &[&capture_instance, &low_water_mark_to_use, &max_deletes],
313 )
314 .await;
315
316 let rows = match result {
317 Ok(rows) => rows,
318 Err(SqlServerError::SqlServer(e)) => {
319 let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
323
324 if already_cleaned_up {
325 return Ok(());
326 } else {
327 return Err(SqlServerError::SqlServer(e));
328 }
329 }
330 Err(other) => return Err(other),
331 };
332
333 match &rows[..] {
334 [row] => {
335 let failure =
336 row.try_get::<bool, _>(0)?
337 .ok_or_else(|| SqlServerError::InvalidData {
338 column_name: "mz_cleanup_status_bit".to_string(),
339 error: "expected a bool, found NULL".to_string(),
340 })?;
341
342 if failure {
343 Err(super::cdc::CdcError::CleanupFailed {
344 capture_instance: capture_instance.to_string(),
345 low_water_mark: *low_water_mark,
346 })?
347 } else {
348 Ok(())
349 }
350 }
351 other => Err(SqlServerError::ProgrammingError(format!(
352 "expected one status row, found {other:?}"
353 ))),
354 }
355}
356
357static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
374SELECT
375 s.name as schema_name,
376 t.name as table_name,
377 ch.capture_instance as capture_instance,
378 ch.create_date as capture_instance_create_date,
379 c.name as col_name,
380 ty.name as col_type,
381 c.is_nullable as col_nullable,
382 c.max_length as col_max_length,
383 c.precision as col_precision,
384 c.scale as col_scale,
385 c.is_computed as col_is_computed,
386 tc.constraint_name AS col_primary_key_constraint
387FROM sys.tables t
388JOIN sys.schemas s ON t.schema_id = s.schema_id
389JOIN sys.columns c ON t.object_id = c.object_id
390JOIN sys.types ty ON c.user_type_id = ty.user_type_id
391JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
392LEFT JOIN information_schema.key_column_usage kc
393 ON kc.table_schema = s.name
394 AND kc.table_name = t.name
395 AND kc.column_name = c.name
396LEFT JOIN information_schema.table_constraints tc
397 ON tc.constraint_catalog = kc.constraint_catalog
398 AND tc.constraint_schema = kc.constraint_schema
399 AND tc.constraint_name = kc.constraint_name
400 AND tc.table_schema = kc.table_schema
401 AND tc.table_name = kc.table_name
402 AND tc.constraint_type = 'PRIMARY KEY'
403";
404
405pub async fn get_tables_for_capture_instance<'a>(
407 client: &mut Client,
408 capture_instances: impl IntoIterator<Item = &str>,
409) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
410 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
413 if params.is_empty() {
415 return Ok(Vec::default());
416 }
417
418 #[allow(clippy::as_conversions)]
420 let params_dyn: SmallVec<[_; 1]> = params
421 .iter()
422 .map(|instance| instance as &dyn tiberius::ToSql)
423 .collect();
424 let param_indexes = params
425 .iter()
426 .enumerate()
427 .map(|(idx, _)| format!("@P{}", idx + 1))
429 .join(", ");
430
431 let table_for_capture_instance_query = format!(
432 "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
433 );
434
435 let result = client
436 .query(&table_for_capture_instance_query, ¶ms_dyn[..])
437 .await?;
438
439 let tables = deserialize_table_columns_to_raw_tables(&result)?;
440
441 Ok(tables)
442}
443
444pub async fn get_cdc_table_columns(
450 client: &mut Client,
451 capture_instance: &str,
452) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
453 static CDC_COLUMNS_QUERY: &str = "SELECT \
454 c.name AS col_name, \
455 t.name AS col_type, \
456 c.max_length AS col_max_length, \
457 c.precision AS col_precision, \
458 c.scale AS col_scale, \
459 c.is_computed as col_is_computed \
460 FROM \
461 sys.columns AS c \
462 JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
463 WHERE \
464 c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
465 ORDER BY c.column_id;";
466 let cdc_table_name = format!(
468 "cdc.{table_name}",
469 table_name = quote_identifier(&format!("{capture_instance}_CT"))
470 );
471 let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
472 let mut columns = BTreeMap::new();
473 for row in result.iter() {
474 let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
475 let column = SqlServerColumnRaw {
478 name: Arc::clone(&column_name),
479 data_type: get_value::<&str>(row, "col_type")?.into(),
480 is_nullable: true,
481 primary_key_constraint: None,
482 max_length: get_value(row, "col_max_length")?,
483 precision: get_value(row, "col_precision")?,
484 scale: get_value(row, "col_scale")?,
485 is_computed: get_value(row, "col_is_computed")?,
486 };
487 columns.insert(column_name, column);
488 }
489 Ok(columns)
490}
491
492pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
497 static DATABASE_CDC_ENABLED_QUERY: &str =
498 "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
499 let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
500
501 check_system_result(&result, "database CDC".to_string(), true)?;
502 Ok(())
503}
504
505pub async fn get_latest_restore_history_id(
512 client: &mut Client,
513) -> Result<Option<i32>, SqlServerError> {
514 static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
515 FROM msdb.dbo.restorehistory \
516 WHERE destination_database_name = DB_NAME() \
517 ORDER BY restore_history_id DESC;";
518 let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
519
520 match &result[..] {
521 [] => Ok(None),
522 [row] => Ok(row.try_get::<i32, _>(0)?),
523 other => Err(SqlServerError::InvariantViolated(format!(
524 "expected one row, got {other:?}"
525 ))),
526 }
527}
528
529#[derive(Debug)]
531pub struct DDLEvent {
532 pub lsn: Lsn,
533 pub ddl_command: Arc<str>,
534}
535
536impl DDLEvent {
537 pub fn is_compatible(&self) -> bool {
544 let mut words = self.ddl_command.split_ascii_whitespace();
547 match (
548 words.next().map(str::to_ascii_lowercase).as_deref(),
549 words.next().map(str::to_ascii_lowercase).as_deref(),
550 ) {
551 (Some("alter"), Some("table")) => {
552 let mut peekable = words.peekable();
553 let mut compatible = true;
554 while compatible && let Some(token) = peekable.next() {
555 compatible = match token.to_ascii_lowercase().as_str() {
556 "alter" | "drop" => peekable
557 .peek()
558 .is_some_and(|next_tok| !next_tok.eq_ignore_ascii_case("column")),
559 _ => true,
560 }
561 }
562 compatible
563 }
564 _ => true,
565 }
566 }
567}
568
569pub async fn get_ddl_history(
574 client: &mut Client,
575 capture_instance: &str,
576 from_lsn: &Lsn,
577 to_lsn: &Lsn,
578) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
579 static DDL_HISTORY_QUERY: &str = "SELECT \
583 s.name AS schema_name, \
584 t.name AS table_name, \
585 dh.ddl_lsn, \
586 dh.ddl_command
587 FROM \
588 cdc.change_tables ct \
589 JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
590 JOIN sys.tables t ON t.object_id = dh.source_object_id \
591 JOIN sys.schemas s ON s.schema_id = t.schema_id \
592 WHERE \
593 ct.capture_instance = @P1 \
594 AND dh.ddl_lsn >= @P2 \
595 AND dh.ddl_lsn <= @P3 \
596 ORDER BY ddl_lsn;";
597
598 let result = client
599 .query(
600 DDL_HISTORY_QUERY,
601 &[
602 &capture_instance,
603 &from_lsn.as_bytes().as_slice(),
604 &to_lsn.as_bytes().as_slice(),
605 ],
606 )
607 .await?;
608
609 let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
612 for row in result.iter() {
613 let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
614 let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
615 let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
616 let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
617
618 let qualified_table_name = SqlServerQualifiedTableName {
619 schema_name,
620 table_name,
621 };
622 let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
623 column_name: "ddl_lsn".to_string(),
624 error: lsn_err,
625 })?;
626
627 collector
628 .entry(qualified_table_name)
629 .or_default()
630 .push(DDLEvent { lsn, ddl_command });
631 }
632
633 Ok(collector)
634}
635
636pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
641 static SNAPSHOT_ISOLATION_QUERY: &str =
642 "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
643 let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
644
645 check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
646 Ok(())
647}
648
649pub async fn ensure_sql_server_agent_running(client: &mut Client) -> Result<(), SqlServerError> {
653 static AGENT_STATUS_QUERY: &str = "SELECT status_desc FROM sys.dm_server_services WHERE servicename LIKE 'SQL Server Agent%';";
654 let result = client.simple_query(AGENT_STATUS_QUERY).await?;
655
656 check_system_result(&result, "SQL Server Agent status".to_string(), "Running")?;
657 Ok(())
658}
659
660pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
661 let result = client
662 .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
663 .await?;
664
665 let tables = deserialize_table_columns_to_raw_tables(&result)?;
666
667 Ok(tables)
668}
669
670fn get_value<'a, T: tiberius::FromSql<'a>>(
672 row: &'a tiberius::Row,
673 name: &'static str,
674) -> Result<T, SqlServerError> {
675 row.try_get(name)?
676 .ok_or(SqlServerError::MissingColumn(name))
677}
678
679fn deserialize_table_columns_to_raw_tables(
680 rows: &[tiberius::Row],
681) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
682 let mut tables = BTreeMap::default();
684 for row in rows {
685 let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
686 let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
687 let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
688 let capture_instance_create_date: NaiveDateTime =
689 get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
690 let primary_key_constraint: Option<Arc<str>> = row
691 .try_get::<&str, _>("col_primary_key_constraint")?
692 .map(|v| v.into());
693
694 let column_name = get_value::<&str>(row, "col_name")?.into();
695 let column = SqlServerColumnRaw {
696 name: Arc::clone(&column_name),
697 data_type: get_value::<&str>(row, "col_type")?.into(),
698 is_nullable: get_value(row, "col_nullable")?,
699 primary_key_constraint,
700 max_length: get_value(row, "col_max_length")?,
701 precision: get_value(row, "col_precision")?,
702 scale: get_value(row, "col_scale")?,
703 is_computed: get_value(row, "col_is_computed")?,
704 };
705
706 let columns: &mut Vec<_> = tables
707 .entry((
708 Arc::clone(&schema_name),
709 Arc::clone(&table_name),
710 Arc::clone(&capture_instance),
711 capture_instance_create_date,
712 ))
713 .or_default();
714 columns.push(column);
715 }
716
717 let raw_tables = tables
719 .into_iter()
720 .map(
721 |((schema, name, capture_instance, capture_instance_create_date), columns)| {
722 SqlServerTableRaw {
723 schema_name: schema,
724 name,
725 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
726 name: capture_instance,
727 create_date: capture_instance_create_date.into(),
728 }),
729 columns: columns.into(),
730 }
731 },
732 )
733 .collect::<Vec<SqlServerTableRaw>>();
734
735 Ok(raw_tables)
736}
737
738pub fn snapshot(
740 client: &mut Client,
741 table: &SqlServerTableRaw,
742) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
743 let cols = table
744 .columns
745 .iter()
746 .map(|SqlServerColumnRaw { name, .. }| quote_identifier(name))
747 .join(",");
748 let query = format!(
749 "SELECT {cols} FROM {schema_name}.{table_name};",
750 schema_name = quote_identifier(&table.schema_name),
751 table_name = quote_identifier(&table.name)
752 );
753 client.query_streaming(query, &[])
754}
755
756pub async fn snapshot_size(
758 client: &mut Client,
759 schema: &str,
760 table: &str,
761) -> Result<usize, SqlServerError> {
762 let query = format!(
763 "SELECT COUNT(*) FROM {schema_name}.{table_name};",
764 schema_name = quote_identifier(schema),
765 table_name = quote_identifier(table)
766 );
767 let result = client.query(query, &[]).await?;
768
769 match &result[..] {
770 [row] => match row.try_get::<i32, _>(0)? {
771 Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
772 Some(negative) => Err(SqlServerError::InvalidData {
773 column_name: "count".to_string(),
774 error: format!("found negative count: {negative}"),
775 }),
776 None => Err(SqlServerError::InvalidData {
777 column_name: "count".to_string(),
778 error: "expected a value found NULL".to_string(),
779 }),
780 },
781 other => Err(SqlServerError::InvariantViolated(format!(
782 "expected one row, got {other:?}"
783 ))),
784 }
785}
786
787fn check_system_result<'a, T>(
789 result: &'a SmallVec<[tiberius::Row; 1]>,
790 name: String,
791 expected: T,
792) -> Result<(), SqlServerError>
793where
794 T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
795{
796 match &result[..] {
797 [row] => {
798 let result: Option<T> = row.try_get(0)?;
799 if result == Some(expected) {
800 Ok(())
801 } else {
802 Err(SqlServerError::InvalidSystemSetting {
803 name,
804 expected: expected.to_string(),
805 actual: format!("{result:?}"),
806 })
807 }
808 }
809 other => Err(SqlServerError::InvariantViolated(format!(
810 "expected 1 row, got {other:?}"
811 ))),
812 }
813}
814
815pub async fn validate_source_privileges<'a>(
820 client: &mut Client,
821 capture_instances: impl IntoIterator<Item = &str>,
822) -> Result<(), SqlServerError> {
823 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
824
825 if params.is_empty() {
826 return Ok(());
827 }
828
829 let params_dyn: SmallVec<[_; 1]> = params
830 .iter()
831 .map(|instance| {
832 let instance: &dyn tiberius::ToSql = instance;
833 instance
834 })
835 .collect();
836
837 let param_indexes = (1..params.len() + 1)
838 .map(|idx| format!("@P{}", idx))
839 .join(", ");
840
841 let capture_instance_query = format!(
843 "
844 SELECT
845 SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
846 ct.capture_instance AS capture_instance,
847 COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
848 COALESCE(HAS_PERMS_BY_NAME('cdc.' + QUOTENAME(ct.capture_instance + '_CT') , 'OBJECT', 'SELECT'), 0) AS capture_table_select
849 FROM cdc.change_tables ct
850 JOIN sys.objects o ON o.object_id = ct.source_object_id
851 WHERE ct.capture_instance IN ({param_indexes});
852 "
853 );
854
855 let rows = client
856 .query(capture_instance_query, ¶ms_dyn[..])
857 .await?;
858
859 let mut capture_instances_without_perms = vec![];
860 let mut tables_without_perms = vec![];
861
862 for row in rows {
863 let table: &str = row
864 .try_get("qualified_table_name")
865 .context("getting table column")?
866 .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
867
868 let capture_instance: &str = row
869 .try_get("capture_instance")
870 .context("getting capture_instance column")?
871 .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
872
873 let permitted_table: i32 = row
874 .try_get("table_select")
875 .context("getting table_select column")?
876 .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
877
878 let permitted_capture_instance: i32 = row
879 .try_get("capture_table_select")
880 .context("getting capture_table_select column")?
881 .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
882
883 if permitted_table == 0 {
884 tables_without_perms.push(table.to_string());
885 }
886
887 if permitted_capture_instance == 0 {
888 capture_instances_without_perms.push(capture_instance.to_string());
889 }
890 }
891
892 if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
893 return Err(SqlServerError::AuthorizationError {
894 tables: tables_without_perms.join(", "),
895 capture_instances: capture_instances_without_perms.join(", "),
896 });
897 }
898
899 Ok(())
900}