1use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::future::InTask;
24use mz_repr::{Diff, GlobalId, Row, RowArena};
25use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
26use mz_sql_server_util::desc::SqlServerRowDecoder;
27use mz_sql_server_util::inspect::get_latest_restore_history_id;
28use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
29use mz_storage_types::sources::SqlServerSourceConnection;
30use mz_storage_types::sources::sql_server::{
31 CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
32};
33use mz_timely_util::builder_async::{
34 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use mz_timely_util::containers::stack::AccountedStackBuilder;
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pushers::Tee;
39use timely::dataflow::operators::{CapabilitySet, Concat, Map};
40use timely::dataflow::{Scope, Stream as TimelyStream};
41use timely::progress::{Antichain, Timestamp};
42
43use crate::source::RawSourceCreationConfig;
44use crate::source::sql_server::{
45 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
46};
47use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
48
49static REPL_READER: &str = "reader";
55
56pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
57 scope: G,
58 config: RawSourceCreationConfig,
59 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
60 source: SqlServerSourceConnection,
61) -> (
62 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
63 TimelyStream<G, Infallible>,
64 TimelyStream<G, ReplicationError>,
65 PressOnDropButton,
66) {
67 let op_name = format!("SqlServerReplicationReader({})", config.id);
68 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
69
70 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
71 let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
72
73 let (definite_error_handle, definite_errors) =
75 builder.new_output::<CapacityContainerBuilder<_>>();
76
77 let (button, transient_errors) = builder.build_fallible(move |caps| {
78 let busy_signal = Arc::clone(&config.busy_signal);
79 Box::pin(SignaledFuture::new(busy_signal, async move {
80 let [
81 data_cap_set,
82 upper_cap_set,
83 definite_error_cap_set,
84 ]: &mut [_; 3] = caps.try_into().unwrap();
85
86 let connection_config = source
87 .connection
88 .resolve_config(
89 &config.config.connection_context.secrets_reader,
90 &config.config,
91 InTask::Yes,
92 )
93 .await?;
94 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
95
96 let worker_id = config.worker_id;
97
98 let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
100 let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
102 let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
104 let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
106
107 for (export_id, output) in outputs.iter() {
108 if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
109 panic!("Multiple decoders for output index {}", output.partition_index);
110 }
111 capture_instances
112 .entry(Arc::clone(&output.capture_instance))
113 .or_default()
114 .push(output.partition_index);
115
116 if *output.resume_upper == [Lsn::minimum()] {
117 capture_instance_to_snapshot
118 .entry(Arc::clone(&output.capture_instance))
119 .or_default()
120 .push((output.partition_index, output.initial_lsn));
121 }
122 export_statistics.entry(Arc::clone(&output.capture_instance))
123 .or_default()
124 .push(
125 config
126 .statistics
127 .get(export_id)
128 .expect("statistics have been intialized")
129 .clone(),
130 );
131 }
132
133 if !capture_instance_to_snapshot.is_empty() {
138 for stats in config.statistics.values() {
139 stats.set_snapshot_records_known(0);
140 stats.set_snapshot_records_staged(0);
141 }
142 }
143 if !config.responsible_for(REPL_READER) {
146 return Ok::<_, TransientError>(());
147 }
148
149 let snapshot_instances = capture_instance_to_snapshot
150 .keys()
151 .map(|i| i.as_ref());
152
153 let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
155
156 let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
158 if current_restore_history_id != source.extras.restore_history_id {
159 let definite_error = DefiniteError::RestoreHistoryChanged(
160 source.extras.restore_history_id.clone(),
161 current_restore_history_id.clone()
162 );
163 tracing::warn!(?definite_error, "Restore detected, exiting");
164
165 return_definite_error(
166 definite_error,
167 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
168 data_output,
169 data_cap_set,
170 definite_error_handle,
171 definite_error_cap_set,
172 ).await;
173 return Ok(());
174 }
175
176 for table in &snapshot_tables {
180 let table_total = mz_sql_server_util::inspect::snapshot_size(&mut client, &table.schema_name, &table.name).await?;
181 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
182 export_stat.set_snapshot_records_known(u64::cast_from(table_total));
183 export_stat.set_snapshot_records_staged(0);
184 }
185 }
186
187 let mut cdc_handle = client
188 .cdc(capture_instances.keys().cloned())
189 .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
190
191 let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
194 cdc_handle.wait_for_ready().await?;
197
198 tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
202
203 let report_interval =
204 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
205 let mut last_report = Instant::now();
206 let mut snapshot_lsns = BTreeMap::new();
207 let arena = RowArena::default();
208
209 for table in snapshot_tables {
210 let (snapshot_lsn, snapshot)= cdc_handle
212 .snapshot(&table, config.worker_id, config.id)
213 .await?;
214
215 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot start");
216
217 let mut snapshot = std::pin::pin!(snapshot);
218
219 snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
220
221 let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
222 .unwrap_or_else(|| {
223 panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
224 });
225
226 let mut snapshot_staged = 0;
227 while let Some(result) = snapshot.next().await {
228 let sql_server_row = result.map_err(TransientError::from)?;
229
230 if last_report.elapsed() > report_interval.get() {
231 last_report = Instant::now();
232 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
233 export_stat.set_snapshot_records_staged(snapshot_staged);
234 }
235 }
236
237 for (partition_idx, _) in partition_indexes {
238 let mut mz_row = Row::default();
240
241 let decoder = decoder_map.get(partition_idx).expect("decoder for output");
242 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
244 Ok(()) => Ok(SourceMessage {
245 key: Row::default(),
246 value: mz_row,
247 metadata: Row::default(),
248 }),
249 Err(e) => {
250 let kind = DecodeErrorKind::Text(e.to_string().into());
251 let raw = format!("{sql_server_row:?}");
253 Err(DataflowError::DecodeError(Box::new(DecodeError {
254 kind,
255 raw: raw.as_bytes().to_vec(),
256 })))
257 }
258 };
259 data_output
260 .give_fueled(
261 &data_cap_set[0],
262 ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
263 )
264 .await;
265 }
266 snapshot_staged += 1;
267 }
268
269 tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot complete");
270
271 for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
273 export_stat.set_snapshot_records_staged(snapshot_staged);
274 export_stat.set_snapshot_records_known(snapshot_staged);
275 }
276 }
277
278 snapshot_lsns
279 };
280
281 let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
296 .iter()
297 .flat_map(|(capture_instance, export_ids)|{
298 let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
299 export_ids
300 .iter()
301 .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
302 }).collect();
303
304 for (initial_lsn, snapshot_lsn) in rewinds.values() {
310 assert!(
311 initial_lsn <= snapshot_lsn,
312 "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
313 );
314 }
315
316 tracing::debug!("rewinds to process: {rewinds:?}");
317
318 capture_instance_to_snapshot.clear();
319
320 let mut resume_lsns = BTreeMap::new();
322 for src_info in outputs.values() {
323 let resume_lsn = match src_info.resume_upper.as_option() {
324 Some(lsn) if *lsn != Lsn::minimum() => *lsn,
325 Some(_) => src_info.initial_lsn.increment(),
329 None => panic!("resume_upper has at least one value"),
330 };
331 resume_lsns.entry(Arc::clone(&src_info.capture_instance))
332 .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
333 .or_insert(resume_lsn);
334 }
335
336 tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
337 for instance in capture_instances.keys() {
338 let resume_lsn = resume_lsns
339 .get(instance)
340 .expect("resume_lsn exists for capture instance");
341 cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
342 }
343
344 let cdc_stream = cdc_handle
346 .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
347 .into_stream();
348 let mut cdc_stream = std::pin::pin!(cdc_stream);
349
350 let mut errored_instances = BTreeSet::new();
351
352 let mut log_rewinds_complete = true;
356 while let Some(event) = cdc_stream.next().await {
357 let event = event.map_err(TransientError::from)?;
358 tracing::trace!(?config.id, ?event, "got replication event");
359
360 match event {
361 CdcEvent::Progress { next_lsn } => {
364 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
365 rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
368 if rewinds.is_empty() {
369 if log_rewinds_complete {
370 tracing::debug!("rewinds complete");
371 log_rewinds_complete = false;
372 }
373 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
374 } else {
375 tracing::debug!("rewinds remaining: {:?}", rewinds);
376 }
377 upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
378 }
379 CdcEvent::Data {
381 capture_instance,
382 lsn,
383 changes,
384 } => {
385 if errored_instances.contains(&capture_instance) {
386 }
389
390 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
391 let definite_error = DefiniteError::ProgrammingError(format!(
392 "capture instance didn't exist: '{capture_instance}'"
393 ));
394 return_definite_error(
395 definite_error,
396 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
397 data_output,
398 data_cap_set,
399 definite_error_handle,
400 definite_error_cap_set,
401 )
402 .await;
403 return Ok(());
404 };
405
406 handle_data_event(
407 changes,
408 partition_indexes,
409 &decoder_map,
410 lsn,
411 &rewinds,
412 &data_output,
413 data_cap_set
414 ).await?
415 },
416 CdcEvent::SchemaUpdate { capture_instance, table, ddl_event } => {
417 if !errored_instances.contains(&capture_instance)
418 && !ddl_event.is_compatible() {
419 let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
420 let definite_error = DefiniteError::ProgrammingError(format!(
421 "capture instance didn't exist: '{capture_instance}'"
422 ));
423 return_definite_error(
424 definite_error,
425 capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
426 data_output,
427 data_cap_set,
428 definite_error_handle,
429 definite_error_cap_set,
430 )
431 .await;
432 return Ok(());
433 };
434 let error = DefiniteError::IncompatibleSchemaChange(
435 capture_instance.to_string(),
436 table.to_string()
437 );
438 for partition_idx in partition_indexes {
439 data_output
440 .give_fueled(
441 &data_cap_set[0],
442 ((*partition_idx, Err(error.clone().into())), ddl_event.lsn, Diff::ONE),
443 )
444 .await;
445 }
446 errored_instances.insert(capture_instance);
447 }
448 }
449 };
450 }
451 Err(TransientError::ReplicationEOF)
452 }))
453 });
454
455 let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
456
457 (
458 data_stream.as_collection(),
459 upper_stream,
460 error_stream,
461 button.press_on_drop(),
462 )
463}
464
465async fn handle_data_event(
466 changes: Vec<CdcOperation>,
467 partition_indexes: &[u64],
468 decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
469 commit_lsn: Lsn,
470 rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
471 data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
472 data_cap_set: &CapabilitySet<Lsn>,
473) -> Result<(), TransientError> {
474 for change in changes {
475 let (sql_server_row, diff): (_, _) = match change {
476 CdcOperation::Insert(sql_server_row) | CdcOperation::UpdateNew(sql_server_row) => {
477 (sql_server_row, Diff::ONE)
478 }
479 CdcOperation::Delete(sql_server_row) | CdcOperation::UpdateOld(sql_server_row) => {
480 (sql_server_row, Diff::MINUS_ONE)
481 }
482 };
483
484 let mut mz_row = Row::default();
486 let arena = RowArena::default();
487
488 for partition_idx in partition_indexes {
489 let decoder = decoder_map.get(partition_idx).unwrap();
490
491 let rewind = rewinds.get(partition_idx);
492 if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
495 continue;
496 }
497
498 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
500 Ok(()) => Ok(SourceMessage {
501 key: Row::default(),
502 value: mz_row.clone(),
503 metadata: Row::default(),
504 }),
505 Err(e) => {
506 let kind = DecodeErrorKind::Text(e.to_string().into());
507 let raw = format!("{sql_server_row:?}");
509 Err(DataflowError::DecodeError(Box::new(DecodeError {
510 kind,
511 raw: raw.as_bytes().to_vec(),
512 })))
513 }
514 };
515
516 if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
517 data_output
518 .give_fueled(
519 &data_cap_set[0],
520 ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
521 )
522 .await;
523 }
524 data_output
525 .give_fueled(
526 &data_cap_set[0],
527 ((*partition_idx, message), commit_lsn, diff),
528 )
529 .await;
530 }
531 }
532 Ok(())
533}
534
535type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
536 T,
537 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
538 Tee<T, TimelyStack<(D, T, Diff)>>,
539>;
540
541async fn return_definite_error(
543 err: DefiniteError,
544 outputs: impl Iterator<Item = u64>,
545 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
546 data_capset: &CapabilitySet<Lsn>,
547 errs_handle: AsyncOutputHandle<
548 Lsn,
549 CapacityContainerBuilder<Vec<ReplicationError>>,
550 Tee<Lsn, Vec<ReplicationError>>,
551 >,
552 errs_capset: &CapabilitySet<Lsn>,
553) {
554 for output_idx in outputs {
555 let update = (
556 (output_idx, Err(err.clone().into())),
557 Lsn {
561 vlf_id: u32::MAX,
562 block_id: u32::MAX,
563 record_id: u16::MAX,
564 },
565 Diff::ONE,
566 );
567 data_handle.give_fueled(&data_capset[0], update).await;
568 }
569 errs_handle.give(
570 &errs_capset[0],
571 ReplicationError::DefiniteError(Rc::new(err)),
572 );
573}