1use std::collections::{BTreeMap, BTreeSet};
13use std::convert::Infallible;
14use std::rc::Rc;
15use std::sync::Arc;
16use std::time::Instant;
17
18use differential_dataflow::AsCollection;
19use differential_dataflow::containers::TimelyStack;
20use futures::StreamExt;
21use mz_ore::cast::CastFrom;
22use mz_ore::future::InTask;
23use mz_repr::{Diff, GlobalId, Row, RowArena};
24use mz_sql_server_util::cdc::{CdcEvent, Lsn, Operation as CdcOperation};
25use mz_storage_types::errors::{DataflowError, DecodeError, DecodeErrorKind};
26use mz_storage_types::sources::SqlServerSource;
27use mz_storage_types::sources::sql_server::{
28 CDC_POLL_INTERVAL, SNAPSHOT_MAX_LSN_WAIT, SNAPSHOT_PROGRESS_REPORT_INTERVAL,
29};
30use mz_timely_util::builder_async::{
31 AsyncOutputHandle, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
32};
33use mz_timely_util::containers::stack::AccountedStackBuilder;
34use timely::container::CapacityContainerBuilder;
35use timely::dataflow::channels::pushers::Tee;
36use timely::dataflow::operators::{CapabilitySet, Concat, Map};
37use timely::dataflow::{Scope, Stream as TimelyStream};
38use timely::progress::{Antichain, Timestamp};
39
40use crate::source::RawSourceCreationConfig;
41use crate::source::sql_server::{
42 DefiniteError, ReplicationError, SourceOutputInfo, TransientError,
43};
44use crate::source::types::{
45 ProgressStatisticsUpdate, SignaledFuture, SourceMessage, StackedCollection,
46};
47
48static REPL_READER: &str = "reader";
54
55pub(crate) fn render<G: Scope<Timestamp = Lsn>>(
56 scope: G,
57 config: RawSourceCreationConfig,
58 outputs: BTreeMap<GlobalId, SourceOutputInfo>,
59 source: SqlServerSource,
60) -> (
61 StackedCollection<G, (u64, Result<SourceMessage, DataflowError>)>,
62 TimelyStream<G, Infallible>,
63 TimelyStream<G, ReplicationError>,
64 TimelyStream<G, ProgressStatisticsUpdate>,
65 PressOnDropButton,
66) {
67 let op_name = format!("SqlServerReplicationReader({})", config.id);
68 let mut builder = AsyncOperatorBuilder::new(op_name, scope);
69
70 let (data_output, data_stream) = builder.new_output::<AccountedStackBuilder<_>>();
71 let (_upper_output, upper_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
72 let (stats_output, stats_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
73
74 let (definite_error_handle, definite_errors) =
76 builder.new_output::<CapacityContainerBuilder<_>>();
77
78 let (button, transient_errors) = builder.build_fallible(move |caps| {
79 let busy_signal = Arc::clone(&config.busy_signal);
80 Box::pin(SignaledFuture::new(busy_signal, async move {
81 let [
82 data_cap_set,
83 upper_cap_set,
84 stats_cap,
85 definite_error_cap_set,
86 ]: &mut [_; 4] = caps.try_into().unwrap();
87
88 if !config.responsible_for(REPL_READER) {
90 return Ok::<_, TransientError>(());
91 }
92
93 let connection_config = source
94 .connection
95 .resolve_config(
96 &config.config.connection_context.secrets_reader,
97 &config.config,
98 InTask::Yes,
99 )
100 .await?;
101 let mut client = mz_sql_server_util::Client::connect(connection_config).await?;
102
103 let output_indexes: Vec<_> = outputs
104 .values()
105 .map(|v| usize::cast_from(v.partition_index))
106 .collect();
107
108 let needs_snapshot: BTreeSet<_> = outputs
110 .values()
111 .filter_map(|output| {
112 if *output.resume_upper == [Lsn::minimum()] {
113 Some(Arc::clone(&output.capture_instance))
114 } else {
115 None
116 }
117 })
118 .collect();
119
120 let capture_instances: BTreeMap<_, _> = outputs
122 .values()
123 .map(|output| {
124 (
125 Arc::clone(&output.capture_instance),
126 (output.partition_index, Arc::clone(&output.decoder)),
127 )
128 })
129 .collect();
130 let mut cdc_handle = client
131 .cdc(capture_instances.keys().cloned())
132 .max_lsn_wait(SNAPSHOT_MAX_LSN_WAIT.get(config.config.config_set()));
133
134 let snapshot_lsn = {
136 let emit_stats = |cap, known: usize, total: usize| {
138 let update = ProgressStatisticsUpdate::Snapshot {
139 records_known: u64::cast_from(known),
140 records_staged: u64::cast_from(total),
141 };
142 tracing::debug!(?config.id, %known, %total, "snapshot progress");
143 stats_output.give(cap, update);
144 };
145
146 tracing::debug!(?config.id, ?needs_snapshot, "starting snapshot");
147 if !needs_snapshot.is_empty() {
149 emit_stats(&stats_cap[0], 0, 0);
150 }
151
152 let (snapshot_lsn, snapshot_stats, snapshot_streams) =
153 cdc_handle.snapshot(Some(needs_snapshot)).await?;
154 let snapshot_cap = data_cap_set.delayed(&snapshot_lsn);
155
156 let mut records_total: usize = 0;
158 let records_known = snapshot_stats.values().sum();
159 let report_interval =
160 SNAPSHOT_PROGRESS_REPORT_INTERVAL.handle(config.config.config_set());
161 let mut last_report = Instant::now();
162 if !snapshot_stats.is_empty() {
163 emit_stats(&stats_cap[0], records_known, 0);
164 }
165
166 let mut snapshot_streams = std::pin::pin!(snapshot_streams);
168 while let Some((capture_instance, data)) = snapshot_streams.next().await {
169 let sql_server_row = data.map_err(TransientError::from)?;
170 records_total = records_total.saturating_add(1);
171
172 if last_report.elapsed() > report_interval.get() {
173 last_report = Instant::now();
174 emit_stats(&stats_cap[0], records_known, records_total);
175 }
176
177 let (partition_idx, decoder) =
179 capture_instances.get(&capture_instance).ok_or_else(|| {
180 let msg =
181 format!("capture instance didn't exist: '{capture_instance}'");
182 TransientError::ProgrammingError(msg)
183 })?;
184
185 let mut mz_row = Row::default();
187 let arena = RowArena::default();
188 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
189 Ok(()) => Ok(SourceMessage {
190 key: Row::default(),
191 value: mz_row,
192 metadata: Row::default(),
193 }),
194 Err(e) => {
195 let kind = DecodeErrorKind::Text(e.to_string().into());
196 let raw = format!("{sql_server_row:?}");
198 Err(DataflowError::DecodeError(Box::new(DecodeError {
199 kind,
200 raw: raw.as_bytes().to_vec(),
201 })))
202 }
203 };
204 data_output
205 .give_fueled(
206 &snapshot_cap,
207 ((*partition_idx, message), snapshot_lsn, Diff::ONE),
208 )
209 .await;
210 }
211
212 mz_ore::soft_assert_eq_or_log!(
213 records_known,
214 records_total,
215 "snapshot size did not match total records received",
216 );
217 emit_stats(&stats_cap[0], records_known, records_total);
218
219 snapshot_lsn
220 };
221
222 let replication_start_lsn = snapshot_lsn.increment();
224
225 for output_info in outputs.values() {
227 match output_info.resume_upper.as_option() {
228 Some(lsn) => {
230 let initial_lsn = if *lsn == Lsn::minimum() {
231 replication_start_lsn
232 } else {
233 *lsn
234 };
235 cdc_handle =
236 cdc_handle.start_lsn(&output_info.capture_instance, initial_lsn);
237 }
238 None => unreachable!("empty resume upper?"),
239 }
240 }
241
242 let cdc_stream = cdc_handle
244 .poll_interval(CDC_POLL_INTERVAL.get(config.config.config_set()))
245 .into_stream();
246 let mut cdc_stream = std::pin::pin!(cdc_stream);
247
248 while let Some(event) = cdc_stream.next().await {
252 let event = event.map_err(TransientError::from)?;
253 tracing::trace!(?config.id, ?event, "got replication event");
254
255 let (capture_instance, commit_lsn, changes) = match event {
256 CdcEvent::Progress { next_lsn } => {
259 tracing::debug!(?config.id, ?next_lsn, "got a closed lsn");
260 data_cap_set.downgrade(Antichain::from_elem(next_lsn));
261 upper_cap_set.downgrade(Antichain::from_elem(next_lsn));
262 continue;
263 }
264 CdcEvent::Data {
266 capture_instance,
267 lsn,
268 changes,
269 } => (capture_instance, lsn, changes),
270 };
271
272 let Some((partition_idx, decoder)) = capture_instances.get(&capture_instance)
274 else {
275 let definite_error = DefiniteError::ProgrammingError(format!(
276 "capture instance didn't exist: '{capture_instance}'"
277 ));
278 let () = return_definite_error(
279 definite_error,
280 &output_indexes[..],
281 data_output,
282 data_cap_set,
283 definite_error_handle,
284 definite_error_cap_set,
285 )
286 .await;
287 return Ok(());
288 };
289
290 for change in changes {
291 let (sql_server_row, diff): (_, _) = match change {
292 CdcOperation::Insert(sql_server_row)
293 | CdcOperation::UpdateNew(sql_server_row) => (sql_server_row, Diff::ONE),
294 CdcOperation::Delete(sql_server_row)
295 | CdcOperation::UpdateOld(sql_server_row) => {
296 (sql_server_row, Diff::MINUS_ONE)
297 }
298 };
299
300 let mut mz_row = Row::default();
302 let arena = RowArena::default();
303 let message = match decoder.decode(&sql_server_row, &mut mz_row, &arena) {
304 Ok(()) => Ok(SourceMessage {
305 key: Row::default(),
306 value: mz_row,
307 metadata: Row::default(),
308 }),
309 Err(e) => {
310 let kind = DecodeErrorKind::Text(e.to_string().into());
311 let raw = format!("{sql_server_row:?}");
313 Err(DataflowError::DecodeError(Box::new(DecodeError {
314 kind,
315 raw: raw.as_bytes().to_vec(),
316 })))
317 }
318 };
319 data_output
320 .give_fueled(
321 &data_cap_set[0],
322 ((*partition_idx, message), commit_lsn, diff),
323 )
324 .await;
325 }
326 }
327
328 Err(TransientError::ReplicationEOF)
329 }))
330 });
331
332 let error_stream = definite_errors.concat(&transient_errors.map(ReplicationError::Transient));
333
334 (
335 data_stream.as_collection(),
336 upper_stream,
337 error_stream,
338 stats_stream,
339 button.press_on_drop(),
340 )
341}
342
343type StackedAsyncOutputHandle<T, D> = AsyncOutputHandle<
344 T,
345 AccountedStackBuilder<CapacityContainerBuilder<TimelyStack<(D, T, Diff)>>>,
346 Tee<T, TimelyStack<(D, T, Diff)>>,
347>;
348
349async fn return_definite_error(
351 err: DefiniteError,
352 outputs: &[usize],
353 data_handle: StackedAsyncOutputHandle<Lsn, (u64, Result<SourceMessage, DataflowError>)>,
354 data_capset: &CapabilitySet<Lsn>,
355 errs_handle: AsyncOutputHandle<
356 Lsn,
357 CapacityContainerBuilder<Vec<ReplicationError>>,
358 Tee<Lsn, Vec<ReplicationError>>,
359 >,
360 errs_capset: &CapabilitySet<Lsn>,
361) {
362 for output_idx in outputs {
363 let update = (
364 (u64::cast_from(*output_idx), Err(err.clone().into())),
365 Lsn::minimum(),
367 Diff::ONE,
368 );
369 data_handle.give_fueled(&data_capset[0], update).await;
370 }
371 errs_handle.give(
372 &errs_capset[0],
373 ReplicationError::DefiniteError(Rc::new(err)),
374 );
375}