1use anyhow::Context;
13use chrono::NaiveDateTime;
14use futures::Stream;
15use itertools::Itertools;
16use mz_ore::cast::CastFrom;
17use mz_ore::retry::RetryResult;
18use smallvec::SmallVec;
19use std::collections::BTreeMap;
20use std::fmt;
21use std::sync::Arc;
22use std::time::Duration;
23use tiberius::numeric::Numeric;
24
25use crate::cdc::{Lsn, RowFilterOption};
26use crate::desc::{SqlServerCaptureInstanceRaw, SqlServerColumnRaw, SqlServerTableRaw};
27use crate::{Client, SqlServerError};
28
29pub async fn get_min_lsn(
33 client: &mut Client,
34 capture_instance: &str,
35) -> Result<Lsn, SqlServerError> {
36 static MIN_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_min_lsn(@P1);";
37 let result = client.query(MIN_LSN_QUERY, &[&capture_instance]).await?;
38
39 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
40 parse_lsn(&result[..1])
41}
42pub async fn get_min_lsn_retry(
47 client: &mut Client,
48 capture_instance: &str,
49 max_retry_duration: Duration,
50) -> Result<Lsn, SqlServerError> {
51 let (_client, lsn_result) = mz_ore::retry::Retry::default()
52 .max_duration(max_retry_duration)
53 .retry_async_with_state(client, |_, client| async {
54 let result = crate::inspect::get_min_lsn(client, capture_instance).await;
55 (client, map_null_lsn_to_retry(result))
56 })
57 .await;
58 let Ok(lsn) = lsn_result else {
59 tracing::warn!("database did not report a minimum LSN in time");
60 return lsn_result;
61 };
62 Ok(lsn)
63}
64
65pub async fn get_max_lsn(client: &mut Client) -> Result<Lsn, SqlServerError> {
73 static MAX_LSN_QUERY: &str = "SELECT sys.fn_cdc_get_max_lsn();";
74 let result = client.simple_query(MAX_LSN_QUERY).await?;
75
76 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
77 parse_lsn(&result[..1])
78}
79
80pub async fn get_min_lsns(
90 client: &mut Client,
91 capture_instances: impl IntoIterator<Item = &str>,
92) -> Result<BTreeMap<Arc<str>, Lsn>, SqlServerError> {
93 let capture_instances: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
94 let values: Vec<_> = capture_instances
95 .iter()
96 .map(|ci| {
97 let ci: &dyn tiberius::ToSql = ci;
98 ci
99 })
100 .collect();
101 let args = (0..capture_instances.len())
102 .map(|i| format!("@P{}", i + 1))
103 .collect::<Vec<_>>()
104 .join(",");
105 let stmt = format!(
106 "SELECT capture_instance, start_lsn FROM cdc.change_tables WHERE capture_instance IN ({args});"
107 );
108 let result = client.query(stmt, &values).await?;
109 let min_lsns = result
110 .into_iter()
111 .map(|row| {
112 let capture_instance: Arc<str> = row
113 .try_get::<&str, _>("capture_instance")?
114 .ok_or_else(|| {
115 SqlServerError::ProgrammingError(
116 "missing column 'capture_instance'".to_string(),
117 )
118 })?
119 .into();
120 let start_lsn: &[u8] = row.try_get("start_lsn")?.ok_or_else(|| {
121 SqlServerError::ProgrammingError("missing column 'start_lsn'".to_string())
122 })?;
123 let min_lsn = Lsn::try_from(start_lsn).map_err(|msg| SqlServerError::InvalidData {
124 column_name: "lsn".to_string(),
125 error: format!("Error parsing LSN for {capture_instance}: {msg}"),
126 })?;
127 Ok::<_, SqlServerError>((capture_instance, min_lsn))
128 })
129 .collect::<Result<_, _>>()?;
130
131 Ok(min_lsns)
132}
133
134pub async fn get_max_lsn_retry(
143 client: &mut Client,
144 max_retry_duration: Duration,
145) -> Result<Lsn, SqlServerError> {
146 let (_client, lsn_result) = mz_ore::retry::Retry::default()
147 .max_duration(max_retry_duration)
148 .retry_async_with_state(client, |_, client| async {
149 let result = crate::inspect::get_max_lsn(client).await;
150 (client, map_null_lsn_to_retry(result))
151 })
152 .await;
153
154 let Ok(lsn) = lsn_result else {
155 tracing::warn!("database did not report a maximum LSN in time");
156 return lsn_result;
157 };
158 Ok(lsn)
159}
160
161fn map_null_lsn_to_retry<T>(result: Result<T, SqlServerError>) -> RetryResult<T, SqlServerError> {
162 match result {
163 Ok(val) => RetryResult::Ok(val),
164 Err(err @ SqlServerError::NullLsn) => RetryResult::RetryableErr(err),
165 Err(other) => RetryResult::FatalErr(other),
166 }
167}
168
169pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServerError> {
173 static INCREMENT_LSN_QUERY: &str = "SELECT sys.fn_cdc_increment_lsn(@P1);";
174 let result = client
175 .query(INCREMENT_LSN_QUERY, &[&lsn.as_bytes().as_slice()])
176 .await?;
177
178 mz_ore::soft_assert_eq_or_log!(result.len(), 1);
179 parse_lsn(&result[..1])
180}
181
182pub(crate) fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
186 match row {
187 [r] => {
188 let numeric_lsn = r
189 .try_get::<Numeric, _>(0)?
190 .ok_or_else(|| SqlServerError::NullLsn)?;
191 let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
192 column_name: "lsn".to_string(),
193 error: msg,
194 })?;
195 Ok(lsn)
196 }
197 other => Err(SqlServerError::InvalidData {
198 column_name: "lsn".to_string(),
199 error: format!("expected 1 column, got {other:?}"),
200 }),
201 }
202}
203
204fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
208 match result {
209 [row] => {
210 let val = row
211 .try_get::<&[u8], _>(0)?
212 .ok_or_else(|| SqlServerError::NullLsn)?;
213 if val.is_empty() {
214 Err(SqlServerError::NullLsn)
215 } else {
216 let lsn = Lsn::try_from(val).map_err(|msg| SqlServerError::InvalidData {
217 column_name: "lsn".to_string(),
218 error: msg,
219 })?;
220 Ok(lsn)
221 }
222 }
223 other => Err(SqlServerError::InvalidData {
224 column_name: "lsn".to_string(),
225 error: format!("expected 1 column, got {other:?}"),
226 }),
227 }
228}
229
230pub fn get_changes_asc(
237 client: &mut Client,
238 capture_instance: &str,
239 start_lsn: Lsn,
240 end_lsn: Lsn,
241 filter: RowFilterOption,
242) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send {
243 const START_LSN_COLUMN: &str = "__$start_lsn";
244 let query = format!(
245 "SELECT * FROM cdc.fn_cdc_get_all_changes_{capture_instance}(@P1, @P2, N'{filter}') ORDER BY {START_LSN_COLUMN} ASC;"
246 );
247 client.query_streaming(
248 query,
249 &[
250 &start_lsn.as_bytes().as_slice(),
251 &end_lsn.as_bytes().as_slice(),
252 ],
253 )
254}
255
256pub async fn cleanup_change_table(
270 client: &mut Client,
271 capture_instance: &str,
272 low_water_mark: &Lsn,
273 max_deletes: u32,
274) -> Result<(), SqlServerError> {
275 static GET_LSN_QUERY: &str =
276 "SELECT MAX(start_lsn) FROM cdc.lsn_time_mapping WHERE start_lsn <= @P1";
277 static CLEANUP_QUERY: &str = "
278DECLARE @mz_cleanup_status_bit BIT;
279SET @mz_cleanup_status_bit = 0;
280EXEC sys.sp_cdc_cleanup_change_table
281 @capture_instance = @P1,
282 @low_water_mark = @P2,
283 @threshold = @P3,
284 @fCleanupFailed = @mz_cleanup_status_bit OUTPUT;
285SELECT @mz_cleanup_status_bit;
286 ";
287
288 let max_deletes = i64::cast_from(max_deletes);
289
290 let result = client
294 .query(GET_LSN_QUERY, &[&low_water_mark.as_bytes().as_slice()])
295 .await?;
296 let low_water_mark_to_use = match &result[..] {
297 [row] => row
298 .try_get::<&[u8], _>(0)?
299 .ok_or_else(|| SqlServerError::InvalidData {
300 column_name: "mz_cleanup_status_bit".to_string(),
301 error: "expected a bool, found NULL".to_string(),
302 })?,
303 other => Err(SqlServerError::ProgrammingError(format!(
304 "expected one row for low water mark, found {other:?}"
305 )))?,
306 };
307
308 let result = client
311 .query(
312 CLEANUP_QUERY,
313 &[&capture_instance, &low_water_mark_to_use, &max_deletes],
314 )
315 .await;
316
317 let rows = match result {
318 Ok(rows) => rows,
319 Err(SqlServerError::SqlServer(e)) => {
320 let already_cleaned_up = e.code().map(|code| code == 22957).unwrap_or(false);
324
325 if already_cleaned_up {
326 return Ok(());
327 } else {
328 return Err(SqlServerError::SqlServer(e));
329 }
330 }
331 Err(other) => return Err(other),
332 };
333
334 match &rows[..] {
335 [row] => {
336 let failure =
337 row.try_get::<bool, _>(0)?
338 .ok_or_else(|| SqlServerError::InvalidData {
339 column_name: "mz_cleanup_status_bit".to_string(),
340 error: "expected a bool, found NULL".to_string(),
341 })?;
342
343 if failure {
344 Err(super::cdc::CdcError::CleanupFailed {
345 capture_instance: capture_instance.to_string(),
346 low_water_mark: *low_water_mark,
347 })?
348 } else {
349 Ok(())
350 }
351 }
352 other => Err(SqlServerError::ProgrammingError(format!(
353 "expected one status row, found {other:?}"
354 ))),
355 }
356}
357
358static GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY: &str = "
375SELECT
376 s.name as schema_name,
377 t.name as table_name,
378 ch.capture_instance as capture_instance,
379 ch.create_date as capture_instance_create_date,
380 c.name as col_name,
381 ty.name as col_type,
382 c.is_nullable as col_nullable,
383 c.max_length as col_max_length,
384 c.precision as col_precision,
385 c.scale as col_scale,
386 tc.constraint_name AS col_primary_key_constraint
387FROM sys.tables t
388JOIN sys.schemas s ON t.schema_id = s.schema_id
389JOIN sys.columns c ON t.object_id = c.object_id
390JOIN sys.types ty ON c.user_type_id = ty.user_type_id
391JOIN cdc.change_tables ch ON t.object_id = ch.source_object_id
392LEFT JOIN information_schema.key_column_usage kc
393 ON kc.table_schema = s.name
394 AND kc.table_name = t.name
395 AND kc.column_name = c.name
396LEFT JOIN information_schema.table_constraints tc
397 ON tc.constraint_catalog = kc.constraint_catalog
398 AND tc.constraint_schema = kc.constraint_schema
399 AND tc.constraint_name = kc.constraint_name
400 AND tc.table_schema = kc.table_schema
401 AND tc.table_name = kc.table_name
402 AND tc.constraint_type = 'PRIMARY KEY'
403";
404
405pub async fn get_tables_for_capture_instance<'a>(
407 client: &mut Client,
408 capture_instances: impl IntoIterator<Item = &str>,
409) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
410 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
413 if params.is_empty() {
415 return Ok(Vec::default());
416 }
417
418 #[allow(clippy::as_conversions)]
420 let params_dyn: SmallVec<[_; 1]> = params
421 .iter()
422 .map(|instance| instance as &dyn tiberius::ToSql)
423 .collect();
424 let param_indexes = params
425 .iter()
426 .enumerate()
427 .map(|(idx, _)| format!("@P{}", idx + 1))
429 .join(", ");
430
431 let table_for_capture_instance_query = format!(
432 "{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY} WHERE ch.capture_instance IN ({param_indexes});"
433 );
434
435 let result = client
436 .query(&table_for_capture_instance_query, ¶ms_dyn[..])
437 .await?;
438
439 let tables = deserialize_table_columns_to_raw_tables(&result)?;
440
441 Ok(tables)
442}
443
444pub async fn ensure_database_cdc_enabled(client: &mut Client) -> Result<(), SqlServerError> {
449 static DATABASE_CDC_ENABLED_QUERY: &str =
450 "SELECT is_cdc_enabled FROM sys.databases WHERE database_id = DB_ID();";
451 let result = client.simple_query(DATABASE_CDC_ENABLED_QUERY).await?;
452
453 check_system_result(&result, "database CDC".to_string(), true)?;
454 Ok(())
455}
456
457pub async fn get_latest_restore_history_id(
464 client: &mut Client,
465) -> Result<Option<i32>, SqlServerError> {
466 static LATEST_RESTORE_ID_QUERY: &str = "SELECT TOP 1 restore_history_id \
467 FROM msdb.dbo.restorehistory \
468 WHERE destination_database_name = DB_NAME() \
469 ORDER BY restore_history_id DESC;";
470 let result = client.simple_query(LATEST_RESTORE_ID_QUERY).await?;
471
472 match &result[..] {
473 [] => Ok(None),
474 [row] => Ok(row.try_get::<i32, _>(0)?),
475 other => Err(SqlServerError::InvariantViolated(format!(
476 "expected one row, got {other:?}"
477 ))),
478 }
479}
480
481pub async fn ensure_snapshot_isolation_enabled(client: &mut Client) -> Result<(), SqlServerError> {
486 static SNAPSHOT_ISOLATION_QUERY: &str =
487 "SELECT snapshot_isolation_state FROM sys.databases WHERE database_id = DB_ID();";
488 let result = client.simple_query(SNAPSHOT_ISOLATION_QUERY).await?;
489
490 check_system_result(&result, "snapshot isolation".to_string(), 1u8)?;
491 Ok(())
492}
493
494pub async fn get_tables(client: &mut Client) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
495 let result = client
496 .simple_query(&format!("{GET_COLUMNS_FOR_TABLES_WITH_CDC_QUERY};"))
497 .await?;
498
499 let tables = deserialize_table_columns_to_raw_tables(&result)?;
500
501 Ok(tables)
502}
503
504fn deserialize_table_columns_to_raw_tables(
505 rows: &[tiberius::Row],
506) -> Result<Vec<SqlServerTableRaw>, SqlServerError> {
507 fn get_value<'a, T: tiberius::FromSql<'a>>(
508 row: &'a tiberius::Row,
509 name: &'static str,
510 ) -> Result<T, SqlServerError> {
511 row.try_get(name)?
512 .ok_or(SqlServerError::MissingColumn(name))
513 }
514
515 let mut tables = BTreeMap::default();
517 for row in rows {
518 let schema_name: Arc<str> = get_value::<&str>(row, "schema_name")?.into();
519 let table_name: Arc<str> = get_value::<&str>(row, "table_name")?.into();
520 let capture_instance: Arc<str> = get_value::<&str>(row, "capture_instance")?.into();
521 let capture_instance_create_date: NaiveDateTime =
522 get_value::<NaiveDateTime>(row, "capture_instance_create_date")?;
523 let primary_key_constraint: Option<Arc<str>> = row
524 .try_get::<&str, _>("col_primary_key_constraint")?
525 .map(|v| v.into());
526
527 let column_name = get_value::<&str>(row, "col_name")?.into();
528 let column = SqlServerColumnRaw {
529 name: Arc::clone(&column_name),
530 data_type: get_value::<&str>(row, "col_type")?.into(),
531 is_nullable: get_value(row, "col_nullable")?,
532 primary_key_constraint,
533 max_length: get_value(row, "col_max_length")?,
534 precision: get_value(row, "col_precision")?,
535 scale: get_value(row, "col_scale")?,
536 };
537
538 let columns: &mut Vec<_> = tables
539 .entry((
540 Arc::clone(&schema_name),
541 Arc::clone(&table_name),
542 Arc::clone(&capture_instance),
543 capture_instance_create_date,
544 ))
545 .or_default();
546 columns.push(column);
547 }
548
549 let raw_tables = tables
551 .into_iter()
552 .map(
553 |((schema, name, capture_instance, capture_instance_create_date), columns)| {
554 SqlServerTableRaw {
555 schema_name: schema,
556 name,
557 capture_instance: Arc::new(SqlServerCaptureInstanceRaw {
558 name: capture_instance,
559 create_date: capture_instance_create_date.into(),
560 }),
561 columns: columns.into(),
562 }
563 },
564 )
565 .collect::<Vec<SqlServerTableRaw>>();
566
567 Ok(raw_tables)
568}
569
570pub fn snapshot(
572 client: &mut Client,
573 schema: &str,
574 table: &str,
575) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> {
576 let query = format!("SELECT * FROM {schema}.{table};");
577 client.query_streaming(query, &[])
578}
579
580pub async fn snapshot_size(
582 client: &mut Client,
583 schema: &str,
584 table: &str,
585) -> Result<usize, SqlServerError> {
586 let query = format!("SELECT COUNT(*) FROM {schema}.{table};");
587 let result = client.query(query, &[]).await?;
588
589 match &result[..] {
590 [row] => match row.try_get::<i32, _>(0)? {
591 Some(count @ 0..) => Ok(usize::try_from(count).expect("known to fit")),
592 Some(negative) => Err(SqlServerError::InvalidData {
593 column_name: "count".to_string(),
594 error: format!("found negative count: {negative}"),
595 }),
596 None => Err(SqlServerError::InvalidData {
597 column_name: "count".to_string(),
598 error: "expected a value found NULL".to_string(),
599 }),
600 },
601 other => Err(SqlServerError::InvariantViolated(format!(
602 "expected one row, got {other:?}"
603 ))),
604 }
605}
606
607fn check_system_result<'a, T>(
609 result: &'a SmallVec<[tiberius::Row; 1]>,
610 name: String,
611 expected: T,
612) -> Result<(), SqlServerError>
613where
614 T: tiberius::FromSql<'a> + Copy + fmt::Debug + fmt::Display + PartialEq,
615{
616 match &result[..] {
617 [row] => {
618 let result: Option<T> = row.try_get(0)?;
619 if result == Some(expected) {
620 Ok(())
621 } else {
622 Err(SqlServerError::InvalidSystemSetting {
623 name,
624 expected: expected.to_string(),
625 actual: format!("{result:?}"),
626 })
627 }
628 }
629 other => Err(SqlServerError::InvariantViolated(format!(
630 "expected 1 row, got {other:?}"
631 ))),
632 }
633}
634
635pub async fn validate_source_privileges<'a>(
640 client: &mut Client,
641 capture_instances: impl IntoIterator<Item = &str>,
642) -> Result<(), SqlServerError> {
643 let params: SmallVec<[_; 1]> = capture_instances.into_iter().collect();
644
645 if params.is_empty() {
646 return Ok(());
647 }
648
649 let params_dyn: SmallVec<[_; 1]> = params
650 .iter()
651 .map(|instance| {
652 let instance: &dyn tiberius::ToSql = instance;
653 instance
654 })
655 .collect();
656
657 let param_indexes = (1..params.len() + 1)
658 .map(|idx| format!("@P{}", idx))
659 .join(", ");
660
661 let capture_instance_query = format!(
663 "
664 SELECT
665 SCHEMA_NAME(o.schema_id) + '.' + o.name AS qualified_table_name,
666 ct.capture_instance AS capture_instance,
667 COALESCE(HAS_PERMS_BY_NAME(SCHEMA_NAME(o.schema_id) + '.' + o.name, 'OBJECT', 'SELECT'), 0) AS table_select,
668 COALESCE(HAS_PERMS_BY_NAME('cdc.' + ct.capture_instance + '_CT', 'OBJECT', 'SELECT'), 0) AS capture_table_select
669 FROM cdc.change_tables ct
670 JOIN sys.objects o ON o.object_id = ct.source_object_id
671 WHERE ct.capture_instance IN ({param_indexes});
672 "
673 );
674
675 let rows = client
676 .query(capture_instance_query, ¶ms_dyn[..])
677 .await?;
678
679 let mut capture_instances_without_perms = vec![];
680 let mut tables_without_perms = vec![];
681
682 for row in rows {
683 let table: &str = row
684 .try_get("qualified_table_name")
685 .context("getting table column")?
686 .ok_or_else(|| anyhow::anyhow!("no table column?"))?;
687
688 let capture_instance: &str = row
689 .try_get("capture_instance")
690 .context("getting capture_instance column")?
691 .ok_or_else(|| anyhow::anyhow!("no capture_instance column?"))?;
692
693 let permitted_table: i32 = row
694 .try_get("table_select")
695 .context("getting table_select column")?
696 .ok_or_else(|| anyhow::anyhow!("no table_select column?"))?;
697
698 let permitted_capture_instance: i32 = row
699 .try_get("capture_table_select")
700 .context("getting capture_table_select column")?
701 .ok_or_else(|| anyhow::anyhow!("no capture_table_select column?"))?;
702
703 if permitted_table == 0 {
704 tables_without_perms.push(table.to_string());
705 }
706
707 if permitted_capture_instance == 0 {
708 capture_instances_without_perms.push(capture_instance.to_string());
709 }
710 }
711
712 if !capture_instances_without_perms.is_empty() || !tables_without_perms.is_empty() {
713 return Err(SqlServerError::AuthorizationError {
714 tables: tables_without_perms.join(", "),
715 capture_instances: capture_instances_without_perms.join(", "),
716 });
717 }
718
719 Ok(())
720}