mz_storage/source/sql_server/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Code to render the ingestion dataflow of a [`SqlServerSource`].
11
12use std::collections::BTreeMap;
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use itertools::Itertools;
22use mz_ore::cast::CastFrom;
23use mz_ore::future::InTask;
24use mz_repr::{Diff, GlobalId, Row, RowArena};
25use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
26use mz_sql_server_util::inspect::get_latest_restore_history_id;
27use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
28use mz_storage_types::sources::SqlServerSource;
29use mz_storage_types::sources::sql_server::{
30    CDC_POLL_INTERVAL, MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
31};
32use mz_timely_util::builder_async::{
33    AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
34};
35use mz_timely_util::containers::stack::AccountedStackBuilder;
36use timely::container::CapacityContainerBuilder;
37use timely::dataflow::channels::pushers::Tee;
38use timely::dataflow::operators::{CapabilitySet, Concat, Map};
39use timely::dataflow::{Scope, Stream as TimelyStream};
40use timely::progress::{Antichain, Timestamp};
41
42use crate::source::RawSourceCreationConfig;
43use crate::source::sql_server::{
44    DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
45};
46use crate::source::types::{
47    ProgressStatisticsUpdate, SignaledFuture, SourceMessage, StackedCollection,
48};
49
50/// Used as a partition ID to determine the worker that is responsible for
51/// reading data from SQL Server.
52///
53/// TODO(sql_server2): It's possible we could have different workers
54/// replicate different tables, if we're using SQL Server's CDC features.
55static REPL_READER: &str = "reader";
56
57pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
58    scope: G,
59    config: RawSourceCreationConfig,
60    outputs: BTreeMap<GlobalId, SourceOutputInfo>,
61    source: SqlServerSource,
62) -> (
63    StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
64    TimelyStream<G, Infallible>,
65    TimelyStream<G, ReplicationError>,
66    TimelyStream<G, ProgressStatisticsUpdate>,
67    PressOnDropButton,
68) {
69    let op_name = format!("SqlServerReplicationReader({})", config.id);
70    let mut builder = AsyncOperatorBuilder::new(op_name, scope);
71
72    let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
73    let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
74    let (stats_output, stats_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
75
76    // Captures DefiniteErrors that affect the entire source, including all outputs
77    let (definite_error_handle, definite_errors) =
78        builder.new_output::<CapacityContainerBuilder<_>>();
79
80    let (button, transient_errors) = builder.build_fallible(move |caps| {
81        let busy_signal = Arc::clone(&config.busy_signal);
82        Box::pin(SignaledFuture::new(busy_signal, async move {
83            let [
84                data_cap_set,
85                upper_cap_set,
86                stats_cap,
87                definite_error_cap_set,
88            ]: &mut [_; 4] = caps.try_into().unwrap();
89
90            // TODO(sql_server2): Run ingestions across multiple workers.
91            if !config.responsible_for(REPL_READER) {
92                return Ok::<_, TransientError>(());
93            }
94
95            let connection_config = source
96                .connection
97                .resolve_config(
98                    &config.config.connection_context.secrets_reader,
99                    &config.config,
100                    InTask::Yes,
101                )
102                .await?;
103            let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
104
105            let worker_id = config.worker_id;
106
107            // The decoder is specific to the export, and each export pulls data from a specific capture instance.
108            let mut decoder_map: BTreeMap<_, _> = BTreeMap::new();
109            // Maps the 'capture instance' to the output index for only those outputs that this worker will snapshot
110            let mut capture_instance_to_snapshot: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
111            // Maps the 'capture instance' to the output index for all outputs of this worker
112            let mut capture_instances: BTreeMap<Arc<str>, Vec<_>> = BTreeMap::new();
113
114            for output in outputs.values() {
115                if decoder_map.insert(output.partition_index, Arc::clone(&output.decoder)).is_some() {
116                    panic!("Multiple decoders for output index {}", output.partition_index);
117                }
118                capture_instances
119                    .entry(Arc::clone(&output.capture_instance))
120                    .or_default()
121                    .push(output.partition_index);
122
123                if *output.resume_upper == [Lsn::minimum()] {
124                    capture_instance_to_snapshot
125                        .entry(Arc::clone(&output.capture_instance))
126                        .or_default()
127                        .push((output.partition_index, output.initial_lsn));
128                }
129            }
130
131            let snapshot_instances = capture_instance_to_snapshot
132                    .keys()
133                    .map(|i| i.as_ref());
134
135            let snapshot_tables = mz_sql_server_util::inspect::get_tables_for_capture_instance(&mut client, snapshot_instances).await?;
136
137            // validate that the restore_history_id hasn't changed
138            let current_restore_history_id = get_latest_restore_history_id(&mut client).await?;
139            if current_restore_history_id != source.extras.restore_history_id {
140                let definite_error = DefiniteError::RestoreHistoryChanged(
141                    source.extras.restore_history_id.clone(),
142                    current_restore_history_id.clone()
143                );
144                tracing::error!(?definite_error, "Restore detected, exiting replication");
145
146                return_definite_error(
147                        definite_error.clone(),
148                        capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
149                        data_output,
150                        data_cap_set,
151                        definite_error_handle,
152                        definite_error_cap_set,
153                    ).await;
154                return Ok(());
155            }
156
157            let mut cdc_handle = client
158                .cdc(capture_instances.keys().cloned())
159                .max_lsn_wait(MAX_LSN_WAIT.get(config.config.config_set()));
160
161            // Snapshot any instance that requires it.
162            // Each table snapshot will have its own LSN captured at the moment of snapshotting.
163            let snapshot_lsns: BTreeMap<Arc<str>, Lsn> = {
164                // Small helper closure.
165                let emit_stats = |cap, known: usize, total: usize| {
166                    let update = ProgressStatisticsUpdate::Snapshot {
167                        records_known: u64::cast_from(known),
168                        records_staged: u64::cast_from(total),
169                    };
170                    tracing::debug!(?config.id, %known, %total, "snapshot progress");
171                    stats_output.give(cap, update);
172                };
173                // Before starting a transaction where the LSN will not advance, ensure
174                // the upstream DB is ready for CDC.
175                cdc_handle.wait_for_ready().await?;
176
177                // Intentionally logging this at info for debugging. This section won't get entered
178                // often, but if there are problems here, it will be much easier to troubleshoot
179                // knowing where stall/hang might be happening.
180                tracing::info!(%config.worker_id, "timely-{worker_id} upstream is ready");
181
182                // Eagerly emit an event if we have tables to snapshot.
183                if !snapshot_tables.is_empty() {
184                    emit_stats(&stats_cap[0], 0, 0);
185                }
186
187                let report_interval =
188                    SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
189                let mut last_report = Instant::now();
190                let mut records_known: usize = 0;
191                let mut records_total: usize = 0;
192                let mut snapshot_lsns = BTreeMap::new();
193                let arena = RowArena::default();
194
195                for table in snapshot_tables {
196                    // TODO(sql_server3): filter columns to only select columns required for Source.
197                    let (snapshot_lsn, size, snapshot)= cdc_handle
198                        .snapshot(&table, config.worker_id, config.id)
199                        .await?;
200
201                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, %size, "timely-{worker_id} snapshot start");
202
203                    let mut snapshot = std::pin::pin!(snapshot);
204
205                    snapshot_lsns.insert(Arc::clone(&table.capture_instance.name), snapshot_lsn);
206
207                    records_known = records_known.saturating_add(size);
208                    emit_stats(&stats_cap[0], records_known, records_total);
209
210                    let partition_indexes = capture_instance_to_snapshot.get(&table.capture_instance.name)
211                        .unwrap_or_else(|| {
212                            panic!("no snapshot outputs in known capture instances [{}] for capture instance: '{}'", capture_instance_to_snapshot.keys().join(","), table.capture_instance.name);
213                        });
214
215                    while let Some(result) = snapshot.next().await {
216                        let sql_server_row = result.map_err(TransientError::from)?;
217
218                        records_total = records_total.saturating_add(1);
219
220                        if last_report.elapsed() > report_interval.get() {
221                            last_report = Instant::now();
222                            emit_stats(&stats_cap[0], records_known, records_total);
223                        }
224
225                        for (partition_idx, _) in partition_indexes {
226                            // Decode the SQL Server row into an MZ one.
227                            let mut mz_row = Row::default();
228
229                            let decoder = decoder_map.get(partition_idx).expect("decoder for output");
230                            // Try to decode a row, returning a SourceError if it fails.
231                            let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
232                                Ok(()) => Ok(SourceMessage {
233                                    key: Row::default(),
234                                    value: mz_row,
235                                    metadata: Row::default(),
236                                }),
237                                Err(e) => {
238                                    let kind = DecodeErrorKind::Text(e.to_string().into());
239                                    // TODO(sql_server2): Get the raw bytes from `tiberius`.
240                                    let raw = format!("{sql_server_row:?}");
241                                    Err(DataflowError::DecodeError(Box::new(DecodeError {
242                                        kind,
243                                        raw: raw.as_bytes().to_vec(),
244                                    })))
245                                }
246                            };
247                            data_output
248                                .give_fueled(
249                                    &data_cap_set[0],
250                                    ((*partition_idx, message), Lsn::minimum(), Diff::ONE),
251                                )
252                                .await;
253                        }
254                    }
255
256                    tracing::info!(%config.id, %table.name, %table.schema_name, %snapshot_lsn, %size, "timely-{worker_id} snapshot complete");
257                }
258
259                mz_ore::soft_assert_eq_or_log!(
260                    records_known,
261                    records_total,
262                    "snapshot size did not match total records received",
263                );
264                emit_stats(&stats_cap[0], records_known, records_total);
265
266                snapshot_lsns
267            };
268
269            // Rewinds need to keep track of 2 timestamps to ensure that
270            // all replicas emit the same set of updates for any given timestamp.
271            // These are the initial_lsn and snapshot_lsn, where initial_lsn must be
272            // less than or equal to snapshot_lsn.
273            //
274            // - events at an LSN less than or equal to initial_lsn are ignored
275            // - events at an LSN greater than initial_lsn and less than or equal to
276            //   snapshot_lsn are retracted at Lsn::minimum(), and emitted at the commit_lsn
277            // - events at an LSN greater than snapshot_lsn are emitted at the commit_lsn
278            //
279            // where the commit_lsn is the upstream LSN that the event was committed at
280            //
281            // If initial_lsn == snapshot_lsn, all CDC events at LSNs up to and including the
282            // snapshot_lsn are ignored, and no rewinds are issued.
283            let mut rewinds: BTreeMap<_, _> = capture_instance_to_snapshot
284                .iter()
285                .flat_map(|(capture_instance, export_ids)|{
286                    let snapshot_lsn = snapshot_lsns.get(capture_instance).expect("snapshot lsn must be collected for capture instance");
287                    export_ids
288                        .iter()
289                        .map(|(idx, initial_lsn)| (*idx, (*initial_lsn, *snapshot_lsn)))
290                }).collect();
291
292            // For now, we assert that initial_lsn captured during purification is less
293            // than or equal to snapshot_lsn. If that was not true, it would mean that
294            // we observed a SQL server DB that appeared to go back in time.
295            // TODO (maz): not ideal to do this after snapshot, move this into
296            // CdcStream::snapshot after https://github.com/MaterializeInc/materialize/pull/32979 is merged.
297            for (initial_lsn, snapshot_lsn) in rewinds.values() {
298                assert!(
299                    initial_lsn <= snapshot_lsn,
300                    "initial_lsn={initial_lsn} snapshot_lsn={snapshot_lsn}"
301                );
302            }
303
304            tracing::debug!("rewinds to process: {rewinds:?}");
305
306            capture_instance_to_snapshot.clear();
307
308            // Resumption point is the minimum LSN that has been observed per capture instance.
309            let mut resume_lsns = BTreeMap::new();
310            for src_info in outputs.values() {
311                let resume_lsn = match src_info.resume_upper.as_option() {
312                    Some(lsn) if *lsn != Lsn::minimum() => *lsn,
313                    // initial_lsn is the max lsn observed, but the resume lsn
314                    // is the next lsn that should be read.  After a snapshot, initial_lsn
315                    // has been read, so replication will start at the next available lsn.
316                    Some(_) => src_info.initial_lsn.increment(),
317                    None => panic!("resume_upper has at least one value"),
318                };
319                resume_lsns.entry(Arc::clone(&src_info.capture_instance))
320                    .and_modify(|existing| *existing = std::cmp::min(*existing, resume_lsn))
321                    .or_insert(resume_lsn);
322            }
323
324            tracing::info!(%config.id, ?resume_lsns, "timely-{} replication starting", config.worker_id);
325            for instance in capture_instances.keys() {
326                let resume_lsn = resume_lsns
327                    .get(instance)
328                    .expect("resume_lsn exists for capture instance");
329                cdc_handle = cdc_handle.start_lsn(instance, *resume_lsn);
330            }
331
332            // Off to the races! Replicate data from SQL Server.
333            let cdc_stream = cdc_handle
334                .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
335                .into_stream();
336            let mut cdc_stream = std::pin::pin!(cdc_stream);
337
338            // TODO(sql_server2): We should emit `ProgressStatisticsUpdate::SteadyState` messages
339            // here, when we receive progress events. What stops us from doing this now is our
340            // 10-byte LSN doesn't fit into the 8-byte integer that the progress event uses.
341            let mut log_rewinds_complete = true;
342            while let Some(event) = cdc_stream.next().await {
343                let event = event.map_err(TransientError::from)?;
344                tracing::trace!(?config.id, ?event, "got replication event");
345
346                let (capture_instance, commit_lsn, changes) = match event {
347                    // We've received all of the changes up-to this LSN, so
348                    // downgrade our capability.
349                    CdcEvent::Progress { next_lsn } => {
350                        tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
351                        // cannot downgrade capability until rewinds have been processed,
352                        // we must be able to produce data at the minimum offset.
353                        rewinds.retain(|_, (_, snapshot_lsn)| next_lsn <= *snapshot_lsn);
354                        if rewinds.is_empty() {
355                            if log_rewinds_complete {
356                                tracing::debug!("rewinds complete");
357                                log_rewinds_complete = false;
358                            }
359                            data_cap_set.downgrade(Antichain::from_elem(next_lsn));
360                        } else {
361                            tracing::debug!("rewinds remaining: {:?}", rewinds);
362                        }
363                        upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
364                        continue;
365                    }
366                    // We've got new data! Let's process it.
367                    CdcEvent::Data {
368                        capture_instance,
369                        lsn,
370                        changes,
371                    } => (capture_instance, lsn, changes),
372                };
373
374                let Some(partition_indexes) = capture_instances.get(&capture_instance) else {
375                    let definite_error = DefiniteError::ProgrammingError(format!(
376                        "capture instance didn't exist: '{capture_instance}'"
377                    ));
378                    let () = return_definite_error(
379                        definite_error.clone(),
380                        capture_instances.values().flat_map(|indexes| indexes.iter().copied()),
381                        data_output,
382                        data_cap_set,
383                        definite_error_handle,
384                        definite_error_cap_set,
385                    )
386                    .await;
387                    return Ok(());
388                };
389
390
391                for change in changes {
392                    let (sql_server_row, diff): (_, _) = match change {
393                        CdcOperation::Insert(sql_server_row)
394                        | CdcOperation::UpdateNew(sql_server_row) => (sql_server_row, Diff::ONE),
395                        CdcOperation::Delete(sql_server_row)
396                        | CdcOperation::UpdateOld(sql_server_row) => {
397                            (sql_server_row, Diff::MINUS_ONE)
398                        }
399                    };
400
401                    // Try to decode a row, returning a SourceError if it fails.
402                    let mut mz_row = Row::default();
403                    let arena = RowArena::default();
404
405                    for partition_idx in partition_indexes {
406                        let decoder = decoder_map.get(partition_idx).unwrap();
407
408                        let rewind = rewinds.get(partition_idx);
409                        // We must continue here to avoid decoding and emitting. We don't have to compare with
410                        // snapshot_lsn as we are guaranteed that initial_lsn <= snapshot_lsn.
411                        if rewind.is_some_and(|(initial_lsn, _)| commit_lsn <= *initial_lsn) {
412                            continue;
413                        }
414
415                        // Try to decode a row, returning a SourceError if it fails.
416                        let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
417                            Ok(()) => Ok(SourceMessage {
418                                key: Row::default(),
419                                value: mz_row.clone(),
420                                metadata: Row::default(),
421                            }),
422                            Err(e) => {
423                                let kind = DecodeErrorKind::Text(e.to_string().into());
424                                // TODO(sql_server2): Get the raw bytes from `tiberius`.
425                                let raw = format!("{sql_server_row:?}");
426                                Err(DataflowError::DecodeError(Box::new(DecodeError {
427                                    kind,
428                                    raw: raw.as_bytes().to_vec(),
429                                })))
430                            }
431                        };
432
433                        if rewind.is_some_and(|(_, snapshot_lsn)| commit_lsn <= *snapshot_lsn) {
434                            data_output
435                                .give_fueled(
436                                    &data_cap_set[0],
437                                    ((*partition_idx, message.clone()), Lsn::minimum(), -diff),
438                                )
439                                .await;
440                        }
441                        data_output
442                            .give_fueled(
443                                &data_cap_set[0],
444                                ((*partition_idx, message), commit_lsn, diff),
445                            )
446                            .await;
447                    }
448                }
449            }
450            Err(TransientError::ReplicationEOF)
451        }))
452    });
453
454    let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
455
456    (
457        data_stream.as_collection(),
458        upper_stream,
459        error_stream,
460        stats_stream,
461        button.press_on_drop(),
462    )
463}
464
465type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
466    T,
467    AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
468    Tee<T, TimelyStack<(D, T, Diff)>>,
469>;
470
471/// Helper method to return a "definite" error upstream.
472async fn return_definite_error(
473    err: DefiniteError,
474    outputs: impl Iterator<Item = u64>,
475    data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
476    data_capset: &CapabilitySet<Lsn>,
477    errs_handle: AsyncOutputHandle<
478        Lsn,
479        CapacityContainerBuilder<Vec<ReplicationError>>,
480        Tee<Lsn, Vec<ReplicationError>>,
481    >,
482    errs_capset: &CapabilitySet<Lsn>,
483) {
484    for output_idx in outputs {
485        let update = (
486            (output_idx, Err(err.clone().into())),
487            // Select an LSN that should not conflict with a previously observed LSN.  Ideally
488            // we could identify the LSN that resulted in the definite error so that all replicas
489            // would emit the same updates for the same times.
490            Lsn {
491                vlf_id: u32::MAX,
492                block_id: u32::MAX,
493                record_id: u16::MAX,
494            },
495            Diff::ONE,
496        );
497        data_handle.give_fueled(&data_capset[0], update).await;
498    }
499    errs_handle.give(
500        &errs_capset[0],
501        ReplicationError::DefiniteError(Rc::new(err)),
502    );
503}