1use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{
27 SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerQualifiedTableName,
28 SqlServerTableConstraintRaw, SqlServerTableRaw,
29};
30use crate::{Client, SqlServerError, quote_identifier};
31
32pub async fn get_min_lsn(
36 client: &mut Client,
37 capture_instance: &str,
38) -> Result<Lsn, SqlServerError> {
39 static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
40 let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
41
42 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
43 parse_lsn(&result[..1])
44}
45pub async fn get_min_lsn_retry(
50 client: &mut Client,
51 capture_instance: &str,
52 max_retry_duration: Duration,
53) -> Result<Lsn, SqlServerError> {
54 let (_client, lsn_result) = mz_ore::retry::Retry::default()
55 .max_duration(max_retry_duration)
56 .retry_async_with_state(client, |_, client| async {
57 let result = crate::inspect::get_min_lsn(client, capture_instance).await;
58 (client, map_null_lsn_to_retry(result))
59 })
60 .await;
61 let Ok(lsn) = lsn_result else {
62 tracing::warn!("database did not report a minimum LSN in time");
63 return lsn_result;
64 };
65 Ok(lsn)
66}
67
68pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
76 static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
77 let result = client.simple_query(MAX_LSN_QUERY).await?;
78
79 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
80 parse_lsn(&result[..1])
81}
82
83pub async fn get_min_lsns(
93 client: &mut Client,
94 capture_instances: impl IntoIterator<Item = &str>,
95) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
96 let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
97 let values: Vec<_> = capture_instances
98 .iter()
99 .map(|ci| {
100 let ci: &dyn tiberius::ToSql = ci;
101 ci
102 })
103 .collect();
104 let args = (0..capture_instances.len())
105 .map(|i| format!("@P{}", i + 1))
106 .collect::<Vec<_>>()
107 .join(",");
108 let stmt = format!(
109 "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
110 );
111 let result = client.query(stmt, &values).await?;
112 let min_lsns = result
113 .into_iter()
114 .map(|row| {
115 let capture_instance: Arc<str> = row
116 .try_get::<&str, _>("capture_instance")?
117 .ok_or_else(|| {
118 SqlServerError::ProgrammingError(
119 "missing column 'capture_instance'".to_string(),
120 )
121 })?
122 .into();
123 let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
124 SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
125 })?;
126 let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
127 column_name: "lsn".to_string(),
128 error: format!("Error parsing LSN for {capture_instance}: {msg}"),
129 })?;
130 Ok::<_, SqlServerError>((capture_instance, min_lsn))
131 })
132 .collect::<Result<_, _>>()?;
133
134 Ok(min_lsns)
135}
136
137pub async fn get_max_lsn_retry(
146 client: &mut Client,
147 max_retry_duration: Duration,
148) -> Result<Lsn, SqlServerError> {
149 let (_client, lsn_result) = mz_ore::retry::Retry::default()
150 .max_duration(max_retry_duration)
151 .retry_async_with_state(client, |_, client| async {
152 let result = crate::inspect::get_max_lsn(client).await;
153 (client, map_null_lsn_to_retry(result))
154 })
155 .await;
156
157 let Ok(lsn) = lsn_result else {
158 tracing::warn!("database did not report a maximum LSN in time");
159 return lsn_result;
160 };
161 Ok(lsn)
162}
163
164fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
165 match result {
166 Ok(val) => RetryResult::Ok(val),
167 Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
168 Err(other) => RetryResult::FatalErr(other),
169 }
170}
171
172pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
176 static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
177 let result = client
178 .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
179 .await?;
180
181 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
182 parse_lsn(&result[..1])
183}
184
185pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
189 match row {
190 [r] => {
191 let numeric_lsn = r
192 .try_get::<Numeric, _>(0)?
193 .ok_or_else(|| SqlServerError::NullLsn)?;
194 let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
195 column_name: "lsn".to_string(),
196 error: msg,
197 })?;
198 Ok(lsn)
199 }
200 other => Err(SqlServerError::InvalidData {
201 column_name: "lsn".to_string(),
202 error: format!("expected 1 column, got {other:?}"),
203 }),
204 }
205}
206
207fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
211 match result {
212 [row] => {
213 let val = row
214 .try_get::<&[u8], _>(0)?
215 .ok_or_else(|| SqlServerError::NullLsn)?;
216 if val.is_empty() {
217 Err(SqlServerError::NullLsn)
218 } else {
219 let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
220 column_name: "lsn".to_string(),
221 error: msg,
222 })?;
223 Ok(lsn)
224 }
225 }
226 other => Err(SqlServerError::InvalidData {
227 column_name: "lsn".to_string(),
228 error: format!("expected 1 column, got {other:?}"),
229 }),
230 }
231}
232
233pub fn get_changes_asc(
236 client: &mut Client,
237 capture_instance: &str,
238 start_lsn: Lsn,
239 end_lsn: Lsn,
240 filter: RowFilterOption,
241) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
242 const START_LSN_COLUMN: &str = "__$start_lsn";
243 let query = format!(
244 "SELECT * FROM cdc.{function}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;",
245 function = quote_identifier(&format!("fn_cdc_get_all_changes_{capture_instance}"))
246 );
247 client.query_streaming(
248 query,
249 &[
250 &start_lsn.as_bytes().as_slice(),
251 &end_lsn.as_bytes().as_slice(),
252 ],
253 )
254}
255
256pub async fn cleanup_change_table(
270 client: &mut Client,
271 capture_instance: &str,
272 low_water_mark: &Lsn,
273 max_deletes: u32,
274) -> Result<(), SqlServerError> {
275 static GET_LSN_QUERY: &str =
276 "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
277 static CLEANUP_QUERY: &str = "
278DECLARE @mz_cleanup_status_bit BIT;
279SET @mz_cleanup_status_bit = 0;
280EXEC sys.sp_cdc_cleanup_change_table
281 @capture_instance = @P1,
282 @low_water_mark = @P2,
283 @threshold = @P3,
284 @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
285SELECT @mz_cleanup_status_bit;
286 ";
287
288 let max_deletes = i64::cast_from(max_deletes);
289
290 let result = client
294 .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
295 .await?;
296 let low_water_mark_to_use = match &result[..] {
297 [row] => row
298 .try_get::<&[u8], _>(0)?
299 .ok_or_else(|| SqlServerError::InvalidData {
300 column_name: "mz_cleanup_status_bit".to_string(),
301 error: "expected a bool, found NULL".to_string(),
302 })?,
303 other => Err(SqlServerError::ProgrammingError(format!(
304 "expected one row for low water mark, found {other:?}"
305 )))?,
306 };
307
308 let result = client
311 .query(
312 CLEANUP_QUERY,
313 &[&capture_instance, &low_water_mark_to_use, &max_deletes],
314 )
315 .await;
316
317 let rows = match result {
318 Ok(rows) => rows,
319 Err(SqlServerError::SqlServer(e)) => {
320 let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
324
325 if already_cleaned_up {
326 return Ok(());
327 } else {
328 return Err(SqlServerError::SqlServer(e));
329 }
330 }
331 Err(other) => return Err(other),
332 };
333
334 match &rows[..] {
335 [row] => {
336 let failure =
337 row.try_get::<bool, _>(0)?
338 .ok_or_else(|| SqlServerError::InvalidData {
339 column_name: "mz_cleanup_status_bit".to_string(),
340 error: "expected a bool, found NULL".to_string(),
341 })?;
342
343 if failure {
344 Err(super::cdc::CdcError::CleanupFailed {
345 capture_instance: capture_instance.to_string(),
346 low_water_mark: *low_water_mark,
347 })?
348 } else {
349 Ok(())
350 }
351 }
352 other => Err(SqlServerError::ProgrammingError(format!(
353 "expected one status row, found {other:?}"
354 ))),
355 }
356}
357
358static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
373SELECT
374 s.name as schema_name,
375 t.name as table_name,
376 ch.capture_instance as capture_instance,
377 ch.create_date as capture_instance_create_date,
378 c.name as col_name,
379 ty.name as col_type,
380 c.is_nullable as col_nullable,
381 c.max_length as col_max_length,
382 c.precision as col_precision,
383 c.scale as col_scale,
384 c.is_computed as col_is_computed
385FROM sys.tables t
386JOIN sys.schemas s ON t.schema_id = s.schema_id
387JOIN sys.columns c ON t.object_id = c.object_id
388JOIN sys.types ty ON c.user_type_id = ty.user_type_id
389JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
390";
391
392pub async fn get_tables_for_capture_instance(
394 client: &mut Client,
395 capture_instances: impl IntoIterator<Item = &str>,
396) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
397 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
400 if params.is_empty() {
402 return Ok(Vec::default());
403 }
404
405 #[allow(clippy::as_conversions)]
407 let params_dyn: SmallVec<[_; 1]> = params
408 .iter()
409 .map(|instance| instance as &dyn tiberius::ToSql)
410 .collect();
411 let param_indexes = params
412 .iter()
413 .enumerate()
414 .map(|(idx, _)| format!("@P{}", idx + 1))
416 .join(", ");
417
418 let table_for_capture_instance_query = format!(
419 "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
420 );
421
422 let result = client
423 .query(&table_for_capture_instance_query, ¶ms_dyn[..])
424 .await?;
425
426 let tables = deserialize_table_columns_to_raw_tables(&result)?;
427
428 Ok(tables)
429}
430
431pub async fn get_cdc_table_columns(
437 client: &mut Client,
438 capture_instance: &str,
439) -> Result<BTreeMap<Arc<str>, SqlServerColumnRaw>, SqlServerError> {
440 static CDC_COLUMNS_QUERY: &str = "SELECT \
441 c.name AS col_name, \
442 t.name AS col_type, \
443 c.max_length AS col_max_length, \
444 c.precision AS col_precision, \
445 c.scale AS col_scale, \
446 c.is_computed as col_is_computed \
447 FROM \
448 sys.columns AS c \
449 JOIN sys.types AS t ON c.system_type_id = t.system_type_id AND c.user_type_id = t.user_type_id \
450 WHERE \
451 c.object_id = OBJECT_ID(@P1) AND c.name NOT LIKE '__$%' \
452 ORDER BY c.column_id;";
453 let cdc_table_name = format!(
455 "cdc.{table_name}",
456 table_name = quote_identifier(&format!("{capture_instance}_CT"))
457 );
458 let result = client.query(CDC_COLUMNS_QUERY, &[&cdc_table_name]).await?;
459 let mut columns = BTreeMap::new();
460 for row in result.iter() {
461 let column_name: Arc<str> = get_value::<&str>(row, "col_name")?.into();
462 let column = SqlServerColumnRaw {
465 name: Arc::clone(&column_name),
466 data_type: get_value::<&str>(row, "col_type")?.into(),
467 is_nullable: true,
468 max_length: get_value(row, "col_max_length")?,
469 precision: get_value(row, "col_precision")?,
470 scale: get_value(row, "col_scale")?,
471 is_computed: get_value(row, "col_is_computed")?,
472 };
473 columns.insert(column_name, column);
474 }
475 Ok(columns)
476}
477
478pub async fn get_constraints_for_tables(
481 client: &mut Client,
482 schema_table_list: impl Iterator<Item = &(Arc<str>, Arc<str>)>,
483) -> Result<BTreeMap<(Arc<str>, Arc<str>), Vec<SqlServerTableConstraintRaw>>, SqlServerError> {
484 let qualified_table_names: Vec<_> = schema_table_list
485 .map(|(schema, table)| {
486 format!(
487 "{quoted_schema}.{quoted_table}",
488 quoted_schema = quote_identifier(schema),
489 quoted_table = quote_identifier(table)
490 )
491 })
492 .collect();
493
494 if qualified_table_names.is_empty() {
495 return Ok(Default::default());
496 }
497
498 let params = (1..qualified_table_names.len() + 1)
499 .map(|idx| format!("@P{}", idx))
500 .join(", ");
501
502 let query = format!(
505 "SELECT \
506 tc.table_schema, \
507 tc.table_name, \
508 kcu.column_name, \
509 tc.constraint_name, \
510 tc.constraint_type \
511 FROM information_schema.table_constraints tc \
512 JOIN information_schema.key_column_usage kcu \
513 ON kcu.constraint_schema = tc.constraint_schema \
514 AND kcu.constraint_name = tc.constraint_name \
515 AND kcu.table_schema = tc.table_schema \
516 AND kcu.table_name = tc.table_name \
517 WHERE
518 QUOTENAME(tc.table_schema) + '.' + QUOTENAME(tc.table_name) IN ({params})
519 AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
520 ORDER BY tc.table_schema, tc.table_name, tc.constraint_name, kcu.ordinal_position"
521 );
522
523 let query_params: Vec<_> = qualified_table_names
524 .iter()
525 .map(|qualified_name| {
526 let name: &dyn tiberius::ToSql = qualified_name;
527 name
528 })
529 .collect();
530
531 tracing::debug!("query = {query} params = {qualified_table_names:?}");
532 let result = client.query(query, &query_params).await?;
533
534 let mut contraints_by_table: BTreeMap<_, BTreeMap<_, Vec<_>>> = BTreeMap::new();
535 for row in result {
536 let schema_name: Arc<str> = get_value::<&str>(&row, "table_schema")?.into();
537 let table_name: Arc<str> = get_value::<&str>(&row, "table_name")?.into();
538 let column_name = get_value::<&str>(&row, "column_name")?.into();
539 let constraint_name = get_value::<&str>(&row, "constraint_name")?.into();
540 let constraint_type = get_value::<&str>(&row, "constraint_type")?.into();
541
542 contraints_by_table
543 .entry((Arc::clone(&schema_name), Arc::clone(&table_name)))
544 .or_default()
545 .entry((constraint_name, constraint_type))
546 .or_default()
547 .push(column_name);
548 }
549 Ok(contraints_by_table
550 .into_iter()
551 .inspect(|((schema_name, table_name), constraints)| {
552 tracing::debug!("table {schema_name}.{table_name} constraints: {constraints:?}")
553 })
554 .map(|(qualified_name, constraints)| {
555 (
556 qualified_name,
557 constraints
558 .into_iter()
559 .map(|((constraint_name, constraint_type), columns)| {
560 SqlServerTableConstraintRaw {
561 constraint_name,
562 constraint_type,
563 columns,
564 }
565 })
566 .collect(),
567 )
568 })
569 .collect())
570}
571
572pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
577 static DATABASE_CDC_ENABLED_QUERY: &str =
578 "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
579 let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
580
581 check_system_result(&result, "database CDC".to_string(), true)?;
582 Ok(())
583}
584
585pub async fn get_latest_restore_history_id(
592 client: &mut Client,
593) -> Result<Option<i32>, SqlServerError> {
594 static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
595 FROM msdb.dbo.restorehistory \
596 WHERE destination_database_name = DB_NAME() \
597 ORDER BY restore_history_id DESC;";
598 let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
599
600 match &result[..] {
601 [] => Ok(None),
602 [row] => Ok(row.try_get::<i32, _>(0)?),
603 other => Err(SqlServerError::InvariantViolated(format!(
604 "expected one row, got {other:?}"
605 ))),
606 }
607}
608
609#[derive(Debug)]
611pub struct DDLEvent {
612 pub lsn: Lsn,
613 pub ddl_command: Arc<str>,
614}
615
616impl DDLEvent {
617 pub fn is_compatible(&self, included_columns: &[Arc<str>]) -> bool {
624 let mut words = self.ddl_command.split_ascii_whitespace();
627 match (
628 words.next().map(str::to_ascii_lowercase).as_deref(),
629 words.next().map(str::to_ascii_lowercase).as_deref(),
630 ) {
631 (Some("alter"), Some("table")) => {
632 let mut peekable = words.peekable();
633 let mut compatible = true;
634 while compatible && let Some(token) = peekable.next() {
635 compatible = match token.to_ascii_lowercase().as_str() {
636 "alter" | "drop" => {
637 let target = peekable.next();
638 match target {
639 Some(t) if t.eq_ignore_ascii_case("column") => {
641 let mut all_excluded = true;
642 while let Some(tok) = peekable.next() {
643 match tok.to_ascii_lowercase().as_str() {
645 "if" | "exists" | "," | "column" => continue,
646 col_str => {
647 if !col_str.trim_matches(',').split(',').all(
651 |col_name| {
652 !included_columns.iter().any(|included| {
653 included.eq_ignore_ascii_case(
654 col_name.trim_matches(
655 ['[', ']', '"'].as_ref(),
656 ),
657 )
658 })
659 },
660 ) {
661 all_excluded = false;
662 break;
663 }
664 if !col_str.ends_with(",") {
668 match peekable.peek() {
669 Some(x) if x.starts_with(",") => continue,
670 _ => break,
671 }
672 }
673 }
674 };
675 }
676 all_excluded
677 }
678 None => false,
680 _ => true,
682 }
683 }
684 _ => true,
685 }
686 }
687 compatible
688 }
689 _ => true,
690 }
691 }
692}
693
694pub async fn get_ddl_history(
699 client: &mut Client,
700 capture_instance: &str,
701 from_lsn: &Lsn,
702 to_lsn: &Lsn,
703) -> Result<BTreeMap<SqlServerQualifiedTableName, Vec<DDLEvent>>, SqlServerError> {
704 static DDL_HISTORY_QUERY: &str = "SELECT \
708 s.name AS schema_name, \
709 t.name AS table_name, \
710 dh.ddl_lsn, \
711 dh.ddl_command
712 FROM \
713 cdc.change_tables ct \
714 JOIN cdc.ddl_history dh ON dh.object_id = ct.object_id \
715 JOIN sys.tables t ON t.object_id = dh.source_object_id \
716 JOIN sys.schemas s ON s.schema_id = t.schema_id \
717 WHERE \
718 ct.capture_instance = @P1 \
719 AND dh.ddl_lsn >= @P2 \
720 AND dh.ddl_lsn <= @P3 \
721 ORDER BY ddl_lsn;";
722
723 let result = client
724 .query(
725 DDL_HISTORY_QUERY,
726 &[
727 &capture_instance,
728 &from_lsn.as_bytes().as_slice(),
729 &to_lsn.as_bytes().as_slice(),
730 ],
731 )
732 .await?;
733
734 let mut collector: BTreeMap<_, Vec<_>> = BTreeMap::new();
737 for row in result.iter() {
738 let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
739 let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
740 let lsn: &[u8] = get_value::<&[u8]>(row, "ddl_lsn")?;
741 let ddl_command: Arc<str> = get_value::<&str>(row, "ddl_command")?.into();
742
743 let qualified_table_name = SqlServerQualifiedTableName {
744 schema_name,
745 table_name,
746 };
747 let lsn = Lsn::try_from(lsn).map_err(|lsn_err| SqlServerError::InvalidData {
748 column_name: "ddl_lsn".to_string(),
749 error: lsn_err,
750 })?;
751
752 collector
753 .entry(qualified_table_name)
754 .or_default()
755 .push(DDLEvent { lsn, ddl_command });
756 }
757
758 Ok(collector)
759}
760
761pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
766 static SNAPSHOT_ISOLATION_QUERY: &str =
767 "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
768 let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
769
770 check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
771 Ok(())
772}
773
774pub async fn ensure_sql_server_agent_running(client: &mut Client) -> Result<(), SqlServerError> {
778 static AGENT_STATUS_QUERY: &str = "SELECT status_desc FROM sys.dm_server_services WHERE servicename LIKE 'SQL Server Agent%';";
779 let result = client.simple_query(AGENT_STATUS_QUERY).await?;
780
781 check_system_result(&result, "SQL Server Agent status".to_string(), "Running")?;
782 Ok(())
783}
784
785pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
786 let result = client
787 .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
788 .await?;
789
790 let tables = deserialize_table_columns_to_raw_tables(&result)?;
791
792 Ok(tables)
793}
794
795fn get_value<'a, T: tiberius::FromSql<'a>>(
797 row: &'a tiberius::Row,
798 name: &'static str,
799) -> Result<T, SqlServerError> {
800 row.try_get(name)?
801 .ok_or(SqlServerError::MissingColumn(name))
802}
803fn deserialize_table_columns_to_raw_tables(
804 rows: &[tiberius::Row],
805) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
806 let mut tables = BTreeMap::default();
808 for row in rows {
809 let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
810 let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
811 let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
812 let capture_instance_create_date: NaiveDateTime =
813 get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
814
815 let column_name = get_value::<&str>(row, "col_name")?.into();
816 let column = SqlServerColumnRaw {
817 name: Arc::clone(&column_name),
818 data_type: get_value::<&str>(row, "col_type")?.into(),
819 is_nullable: get_value(row, "col_nullable")?,
820 max_length: get_value(row, "col_max_length")?,
821 precision: get_value(row, "col_precision")?,
822 scale: get_value(row, "col_scale")?,
823 is_computed: get_value(row, "col_is_computed")?,
824 };
825
826 let columns: &mut Vec<_> = tables
827 .entry((
828 Arc::clone(&schema_name),
829 Arc::clone(&table_name),
830 Arc::clone(&capture_instance),
831 capture_instance_create_date,
832 ))
833 .or_default();
834 columns.push(column);
835 }
836
837 let raw_tables = tables
838 .into_iter()
839 .map(
840 |((schema, name, capture_instance, capture_instance_create_date), columns)| {
841 SqlServerTableRaw {
842 schema_name: schema,
843 name,
844 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
845 name: capture_instance,
846 create_date: capture_instance_create_date.into(),
847 }),
848 columns: columns.into(),
849 }
850 },
851 )
852 .collect();
853 Ok(raw_tables)
854}
855
856pub fn snapshot(
858 client: &mut Client,
859 table: &SqlServerTableRaw,
860) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
861 let cols = table
862 .columns
863 .iter()
864 .map(|SqlServerColumnRaw { name, .. }| quote_identifier(name))
865 .join(",");
866 let query = format!(
867 "SELECT {cols} FROM {schema_name}.{table_name};",
868 schema_name = quote_identifier(&table.schema_name),
869 table_name = quote_identifier(&table.name)
870 );
871 client.query_streaming(query, &[])
872}
873
874pub async fn snapshot_size(
876 client: &mut Client,
877 schema: &str,
878 table: &str,
879) -> Result<usize, SqlServerError> {
880 let query = format!(
881 "SELECT COUNT(*) FROM {schema_name}.{table_name};",
882 schema_name = quote_identifier(schema),
883 table_name = quote_identifier(table)
884 );
885 let result = client.query(query, &[]).await?;
886
887 match &result[..] {
888 [row] => match row.try_get::<i32, _>(0)? {
889 Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
890 Some(negative) => Err(SqlServerError::InvalidData {
891 column_name: "count".to_string(),
892 error: format!("found negative count: {negative}"),
893 }),
894 None => Err(SqlServerError::InvalidData {
895 column_name: "count".to_string(),
896 error: "expected a value found NULL".to_string(),
897 }),
898 },
899 other => Err(SqlServerError::InvariantViolated(format!(
900 "expected one row, got {other:?}"
901 ))),
902 }
903}
904
905fn check_system_result<'a, T>(
907 result: &'a SmallVec<[tiberius::Row; 1]>,
908 name: String,
909 expected: T,
910) -> Result<(), SqlServerError>
911where
912 T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
913{
914 match &result[..] {
915 [row] => {
916 let result: Option<T> = row.try_get(0)?;
917 if result == Some(expected) {
918 Ok(())
919 } else {
920 Err(SqlServerError::InvalidSystemSetting {
921 name,
922 expected: expected.to_string(),
923 actual: format!("{result:?}"),
924 })
925 }
926 }
927 other => Err(SqlServerError::InvariantViolated(format!(
928 "expected 1 row, got {other:?}"
929 ))),
930 }
931}
932
933pub async fn validate_source_privileges(
938 client: &mut Client,
939 capture_instances: impl IntoIterator<Item = &str>,
940) -> Result<(), SqlServerError> {
941 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
942
943 if params.is_empty() {
944 return Ok(());
945 }
946
947 let params_dyn: SmallVec<[_; 1]> = params
948 .iter()
949 .map(|instance| {
950 let instance: &dyn tiberius::ToSql = instance;
951 instance
952 })
953 .collect();
954
955 let param_indexes = (1..params.len() + 1)
956 .map(|idx| format!("@P{}", idx))
957 .join(", ");
958
959 let capture_instance_query = format!(
961 "
962 SELECT
963 SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
964 ct.capture_instance AS capture_instance,
965 COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
966 COALESCE(HAS_PERMS_BY_NAME('cdc.' + QUOTENAME(ct.capture_instance + '_CT') , 'OBJECT', 'SELECT'), 0) AS capture_table_select
967 FROM cdc.change_tables ct
968 JOIN sys.objects o ON o.object_id = ct.source_object_id
969 WHERE ct.capture_instance IN ({param_indexes});
970 "
971 );
972
973 let rows = client
974 .query(capture_instance_query, ¶ms_dyn[..])
975 .await?;
976
977 let mut capture_instances_without_perms = vec![];
978 let mut tables_without_perms = vec![];
979
980 for row in rows {
981 let table: &str = row
982 .try_get("qualified_table_name")
983 .context("getting table column")?
984 .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
985
986 let capture_instance: &str = row
987 .try_get("capture_instance")
988 .context("getting capture_instance column")?
989 .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
990
991 let permitted_table: i32 = row
992 .try_get("table_select")
993 .context("getting table_select column")?
994 .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
995
996 let permitted_capture_instance: i32 = row
997 .try_get("capture_table_select")
998 .context("getting capture_table_select column")?
999 .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
1000
1001 if permitted_table == 0 {
1002 tables_without_perms.push(table.to_string());
1003 }
1004
1005 if permitted_capture_instance == 0 {
1006 capture_instances_without_perms.push(capture_instance.to_string());
1007 }
1008 }
1009
1010 if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
1011 return Err(SqlServerError::AuthorizationError {
1012 tables: tables_without_perms.join(", "),
1013 capture_instances: capture_instances_without_perms.join(", "),
1014 });
1015 }
1016
1017 Ok(())
1018}
1019
1020#[cfg(test)]
1021mod tests {
1022 use super::DDLEvent;
1023 use std::sync::Arc;
1024
1025 #[mz_ore::test]
1026 fn test_ddl_event_is_compatible() {
1027 fn test_case(ddl_command: &str, included_columns: &[Arc<str>], expected: bool) {
1028 let ddl_event = DDLEvent {
1029 lsn: Default::default(),
1030 ddl_command: ddl_command.into(),
1031 };
1032 let result = ddl_event.is_compatible(included_columns);
1033 assert_eq!(
1034 result, expected,
1035 "DDL command '{}' with included_columns {:?} expected to be {}, got {}",
1036 ddl_command, included_columns, expected, result
1037 );
1038 }
1039
1040 let included_columns = vec![Arc::from("col3"), Arc::from("col4"), Arc::from("col4")];
1041
1042 test_case(
1043 "ALTER TABLE my_table ALTER COLUMN col1 INT",
1044 &included_columns,
1045 true,
1046 );
1047 test_case(
1048 "ALTER TABLE my_table DROP COLUMN col2",
1049 &included_columns,
1050 true,
1051 );
1052 test_case(
1053 "ALTER TABLE my_table ALTER COLUMN col3 INT",
1054 &included_columns,
1055 false,
1056 );
1057 test_case(
1058 "ALTER TABLE my_table DROP COLUMN col4 INT",
1059 &included_columns,
1060 false,
1061 );
1062 test_case(
1063 "CREATE INDEX idx_my_index ON my_table(col1)",
1064 &included_columns,
1065 true,
1066 );
1067 test_case(
1068 "DROP INDEX idx_my_index ON my_table",
1069 &included_columns,
1070 true,
1071 );
1072 test_case(
1073 "ALTER TABLE my_table ADD COLUMN col5 INT",
1074 &included_columns,
1075 true,
1076 );
1077 test_case(
1078 "ALTER TABLE my_table DROP COLUMN col1, col2",
1079 &included_columns,
1080 true,
1081 );
1082 test_case(
1083 "ALTER TABLE my_table DROP COLUMN col3, col2",
1084 &included_columns,
1085 false,
1086 );
1087 test_case(
1088 "ALTER TABLE my_table DROP COLUMN col3, col4",
1089 &included_columns,
1090 false,
1091 );
1092 test_case(
1093 "ALTER TABLE my_table DROP COLUMN IF EXISTS col1, col2",
1094 &included_columns,
1095 true,
1096 );
1097 test_case(
1098 "ALTER TABLE my_table DROP CONSTRAINT constraint_name",
1099 &included_columns,
1100 true,
1101 );
1102 test_case(
1103 "ALTER TABLE my_table DROP COLUMN col1,col3",
1104 &included_columns,
1105 false,
1106 );
1107 test_case(
1108 "ALTER TABLE my_table DROP COLUMN col1,col2",
1109 &included_columns,
1110 true,
1111 );
1112 test_case(
1113 "ALTER TABLE my_table DROP COLUMN col1 ,col2",
1114 &included_columns,
1115 true,
1116 );
1117 test_case(
1118 "ALTER TABLE my_table DROP COLUMN col1 , col2",
1119 &included_columns,
1120 true,
1121 );
1122 test_case(
1123 "ALTER TABLE my_table DROP COLUMN col1 , col3",
1124 &included_columns,
1125 false,
1126 );
1127 test_case(
1128 "ALTER TABLE my_table DROP COLUMN col1 , COLUMN col3",
1129 &included_columns,
1130 false,
1131 );
1132 }
1133}