mz_storage/source/sql_server/
replication.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Code to render the ingestion dataflow of a [`SqlServerSource`].
11
12use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use mz_ore::cast::CastFrom;
22use mz_ore::future::InTask;
23use mz_repr::{Diff, GlobalId, Row, RowArena};
24use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
25use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
26use mz_storage_types::sources::SqlServerSource;
27use mz_storage_types::sources::sql_server::{
28    CDC_POLL_INTERVAL, SNAPSHOT_MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
29};
30use mz_timely_util::builder_async::{
31    AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
32};
33use mz_timely_util::containers::stack::AccountedStackBuilder;
34use timely::container::CapacityContainerBuilder;
35use timely::dataflow::channels::pushers::Tee;
36use timely::dataflow::operators::{CapabilitySet, Concat, Map};
37use timely::dataflow::{Scope, Stream as TimelyStream};
38use timely::progress::{Antichain, Timestamp};
39
40use crate::source::RawSourceCreationConfig;
41use crate::source::sql_server::{
42    DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
43};
44use crate::source::types::{
45    ProgressStatisticsUpdate, SignaledFuture, SourceMessage, StackedCollection,
46};
47
48/// Used as a partition ID to determine the worker that is responsible for
49/// reading data from SQL Server.
50///
51/// TODO(sql_server2): It's possible we could have different workers
52/// replicate different tables, if we're using SQL Server's CDC features.
53static REPL_READER: &str = "reader";
54
55pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
56    scope: G,
57    config: RawSourceCreationConfig,
58    outputs: BTreeMap<GlobalId, SourceOutputInfo>,
59    source: SqlServerSource,
60) -> (
61    StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
62    TimelyStream<G, Infallible>,
63    TimelyStream<G, ReplicationError>,
64    TimelyStream<G, ProgressStatisticsUpdate>,
65    PressOnDropButton,
66) {
67    let op_name = format!("SqlServerReplicationReader({})", config.id);
68    let mut builder = AsyncOperatorBuilder::new(op_name, scope);
69
70    let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
71    let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
72    let (stats_output, stats_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
73
74    // Captures DefiniteErrors that affect the entire source, including all outputs
75    let (definite_error_handle, definite_errors) =
76        builder.new_output::<CapacityContainerBuilder<_>>();
77
78    let (button, transient_errors) = builder.build_fallible(move |caps| {
79        let busy_signal = Arc::clone(&config.busy_signal);
80        Box::pin(SignaledFuture::new(busy_signal, async move {
81            let [
82                data_cap_set,
83                upper_cap_set,
84                stats_cap,
85                definite_error_cap_set,
86            ]: &mut [_; 4] = caps.try_into().unwrap();
87
88            // TODO(sql_server2): Run ingestions across multiple workers.
89            if !config.responsible_for(REPL_READER) {
90                return Ok::<_, TransientError>(());
91            }
92
93            let connection_config = source
94                .connection
95                .resolve_config(
96                    &config.config.connection_context.secrets_reader,
97                    &config.config,
98                    InTask::Yes,
99                )
100                .await?;
101            let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
102
103            let output_indexes: Vec<_> = outputs
104                .values()
105                .map(|v| usize::cast_from(v.partition_index))
106                .collect();
107
108            // Instances that have already made progress do not need to be snapshotted.
109            let needs_snapshot: BTreeSet<_> = outputs
110                .values()
111                .filter_map(|output| {
112                    if *output.resume_upper == [Lsn::minimum()] {
113                        Some(Arc::clone(&output.capture_instance))
114                    } else {
115                        None
116                    }
117                })
118                .collect();
119
120            // Map from a SQL Server 'capture instance' to Materialize collection.
121            let capture_instances: BTreeMap<_, _> = outputs
122                .values()
123                .map(|output| {
124                    (
125                        Arc::clone(&output.capture_instance),
126                        (output.partition_index, Arc::clone(&output.decoder)),
127                    )
128                })
129                .collect();
130            let mut cdc_handle = client
131                .cdc(capture_instances.keys().cloned())
132                .max_lsn_wait(SNAPSHOT_MAX_LSN_WAIT.get(config.config.config_set()));
133
134            // Snapshot any instances that require it.
135            let snapshot_lsn = {
136                // Small helper closure.
137                let emit_stats = |cap, known: usize, total: usize| {
138                    let update = ProgressStatisticsUpdate::Snapshot {
139                        records_known: u64::cast_from(known),
140                        records_staged: u64::cast_from(total),
141                    };
142                    tracing::debug!(?config.id, %known, %total, "snapshot progress");
143                    stats_output.give(cap, update);
144                };
145
146                tracing::debug!(?config.id, ?needs_snapshot, "starting snapshot");
147                // Eagerly emit an event if we have tables to snapshot.
148                if !needs_snapshot.is_empty() {
149                    emit_stats(&stats_cap[0], 0, 0);
150                }
151
152                let (snapshot_lsn, snapshot_stats, snapshot_streams) =
153                    cdc_handle.snapshot(Some(needs_snapshot)).await?;
154                let snapshot_cap = data_cap_set.delayed(&snapshot_lsn);
155
156                // As we stream rows for the snapshot we'll track the total we've seen.
157                let mut records_total: usize = 0;
158                let records_known = snapshot_stats.values().sum();
159                let report_interval =
160                    SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
161                let mut last_report = Instant::now();
162                if !snapshot_stats.is_empty() {
163                    emit_stats(&stats_cap[0], records_known, 0);
164                }
165
166                // Begin streaming our snapshots!
167                let mut snapshot_streams = std::pin::pin!(snapshot_streams);
168                while let Some((capture_instance, data)) = snapshot_streams.next().await {
169                    let sql_server_row = data.map_err(TransientError::from)?;
170                    records_total = records_total.saturating_add(1);
171
172                    if last_report.elapsed() > report_interval.get() {
173                        last_report = Instant::now();
174                        emit_stats(&stats_cap[0], records_known, records_total);
175                    }
176
177                    // Decode the SQL Server row into an MZ one.
178                    let (partition_idx, decoder) =
179                        capture_instances.get(&capture_instance).ok_or_else(|| {
180                            let msg =
181                                format!("capture instance didn't exist: '{capture_instance}'");
182                            TransientError::ProgrammingError(msg)
183                        })?;
184
185                    // Try to decode a row, returning a SourceError if it fails.
186                    let mut mz_row = Row::default();
187                    let arena = RowArena::default();
188                    let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
189                        Ok(()) => Ok(SourceMessage {
190                            key: Row::default(),
191                            value: mz_row,
192                            metadata: Row::default(),
193                        }),
194                        Err(e) => {
195                            let kind = DecodeErrorKind::Text(e.to_string().into());
196                            // TODO(sql_server2): Get the raw bytes from `tiberius`.
197                            let raw = format!("{sql_server_row:?}");
198                            Err(DataflowError::DecodeError(Box::new(DecodeError {
199                                kind,
200                                raw: raw.as_bytes().to_vec(),
201                            })))
202                        }
203                    };
204                    data_output
205                        .give_fueled(
206                            &snapshot_cap,
207                            ((*partition_idx, message), snapshot_lsn, Diff::ONE),
208                        )
209                        .await;
210                }
211
212                mz_ore::soft_assert_eq_or_log!(
213                    records_known,
214                    records_total,
215                    "snapshot size did not match total records received",
216                );
217                emit_stats(&stats_cap[0], records_known, records_total);
218
219                snapshot_lsn
220            };
221
222            // Start replicating from the LSN __after__ we took the snapshot.
223            let replication_start_lsn = snapshot_lsn.increment();
224
225            // Set all of the LSNs to start replicating from.
226            for output_info in outputs.values() {
227                match output_info.resume_upper.as_option() {
228                    // We just snapshotted this instance, so use the snapshot LSN.
229                    Some(lsn) => {
230                        let initial_lsn = if *lsn == Lsn::minimum() {
231                            replication_start_lsn
232                        } else {
233                            *lsn
234                        };
235                        cdc_handle =
236                            cdc_handle.start_lsn(&output_info.capture_instance, initial_lsn);
237                    }
238                    None => unreachable!("empty resume upper?"),
239                }
240            }
241
242            // Off to the races! Replicate data from SQL Server.
243            let cdc_stream = cdc_handle
244                .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
245                .into_stream();
246            let mut cdc_stream = std::pin::pin!(cdc_stream);
247
248            // TODO(sql_server2): We should emit `ProgressStatisticsUpdate::SteadyState` messages
249            // here, when we receive progress events. What stops us from doing this now is our
250            // 10-byte LSN doesn't fit into the 8-byte integer that the progress event uses.
251            while let Some(event) = cdc_stream.next().await {
252                let event = event.map_err(TransientError::from)?;
253                tracing::trace!(?config.id, ?event, "got replication event");
254
255                let (capture_instance, commit_lsn, changes) = match event {
256                    // We've received all of the changes up-to this LSN, so
257                    // downgrade our capability.
258                    CdcEvent::Progress { next_lsn } => {
259                        tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
260                        data_cap_set.downgrade(Antichain::from_elem(next_lsn));
261                        upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
262                        continue;
263                    }
264                    // We've got new data! Let's process it.
265                    CdcEvent::Data {
266                        capture_instance,
267                        lsn,
268                        changes,
269                    } => (capture_instance, lsn, changes),
270                };
271
272                // Decode the SQL Server row into an MZ one.
273                let Some((partition_idx, decoder)) = capture_instances.get(&capture_instance)
274                else {
275                    let definite_error = DefiniteError::ProgrammingError(format!(
276                        "capture instance didn't exist: '{capture_instance}'"
277                    ));
278                    let () = return_definite_error(
279                        definite_error,
280                        &output_indexes[..],
281                        data_output,
282                        data_cap_set,
283                        definite_error_handle,
284                        definite_error_cap_set,
285                    )
286                    .await;
287                    return Ok(());
288                };
289
290                for change in changes {
291                    let (sql_server_row, diff): (_, _) = match change {
292                        CdcOperation::Insert(sql_server_row)
293                        | CdcOperation::UpdateNew(sql_server_row) => (sql_server_row, Diff::ONE),
294                        CdcOperation::Delete(sql_server_row)
295                        | CdcOperation::UpdateOld(sql_server_row) => {
296                            (sql_server_row, Diff::MINUS_ONE)
297                        }
298                    };
299
300                    // Try to decode a row, returning a SourceError if it fails.
301                    let mut mz_row = Row::default();
302                    let arena = RowArena::default();
303                    let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
304                        Ok(()) => Ok(SourceMessage {
305                            key: Row::default(),
306                            value: mz_row,
307                            metadata: Row::default(),
308                        }),
309                        Err(e) => {
310                            let kind = DecodeErrorKind::Text(e.to_string().into());
311                            // TODO(sql_server2): Get the raw bytes from `tiberius`.
312                            let raw = format!("{sql_server_row:?}");
313                            Err(DataflowError::DecodeError(Box::new(DecodeError {
314                                kind,
315                                raw: raw.as_bytes().to_vec(),
316                            })))
317                        }
318                    };
319                    data_output
320                        .give_fueled(
321                            &data_cap_set[0],
322                            ((*partition_idx, message), commit_lsn, diff),
323                        )
324                        .await;
325                }
326            }
327
328            Err(TransientError::ReplicationEOF)
329        }))
330    });
331
332    let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
333
334    (
335        data_stream.as_collection(),
336        upper_stream,
337        error_stream,
338        stats_stream,
339        button.press_on_drop(),
340    )
341}
342
343type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
344    T,
345    AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
346    Tee<T, TimelyStack<(D, T, Diff)>>,
347>;
348
349/// Helper method to return a "definite" error upstream.
350async fn return_definite_error(
351    err: DefiniteError,
352    outputs: &[usize],
353    data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
354    data_capset: &CapabilitySet<Lsn>,
355    errs_handle: AsyncOutputHandle<
356        Lsn,
357        CapacityContainerBuilder<Vec<ReplicationError>>,
358        Tee<Lsn, Vec<ReplicationError>>,
359    >,
360    errs_capset: &CapabilitySet<Lsn>,
361) {
362    for output_idx in outputs {
363        let update = (
364            (u64::cast_from(*output_idx), Err(err.clone().into())),
365            // TODO(sql_server1): Provide the correct LSN.
366            Lsn::minimum(),
367            Diff::ONE,
368        );
369        data_handle.give_fueled(&data_capset[0], update).await;
370    }
371    errs_handle.give(
372        &errs_capset[0],
373        ReplicationError::DefiniteError(Rc::new(err)),
374    );
375}