1use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::future::InTask;
24use mz_repr::{Diff, GlobalId, Row, RowArena};
25use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
26use mz_sql_server_util::desc::SqlServerRowDecoder;
27use mz_sql_server_util::inspect::get_latest_restore_history_id;
28use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
29use mz_storage_types::sources::SqlServerSourceConnection;
30use mz_storage_types::sources::sql_server::{
31 CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
32};
33use mz_timely_util::builder_async::{
34 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use mz_timely_util::containers::stack::AccountedStackBuilder;
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pushers::Tee;
39use timely::dataflow::operators::{CapabilitySet, Concat, Map};
40use timely::dataflow::{Scope, Stream as TimelyStream};
41use timely::progress::{Antichain, Timestamp};
42
43use crate::source::RawSourceCreationConfig;
44use crate::source::sql_server::{
45 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
46};
47use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
48
49static REPL_READER: &str = "reader";
55
56pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
57 scope: G,
58 config: RawSourceCreationConfig,
59 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
60 source: SqlServerSourceConnection,
61) -> (
62 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
63 TimelyStream<G, Infallible>,
64 TimelyStream<G, ReplicationError>,
65 PressOnDropButton,
66) {
67 let op_name = format!("SqlServerReplicationReader({})", config.id);
68 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
69
70 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
71 let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
72
73 let (definite_error_handle, definite_errors) =
75 builder.new_output::<CapacityContainerBuilder<_>>();
76
77 let (button, transient_errors) = builder.build_fallible(move |caps| {
78 let busy_signal = Arc::clone(&config.busy_signal);
79 Box::pin(SignaledFuture::new(busy_signal, async move {
80 let [
81 data_cap_set,
82 upper_cap_set,
83 definite_error_cap_set,
84 ]: &mut [_; 3] = caps.try_into().unwrap();
85
86 let connection_config = source
87 .connection
88 .resolve_config(
89 &config.config.connection_context.secrets_reader,
90 &config.config,
91 InTask::Yes,
92 )
93 .await?;
94 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
95
96 let worker_id = config.worker_id;
97
98 let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
100 let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
102 let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
104 let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
106
107 for (export_id, output) in outputs.iter() {
108 if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
109 panic!("Multiple decoders for output index {}", output.partition_index);
110 }
111 capture_instances
112 .entry(Arc::clone(&output.capture_instance))
113 .or_default()
114 .push(output.partition_index);
115
116 if *output.resume_upper == [Lsn::minimum()] {
117 capture_instance_to_snapshot
118 .entry(Arc::clone(&output.capture_instance))
119 .or_default()
120 .push((output.partition_index, output.initial_lsn));
121 }
122 export_statistics.entry(Arc::clone(&output.capture_instance))
123 .or_default()
124 .push(
125 config
126 .statistics
127 .get(export_id)
128 .expect("statistics have been intialized")
129 .clone(),
130 );
131 }
132
133 if !capture_instance_to_snapshot.is_empty() {
138 for stats in config.statistics.values() {
139 stats.set_snapshot_records_known(0);
140 stats.set_snapshot_records_staged(0);
141 }
142 }
143 if !config.responsible_for(REPL_READER) {
146 return Ok::<_, TransientError>(());
147 }
148
149 let snapshot_instances = capture_instance_to_snapshot
150 .keys()
151 .map(|i| i.as_ref());
152
153 let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
154
155 let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
157 if current_restore_history_id != source.extras.restore_history_id {
158 let definite_error = DefiniteError::RestoreHistoryChanged(
159 source.extras.restore_history_id.clone(),
160 current_restore_history_id.clone()
161 );
162 tracing::error!(?definite_error, "Restore detected, exiting replication");
163
164 return_definite_error(
165 definite_error,
166 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
167 data_output,
168 data_cap_set,
169 definite_error_handle,
170 definite_error_cap_set,
171 ).await;
172 return Ok(());
173 }
174
175 for table in &snapshot_tables {
179 let table_total = mz_sql_server_util::inspect::snapshot_size(&mut client, &table.schema_name, &table.name).await?;
180 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
181 export_stat.set_snapshot_records_known(u64::cast_from(table_total));
182 export_stat.set_snapshot_records_staged(0);
183 }
184 }
185
186 let mut cdc_handle = client
187 .cdc(capture_instances.keys().cloned())
188 .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
189
190 let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
193 cdc_handle.wait_for_ready().await?;
196
197 tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
201
202 let report_interval =
203 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
204 let mut last_report = Instant::now();
205 let mut snapshot_lsns = BTreeMap::new();
206 let arena = RowArena::default();
207
208 for table in snapshot_tables {
209 let (snapshot_lsn, snapshot)= cdc_handle
211 .snapshot(&table, config.worker_id, config.id)
212 .await?;
213
214 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot start");
215
216 let mut snapshot = std::pin::pin!(snapshot);
217
218 snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
219
220 let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
221 .unwrap_or_else(|| {
222 panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
223 });
224
225 let mut snapshot_staged = 0;
226 while let Some(result) = snapshot.next().await {
227 let sql_server_row = result.map_err(TransientError::from)?;
228
229 if last_report.elapsed() > report_interval.get() {
230 last_report = Instant::now();
231 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
232 export_stat.set_snapshot_records_staged(snapshot_staged);
233 }
234 }
235
236 for (partition_idx, _) in partition_indexes {
237 let mut mz_row = Row::default();
239
240 let decoder = decoder_map.get(partition_idx).expect("decoder for output");
241 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
243 Ok(()) => Ok(SourceMessage {
244 key: Row::default(),
245 value: mz_row,
246 metadata: Row::default(),
247 }),
248 Err(e) => {
249 let kind = DecodeErrorKind::Text(e.to_string().into());
250 let raw = format!("{sql_server_row:?}");
252 Err(DataflowError::DecodeError(Box::new(DecodeError {
253 kind,
254 raw: raw.as_bytes().to_vec(),
255 })))
256 }
257 };
258 data_output
259 .give_fueled(
260 &data_cap_set[0],
261 ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
262 )
263 .await;
264 }
265 snapshot_staged += 1;
266 }
267
268 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot complete");
269
270 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
272 export_stat.set_snapshot_records_staged(snapshot_staged);
273 export_stat.set_snapshot_records_known(snapshot_staged);
274 }
275 }
276
277 snapshot_lsns
278 };
279
280 let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
295 .iter()
296 .flat_map(|(capture_instance, export_ids)|{
297 let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
298 export_ids
299 .iter()
300 .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
301 }).collect();
302
303 for (initial_lsn, snapshot_lsn) in rewinds.values() {
309 assert!(
310 initial_lsn <= snapshot_lsn,
311 "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
312 );
313 }
314
315 tracing::debug!("rewinds to process: {rewinds:?}");
316
317 capture_instance_to_snapshot.clear();
318
319 let mut resume_lsns = BTreeMap::new();
321 for src_info in outputs.values() {
322 let resume_lsn = match src_info.resume_upper.as_option() {
323 Some(lsn) if *lsn != Lsn::minimum() => *lsn,
324 Some(_) => src_info.initial_lsn.increment(),
328 None => panic!("resume_upper has at least one value"),
329 };
330 resume_lsns.entry(Arc::clone(&src_info.capture_instance))
331 .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
332 .or_insert(resume_lsn);
333 }
334
335 tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
336 for instance in capture_instances.keys() {
337 let resume_lsn = resume_lsns
338 .get(instance)
339 .expect("resume_lsn exists for capture instance");
340 cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
341 }
342
343 let cdc_stream = cdc_handle
345 .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
346 .into_stream();
347 let mut cdc_stream = std::pin::pin!(cdc_stream);
348
349 let mut errored_instances = BTreeSet::new();
350
351 let mut log_rewinds_complete = true;
355 while let Some(event) = cdc_stream.next().await {
356 let event = event.map_err(TransientError::from)?;
357 tracing::trace!(?config.id, ?event, "got replication event");
358
359 match event {
360 CdcEvent::Progress { next_lsn } => {
363 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
364 rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
367 if rewinds.is_empty() {
368 if log_rewinds_complete {
369 tracing::debug!("rewinds complete");
370 log_rewinds_complete = false;
371 }
372 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
373 } else {
374 tracing::debug!("rewinds remaining: {:?}", rewinds);
375 }
376 upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
377 }
378 CdcEvent::Data {
380 capture_instance,
381 lsn,
382 changes,
383 } => {
384 if errored_instances.contains(&capture_instance) {
385 }
388
389 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
390 let definite_error = DefiniteError::ProgrammingError(format!(
391 "capture instance didn't exist: '{capture_instance}'"
392 ));
393 return_definite_error(
394 definite_error,
395 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
396 data_output,
397 data_cap_set,
398 definite_error_handle,
399 definite_error_cap_set,
400 )
401 .await;
402 return Ok(());
403 };
404
405 handle_data_event(
406 changes,
407 partition_indexes,
408 &decoder_map,
409 lsn,
410 &rewinds,
411 &data_output,
412 data_cap_set
413 ).await?
414 },
415 CdcEvent::SchemaUpdate { capture_instance, table, ddl_event } => {
416 if !errored_instances.contains(&capture_instance)
417 && !ddl_event.is_compatible() {
418 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
419 let definite_error = DefiniteError::ProgrammingError(format!(
420 "capture instance didn't exist: '{capture_instance}'"
421 ));
422 return_definite_error(
423 definite_error,
424 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
425 data_output,
426 data_cap_set,
427 definite_error_handle,
428 definite_error_cap_set,
429 )
430 .await;
431 return Ok(());
432 };
433 let error = DefiniteError::IncompatibleSchemaChange(
434 capture_instance.to_string(),
435 table.to_string()
436 );
437 for partition_idx in partition_indexes {
438 data_output
439 .give_fueled(
440 &data_cap_set[0],
441 ((*partition_idx, Err(error.clone().into())), Lsn::minimum(), Diff::ONE),
442 )
443 .await;
444 }
445 errored_instances.insert(capture_instance);
446 }
447 }
448 };
449 }
450 Err(TransientError::ReplicationEOF)
451 }))
452 });
453
454 let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
455
456 (
457 data_stream.as_collection(),
458 upper_stream,
459 error_stream,
460 button.press_on_drop(),
461 )
462}
463
464async fn handle_data_event(
465 changes: Vec<CdcOperation>,
466 partition_indexes: &[u64],
467 decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
468 commit_lsn: Lsn,
469 rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
470 data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
471 data_cap_set: &CapabilitySet<Lsn>,
472) -> Result<(), TransientError> {
473 for change in changes {
474 let (sql_server_row, diff): (_, _) = match change {
475 CdcOperation::Insert(sql_server_row) | CdcOperation::UpdateNew(sql_server_row) => {
476 (sql_server_row, Diff::ONE)
477 }
478 CdcOperation::Delete(sql_server_row) | CdcOperation::UpdateOld(sql_server_row) => {
479 (sql_server_row, Diff::MINUS_ONE)
480 }
481 };
482
483 let mut mz_row = Row::default();
485 let arena = RowArena::default();
486
487 for partition_idx in partition_indexes {
488 let decoder = decoder_map.get(partition_idx).unwrap();
489
490 let rewind = rewinds.get(partition_idx);
491 if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
494 continue;
495 }
496
497 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
499 Ok(()) => Ok(SourceMessage {
500 key: Row::default(),
501 value: mz_row.clone(),
502 metadata: Row::default(),
503 }),
504 Err(e) => {
505 let kind = DecodeErrorKind::Text(e.to_string().into());
506 let raw = format!("{sql_server_row:?}");
508 Err(DataflowError::DecodeError(Box::new(DecodeError {
509 kind,
510 raw: raw.as_bytes().to_vec(),
511 })))
512 }
513 };
514
515 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
516 data_output
517 .give_fueled(
518 &data_cap_set[0],
519 ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
520 )
521 .await;
522 }
523 data_output
524 .give_fueled(
525 &data_cap_set[0],
526 ((*partition_idx, message), commit_lsn, diff),
527 )
528 .await;
529 }
530 }
531 Ok(())
532}
533
534type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
535 T,
536 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
537 Tee<T, TimelyStack<(D, T, Diff)>>,
538>;
539
540async fn return_definite_error(
542 err: DefiniteError,
543 outputs: impl Iterator<Item = u64>,
544 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
545 data_capset: &CapabilitySet<Lsn>,
546 errs_handle: AsyncOutputHandle<
547 Lsn,
548 CapacityContainerBuilder<Vec<ReplicationError>>,
549 Tee<Lsn, Vec<ReplicationError>>,
550 >,
551 errs_capset: &CapabilitySet<Lsn>,
552) {
553 for output_idx in outputs {
554 let update = (
555 (output_idx, Err(err.clone().into())),
556 Lsn {
560 vlf_id: u32::MAX,
561 block_id: u32::MAX,
562 record_id: u16::MAX,
563 },
564 Diff::ONE,
565 );
566 data_handle.give_fueled(&data_capset[0], update).await;
567 }
568 errs_handle.give(
569 &errs_capset[0],
570 ReplicationError::DefiniteError(Rc::new(err)),
571 );
572}