Skip to main content

mz_sql_server_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::result::ResultExt;
21use mz_repr::SqlScalarType;
22use smallvec::{SmallVec, smallvec};
23use tiberius::ToSql;
24use tokio::net::TcpStream;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
26use tokio::sync::oneshot;
27use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
28
29pub mod cdc;
30pub mod config;
31pub mod desc;
32pub mod inspect;
33
34pub use config::Config;
35pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
36
37use crate::cdc::Lsn;
38use crate::config::TunnelConfig;
39use crate::desc::SqlServerColumnDecodeType;
40
41/// Higher level wrapper around a [`tiberius::Client`] that models transaction
42/// management like other database clients.
43#[derive(Debug)]
44pub struct Client {
45    tx: UnboundedSender<Request>,
46    // The configuration used to create this client.
47    config: Config,
48}
49// While a Client could implement Clone, it's not obvious how multiple Clients
50// using the same SQL Server connection would interact, so ban it for now.
51static_assertions::assert_not_impl_all!(Client: Clone);
52
53impl Client {
54    /// Connect to the specified SQL Server instance, returning a [`Client`]
55    /// that can be used to query it and a [`Connection`] that must be polled
56    /// to send and receive results.
57    ///
58    /// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements
59    /// IntoFuture and does the default good thing of moving the `Connection`
60    /// into a tokio task? And a `.raw()` option that will instead return both
61    /// the Client and Connection for manual polling.
62    pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
63        // Setup our tunnelling and return any resources that need to be kept
64        // alive for the duration of the connection.
65        let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
66            TunnelConfig::Direct { resolved_addresses } => {
67                let tcp = if resolved_addresses.is_empty() {
68                    TcpStream::connect(config.inner.get_addr()).await
69                } else {
70                    TcpStream::connect(resolved_addresses.as_ref()).await
71                }
72                .context("direct")?;
73                (tcp, None)
74            }
75            TunnelConfig::Ssh {
76                config: ssh_config,
77                manager,
78                timeout,
79                host,
80                port,
81            } => {
82                // N.B. If this tunnel is dropped it will close so we need to
83                // keep it alive for the duration of the connection.
84                let tunnel = manager
85                    .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
86                    .await?;
87                let tcp = TcpStream::connect(tunnel.local_addr())
88                    .await
89                    .context("ssh tunnel")?;
90
91                (tcp, Some(Box::new(tunnel)))
92            }
93            TunnelConfig::AwsPrivatelink {
94                connection_id,
95                port,
96            } => {
97                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
98                let tcp = TcpStream::connect((privatelink_host.as_str(), *port))
99                    .await
100                    .context(format!("aws privatelink {:?}", privatelink_host))?;
101
102                (tcp, None)
103            }
104        };
105
106        tcp.set_nodelay(true)?;
107
108        let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
109        mz_ore::task::spawn(|| "sql-server-client-connection", async move {
110            connection.await
111        });
112
113        Ok(client)
114    }
115
116    /// Create a new Client instance with the same configuration that created
117    /// this configuration.
118    pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
119        Self::connect(self.config.clone()).await
120    }
121
122    pub async fn connect_raw(
123        config: Config,
124        tcp: tokio::net::TcpStream,
125        resources: Option<Box<dyn Any + Send + Sync>>,
126    ) -> Result<(Self, Connection), SqlServerError> {
127        let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
128        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
129
130        // TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have.
131
132        Ok((
133            Client { tx, config },
134            Connection {
135                rx,
136                client,
137                _resources: resources,
138            },
139        ))
140    }
141
142    /// Executes SQL statements in SQL Server, returning the number of rows effected.
143    ///
144    /// Passthrough method for [`tiberius::Client::execute`].
145    ///
146    /// Note: The returned [`Future`] does not need to be awaited for the query
147    /// to be sent.
148    ///
149    /// [`Future`]: std::future::Future
150    pub async fn execute<'a>(
151        &mut self,
152        query: impl Into<Cow<'a, str>>,
153        params: &[&dyn ToSql],
154    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
155        let (tx, rx) = tokio::sync::oneshot::channel();
156
157        let params = params
158            .iter()
159            .map(|p| OwnedColumnData::from(p.to_sql()))
160            .collect();
161        let kind = RequestKind::Execute {
162            query: query.into().to_string(),
163            params,
164        };
165        self.tx
166            .send(Request { tx, kind })
167            .context("sending request")?;
168
169        let response = rx.await.context("channel")??;
170        match response {
171            Response::Execute { rows_affected } => Ok(rows_affected),
172            other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
173                Err(SqlServerError::ProgrammingError(format!(
174                    "expected Response::Execute, got {other:?}"
175                )))
176            }
177        }
178    }
179
180    /// Executes SQL statements in SQL Server, returning the resulting rows.
181    ///
182    /// Passthrough method for [`tiberius::Client::query`].
183    ///
184    /// Note: The returned [`Future`] does not need to be awaited for the query
185    /// to be sent.
186    ///
187    /// [`Future`]: std::future::Future
188    pub async fn query<'a>(
189        &mut self,
190        query: impl Into<Cow<'a, str>>,
191        params: &[&dyn tiberius::ToSql],
192    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
193        let (tx, rx) = tokio::sync::oneshot::channel();
194
195        let params = params
196            .iter()
197            .map(|p| OwnedColumnData::from(p.to_sql()))
198            .collect();
199        let kind = RequestKind::Query {
200            query: query.into().to_string(),
201            params,
202        };
203        self.tx
204            .send(Request { tx, kind })
205            .context("sending request")?;
206
207        let response = rx.await.context("channel")??;
208        match response {
209            Response::Rows(rows) => Ok(rows),
210            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
211                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
212            ),
213        }
214    }
215
216    /// Executes SQL statements in SQL Server, returning a [`Stream`] of
217    /// resulting rows.
218    ///
219    /// Passthrough method for [`tiberius::Client::query`].
220    pub fn query_streaming<'c, 'q, Q>(
221        &'c mut self,
222        query: Q,
223        params: &[&dyn tiberius::ToSql],
224    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
225    where
226        Q: Into<Cow<'q, str>>,
227    {
228        let (tx, rx) = tokio::sync::oneshot::channel();
229        let params = params
230            .iter()
231            .map(|p| OwnedColumnData::from(p.to_sql()))
232            .collect();
233        let kind = RequestKind::QueryStreamed {
234            query: query.into().to_string(),
235            params,
236        };
237
238        // Make our initial request which will return a Stream of Rows.
239        let request_future = async move {
240            self.tx
241                .send(Request { tx, kind })
242                .context("sending request")?;
243
244            let response = rx.await.context("channel")??;
245            match response {
246                Response::RowStream { stream } => {
247                    Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
248                }
249                other @ Response::Execute { .. } | other @ Response::Rows(_) => {
250                    Err(SqlServerError::ProgrammingError(format!(
251                        "expected Response::Rows, got {other:?}"
252                    )))
253                }
254            }
255        };
256
257        // "flatten" our initial request into the returned stream.
258        futures::stream::once(request_future).try_flatten()
259    }
260
261    /// Executes multiple queries, delimited with `;` and return multiple
262    /// result sets; one for each query.
263    ///
264    /// Passthrough method for [`tiberius::Client::simple_query`].
265    ///
266    /// Note: The returned [`Future`] does not need to be awaited for the query
267    /// to be sent.
268    ///
269    /// [`Future`]: std::future::Future
270    pub async fn simple_query<'a>(
271        &mut self,
272        query: impl Into<Cow<'a, str>>,
273    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
274        let (tx, rx) = tokio::sync::oneshot::channel();
275        let kind = RequestKind::SimpleQuery {
276            query: query.into().to_string(),
277        };
278        self.tx
279            .send(Request { tx, kind })
280            .context("sending request")?;
281
282        let response = rx.await.context("channel")??;
283        match response {
284            Response::Rows(rows) => Ok(rows),
285            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
286                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
287            ),
288        }
289    }
290
291    /// Starts a transaction which is automatically rolled back on drop.
292    ///
293    /// To commit or rollback the transaction, see [`Transaction::commit`] and
294    /// [`Transaction::rollback`] respectively.
295    pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
296        Transaction::new(self).await
297    }
298
299    /// Sets the transaction isolation level for the current session.
300    pub async fn set_transaction_isolation(
301        &mut self,
302        level: TransactionIsolationLevel,
303    ) -> Result<(), SqlServerError> {
304        let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
305        self.simple_query(query).await?;
306        Ok(())
307    }
308
309    /// Returns the current transaction isolation level for the current session.
310    pub async fn get_transaction_isolation(
311        &mut self,
312    ) -> Result<TransactionIsolationLevel, SqlServerError> {
313        const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
314        let rows = self.simple_query(QUERY).await?;
315        match &rows[..] {
316            [row] => {
317                let val: i16 = row
318                    .try_get(0)
319                    .context("getting 0th column")?
320                    .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
321                let level = TransactionIsolationLevel::try_from_sql_server(val)?;
322                Ok(level)
323            }
324            other => Err(SqlServerError::InvariantViolated(format!(
325                "expected one row, got {other:?}"
326            ))),
327        }
328    }
329
330    /// Return a [`CdcStream`] that can be used to track changes for the specified
331    /// `capture_instances`.
332    ///
333    /// [`CdcStream`]: crate::cdc::CdcStream
334    pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
335    where
336        I: IntoIterator,
337        I::Item: Into<Arc<str>>,
338        M: SqlServerCdcMetrics,
339    {
340        let instances = capture_instances
341            .into_iter()
342            .map(|i| (i.into(), None))
343            .collect();
344        crate::cdc::CdcStream::new(self, instances, metrics)
345    }
346}
347
348/// A stream of [`tiberius::Row`]s.
349pub type RowStream<'a> =
350    Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
351
352#[derive(Debug)]
353pub struct Transaction<'a> {
354    client: &'a mut Client,
355    closed: bool,
356}
357
358impl<'a> Transaction<'a> {
359    async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
360        // Construct the guard *before* awaiting BEGIN to avoid a potential race where
361        // transaction is started on remote, but the transaction is cancelled before returning.
362        let tx = Transaction {
363            client,
364            closed: false,
365        };
366        let results = tx
367            .client
368            .simple_query("BEGIN TRANSACTION")
369            .await
370            .context("begin")?;
371        if !results.is_empty() {
372            Err(SqlServerError::InvariantViolated(format!(
373                "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
374            )))
375        } else {
376            Ok(tx)
377        }
378    }
379
380    /// Creates a savepoint via `SAVE TRANSACTION` with the provided name.
381    /// Creating a savepoint forces a write to the transaction log, which will associate an
382    /// [`Lsn`] with the current transaction.
383    ///
384    /// The savepoint name must follow rules for SQL Server identifiers
385    /// - starts with letter or underscore
386    /// - only contains letters, digits, and underscores
387    /// - no reserved words
388    /// - 32 char max
389    pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
390        // Limit the name checks to prevent sending a potentially dangerous string to the SQL Server.
391        // We prefer the server do the majority of the validation.
392        if savepoint_name.is_empty()
393            || !savepoint_name
394                .chars()
395                .all(|c| c.is_alphanumeric() || c == '_')
396        {
397            Err(SqlServerError::ProgrammingError(format!(
398                "Invalid savepoint name: '{savepoint_name}"
399            )))?;
400        }
401
402        let stmt = format!("SAVE TRANSACTION {}", quote_identifier(savepoint_name));
403        let _result = self.client.simple_query(stmt).await?;
404        Ok(())
405    }
406
407    /// Retrieve the [`Lsn`] associated with the current session.
408    ///
409    /// MS SQL Server will not assign an [`Lsn`] until a write is performed (e.g. via `SAVE TRANSACTION`).
410    pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
411        static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
412            FROM sys.dm_tran_database_transactions dt \
413            JOIN sys.dm_tran_current_transaction ct \
414                ON ct.transaction_id = dt.transaction_id \
415            WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
416        let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
417        crate::inspect::parse_numeric_lsn(&result)
418    }
419
420    /// Lock the provided table to prevent writes but allow reads, uses `(TABLOCK, HOLDLOCK)`.
421    ///
422    /// This will set the transaction isolation level to `READ COMMITTED` and then obtain the
423    /// lock using a `SELECT` statement that will not read any data from the table.
424    /// The lock is released after transaction commit or rollback.
425    pub async fn lock_table_shared(
426        &mut self,
427        schema: &str,
428        table: &str,
429    ) -> Result<(), SqlServerError> {
430        // Locks in MS SQL server do not behave the same way under all isolation levels. In testing,
431        // it has been observed that if the isolation level is SNAPSHOT, these locks are ineffective.
432        static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
433        // This query probably seems odd, but there is no LOCK command in MS SQL. Locks are specified
434        // in SELECT using the WITH keyword.  This query does not need to return any rows to lock the table,
435        // hence the 1=0, which is something short that always evaluates to false in this universe.
436        let query = format!(
437            "{SET_READ_COMMITTED}\nSELECT * FROM {schema}.{table} WITH (TABLOCK, HOLDLOCK) WHERE 1=0;",
438            schema = quote_identifier(schema),
439            table = quote_identifier(table)
440        );
441        let _result = self.client.simple_query(query).await?;
442        Ok(())
443    }
444
445    /// See [`Client::execute`].
446    pub async fn execute<'q>(
447        &mut self,
448        query: impl Into<Cow<'q, str>>,
449        params: &[&dyn ToSql],
450    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
451        self.client.execute(query, params).await
452    }
453
454    /// See [`Client::query`].
455    pub async fn query<'q>(
456        &mut self,
457        query: impl Into<Cow<'q, str>>,
458        params: &[&dyn tiberius::ToSql],
459    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
460        self.client.query(query, params).await
461    }
462
463    /// See [`Client::query_streaming`]
464    pub fn query_streaming<'c, 'q, Q>(
465        &'c mut self,
466        query: Q,
467        params: &[&dyn tiberius::ToSql],
468    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
469    where
470        Q: Into<Cow<'q, str>>,
471    {
472        self.client.query_streaming(query, params)
473    }
474
475    /// See [`Client::simple_query`].
476    pub async fn simple_query<'q>(
477        &mut self,
478        query: impl Into<Cow<'q, str>>,
479    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
480        self.client.simple_query(query).await
481    }
482
483    /// Rollback the [`Transaction`].
484    pub async fn rollback(mut self) -> Result<(), SqlServerError> {
485        static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
486        // N.B. Mark closed _before_ running the query. This prevents us from
487        // double closing the transaction if this query itself fails.
488        self.closed = true;
489        self.client.simple_query(ROLLBACK_QUERY).await?;
490        Ok(())
491    }
492
493    /// Commit the [`Transaction`].
494    pub async fn commit(mut self) -> Result<(), SqlServerError> {
495        static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
496        // N.B. Mark closed _before_ running the query. This prevents us from
497        // double closing the transaction if this query itself fails.
498        self.closed = true;
499        self.client.simple_query(COMMIT_QUERY).await?;
500        Ok(())
501    }
502}
503
504impl Drop for Transaction<'_> {
505    fn drop(&mut self) {
506        if !self.closed {
507            // Send the ROLLBACK request directly through the channel, bypassing
508            // the async `simple_query` method. We cannot `.await` in `Drop`, and
509            // merely calling an async fn without awaiting it does nothing (the
510            // future is never polled so the channel send inside never executes).
511            //
512            // We intentionally drop the response receiver since we cannot await
513            // it in a synchronous context. The Connection task will execute the
514            // ROLLBACK and discard the response when the receiver is gone.
515            let (tx, _rx) = oneshot::channel();
516            let kind = RequestKind::SimpleQuery {
517                query: "ROLLBACK TRANSACTION".to_string(),
518            };
519            let _ = self.client.tx.send(Request { tx, kind });
520        }
521    }
522}
523
524/// Transaction isolation levels defined by Microsoft's SQL Server.
525///
526/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql>
527#[derive(Debug, PartialEq, Eq)]
528pub enum TransactionIsolationLevel {
529    ReadUncommitted,
530    ReadCommitted,
531    RepeatableRead,
532    Snapshot,
533    Serializable,
534}
535
536impl TransactionIsolationLevel {
537    /// Return the string representation of a transaction isolation level.
538    fn as_str(&self) -> &'static str {
539        match self {
540            TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
541            TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
542            TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
543            TransactionIsolationLevel::Snapshot => "SNAPSHOT",
544            TransactionIsolationLevel::Serializable => "SERIALIZABLE",
545        }
546    }
547
548    /// Try to parse a [`TransactionIsolationLevel`] from the value returned from SQL Server.
549    fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
550        let level = match val {
551            1 => TransactionIsolationLevel::ReadUncommitted,
552            2 => TransactionIsolationLevel::ReadCommitted,
553            3 => TransactionIsolationLevel::RepeatableRead,
554            4 => TransactionIsolationLevel::Serializable,
555            5 => TransactionIsolationLevel::Snapshot,
556            x => anyhow::bail!("unknown level {x}"),
557        };
558        Ok(level)
559    }
560}
561
562#[derive(Derivative)]
563#[derivative(Debug)]
564enum Response {
565    Execute {
566        rows_affected: SmallVec<[u64; 1]>,
567    },
568    Rows(SmallVec<[tiberius::Row; 1]>),
569    RowStream {
570        #[derivative(Debug = "ignore")]
571        stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
572    },
573}
574
575#[derive(Debug)]
576struct Request {
577    tx: oneshot::Sender<Result<Response, SqlServerError>>,
578    kind: RequestKind,
579}
580
581#[derive(Derivative)]
582#[derivative(Debug)]
583enum RequestKind {
584    Execute {
585        query: String,
586        #[derivative(Debug = "ignore")]
587        params: SmallVec<[OwnedColumnData; 4]>,
588    },
589    Query {
590        query: String,
591        #[derivative(Debug = "ignore")]
592        params: SmallVec<[OwnedColumnData; 4]>,
593    },
594    QueryStreamed {
595        query: String,
596        #[derivative(Debug = "ignore")]
597        params: SmallVec<[OwnedColumnData; 4]>,
598    },
599    SimpleQuery {
600        query: String,
601    },
602}
603
604pub struct Connection {
605    /// Other end of the channel that [`Client`] holds.
606    rx: UnboundedReceiver<Request>,
607    /// Actual client that we use to send requests.
608    client: tiberius::Client<Compat<TcpStream>>,
609    /// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection.
610    _resources: Option<Box<dyn Any + Send + Sync>>,
611}
612
613impl Connection {
614    async fn run(mut self) {
615        while let Some(Request { tx, kind }) = self.rx.recv().await {
616            tracing::trace!(?kind, "processing SQL Server query");
617            let result = Connection::handle_request(&mut self.client, kind).await;
618            let (response, maybe_extra_work) = match result {
619                Ok((response, work)) => (Ok(response), work),
620                Err(err) => (Err(err), None),
621            };
622
623            // We don't care if our listener for this query has gone away.
624            let _ = tx.send(response);
625
626            // After we handle a request there might still be something in-flight
627            // that we need to continue driving, e.g. when the response is a
628            // Stream of Rows.
629            if let Some(extra_work) = maybe_extra_work {
630                extra_work.await;
631            }
632        }
633        tracing::debug!("channel closed, SQL Server InnerClient shutting down");
634    }
635
636    async fn handle_request<'c>(
637        client: &'c mut tiberius::Client<Compat<TcpStream>>,
638        kind: RequestKind,
639    ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
640        match kind {
641            RequestKind::Execute { query, params } => {
642                #[allow(clippy::as_conversions)]
643                let params: SmallVec<[&dyn ToSql; 4]> =
644                    params.iter().map(|x| x as &dyn ToSql).collect();
645                let result = client.execute(query, &params[..]).await?;
646
647                match result.rows_affected() {
648                    rows_affected => {
649                        let response = Response::Execute {
650                            rows_affected: rows_affected.into(),
651                        };
652                        Ok((response, None))
653                    }
654                }
655            }
656            RequestKind::Query { query, params } => {
657                #[allow(clippy::as_conversions)]
658                let params: SmallVec<[&dyn ToSql; 4]> =
659                    params.iter().map(|x| x as &dyn ToSql).collect();
660                let result = client.query(query, params.as_slice()).await?;
661
662                let mut results = result.into_results().await.context("into results")?;
663                if results.is_empty() {
664                    Ok((Response::Rows(smallvec![]), None))
665                } else if results.len() == 1 {
666                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
667                    // push onto a SmallVec to avoid the heap allocations.
668                    let rows = results.pop().expect("checked len").into();
669                    Ok((Response::Rows(rows), None))
670                } else {
671                    Err(SqlServerError::ProgrammingError(format!(
672                        "Query only supports 1 statement, got {}",
673                        results.len()
674                    )))
675                }
676            }
677            RequestKind::QueryStreamed { query, params } => {
678                #[allow(clippy::as_conversions)]
679                let params: SmallVec<[&dyn ToSql; 4]> =
680                    params.iter().map(|x| x as &dyn ToSql).collect();
681                let result = client.query(query, params.as_slice()).await?;
682
683                // ~~ Rust Lifetimes ~~
684                //
685                // What's going on here, why do we have some extra channel and
686                // this 'work' future?
687                //
688                // Remember, we run the actual `tiberius::Client` in a separate
689                // `tokio::task` and the `mz::Client` sends query requests via
690                // a channel, this allows us to "automatically" manage
691                // transactions.
692                //
693                // But the returned `QueryStream` from a `tiberius::Client` has
694                // a lifetime associated with said client running in this
695                // separate task. Thus we cannot send the `QueryStream` back to
696                // the `mz::Client` because the lifetime of these two clients
697                // is not linked at all. The fix is to create a separate owned
698                // channel and return the receiving end, while this work future
699                // pulls events off the `QueryStream` and sends them over the
700                // channel we just returned.
701                let (tx, rx) = tokio::sync::mpsc::channel(256);
702                let work = Box::pin(async move {
703                    let mut stream = result.into_row_stream();
704                    while let Some(result) = stream.next().await {
705                        if let Err(err) = tx.send(result.err_into()).await {
706                            tracing::warn!(?err, "SQL Server row stream receiver went away");
707                        }
708                    }
709                    tracing::info!("SQL Server row stream complete");
710                });
711
712                Ok((Response::RowStream { stream: rx }, Some(work)))
713            }
714            RequestKind::SimpleQuery { query } => {
715                let result = client.simple_query(query).await?;
716
717                let mut results = result.into_results().await.context("into results")?;
718                if results.is_empty() {
719                    Ok((Response::Rows(smallvec![]), None))
720                } else if results.len() == 1 {
721                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
722                    // push onto a SmallVec to avoid the heap allocations.
723                    let rows = results.pop().expect("checked len").into();
724                    Ok((Response::Rows(rows), None))
725                } else {
726                    Err(SqlServerError::ProgrammingError(format!(
727                        "Simple query only supports 1 statement, got {}",
728                        results.len()
729                    )))
730                }
731            }
732        }
733    }
734}
735
736impl IntoFuture for Connection {
737    type Output = ();
738    type IntoFuture = BoxFuture<'static, Self::Output>;
739
740    fn into_future(self) -> Self::IntoFuture {
741        self.run().boxed()
742    }
743}
744
745/// Owned version of [`tiberius::ColumnData`] that can be more easily sent
746/// across threads or through a channel.
747#[derive(Debug)]
748enum OwnedColumnData {
749    U8(Option<u8>),
750    I16(Option<i16>),
751    I32(Option<i32>),
752    I64(Option<i64>),
753    F32(Option<f32>),
754    F64(Option<f64>),
755    Bit(Option<bool>),
756    String(Option<String>),
757    Guid(Option<uuid::Uuid>),
758    Binary(Option<Vec<u8>>),
759    Numeric(Option<tiberius::numeric::Numeric>),
760    Xml(Option<tiberius::xml::XmlData>),
761    DateTime(Option<tiberius::time::DateTime>),
762    SmallDateTime(Option<tiberius::time::SmallDateTime>),
763    Time(Option<tiberius::time::Time>),
764    Date(Option<tiberius::time::Date>),
765    DateTime2(Option<tiberius::time::DateTime2>),
766    DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
767}
768
769impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
770    fn from(value: tiberius::ColumnData<'a>) -> Self {
771        match value {
772            tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
773            tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
774            tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
775            tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
776            tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
777            tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
778            tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
779            tiberius::ColumnData::String(inner) => {
780                OwnedColumnData::String(inner.map(|s| s.to_string()))
781            }
782            tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
783            tiberius::ColumnData::Binary(inner) => {
784                OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
785            }
786            tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
787            tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
788            tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
789            tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
790            tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
791            tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
792            tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
793            tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
794        }
795    }
796}
797
798impl tiberius::ToSql for OwnedColumnData {
799    fn to_sql(&self) -> tiberius::ColumnData<'_> {
800        match self {
801            OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
802            OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
803            OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
804            OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
805            OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
806            OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
807            OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
808            OwnedColumnData::String(inner) => {
809                tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
810            }
811            OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
812            OwnedColumnData::Binary(inner) => {
813                tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
814            }
815            OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
816            OwnedColumnData::Xml(inner) => {
817                tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
818            }
819            OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
820            OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
821            OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
822            OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
823            OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
824            OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
825        }
826    }
827}
828
829impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
830    fn from(value: &'a T) -> Self {
831        OwnedColumnData::from(value.to_sql())
832    }
833}
834
835#[derive(Debug, thiserror::Error)]
836pub enum SqlServerError {
837    #[error(transparent)]
838    SqlServer(#[from] tiberius::error::Error),
839    #[error(transparent)]
840    CdcError(#[from] crate::cdc::CdcError),
841    #[error("expected column '{0}' to be present")]
842    MissingColumn(&'static str),
843    #[error("sql server client encountered I/O error: {0}")]
844    IO(#[from] tokio::io::Error),
845    #[error("found invalid data in the column '{column_name}': {error}")]
846    InvalidData { column_name: String, error: String },
847    #[error("got back a null value when querying for the LSN")]
848    NullLsn,
849    #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
850    InvalidSystemSetting {
851        name: String,
852        expected: String,
853        actual: String,
854    },
855    #[error("invariant was violated: {0}")]
856    InvariantViolated(String),
857    #[error(transparent)]
858    Generic(#[from] anyhow::Error),
859    #[error("programming error! {0}")]
860    ProgrammingError(String),
861    #[error(
862        "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
863    )]
864    AuthorizationError {
865        tables: String,
866        capture_instances: String,
867    },
868}
869
870/// Errors returned from decoding SQL Server rows.
871///
872/// **PLEASE READ**
873///
874/// The string representation of this error type is **durably stored** in a source and thus this
875/// error type needs to be **stable** across releases. For example, if in v11 of Materialize we
876/// fail to decode `Row(["foo bar"])` from SQL Server, we will record the error in the source's
877/// Persist shard. If in v12 of Materialize the user deletes the `Row(["foo bar"])` from their
878/// upstream instance, we need to perfectly retract the error we previously committed.
879///
880/// This means be **very** careful when changing this type.
881#[derive(Debug, thiserror::Error)]
882pub enum SqlServerDecodeError {
883    #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
884    InvalidColumn {
885        column_name: String,
886        as_type: &'static str,
887    },
888    #[error("found invalid data in the column '{column_name}': {error}")]
889    InvalidData { column_name: String, error: String },
890    #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
891    Unsupported {
892        sql_server_type: SqlServerColumnDecodeType,
893        mz_type: SqlScalarType,
894    },
895}
896
897impl SqlServerDecodeError {
898    fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
899        // These error messages need to remain stable, do not change them.
900        let error = match error {
901            mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
902        };
903        SqlServerDecodeError::InvalidData {
904            column_name: name.to_string(),
905            error: error.to_string(),
906        }
907    }
908
909    fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
910        // These error messages need to remain stable, do not change them.
911        let error = match error {
912            mz_repr::adt::date::DateError::OutOfRange => "out of range",
913        };
914        SqlServerDecodeError::InvalidData {
915            column_name: name.to_string(),
916            error: error.to_string(),
917        }
918    }
919
920    fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
921        SqlServerDecodeError::InvalidData {
922            column_name: name.to_string(),
923            error: format!("expected {expected_chars} chars found {found_chars}"),
924        }
925    }
926
927    fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
928        SqlServerDecodeError::InvalidData {
929            column_name: name.to_string(),
930            error: format!("expected max {max_chars} chars found {found_chars}"),
931        }
932    }
933
934    fn invalid_column(name: &str, as_type: &'static str) -> Self {
935        SqlServerDecodeError::InvalidColumn {
936            column_name: name.to_string(),
937            as_type,
938        }
939    }
940}
941
942/// Quotes the provided string using '[]' to match SQL Server `QUOTENAME` function. This form
943/// of quotes is unaffected by the SQL Server setting `SET QUOTED_IDENTIFIER`.
944///
945/// See:
946/// - <https://learn.microsoft.com/en-us/sql/t-sql/functions/quotename-transact-sql?view=sql-server-ver17>
947/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-quoted-identifier-transact-sql?view=sql-server-ver17>
948pub fn quote_identifier(ident: &str) -> String {
949    let mut quoted = ident.replace(']', "]]");
950    quoted.insert(0, '[');
951    quoted.push(']');
952    quoted
953}
954
955pub trait SqlServerCdcMetrics {
956    /// Called before the table lock is aquired
957    fn snapshot_table_lock_start(&self, table_name: &str);
958    /// Called after the table lock is released
959    fn snapshot_table_lock_end(&self, table_name: &str);
960}
961
962/// A simple implementation of [`SqlServerCdcMetrics`] that uses the tracing framework to log
963/// the start and end conditions.
964pub struct LoggingSqlServerCdcMetrics;
965
966impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
967    fn snapshot_table_lock_start(&self, table_name: &str) {
968        tracing::info!("snapshot_table_lock_start: {table_name}");
969    }
970
971    fn snapshot_table_lock_end(&self, table_name: &str) {
972        tracing::info!("snapshot_table_lock_end: {table_name}");
973    }
974}
975
976#[cfg(test)]
977mod test {
978    use super::*;
979
980    #[mz_ore::test]
981    fn test_sql_server_escaping() {
982        assert_eq!("[]", &quote_identifier(""));
983        assert_eq!("[]]]", &quote_identifier("]"));
984        assert_eq!("[a]", &quote_identifier("a"));
985        assert_eq!("[cost(]]\u{00A3})]", &quote_identifier("cost(]\u{00A3})"));
986        assert_eq!("[[g[o[o]][]", &quote_identifier("[g[o[o]["));
987    }
988}