mz_storage/source/sql_server/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Code to render the ingestion dataflow of a [`SqlServerSourceConnection`].
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::future::InTask;
24use mz_repr::{Diff, GlobalId, Row, RowArena};
25use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
26use mz_sql_server_util::desc::SqlServerRowDecoder;
27use mz_sql_server_util::inspect::get_latest_restore_history_id;
28use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
29use mz_storage_types::sources::SqlServerSourceConnection;
30use mz_storage_types::sources::sql_server::{
31    CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
32};
33use mz_timely_util::builder_async::{
34    AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use mz_timely_util::containers::stack::AccountedStackBuilder;
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pushers::Tee;
39use timely::dataflow::operators::{CapabilitySet, Concat, Map};
40use timely::dataflow::{Scope, Stream as TimelyStream};
41use timely::progress::{Antichain, Timestamp};
42
43use crate::source::RawSourceCreationConfig;
44use crate::source::sql_server::{
45    DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
46};
47use crate::source::types::{SignaledFuture, SourceMessage, StackedCollection};
48
49/// Used as a partition ID to determine the worker that is responsible for
50/// reading data from SQL Server.
51///
52/// TODO(sql_server2): It's possible we could have different workers
53/// replicate different tables, if we're using SQL Server's CDC features.
54static REPL_READER: &str = "reader";
55
56pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
57    scope: G,
58    config: RawSourceCreationConfig,
59    outputs: BTreeMap<GlobalId, SourceOutputInfo>,
60    source: SqlServerSourceConnection,
61) -> (
62    StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
63    TimelyStream<G, Infallible>,
64    TimelyStream<G, ReplicationError>,
65    PressOnDropButton,
66) {
67    let op_name = format!("SqlServerReplicationReader({})", config.id);
68    let mut builder = AsyncOperatorBuilder::new(op_name, scope);
69
70    let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
71    let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
72
73    // Captures DefiniteErrors that affect the entire source, including all outputs
74    let (definite_error_handle, definite_errors) =
75        builder.new_output::<CapacityContainerBuilder<_>>();
76
77    let (button, transient_errors) = builder.build_fallible(move |caps| {
78        let busy_signal = Arc::clone(&config.busy_signal);
79        Box::pin(SignaledFuture::new(busy_signal, async move {
80            let [
81                data_cap_set,
82                upper_cap_set,
83                definite_error_cap_set,
84            ]: &mut [_; 3] = caps.try_into().unwrap();
85
86            let connection_config = source
87                .connection
88                .resolve_config(
89                    &config.config.connection_context.secrets_reader,
90                    &config.config,
91                    InTask::Yes,
92                )
93                .await?;
94            let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
95
96            let worker_id = config.worker_id;
97
98            // The decoder is specific to the export, and each export pulls data from a specific capture instance.
99            let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
100            // Maps the 'capture instance' to the output index for only those outputs that this worker will snapshot
101            let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
102            // Maps the 'capture instance' to the output index for all outputs of this worker
103            let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
104            // Export statistics for a given capture instance
105            let mut export_statistics: BTreeMap<_, Vec<_>> = BTreeMap::new();
106
107            for (export_id, output) in outputs.iter() {
108                if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
109                    panic!("Multiple decoders for output index {}", output.partition_index);
110                }
111                capture_instances
112                    .entry(Arc::clone(&output.capture_instance))
113                    .or_default()
114                    .push(output.partition_index);
115
116                if *output.resume_upper == [Lsn::minimum()] {
117                    capture_instance_to_snapshot
118                        .entry(Arc::clone(&output.capture_instance))
119                        .or_default()
120                        .push((output.partition_index, output.initial_lsn));
121                }
122                export_statistics.entry(Arc::clone(&output.capture_instance))
123                    .or_default()
124                    .push(
125                        config
126                            .statistics
127                            .get(export_id)
128                            .expect("statistics have been intialized")
129                            .clone(),
130                    );
131            }
132
133            // Eagerly emit an event if we have tables to snapshot.
134            // A worker *must* emit a count even if not responsible for snapshotting a table
135            // as statistic summarization will return null if any worker hasn't set a value.
136            // This will also reset snapshot stats for any exports not snapshotting.
137            if !capture_instance_to_snapshot.is_empty() {
138                for stats in config.statistics.values() {
139                    stats.set_snapshot_records_known(0);
140                    stats.set_snapshot_records_staged(0);
141                }
142            }
143            // We need to emit statistics before we exit
144            // TODO(sql_server2): Run ingestions across multiple workers.
145            if !config.responsible_for(REPL_READER) {
146                return Ok::<_, TransientError>(());
147            }
148
149            let snapshot_instances = capture_instance_to_snapshot
150                    .keys()
151                    .map(|i| i.as_ref());
152
153            // TODO (maz): we can avoid this query by using SourceOutputInfo
154            let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
155
156            // validate that the restore_history_id hasn't changed
157            let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
158            if current_restore_history_id != source.extras.restore_history_id {
159                let definite_error = DefiniteError::RestoreHistoryChanged(
160                    source.extras.restore_history_id.clone(),
161                    current_restore_history_id.clone()
162                );
163                tracing::warn!(?definite_error, "Restore detected, exiting");
164
165                return_definite_error(
166                        definite_error,
167                        capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
168                        data_output,
169                        data_cap_set,
170                        definite_error_handle,
171                        definite_error_cap_set,
172                    ).await;
173                return Ok(());
174            }
175
176            // We first calculate all the total rows we need to fetch across all tables. Since this
177            // happens outside the snapshot transaction the totals might be off, so we won't assert
178            // that we get exactly this many rows later.
179            for table in &snapshot_tables {
180                let table_total = mz_sql_server_util::inspect::snapshot_size(&mut client, &table.schema_name, &table.name).await?;
181                for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
182                    export_stat.set_snapshot_records_known(u64::cast_from(table_total));
183                    export_stat.set_snapshot_records_staged(0);
184                }
185            }
186
187            let mut cdc_handle = client
188                .cdc(capture_instances.keys().cloned())
189                .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
190
191            // Snapshot any instance that requires it.
192            // Each table snapshot will have its own LSN captured at the moment of snapshotting.
193            let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
194                // Before starting a transaction where the LSN will not advance, ensure
195                // the upstream DB is ready for CDC.
196                cdc_handle.wait_for_ready().await?;
197
198                // Intentionally logging this at info for debugging. This section won't get entered
199                // often, but if there are problems here, it will be much easier to troubleshoot
200                // knowing where stall/hang might be happening.
201                tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
202
203                let report_interval =
204                    SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
205                let mut last_report = Instant::now();
206                let mut snapshot_lsns = BTreeMap::new();
207                let arena = RowArena::default();
208
209                for table in snapshot_tables {
210                    // TODO(sql_server3): filter columns to only select columns required for Source.
211                    let (snapshot_lsn, snapshot)= cdc_handle
212                        .snapshot(&table, config.worker_id, config.id)
213                        .await?;
214
215                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot start");
216
217                    let mut snapshot = std::pin::pin!(snapshot);
218
219                    snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
220
221                    let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
222                        .unwrap_or_else(|| {
223                            panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
224                        });
225
226                    let mut snapshot_staged = 0;
227                    while let Some(result) = snapshot.next().await {
228                        let sql_server_row = result.map_err(TransientError::from)?;
229
230                        if last_report.elapsed() > report_interval.get() {
231                            last_report = Instant::now();
232                            for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
233                                export_stat.set_snapshot_records_staged(snapshot_staged);
234                            }
235                        }
236
237                        for (partition_idx, _) in partition_indexes {
238                            // Decode the SQL Server row into an MZ one.
239                            let mut mz_row = Row::default();
240
241                            let decoder = decoder_map.get(partition_idx).expect("decoder for output");
242                            // Try to decode a row, returning a SourceError if it fails.
243                            let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
244                                Ok(()) => Ok(SourceMessage {
245                                    key: Row::default(),
246                                    value: mz_row,
247                                    metadata: Row::default(),
248                                }),
249                                Err(e) => {
250                                    let kind = DecodeErrorKind::Text(e.to_string().into());
251                                    // TODO(sql_server2): Get the raw bytes from `tiberius`.
252                                    let raw = format!("{sql_server_row:?}");
253                                    Err(DataflowError::DecodeError(Box::new(DecodeError {
254                                        kind,
255                                        raw: raw.as_bytes().to_vec(),
256                                    })))
257                                }
258                            };
259                            data_output
260                                .give_fueled(
261                                    &data_cap_set[0],
262                                    ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
263                                )
264                                .await;
265                        }
266                        snapshot_staged += 1;
267                    }
268
269                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, "timely-{worker_id} snapshot complete");
270
271                    // final update for snapshot_staged, using the staged values as the total is an estimate
272                    for export_stat in export_statistics.get(&table.capture_instance.name).unwrap() {
273                        export_stat.set_snapshot_records_staged(snapshot_staged);
274                        export_stat.set_snapshot_records_known(snapshot_staged);
275                    }
276                }
277
278                snapshot_lsns
279            };
280
281            // Rewinds need to keep track of 2 timestamps to ensure that
282            // all replicas emit the same set of updates for any given timestamp.
283            // These are the initial_lsn and snapshot_lsn, where initial_lsn must be
284            // less than or equal to snapshot_lsn.
285            //
286            // - events at an LSN less than or equal to initial_lsn are ignored
287            // - events at an LSN greater than initial_lsn and less than or equal to
288            //   snapshot_lsn are retracted at Lsn::minimum(), and emitted at the commit_lsn
289            // - events at an LSN greater than snapshot_lsn are emitted at the commit_lsn
290            //
291            // where the commit_lsn is the upstream LSN that the event was committed at
292            //
293            // If initial_lsn == snapshot_lsn, all CDC events at LSNs up to and including the
294            // snapshot_lsn are ignored, and no rewinds are issued.
295            let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
296                .iter()
297                .flat_map(|(capture_instance, export_ids)|{
298                    let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
299                    export_ids
300                        .iter()
301                        .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
302                }).collect();
303
304            // For now, we assert that initial_lsn captured during purification is less
305            // than or equal to snapshot_lsn. If that was not true, it would mean that
306            // we observed a SQL server DB that appeared to go back in time.
307            // TODO (maz): not ideal to do this after snapshot, move this into
308            // CdcStream::snapshot after https://github.com/MaterializeInc/materialize/pull/32979 is merged.
309            for (initial_lsn, snapshot_lsn) in rewinds.values() {
310                assert!(
311                    initial_lsn <= snapshot_lsn,
312                    "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
313                );
314            }
315
316            tracing::debug!("rewinds to process: {rewinds:?}");
317
318            capture_instance_to_snapshot.clear();
319
320            // Resumption point is the minimum LSN that has been observed per capture instance.
321            let mut resume_lsns = BTreeMap::new();
322            for src_info in outputs.values() {
323                let resume_lsn = match src_info.resume_upper.as_option() {
324                    Some(lsn) if *lsn != Lsn::minimum() => *lsn,
325                    // initial_lsn is the max lsn observed, but the resume lsn
326                    // is the next lsn that should be read.  After a snapshot, initial_lsn
327                    // has been read, so replication will start at the next available lsn.
328                    Some(_) => src_info.initial_lsn.increment(),
329                    None => panic!("resume_upper has at least one value"),
330                };
331                resume_lsns.entry(Arc::clone(&src_info.capture_instance))
332                    .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
333                    .or_insert(resume_lsn);
334            }
335
336            tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
337            for instance in capture_instances.keys() {
338                let resume_lsn = resume_lsns
339                    .get(instance)
340                    .expect("resume_lsn exists for capture instance");
341                cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
342            }
343
344            // Off to the races! Replicate data from SQL Server.
345            let cdc_stream = cdc_handle
346                .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
347                .into_stream();
348            let mut cdc_stream = std::pin::pin!(cdc_stream);
349
350            let mut errored_instances = BTreeSet::new();
351
352            // TODO(sql_server2): We should emit `ProgressStatisticsUpdate::SteadyState` messages
353            // here, when we receive progress events. What stops us from doing this now is our
354            // 10-byte LSN doesn't fit into the 8-byte integer that the progress event uses.
355            let mut log_rewinds_complete = true;
356            while let Some(event) = cdc_stream.next().await {
357                let event = event.map_err(TransientError::from)?;
358                tracing::trace!(?config.id, ?event, "got replication event");
359
360                match event {
361                    // We've received all of the changes up-to this LSN, so
362                    // downgrade our capability.
363                    CdcEvent::Progress { next_lsn } => {
364                        tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
365                        // cannot downgrade capability until rewinds have been processed,
366                        // we must be able to produce data at the minimum offset.
367                        rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
368                        if rewinds.is_empty() {
369                            if log_rewinds_complete {
370                                tracing::debug!("rewinds complete");
371                                log_rewinds_complete = false;
372                            }
373                            data_cap_set.downgrade(Antichain::from_elem(next_lsn));
374                        } else {
375                            tracing::debug!("rewinds remaining: {:?}", rewinds);
376                        }
377                        upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
378                    }
379                    // We've got new data! Let's process it.
380                    CdcEvent::Data {
381                        capture_instance,
382                        lsn,
383                        changes,
384                    } => {
385                        if errored_instances.contains(&capture_instance) {
386                            // outputs for this captured instance are in an errored state, so they are not
387                            // emitted
388                        }
389
390                        let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
391                            let definite_error = DefiniteError::ProgrammingError(format!(
392                                "capture instance didn't exist: '{capture_instance}'"
393                            ));
394                            return_definite_error(
395                                definite_error,
396                                capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
397                                data_output,
398                                data_cap_set,
399                                definite_error_handle,
400                                definite_error_cap_set,
401                            )
402                            .await;
403                            return Ok(());
404                        };
405
406                        handle_data_event(
407                            changes,
408                            partition_indexes,
409                            &decoder_map,
410                            lsn,
411                            &rewinds,
412                            &data_output,
413                            data_cap_set
414                        ).await?
415                    },
416                    CdcEvent::SchemaUpdate { capture_instance, table, ddl_event } => {
417                        if !errored_instances.contains(&capture_instance)
418                            && !ddl_event.is_compatible() {
419                            let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
420                                let definite_error = DefiniteError::ProgrammingError(format!(
421                                    "capture instance didn't exist: '{capture_instance}'"
422                                ));
423                                return_definite_error(
424                                    definite_error,
425                                    capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
426                                    data_output,
427                                    data_cap_set,
428                                    definite_error_handle,
429                                    definite_error_cap_set,
430                                )
431                                .await;
432                                return Ok(());
433                            };
434                            let error = DefiniteError::IncompatibleSchemaChange(
435                                capture_instance.to_string(),
436                                table.to_string()
437                            );
438                            for partition_idx in partition_indexes {
439                                data_output
440                                    .give_fueled(
441                                        &data_cap_set[0],
442                                        ((*partition_idx, Err(error.clone().into())), ddl_event.lsn, Diff::ONE),
443                                    )
444                                    .await;
445                            }
446                            errored_instances.insert(capture_instance);
447                        }
448                    }
449                };
450            }
451            Err(TransientError::ReplicationEOF)
452        }))
453    });
454
455    let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
456
457    (
458        data_stream.as_collection(),
459        upper_stream,
460        error_stream,
461        button.press_on_drop(),
462    )
463}
464
465async fn handle_data_event(
466    changes: Vec<CdcOperation>,
467    partition_indexes: &[u64],
468    decoder_map: &BTreeMap<u64, Arc<SqlServerRowDecoder>>,
469    commit_lsn: Lsn,
470    rewinds: &BTreeMap<u64, (Lsn, Lsn)>,
471    data_output: &StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
472    data_cap_set: &CapabilitySet<Lsn>,
473) -> Result<(), TransientError> {
474    for change in changes {
475        let (sql_server_row, diff): (_, _) = match change {
476            CdcOperation::Insert(sql_server_row) | CdcOperation::UpdateNew(sql_server_row) => {
477                (sql_server_row, Diff::ONE)
478            }
479            CdcOperation::Delete(sql_server_row) | CdcOperation::UpdateOld(sql_server_row) => {
480                (sql_server_row, Diff::MINUS_ONE)
481            }
482        };
483
484        // Try to decode a row, returning a SourceError if it fails.
485        let mut mz_row = Row::default();
486        let arena = RowArena::default();
487
488        for partition_idx in partition_indexes {
489            let decoder = decoder_map.get(partition_idx).unwrap();
490
491            let rewind = rewinds.get(partition_idx);
492            // We must continue here to avoid decoding and emitting. We don't have to compare with
493            // snapshot_lsn as we are guaranteed that initial_lsn <= snapshot_lsn.
494            if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
495                continue;
496            }
497
498            // Try to decode a row, returning a SourceError if it fails.
499            let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
500                Ok(()) => Ok(SourceMessage {
501                    key: Row::default(),
502                    value: mz_row.clone(),
503                    metadata: Row::default(),
504                }),
505                Err(e) => {
506                    let kind = DecodeErrorKind::Text(e.to_string().into());
507                    // TODO(sql_server2): Get the raw bytes from `tiberius`.
508                    let raw = format!("{sql_server_row:?}");
509                    Err(DataflowError::DecodeError(Box::new(DecodeError {
510                        kind,
511                        raw: raw.as_bytes().to_vec(),
512                    })))
513                }
514            };
515
516            if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
517                data_output
518                    .give_fueled(
519                        &data_cap_set[0],
520                        ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
521                    )
522                    .await;
523            }
524            data_output
525                .give_fueled(
526                    &data_cap_set[0],
527                    ((*partition_idx, message), commit_lsn, diff),
528                )
529                .await;
530        }
531    }
532    Ok(())
533}
534
535type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
536    T,
537    AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
538    Tee<T, TimelyStack<(D, T, Diff)>>,
539>;
540
541/// Helper method to return a "definite" error upstream.
542async fn return_definite_error(
543    err: DefiniteError,
544    outputs: impl Iterator<Item = u64>,
545    data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
546    data_capset: &CapabilitySet<Lsn>,
547    errs_handle: AsyncOutputHandle<
548        Lsn,
549        CapacityContainerBuilder<Vec<ReplicationError>>,
550        Tee<Lsn, Vec<ReplicationError>>,
551    >,
552    errs_capset: &CapabilitySet<Lsn>,
553) {
554    for output_idx in outputs {
555        let update = (
556            (output_idx, Err(err.clone().into())),
557            // Select an LSN that should not conflict with a previously observed LSN.  Ideally
558            // we could identify the LSN that resulted in the definite error so that all replicas
559            // would emit the same updates for the same times.
560            Lsn {
561                vlf_id: u32::MAX,
562                block_id: u32::MAX,
563                record_id: u16::MAX,
564            },
565            Diff::ONE,
566        );
567        data_handle.give_fueled(&data_capset[0], update).await;
568    }
569    errs_handle.give(
570        &errs_capset[0],
571        ReplicationError::DefiniteError(Rc::new(err)),
572    );
573}