1use std::collections::{BTreeMap, BTreeSet};
11use std::fmt;
12use std::sync::Arc;
13use std::time::Duration;
14
15use derivative::Derivative;
16use futures::{Stream, StreamExt};
17use proptest_derive::Arbitrary;
18use serde::{Deserialize, Serialize};
19
20use crate::{Client, SqlServerError, TransactionIsolationLevel};
21
22pub struct CdcStream<'a> {
29 client: &'a mut Client,
31 capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
33 poll_interval: Duration,
35}
36
37impl<'a> CdcStream<'a> {
38 pub(crate) fn new(
39 client: &'a mut Client,
40 capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
41 ) -> Self {
42 CdcStream {
43 client,
44 capture_instances,
45 poll_interval: Duration::from_secs(1),
46 }
47 }
48
49 pub fn start_lsn(mut self, capture_instance: &str, lsn: Lsn) -> Self {
54 let start_lsn = self
55 .capture_instances
56 .get_mut(capture_instance)
57 .expect("capture instance does not exist");
58 *start_lsn = Some(lsn);
59 self
60 }
61
62 pub fn poll_interval(mut self, interval: Duration) -> Self {
66 self.poll_interval = interval;
67 self
68 }
69
70 pub async fn snapshot<'b>(
75 &'b mut self,
76 instances: Option<BTreeSet<Arc<str>>>,
77 ) -> Result<
78 (
79 Lsn,
80 impl Stream<Item = (Arc<str>, Result<tiberius::Row, SqlServerError>)> + use<'b, 'a>,
81 ),
82 SqlServerError,
83 > {
84 let instances = self
86 .capture_instances
87 .keys()
88 .filter(|i| match instances.as_ref() {
89 Some(filter) => filter.contains(i.as_ref()),
91 None => true,
92 })
93 .map(|i| i.as_ref());
94 let tables =
95 crate::inspect::get_tables_for_capture_instance(self.client, instances).await?;
96 tracing::info!(?tables, "got table for capture instance");
97
98 self.client
99 .set_transaction_isolation(TransactionIsolationLevel::Snapshot)
100 .await?;
101 let txn = self.client.transaction().await?;
102
103 let lsn = crate::inspect::get_max_lsn(txn.client).await?;
105 tracing::info!(?tables, ?lsn, "starting snapshot");
106
107 let stream = async_stream::stream! {
109 for (capture_instance, schema_name, table_name) in tables {
112 tracing::trace!(%capture_instance, %schema_name, %table_name, "snapshot start");
113
114 let query = format!("SELECT * FROM {schema_name}.{table_name};");
115 let snapshot = txn.client.query_streaming(&query, &[]);
116 let mut snapshot = std::pin::pin!(snapshot);
117 while let Some(result) = snapshot.next().await {
118 yield (Arc::clone(&capture_instance), result);
119 }
120
121 tracing::trace!(%capture_instance, %schema_name, %table_name, "snapshot end");
122 }
123
124 if let Err(e) = txn.commit().await {
127 yield ("commit".into(), Err(e));
128 }
129 };
130
131 Ok((lsn, stream))
132 }
133
134 pub fn into_stream(mut self) -> impl Stream<Item = Result<CdcEvent, SqlServerError>> + use<'a> {
136 async_stream::try_stream! {
137 self.initialize_start_lsns().await?;
139
140 loop {
141 let next_tick = tokio::time::Instant::now()
144 .checked_add(self.poll_interval)
145 .expect("tick interval overflowed!");
146
147 let maybe_curr_lsn = self.capture_instances.values().filter_map(|x| *x).min();
150 let Some(curr_lsn) = maybe_curr_lsn else {
151 tracing::warn!("shutting down CDC stream because nothing to replicate");
152 break;
153 };
154
155 let new_lsn = crate::inspect::get_max_lsn(self.client).await?;
157 tracing::debug!(?new_lsn, ?curr_lsn, "got max LSN");
158
159 if new_lsn > curr_lsn {
161 for (instance, instance_lsn) in &self.capture_instances {
162 let Some(instance_lsn) = instance_lsn.as_ref() else {
163 tracing::error!(?instance, "found uninitialized LSN!");
164 continue;
165 };
166
167 if new_lsn < *instance_lsn {
170 continue;
171 }
172
173 let changes = crate::inspect::get_changes(
175 self.client,
176 &*instance,
177 *instance_lsn,
178 new_lsn,
179 RowFilterOption::AllUpdateOld,
180 )
181 .await?;
182
183 let mut events: BTreeMap<Lsn, Vec<Operation>> = BTreeMap::default();
188 for change in changes {
189 let (lsn, operation) = Operation::try_parse(change)?;
190 events.entry(lsn).or_default().push(operation);
192 }
193
194 for (lsn, changes) in events {
195 let capture_instance = Arc::clone(instance);
198 let mark_complete = Box::new(move || {
199 let _capture_isntance = capture_instance;
200 let _completed_lsn = lsn;
201 });
202 let event = CdcEvent::Data {
203 capture_instance: Arc::clone(instance),
204 lsn,
205 changes,
206 mark_complete,
207 };
208
209 yield event;
210 }
211 }
212
213 let next_lsn = crate::inspect::increment_lsn(self.client, new_lsn).await?;
215 tracing::debug!(?curr_lsn, ?next_lsn, "incrementing LSN");
216
217 yield CdcEvent::Progress { next_lsn };
221
222 for instance_lsn in self.capture_instances.values_mut() {
224 let instance_lsn = instance_lsn.as_mut().expect("should be initialized");
225 *instance_lsn = std::cmp::max(*instance_lsn, next_lsn);
227 }
228 }
229
230 tokio::time::sleep_until(next_tick).await;
231 }
232 }
233 }
234
235 async fn initialize_start_lsns(&mut self) -> Result<(), SqlServerError> {
237 let max_lsn = crate::inspect::get_max_lsn(self.client).await?;
240 for (_instance, requsted_lsn) in self.capture_instances.iter_mut() {
241 if requsted_lsn.is_none() {
242 requsted_lsn.replace(max_lsn);
243 }
244 }
245
246 for (instance, requested_lsn) in self.capture_instances.iter() {
248 let requested_lsn = requested_lsn
249 .as_ref()
250 .expect("initialized all values above");
251
252 let available_lsn = crate::inspect::get_min_lsn(self.client, &*instance).await?;
254
255 if *requested_lsn < available_lsn {
257 return Err(CdcError::LsnNotAvailable {
258 requested: *requested_lsn,
259 minimum: available_lsn,
260 }
261 .into());
262 }
263 }
264
265 Ok(())
266 }
267}
268
269#[derive(Derivative)]
271#[derivative(Debug)]
272pub enum CdcEvent {
273 Data {
275 capture_instance: Arc<str>,
277 lsn: Lsn,
279 changes: Vec<Operation>,
281 #[derivative(Debug = "ignore")]
283 mark_complete: Box<dyn FnOnce() + Send + Sync>,
284 },
285 Progress {
287 next_lsn: Lsn,
289 },
290}
291
292#[derive(Debug, thiserror::Error)]
293pub enum CdcError {
294 #[error("the requested LSN '{requested:?}' is less then the minimum '{minimum:?}'")]
295 LsnNotAvailable { requested: Lsn, minimum: Lsn },
296 #[error("failed to get the required column '{column_name}': {error}")]
297 RequiredColumn {
298 column_name: &'static str,
299 error: String,
300 },
301}
302
303#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
319pub struct Lsn([u8; 10]);
320
321impl Lsn {
322 pub fn interpret(bytes: [u8; 10]) -> Self {
324 Lsn(bytes)
325 }
326
327 pub fn as_bytes(&self) -> &[u8] {
329 self.0.as_slice()
330 }
331
332 pub fn as_structured(&self) -> StructuredLsn {
334 let vlf_id: [u8; 4] = self.0[0..4].try_into().expect("known good length");
335 let block_id: [u8; 4] = self.0[4..8].try_into().expect("known good length");
336 let record_id: [u8; 2] = self.0[8..].try_into().expect("known good length");
337
338 StructuredLsn {
339 vlf_id: u32::from_be_bytes(vlf_id),
340 block_id: u32::from_be_bytes(block_id),
341 record_id: u16::from_be_bytes(record_id),
342 }
343 }
344}
345
346impl Ord for Lsn {
347 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
348 self.as_structured().cmp(&other.as_structured())
349 }
350}
351
352impl PartialOrd for Lsn {
353 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
354 Some(self.cmp(other))
355 }
356}
357
358impl TryFrom<&[u8]> for Lsn {
359 type Error = String;
360
361 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
362 let value: [u8; 10] = value
363 .try_into()
364 .map_err(|_| format!("incorrect length, expected 10 got {}", value.len()))?;
365 Ok(Lsn(value))
366 }
367}
368
369impl fmt::Display for Lsn {
370 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371 write!(f, "{}", hex::encode(&self.0[..]))
372 }
373}
374
375impl columnation::Columnation for Lsn {
376 type InnerRegion = columnation::CopyRegion<Lsn>;
377}
378
379impl timely::progress::Timestamp for Lsn {
380 type Summary = ();
382
383 fn minimum() -> Self {
384 Lsn(Default::default())
385 }
386}
387
388impl timely::progress::PathSummary<Lsn> for () {
389 fn results_in(&self, src: &Lsn) -> Option<Lsn> {
390 Some(*src)
391 }
392
393 fn followed_by(&self, _other: &Self) -> Option<Self> {
394 Some(())
395 }
396}
397
398impl timely::progress::timestamp::Refines<()> for Lsn {
399 fn to_inner(_other: ()) -> Self {
400 use timely::progress::Timestamp;
401 Self::minimum()
402 }
403 fn to_outer(self) -> () {}
404
405 fn summarize(_path: <Self as timely::progress::Timestamp>::Summary) -> () {}
406}
407
408impl timely::order::PartialOrder for Lsn {
409 fn less_equal(&self, other: &Self) -> bool {
410 self <= other
411 }
412
413 fn less_than(&self, other: &Self) -> bool {
414 self < other
415 }
416}
417impl timely::order::TotalOrder for Lsn {}
418
419#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
424pub struct StructuredLsn {
425 vlf_id: u32,
426 block_id: u32,
427 record_id: u16,
428}
429
430#[derive(Debug, Copy, Clone)]
438pub enum RowFilterOption {
439 AllUpdateOld,
441}
442
443impl RowFilterOption {
444 pub fn to_sql_string(&self) -> &'static str {
446 match self {
447 RowFilterOption::AllUpdateOld => "all update old",
448 }
449 }
450}
451
452impl fmt::Display for RowFilterOption {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 write!(f, "{}", self.to_sql_string())
455 }
456}
457
458#[derive(Debug)]
460pub enum Operation {
461 Insert(tiberius::Row),
463 Delete(tiberius::Row),
465 UpdateOld(tiberius::Row),
467 UpdateNew(tiberius::Row),
469}
470
471impl Operation {
472 fn try_parse(data: tiberius::Row) -> Result<(Lsn, Self), SqlServerError> {
476 static START_LSN_COLUMN: &str = "__$start_lsn";
477 static OPERATION_COLUMN: &str = "__$operation";
478
479 let lsn: &[u8] = data
480 .try_get(START_LSN_COLUMN)
481 .map_err(|e| CdcError::RequiredColumn {
482 column_name: START_LSN_COLUMN,
483 error: e.to_string(),
484 })?
485 .ok_or_else(|| CdcError::RequiredColumn {
486 column_name: START_LSN_COLUMN,
487 error: "got null value".to_string(),
488 })?;
489 let operation: i32 = data
490 .try_get(OPERATION_COLUMN)
491 .map_err(|e| CdcError::RequiredColumn {
492 column_name: OPERATION_COLUMN,
493 error: e.to_string(),
494 })?
495 .ok_or_else(|| CdcError::RequiredColumn {
496 column_name: OPERATION_COLUMN,
497 error: "got null value".to_string(),
498 })?;
499
500 let lsn = Lsn::try_from(lsn).map_err(|msg| SqlServerError::InvalidData {
501 column_name: START_LSN_COLUMN.to_string(),
502 error: msg,
503 })?;
504 let operation = match operation {
505 1 => Operation::Delete(data),
506 2 => Operation::Insert(data),
507 3 => Operation::UpdateOld(data),
508 4 => Operation::UpdateNew(data),
509 other => {
510 return Err(SqlServerError::InvalidData {
511 column_name: OPERATION_COLUMN.to_string(),
512 error: format!("unrecognized operation {other}"),
513 });
514 }
515 };
516
517 Ok((lsn, operation))
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::Lsn;
524
525 #[mz_ore::test]
526 fn smoketest_lsn_ordering() {
527 let a = hex::decode("0000003D000019B80004").unwrap();
528 let a = Lsn::try_from(&a[..]).unwrap();
529
530 let b = hex::decode("0000003D000019F00011").unwrap();
531 let b = Lsn::try_from(&b[..]).unwrap();
532
533 let c = hex::decode("0000003D00001A500003").unwrap();
534 let c = Lsn::try_from(&c[..]).unwrap();
535
536 assert!(a < b);
537 assert!(b < c);
538 assert!(a < c);
539
540 assert_eq!(a, a);
541 assert_eq!(b, b);
542 assert_eq!(c, c);
543 }
544}