Skip to main content

mz_sql_server_util/
cdc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replicate a table from SQL Server using their Change-Data-Capture (CDC) primitives.
11//!
12//! This module provides a [`CdcStream`] type that provides the following API for
13//! replicating a table:
14//!
15//! 1. [`CdcStream::snapshot`] returns an initial snapshot of a table and the [`Lsn`] at
16//!    which the snapshot was taken.
17//! 2. [`CdcStream::into_stream`] returns a [`futures::Stream`] of [`CdcEvent`]s
18//!    optionally from the [`Lsn`] returned in step 1.
19//!
20//! The snapshot process is responsible for identifying an [`Lsn`] that corresponds to
21//! a point-in-time view of the data for the table(s) being copied. Similarly to
22//! MySQL, Microsoft SQL server, as far as we know, does not provide an API to
23//! achieve this.
24//!
25//! SQL Server `SNAPSHOT` isolation provides guarantees that a reader will only
26//! see writes committed before the transaction began.  More specficially, this
27//! snapshot is implemented using versions that are visibile based on the
28//! transaction sequence number (`XSN`). The `XSN` is set at the first
29//! read or write, not at `BEGIN TRANSACTION`, see [here](https://learn.microsoft.com/en-us/sql/relational-databases/sql-server-transaction-locking-and-row-versioning-guide?view=sql-server-ver17).
30//! This provides us a suitable starting point for capturing the table data.
31//! To force an `XSN` to be assigned, experiments have shown that a table must
32//! be read. We choose a well-known table that we should already have access to,
33//! [cdc.change_tables](https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver17),
34//! and read a single value from it.
35//!
36//! Due to the asynchronous nature of CDC, we can assume that the [`Lsn`]
37//! returned from any CDC tables or CDC functions will always be stale,
38//! in relation to the source table that CDC is tracking. The system table
39//! [sys.dm_tran_database_transactions](https://learn.microsoft.com/en-us/sql/relational-databases/system-dynamic-management-views/sys-dm-tran-database-transactions-transact-sql?view=sql-server-ver17)
40//! will contain an [`Lsn`] for any transaction that performs a write operation.
41//! Creating a savepoint using [SAVE TRANSACTION](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/save-transaction-transact-sql?view=sql-server-ver17)
42//! is sufficient to generate an [`Lsn`] in this case.
43//!
44//! To ensure that the the point-in-time view is established atomically with
45//! collection of the [`Lsn`], we lock the tables to prevent writes from being
46//! interleaved between the 2 commands (read to establish `XSN` and creation of
47//! the savepoint).
48//!
49//! SQL server supports table locks, but those will only be released
50//! once the outermost transaction completes. For this reason, this module
51//! uses two connections for the snapshot process. The first connection is used
52//! to initiate a transaction and lock the upstream tables under
53//! [`TransactionIsolationLevel::ReadCommitted`] isolation. While the first
54//! connection maintains the locks, the second connection starts a
55//! transaction with [`TransactionIsolationLevel::Snapshot`] isolation and
56//! creates a savepoint. Once the savepoint is created, SQL server has assigned
57//! an [`Lsn`] and the the first connection rolls back the transaction.
58//! The [`Lsn`] and snapshot are captured by the second connection within the
59//! existing transaction.
60//!
61//! After completing the snapshot we use [`crate::inspect::get_changes_asc`] which will return
62//! all changes between a `[lower, upper)` bound of [`Lsn`]s.
63
64use std::collections::BTreeMap;
65use std::fmt;
66use std::sync::Arc;
67use std::time::Duration;
68
69use derivative::Derivative;
70use futures::{Stream, StreamExt};
71use mz_repr::GlobalId;
72#[cfg(any(test, feature = "proptest"))]
73use proptest_derive::Arbitrary;
74use serde::{Deserialize, Serialize};
75use tiberius::numeric::Numeric;
76
77use crate::desc::{SqlServerQualifiedTableName, SqlServerTableRaw};
78use crate::inspect::DDLEvent;
79use crate::{Client, SqlServerCdcMetrics, SqlServerError, TransactionIsolationLevel};
80
81/// A stream of changes from a table in SQL Server that has CDC enabled.
82///
83/// SQL Server does not have an API to push or notify consumers of changes, so we periodically
84/// poll the upstream source.
85///
86/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/change-data-capture-tables-transact-sql?view=sql-server-ver16>
87pub struct CdcStream<'a, M: SqlServerCdcMetrics> {
88    /// Client we use for querying SQL Server.
89    client: &'a mut Client,
90    /// Upstream capture instances we'll list changes from.
91    capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
92    /// How often we poll the upstream for changes.
93    poll_interval: Duration,
94    /// How long we'll wait for SQL Server to return a max LSN before taking a snapshot.
95    ///
96    /// Note: When CDC is first enabled in an instance of SQL Server it can take a moment
97    /// for it to "completely" startup. Before starting a `TRANSACTION` for our snapshot
98    /// we'll wait this duration for SQL Server to report an [`Lsn`] and thus indicate CDC is
99    /// ready to go.
100    max_lsn_wait: Duration,
101    /// Metrics.
102    metrics: M,
103}
104
105impl<'a, M: SqlServerCdcMetrics> CdcStream<'a, M> {
106    pub(crate) fn new(
107        client: &'a mut Client,
108        capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
109        metrics: M,
110    ) -> Self {
111        CdcStream {
112            client,
113            capture_instances,
114            poll_interval: Duration::from_secs(1),
115            max_lsn_wait: Duration::from_secs(10),
116            metrics,
117        }
118    }
119
120    /// Set the [`Lsn`] that we should start streaming changes from.
121    ///
122    /// If the provided [`Lsn`] is not available, the stream will return an error
123    /// when first polled.
124    pub fn start_lsn(mut self, capture_instance: &str, lsn: Lsn) -> Self {
125        let start_lsn = self
126            .capture_instances
127            .get_mut(capture_instance)
128            .expect("capture instance does not exist");
129        *start_lsn = Some(lsn);
130        self
131    }
132
133    /// The cadence at which we'll poll the upstream SQL Server database for changes.
134    ///
135    /// Default is 1 second.
136    pub fn poll_interval(mut self, interval: Duration) -> Self {
137        self.poll_interval = interval;
138        self
139    }
140
141    /// The max duration we'll wait for SQL Server to return an [`Lsn`] before taking a
142    /// snapshot.
143    ///
144    /// When CDC is first enabled in SQL Server it can take a moment before it is fully
145    /// setup and starts reporting LSNs.
146    ///
147    /// Default is 10 seconds.
148    pub fn max_lsn_wait(mut self, wait: Duration) -> Self {
149        self.max_lsn_wait = wait;
150        self
151    }
152
153    /// Takes a snapshot of the upstream table that the specified `table` represents.
154    pub async fn snapshot<'b>(
155        &'b mut self,
156        table: &SqlServerTableRaw,
157        worker_id: usize,
158        source_id: GlobalId,
159    ) -> Result<
160        (
161            Lsn,
162            impl Stream<Item = Result<tiberius::Row, SqlServerError>>,
163        ),
164        SqlServerError,
165    > {
166        static SAVEPOINT_NAME: &str = "_mz_snap_";
167
168        // The client that will be used for fencing does not need any special isolation level
169        // as it will be just be locking the table(s).
170        let mut fencing_client = self.client.new_connection().await?;
171        let mut fence_txn = fencing_client.transaction().await?;
172        let qualified_table_name = format!(
173            "{schema_name}.{table_name}",
174            schema_name = &table.schema_name,
175            table_name = &table.name
176        );
177        self.metrics
178            .snapshot_table_lock_start(&qualified_table_name);
179        let result: Result<_, SqlServerError> = async {
180            fence_txn
181                .lock_table_shared(&table.schema_name, &table.name)
182                .await?;
183            tracing::info!(%source_id, %table.schema_name, %table.name, "timely-{worker_id} locked table");
184
185            self.client
186                .set_transaction_isolation(TransactionIsolationLevel::Snapshot)
187                .await?;
188            let mut txn = self.client.transaction().await?;
189            // Creating a savepoint forces a write to the transaction log, which will
190            // assign an LSN, but it does not force a transaction sequence number to be
191            // assigned as far as I can tell.  I have not observed any entries added to
192            // `sys.dm_tran_active_snapshot_database_transactions` when creating a savepoint
193            // or when reading system views to retrieve the LSN.
194            //
195            // We choose cdc.change_tables because it is a system table that will exist
196            // when CDC is enabled, it has a well known schema, and as a CDC client,
197            // we should be able to read from it already.
198            let res = txn
199                .simple_query("SELECT TOP 1 object_id FROM cdc.change_tables")
200                .await?;
201            if res.len() != 1 {
202                Err(SqlServerError::InvariantViolated(
203                    "No objects found in cdc.change_tables".into(),
204                ))?
205            }
206
207            // Because the table is locked, any write operation has either
208            // completed, or is blocked. The LSN and XSN acquired now will represent a
209            // consistent point-in-time view, such that any committed write will be
210            // visible to this snapshot and the LSN of such a write will be less than
211            // or equal to the LSN captured here. Creating the savepoint sets the LSN,
212            // we can read it after rolling back the locks.
213            txn.create_savepoint(SAVEPOINT_NAME).await?;
214            tracing::info!(%source_id, %table.schema_name, %table.name, %SAVEPOINT_NAME, "timely-{worker_id} created savepoint");
215
216            // Once the savepoint is created (which establishes the XSN and captures the LSN),
217            // the table no longer needs to be locked. Any writes that happen to the upstream table
218            // will have an LSN higher than our captured LSN, and will be read from CDC.
219            fence_txn.rollback().await?;
220
221            Ok(txn)
222        }.await;
223        self.metrics.snapshot_table_lock_end(&qualified_table_name);
224        let mut txn = result?;
225        let lsn = txn.get_lsn().await?;
226
227        tracing::info!(%source_id, ?lsn, "timely-{worker_id} starting snapshot");
228        let rows = async_stream::try_stream! {
229            {
230                let snapshot_stream = crate::inspect::snapshot(txn.client, table);
231                tokio::pin!(snapshot_stream);
232
233                while let Some(row) = snapshot_stream.next().await {
234                    yield row?;
235                }
236            }
237
238            txn.rollback().await?
239        };
240
241        Ok((lsn, rows))
242    }
243
244    /// Consume `self` returning a [`Stream`] of [`CdcEvent`]s.
245    pub fn into_stream(
246        mut self,
247    ) -> impl Stream<Item = Result<CdcEvent, SqlServerError>> + use<'a, M> {
248        async_stream::try_stream! {
249            // Initialize all of our start LSNs.
250            self.initialize_start_lsns().await?;
251
252            // When starting the stream we'll emit one progress event if we've already observed
253            // everything the DB currently has.
254            if let Some(starting_lsn) = self.capture_instances.values().filter_map(|x| *x).min() {
255                let db_curr_lsn = crate::inspect::get_max_lsn(self.client).await?;
256                let next_lsn = db_curr_lsn.increment();
257                if starting_lsn >= db_curr_lsn {
258                    tracing::debug!(
259                        %starting_lsn,
260                        %db_curr_lsn,
261                        %next_lsn,
262                        "yielding initial progress",
263                    );
264                    yield CdcEvent::Progress { next_lsn };
265                }
266            }
267
268            loop {
269                // Measure the tick before we do any operation so the time it takes
270                // to query SQL Server is included in the time that we wait.
271                let next_tick = tokio::time::Instant::now()
272                    .checked_add(self.poll_interval)
273                    .expect("tick interval overflowed!");
274
275                // We always check for changes based on the "global" minimum LSN of any
276                // one capture instance.
277                let maybe_curr_lsn = self.capture_instances.values().filter_map(|x| *x).min();
278                let Some(curr_lsn) = maybe_curr_lsn else {
279                    tracing::warn!("shutting down CDC stream because nothing to replicate");
280                    break;
281                };
282
283                // Get the max LSN for the DB.
284                let db_max_lsn = crate::inspect::get_max_lsn(self.client).await?;
285                tracing::debug!(?db_max_lsn, ?curr_lsn, "got max LSN");
286
287                // If the LSN of the DB has increased then get all of our changes.
288                if db_max_lsn > curr_lsn {
289                    for (instance, instance_lsn) in &self.capture_instances {
290                        let Some(instance_lsn) = instance_lsn.as_ref() else {
291                            tracing::error!(?instance, "found uninitialized LSN!");
292                            continue;
293                        };
294
295                        // We've already replicated everything up-to db_max_lsn, so
296                        // nothing to do.
297                        if db_max_lsn < *instance_lsn {
298                            continue;
299                        }
300
301                        {
302                            // Get a stream of all the changes for the current instance.
303                            let changes = crate::inspect::get_changes_asc(
304                                self.client,
305                                &*instance,
306                                *instance_lsn,
307                                db_max_lsn,
308                                RowFilterOption::AllUpdateOld,
309                            )
310                            // TODO(sql_server3): Make this chunk size configurable.
311                            .ready_chunks(64);
312                            let mut changes = std::pin::pin!(changes);
313
314                            // Map and stream all the rows to our listener.
315                            while let Some(chunk) = changes.next().await {
316                                // Group events by LSN.
317                                //
318                                // TODO(sql_server3): Can we maybe re-use this BTreeMap or these Vec
319                                // allocations? Something to be careful of is shrinking the allocations
320                                // if/when they grow to large, e.g. from a large spike of changes.
321                                // Alternatively we could also use a single Vec here since we know the
322                                // changes are ordered by LSN.
323                                let mut events: BTreeMap<Lsn, Vec<Operation>> = BTreeMap::default();
324                                for change in chunk {
325                                    let (lsn, operation) = change.and_then(Operation::try_parse)?;
326                                    events.entry(lsn).or_default().push(operation);
327                                }
328
329                                // Emit the groups of events.
330                                for (lsn, changes) in events {
331                                    yield CdcEvent::Data {
332                                        capture_instance: Arc::clone(instance),
333                                        lsn,
334                                        changes,
335                                    };
336                                }
337                            }
338                        }
339
340                        let ddl_history = crate::inspect::get_ddl_history(
341                            self.client, instance, instance_lsn, &db_max_lsn,
342                        ).await?;
343                        for (table, ddl_events) in ddl_history {
344                            for ddl_event in ddl_events {
345                                yield CdcEvent::SchemaUpdate {
346                                    capture_instance: Arc::clone(instance),
347                                    table: table.clone(),
348                                    ddl_event
349                                }
350                            }
351                        }
352                    }
353
354                    // Increment our LSN (`get_changes` is inclusive).
355                    //
356                    // TODO(sql_server2): We should occassionally check to see how close the LSN we
357                    // generate is to the LSN returned from incrementing via SQL Server itself.
358                    let next_lsn = db_max_lsn.increment();
359                    tracing::debug!(?curr_lsn, ?next_lsn, "incrementing LSN");
360
361                    // Notify our listener that we've emitted all changes __less than__ this LSN.
362                    //
363                    // Note: This aligns well with timely's semantics of progress tracking.
364                    yield CdcEvent::Progress { next_lsn };
365
366                    // We just listed everything upto next_lsn.
367                    for instance_lsn in self.capture_instances.values_mut() {
368                        let instance_lsn = instance_lsn.as_mut().expect("should be initialized");
369                        // Ensure LSNs don't go backwards.
370                        *instance_lsn = std::cmp::max(*instance_lsn, next_lsn);
371                    }
372                }
373
374                tokio::time::sleep_until(next_tick).await;
375            }
376        }
377    }
378
379    /// Determine the [`Lsn`] to start streaming changes from.
380    async fn initialize_start_lsns(&mut self) -> Result<(), SqlServerError> {
381        // First, initialize all start LSNs. If a capture instance didn't have
382        // one specified then we'll start from the current max.
383        let max_lsn = crate::inspect::get_max_lsn(self.client).await?;
384        for (_instance, requsted_lsn) in self.capture_instances.iter_mut() {
385            if requsted_lsn.is_none() {
386                requsted_lsn.replace(max_lsn);
387            }
388        }
389
390        // For each instance, ensure their requested LSN is available.
391        for (instance, requested_lsn) in self.capture_instances.iter() {
392            let requested_lsn = requested_lsn
393                .as_ref()
394                .expect("initialized all values above");
395
396            // Get the minimum Lsn available for this instance.
397            let available_lsn = crate::inspect::get_min_lsn(self.client, &*instance).await?;
398
399            // If we cannot start at our desired LSN, we must return an error!.
400            if *requested_lsn < available_lsn {
401                return Err(CdcError::LsnNotAvailable {
402                    capture_instance: Arc::clone(instance),
403                    requested: *requested_lsn,
404                    minimum: available_lsn,
405                }
406                .into());
407            }
408        }
409
410        Ok(())
411    }
412
413    /// If CDC was recently enabled on an instance of SQL Server then it will report
414    /// `NULL` for the minimum LSN of a capture instance and/or the maximum LSN of the
415    /// entire database.
416    ///
417    /// This method runs a retry loop that waits for the upstream DB to report good
418    /// values. It should be called before taking the initial [`CdcStream::snapshot`]
419    /// to ensure the system is ready to proceed with CDC.
420    pub async fn wait_for_ready(&mut self) -> Result<(), SqlServerError> {
421        // Ensure all of the capture instances are reporting an LSN.
422        for instance in self.capture_instances.keys() {
423            crate::inspect::get_min_lsn_retry(self.client, instance, self.max_lsn_wait).await?;
424        }
425
426        // Ensure the database is reporting a max LSN.
427        crate::inspect::get_max_lsn_retry(self.client, self.max_lsn_wait).await?;
428
429        Ok(())
430    }
431}
432
433/// A change event from a [`CdcStream`].
434#[derive(Derivative)]
435#[derivative(Debug)]
436pub enum CdcEvent {
437    /// Changes have occurred upstream.
438    Data {
439        /// The capture instance these changes are for.
440        capture_instance: Arc<str>,
441        /// The LSN that this change occurred at.
442        lsn: Lsn,
443        /// The change itself.
444        changes: Vec<Operation>,
445    },
446    /// We've made progress and observed all the changes less than `next_lsn`.
447    Progress {
448        /// We've received all of the data for [`Lsn`]s __less than__ this one.
449        next_lsn: Lsn,
450    },
451    /// DDL change has occured for the upstream table.
452    SchemaUpdate {
453        /// The capture instance.
454        capture_instance: Arc<str>,
455        /// The upstream table that was updated.
456        table: SqlServerQualifiedTableName,
457        /// DDL event
458        ddl_event: DDLEvent,
459    },
460}
461
462#[derive(Debug, thiserror::Error)]
463pub enum CdcError {
464    #[error(
465        "the requested LSN '{requested:?}' is less than the minimum '{minimum:?}' for `{capture_instance}'"
466    )]
467    LsnNotAvailable {
468        capture_instance: Arc<str>,
469        requested: Lsn,
470        minimum: Lsn,
471    },
472    #[error("failed to get the required column '{column_name}': {error}")]
473    RequiredColumn {
474        column_name: &'static str,
475        error: String,
476    },
477    #[error("failed to cleanup values for '{capture_instance}' at {low_water_mark}")]
478    CleanupFailed {
479        capture_instance: String,
480        low_water_mark: Lsn,
481    },
482}
483
484/// This type is used to represent the progress of each SQL Server instance in
485/// the ingestion dataflow.
486///
487/// A SQL Server LSN is a three part "number" that provides a __total order__
488/// to all transations within a database. Interally we don't really care what
489/// these parts mean, but they are:
490///
491/// 1. A Virtual Log File (VLF) sequence number, bytes [0, 4)
492/// 2. Log block number, bytes [4, 8)
493/// 3. Log record number, bytes [8, 10)
494///
495/// For more info on log sequence numbers in SQL Server see:
496/// <https://learn.microsoft.com/en-us/sql/relational-databases/sql-server-transaction-log-architecture-and-management-guide?view=sql-server-ver16#Logical_Arch>
497///
498/// Note: The derived impl of [`PartialOrd`] and [`Ord`] relies on the field
499/// ordering so do not change it.
500#[derive(
501    Default,
502    Copy,
503    Clone,
504    Debug,
505    Eq,
506    PartialEq,
507    PartialOrd,
508    Ord,
509    Hash,
510    Serialize,
511    Deserialize
512)]
513#[cfg_attr(any(test, feature = "proptest"), derive(Arbitrary))]
514pub struct Lsn {
515    /// Virtual Log File sequence number.
516    pub vlf_id: u32,
517    /// Log block number.
518    pub block_id: u32,
519    /// Log record number.
520    pub record_id: u16,
521}
522
523impl Lsn {
524    const SIZE: usize = 10;
525
526    /// Interpret the provided bytes as an [`Lsn`].
527    pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, String> {
528        if bytes.len() != Self::SIZE {
529            return Err(format!("incorrect length, expected 10 got {}", bytes.len()));
530        }
531
532        let vlf_id: [u8; 4] = bytes[0..4].try_into().expect("known good length");
533        let block_id: [u8; 4] = bytes[4..8].try_into().expect("known good length");
534        let record_id: [u8; 2] = bytes[8..].try_into().expect("known good length");
535
536        Ok(Lsn {
537            vlf_id: u32::from_be_bytes(vlf_id),
538            block_id: u32::from_be_bytes(block_id),
539            record_id: u16::from_be_bytes(record_id),
540        })
541    }
542
543    /// Return the underlying byte slice for this [`Lsn`].
544    pub fn as_bytes(&self) -> [u8; 10] {
545        let mut raw: [u8; Self::SIZE] = [0; 10];
546
547        raw[0..4].copy_from_slice(&self.vlf_id.to_be_bytes());
548        raw[4..8].copy_from_slice(&self.block_id.to_be_bytes());
549        raw[8..].copy_from_slice(&self.record_id.to_be_bytes());
550
551        raw
552    }
553
554    /// Increment this [`Lsn`].
555    ///
556    /// The returned [`Lsn`] may not exist upstream yet, but it's guaranteed to
557    /// sort greater than `self`.
558    pub fn increment(self) -> Lsn {
559        let (record_id, carry) = self.record_id.overflowing_add(1);
560        let (block_id, carry) = self.block_id.overflowing_add(carry.into());
561        let (vlf_id, overflow) = self.vlf_id.overflowing_add(carry.into());
562        assert!(!overflow, "overflowed Lsn, {self:?}");
563
564        Lsn {
565            vlf_id,
566            block_id,
567            record_id,
568        }
569    }
570
571    /// Drops the `record_id` portion of the [`Lsn`] so we can fit an "abbreviation"
572    /// of this [`Lsn`] into a [`u64`], without losing the total order.
573    pub fn abbreviate(&self) -> u64 {
574        let mut abbreviated: u64 = 0;
575
576        #[allow(clippy::as_conversions)]
577        {
578            abbreviated += (self.vlf_id as u64) << 32;
579            abbreviated += self.block_id as u64;
580        }
581
582        abbreviated
583    }
584}
585
586impl fmt::Display for Lsn {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        write!(f, "{}:{}:{}", self.vlf_id, self.block_id, self.record_id)
589    }
590}
591
592impl TryFrom<&[u8]> for Lsn {
593    type Error = String;
594
595    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
596        Lsn::try_from_bytes(value)
597    }
598}
599
600impl TryFrom<Numeric> for Lsn {
601    type Error = String;
602
603    fn try_from(value: Numeric) -> Result<Self, Self::Error> {
604        if value.dec_part() != 0 {
605            return Err(format!(
606                "LSN expect Numeric(25,0), but found decimal portion {}",
607                value.dec_part()
608            ));
609        }
610        let mut decimal_lsn = value.int_part();
611        // LSN is composed of 4 bytes : 4 bytes : 2 bytes
612        // and MS provided the method to decode that here
613        // https://github.com/microsoft/sql-server-samples/blob/master/samples/features/ssms-templates/Sql/Change%20Data%20Capture/Enumeration/Create%20Function%20fn_convertnumericlsntobinary.sql
614
615        let vlf_id = u32::try_from(decimal_lsn / 10_i128.pow(15))
616            .map_err(|e| format!("Failed to decode vlf_id for lsn {decimal_lsn}: {e:?}"))?;
617        decimal_lsn -= i128::from(vlf_id) * 10_i128.pow(15);
618
619        let block_id = u32::try_from(decimal_lsn / 10_i128.pow(5))
620            .map_err(|e| format!("Failed to decode block_id for lsn {decimal_lsn}: {e:?}"))?;
621        decimal_lsn -= i128::from(block_id) * 10_i128.pow(5);
622
623        let record_id = u16::try_from(decimal_lsn)
624            .map_err(|e| format!("Failed to decode record_id for lsn {decimal_lsn}: {e:?}"))?;
625
626        Ok(Lsn {
627            vlf_id,
628            block_id,
629            record_id,
630        })
631    }
632}
633
634impl columnation::Columnation for Lsn {
635    type InnerRegion = columnation::CopyRegion<Lsn>;
636}
637
638impl timely::progress::Timestamp for Lsn {
639    // No need to describe complex summaries.
640    type Summary = ();
641
642    fn minimum() -> Self {
643        Lsn::default()
644    }
645}
646
647impl timely::progress::PathSummary<Lsn> for () {
648    fn results_in(&self, src: &Lsn) -> Option<Lsn> {
649        Some(*src)
650    }
651
652    fn followed_by(&self, _other: &Self) -> Option<Self> {
653        Some(())
654    }
655}
656
657impl timely::progress::timestamp::Refines<()> for Lsn {
658    fn to_inner(_other: ()) -> Self {
659        use timely::progress::Timestamp;
660        Self::minimum()
661    }
662    fn to_outer(self) -> () {}
663
664    fn summarize(_path: <Self as timely::progress::Timestamp>::Summary) -> () {}
665}
666
667impl timely::order::PartialOrder for Lsn {
668    fn less_equal(&self, other: &Self) -> bool {
669        self <= other
670    }
671
672    fn less_than(&self, other: &Self) -> bool {
673        self < other
674    }
675}
676impl timely::order::TotalOrder for Lsn {}
677
678/// Structured format of an [`Lsn`].
679///
680/// Note: The derived impl of [`PartialOrd`] and [`Ord`] relies on the field
681/// ordering so do not change it.
682#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
683pub struct StructuredLsn {
684    vlf_id: u32,
685    block_id: u32,
686    record_id: u16,
687}
688
689/// When querying CDC functions like `cdc.fn_cdc_get_all_changes_<capture_instance>` this governs
690/// what content is returned.
691///
692/// Note: There exists another option `All` that exclude the _before_ value from an `UPDATE`. We
693/// don't support this for SQL Server sources yet, so it's not included in this enum.
694///
695/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#row_filter_option>
696#[derive(Debug, Copy, Clone)]
697pub enum RowFilterOption {
698    /// Includes both the before and after values of a row when changed because of an `UPDATE`.
699    AllUpdateOld,
700}
701
702impl RowFilterOption {
703    /// Returns this option formatted in a way that can be used in a query.
704    pub fn to_sql_string(&self) -> &'static str {
705        match self {
706            RowFilterOption::AllUpdateOld => "all update old",
707        }
708    }
709}
710
711impl fmt::Display for RowFilterOption {
712    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
713        write!(f, "{}", self.to_sql_string())
714    }
715}
716
717/// Identifies what change was made to the SQL Server table tracked by CDC.
718#[derive(Debug)]
719pub enum Operation {
720    /// Row was `INSERT`-ed.
721    Insert(tiberius::Row),
722    /// Row was `DELETE`-ed.
723    Delete(tiberius::Row),
724    /// Original value of the row when `UPDATE`-ed.
725    UpdateOld(Lsn, tiberius::Row),
726    /// New value of the row when `UPDATE`-ed.
727    UpdateNew(Lsn, tiberius::Row),
728}
729
730impl Operation {
731    /// Parse the provided [`tiberius::Row`] to determine what [`Operation`] occurred.
732    ///
733    /// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#table-returned>.
734    fn try_parse(data: tiberius::Row) -> Result<(Lsn, Self), SqlServerError> {
735        static START_LSN_COLUMN: &str = "__$start_lsn";
736        static OPERATION_COLUMN: &str = "__$operation";
737        static SEQVAL_COLUMN: &str = "__$seqval";
738
739        let lsn: &[u8] = data
740            .try_get(START_LSN_COLUMN)
741            .map_err(|e| CdcError::RequiredColumn {
742                column_name: START_LSN_COLUMN,
743                error: e.to_string(),
744            })?
745            .ok_or_else(|| CdcError::RequiredColumn {
746                column_name: START_LSN_COLUMN,
747                error: "got null value".to_string(),
748            })?;
749        let operation: i32 = data
750            .try_get(OPERATION_COLUMN)
751            .map_err(|e| CdcError::RequiredColumn {
752                column_name: OPERATION_COLUMN,
753                error: e.to_string(),
754            })?
755            .ok_or_else(|| CdcError::RequiredColumn {
756                column_name: OPERATION_COLUMN,
757                error: "got null value".to_string(),
758            })?;
759        let seqval: &[u8] = data
760            .try_get(SEQVAL_COLUMN)
761            .map_err(|e| CdcError::RequiredColumn {
762                column_name: SEQVAL_COLUMN,
763                error: e.to_string(),
764            })?
765            .ok_or_else(|| CdcError::RequiredColumn {
766                column_name: SEQVAL_COLUMN,
767                error: "got null value".to_string(),
768            })?;
769
770        let lsn = Lsn::try_from(lsn).map_err(|msg| SqlServerError::InvalidData {
771            column_name: START_LSN_COLUMN.to_string(),
772            error: msg,
773        })?;
774        let seqval = Lsn::try_from(seqval).map_err(|msg| SqlServerError::InvalidData {
775            column_name: SEQVAL_COLUMN.to_string(),
776            error: msg,
777        })?;
778
779        let operation = match operation {
780            1 => Operation::Delete(data),
781            2 => Operation::Insert(data),
782            3 => Operation::UpdateOld(seqval, data),
783            4 => Operation::UpdateNew(seqval, data),
784            other => {
785                return Err(SqlServerError::InvalidData {
786                    column_name: OPERATION_COLUMN.to_string(),
787                    error: format!("unrecognized operation {other}"),
788                });
789            }
790        };
791
792        Ok((lsn, operation))
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::Lsn;
799    use proptest::prelude::*;
800    use tiberius::numeric::Numeric;
801
802    #[mz_ore::test]
803    fn smoketest_lsn_ordering() {
804        let a = hex::decode("0000003D000019B80004").unwrap();
805        let a = Lsn::try_from(&a[..]).unwrap();
806
807        let b = hex::decode("0000003D000019F00011").unwrap();
808        let b = Lsn::try_from(&b[..]).unwrap();
809
810        let c = hex::decode("0000003D00001A500003").unwrap();
811        let c = Lsn::try_from(&c[..]).unwrap();
812
813        assert!(a < b);
814        assert!(b < c);
815        assert!(a < c);
816
817        assert_eq!(a, a);
818        assert_eq!(b, b);
819        assert_eq!(c, c);
820    }
821
822    #[mz_ore::test]
823    fn smoketest_lsn_roundtrips() {
824        #[track_caller]
825        fn test_case(hex: &str) {
826            let og = hex::decode(hex).unwrap();
827            let lsn = Lsn::try_from(&og[..]).unwrap();
828            let rnd = lsn.as_bytes();
829            assert_eq!(og[..], rnd[..]);
830        }
831
832        test_case("0000003D000019B80004");
833        test_case("0000003D000019F00011");
834        test_case("0000003D00001A500003");
835    }
836
837    #[mz_ore::test]
838    fn proptest_lsn_roundtrips() {
839        #[track_caller]
840        fn test_case(bytes: [u8; 10]) {
841            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
842            let rnd = lsn.as_bytes();
843            assert_eq!(&bytes[..], &rnd[..]);
844        }
845        proptest!(|(random_bytes in any::<[u8; 10]>())| {
846            test_case(random_bytes)
847        })
848    }
849
850    #[mz_ore::test]
851    fn proptest_lsn_increment() {
852        #[track_caller]
853        fn test_case(bytes: [u8; 10]) {
854            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
855            let new = lsn.increment();
856            assert!(lsn < new);
857        }
858        proptest!(|(random_bytes in any::<[u8; 10]>())| {
859            test_case(random_bytes)
860        })
861    }
862
863    #[mz_ore::test]
864    fn proptest_lsn_abbreviate_total_order() {
865        #[track_caller]
866        fn test_case(bytes: [u8; 10], num_increment: u8) {
867            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
868            let mut new = lsn;
869            for _ in 0..num_increment {
870                new = new.increment();
871            }
872
873            let a = lsn.abbreviate();
874            let b = new.abbreviate();
875
876            assert!(a <= b);
877        }
878        proptest!(|(random_bytes in any::<[u8; 10]>(), num_increment in any::<u8>())| {
879            test_case(random_bytes, num_increment)
880        })
881    }
882
883    #[mz_ore::test]
884    fn test_numeric_lsn_ordering() {
885        let a = Lsn::try_from(Numeric::new_with_scale(45_0000008784_00001_i128, 0)).unwrap();
886        let b = Lsn::try_from(Numeric::new_with_scale(45_0000008784_00002_i128, 0)).unwrap();
887        let c = Lsn::try_from(Numeric::new_with_scale(45_0000008785_00002_i128, 0)).unwrap();
888        let d = Lsn::try_from(Numeric::new_with_scale(49_0000008784_00002_i128, 0)).unwrap();
889        assert!(a < b);
890        assert!(b < c);
891        assert!(c < d);
892        assert!(a < d);
893
894        assert_eq!(a, a);
895        assert_eq!(b, b);
896        assert_eq!(c, c);
897        assert_eq!(d, d);
898    }
899
900    #[mz_ore::test]
901    fn test_numeric_lsn_invalid() {
902        let with_decimal = Numeric::new_with_scale(1, 20);
903        assert!(Lsn::try_from(with_decimal).is_err());
904
905        for v in [
906            4294967296_0000000000_00000_i128, // vlf_id is too large
907            1_4294967296_00000_i128,          // block_id is too large
908            1_0000000001_65536_i128,          // record_id is too large
909            -49_0000008784_00002_i128,        // negative is invalid
910        ] {
911            let invalid_lsn = Numeric::new_with_scale(v, 0);
912            assert!(Lsn::try_from(invalid_lsn).is_err());
913        }
914    }
915}