1use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::collections::HashMap;
24use mz_ore::future::InTask;
25use mz_repr::{Diff, GlobalId, Row, RowArena};
26use mz_sql_server_util::SqlServerCdcMetrics;
27use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
28use mz_sql_server_util::desc::SqlServerRowDecoder;
29use mz_sql_server_util::inspect::{
30 ensure_database_cdc_enabled, ensure_sql_server_agent_running, get_latest_restore_history_id,
31};
32use mz_storage_types::dyncfgs::SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY;
33use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
34use mz_storage_types::sources::SqlServerSourceConnection;
35use mz_storage_types::sources::sql_server::{
36 CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
37};
38use mz_timely_util::builder_async::{
39 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
40};
41use mz_timely_util::containers::stack::AccountedStackBuilder;
42use timely::container::CapacityContainerBuilder;
43use timely::dataflow::operators::{CapabilitySet, Concat, Map};
44use timely::dataflow::{Scope, Stream as TimelyStream};
45use timely::progress::{Antichain, Timestamp};
46
47use crate::metrics::source::sql_server::SqlServerSourceMetrics;
48use crate::source::RawSourceCreationConfig;
49use crate::source::sql_server::{
50 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
51};
52use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
53
54static REPL_READER: &str = "reader";
60
61pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
62 scope: G,
63 config: RawSourceCreationConfig,
64 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
65 source: SqlServerSourceConnection,
66 metrics: SqlServerSourceMetrics,
67) -> (
68 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
69 TimelyStream<G, Infallible>,
70 TimelyStream<G, ReplicationError>,
71 PressOnDropButton,
72) {
73 let op_name = format!("SqlServerReplicationReader({})", config.id);
74 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
75
76 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
77 let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
78
79 let (definite_error_handle, definite_errors) =
81 builder.new_output::<CapacityContainerBuilder<_>>();
82
83 let (button, transient_errors) = builder.build_fallible(move |caps| {
84 let busy_signal = Arc::clone(&config.busy_signal);
85 Box::pin(SignaledFuture::new(busy_signal, async move {
86 let [
87 data_cap_set,
88 upper_cap_set,
89 definite_error_cap_set,
90 ]: &mut [_; 3] = caps.try_into().unwrap();
91
92 let connection_config = source
93 .connection
94 .resolve_config(
95 &config.config.connection_context.secrets_reader,
96 &config.config,
97 InTask::Yes,
98 )
99 .await?;
100 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
101
102 let worker_id = config.worker_id;
103
104 let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
106 let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
108 let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
110 let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
112 let mut included_columns: HashMap<u64, Vec<Arc<str>>> = HashMap::new();
115
116 for (export_id, output) in outputs.iter() {
117 let key = output.partition_index;
118 if decoder_map.insert(key, Arc::clone(&output.decoder)).is_some() {
119 panic!("Multiple decoders for output index {}", output.partition_index);
120 }
121 let included_cols = output.decoder.included_column_names();
127 included_columns.insert(output.partition_index, included_cols);
128
129 capture_instances
130 .entry(Arc::clone(&output.capture_instance))
131 .or_default()
132 .push(output.partition_index);
133
134 if *output.resume_upper == [Lsn::minimum()] {
135 capture_instance_to_snapshot
136 .entry(Arc::clone(&output.capture_instance))
137 .or_default()
138 .push((output.partition_index, output.initial_lsn));
139 }
140 export_statistics.entry(Arc::clone(&output.capture_instance))
141 .or_default()
142 .push(
143 config
144 .statistics
145 .get(export_id)
146 .expect("statistics have been intialized")
147 .clone(),
148 );
149 }
150
151 metrics.snapshot_table_count.set(u64::cast_from(capture_instance_to_snapshot.len()));
156 if !capture_instance_to_snapshot.is_empty() {
157 for stats in config.statistics.values() {
158 stats.set_snapshot_records_known(0);
159 stats.set_snapshot_records_staged(0);
160 }
161 }
162 if !config.responsible_for(REPL_READER) {
165 return Ok::<_, TransientError>(());
166 }
167
168 let snapshot_instances = capture_instance_to_snapshot
169 .keys()
170 .map(|i| i.as_ref());
171
172 let snapshot_tables =
174 mz_sql_server_util::inspect::get_tables_for_capture_instance(
175 &mut client,
176 snapshot_instances,
177 )
178 .await?;
179
180 let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
182 if current_restore_history_id != source.extras.restore_history_id {
183 if SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY.get(config.config.config_set()) {
184 let definite_error = DefiniteError::RestoreHistoryChanged(
185 source.extras.restore_history_id.clone(),
186 current_restore_history_id.clone()
187 );
188 tracing::warn!(?definite_error, "Restore detected, exiting");
189
190 return_definite_error(
191 definite_error,
192 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
193 data_output,
194 data_cap_set,
195 definite_error_handle,
196 definite_error_cap_set,
197 ).await;
198 return Ok(());
199 } else {
200 tracing::warn!(
201 "Restore history mismatch ignored: expected={expected:?} actual={actual:?}",
202 expected=source.extras.restore_history_id,
203 actual=current_restore_history_id
204 );
205 }
206 }
207
208 ensure_database_cdc_enabled(&mut client).await?;
211 ensure_sql_server_agent_running(&mut client).await?;
212
213 for table in &snapshot_tables {
217 let qualified_table_name = format!("{schema_name}.{table_name}",
218 schema_name = &table.schema_name,
219 table_name = &table.name);
220 let size_calc_start = Instant::now();
221 let table_total =
222 mz_sql_server_util::inspect::snapshot_size(
223 &mut client,
224 &table.schema_name,
225 &table.name,
226 )
227 .await?;
228 metrics.set_snapshot_table_size_latency(
229 &qualified_table_name,
230 size_calc_start.elapsed().as_secs_f64()
231 );
232 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
233 export_stat.set_snapshot_records_known(u64::cast_from(table_total));
234 export_stat.set_snapshot_records_staged(0);
235 }
236 }
237 let cdc_metrics = PrometheusSqlServerCdcMetrics{inner: &metrics};
238 let mut cdc_handle = client
239 .cdc(capture_instances.keys().cloned(), cdc_metrics)
240 .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
241
242 let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
245 cdc_handle.wait_for_ready().await?;
248
249 tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
253
254 let report_interval =
255 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
256 let mut last_report = Instant::now();
257 let mut snapshot_lsns = BTreeMap::new();
258 let arena = RowArena::default();
259
260 for table in snapshot_tables {
261 let (snapshot_lsn, snapshot) = cdc_handle
263 .snapshot(&table, config.worker_id, config.id)
264 .await?;
265
266 tracing::info!(
267 %config.id,
268 %table.name,
269 %table.schema_name,
270 %snapshot_lsn,
271 "timely-{worker_id} snapshot start",
272 );
273
274 let mut snapshot = std::pin::pin!(snapshot);
275
276 snapshot_lsns.insert(
277 Arc::clone(&table.capture_instance.name),
278 snapshot_lsn,
279 );
280
281 let ci_name = &table.capture_instance.name;
282 let partition_indexes = capture_instance_to_snapshot
283 .get(ci_name)
284 .unwrap_or_else(|| {
285 panic!(
286 "no snapshot outputs in known capture \
287 instances [{}] for capture instance: \
288 '{}'",
289 capture_instance_to_snapshot
290 .keys()
291 .join(","),
292 ci_name,
293 );
294 });
295
296 let mut snapshot_staged = 0;
297 while let Some(result) = snapshot.next().await {
298 let sql_server_row =
299 result.map_err(TransientError::from)?;
300
301 if last_report.elapsed() > report_interval.get() {
302 last_report = Instant::now();
303 let stats =
304 export_statistics.get(ci_name).unwrap();
305 for export_stat in stats {
306 export_stat.set_snapshot_records_staged(
307 snapshot_staged,
308 );
309 }
310 }
311
312 for (partition_idx, _) in partition_indexes {
313 let mut mz_row = Row::default();
315
316 let decoder = decoder_map
317 .get(partition_idx)
318 .expect("decoder for output");
319 let message = decode(
322 decoder,
323 &sql_server_row,
324 &mut mz_row,
325 &arena,
326 None,
327 );
328 data_output
329 .give_fueled(
330 &data_cap_set[0],
331 ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
332 )
333 .await;
334 }
335 snapshot_staged += 1;
336 }
337
338 tracing::info!(
339 %config.id,
340 %table.name,
341 %table.schema_name,
342 %snapshot_lsn,
343 "timely-{worker_id} snapshot complete",
344 );
345 metrics.snapshot_table_count.dec();
346 let stats = export_statistics.get(ci_name).unwrap();
349 for export_stat in stats {
350 export_stat.set_snapshot_records_staged(snapshot_staged);
351 export_stat.set_snapshot_records_known(snapshot_staged);
352 }
353 }
354
355 snapshot_lsns
356 };
357
358 let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
373 .iter()
374 .flat_map(|(capture_instance, export_ids)|{
375 let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
376 export_ids
377 .iter()
378 .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
379 }).collect();
380
381 for (initial_lsn, snapshot_lsn) in rewinds.values() {
387 assert!(
388 initial_lsn <= snapshot_lsn,
389 "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
390 );
391 }
392
393 tracing::debug!("rewinds to process: {rewinds:?}");
394
395 capture_instance_to_snapshot.clear();
396
397 let mut resume_lsns = BTreeMap::new();
399 for src_info in outputs.values() {
400 let resume_lsn = match src_info.resume_upper.as_option() {
401 Some(lsn) if *lsn != Lsn::minimum() => *lsn,
402 Some(_) => src_info.initial_lsn.increment(),
406 None => panic!("resume_upper has at least one value"),
407 };
408 resume_lsns.entry(Arc::clone(&src_info.capture_instance))
409 .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
410 .or_insert(resume_lsn);
411 }
412
413 tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
414 for instance in capture_instances.keys() {
415 let resume_lsn = resume_lsns
416 .get(instance)
417 .expect("resume_lsn exists for capture instance");
418 cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
419 }
420
421 let cdc_stream = cdc_handle
423 .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
424 .into_stream();
425 let mut cdc_stream = std::pin::pin!(cdc_stream);
426
427 let mut errored_partitions = BTreeSet::new();
428
429 let mut log_rewinds_complete = true;
433
434 let mut deferred_updates = BTreeMap::new();
449
450 while let Some(event) = cdc_stream.next().await {
451 let event = event.map_err(TransientError::from)?;
452 tracing::trace!(?config.id, ?event, "got replication event");
453
454 tracing::trace!("deferred_updates = {deferred_updates:?}");
455 match event {
456 CdcEvent::Progress { next_lsn } => {
459 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
460 rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
463 if rewinds.is_empty() {
464 if log_rewinds_complete {
465 tracing::debug!("rewinds complete");
466 log_rewinds_complete = false;
467 }
468 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
469 } else {
470 tracing::debug!("rewinds remaining: {:?}", rewinds);
471 }
472
473 if let Some(((deferred_lsn, _seqval), _row)) =
476 deferred_updates.first_key_value()
477 && *deferred_lsn < next_lsn
478 {
479 panic!(
480 "deferred update lsn {deferred_lsn} \
481 < progress lsn {next_lsn}: {:?}",
482 deferred_updates.keys()
483 );
484 }
485
486 upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
487 }
488 CdcEvent::Data {
490 capture_instance,
491 lsn,
492 changes,
493 } => {
494 let Some(partition_indexes) =
495 capture_instances.get(&capture_instance)
496 else {
497 let definite_error =
498 DefiniteError::ProgrammingError(format!(
499 "capture instance didn't exist: \
500 '{capture_instance}'"
501 ));
502 return_definite_error(
503 definite_error,
504 capture_instances
505 .values()
506 .flat_map(|indexes| {
507 indexes.iter().copied()
508 }),
509 data_output,
510 data_cap_set,
511 definite_error_handle,
512 definite_error_cap_set,
513 )
514 .await;
515 return Ok(());
516 };
517
518 let (valid_partitions, err_partitions) =
519 partition_indexes
520 .iter()
521 .partition::<Vec<u64>, _>(
522 |&partition_idx| {
523 !errored_partitions
524 .contains(partition_idx)
525 },
526 );
527
528 if err_partitions.len() > 0 {
529 metrics.ignored.inc_by(u64::cast_from(changes.len()));
530 }
531
532 handle_data_event(
533 changes,
534 &valid_partitions,
535 &decoder_map,
536 lsn,
537 &rewinds,
538 &data_output,
539 data_cap_set,
540 &metrics,
541 &mut deferred_updates,
542 ).await?
543 },
544 CdcEvent::SchemaUpdate {
545 capture_instance,
546 table,
547 ddl_event,
548 } => {
549 let Some(partition_indexes) =
550 capture_instances.get(&capture_instance)
551 else {
552 let definite_error =
553 DefiniteError::ProgrammingError(format!(
554 "capture instance didn't exist: \
555 '{capture_instance}'"
556 ));
557 return_definite_error(
558 definite_error,
559 capture_instances
560 .values()
561 .flat_map(|indexes| {
562 indexes.iter().copied()
563 }),
564 data_output,
565 data_cap_set,
566 definite_error_handle,
567 definite_error_cap_set,
568 )
569 .await;
570 return Ok(());
571 };
572 let error =
573 DefiniteError::IncompatibleSchemaChange(
574 capture_instance.to_string(),
575 table.to_string(),
576 );
577 for partition_idx in partition_indexes {
578 let cols = included_columns
579 .get(partition_idx)
580 .unwrap_or_else(|| {
581 panic!(
582 "Partition index didn't \
583 exist: '{partition_idx}'"
584 )
585 });
586 if !errored_partitions
587 .contains(partition_idx)
588 && !ddl_event.is_compatible(cols)
589 {
590 let msg = Err(
591 error.clone().into(),
592 );
593 data_output
594 .give_fueled(
595 &data_cap_set[0],
596 ((*partition_idx, msg), ddl_event.lsn, Diff::ONE),
597 )
598 .await;
599 errored_partitions.insert(*partition_idx);
600 }
601 }
602 }
603 };
604 }
605 Err(TransientError::ReplicationEOF)
606 }))
607 });
608
609 let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
610
611 (
612 data_stream.as_collection(),
613 upper_stream,
614 error_stream,
615 button.press_on_drop(),
616 )
617}
618
619async fn handle_data_event(
620 changes: Vec<CdcOperation>,
621 partition_indexes: &[u64],
622 decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
623 commit_lsn: Lsn,
624 rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
625 data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
626 data_cap_set: &CapabilitySet<Lsn>,
627 metrics: &SqlServerSourceMetrics,
628 deferred_updates: &mut BTreeMap<(Lsn, Lsn), CdcOperation>,
629) -> Result<(), TransientError> {
630 let mut mz_row = Row::default();
631 let arena = RowArena::default();
632
633 for change in changes {
634 let mut deferred_update: Option<_> = None;
638 let (sql_server_row, diff): (_, _) = match change {
639 CdcOperation::Insert(sql_server_row) => {
640 metrics.inserts.inc();
641 (sql_server_row, Diff::ONE)
642 }
643 CdcOperation::Delete(sql_server_row) => {
644 metrics.deletes.inc();
645 (sql_server_row, Diff::MINUS_ONE)
646 }
647
648 CdcOperation::UpdateNew(seqval, sql_server_row) => {
651 metrics.updates.inc();
653 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
654 if deferred_update.is_none() {
655 tracing::trace!("capture deferred UpdateNew ({commit_lsn}, {seqval})");
656 deferred_updates.insert(
657 (commit_lsn, seqval),
658 CdcOperation::UpdateNew(seqval, sql_server_row),
659 );
660 continue;
661 }
662 (sql_server_row, Diff::ZERO)
664 }
665 CdcOperation::UpdateOld(seqval, sql_server_row) => {
666 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
667 if deferred_update.is_none() {
668 tracing::trace!("capture deferred UpdateOld ({commit_lsn}, {seqval})");
669 deferred_updates.insert(
670 (commit_lsn, seqval),
671 CdcOperation::UpdateOld(seqval, sql_server_row),
672 );
673 continue;
674 }
675 (sql_server_row, Diff::ZERO)
677 }
678 };
679
680 for partition_idx in partition_indexes {
682 let decoder = decoder_map.get(partition_idx).unwrap();
683
684 let rewind = rewinds.get(partition_idx);
685 if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
688 continue;
689 }
690
691 let (message, diff) = if let Some(ref deferred_update) = deferred_update {
692 let (old_row, new_row) = match deferred_update {
693 CdcOperation::UpdateOld(_seqval, row) => (row, &sql_server_row),
694 CdcOperation::UpdateNew(_seqval, row) => (&sql_server_row, row),
695 CdcOperation::Insert(_) | CdcOperation::Delete(_) => unreachable!(),
696 };
697
698 let update_old = decode(decoder, old_row, &mut mz_row, &arena, Some(new_row));
699 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
700 data_output
701 .give_fueled(
702 &data_cap_set[0],
703 (
704 (*partition_idx, update_old.clone()),
705 Lsn::minimum(),
706 Diff::ONE,
707 ),
708 )
709 .await;
710 }
711 data_output
712 .give_fueled(
713 &data_cap_set[0],
714 ((*partition_idx, update_old), commit_lsn, Diff::MINUS_ONE),
715 )
716 .await;
717
718 (
719 decode(decoder, new_row, &mut mz_row, &arena, None),
720 Diff::ONE,
721 )
722 } else {
723 (
724 decode(decoder, &sql_server_row, &mut mz_row, &arena, None),
725 diff,
726 )
727 };
728 assert_ne!(Diff::ZERO, diff);
729 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
730 data_output
731 .give_fueled(
732 &data_cap_set[0],
733 ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
734 )
735 .await;
736 }
737 data_output
738 .give_fueled(
739 &data_cap_set[0],
740 ((*partition_idx, message), commit_lsn, diff),
741 )
742 .await;
743 }
744 }
745 Ok(())
746}
747
748type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
749 T,
750 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
751>;
752
753fn decode(
756 decoder: &SqlServerRowDecoder,
757 row: &tiberius::Row,
758 mz_row: &mut Row,
759 arena: &RowArena,
760 new_row: Option<&tiberius::Row>,
761) -> Result<SourceMessage, DataflowError> {
762 match decoder.decode(row, mz_row, arena, new_row) {
763 Ok(()) => Ok(SourceMessage {
764 key: Row::default(),
765 value: mz_row.clone(),
766 metadata: Row::default(),
767 }),
768 Err(e) => {
769 let kind = DecodeErrorKind::Text(e.to_string().into());
770 let raw = format!("{row:?}");
772 Err(DataflowError::DecodeError(Box::new(DecodeError {
773 kind,
774 raw: raw.as_bytes().to_vec(),
775 })))
776 }
777 }
778}
779
780async fn return_definite_error(
782 err: DefiniteError,
783 outputs: impl Iterator<Item = u64>,
784 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
785 data_capset: &CapabilitySet<Lsn>,
786 errs_handle: AsyncOutputHandle<Lsn, CapacityContainerBuilder<Vec<ReplicationError>>>,
787 errs_capset: &CapabilitySet<Lsn>,
788) {
789 for output_idx in outputs {
790 let update = (
791 (output_idx, Err(err.clone().into())),
792 Lsn {
796 vlf_id: u32::MAX,
797 block_id: u32::MAX,
798 record_id: u16::MAX,
799 },
800 Diff::ONE,
801 );
802 data_handle.give_fueled(&data_capset[0], update).await;
803 }
804 errs_handle.give(
805 &errs_capset[0],
806 ReplicationError::DefiniteError(Rc::new(err)),
807 );
808}
809
810struct PrometheusSqlServerCdcMetrics<'a> {
812 inner: &'a SqlServerSourceMetrics,
813}
814
815impl<'a> SqlServerCdcMetrics for PrometheusSqlServerCdcMetrics<'a> {
816 fn snapshot_table_lock_start(&self, table_name: &str) {
817 self.inner.update_snapshot_table_lock_count(table_name, 1);
818 }
819
820 fn snapshot_table_lock_end(&self, table_name: &str) {
821 self.inner.update_snapshot_table_lock_count(table_name, -1);
822 }
823}