Skip to main content

mz_storage/source/sql_server/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Code to render the ingestion dataflow of a [`SqlServerSourceConnection`].
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::collections::HashMap;
24use mz_ore::future::InTask;
25use mz_repr::{Diff, GlobalId, Row, RowArena};
26use mz_sql_server_util::SqlServerCdcMetrics;
27use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
28use mz_sql_server_util::desc::SqlServerRowDecoder;
29use mz_sql_server_util::inspect::{
30    ensure_database_cdc_enabled, ensure_sql_server_agent_running, get_latest_restore_history_id,
31};
32use mz_storage_types::dyncfgs::SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY;
33use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
34use mz_storage_types::sources::SqlServerSourceConnection;
35use mz_storage_types::sources::sql_server::{
36    CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
37};
38use mz_timely_util::builder_async::{
39    AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
40};
41use mz_timely_util::containers::stack::AccountedStackBuilder;
42use timely::container::CapacityContainerBuilder;
43use timely::dataflow::operators::{CapabilitySet, Concat, Map};
44use timely::dataflow::{Scope, Stream as TimelyStream};
45use timely::progress::{Antichain, Timestamp};
46
47use crate::metrics::source::sql_server::SqlServerSourceMetrics;
48use crate::source::RawSourceCreationConfig;
49use crate::source::sql_server::{
50    DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
51};
52use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
53
54/// Used as a partition ID to determine the worker that is responsible for
55/// reading data from SQL Server.
56///
57/// TODO(sql_server2): It's possible we could have different workers
58/// replicate different tables, if we're using SQL Server's CDC features.
59static REPL_READER: &str = "reader";
60
61pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
62    scope: G,
63    config: RawSourceCreationConfig,
64    outputs: BTreeMap<GlobalId, SourceOutputInfo>,
65    source: SqlServerSourceConnection,
66    metrics: SqlServerSourceMetrics,
67) -> (
68    StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
69    TimelyStream<G, Infallible>,
70    TimelyStream<G, ReplicationError>,
71    PressOnDropButton,
72) {
73    let op_name = format!("SqlServerReplicationReader({})", config.id);
74    let mut builder = AsyncOperatorBuilder::new(op_name, scope);
75
76    let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
77    let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
78
79    // Captures DefiniteErrors that affect the entire source, including all outputs
80    let (definite_error_handle, definite_errors) =
81        builder.new_output::<CapacityContainerBuilder<_>>();
82
83    let (button, transient_errors) = builder.build_fallible(move |caps| {
84        let busy_signal = Arc::clone(&config.busy_signal);
85        Box::pin(SignaledFuture::new(busy_signal, async move {
86            let [
87                data_cap_set,
88                upper_cap_set,
89                definite_error_cap_set,
90            ]: &mut [_; 3] = caps.try_into().unwrap();
91
92            let connection_config = source
93                .connection
94                .resolve_config(
95                    &config.config.connection_context.secrets_reader,
96                    &config.config,
97                    InTask::Yes,
98                )
99                .await?;
100            let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
101
102            let worker_id = config.worker_id;
103
104            // The decoder is specific to the export, and each export pulls data from a specific capture instance.
105            let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
106            // Maps the 'capture instance' to the output index for only those outputs that this worker will snapshot
107            let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
108            // Maps the 'capture instance' to the output index for all outputs of this worker
109            let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
110            // Export statistics for a given capture instance
111            let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
112            // Maps the included columns for each output index so we can check whether schema updates are valid on a per-output basis
113            let mut included_columns: HashMap<u64, Vec<Arc<str>>> = HashMap::new();
114
115            for (export_id, output) in outputs.iter() {
116                if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
117                    panic!("Multiple decoders for output index {}", output.partition_index);
118                }
119                // Collect the included columns from decoder for schema change validation
120                // The decoder serves as an effective source of truth for which columns are "included", as we only care about the columns that are being decoded and replicated
121                let included_cols = output.decoder.included_column_names();
122                included_columns.insert(output.partition_index, included_cols);
123
124                capture_instances
125                    .entry(Arc::clone(&output.capture_instance))
126                    .or_default()
127                    .push(output.partition_index);
128
129                if *output.resume_upper == [Lsn::minimum()] {
130                    capture_instance_to_snapshot
131                        .entry(Arc::clone(&output.capture_instance))
132                        .or_default()
133                        .push((output.partition_index, output.initial_lsn));
134                }
135                export_statistics.entry(Arc::clone(&output.capture_instance))
136                    .or_default()
137                    .push(
138                        config
139                            .statistics
140                            .get(export_id)
141                            .expect("statistics have been intialized")
142                            .clone(),
143                    );
144            }
145
146            // Eagerly emit an event if we have tables to snapshot.
147            // A worker *must* emit a count even if not responsible for snapshotting a table
148            // as statistic summarization will return null if any worker hasn't set a value.
149            // This will also reset snapshot stats for any exports not snapshotting.
150            metrics.snapshot_table_count.set(u64::cast_from(capture_instance_to_snapshot.len()));
151            if !capture_instance_to_snapshot.is_empty() {
152                for stats in config.statistics.values() {
153                    stats.set_snapshot_records_known(0);
154                    stats.set_snapshot_records_staged(0);
155                }
156            }
157            // We need to emit statistics before we exit
158            // TODO(sql_server2): Run ingestions across multiple workers.
159            if !config.responsible_for(REPL_READER) {
160                return Ok::<_, TransientError>(());
161            }
162
163            let snapshot_instances = capture_instance_to_snapshot
164                    .keys()
165                    .map(|i| i.as_ref());
166
167            // TODO (maz): we can avoid this query by using SourceOutputInfo
168            let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
169
170            // validate that the restore_history_id hasn't changed
171            let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
172            if current_restore_history_id != source.extras.restore_history_id {
173                if SQL_SERVER_SOURCE_VALIDATE_RESTORE_HISTORY.get(config.config.config_set()) {
174                    let definite_error = DefiniteError::RestoreHistoryChanged(
175                        source.extras.restore_history_id.clone(),
176                        current_restore_history_id.clone()
177                    );
178                    tracing::warn!(?definite_error, "Restore detected, exiting");
179
180                    return_definite_error(
181                            definite_error,
182                            capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
183                            data_output,
184                            data_cap_set,
185                            definite_error_handle,
186                            definite_error_cap_set,
187                        ).await;
188                    return Ok(());
189                } else {
190                    tracing::warn!(
191                        "Restore history mismatch ignored: expected={expected:?} actual={actual:?}",
192                        expected=source.extras.restore_history_id,
193                        actual=current_restore_history_id
194                    );
195                }
196            }
197
198            // For AOAG, it's possible that the dataflow restarted and is now connected to a
199            // different SQL Server, which may not have CDC enabled correctly.
200            ensure_database_cdc_enabled(&mut client).await?;
201            ensure_sql_server_agent_running(&mut client).await?;
202
203            // We first calculate all the total rows we need to fetch across all tables. Since this
204            // happens outside the snapshot transaction the totals might be off, so we won't assert
205            // that we get exactly this many rows later.
206            for table in &snapshot_tables {
207                let qualified_table_name = format!("{schema_name}.{table_name}",
208                    schema_name = &table.schema_name,
209                    table_name = &table.name);
210                let size_calc_start = Instant::now();
211                let table_total = mz_sql_server_util::inspect::snapshot_size(&mut client, &table.schema_name, &table.name).await?;
212                metrics.set_snapshot_table_size_latency(
213                    &qualified_table_name,
214                    size_calc_start.elapsed().as_secs_f64()
215                );
216                for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
217                    export_stat.set_snapshot_records_known(u64::cast_from(table_total));
218                    export_stat.set_snapshot_records_staged(0);
219                }
220            }
221            let cdc_metrics = PrometheusSqlServerCdcMetrics{inner: &metrics};
222            let mut cdc_handle = client
223                .cdc(capture_instances.keys().cloned(), cdc_metrics)
224                .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
225
226            // Snapshot any instance that requires it.
227            // Each table snapshot will have its own LSN captured at the moment of snapshotting.
228            let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
229                // Before starting a transaction where the LSN will not advance, ensure
230                // the upstream DB is ready for CDC.
231                cdc_handle.wait_for_ready().await?;
232
233                // Intentionally logging this at info for debugging. This section won't get entered
234                // often, but if there are problems here, it will be much easier to troubleshoot
235                // knowing where stall/hang might be happening.
236                tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
237
238                let report_interval =
239                    SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
240                let mut last_report = Instant::now();
241                let mut snapshot_lsns = BTreeMap::new();
242                let arena = RowArena::default();
243
244                for table in snapshot_tables {
245                    // TODO(sql_server3): filter columns to only select columns required for Source.
246                    let (snapshot_lsn, snapshot) = cdc_handle
247                        .snapshot(&table, config.worker_id, config.id)
248                        .await?;
249
250                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot start");
251
252                    let mut snapshot = std::pin::pin!(snapshot);
253
254                    snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
255
256                    let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
257                        .unwrap_or_else(|| {
258                            panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
259                        });
260
261                    let mut snapshot_staged = 0;
262                    while let Some(result) = snapshot.next().await {
263                        let sql_server_row = result.map_err(TransientError::from)?;
264
265                        if last_report.elapsed() > report_interval.get() {
266                            last_report = Instant::now();
267                            for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
268                                export_stat.set_snapshot_records_staged(snapshot_staged);
269                            }
270                        }
271
272                        for (partition_idx, _) in partition_indexes {
273                            // Decode the SQL Server row into an MZ one.
274                            let mut mz_row = Row::default();
275
276                            let decoder = decoder_map.get(partition_idx).expect("decoder for output");
277                            // Try to decode a row, returning a SourceError if it fails.
278                            let message = decode(decoder, &sql_server_row, &mut mz_row, &arena, None);
279                            data_output
280                                .give_fueled(
281                                    &data_cap_set[0],
282                                    ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
283                                )
284                                .await;
285                        }
286                        snapshot_staged += 1;
287                    }
288
289                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot complete");
290                    metrics.snapshot_table_count.dec();
291                    // final update for snapshot_staged, using the staged values as the total is an estimate
292                    for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
293                        export_stat.set_snapshot_records_staged(snapshot_staged);
294                        export_stat.set_snapshot_records_known(snapshot_staged);
295                    }
296                }
297
298                snapshot_lsns
299            };
300
301            // Rewinds need to keep track of 2 timestamps to ensure that
302            // all replicas emit the same set of updates for any given timestamp.
303            // These are the initial_lsn and snapshot_lsn, where initial_lsn must be
304            // less than or equal to snapshot_lsn.
305            //
306            // - events at an LSN less than or equal to initial_lsn are ignored
307            // - events at an LSN greater than initial_lsn and less than or equal to
308            //   snapshot_lsn are retracted at Lsn::minimum(), and emitted at the commit_lsn
309            // - events at an LSN greater than snapshot_lsn are emitted at the commit_lsn
310            //
311            // where the commit_lsn is the upstream LSN that the event was committed at
312            //
313            // If initial_lsn == snapshot_lsn, all CDC events at LSNs up to and including the
314            // snapshot_lsn are ignored, and no rewinds are issued.
315            let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
316                .iter()
317                .flat_map(|(capture_instance, export_ids)|{
318                    let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
319                    export_ids
320                        .iter()
321                        .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
322                }).collect();
323
324            // For now, we assert that initial_lsn captured during purification is less
325            // than or equal to snapshot_lsn. If that was not true, it would mean that
326            // we observed a SQL server DB that appeared to go back in time.
327            // TODO (maz): not ideal to do this after snapshot, move this into
328            // CdcStream::snapshot after https://github.com/MaterializeInc/materialize/pull/32979 is merged.
329            for (initial_lsn, snapshot_lsn) in rewinds.values() {
330                assert!(
331                    initial_lsn <= snapshot_lsn,
332                    "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
333                );
334            }
335
336            tracing::debug!("rewinds to process: {rewinds:?}");
337
338            capture_instance_to_snapshot.clear();
339
340            // Resumption point is the minimum LSN that has been observed per capture instance.
341            let mut resume_lsns = BTreeMap::new();
342            for src_info in outputs.values() {
343                let resume_lsn = match src_info.resume_upper.as_option() {
344                    Some(lsn) if *lsn != Lsn::minimum() => *lsn,
345                    // initial_lsn is the max lsn observed, but the resume lsn
346                    // is the next lsn that should be read.  After a snapshot, initial_lsn
347                    // has been read, so replication will start at the next available lsn.
348                    Some(_) => src_info.initial_lsn.increment(),
349                    None => panic!("resume_upper has at least one value"),
350                };
351                resume_lsns.entry(Arc::clone(&src_info.capture_instance))
352                    .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
353                    .or_insert(resume_lsn);
354            }
355
356            tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
357            for instance in capture_instances.keys() {
358                let resume_lsn = resume_lsns
359                    .get(instance)
360                    .expect("resume_lsn exists for capture instance");
361                cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
362            }
363
364            // Off to the races! Replicate data from SQL Server.
365            let cdc_stream = cdc_handle
366                .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
367                .into_stream();
368            let mut cdc_stream = std::pin::pin!(cdc_stream);
369
370            let mut errored_partitions = BTreeSet::new();
371
372            // TODO(sql_server2): We should emit `ProgressStatisticsUpdate::SteadyState` messages
373            // here, when we receive progress events. What stops us from doing this now is our
374            // 10-byte LSN doesn't fit into the 8-byte integer that the progress event uses.
375            let mut log_rewinds_complete = true;
376
377            // deferred_updates temporarily stores rows for UPDATE operation to support Large Object
378            // Data (LOD) types (i.e. varchar(max), nvarchar(max)). The value of a
379            // LOD column will be NULL for the old row (operation = 3) if the value of the
380            // field did not change. The field data will be available in the new row
381            // (operation = 4).
382            // The CDC stream implementation emits a [`CdcEvent::Data`] event, which contains a
383            // batch of operations.  There is no guarantee that both old and new rows will
384            // exist in a single batch, so deferred updates must be tracked across multiple data
385            // events.
386            //
387            // In the current implementation schema change events won't be emitted between old
388            // and new rows.
389            //
390            // See <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-capture-instance-ct-transact-sql?view=sql-server-ver17#large-object-data-types>
391            let mut deferred_updates = BTreeMap::new();
392
393            while let Some(event) = cdc_stream.next().await {
394                let event = event.map_err(TransientError::from)?;
395                tracing::trace!(?config.id, ?event, "got replication event");
396
397                tracing::trace!("deferred_updates = {deferred_updates:?}");
398                match event {
399                    // We've received all of the changes up-to this LSN, so
400                    // downgrade our capability.
401                    CdcEvent::Progress { next_lsn } => {
402                        tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
403                        // cannot downgrade capability until rewinds have been processed,
404                        // we must be able to produce data at the minimum offset.
405                        rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
406                        if rewinds.is_empty() {
407                            if log_rewinds_complete {
408                                tracing::debug!("rewinds complete");
409                                log_rewinds_complete = false;
410                            }
411                            data_cap_set.downgrade(Antichain::from_elem(next_lsn));
412                        } else {
413                            tracing::debug!("rewinds remaining: {:?}", rewinds);
414                        }
415
416                        // Events are emitted in LSN order for a given capture instance, if any
417                        // deferred updates remain when the LSN progresses, it is a bug.
418                        if let Some(((deferred_lsn, _seqval), _row)) = deferred_updates.first_key_value()
419                            && *deferred_lsn < next_lsn
420                        {
421                            panic!(
422                                "deferred update lsn {deferred_lsn} < progress lsn {next_lsn}: {:?}",
423                                deferred_updates.keys()
424                            );
425                        }
426
427                        upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
428                    }
429                    // We've got new data! Let's process it.
430                    CdcEvent::Data {
431                        capture_instance,
432                        lsn,
433                        changes,
434                    } => {
435                        let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
436                            let definite_error = DefiniteError::ProgrammingError(format!(
437                                "capture instance didn't exist: '{capture_instance}'"
438                            ));
439                            return_definite_error(
440                                definite_error,
441                                capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
442                                data_output,
443                                data_cap_set,
444                                definite_error_handle,
445                                definite_error_cap_set,
446                            )
447                            .await;
448                            return Ok(());
449                        };
450
451                        let (valid_partitions, err_partitions) = partition_indexes.iter().partition::<Vec<u64>, _>(|&partition_idx| {
452                            !errored_partitions.contains(partition_idx)
453                        });
454
455                        if err_partitions.len() > 0 {
456                            metrics.ignored.inc_by(u64::cast_from(changes.len()));
457                        }
458
459                        handle_data_event(
460                            changes,
461                            &valid_partitions,
462                            &decoder_map,
463                            lsn,
464                            &rewinds,
465                            &data_output,
466                            data_cap_set,
467                            &metrics,
468                            &mut deferred_updates,
469                        ).await?
470                    },
471                    CdcEvent::SchemaUpdate { capture_instance, table, ddl_event } => {
472                        let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
473                            let definite_error = DefiniteError::ProgrammingError(format!(
474                                "capture instance didn't exist: '{capture_instance}'"
475                            ));
476                            return_definite_error(
477                                definite_error,
478                                capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
479                                data_output,
480                                data_cap_set,
481                                definite_error_handle,
482                                definite_error_cap_set,
483                            )
484                            .await;
485                            return Ok(());
486                        };
487                        let error = DefiniteError::IncompatibleSchemaChange(
488                            capture_instance.to_string(),
489                            table.to_string()
490                        );
491                        for partition_idx in partition_indexes {
492                            if !errored_partitions.contains(partition_idx) && !ddl_event.is_compatible(included_columns.get(partition_idx).unwrap_or_else(|| panic!("Partition index didn't exist: '{partition_idx}'"))) {
493                                data_output
494                                    .give_fueled(
495                                        &data_cap_set[0],
496                                        ((*partition_idx, Err(error.clone().into())), ddl_event.lsn, Diff::ONE),
497                                    )
498                                    .await;
499                                errored_partitions.insert(*partition_idx);
500                            }
501                        }
502                    }
503                };
504            }
505            Err(TransientError::ReplicationEOF)
506        }))
507    });
508
509    let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
510
511    (
512        data_stream.as_collection(),
513        upper_stream,
514        error_stream,
515        button.press_on_drop(),
516    )
517}
518
519async fn handle_data_event(
520    changes: Vec<CdcOperation>,
521    partition_indexes: &[u64],
522    decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
523    commit_lsn: Lsn,
524    rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
525    data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
526    data_cap_set: &CapabilitySet<Lsn>,
527    metrics: &SqlServerSourceMetrics,
528    deferred_updates: &mut BTreeMap<(Lsn, Lsn), CdcOperation>,
529) -> Result<(), TransientError> {
530    let mut mz_row = Row::default();
531    let arena = RowArena::default();
532
533    for change in changes {
534        // deferred_update is only valid for single iteration of the loop.  It is set once both
535        // old and new update rows are seen. It will be decoded and emitted to appropriate outputs.
536        // Its life now fullfilled, it will return to whence it came.
537        let mut deferred_update: Option<_> = None;
538        let (sql_server_row, diff): (_, _) = match change {
539            CdcOperation::Insert(sql_server_row) => {
540                metrics.inserts.inc();
541                (sql_server_row, Diff::ONE)
542            }
543            CdcOperation::Delete(sql_server_row) => {
544                metrics.deletes.inc();
545                (sql_server_row, Diff::MINUS_ONE)
546            }
547
548            // Updates are not ordered by seqval, so either old or new row could be observed first.
549            // The first update row is stashed, when the second arrives, both are processed.
550            CdcOperation::UpdateNew(seqval, sql_server_row) => {
551                // arbitrarily choosing to update metrics on the the new row
552                metrics.updates.inc();
553                deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
554                if deferred_update.is_none() {
555                    tracing::trace!("capture deferred UpdateNew ({commit_lsn}, {seqval})");
556                    deferred_updates.insert(
557                        (commit_lsn, seqval),
558                        CdcOperation::UpdateNew(seqval, sql_server_row),
559                    );
560                    continue;
561                }
562                // this is overriden below when the updates are ordered
563                (sql_server_row, Diff::ZERO)
564            }
565            CdcOperation::UpdateOld(seqval, sql_server_row) => {
566                deferred_update = deferred_updates.remove(&(commit_lsn, seqval));
567                if deferred_update.is_none() {
568                    tracing::trace!("capture deferred UpdateOld ({commit_lsn}, {seqval})");
569                    deferred_updates.insert(
570                        (commit_lsn, seqval),
571                        CdcOperation::UpdateOld(seqval, sql_server_row),
572                    );
573                    continue;
574                }
575                // this is overriden below when the updates are ordered
576                (sql_server_row, Diff::ZERO)
577            }
578        };
579
580        // Try to decode the input row for each output.
581        for partition_idx in partition_indexes {
582            let decoder = decoder_map.get(partition_idx).unwrap();
583
584            let rewind = rewinds.get(partition_idx);
585            // We must continue here to avoid decoding and emitting. We don't have to compare with
586            // snapshot_lsn as we are guaranteed that initial_lsn <= snapshot_lsn.
587            if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
588                continue;
589            }
590
591            let (message, diff) = if let Some(ref deferred_update) = deferred_update {
592                let (old_row, new_row) = match deferred_update {
593                    CdcOperation::UpdateOld(_seqval, row) => (row, &sql_server_row),
594                    CdcOperation::UpdateNew(_seqval, row) => (&sql_server_row, row),
595                    CdcOperation::Insert(_) | CdcOperation::Delete(_) => unreachable!(),
596                };
597
598                let update_old = decode(decoder, old_row, &mut mz_row, &arena, Some(new_row));
599                if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
600                    data_output
601                        .give_fueled(
602                            &data_cap_set[0],
603                            (
604                                (*partition_idx, update_old.clone()),
605                                Lsn::minimum(),
606                                Diff::ONE,
607                            ),
608                        )
609                        .await;
610                }
611                data_output
612                    .give_fueled(
613                        &data_cap_set[0],
614                        ((*partition_idx, update_old), commit_lsn, Diff::MINUS_ONE),
615                    )
616                    .await;
617
618                (
619                    decode(decoder, new_row, &mut mz_row, &arena, None),
620                    Diff::ONE,
621                )
622            } else {
623                (
624                    decode(decoder, &sql_server_row, &mut mz_row, &arena, None),
625                    diff,
626                )
627            };
628            assert_ne!(Diff::ZERO, diff);
629            if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
630                data_output
631                    .give_fueled(
632                        &data_cap_set[0],
633                        ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
634                    )
635                    .await;
636            }
637            data_output
638                .give_fueled(
639                    &data_cap_set[0],
640                    ((*partition_idx, message), commit_lsn, diff),
641                )
642                .await;
643        }
644    }
645    Ok(())
646}
647
648type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
649    T,
650    AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
651>;
652
653/// Helper method to decode a row from a [`tiberius::Row`] (or 2 of them in the case of update)
654/// to a [`Row`]. This centralizes the decode and mapping to result.
655fn decode(
656    decoder: &SqlServerRowDecoder,
657    row: &tiberius::Row,
658    mz_row: &mut Row,
659    arena: &RowArena,
660    new_row: Option<&tiberius::Row>,
661) -> Result<SourceMessage, DataflowError> {
662    match decoder.decode(row, mz_row, arena, new_row) {
663        Ok(()) => Ok(SourceMessage {
664            key: Row::default(),
665            value: mz_row.clone(),
666            metadata: Row::default(),
667        }),
668        Err(e) => {
669            let kind = DecodeErrorKind::Text(e.to_string().into());
670            // TODO(sql_server2): Get the raw bytes from `tiberius`.
671            let raw = format!("{row:?}");
672            Err(DataflowError::DecodeError(Box::new(DecodeError {
673                kind,
674                raw: raw.as_bytes().to_vec(),
675            })))
676        }
677    }
678}
679
680/// Helper method to return a "definite" error upstream.
681async fn return_definite_error(
682    err: DefiniteError,
683    outputs: impl Iterator<Item = u64>,
684    data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
685    data_capset: &CapabilitySet<Lsn>,
686    errs_handle: AsyncOutputHandle<Lsn, CapacityContainerBuilder<Vec<ReplicationError>>>,
687    errs_capset: &CapabilitySet<Lsn>,
688) {
689    for output_idx in outputs {
690        let update = (
691            (output_idx, Err(err.clone().into())),
692            // Select an LSN that should not conflict with a previously observed LSN.  Ideally
693            // we could identify the LSN that resulted in the definite error so that all replicas
694            // would emit the same updates for the same times.
695            Lsn {
696                vlf_id: u32::MAX,
697                block_id: u32::MAX,
698                record_id: u16::MAX,
699            },
700            Diff::ONE,
701        );
702        data_handle.give_fueled(&data_capset[0], update).await;
703    }
704    errs_handle.give(
705        &errs_capset[0],
706        ReplicationError::DefiniteError(Rc::new(err)),
707    );
708}
709
710/// Provides an implemntation of [`SqlServerCdcMetrics`] that will update [`SqlServerSourceMetrics`]`
711struct PrometheusSqlServerCdcMetrics<'a> {
712    inner: &'a SqlServerSourceMetrics,
713}
714
715impl<'a> SqlServerCdcMetrics for PrometheusSqlServerCdcMetrics<'a> {
716    fn snapshot_table_lock_start(&self, table_name: &str) {
717        self.inner.update_snapshot_table_lock_count(table_name, 1);
718    }
719
720    fn snapshot_table_lock_end(&self, table_name: &str) {
721        self.inner.update_snapshot_table_lock_count(table_name, -1);
722    }
723}