mz_sql_server_util/
cdc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::{BTreeMap, BTreeSet};
11use std::fmt;
12use std::sync::Arc;
13use std::time::Duration;
14
15use derivative::Derivative;
16use futures::{Stream, StreamExt};
17use proptest_derive::Arbitrary;
18use serde::{Deserialize, Serialize};
19
20use crate::{Client, SqlServerError, TransactionIsolationLevel};
21
22/// A stream of changes from a table in SQL Server that has CDC enabled.
23///
24/// SQL Server does not have an API to push or notify consumers of changes, so we periodically
25/// poll the upstream source.
26///
27/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/change-data-capture-tables-transact-sql?view=sql-server-ver16>
28pub struct CdcStream<'a> {
29    /// Client we use for querying SQL Server.
30    client: &'a mut Client,
31    /// Upstream capture instances we'll list changes from.
32    capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
33    /// How often we poll the upstream for changes.
34    poll_interval: Duration,
35}
36
37impl<'a> CdcStream<'a> {
38    pub(crate) fn new(
39        client: &'a mut Client,
40        capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
41    ) -> Self {
42        CdcStream {
43            client,
44            capture_instances,
45            poll_interval: Duration::from_secs(1),
46        }
47    }
48
49    /// Set the [`Lsn`] that we should start streaming changes from.
50    ///
51    /// If the provided [`Lsn`] is not available, the stream will return an error
52    /// when first polled.
53    pub fn start_lsn(mut self, capture_instance: &str, lsn: Lsn) -> Self {
54        let start_lsn = self
55            .capture_instances
56            .get_mut(capture_instance)
57            .expect("capture instance does not exist");
58        *start_lsn = Some(lsn);
59        self
60    }
61
62    /// The cadence at which we'll poll the upstream SQL Server database for changes.
63    ///
64    /// Default is 1 second.
65    pub fn poll_interval(mut self, interval: Duration) -> Self {
66        self.poll_interval = interval;
67        self
68    }
69
70    /// Takes a snapshot of the upstream table that the specified `capture_instance` is
71    /// replicating changes from.
72    ///
73    /// An optional `instances` parameter can be provided to only snapshot the specified instances.
74    pub async fn snapshot<'b>(
75        &'b mut self,
76        instances: Option<BTreeSet<Arc<str>>>,
77    ) -> Result<
78        (
79            Lsn,
80            impl Stream<Item = (Arc<str>, Result<tiberius::Row, SqlServerError>)> + use<'b, 'a>,
81        ),
82        SqlServerError,
83    > {
84        // Determine what table we need to snapshot.
85        let instances = self
86            .capture_instances
87            .keys()
88            .filter(|i| match instances.as_ref() {
89                // Only snapshot the instance if the filter includes it.
90                Some(filter) => filter.contains(i.as_ref()),
91                None => true,
92            })
93            .map(|i| i.as_ref());
94        let tables =
95            crate::inspect::get_tables_for_capture_instance(self.client, instances).await?;
96        tracing::info!(?tables, "got table for capture instance");
97
98        self.client
99            .set_transaction_isolation(TransactionIsolationLevel::Snapshot)
100            .await?;
101        let txn = self.client.transaction().await?;
102
103        // Get the current LSN of the database.
104        let lsn = crate::inspect::get_max_lsn(txn.client).await?;
105        tracing::info!(?tables, ?lsn, "starting snapshot");
106
107        // Run a `SELECT` query to snapshot the entire table.
108        let stream = async_stream::stream! {
109            // TODO(sql_server3): A stream of streams would be better here than
110            // returning the name with each result, but the lifetimes are tricky.
111            for (capture_instance, schema_name, table_name) in tables {
112                tracing::trace!(%capture_instance, %schema_name, %table_name, "snapshot start");
113
114                let query = format!("SELECT * FROM {schema_name}.{table_name};");
115                let snapshot = txn.client.query_streaming(&query, &[]);
116                let mut snapshot = std::pin::pin!(snapshot);
117                while let Some(result) = snapshot.next().await {
118                    yield (Arc::clone(&capture_instance), result);
119                }
120
121                tracing::trace!(%capture_instance, %schema_name, %table_name, "snapshot end");
122            }
123
124            // Slightly awkward, but if the commit fails we need to conform to
125            // type of the stream.
126            if let Err(e) = txn.commit().await {
127                yield ("commit".into(), Err(e));
128            }
129        };
130
131        Ok((lsn, stream))
132    }
133
134    /// Consume `self` returning a [`Stream`] of [`CdcEvent`]s.
135    pub fn into_stream(mut self) -> impl Stream<Item = Result<CdcEvent, SqlServerError>> + use<'a> {
136        async_stream::try_stream! {
137            // Initialize all of our start LSNs.
138            self.initialize_start_lsns().await?;
139
140            loop {
141                // Measure the tick before we do any operation so the time it takes
142                // to query SQL Server is included in the time that we wait.
143                let next_tick = tokio::time::Instant::now()
144                    .checked_add(self.poll_interval)
145                    .expect("tick interval overflowed!");
146
147                // We always check for changes based on the "global" minimum LSN of any
148                // one capture instance.
149                let maybe_curr_lsn = self.capture_instances.values().filter_map(|x| *x).min();
150                let Some(curr_lsn) = maybe_curr_lsn else {
151                    tracing::warn!("shutting down CDC stream because nothing to replicate");
152                    break;
153                };
154
155                // Get the max LSN for the DB.
156                let new_lsn = crate::inspect::get_max_lsn(self.client).await?;
157                tracing::debug!(?new_lsn, ?curr_lsn, "got max LSN");
158
159                // If the LSN has increased then get all of our changes.
160                if new_lsn > curr_lsn {
161                    for (instance, instance_lsn) in &self.capture_instances {
162                        let Some(instance_lsn) = instance_lsn.as_ref() else {
163                            tracing::error!(?instance, "found uninitialized LSN!");
164                            continue;
165                        };
166
167                        // We've already replicated everything up-to new_lsn, so
168                        // nothing to do.
169                        if new_lsn < *instance_lsn {
170                            continue;
171                        }
172
173                        // List all the changes for the current instance.
174                        let changes = crate::inspect::get_changes(
175                            self.client,
176                            &*instance,
177                            *instance_lsn,
178                            new_lsn,
179                            RowFilterOption::AllUpdateOld,
180                        )
181                        .await?;
182
183                        // TODO(sql_server1): For very large changes it feels bad collecting
184                        // them all in memory, it would be best if we streamed them to the
185                        // caller.
186
187                        let mut events: BTreeMap<Lsn, Vec<Operation>> = BTreeMap::default();
188                        for change in changes {
189                            let (lsn, operation) = Operation::try_parse(change)?;
190                            // Group all the operations by their LSN.
191                            events.entry(lsn).or_default().push(operation);
192                        }
193
194                        for (lsn, changes) in events {
195                            // TODO(sql_server1): Handle these events and notify the upstream
196                            // that we can cleanup this LSN.
197                            let capture_instance = Arc::clone(instance);
198                            let mark_complete = Box::new(move || {
199                                let _capture_isntance = capture_instance;
200                                let _completed_lsn = lsn;
201                            });
202                            let event = CdcEvent::Data {
203                                capture_instance: Arc::clone(instance),
204                                lsn,
205                                changes,
206                                mark_complete,
207                            };
208
209                            yield event;
210                        }
211                    }
212
213                    // Increment our LSN (`get_changes` is inclusive).
214                    let next_lsn = crate::inspect::increment_lsn(self.client, new_lsn).await?;
215                    tracing::debug!(?curr_lsn, ?next_lsn, "incrementing LSN");
216
217                    // Notify our listener that we've emitted all changes __less than__ this LSN.
218                    //
219                    // Note: This aligns well with timely's semantics of progress tracking.
220                    yield CdcEvent::Progress { next_lsn };
221
222                    // We just listed everything upto next_lsn.
223                    for instance_lsn in self.capture_instances.values_mut() {
224                        let instance_lsn = instance_lsn.as_mut().expect("should be initialized");
225                        // Ensure LSNs don't go backwards.
226                        *instance_lsn = std::cmp::max(*instance_lsn, next_lsn);
227                    }
228                }
229
230                tokio::time::sleep_until(next_tick).await;
231            }
232        }
233    }
234
235    /// Determine the [`Lsn`] to start streaming changes from.
236    async fn initialize_start_lsns(&mut self) -> Result<(), SqlServerError> {
237        // First, initialize all start LSNs. If a capture instance didn't have
238        // one specified then we'll start from the current max.
239        let max_lsn = crate::inspect::get_max_lsn(self.client).await?;
240        for (_instance, requsted_lsn) in self.capture_instances.iter_mut() {
241            if requsted_lsn.is_none() {
242                requsted_lsn.replace(max_lsn);
243            }
244        }
245
246        // For each instance, ensure their requested LSN is available.
247        for (instance, requested_lsn) in self.capture_instances.iter() {
248            let requested_lsn = requested_lsn
249                .as_ref()
250                .expect("initialized all values above");
251
252            // Get the minimum Lsn available for this instance.
253            let available_lsn = crate::inspect::get_min_lsn(self.client, &*instance).await?;
254
255            // If we cannot start at our desired LSN, we must return an error!.
256            if *requested_lsn < available_lsn {
257                return Err(CdcError::LsnNotAvailable {
258                    requested: *requested_lsn,
259                    minimum: available_lsn,
260                }
261                .into());
262            }
263        }
264
265        Ok(())
266    }
267}
268
269/// A change event from a [`CdcStream`].
270#[derive(Derivative)]
271#[derivative(Debug)]
272pub enum CdcEvent {
273    /// Changes have occurred upstream.
274    Data {
275        /// The capture instance these changes are for.
276        capture_instance: Arc<str>,
277        /// The LSN that this change occurred at.
278        lsn: Lsn,
279        /// The change itself.
280        changes: Vec<Operation>,
281        /// When called marks `lsn` as complete allowing the upstream DB to clean up the record.
282        #[derivative(Debug = "ignore")]
283        mark_complete: Box<dyn FnOnce() + Send + Sync>,
284    },
285    /// We've made progress and observed all the changes less than `next_lsn`.
286    Progress {
287        /// We've received all of the data for [`Lsn`]s __less than__ this one.
288        next_lsn: Lsn,
289    },
290}
291
292#[derive(Debug, thiserror::Error)]
293pub enum CdcError {
294    #[error("the requested LSN '{requested:?}' is less then the minimum '{minimum:?}'")]
295    LsnNotAvailable { requested: Lsn, minimum: Lsn },
296    #[error("failed to get the required column '{column_name}': {error}")]
297    RequiredColumn {
298        column_name: &'static str,
299        error: String,
300    },
301}
302
303/// This type is used to represent the progress of each SQL Server instance in
304/// the ingestion dataflow.
305///
306/// A SQL Server LSN is essentially an opaque binary blob that provides a
307/// __total order__ to all transations within a database. Technically though an
308/// LSN has three components:
309///
310/// 1. A Virtual Log File (VLF) sequence number, bytes [0, 4)
311/// 2. Log block number, bytes [4, 8)
312/// 3. Log record number, bytes [8, 10)
313///
314/// To increment an LSN you need to call the [`sys.fn_cdc_increment_lsn`] T-SQL
315/// function.
316///
317/// [`sys.fn_cdc_increment_lsn`](https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/sys-fn-cdc-increment-lsn-transact-sql?view=sql-server-ver16)
318#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
319pub struct Lsn([u8; 10]);
320
321impl Lsn {
322    /// Interpret the provided bytes as an [`Lsn`].
323    pub fn interpret(bytes: [u8; 10]) -> Self {
324        Lsn(bytes)
325    }
326
327    /// Return the underlying byte slice for this [`Lsn`].
328    pub fn as_bytes(&self) -> &[u8] {
329        self.0.as_slice()
330    }
331
332    /// Returns `self` as a [`StructuredLsn`].
333    pub fn as_structured(&self) -> StructuredLsn {
334        let vlf_id: [u8; 4] = self.0[0..4].try_into().expect("known good length");
335        let block_id: [u8; 4] = self.0[4..8].try_into().expect("known good length");
336        let record_id: [u8; 2] = self.0[8..].try_into().expect("known good length");
337
338        StructuredLsn {
339            vlf_id: u32::from_be_bytes(vlf_id),
340            block_id: u32::from_be_bytes(block_id),
341            record_id: u16::from_be_bytes(record_id),
342        }
343    }
344}
345
346impl Ord for Lsn {
347    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
348        self.as_structured().cmp(&other.as_structured())
349    }
350}
351
352impl PartialOrd for Lsn {
353    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
354        Some(self.cmp(other))
355    }
356}
357
358impl TryFrom<&[u8]> for Lsn {
359    type Error = String;
360
361    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
362        let value: [u8; 10] = value
363            .try_into()
364            .map_err(|_| format!("incorrect length, expected 10 got {}", value.len()))?;
365        Ok(Lsn(value))
366    }
367}
368
369impl fmt::Display for Lsn {
370    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
371        write!(f, "{}", hex::encode(&self.0[..]))
372    }
373}
374
375impl columnation::Columnation for Lsn {
376    type InnerRegion = columnation::CopyRegion<Lsn>;
377}
378
379impl timely::progress::Timestamp for Lsn {
380    // No need to describe complex summaries.
381    type Summary = ();
382
383    fn minimum() -> Self {
384        Lsn(Default::default())
385    }
386}
387
388impl timely::progress::PathSummary<Lsn> for () {
389    fn results_in(&self, src: &Lsn) -> Option<Lsn> {
390        Some(*src)
391    }
392
393    fn followed_by(&self, _other: &Self) -> Option<Self> {
394        Some(())
395    }
396}
397
398impl timely::progress::timestamp::Refines<()> for Lsn {
399    fn to_inner(_other: ()) -> Self {
400        use timely::progress::Timestamp;
401        Self::minimum()
402    }
403    fn to_outer(self) -> () {}
404
405    fn summarize(_path: <Self as timely::progress::Timestamp>::Summary) -> () {}
406}
407
408impl timely::order::PartialOrder for Lsn {
409    fn less_equal(&self, other: &Self) -> bool {
410        self <= other
411    }
412
413    fn less_than(&self, other: &Self) -> bool {
414        self < other
415    }
416}
417impl timely::order::TotalOrder for Lsn {}
418
419/// Structured format of an [`Lsn`].
420///
421/// Note: The derived impl of [`PartialOrd`] and [`Ord`] relies on the field
422/// ordering so do not change it.
423#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
424pub struct StructuredLsn {
425    vlf_id: u32,
426    block_id: u32,
427    record_id: u16,
428}
429
430/// When querying CDC functions like `cdc.fn_cdc_get_all_changes_<capture_instance>` this governs
431/// what content is returned.
432///
433/// Note: There exists another option `All` that exclude the _before_ value from an `UPDATE`. We
434/// don't support this for SQL Server sources yet, so it's not included in this enum.
435///
436/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#row_filter_option>
437#[derive(Debug, Copy, Clone)]
438pub enum RowFilterOption {
439    /// Includes both the before and after values of a row when changed because of an `UPDATE`.
440    AllUpdateOld,
441}
442
443impl RowFilterOption {
444    /// Returns this option formatted in a way that can be used in a query.
445    pub fn to_sql_string(&self) -> &'static str {
446        match self {
447            RowFilterOption::AllUpdateOld => "all update old",
448        }
449    }
450}
451
452impl fmt::Display for RowFilterOption {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        write!(f, "{}", self.to_sql_string())
455    }
456}
457
458/// Identifies what change was made to the SQL Server table tracked by CDC.
459#[derive(Debug)]
460pub enum Operation {
461    /// Row was `INSERT`-ed.
462    Insert(tiberius::Row),
463    /// Row was `DELETE`-ed.
464    Delete(tiberius::Row),
465    /// Original value of the row when `UPDATE`-ed.
466    UpdateOld(tiberius::Row),
467    /// New value of the row when `UPDATE`-ed.
468    UpdateNew(tiberius::Row),
469}
470
471impl Operation {
472    /// Parse the provided [`tiberius::Row`] to determine what [`Operation`] occurred.
473    ///
474    /// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#table-returned>.
475    fn try_parse(data: tiberius::Row) -> Result<(Lsn, Self), SqlServerError> {
476        static START_LSN_COLUMN: &str = "__$start_lsn";
477        static OPERATION_COLUMN: &str = "__$operation";
478
479        let lsn: &[u8] = data
480            .try_get(START_LSN_COLUMN)
481            .map_err(|e| CdcError::RequiredColumn {
482                column_name: START_LSN_COLUMN,
483                error: e.to_string(),
484            })?
485            .ok_or_else(|| CdcError::RequiredColumn {
486                column_name: START_LSN_COLUMN,
487                error: "got null value".to_string(),
488            })?;
489        let operation: i32 = data
490            .try_get(OPERATION_COLUMN)
491            .map_err(|e| CdcError::RequiredColumn {
492                column_name: OPERATION_COLUMN,
493                error: e.to_string(),
494            })?
495            .ok_or_else(|| CdcError::RequiredColumn {
496                column_name: OPERATION_COLUMN,
497                error: "got null value".to_string(),
498            })?;
499
500        let lsn = Lsn::try_from(lsn).map_err(|msg| SqlServerError::InvalidData {
501            column_name: START_LSN_COLUMN.to_string(),
502            error: msg,
503        })?;
504        let operation = match operation {
505            1 => Operation::Delete(data),
506            2 => Operation::Insert(data),
507            3 => Operation::UpdateOld(data),
508            4 => Operation::UpdateNew(data),
509            other => {
510                return Err(SqlServerError::InvalidData {
511                    column_name: OPERATION_COLUMN.to_string(),
512                    error: format!("unrecognized operation {other}"),
513                });
514            }
515        };
516
517        Ok((lsn, operation))
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::Lsn;
524
525    #[mz_ore::test]
526    fn smoketest_lsn_ordering() {
527        let a = hex::decode("0000003D000019B80004").unwrap();
528        let a = Lsn::try_from(&a[..]).unwrap();
529
530        let b = hex::decode("0000003D000019F00011").unwrap();
531        let b = Lsn::try_from(&b[..]).unwrap();
532
533        let c = hex::decode("0000003D00001A500003").unwrap();
534        let c = Lsn::try_from(&c[..]).unwrap();
535
536        assert!(a < b);
537        assert!(b < c);
538        assert!(a < c);
539
540        assert_eq!(a, a);
541        assert_eq!(b, b);
542        assert_eq!(c, c);
543    }
544}