1use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::collections::HashMap;
24use mz_ore::future::InTask;
25use mz_repr::{Diff, GlobalId, Row, RowArena};
26use mz_sql_server_util::SqlServerCdcMetrics;
27use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
28use mz_sql_server_util::desc::SqlServerRowDecoder;
29use mz_sql_server_util::inspect::{
30 ensure_database_cdc_enabled, ensure_sql_server_agent_running, get_latest_restore_history_id,
31};
32use mz_storage_types::dyncfgs::SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY;
33use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
34use mz_storage_types::sources::SqlServerSourceConnection;
35use mz_storage_types::sources::sql_server::{
36 CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
37};
38use mz_timely_util::builder_async::{
39 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
40};
41use mz_timely_util::containers::stack::AccountedStackBuilder;
42use timely::container::CapacityContainerBuilder;
43use timely::dataflow::operators::{CapabilitySet, Concat, Map};
44use timely::dataflow::{Scope, Stream as TimelyStream};
45use timely::progress::{Antichain, Timestamp};
46
47use crate::metrics::source::sql_server::SqlServerSourceMetrics;
48use crate::source::RawSourceCreationConfig;
49use crate::source::sql_server::{
50 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
51};
52use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
53
54static REPL_READER: &str = "reader";
60
61pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
62 scope: G,
63 config: RawSourceCreationConfig,
64 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
65 source: SqlServerSourceConnection,
66 metrics: SqlServerSourceMetrics,
67) -> (
68 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
69 TimelyStream<G, Infallible>,
70 TimelyStream<G, ReplicationError>,
71 PressOnDropButton,
72) {
73 let op_name = format!("SqlServerReplicationReader({})", config.id);
74 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
75
76 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
77 let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
78
79 let (definite_error_handle, definite_errors) =
81 builder.new_output::<CapacityContainerBuilder<_>>();
82
83 let (button, transient_errors) = builder.build_fallible(move |caps| {
84 let busy_signal = Arc::clone(&config.busy_signal);
85 Box::pin(SignaledFuture::new(busy_signal, async move {
86 let [
87 data_cap_set,
88 upper_cap_set,
89 definite_error_cap_set,
90 ]: &mut [_; 3] = caps.try_into().unwrap();
91
92 let connection_config = source
93 .connection
94 .resolve_config(
95 &config.config.connection_context.secrets_reader,
96 &config.config,
97 InTask::Yes,
98 )
99 .await?;
100 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
101
102 let worker_id = config.worker_id;
103
104 let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
106 let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
108 let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
110 let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
112 let mut included_columns: HashMap<u64, Vec<Arc<str>>> = HashMap::new();
114
115 for (export_id, output) in outputs.iter() {
116 if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
117 panic!("Multiple decoders for output index {}", output.partition_index);
118 }
119 let included_cols = output.decoder.included_column_names();
122 included_columns.insert(output.partition_index, included_cols);
123
124 capture_instances
125 .entry(Arc::clone(&output.capture_instance))
126 .or_default()
127 .push(output.partition_index);
128
129 if *output.resume_upper == [Lsn::minimum()] {
130 capture_instance_to_snapshot
131 .entry(Arc::clone(&output.capture_instance))
132 .or_default()
133 .push((output.partition_index, output.initial_lsn));
134 }
135 export_statistics.entry(Arc::clone(&output.capture_instance))
136 .or_default()
137 .push(
138 config
139 .statistics
140 .get(export_id)
141 .expect("statistics have been intialized")
142 .clone(),
143 );
144 }
145
146 metrics.snapshot_table_count.set(u64::cast_from(capture_instance_to_snapshot.len()));
151 if !capture_instance_to_snapshot.is_empty() {
152 for stats in config.statistics.values() {
153 stats.set_snapshot_records_known(0);
154 stats.set_snapshot_records_staged(0);
155 }
156 }
157 if !config.responsible_for(REPL_READER) {
160 return Ok::<_, TransientError>(());
161 }
162
163 let snapshot_instances = capture_instance_to_snapshot
164 .keys()
165 .map(|i| i.as_ref());
166
167 let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
169
170 let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
172 if current_restore_history_id != source.extras.restore_history_id {
173 if SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY.get(config.config.config_set()) {
174 let definite_error = DefiniteError::RestoreHistoryChanged(
175 source.extras.restore_history_id.clone(),
176 current_restore_history_id.clone()
177 );
178 tracing::warn!(?definite_error, "Restore detected, exiting");
179
180 return_definite_error(
181 definite_error,
182 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
183 data_output,
184 data_cap_set,
185 definite_error_handle,
186 definite_error_cap_set,
187 ).await;
188 return Ok(());
189 } else {
190 tracing::warn!(
191 "Restore history mismatch ignored: expected={expected:?} actual={actual:?}",
192 expected=source.extras.restore_history_id,
193 actual=current_restore_history_id
194 );
195 }
196 }
197
198 ensure_database_cdc_enabled(&mut client).await?;
201 ensure_sql_server_agent_running(&mut client).await?;
202
203 for table in &snapshot_tables {
207 let qualified_table_name = format!("{schema_name}.{table_name}",
208 schema_name = &table.schema_name,
209 table_name = &table.name);
210 let size_calc_start = Instant::now();
211 let table_total = mz_sql_server_util::inspect::snapshot_size(&mut client, &table.schema_name, &table.name).await?;
212 metrics.set_snapshot_table_size_latency(
213 &qualified_table_name,
214 size_calc_start.elapsed().as_secs_f64()
215 );
216 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
217 export_stat.set_snapshot_records_known(u64::cast_from(table_total));
218 export_stat.set_snapshot_records_staged(0);
219 }
220 }
221 let cdc_metrics = PrometheusSqlServerCdcMetrics{inner: &metrics};
222 let mut cdc_handle = client
223 .cdc(capture_instances.keys().cloned(), cdc_metrics)
224 .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
225
226 let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
229 cdc_handle.wait_for_ready().await?;
232
233 tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
237
238 let report_interval =
239 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
240 let mut last_report = Instant::now();
241 let mut snapshot_lsns = BTreeMap::new();
242 let arena = RowArena::default();
243
244 for table in snapshot_tables {
245 let (snapshot_lsn, snapshot) = cdc_handle
247 .snapshot(&table, config.worker_id, config.id)
248 .await?;
249
250 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot start");
251
252 let mut snapshot = std::pin::pin!(snapshot);
253
254 snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
255
256 let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
257 .unwrap_or_else(|| {
258 panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
259 });
260
261 let mut snapshot_staged = 0;
262 while let Some(result) = snapshot.next().await {
263 let sql_server_row = result.map_err(TransientError::from)?;
264
265 if last_report.elapsed() > report_interval.get() {
266 last_report = Instant::now();
267 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
268 export_stat.set_snapshot_records_staged(snapshot_staged);
269 }
270 }
271
272 for (partition_idx, _) in partition_indexes {
273 let mut mz_row = Row::default();
275
276 let decoder = decoder_map.get(partition_idx).expect("decoder for output");
277 let message = decode(decoder, &sql_server_row, &mut mz_row, &arena, None);
279 data_output
280 .give_fueled(
281 &data_cap_set[0],
282 ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
283 )
284 .await;
285 }
286 snapshot_staged += 1;
287 }
288
289 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot complete");
290 metrics.snapshot_table_count.dec();
291 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
293 export_stat.set_snapshot_records_staged(snapshot_staged);
294 export_stat.set_snapshot_records_known(snapshot_staged);
295 }
296 }
297
298 snapshot_lsns
299 };
300
301 let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
316 .iter()
317 .flat_map(|(capture_instance, export_ids)|{
318 let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
319 export_ids
320 .iter()
321 .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
322 }).collect();
323
324 for (initial_lsn, snapshot_lsn) in rewinds.values() {
330 assert!(
331 initial_lsn <= snapshot_lsn,
332 "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
333 );
334 }
335
336 tracing::debug!("rewinds to process: {rewinds:?}");
337
338 capture_instance_to_snapshot.clear();
339
340 let mut resume_lsns = BTreeMap::new();
342 for src_info in outputs.values() {
343 let resume_lsn = match src_info.resume_upper.as_option() {
344 Some(lsn) if *lsn != Lsn::minimum() => *lsn,
345 Some(_) => src_info.initial_lsn.increment(),
349 None => panic!("resume_upper has at least one value"),
350 };
351 resume_lsns.entry(Arc::clone(&src_info.capture_instance))
352 .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
353 .or_insert(resume_lsn);
354 }
355
356 tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
357 for instance in capture_instances.keys() {
358 let resume_lsn = resume_lsns
359 .get(instance)
360 .expect("resume_lsn exists for capture instance");
361 cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
362 }
363
364 let cdc_stream = cdc_handle
366 .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
367 .into_stream();
368 let mut cdc_stream = std::pin::pin!(cdc_stream);
369
370 let mut errored_partitions = BTreeSet::new();
371
372 let mut log_rewinds_complete = true;
376
377 let mut deferred_updates = BTreeMap::new();
392
393 while let Some(event) = cdc_stream.next().await {
394 let event = event.map_err(TransientError::from)?;
395 tracing::trace!(?config.id, ?event, "got replication event");
396
397 tracing::trace!("deferred_updates = {deferred_updates:?}");
398 match event {
399 CdcEvent::Progress { next_lsn } => {
402 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
403 rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
406 if rewinds.is_empty() {
407 if log_rewinds_complete {
408 tracing::debug!("rewinds complete");
409 log_rewinds_complete = false;
410 }
411 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
412 } else {
413 tracing::debug!("rewinds remaining: {:?}", rewinds);
414 }
415
416 if let Some(((deferred_lsn, _seqval), _row)) = deferred_updates.first_key_value()
419 && *deferred_lsn < next_lsn
420 {
421 panic!(
422 "deferred update lsn {deferred_lsn} < progress lsn {next_lsn}: {:?}",
423 deferred_updates.keys()
424 );
425 }
426
427 upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
428 }
429 CdcEvent::Data {
431 capture_instance,
432 lsn,
433 changes,
434 } => {
435 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
436 let definite_error = DefiniteError::ProgrammingError(format!(
437 "capture instance didn't exist: '{capture_instance}'"
438 ));
439 return_definite_error(
440 definite_error,
441 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
442 data_output,
443 data_cap_set,
444 definite_error_handle,
445 definite_error_cap_set,
446 )
447 .await;
448 return Ok(());
449 };
450
451 let (valid_partitions, err_partitions) = partition_indexes.iter().partition::<Vec<u64>, _>(|&partition_idx| {
452 !errored_partitions.contains(partition_idx)
453 });
454
455 if err_partitions.len() > 0 {
456 metrics.ignored.inc_by(u64::cast_from(changes.len()));
457 }
458
459 handle_data_event(
460 changes,
461 &valid_partitions,
462 &decoder_map,
463 lsn,
464 &rewinds,
465 &data_output,
466 data_cap_set,
467 &metrics,
468 &mut deferred_updates,
469 ).await?
470 },
471 CdcEvent::SchemaUpdate { capture_instance, table, ddl_event } => {
472 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
473 let definite_error = DefiniteError::ProgrammingError(format!(
474 "capture instance didn't exist: '{capture_instance}'"
475 ));
476 return_definite_error(
477 definite_error,
478 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
479 data_output,
480 data_cap_set,
481 definite_error_handle,
482 definite_error_cap_set,
483 )
484 .await;
485 return Ok(());
486 };
487 let error = DefiniteError::IncompatibleSchemaChange(
488 capture_instance.to_string(),
489 table.to_string()
490 );
491 for partition_idx in partition_indexes {
492 if !errored_partitions.contains(partition_idx) && !ddl_event.is_compatible(included_columns.get(partition_idx).unwrap_or_else(|| panic!("Partition index didn't exist: '{partition_idx}'"))) {
493 data_output
494 .give_fueled(
495 &data_cap_set[0],
496 ((*partition_idx, Err(error.clone().into())), ddl_event.lsn, Diff::ONE),
497 )
498 .await;
499 errored_partitions.insert(*partition_idx);
500 }
501 }
502 }
503 };
504 }
505 Err(TransientError::ReplicationEOF)
506 }))
507 });
508
509 let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
510
511 (
512 data_stream.as_collection(),
513 upper_stream,
514 error_stream,
515 button.press_on_drop(),
516 )
517}
518
519async fn handle_data_event(
520 changes: Vec<CdcOperation>,
521 partition_indexes: &[u64],
522 decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
523 commit_lsn: Lsn,
524 rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
525 data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
526 data_cap_set: &CapabilitySet<Lsn>,
527 metrics: &SqlServerSourceMetrics,
528 deferred_updates: &mut BTreeMap<(Lsn, Lsn), CdcOperation>,
529) -> Result<(), TransientError> {
530 let mut mz_row = Row::default();
531 let arena = RowArena::default();
532
533 for change in changes {
534 let mut deferred_update: Option<_> = None;
538 let (sql_server_row, diff): (_, _) = match change {
539 CdcOperation::Insert(sql_server_row) => {
540 metrics.inserts.inc();
541 (sql_server_row, Diff::ONE)
542 }
543 CdcOperation::Delete(sql_server_row) => {
544 metrics.deletes.inc();
545 (sql_server_row, Diff::MINUS_ONE)
546 }
547
548 CdcOperation::UpdateNew(seqval, sql_server_row) => {
551 metrics.updates.inc();
553 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
554 if deferred_update.is_none() {
555 tracing::trace!("capture deferred UpdateNew ({commit_lsn}, {seqval})");
556 deferred_updates.insert(
557 (commit_lsn, seqval),
558 CdcOperation::UpdateNew(seqval, sql_server_row),
559 );
560 continue;
561 }
562 (sql_server_row, Diff::ZERO)
564 }
565 CdcOperation::UpdateOld(seqval, sql_server_row) => {
566 deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
567 if deferred_update.is_none() {
568 tracing::trace!("capture deferred UpdateOld ({commit_lsn}, {seqval})");
569 deferred_updates.insert(
570 (commit_lsn, seqval),
571 CdcOperation::UpdateOld(seqval, sql_server_row),
572 );
573 continue;
574 }
575 (sql_server_row, Diff::ZERO)
577 }
578 };
579
580 for partition_idx in partition_indexes {
582 let decoder = decoder_map.get(partition_idx).unwrap();
583
584 let rewind = rewinds.get(partition_idx);
585 if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
588 continue;
589 }
590
591 let (message, diff) = if let Some(ref deferred_update) = deferred_update {
592 let (old_row, new_row) = match deferred_update {
593 CdcOperation::UpdateOld(_seqval, row) => (row, &sql_server_row),
594 CdcOperation::UpdateNew(_seqval, row) => (&sql_server_row, row),
595 CdcOperation::Insert(_) | CdcOperation::Delete(_) => unreachable!(),
596 };
597
598 let update_old = decode(decoder, old_row, &mut mz_row, &arena, Some(new_row));
599 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
600 data_output
601 .give_fueled(
602 &data_cap_set[0],
603 (
604 (*partition_idx, update_old.clone()),
605 Lsn::minimum(),
606 Diff::ONE,
607 ),
608 )
609 .await;
610 }
611 data_output
612 .give_fueled(
613 &data_cap_set[0],
614 ((*partition_idx, update_old), commit_lsn, Diff::MINUS_ONE),
615 )
616 .await;
617
618 (
619 decode(decoder, new_row, &mut mz_row, &arena, None),
620 Diff::ONE,
621 )
622 } else {
623 (
624 decode(decoder, &sql_server_row, &mut mz_row, &arena, None),
625 diff,
626 )
627 };
628 assert_ne!(Diff::ZERO, diff);
629 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
630 data_output
631 .give_fueled(
632 &data_cap_set[0],
633 ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
634 )
635 .await;
636 }
637 data_output
638 .give_fueled(
639 &data_cap_set[0],
640 ((*partition_idx, message), commit_lsn, diff),
641 )
642 .await;
643 }
644 }
645 Ok(())
646}
647
648type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
649 T,
650 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
651>;
652
653fn decode(
656 decoder: &SqlServerRowDecoder,
657 row: &tiberius::Row,
658 mz_row: &mut Row,
659 arena: &RowArena,
660 new_row: Option<&tiberius::Row>,
661) -> Result<SourceMessage, DataflowError> {
662 match decoder.decode(row, mz_row, arena, new_row) {
663 Ok(()) => Ok(SourceMessage {
664 key: Row::default(),
665 value: mz_row.clone(),
666 metadata: Row::default(),
667 }),
668 Err(e) => {
669 let kind = DecodeErrorKind::Text(e.to_string().into());
670 let raw = format!("{row:?}");
672 Err(DataflowError::DecodeError(Box::new(DecodeError {
673 kind,
674 raw: raw.as_bytes().to_vec(),
675 })))
676 }
677 }
678}
679
680async fn return_definite_error(
682 err: DefiniteError,
683 outputs: impl Iterator<Item = u64>,
684 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
685 data_capset: &CapabilitySet<Lsn>,
686 errs_handle: AsyncOutputHandle<Lsn, CapacityContainerBuilder<Vec<ReplicationError>>>,
687 errs_capset: &CapabilitySet<Lsn>,
688) {
689 for output_idx in outputs {
690 let update = (
691 (output_idx, Err(err.clone().into())),
692 Lsn {
696 vlf_id: u32::MAX,
697 block_id: u32::MAX,
698 record_id: u16::MAX,
699 },
700 Diff::ONE,
701 );
702 data_handle.give_fueled(&data_capset[0], update).await;
703 }
704 errs_handle.give(
705 &errs_capset[0],
706 ReplicationError::DefiniteError(Rc::new(err)),
707 );
708}
709
710struct PrometheusSqlServerCdcMetrics<'a> {
712 inner: &'a SqlServerSourceMetrics,
713}
714
715impl<'a> SqlServerCdcMetrics for PrometheusSqlServerCdcMetrics<'a> {
716 fn snapshot_table_lock_start(&self, table_name: &str) {
717 self.inner.update_snapshot_table_lock_count(table_name, 1);
718 }
719
720 fn snapshot_table_lock_end(&self, table_name: &str) {
721 self.inner.update_snapshot_table_lock_count(table_name, -1);
722 }
723}