Skip to main content

mz_sql_server_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::result::ResultExt;
21use mz_repr::SqlScalarType;
22use smallvec::{SmallVec, smallvec};
23use tiberius::ToSql;
24use tokio::net::TcpStream;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
26use tokio::sync::oneshot;
27use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
28
29pub mod cdc;
30pub mod config;
31pub mod desc;
32pub mod inspect;
33
34pub use config::Config;
35pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
36
37use crate::cdc::Lsn;
38use crate::config::TunnelConfig;
39use crate::desc::SqlServerColumnDecodeType;
40
41/// Higher level wrapper around a [`tiberius::Client`] that models transaction
42/// management like other database clients.
43#[derive(Debug)]
44pub struct Client {
45    tx: UnboundedSender<Request>,
46    // The configuration used to create this client.
47    config: Config,
48}
49// While a Client could implement Clone, it's not obvious how multiple Clients
50// using the same SQL Server connection would interact, so ban it for now.
51static_assertions::assert_not_impl_all!(Client: Clone);
52
53impl Client {
54    /// Connect to the specified SQL Server instance, returning a [`Client`]
55    /// that can be used to query it and a [`Connection`] that must be polled
56    /// to send and receive results.
57    ///
58    /// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements
59    /// IntoFuture and does the default good thing of moving the `Connection`
60    /// into a tokio task? And a `.raw()` option that will instead return both
61    /// the Client and Connection for manual polling.
62    pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
63        // Setup our tunnelling and return any resources that need to be kept
64        // alive for the duration of the connection.
65        let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
66            TunnelConfig::Direct => {
67                let tcp = TcpStream::connect(config.inner.get_addr())
68                    .await
69                    .context("direct")?;
70                (tcp, None)
71            }
72            TunnelConfig::Ssh {
73                config: ssh_config,
74                manager,
75                timeout,
76                host,
77                port,
78            } => {
79                // N.B. If this tunnel is dropped it will close so we need to
80                // keep it alive for the duration of the connection.
81                let tunnel = manager
82                    .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
83                    .await?;
84                let tcp = TcpStream::connect(tunnel.local_addr())
85                    .await
86                    .context("ssh tunnel")?;
87
88                (tcp, Some(Box::new(tunnel)))
89            }
90            TunnelConfig::AwsPrivatelink {
91                connection_id,
92                port,
93            } => {
94                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
95                let mut privatelink_addrs =
96                    tokio::net::lookup_host((privatelink_host.clone(), 0)).await?;
97
98                let Some(mut addr) = privatelink_addrs.next() else {
99                    return Err(SqlServerError::InvariantViolated(format!(
100                        "aws privatelink: no addresses found for host {:?}",
101                        privatelink_host
102                    )));
103                };
104
105                addr.set_port(port.clone());
106
107                let tcp = TcpStream::connect(addr)
108                    .await
109                    .context(format!("aws privatelink {:?}", addr))?;
110
111                (tcp, None)
112            }
113        };
114
115        tcp.set_nodelay(true)?;
116
117        let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
118        mz_ore::task::spawn(|| "sql-server-client-connection", async move {
119            connection.await
120        });
121
122        Ok(client)
123    }
124
125    /// Create a new Client instance with the same configuration that created
126    /// this configuration.
127    pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
128        Self::connect(self.config.clone()).await
129    }
130
131    pub async fn connect_raw(
132        config: Config,
133        tcp: tokio::net::TcpStream,
134        resources: Option<Box<dyn Any + Send + Sync>>,
135    ) -> Result<(Self, Connection), SqlServerError> {
136        let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
137        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
138
139        // TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have.
140
141        Ok((
142            Client { tx, config },
143            Connection {
144                rx,
145                client,
146                _resources: resources,
147            },
148        ))
149    }
150
151    /// Executes SQL statements in SQL Server, returning the number of rows effected.
152    ///
153    /// Passthrough method for [`tiberius::Client::execute`].
154    ///
155    /// Note: The returned [`Future`] does not need to be awaited for the query
156    /// to be sent.
157    ///
158    /// [`Future`]: std::future::Future
159    pub async fn execute<'a>(
160        &mut self,
161        query: impl Into<Cow<'a, str>>,
162        params: &[&dyn ToSql],
163    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
164        let (tx, rx) = tokio::sync::oneshot::channel();
165
166        let params = params
167            .iter()
168            .map(|p| OwnedColumnData::from(p.to_sql()))
169            .collect();
170        let kind = RequestKind::Execute {
171            query: query.into().to_string(),
172            params,
173        };
174        self.tx
175            .send(Request { tx, kind })
176            .context("sending request")?;
177
178        let response = rx.await.context("channel")??;
179        match response {
180            Response::Execute { rows_affected } => Ok(rows_affected),
181            other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
182                Err(SqlServerError::ProgrammingError(format!(
183                    "expected Response::Execute, got {other:?}"
184                )))
185            }
186        }
187    }
188
189    /// Executes SQL statements in SQL Server, returning the resulting rows.
190    ///
191    /// Passthrough method for [`tiberius::Client::query`].
192    ///
193    /// Note: The returned [`Future`] does not need to be awaited for the query
194    /// to be sent.
195    ///
196    /// [`Future`]: std::future::Future
197    pub async fn query<'a>(
198        &mut self,
199        query: impl Into<Cow<'a, str>>,
200        params: &[&dyn tiberius::ToSql],
201    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
202        let (tx, rx) = tokio::sync::oneshot::channel();
203
204        let params = params
205            .iter()
206            .map(|p| OwnedColumnData::from(p.to_sql()))
207            .collect();
208        let kind = RequestKind::Query {
209            query: query.into().to_string(),
210            params,
211        };
212        self.tx
213            .send(Request { tx, kind })
214            .context("sending request")?;
215
216        let response = rx.await.context("channel")??;
217        match response {
218            Response::Rows(rows) => Ok(rows),
219            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
220                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
221            ),
222        }
223    }
224
225    /// Executes SQL statements in SQL Server, returning a [`Stream`] of
226    /// resulting rows.
227    ///
228    /// Passthrough method for [`tiberius::Client::query`].
229    pub fn query_streaming<'c, 'q, Q>(
230        &'c mut self,
231        query: Q,
232        params: &[&dyn tiberius::ToSql],
233    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
234    where
235        Q: Into<Cow<'q, str>>,
236    {
237        let (tx, rx) = tokio::sync::oneshot::channel();
238        let params = params
239            .iter()
240            .map(|p| OwnedColumnData::from(p.to_sql()))
241            .collect();
242        let kind = RequestKind::QueryStreamed {
243            query: query.into().to_string(),
244            params,
245        };
246
247        // Make our initial request which will return a Stream of Rows.
248        let request_future = async move {
249            self.tx
250                .send(Request { tx, kind })
251                .context("sending request")?;
252
253            let response = rx.await.context("channel")??;
254            match response {
255                Response::RowStream { stream } => {
256                    Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
257                }
258                other @ Response::Execute { .. } | other @ Response::Rows(_) => {
259                    Err(SqlServerError::ProgrammingError(format!(
260                        "expected Response::Rows, got {other:?}"
261                    )))
262                }
263            }
264        };
265
266        // "flatten" our initial request into the returned stream.
267        futures::stream::once(request_future).try_flatten()
268    }
269
270    /// Executes multiple queries, delimited with `;` and return multiple
271    /// result sets; one for each query.
272    ///
273    /// Passthrough method for [`tiberius::Client::simple_query`].
274    ///
275    /// Note: The returned [`Future`] does not need to be awaited for the query
276    /// to be sent.
277    ///
278    /// [`Future`]: std::future::Future
279    pub async fn simple_query<'a>(
280        &mut self,
281        query: impl Into<Cow<'a, str>>,
282    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
283        let (tx, rx) = tokio::sync::oneshot::channel();
284        let kind = RequestKind::SimpleQuery {
285            query: query.into().to_string(),
286        };
287        self.tx
288            .send(Request { tx, kind })
289            .context("sending request")?;
290
291        let response = rx.await.context("channel")??;
292        match response {
293            Response::Rows(rows) => Ok(rows),
294            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
295                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
296            ),
297        }
298    }
299
300    /// Starts a transaction which is automatically rolled back on drop.
301    ///
302    /// To commit or rollback the transaction, see [`Transaction::commit`] and
303    /// [`Transaction::rollback`] respectively.
304    pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
305        Transaction::new(self).await
306    }
307
308    /// Sets the transaction isolation level for the current session.
309    pub async fn set_transaction_isolation(
310        &mut self,
311        level: TransactionIsolationLevel,
312    ) -> Result<(), SqlServerError> {
313        let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
314        self.simple_query(query).await?;
315        Ok(())
316    }
317
318    /// Returns the current transaction isolation level for the current session.
319    pub async fn get_transaction_isolation(
320        &mut self,
321    ) -> Result<TransactionIsolationLevel, SqlServerError> {
322        const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
323        let rows = self.simple_query(QUERY).await?;
324        match &rows[..] {
325            [row] => {
326                let val: i16 = row
327                    .try_get(0)
328                    .context("getting 0th column")?
329                    .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
330                let level = TransactionIsolationLevel::try_from_sql_server(val)?;
331                Ok(level)
332            }
333            other => Err(SqlServerError::InvariantViolated(format!(
334                "expected one row, got {other:?}"
335            ))),
336        }
337    }
338
339    /// Return a [`CdcStream`] that can be used to track changes for the specified
340    /// `capture_instances`.
341    ///
342    /// [`CdcStream`]: crate::cdc::CdcStream
343    pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
344    where
345        I: IntoIterator,
346        I::Item: Into<Arc<str>>,
347        M: SqlServerCdcMetrics,
348    {
349        let instances = capture_instances
350            .into_iter()
351            .map(|i| (i.into(), None))
352            .collect();
353        crate::cdc::CdcStream::new(self, instances, metrics)
354    }
355}
356
357/// A stream of [`tiberius::Row`]s.
358pub type RowStream<'a> =
359    Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
360
361#[derive(Debug)]
362pub struct Transaction<'a> {
363    client: &'a mut Client,
364    closed: bool,
365}
366
367impl<'a> Transaction<'a> {
368    async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
369        let results = client
370            .simple_query("BEGIN TRANSACTION")
371            .await
372            .context("begin")?;
373        if !results.is_empty() {
374            Err(SqlServerError::InvariantViolated(format!(
375                "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
376            )))
377        } else {
378            Ok(Transaction {
379                client,
380                closed: false,
381            })
382        }
383    }
384
385    /// Creates a savepoint via `SAVE TRANSACTION` with the provided name.
386    /// Creating a savepoint forces a write to the transaction log, which will associate an
387    /// [`Lsn`] with the current transaction.
388    ///
389    /// The savepoint name must follow rules for SQL Server identifiers
390    /// - starts with letter or underscore
391    /// - only contains letters, digits, and underscores
392    /// - no reserved words
393    /// - 32 char max
394    pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
395        // Limit the name checks to prevent sending a potentially dangerous string to the SQL Server.
396        // We prefer the server do the majority of the validation.
397        if savepoint_name.is_empty()
398            || !savepoint_name
399                .chars()
400                .all(|c| c.is_alphanumeric() || c == '_')
401        {
402            Err(SqlServerError::ProgrammingError(format!(
403                "Invalid savepoint name: '{savepoint_name}"
404            )))?;
405        }
406
407        let stmt = format!("SAVE TRANSACTION {}", quote_identifier(savepoint_name));
408        let _result = self.client.simple_query(stmt).await?;
409        Ok(())
410    }
411
412    /// Retrieve the [`Lsn`] associated with the current session.
413    ///
414    /// MS SQL Server will not assign an [`Lsn`] until a write is performed (e.g. via `SAVE TRANSACTION`).
415    pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
416        static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
417            FROM sys.dm_tran_database_transactions dt \
418            JOIN sys.dm_tran_current_transaction ct \
419                ON ct.transaction_id = dt.transaction_id \
420            WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
421        let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
422        crate::inspect::parse_numeric_lsn(&result)
423    }
424
425    /// Lock the provided table to prevent writes but allow reads, uses `(TABLOCK, HOLDLOCK)`.
426    ///
427    /// This will set the transaction isolation level to `READ COMMITTED` and then obtain the
428    /// lock using a `SELECT` statement that will not read any data from the table.
429    /// The lock is released after transaction commit or rollback.
430    pub async fn lock_table_shared(
431        &mut self,
432        schema: &str,
433        table: &str,
434    ) -> Result<(), SqlServerError> {
435        // Locks in MS SQL server do not behave the same way under all isolation levels. In testing,
436        // it has been observed that if the isolation level is SNAPSHOT, these locks are ineffective.
437        static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
438        // This query probably seems odd, but there is no LOCK command in MS SQL. Locks are specified
439        // in SELECT using the WITH keyword.  This query does not need to return any rows to lock the table,
440        // hence the 1=0, which is something short that always evaluates to false in this universe.
441        let query = format!(
442            "{SET_READ_COMMITTED}\nSELECT * FROM {schema}.{table} WITH (TABLOCK, HOLDLOCK) WHERE 1=0;",
443            schema = quote_identifier(schema),
444            table = quote_identifier(table)
445        );
446        let _result = self.client.simple_query(query).await?;
447        Ok(())
448    }
449
450    /// See [`Client::execute`].
451    pub async fn execute<'q>(
452        &mut self,
453        query: impl Into<Cow<'q, str>>,
454        params: &[&dyn ToSql],
455    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
456        self.client.execute(query, params).await
457    }
458
459    /// See [`Client::query`].
460    pub async fn query<'q>(
461        &mut self,
462        query: impl Into<Cow<'q, str>>,
463        params: &[&dyn tiberius::ToSql],
464    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
465        self.client.query(query, params).await
466    }
467
468    /// See [`Client::query_streaming`]
469    pub fn query_streaming<'c, 'q, Q>(
470        &'c mut self,
471        query: Q,
472        params: &[&dyn tiberius::ToSql],
473    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
474    where
475        Q: Into<Cow<'q, str>>,
476    {
477        self.client.query_streaming(query, params)
478    }
479
480    /// See [`Client::simple_query`].
481    pub async fn simple_query<'q>(
482        &mut self,
483        query: impl Into<Cow<'q, str>>,
484    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
485        self.client.simple_query(query).await
486    }
487
488    /// Rollback the [`Transaction`].
489    pub async fn rollback(mut self) -> Result<(), SqlServerError> {
490        static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
491        // N.B. Mark closed _before_ running the query. This prevents us from
492        // double closing the transaction if this query itself fails.
493        self.closed = true;
494        self.client.simple_query(ROLLBACK_QUERY).await?;
495        Ok(())
496    }
497
498    /// Commit the [`Transaction`].
499    pub async fn commit(mut self) -> Result<(), SqlServerError> {
500        static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
501        // N.B. Mark closed _before_ running the query. This prevents us from
502        // double closing the transaction if this query itself fails.
503        self.closed = true;
504        self.client.simple_query(COMMIT_QUERY).await?;
505        Ok(())
506    }
507}
508
509impl Drop for Transaction<'_> {
510    fn drop(&mut self) {
511        if !self.closed {
512            // Send the ROLLBACK request directly through the channel, bypassing
513            // the async `simple_query` method. We cannot `.await` in `Drop`, and
514            // merely calling an async fn without awaiting it does nothing (the
515            // future is never polled so the channel send inside never executes).
516            //
517            // We intentionally drop the response receiver since we cannot await
518            // it in a synchronous context. The Connection task will execute the
519            // ROLLBACK and discard the response when the receiver is gone.
520            let (tx, _rx) = oneshot::channel();
521            let kind = RequestKind::SimpleQuery {
522                query: "ROLLBACK TRANSACTION".to_string(),
523            };
524            let _ = self.client.tx.send(Request { tx, kind });
525        }
526    }
527}
528
529/// Transaction isolation levels defined by Microsoft's SQL Server.
530///
531/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql>
532#[derive(Debug, PartialEq, Eq)]
533pub enum TransactionIsolationLevel {
534    ReadUncommitted,
535    ReadCommitted,
536    RepeatableRead,
537    Snapshot,
538    Serializable,
539}
540
541impl TransactionIsolationLevel {
542    /// Return the string representation of a transaction isolation level.
543    fn as_str(&self) -> &'static str {
544        match self {
545            TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
546            TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
547            TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
548            TransactionIsolationLevel::Snapshot => "SNAPSHOT",
549            TransactionIsolationLevel::Serializable => "SERIALIZABLE",
550        }
551    }
552
553    /// Try to parse a [`TransactionIsolationLevel`] from the value returned from SQL Server.
554    fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
555        let level = match val {
556            1 => TransactionIsolationLevel::ReadUncommitted,
557            2 => TransactionIsolationLevel::ReadCommitted,
558            3 => TransactionIsolationLevel::RepeatableRead,
559            4 => TransactionIsolationLevel::Serializable,
560            5 => TransactionIsolationLevel::Snapshot,
561            x => anyhow::bail!("unknown level {x}"),
562        };
563        Ok(level)
564    }
565}
566
567#[derive(Derivative)]
568#[derivative(Debug)]
569enum Response {
570    Execute {
571        rows_affected: SmallVec<[u64; 1]>,
572    },
573    Rows(SmallVec<[tiberius::Row; 1]>),
574    RowStream {
575        #[derivative(Debug = "ignore")]
576        stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
577    },
578}
579
580#[derive(Debug)]
581struct Request {
582    tx: oneshot::Sender<Result<Response, SqlServerError>>,
583    kind: RequestKind,
584}
585
586#[derive(Derivative)]
587#[derivative(Debug)]
588enum RequestKind {
589    Execute {
590        query: String,
591        #[derivative(Debug = "ignore")]
592        params: SmallVec<[OwnedColumnData; 4]>,
593    },
594    Query {
595        query: String,
596        #[derivative(Debug = "ignore")]
597        params: SmallVec<[OwnedColumnData; 4]>,
598    },
599    QueryStreamed {
600        query: String,
601        #[derivative(Debug = "ignore")]
602        params: SmallVec<[OwnedColumnData; 4]>,
603    },
604    SimpleQuery {
605        query: String,
606    },
607}
608
609pub struct Connection {
610    /// Other end of the channel that [`Client`] holds.
611    rx: UnboundedReceiver<Request>,
612    /// Actual client that we use to send requests.
613    client: tiberius::Client<Compat<TcpStream>>,
614    /// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection.
615    _resources: Option<Box<dyn Any + Send + Sync>>,
616}
617
618impl Connection {
619    async fn run(mut self) {
620        while let Some(Request { tx, kind }) = self.rx.recv().await {
621            tracing::trace!(?kind, "processing SQL Server query");
622            let result = Connection::handle_request(&mut self.client, kind).await;
623            let (response, maybe_extra_work) = match result {
624                Ok((response, work)) => (Ok(response), work),
625                Err(err) => (Err(err), None),
626            };
627
628            // We don't care if our listener for this query has gone away.
629            let _ = tx.send(response);
630
631            // After we handle a request there might still be something in-flight
632            // that we need to continue driving, e.g. when the response is a
633            // Stream of Rows.
634            if let Some(extra_work) = maybe_extra_work {
635                extra_work.await;
636            }
637        }
638        tracing::debug!("channel closed, SQL Server InnerClient shutting down");
639    }
640
641    async fn handle_request<'c>(
642        client: &'c mut tiberius::Client<Compat<TcpStream>>,
643        kind: RequestKind,
644    ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
645        match kind {
646            RequestKind::Execute { query, params } => {
647                #[allow(clippy::as_conversions)]
648                let params: SmallVec<[&dyn ToSql; 4]> =
649                    params.iter().map(|x| x as &dyn ToSql).collect();
650                let result = client.execute(query, &params[..]).await?;
651
652                match result.rows_affected() {
653                    rows_affected => {
654                        let response = Response::Execute {
655                            rows_affected: rows_affected.into(),
656                        };
657                        Ok((response, None))
658                    }
659                }
660            }
661            RequestKind::Query { query, params } => {
662                #[allow(clippy::as_conversions)]
663                let params: SmallVec<[&dyn ToSql; 4]> =
664                    params.iter().map(|x| x as &dyn ToSql).collect();
665                let result = client.query(query, params.as_slice()).await?;
666
667                let mut results = result.into_results().await.context("into results")?;
668                if results.is_empty() {
669                    Ok((Response::Rows(smallvec![]), None))
670                } else if results.len() == 1 {
671                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
672                    // push onto a SmallVec to avoid the heap allocations.
673                    let rows = results.pop().expect("checked len").into();
674                    Ok((Response::Rows(rows), None))
675                } else {
676                    Err(SqlServerError::ProgrammingError(format!(
677                        "Query only supports 1 statement, got {}",
678                        results.len()
679                    )))
680                }
681            }
682            RequestKind::QueryStreamed { query, params } => {
683                #[allow(clippy::as_conversions)]
684                let params: SmallVec<[&dyn ToSql; 4]> =
685                    params.iter().map(|x| x as &dyn ToSql).collect();
686                let result = client.query(query, params.as_slice()).await?;
687
688                // ~~ Rust Lifetimes ~~
689                //
690                // What's going on here, why do we have some extra channel and
691                // this 'work' future?
692                //
693                // Remember, we run the actual `tiberius::Client` in a separate
694                // `tokio::task` and the `mz::Client` sends query requests via
695                // a channel, this allows us to "automatically" manage
696                // transactions.
697                //
698                // But the returned `QueryStream` from a `tiberius::Client` has
699                // a lifetime associated with said client running in this
700                // separate task. Thus we cannot send the `QueryStream` back to
701                // the `mz::Client` because the lifetime of these two clients
702                // is not linked at all. The fix is to create a separate owned
703                // channel and return the receiving end, while this work future
704                // pulls events off the `QueryStream` and sends them over the
705                // channel we just returned.
706                let (tx, rx) = tokio::sync::mpsc::channel(256);
707                let work = Box::pin(async move {
708                    let mut stream = result.into_row_stream();
709                    while let Some(result) = stream.next().await {
710                        if let Err(err) = tx.send(result.err_into()).await {
711                            tracing::warn!(?err, "SQL Server row stream receiver went away");
712                        }
713                    }
714                    tracing::info!("SQL Server row stream complete");
715                });
716
717                Ok((Response::RowStream { stream: rx }, Some(work)))
718            }
719            RequestKind::SimpleQuery { query } => {
720                let result = client.simple_query(query).await?;
721
722                let mut results = result.into_results().await.context("into results")?;
723                if results.is_empty() {
724                    Ok((Response::Rows(smallvec![]), None))
725                } else if results.len() == 1 {
726                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
727                    // push onto a SmallVec to avoid the heap allocations.
728                    let rows = results.pop().expect("checked len").into();
729                    Ok((Response::Rows(rows), None))
730                } else {
731                    Err(SqlServerError::ProgrammingError(format!(
732                        "Simple query only supports 1 statement, got {}",
733                        results.len()
734                    )))
735                }
736            }
737        }
738    }
739}
740
741impl IntoFuture for Connection {
742    type Output = ();
743    type IntoFuture = BoxFuture<'static, Self::Output>;
744
745    fn into_future(self) -> Self::IntoFuture {
746        self.run().boxed()
747    }
748}
749
750/// Owned version of [`tiberius::ColumnData`] that can be more easily sent
751/// across threads or through a channel.
752#[derive(Debug)]
753enum OwnedColumnData {
754    U8(Option<u8>),
755    I16(Option<i16>),
756    I32(Option<i32>),
757    I64(Option<i64>),
758    F32(Option<f32>),
759    F64(Option<f64>),
760    Bit(Option<bool>),
761    String(Option<String>),
762    Guid(Option<uuid::Uuid>),
763    Binary(Option<Vec<u8>>),
764    Numeric(Option<tiberius::numeric::Numeric>),
765    Xml(Option<tiberius::xml::XmlData>),
766    DateTime(Option<tiberius::time::DateTime>),
767    SmallDateTime(Option<tiberius::time::SmallDateTime>),
768    Time(Option<tiberius::time::Time>),
769    Date(Option<tiberius::time::Date>),
770    DateTime2(Option<tiberius::time::DateTime2>),
771    DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
772}
773
774impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
775    fn from(value: tiberius::ColumnData<'a>) -> Self {
776        match value {
777            tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
778            tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
779            tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
780            tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
781            tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
782            tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
783            tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
784            tiberius::ColumnData::String(inner) => {
785                OwnedColumnData::String(inner.map(|s| s.to_string()))
786            }
787            tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
788            tiberius::ColumnData::Binary(inner) => {
789                OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
790            }
791            tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
792            tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
793            tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
794            tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
795            tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
796            tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
797            tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
798            tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
799        }
800    }
801}
802
803impl tiberius::ToSql for OwnedColumnData {
804    fn to_sql(&self) -> tiberius::ColumnData<'_> {
805        match self {
806            OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
807            OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
808            OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
809            OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
810            OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
811            OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
812            OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
813            OwnedColumnData::String(inner) => {
814                tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
815            }
816            OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
817            OwnedColumnData::Binary(inner) => {
818                tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
819            }
820            OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
821            OwnedColumnData::Xml(inner) => {
822                tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
823            }
824            OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
825            OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
826            OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
827            OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
828            OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
829            OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
830        }
831    }
832}
833
834impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
835    fn from(value: &'a T) -> Self {
836        OwnedColumnData::from(value.to_sql())
837    }
838}
839
840#[derive(Debug, thiserror::Error)]
841pub enum SqlServerError {
842    #[error(transparent)]
843    SqlServer(#[from] tiberius::error::Error),
844    #[error(transparent)]
845    CdcError(#[from] crate::cdc::CdcError),
846    #[error("expected column '{0}' to be present")]
847    MissingColumn(&'static str),
848    #[error("sql server client encountered I/O error: {0}")]
849    IO(#[from] tokio::io::Error),
850    #[error("found invalid data in the column '{column_name}': {error}")]
851    InvalidData { column_name: String, error: String },
852    #[error("got back a null value when querying for the LSN")]
853    NullLsn,
854    #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
855    InvalidSystemSetting {
856        name: String,
857        expected: String,
858        actual: String,
859    },
860    #[error("invariant was violated: {0}")]
861    InvariantViolated(String),
862    #[error(transparent)]
863    Generic(#[from] anyhow::Error),
864    #[error("programming error! {0}")]
865    ProgrammingError(String),
866    #[error(
867        "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
868    )]
869    AuthorizationError {
870        tables: String,
871        capture_instances: String,
872    },
873}
874
875/// Errors returned from decoding SQL Server rows.
876///
877/// **PLEASE READ**
878///
879/// The string representation of this error type is **durably stored** in a source and thus this
880/// error type needs to be **stable** across releases. For example, if in v11 of Materialize we
881/// fail to decode `Row(["foo bar"])` from SQL Server, we will record the error in the source's
882/// Persist shard. If in v12 of Materialize the user deletes the `Row(["foo bar"])` from their
883/// upstream instance, we need to perfectly retract the error we previously committed.
884///
885/// This means be **very** careful when changing this type.
886#[derive(Debug, thiserror::Error)]
887pub enum SqlServerDecodeError {
888    #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
889    InvalidColumn {
890        column_name: String,
891        as_type: &'static str,
892    },
893    #[error("found invalid data in the column '{column_name}': {error}")]
894    InvalidData { column_name: String, error: String },
895    #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
896    Unsupported {
897        sql_server_type: SqlServerColumnDecodeType,
898        mz_type: SqlScalarType,
899    },
900}
901
902impl SqlServerDecodeError {
903    fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
904        // These error messages need to remain stable, do not change them.
905        let error = match error {
906            mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
907        };
908        SqlServerDecodeError::InvalidData {
909            column_name: name.to_string(),
910            error: error.to_string(),
911        }
912    }
913
914    fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
915        // These error messages need to remain stable, do not change them.
916        let error = match error {
917            mz_repr::adt::date::DateError::OutOfRange => "out of range",
918        };
919        SqlServerDecodeError::InvalidData {
920            column_name: name.to_string(),
921            error: error.to_string(),
922        }
923    }
924
925    fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
926        SqlServerDecodeError::InvalidData {
927            column_name: name.to_string(),
928            error: format!("expected {expected_chars} chars found {found_chars}"),
929        }
930    }
931
932    fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
933        SqlServerDecodeError::InvalidData {
934            column_name: name.to_string(),
935            error: format!("expected max {max_chars} chars found {found_chars}"),
936        }
937    }
938
939    fn invalid_column(name: &str, as_type: &'static str) -> Self {
940        SqlServerDecodeError::InvalidColumn {
941            column_name: name.to_string(),
942            as_type,
943        }
944    }
945}
946
947/// Quotes the provided string using '[]' to match SQL Server `QUOTENAME` function. This form
948/// of quotes is unaffected by the SQL Server setting `SET QUOTED_IDENTIFIER`.
949///
950/// See:
951/// - <https://learn.microsoft.com/en-us/sql/t-sql/functions/quotename-transact-sql?view=sql-server-ver17>
952/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-quoted-identifier-transact-sql?view=sql-server-ver17>
953pub fn quote_identifier(ident: &str) -> String {
954    let mut quoted = ident.replace(']', "]]");
955    quoted.insert(0, '[');
956    quoted.push(']');
957    quoted
958}
959
960pub trait SqlServerCdcMetrics {
961    /// Called before the table lock is aquired
962    fn snapshot_table_lock_start(&self, table_name: &str);
963    /// Called after the table lock is released
964    fn snapshot_table_lock_end(&self, table_name: &str);
965}
966
967/// A simple implementation of [`SqlServerCdcMetrics`] that uses the tracing framework to log
968/// the start and end conditions.
969pub struct LoggingSqlServerCdcMetrics;
970
971impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
972    fn snapshot_table_lock_start(&self, table_name: &str) {
973        tracing::info!("snapshot_table_lock_start: {table_name}");
974    }
975
976    fn snapshot_table_lock_end(&self, table_name: &str) {
977        tracing::info!("snapshot_table_lock_end: {table_name}");
978    }
979}
980
981#[cfg(test)]
982mod test {
983    use super::*;
984
985    #[mz_ore::test]
986    fn test_sql_server_escaping() {
987        assert_eq!("[]", &quote_identifier(""));
988        assert_eq!("[]]]", &quote_identifier("]"));
989        assert_eq!("[a]", &quote_identifier("a"));
990        assert_eq!("[cost(]]\u{00A3})]", &quote_identifier("cost(]\u{00A3})"));
991        assert_eq!("[[g[o[o]][]", &quote_identifier("[g[o[o]["));
992    }
993}