mz_sql_server_util/
cdc.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Replicate a table from SQL Server using their Change-Data-Capture (CDC) primitives.
11//!
12//! This module provides a [`CdcStream`] type that provides the following API for
13//! replicating a table:
14//!
15//! 1. [`CdcStream::snapshot`] returns an initial snapshot of a table and the [`Lsn`] at
16//!    which the snapshot was taken.
17//! 2. [`CdcStream::into_stream`] returns a [`futures::Stream`] of [`CdcEvent`]s
18//!    optionally from the [`Lsn`] returned in step 1.
19//!
20//! The snapshot process is responsible for identifying an [`Lsn`] that corresponds to
21//! a point-in-time view of the data for the table(s) being copied. Similarly to
22//! MySQL, Microsoft SQL server, as far as we know, does not provide an API to
23//! achieve this.
24//!
25//! SQL Server `SNAPSHOT` isolation provides guarantees that a reader will only
26//! see writes committed before the transaction began.  More specficially, this
27//! snapshot is implemented using versions that are visibile based on the
28//! transaction sequence number (`XSN`). The `XSN` is set at the first
29//! read or write, not at `BEGIN TRANSACTION`, see [here](https://learn.microsoft.com/en-us/sql/relational-databases/sql-server-transaction-locking-and-row-versioning-guide?view=sql-server-ver17).
30//! This provides us a suitable starting point for capturing the table data.
31//! To force an `XSN` to be assigned, experiments have shown that a table must
32//! be read. We choose a well-known table that we should already have access to,
33//! [cdc.change_tables](https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/cdc-change-tables-transact-sql?view=sql-server-ver17),
34//! and read a single value from it.
35//!
36//! Due to the asynchronous nature of CDC, we can assume that the [`Lsn`]
37//! returned from any CDC tables or CDC functions will always be stale,
38//! in relation to the source table that CDC is tracking. The system table
39//! [sys.dm_tran_database_transactions](https://learn.microsoft.com/en-us/sql/relational-databases/system-dynamic-management-views/sys-dm-tran-database-transactions-transact-sql?view=sql-server-ver17)
40//! will contain an [`Lsn`] for any transaction that performs a write operation.
41//! Creating a savepoint using [SAVE TRANSACTION](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/save-transaction-transact-sql?view=sql-server-ver17)
42//! is sufficient to generate an [`Lsn`] in this case.
43//!
44//! To ensure that the the point-in-time view is established atomically with
45//! collection of the [`Lsn`], we lock the tables to prevent writes from being
46//! interleaved between the 2 commands (read to establish `XSN` and creation of
47//! the savepoint).
48//!
49//! SQL server supports table locks, but those will only be released
50//! once the outermost transaction completes. For this reason, this module
51//! uses two connections for the snapshot process. The first connection is used
52//! to initiate a transaction and lock the upstream tables under
53//! [`TransactionIsolationLevel::ReadCommitted`] isolation. While the first
54//! connection maintains the locks, the second connection starts a
55//! transaction with [`TransactionIsolationLevel::Snapshot`] isolation and
56//! creates a savepoint. Once the savepoint is created, SQL server has assigned
57//! an [`Lsn`] and the the first connection rolls back the transaction.
58//! The [`Lsn`] and snapshot are captured by the second connection within the
59//! existing transaction.
60//!
61//! After completing the snapshot we use [`crate::inspect::get_changes_asc`] which will return
62//! all changes between a `[lower, upper)` bound of [`Lsn`]s.
63
64use std::collections::BTreeMap;
65use std::fmt;
66use std::sync::Arc;
67use std::time::Duration;
68
69use derivative::Derivative;
70use futures::{Stream, StreamExt};
71use mz_repr::GlobalId;
72use proptest_derive::Arbitrary;
73use serde::{Deserialize, Serialize};
74use tiberius::numeric::Numeric;
75
76use crate::desc::SqlServerTableRaw;
77use crate::{Client, SqlServerError, TransactionIsolationLevel};
78
79/// A stream of changes from a table in SQL Server that has CDC enabled.
80///
81/// SQL Server does not have an API to push or notify consumers of changes, so we periodically
82/// poll the upstream source.
83///
84/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-tables/change-data-capture-tables-transact-sql?view=sql-server-ver16>
85pub struct CdcStream<'a> {
86    /// Client we use for querying SQL Server.
87    client: &'a mut Client,
88    /// Upstream capture instances we'll list changes from.
89    capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
90    /// How often we poll the upstream for changes.
91    poll_interval: Duration,
92    /// How long we'll wait for SQL Server to return a max LSN before taking a snapshot.
93    ///
94    /// Note: When CDC is first enabled in an instance of SQL Server it can take a moment
95    /// for it to "completely" startup. Before starting a `TRANSACTION` for our snapshot
96    /// we'll wait this duration for SQL Server to report an [`Lsn`] and thus indicate CDC is
97    /// ready to go.
98    max_lsn_wait: Duration,
99}
100
101impl<'a> CdcStream<'a> {
102    pub(crate) fn new(
103        client: &'a mut Client,
104        capture_instances: BTreeMap<Arc<str>, Option<Lsn>>,
105    ) -> Self {
106        CdcStream {
107            client,
108            capture_instances,
109            poll_interval: Duration::from_secs(1),
110            max_lsn_wait: Duration::from_secs(10),
111        }
112    }
113
114    /// Set the [`Lsn`] that we should start streaming changes from.
115    ///
116    /// If the provided [`Lsn`] is not available, the stream will return an error
117    /// when first polled.
118    pub fn start_lsn(mut self, capture_instance: &str, lsn: Lsn) -> Self {
119        let start_lsn = self
120            .capture_instances
121            .get_mut(capture_instance)
122            .expect("capture instance does not exist");
123        *start_lsn = Some(lsn);
124        self
125    }
126
127    /// The cadence at which we'll poll the upstream SQL Server database for changes.
128    ///
129    /// Default is 1 second.
130    pub fn poll_interval(mut self, interval: Duration) -> Self {
131        self.poll_interval = interval;
132        self
133    }
134
135    /// The max duration we'll wait for SQL Server to return an [`Lsn`] before taking a
136    /// snapshot.
137    ///
138    /// When CDC is first enabled in SQL Server it can take a moment before it is fully
139    /// setup and starts reporting LSNs.
140    ///
141    /// Default is 10 seconds.
142    pub fn max_lsn_wait(mut self, wait: Duration) -> Self {
143        self.max_lsn_wait = wait;
144        self
145    }
146
147    /// Takes a snapshot of the upstream table that the specified `table` represents.
148    pub async fn snapshot<'b>(
149        &'b mut self,
150        table: &SqlServerTableRaw,
151        worker_id: usize,
152        source_id: GlobalId,
153    ) -> Result<
154        (
155            Lsn,
156            usize,
157            impl Stream<Item = Result<tiberius::Row, SqlServerError>>,
158        ),
159        SqlServerError,
160    > {
161        static SAVEPOINT_NAME: &str = "_mz_snap_";
162
163        // The client that will be used for fencing does not need any special isolation level
164        // as it will be just be locking the table(s).
165        let mut fencing_client = self.client.new_connection().await?;
166        let mut fence_txn = fencing_client.transaction().await?;
167        fence_txn
168            .lock_table_shared(&table.schema_name, &table.name)
169            .await?;
170        tracing::info!(%source_id, %table.schema_name, %table.name, "timely-{worker_id} locked table");
171
172        self.client
173            .set_transaction_isolation(TransactionIsolationLevel::Snapshot)
174            .await?;
175        let mut txn = self.client.transaction().await?;
176        // Creating a savepoint forces a write to the transaction log, which will
177        // assign an LSN, but it does not force a transaction sequence number to be
178        // assigned as far as I can tell.  I have not observed any entries added to
179        // `sys.dm_tran_active_snapshot_database_transactions` when creating a savepoint
180        // or when reading system views to retrieve the LSN.
181        //
182        // We choose cdc.change_tables because it is a system table that will exist
183        // when CDC is enabled, it has a well known schema, and as a CDC client,
184        // we should be able to read from it already.
185        let res = txn
186            .simple_query("SELECT TOP 1 object_id FROM cdc.change_tables")
187            .await?;
188        if res.len() != 1 {
189            Err(SqlServerError::InvariantViolated(
190                "No objects found in cdc.change_tables".into(),
191            ))?
192        }
193
194        // Because the table is locked, any write operation has either
195        // completed, or is blocked. The LSN and XSN acquired now will represent a
196        // consistent point-in-time view, such that any committed write will be
197        // visible to this snapshot and the LSN of such a write will be less than
198        // or equal to the LSN captured here. Creating the savepoint sets the LSN,
199        // we can read it after rolling back the locks.
200        txn.create_savepoint(SAVEPOINT_NAME).await?;
201        tracing::info!(%source_id, %table.schema_name, %table.name, %SAVEPOINT_NAME, "timely-{worker_id} created savepoint");
202
203        // Once the savepoint is created (which establishes the XSN and captures the LSN),
204        // the table no longer needs to be locked. Any writes that happen to the upstream table
205        // will have an LSN higher than our captured LSN, and will be read from CDC.
206        fence_txn.rollback().await?;
207
208        let lsn = txn.get_lsn().await?;
209
210        tracing::info!(%source_id, ?lsn, "timely-{worker_id} starting snapshot");
211
212        tracing::trace!(%source_id, %table.capture_instance.name, %table.schema_name, %table.name, "timely-{worker_id} snapshot stats start");
213        let size =
214            crate::inspect::snapshot_size(txn.client, &table.schema_name, &table.name).await?;
215        tracing::trace!(%source_id, %table.capture_instance.name, %table.schema_name, %table.name, "timely-{worker_id} snapshot stats end");
216        let schema_name = &*table.schema_name;
217        let table_name = &*table.name;
218        let rows = async_stream::try_stream! {
219            {
220                let snapshot_stream = crate::inspect::snapshot(txn.client, &*schema_name, &*table_name);
221                tokio::pin!(snapshot_stream);
222
223                while let Some(row) = snapshot_stream.next().await {
224                    yield row?;
225                }
226            }
227
228            txn.rollback().await?
229        };
230
231        Ok((lsn, size, rows))
232    }
233
234    /// Consume `self` returning a [`Stream`] of [`CdcEvent`]s.
235    pub fn into_stream(mut self) -> impl Stream<Item = Result<CdcEvent, SqlServerError>> + use<'a> {
236        async_stream::try_stream! {
237            // Initialize all of our start LSNs.
238            self.initialize_start_lsns().await?;
239
240            // When starting the stream we'll emit one progress event if we've already observed
241            // everything the DB currently has.
242            if let Some(starting_lsn) = self.capture_instances.values().filter_map(|x| *x).min() {
243                let db_curr_lsn = crate::inspect::get_max_lsn(self.client).await?;
244                let next_lsn = db_curr_lsn.increment();
245                if starting_lsn >= db_curr_lsn {
246                    tracing::debug!(
247                        %starting_lsn,
248                        %db_curr_lsn,
249                        %next_lsn,
250                        "yielding initial progress",
251                    );
252                    yield CdcEvent::Progress { next_lsn };
253                }
254            }
255
256            loop {
257                // Measure the tick before we do any operation so the time it takes
258                // to query SQL Server is included in the time that we wait.
259                let next_tick = tokio::time::Instant::now()
260                    .checked_add(self.poll_interval)
261                    .expect("tick interval overflowed!");
262
263                // We always check for changes based on the "global" minimum LSN of any
264                // one capture instance.
265                let maybe_curr_lsn = self.capture_instances.values().filter_map(|x| *x).min();
266                let Some(curr_lsn) = maybe_curr_lsn else {
267                    tracing::warn!("shutting down CDC stream because nothing to replicate");
268                    break;
269                };
270
271                // Get the max LSN for the DB.
272                let db_max_lsn = crate::inspect::get_max_lsn(self.client).await?;
273                tracing::debug!(?db_max_lsn, ?curr_lsn, "got max LSN");
274
275                // If the LSN of the DB has increased then get all of our changes.
276                if db_max_lsn > curr_lsn {
277                    for (instance, instance_lsn) in &self.capture_instances {
278                        let Some(instance_lsn) = instance_lsn.as_ref() else {
279                            tracing::error!(?instance, "found uninitialized LSN!");
280                            continue;
281                        };
282
283                        // We've already replicated everything up-to db_max_lsn, so
284                        // nothing to do.
285                        if db_max_lsn < *instance_lsn {
286                            continue;
287                        }
288
289                        // Get a stream of all the changes for the current instance.
290                        let changes = crate::inspect::get_changes_asc(
291                            self.client,
292                            &*instance,
293                            *instance_lsn,
294                            db_max_lsn,
295                            RowFilterOption::AllUpdateOld,
296                        )
297                        // TODO(sql_server3): Make this chunk size configurable.
298                        .ready_chunks(64);
299                        let mut changes = std::pin::pin!(changes);
300
301                        // Map and stream all the rows to our listener.
302                        while let Some(chunk) = changes.next().await {
303                            // Group events by LSN.
304                            //
305                            // TODO(sql_server3): Can we maybe re-use this BTreeMap or these Vec
306                            // allocations? Something to be careful of is shrinking the allocations
307                            // if/when they grow to large, e.g. from a large spike of changes.
308                            // Alternatively we could also use a single Vec here since we know the
309                            // changes are ordered by LSN.
310                            let mut events: BTreeMap<Lsn, Vec<Operation>> = BTreeMap::default();
311                            for change in chunk {
312                                let (lsn, operation) = change.and_then(Operation::try_parse)?;
313                                events.entry(lsn).or_default().push(operation);
314                            }
315
316                            // Emit the groups of events.
317                            for (lsn, changes) in events {
318                                yield CdcEvent::Data {
319                                    capture_instance: Arc::clone(instance),
320                                    lsn,
321                                    changes,
322                                };
323                            }
324                        }
325                    }
326
327                    // Increment our LSN (`get_changes` is inclusive).
328                    //
329                    // TODO(sql_server2): We should occassionally check to see how close the LSN we
330                    // generate is to the LSN returned from incrementing via SQL Server itself.
331                    let next_lsn = db_max_lsn.increment();
332                    tracing::debug!(?curr_lsn, ?next_lsn, "incrementing LSN");
333
334                    // Notify our listener that we've emitted all changes __less than__ this LSN.
335                    //
336                    // Note: This aligns well with timely's semantics of progress tracking.
337                    yield CdcEvent::Progress { next_lsn };
338
339                    // We just listed everything upto next_lsn.
340                    for instance_lsn in self.capture_instances.values_mut() {
341                        let instance_lsn = instance_lsn.as_mut().expect("should be initialized");
342                        // Ensure LSNs don't go backwards.
343                        *instance_lsn = std::cmp::max(*instance_lsn, next_lsn);
344                    }
345                }
346
347                tokio::time::sleep_until(next_tick).await;
348            }
349        }
350    }
351
352    /// Determine the [`Lsn`] to start streaming changes from.
353    async fn initialize_start_lsns(&mut self) -> Result<(), SqlServerError> {
354        // First, initialize all start LSNs. If a capture instance didn't have
355        // one specified then we'll start from the current max.
356        let max_lsn = crate::inspect::get_max_lsn(self.client).await?;
357        for (_instance, requsted_lsn) in self.capture_instances.iter_mut() {
358            if requsted_lsn.is_none() {
359                requsted_lsn.replace(max_lsn);
360            }
361        }
362
363        // For each instance, ensure their requested LSN is available.
364        for (instance, requested_lsn) in self.capture_instances.iter() {
365            let requested_lsn = requested_lsn
366                .as_ref()
367                .expect("initialized all values above");
368
369            // Get the minimum Lsn available for this instance.
370            let available_lsn = crate::inspect::get_min_lsn(self.client, &*instance).await?;
371
372            // If we cannot start at our desired LSN, we must return an error!.
373            if *requested_lsn < available_lsn {
374                return Err(CdcError::LsnNotAvailable {
375                    capture_instance: Arc::clone(instance),
376                    requested: *requested_lsn,
377                    minimum: available_lsn,
378                }
379                .into());
380            }
381        }
382
383        Ok(())
384    }
385
386    /// If CDC was recently enabled on an instance of SQL Server then it will report
387    /// `NULL` for the minimum LSN of a capture instance and/or the maximum LSN of the
388    /// entire database.
389    ///
390    /// This method runs a retry loop that waits for the upstream DB to report good
391    /// values. It should be called before taking the initial [`CdcStream::snapshot`]
392    /// to ensure the system is ready to proceed with CDC.
393    pub async fn wait_for_ready(&mut self) -> Result<(), SqlServerError> {
394        // Ensure all of the capture instances are reporting an LSN.
395        for instance in self.capture_instances.keys() {
396            crate::inspect::get_min_lsn_retry(self.client, instance, self.max_lsn_wait).await?;
397        }
398
399        // Ensure the database is reporting a max LSN.
400        crate::inspect::get_max_lsn_retry(self.client, self.max_lsn_wait).await?;
401
402        Ok(())
403    }
404}
405
406/// A change event from a [`CdcStream`].
407#[derive(Derivative)]
408#[derivative(Debug)]
409pub enum CdcEvent {
410    /// Changes have occurred upstream.
411    Data {
412        /// The capture instance these changes are for.
413        capture_instance: Arc<str>,
414        /// The LSN that this change occurred at.
415        lsn: Lsn,
416        /// The change itself.
417        changes: Vec<Operation>,
418    },
419    /// We've made progress and observed all the changes less than `next_lsn`.
420    Progress {
421        /// We've received all of the data for [`Lsn`]s __less than__ this one.
422        next_lsn: Lsn,
423    },
424}
425
426#[derive(Debug, thiserror::Error)]
427pub enum CdcError {
428    #[error(
429        "the requested LSN '{requested:?}' is less than the minimum '{minimum:?}' for `{capture_instance}'"
430    )]
431    LsnNotAvailable {
432        capture_instance: Arc<str>,
433        requested: Lsn,
434        minimum: Lsn,
435    },
436    #[error("failed to get the required column '{column_name}': {error}")]
437    RequiredColumn {
438        column_name: &'static str,
439        error: String,
440    },
441    #[error("failed to cleanup values for '{capture_instance}' at {low_water_mark}")]
442    CleanupFailed {
443        capture_instance: String,
444        low_water_mark: Lsn,
445    },
446}
447
448/// This type is used to represent the progress of each SQL Server instance in
449/// the ingestion dataflow.
450///
451/// A SQL Server LSN is a three part "number" that provides a __total order__
452/// to all transations within a database. Interally we don't really care what
453/// these parts mean, but they are:
454///
455/// 1. A Virtual Log File (VLF) sequence number, bytes [0, 4)
456/// 2. Log block number, bytes [4, 8)
457/// 3. Log record number, bytes [8, 10)
458///
459/// For more info on log sequence numbers in SQL Server see:
460/// <https://learn.microsoft.com/en-us/sql/relational-databases/sql-server-transaction-log-architecture-and-management-guide?view=sql-server-ver16#Logical_Arch>
461///
462/// Note: The derived impl of [`PartialOrd`] and [`Ord`] relies on the field
463/// ordering so do not change it.
464#[derive(
465    Default,
466    Copy,
467    Clone,
468    Debug,
469    Eq,
470    PartialEq,
471    PartialOrd,
472    Ord,
473    Hash,
474    Serialize,
475    Deserialize,
476    Arbitrary,
477)]
478pub struct Lsn {
479    /// Virtual Log File sequence number.
480    pub vlf_id: u32,
481    /// Log block number.
482    pub block_id: u32,
483    /// Log record number.
484    pub record_id: u16,
485}
486
487impl Lsn {
488    const SIZE: usize = 10;
489
490    /// Interpret the provided bytes as an [`Lsn`].
491    pub fn try_from_bytes(bytes: &[u8]) -> Result<Self, String> {
492        if bytes.len() != Self::SIZE {
493            return Err(format!("incorrect length, expected 10 got {}", bytes.len()));
494        }
495
496        let vlf_id: [u8; 4] = bytes[0..4].try_into().expect("known good length");
497        let block_id: [u8; 4] = bytes[4..8].try_into().expect("known good length");
498        let record_id: [u8; 2] = bytes[8..].try_into().expect("known good length");
499
500        Ok(Lsn {
501            vlf_id: u32::from_be_bytes(vlf_id),
502            block_id: u32::from_be_bytes(block_id),
503            record_id: u16::from_be_bytes(record_id),
504        })
505    }
506
507    /// Return the underlying byte slice for this [`Lsn`].
508    pub fn as_bytes(&self) -> [u8; 10] {
509        let mut raw: [u8; Self::SIZE] = [0; 10];
510
511        raw[0..4].copy_from_slice(&self.vlf_id.to_be_bytes());
512        raw[4..8].copy_from_slice(&self.block_id.to_be_bytes());
513        raw[8..].copy_from_slice(&self.record_id.to_be_bytes());
514
515        raw
516    }
517
518    /// Increment this [`Lsn`].
519    ///
520    /// The returned [`Lsn`] may not exist upstream yet, but it's guaranteed to
521    /// sort greater than `self`.
522    pub fn increment(self) -> Lsn {
523        let (record_id, carry) = self.record_id.overflowing_add(1);
524        let (block_id, carry) = self.block_id.overflowing_add(carry.into());
525        let (vlf_id, overflow) = self.vlf_id.overflowing_add(carry.into());
526        assert!(!overflow, "overflowed Lsn, {self:?}");
527
528        Lsn {
529            vlf_id,
530            block_id,
531            record_id,
532        }
533    }
534
535    /// Drops the `record_id` portion of the [`Lsn`] so we can fit an "abbreviation"
536    /// of this [`Lsn`] into a [`u64`], without losing the total order.
537    pub fn abbreviate(&self) -> u64 {
538        let mut abbreviated: u64 = 0;
539
540        #[allow(clippy::as_conversions)]
541        {
542            abbreviated += (self.vlf_id as u64) << 32;
543            abbreviated += self.block_id as u64;
544        }
545
546        abbreviated
547    }
548}
549
550impl fmt::Display for Lsn {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        write!(f, "{}:{}:{}", self.vlf_id, self.block_id, self.record_id)
553    }
554}
555
556impl TryFrom<&[u8]> for Lsn {
557    type Error = String;
558
559    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
560        Lsn::try_from_bytes(value)
561    }
562}
563
564impl TryFrom<Numeric> for Lsn {
565    type Error = String;
566
567    fn try_from(value: Numeric) -> Result<Self, Self::Error> {
568        if value.dec_part() != 0 {
569            return Err(format!(
570                "LSN expect Numeric(25,0), but found decimal portion {}",
571                value.dec_part()
572            ));
573        }
574        let mut decimal_lsn = value.int_part();
575        // LSN is composed of 4 bytes : 4 bytes : 2 bytes
576        // and MS provided the method to decode that here
577        // https://github.com/microsoft/sql-server-samples/blob/master/samples/features/ssms-templates/Sql/Change%20Data%20Capture/Enumeration/Create%20Function%20fn_convertnumericlsntobinary.sql
578
579        let vlf_id = u32::try_from(decimal_lsn / 10_i128.pow(15))
580            .map_err(|e| format!("Failed to decode vlf_id for lsn {decimal_lsn}: {e:?}"))?;
581        decimal_lsn -= i128::from(vlf_id) * 10_i128.pow(15);
582
583        let block_id = u32::try_from(decimal_lsn / 10_i128.pow(5))
584            .map_err(|e| format!("Failed to decode block_id for lsn {decimal_lsn}: {e:?}"))?;
585        decimal_lsn -= i128::from(block_id) * 10_i128.pow(5);
586
587        let record_id = u16::try_from(decimal_lsn)
588            .map_err(|e| format!("Failed to decode record_id for lsn {decimal_lsn}: {e:?}"))?;
589
590        Ok(Lsn {
591            vlf_id,
592            block_id,
593            record_id,
594        })
595    }
596}
597
598impl columnation::Columnation for Lsn {
599    type InnerRegion = columnation::CopyRegion<Lsn>;
600}
601
602impl timely::progress::Timestamp for Lsn {
603    // No need to describe complex summaries.
604    type Summary = ();
605
606    fn minimum() -> Self {
607        Lsn::default()
608    }
609}
610
611impl timely::progress::PathSummary<Lsn> for () {
612    fn results_in(&self, src: &Lsn) -> Option<Lsn> {
613        Some(*src)
614    }
615
616    fn followed_by(&self, _other: &Self) -> Option<Self> {
617        Some(())
618    }
619}
620
621impl timely::progress::timestamp::Refines<()> for Lsn {
622    fn to_inner(_other: ()) -> Self {
623        use timely::progress::Timestamp;
624        Self::minimum()
625    }
626    fn to_outer(self) -> () {}
627
628    fn summarize(_path: <Self as timely::progress::Timestamp>::Summary) -> () {}
629}
630
631impl timely::order::PartialOrder for Lsn {
632    fn less_equal(&self, other: &Self) -> bool {
633        self <= other
634    }
635
636    fn less_than(&self, other: &Self) -> bool {
637        self < other
638    }
639}
640impl timely::order::TotalOrder for Lsn {}
641
642/// Structured format of an [`Lsn`].
643///
644/// Note: The derived impl of [`PartialOrd`] and [`Ord`] relies on the field
645/// ordering so do not change it.
646#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
647pub struct StructuredLsn {
648    vlf_id: u32,
649    block_id: u32,
650    record_id: u16,
651}
652
653/// When querying CDC functions like `cdc.fn_cdc_get_all_changes_<capture_instance>` this governs
654/// what content is returned.
655///
656/// Note: There exists another option `All` that exclude the _before_ value from an `UPDATE`. We
657/// don't support this for SQL Server sources yet, so it's not included in this enum.
658///
659/// See: <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#row_filter_option>
660#[derive(Debug, Copy, Clone)]
661pub enum RowFilterOption {
662    /// Includes both the before and after values of a row when changed because of an `UPDATE`.
663    AllUpdateOld,
664}
665
666impl RowFilterOption {
667    /// Returns this option formatted in a way that can be used in a query.
668    pub fn to_sql_string(&self) -> &'static str {
669        match self {
670            RowFilterOption::AllUpdateOld => "all update old",
671        }
672    }
673}
674
675impl fmt::Display for RowFilterOption {
676    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
677        write!(f, "{}", self.to_sql_string())
678    }
679}
680
681/// Identifies what change was made to the SQL Server table tracked by CDC.
682#[derive(Debug)]
683pub enum Operation {
684    /// Row was `INSERT`-ed.
685    Insert(tiberius::Row),
686    /// Row was `DELETE`-ed.
687    Delete(tiberius::Row),
688    /// Original value of the row when `UPDATE`-ed.
689    UpdateOld(tiberius::Row),
690    /// New value of the row when `UPDATE`-ed.
691    UpdateNew(tiberius::Row),
692}
693
694impl Operation {
695    /// Parse the provided [`tiberius::Row`] to determine what [`Operation`] occurred.
696    ///
697    /// See <https://learn.microsoft.com/en-us/sql/relational-databases/system-functions/cdc-fn-cdc-get-all-changes-capture-instance-transact-sql?view=sql-server-ver16#table-returned>.
698    fn try_parse(data: tiberius::Row) -> Result<(Lsn, Self), SqlServerError> {
699        static START_LSN_COLUMN: &str = "__$start_lsn";
700        static OPERATION_COLUMN: &str = "__$operation";
701
702        let lsn: &[u8] = data
703            .try_get(START_LSN_COLUMN)
704            .map_err(|e| CdcError::RequiredColumn {
705                column_name: START_LSN_COLUMN,
706                error: e.to_string(),
707            })?
708            .ok_or_else(|| CdcError::RequiredColumn {
709                column_name: START_LSN_COLUMN,
710                error: "got null value".to_string(),
711            })?;
712        let operation: i32 = data
713            .try_get(OPERATION_COLUMN)
714            .map_err(|e| CdcError::RequiredColumn {
715                column_name: OPERATION_COLUMN,
716                error: e.to_string(),
717            })?
718            .ok_or_else(|| CdcError::RequiredColumn {
719                column_name: OPERATION_COLUMN,
720                error: "got null value".to_string(),
721            })?;
722
723        let lsn = Lsn::try_from(lsn).map_err(|msg| SqlServerError::InvalidData {
724            column_name: START_LSN_COLUMN.to_string(),
725            error: msg,
726        })?;
727        let operation = match operation {
728            1 => Operation::Delete(data),
729            2 => Operation::Insert(data),
730            3 => Operation::UpdateOld(data),
731            4 => Operation::UpdateNew(data),
732            other => {
733                return Err(SqlServerError::InvalidData {
734                    column_name: OPERATION_COLUMN.to_string(),
735                    error: format!("unrecognized operation {other}"),
736                });
737            }
738        };
739
740        Ok((lsn, operation))
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use super::Lsn;
747    use proptest::prelude::*;
748    use tiberius::numeric::Numeric;
749
750    #[mz_ore::test]
751    fn smoketest_lsn_ordering() {
752        let a = hex::decode("0000003D000019B80004").unwrap();
753        let a = Lsn::try_from(&a[..]).unwrap();
754
755        let b = hex::decode("0000003D000019F00011").unwrap();
756        let b = Lsn::try_from(&b[..]).unwrap();
757
758        let c = hex::decode("0000003D00001A500003").unwrap();
759        let c = Lsn::try_from(&c[..]).unwrap();
760
761        assert!(a < b);
762        assert!(b < c);
763        assert!(a < c);
764
765        assert_eq!(a, a);
766        assert_eq!(b, b);
767        assert_eq!(c, c);
768    }
769
770    #[mz_ore::test]
771    fn smoketest_lsn_roundtrips() {
772        #[track_caller]
773        fn test_case(hex: &str) {
774            let og = hex::decode(hex).unwrap();
775            let lsn = Lsn::try_from(&og[..]).unwrap();
776            let rnd = lsn.as_bytes();
777            assert_eq!(og[..], rnd[..]);
778        }
779
780        test_case("0000003D000019B80004");
781        test_case("0000003D000019F00011");
782        test_case("0000003D00001A500003");
783    }
784
785    #[mz_ore::test]
786    fn proptest_lsn_roundtrips() {
787        #[track_caller]
788        fn test_case(bytes: [u8; 10]) {
789            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
790            let rnd = lsn.as_bytes();
791            assert_eq!(&bytes[..], &rnd[..]);
792        }
793        proptest!(|(random_bytes in any::<[u8; 10]>())| {
794            test_case(random_bytes)
795        })
796    }
797
798    #[mz_ore::test]
799    fn proptest_lsn_increment() {
800        #[track_caller]
801        fn test_case(bytes: [u8; 10]) {
802            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
803            let new = lsn.increment();
804            assert!(lsn < new);
805        }
806        proptest!(|(random_bytes in any::<[u8; 10]>())| {
807            test_case(random_bytes)
808        })
809    }
810
811    #[mz_ore::test]
812    fn proptest_lsn_abbreviate_total_order() {
813        #[track_caller]
814        fn test_case(bytes: [u8; 10], num_increment: u8) {
815            let lsn = Lsn::try_from_bytes(&bytes[..]).unwrap();
816            let mut new = lsn;
817            for _ in 0..num_increment {
818                new = new.increment();
819            }
820
821            let a = lsn.abbreviate();
822            let b = new.abbreviate();
823
824            assert!(a <= b);
825        }
826        proptest!(|(random_bytes in any::<[u8; 10]>(), num_increment in any::<u8>())| {
827            test_case(random_bytes, num_increment)
828        })
829    }
830
831    #[mz_ore::test]
832    fn test_numeric_lsn_ordering() {
833        let a = Lsn::try_from(Numeric::new_with_scale(45_0000008784_00001_i128, 0)).unwrap();
834        let b = Lsn::try_from(Numeric::new_with_scale(45_0000008784_00002_i128, 0)).unwrap();
835        let c = Lsn::try_from(Numeric::new_with_scale(45_0000008785_00002_i128, 0)).unwrap();
836        let d = Lsn::try_from(Numeric::new_with_scale(49_0000008784_00002_i128, 0)).unwrap();
837        assert!(a < b);
838        assert!(b < c);
839        assert!(c < d);
840        assert!(a < d);
841
842        assert_eq!(a, a);
843        assert_eq!(b, b);
844        assert_eq!(c, c);
845        assert_eq!(d, d);
846    }
847
848    #[mz_ore::test]
849    fn test_numeric_lsn_invalid() {
850        let with_decimal = Numeric::new_with_scale(1, 20);
851        assert!(Lsn::try_from(with_decimal).is_err());
852
853        for v in [
854            4294967296_0000000000_00000_i128, // vlf_id is too large
855            1_4294967296_00000_i128,          // block_id is too large
856            1_0000000001_65536_i128,          // record_id is too large
857            -49_0000008784_00002_i128,        // negative is invalid
858        ] {
859            let invalid_lsn = Numeric::new_with_scale(v, 0);
860            assert!(Lsn::try_from(invalid_lsn).is_err());
861        }
862    }
863}