1use std::collections::{BTreeMap, BTreeSet};
13use std::rc::Rc;
14use std::sync::Arc;
15use std::time::Instant;
16
17use differential_dataflow::AsCollection;
18use differential_dataflow::containers::TimelyStack;
19use futures::StreamExt;
20use itertools::Itertools;
21use mz_ore::cast::CastFrom;
22use mz_ore::collections::HashMap;
23use mz_ore::future::InTask;
24use mz_repr::{Diff, GlobalId, Row, RowArena};
25use mz_sql_server_util::SqlServerCdcMetrics;
26use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
27use mz_sql_server_util::desc::SqlServerRowDecoder;
28use mz_sql_server_util::inspect::{
29 ensure_database_cdc_enabled, ensure_sql_server_agent_running, get_latest_restore_history_id,
30};
31use mz_storage_types::dyncfgs::SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY;
32use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
33use mz_storage_types::sources::SqlServerSourceConnection;
34use mz_storage_types::sources::sql_server::{MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL};
35use mz_timely_util::builder_async::{
36 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
37};
38use mz_timely_util::containers::stack::AccountedStackBuilder;
39use timely::container::CapacityContainerBuilder;
40use timely::dataflow::operators::vec::Map;
41use timely::dataflow::operators::{CapabilitySet, Concat};
42use timely::dataflow::{Scope, StreamVec};
43use timely::progress::{Antichain, Timestamp};
44
45use crate::metrics::source::sql_server::SqlServerSourceMetrics;
46use crate::source::RawSourceCreationConfig;
47use crate::source::sql_server::{
48 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
49};
50use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
51
52static REPL_READER: &str = "reader";
58
59pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
60 scope: G,
61 config: RawSourceCreationConfig,
62 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
63 source: SqlServerSourceConnection,
64 metrics: SqlServerSourceMetrics,
65) -> (
66 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
67 StreamVec<G, ReplicationError>,
68 PressOnDropButton,
69) {
70 let op_name = format!("SqlServerReplicationReader({})", config.id);
71 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
72
73 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
74
75 let (definite_error_handle, definite_errors) =
77 builder.new_output::<CapacityContainerBuilder<_>>();
78
79 let (button, transient_errors) = builder.build_fallible(move |caps| {
80 let busy_signal = Arc::clone(&config.busy_signal);
81 Box::pin(SignaledFuture::new(busy_signal, async move {
82 let [
83 data_cap_set,
84 definite_error_cap_set,
85 ]: &mut [_; 2] = caps.try_into().unwrap();
86
87 let connection_config = source
88 .connection
89 .resolve_config(
90 &config.config.connection_context.secrets_reader,
91 &config.config,
92 InTask::Yes,
93 )
94 .await?;
95 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
96
97 let worker_id = config.worker_id;
98
99 let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
101 let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
103 let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
105 let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
107 let mut included_columns: HashMap<u64, Vec<Arc<str>>> = HashMap::new();
110
111 for (export_id, output) in outputs.iter() {
112 let key = output.partition_index;
113 if decoder_map.insert(key, Arc::clone(&output.decoder)).is_some() {
114 panic!("Multiple decoders for output index {}", output.partition_index);
115 }
116 let included_cols = output.decoder.included_column_names();
122 included_columns.insert(output.partition_index, included_cols);
123
124 capture_instances
125 .entry(Arc::clone(&output.capture_instance))
126 .or_default()
127 .push(output.partition_index);
128
129 if *output.resume_upper == [Lsn::minimum()] {
130 capture_instance_to_snapshot
131 .entry(Arc::clone(&output.capture_instance))
132 .or_default()
133 .push((output.partition_index, output.initial_lsn));
134 }
135 export_statistics.entry(Arc::clone(&output.capture_instance))
136 .or_default()
137 .push(
138 config
139 .statistics
140 .get(export_id)
141 .expect("statistics have been intialized")
142 .clone(),
143 );
144 }
145
146 metrics.snapshot_table_count.set(u64::cast_from(capture_instance_to_snapshot.len()));
151 if !capture_instance_to_snapshot.is_empty() {
152 for stats in config.statistics.values() {
153 stats.set_snapshot_records_known(0);
154 stats.set_snapshot_records_staged(0);
155 }
156 }
157 if !config.responsible_for(REPL_READER) {
160 return Ok::<_, TransientError>(());
161 }
162
163 let snapshot_instances = capture_instance_to_snapshot
164 .keys()
165 .map(|i| i.as_ref());
166
167 let snapshot_tables =
169 mz_sql_server_util::inspect::get_tables_for_capture_instance(
170 &mut client,
171 snapshot_instances,
172 )
173 .await?;
174
175 let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
177 if current_restore_history_id != source.extras.restore_history_id {
178 if SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY.get(config.config.config_set()) {
179 let definite_error = DefiniteError::RestoreHistoryChanged(
180 source.extras.restore_history_id.clone(),
181 current_restore_history_id.clone()
182 );
183 tracing::warn!(?definite_error, "Restore detected, exiting");
184
185 return_definite_error(
186 definite_error,
187 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
188 data_output,
189 data_cap_set,
190 definite_error_handle,
191 definite_error_cap_set,
192 ).await;
193 return Ok(());
194 } else {
195 tracing::warn!(
196 "Restore history mismatch ignored: expected={expected:?} actual={actual:?}",
197 expected=source.extras.restore_history_id,
198 actual=current_restore_history_id
199 );
200 }
201 }
202
203 ensure_database_cdc_enabled(&mut client).await?;
206 ensure_sql_server_agent_running(&mut client).await?;
207
208 for table in &snapshot_tables {
212 let qualified_table_name = format!("{schema_name}.{table_name}",
213 schema_name = &table.schema_name,
214 table_name = &table.name);
215 let size_calc_start = Instant::now();
216 let table_total =
217 mz_sql_server_util::inspect::snapshot_size(
218 &mut client,
219 &table.schema_name,
220 &table.name,
221 )
222 .await?;
223 metrics.set_snapshot_table_size_latency(
224 &qualified_table_name,
225 size_calc_start.elapsed().as_secs_f64()
226 );
227 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
228 export_stat.set_snapshot_records_known(u64::cast_from(table_total));
229 export_stat.set_snapshot_records_staged(0);
230 }
231 }
232 let cdc_metrics = PrometheusSqlServerCdcMetrics{inner: &metrics};
233 let mut cdc_handle = client
234 .cdc(capture_instances.keys().cloned(), cdc_metrics)
235 .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
236
237 let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
240 cdc_handle.wait_for_ready().await?;
243
244 tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
248
249 let report_interval =
250 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
251 let mut last_report = Instant::now();
252 let mut snapshot_lsns = BTreeMap::new();
253 let arena = RowArena::default();
254
255 for table in snapshot_tables {
256 let (snapshot_lsn, snapshot) = cdc_handle
258 .snapshot(&table, config.worker_id, config.id)
259 .await?;
260
261 tracing::info!(
262 %config.id,
263 %table.name,
264 %table.schema_name,
265 %snapshot_lsn,
266 "timely-{worker_id} snapshot start",
267 );
268
269 let mut snapshot = std::pin::pin!(snapshot);
270
271 snapshot_lsns.insert(
272 Arc::clone(&table.capture_instance.name),
273 snapshot_lsn,
274 );
275
276 let ci_name = &table.capture_instance.name;
277 let partition_indexes = capture_instance_to_snapshot
278 .get(ci_name)
279 .unwrap_or_else(|| {
280 panic!(
281 "no snapshot outputs in known capture \
282 instances [{}] for capture instance: \
283 '{}'",
284 capture_instance_to_snapshot
285 .keys()
286 .join(","),
287 ci_name,
288 );
289 });
290
291 let mut snapshot_staged = 0;
292 while let Some(result) = snapshot.next().await {
293 let sql_server_row =
294 result.map_err(TransientError::from)?;
295
296 if last_report.elapsed() > report_interval.get() {
297 last_report = Instant::now();
298 let stats =
299 export_statistics.get(ci_name).unwrap();
300 for export_stat in stats {
301 export_stat.set_snapshot_records_staged(
302 snapshot_staged,
303 );
304 }
305 }
306
307 for (partition_idx, _) in partition_indexes {
308 let mut mz_row = Row::default();
310
311 let decoder = decoder_map
312 .get(partition_idx)
313 .expect("decoder for output");
314 let message = decode(
317 decoder,
318 &sql_server_row,
319 &mut mz_row,
320 &arena,
321 None,
322 );
323 data_output
324 .give_fueled(
325 &data_cap_set[0],
326 ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
327 )
328 .await;
329 }
330 snapshot_staged += 1;
331 }
332
333 tracing::info!(
334 %config.id,
335 %table.name,
336 %table.schema_name,
337 %snapshot_lsn,
338 "timely-{worker_id} snapshot complete",
339 );
340 metrics.snapshot_table_count.dec();
341 let stats = export_statistics.get(ci_name).unwrap();
344 for export_stat in stats {
345 export_stat.set_snapshot_records_staged(snapshot_staged);
346 export_stat.set_snapshot_records_known(snapshot_staged);
347 }
348 }
349
350 snapshot_lsns
351 };
352
353 let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
368 .iter()
369 .flat_map(|(capture_instance, export_ids)|{
370 let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
371 export_ids
372 .iter()
373 .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
374 }).collect();
375
376 for (initial_lsn, snapshot_lsn) in rewinds.values() {
382 assert!(
383 initial_lsn <= snapshot_lsn,
384 "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
385 );
386 }
387
388 tracing::debug!("rewinds to process: {rewinds:?}");
389
390 capture_instance_to_snapshot.clear();
391
392 let mut resume_lsns = BTreeMap::new();
394 for src_info in outputs.values() {
395 let resume_lsn = match src_info.resume_upper.as_option() {
396 Some(lsn) if *lsn != Lsn::minimum() => *lsn,
397 Some(_) => src_info.initial_lsn.increment(),
401 None => panic!("resume_upper has at least one value"),
402 };
403 resume_lsns.entry(Arc::clone(&src_info.capture_instance))
404 .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
405 .or_insert(resume_lsn);
406 }
407
408 tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
409 for instance in capture_instances.keys() {
410 let resume_lsn = resume_lsns
411 .get(instance)
412 .expect("resume_lsn exists for capture instance");
413 cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
414 }
415
416 let cdc_stream = cdc_handle
418 .poll_interval(config.timestamp_interval)
419 .into_stream();
420 let mut cdc_stream = std::pin::pin!(cdc_stream);
421
422 let mut errored_partitions = BTreeSet::new();
423
424 let mut log_rewinds_complete = true;
428
429 let mut deferred_updates = BTreeMap::new();
444
445 while let Some(event) = cdc_stream.next().await {
446 let event = event.map_err(TransientError::from)?;
447 tracing::trace!(?config.id, ?event, "got replication event");
448
449 tracing::trace!("deferred_updates = {deferred_updates:?}");
450 match event {
451 CdcEvent::Progress { next_lsn } => {
454 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
455 rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
458 if rewinds.is_empty() {
459 if log_rewinds_complete {
460 tracing::debug!("rewinds complete");
461 log_rewinds_complete = false;
462 }
463 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
464 } else {
465 tracing::debug!("rewinds remaining: {:?}", rewinds);
466 }
467
468 if let Some(((deferred_lsn, _seqval), _row)) =
471 deferred_updates.first_key_value()
472 && *deferred_lsn < next_lsn
473 {
474 panic!(
475 "deferred update lsn {deferred_lsn} \
476 < progress lsn {next_lsn}: {:?}",
477 deferred_updates.keys()
478 );
479 }
480
481 }
482 CdcEvent::Data {
484 capture_instance,
485 lsn,
486 changes,
487 } => {
488 let Some(partition_indexes) =
489 capture_instances.get(&capture_instance)
490 else {
491 let definite_error =
492 DefiniteError::ProgrammingError(format!(
493 "capture instance didn't exist: \
494 '{capture_instance}'"
495 ));
496 return_definite_error(
497 definite_error,
498 capture_instances
499 .values()
500 .flat_map(|indexes| {
501 indexes.iter().copied()
502 }),
503 data_output,
504 data_cap_set,
505 definite_error_handle,
506 definite_error_cap_set,
507 )
508 .await;
509 return Ok(());
510 };
511
512 let (valid_partitions, err_partitions) =
513 partition_indexes
514 .iter()
515 .partition::<Vec<u64>, _>(
516 |&partition_idx| {
517 !errored_partitions
518 .contains(partition_idx)
519 },
520 );
521
522 if err_partitions.len() > 0 {
523 metrics.ignored.inc_by(u64::cast_from(changes.len()));
524 }
525
526 handle_data_event(
527 changes,
528 &valid_partitions,
529 &decoder_map,
530 lsn,
531 &rewinds,
532 &data_output,
533 data_cap_set,
534 &metrics,
535 &mut deferred_updates,
536 ).await?
537 },
538 CdcEvent::SchemaUpdate {
539 capture_instance,
540 table,
541 ddl_event,
542 } => {
543 let Some(partition_indexes) =
544 capture_instances.get(&capture_instance)
545 else {
546 let definite_error =
547 DefiniteError::ProgrammingError(format!(
548 "capture instance didn't exist: \
549 '{capture_instance}'"
550 ));
551 return_definite_error(
552 definite_error,
553 capture_instances
554 .values()
555 .flat_map(|indexes| {
556 indexes.iter().copied()
557 }),
558 data_output,
559 data_cap_set,
560 definite_error_handle,
561 definite_error_cap_set,
562 )
563 .await;
564 return Ok(());
565 };
566 let error =
567 DefiniteError::IncompatibleSchemaChange(
568 capture_instance.to_string(),
569 table.to_string(),
570 );
571 for partition_idx in partition_indexes {
572 let cols = included_columns
573 .get(partition_idx)
574 .unwrap_or_else(|| {
575 panic!(
576 "Partition index didn't \
577 exist: '{partition_idx}'"
578 )
579 });
580 if !errored_partitions
581 .contains(partition_idx)
582 && !ddl_event.is_compatible(cols)
583 {
584 let msg = Err(
585 error.clone().into(),
586 );
587 data_output
588 .give_fueled(
589 &data_cap_set[0],
590 ((*partition_idx, msg), ddl_event.lsn, Diff::ONE),
591 )
592 .await;
593 errored_partitions.insert(*partition_idx);
594 }
595 }
596 }
597 };
598 }
599 Err(TransientError::ReplicationEOF)
600 }))
601 });
602
603 let error_stream = definite_errors.concat(transient_errors.map(ReplicationError::Transient));
604
605 (
606 data_stream.as_collection(),
607 error_stream,
608 button.press_on_drop(),
609 )
610}
611
612async fn handle_data_event(
613 changes: Vec<CdcOperation>,
614 partition_indexes: &[u64],
615 decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
616 commit_lsn: Lsn,
617 rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
618 data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
619 data_cap_set: &CapabilitySet<Lsn>,
620 metrics: &SqlServerSourceMetrics,
621 deferred_updates: &mut BTreeMap<(Lsn, Lsn), CdcOperation>,
622) -> Result<(), TransientError> {
623 let mut mz_row = Row::default();
624 let arena = RowArena::default();
625
626 for change in changes {
627 let mut deferred_update: Option<_> = None;
631 let (sql_server_row, diff): (_, _) = match change {
632 CdcOperation::Insert(sql_server_row) => {
633 metrics.inserts.inc();
634 (sql_server_row, Diff::ONE)
635 }
636 CdcOperation::Delete(sql_server_row) => {
637 metrics.deletes.inc();
638 (sql_server_row, Diff::MINUS_ONE)
639 }
640
641 CdcOperation::UpdateNew(seqval, sql_server_row) => {
644 metrics.updates.inc();
646 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
647 if deferred_update.is_none() {
648 tracing::trace!("capture deferred UpdateNew ({commit_lsn}, {seqval})");
649 deferred_updates.insert(
650 (commit_lsn, seqval),
651 CdcOperation::UpdateNew(seqval, sql_server_row),
652 );
653 continue;
654 }
655 (sql_server_row, Diff::ZERO)
657 }
658 CdcOperation::UpdateOld(seqval, sql_server_row) => {
659 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
660 if deferred_update.is_none() {
661 tracing::trace!("capture deferred UpdateOld ({commit_lsn}, {seqval})");
662 deferred_updates.insert(
663 (commit_lsn, seqval),
664 CdcOperation::UpdateOld(seqval, sql_server_row),
665 );
666 continue;
667 }
668 (sql_server_row, Diff::ZERO)
670 }
671 };
672
673 for partition_idx in partition_indexes {
675 let decoder = decoder_map.get(partition_idx).unwrap();
676
677 let rewind = rewinds.get(partition_idx);
678 if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
681 continue;
682 }
683
684 let (message, diff) = if let Some(ref deferred_update) = deferred_update {
685 let (old_row, new_row) = match deferred_update {
686 CdcOperation::UpdateOld(_seqval, row) => (row, &sql_server_row),
687 CdcOperation::UpdateNew(_seqval, row) => (&sql_server_row, row),
688 CdcOperation::Insert(_) | CdcOperation::Delete(_) => unreachable!(),
689 };
690
691 let update_old = decode(decoder, old_row, &mut mz_row, &arena, Some(new_row));
692 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
693 data_output
694 .give_fueled(
695 &data_cap_set[0],
696 (
697 (*partition_idx, update_old.clone()),
698 Lsn::minimum(),
699 Diff::ONE,
700 ),
701 )
702 .await;
703 }
704 data_output
705 .give_fueled(
706 &data_cap_set[0],
707 ((*partition_idx, update_old), commit_lsn, Diff::MINUS_ONE),
708 )
709 .await;
710
711 (
712 decode(decoder, new_row, &mut mz_row, &arena, None),
713 Diff::ONE,
714 )
715 } else {
716 (
717 decode(decoder, &sql_server_row, &mut mz_row, &arena, None),
718 diff,
719 )
720 };
721 assert_ne!(Diff::ZERO, diff);
722 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
723 data_output
724 .give_fueled(
725 &data_cap_set[0],
726 ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
727 )
728 .await;
729 }
730 data_output
731 .give_fueled(
732 &data_cap_set[0],
733 ((*partition_idx, message), commit_lsn, diff),
734 )
735 .await;
736 }
737 }
738 Ok(())
739}
740
741type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
742 T,
743 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
744>;
745
746fn decode(
749 decoder: &SqlServerRowDecoder,
750 row: &tiberius::Row,
751 mz_row: &mut Row,
752 arena: &RowArena,
753 new_row: Option<&tiberius::Row>,
754) -> Result<SourceMessage, DataflowError> {
755 match decoder.decode(row, mz_row, arena, new_row) {
756 Ok(()) => Ok(SourceMessage {
757 key: Row::default(),
758 value: mz_row.clone(),
759 metadata: Row::default(),
760 }),
761 Err(e) => {
762 let kind = DecodeErrorKind::Text(e.to_string().into());
763 let raw = format!("{row:?}");
765 Err(DataflowError::DecodeError(Box::new(DecodeError {
766 kind,
767 raw: raw.as_bytes().to_vec(),
768 })))
769 }
770 }
771}
772
773async fn return_definite_error(
775 err: DefiniteError,
776 outputs: impl Iterator<Item = u64>,
777 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
778 data_capset: &CapabilitySet<Lsn>,
779 errs_handle: AsyncOutputHandle<Lsn, CapacityContainerBuilder<Vec<ReplicationError>>>,
780 errs_capset: &CapabilitySet<Lsn>,
781) {
782 for output_idx in outputs {
783 let update = (
784 (output_idx, Err(err.clone().into())),
785 Lsn {
789 vlf_id: u32::MAX,
790 block_id: u32::MAX,
791 record_id: u16::MAX,
792 },
793 Diff::ONE,
794 );
795 data_handle.give_fueled(&data_capset[0], update).await;
796 }
797 errs_handle.give(
798 &errs_capset[0],
799 ReplicationError::DefiniteError(Rc::new(err)),
800 );
801}
802
803struct PrometheusSqlServerCdcMetrics<'a> {
805 inner: &'a SqlServerSourceMetrics,
806}
807
808impl<'a> SqlServerCdcMetrics for PrometheusSqlServerCdcMetrics<'a> {
809 fn snapshot_table_lock_start(&self, table_name: &str) {
810 self.inner.update_snapshot_table_lock_count(table_name, 1);
811 }
812
813 fn snapshot_table_lock_end(&self, table_name: &str) {
814 self.inner.update_snapshot_table_lock_count(table_name, -1);
815 }
816}