mz_sql_server_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::netio::DUMMY_DNS_PORT;
21use mz_ore::result::ResultExt;
22use mz_repr::SqlScalarType;
23use smallvec::{SmallVec, smallvec};
24use tiberius::ToSql;
25use tokio::net::TcpStream;
26use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
27use tokio::sync::oneshot;
28use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
29
30pub mod cdc;
31pub mod config;
32pub mod desc;
33pub mod inspect;
34
35pub use config::Config;
36pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
37
38use crate::cdc::Lsn;
39use crate::config::TunnelConfig;
40use crate::desc::SqlServerColumnDecodeType;
41
42/// Higher level wrapper around a [`tiberius::Client`] that models transaction
43/// management like other database clients.
44#[derive(Debug)]
45pub struct Client {
46    tx: UnboundedSender<Request>,
47    // The configuration used to create this client.
48    config: Config,
49}
50// While a Client could implement Clone, it's not obvious how multiple Clients
51// using the same SQL Server connection would interact, so ban it for now.
52static_assertions::assert_not_impl_all!(Client: Clone);
53
54impl Client {
55    /// Connect to the specified SQL Server instance, returning a [`Client`]
56    /// that can be used to query it and a [`Connection`] that must be polled
57    /// to send and receive results.
58    ///
59    /// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements
60    /// IntoFuture and does the default good thing of moving the `Connection`
61    /// into a tokio task? And a `.raw()` option that will instead return both
62    /// the Client and Connection for manual polling.
63    pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
64        // Setup our tunnelling and return any resources that need to be kept
65        // alive for the duration of the connection.
66        let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
67            TunnelConfig::Direct => {
68                let tcp = TcpStream::connect(config.inner.get_addr())
69                    .await
70                    .context("direct")?;
71                (tcp, None)
72            }
73            TunnelConfig::Ssh {
74                config: ssh_config,
75                manager,
76                timeout,
77                host,
78                port,
79            } => {
80                // N.B. If this tunnel is dropped it will close so we need to
81                // keep it alive for the duration of the connection.
82                let tunnel = manager
83                    .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
84                    .await?;
85                let tcp = TcpStream::connect(tunnel.local_addr())
86                    .await
87                    .context("ssh tunnel")?;
88
89                (tcp, Some(Box::new(tunnel)))
90            }
91            TunnelConfig::AwsPrivatelink {
92                connection_id,
93                port,
94            } => {
95                let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
96                let mut privatelink_addrs =
97                    tokio::net::lookup_host((privatelink_host.clone(), DUMMY_DNS_PORT)).await?;
98
99                let Some(mut addr) = privatelink_addrs.next() else {
100                    return Err(SqlServerError::InvariantViolated(format!(
101                        "aws privatelink: no addresses found for host {:?}",
102                        privatelink_host
103                    )));
104                };
105
106                addr.set_port(port.clone());
107
108                let tcp = TcpStream::connect(addr)
109                    .await
110                    .context(format!("aws privatelink {:?}", addr))?;
111
112                (tcp, None)
113            }
114        };
115
116        tcp.set_nodelay(true)?;
117
118        let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
119        mz_ore::task::spawn(|| "sql-server-client-connection", async move {
120            connection.await
121        });
122
123        Ok(client)
124    }
125
126    /// Create a new Client instance with the same configuration that created
127    /// this configuration.
128    pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
129        Self::connect(self.config.clone()).await
130    }
131
132    pub async fn connect_raw(
133        config: Config,
134        tcp: tokio::net::TcpStream,
135        resources: Option<Box<dyn Any + Send + Sync>>,
136    ) -> Result<(Self, Connection), SqlServerError> {
137        let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
138        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
139
140        // TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have.
141
142        Ok((
143            Client { tx, config },
144            Connection {
145                rx,
146                client,
147                _resources: resources,
148            },
149        ))
150    }
151
152    /// Executes SQL statements in SQL Server, returning the number of rows effected.
153    ///
154    /// Passthrough method for [`tiberius::Client::execute`].
155    ///
156    /// Note: The returned [`Future`] does not need to be awaited for the query
157    /// to be sent.
158    ///
159    /// [`Future`]: std::future::Future
160    pub async fn execute<'a>(
161        &mut self,
162        query: impl Into<Cow<'a, str>>,
163        params: &[&dyn ToSql],
164    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
165        let (tx, rx) = tokio::sync::oneshot::channel();
166
167        let params = params
168            .iter()
169            .map(|p| OwnedColumnData::from(p.to_sql()))
170            .collect();
171        let kind = RequestKind::Execute {
172            query: query.into().to_string(),
173            params,
174        };
175        self.tx
176            .send(Request { tx, kind })
177            .context("sending request")?;
178
179        let response = rx.await.context("channel")??;
180        match response {
181            Response::Execute { rows_affected } => Ok(rows_affected),
182            other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
183                Err(SqlServerError::ProgrammingError(format!(
184                    "expected Response::Execute, got {other:?}"
185                )))
186            }
187        }
188    }
189
190    /// Executes SQL statements in SQL Server, returning the resulting rows.
191    ///
192    /// Passthrough method for [`tiberius::Client::query`].
193    ///
194    /// Note: The returned [`Future`] does not need to be awaited for the query
195    /// to be sent.
196    ///
197    /// [`Future`]: std::future::Future
198    pub async fn query<'a>(
199        &mut self,
200        query: impl Into<Cow<'a, str>>,
201        params: &[&dyn tiberius::ToSql],
202    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
203        let (tx, rx) = tokio::sync::oneshot::channel();
204
205        let params = params
206            .iter()
207            .map(|p| OwnedColumnData::from(p.to_sql()))
208            .collect();
209        let kind = RequestKind::Query {
210            query: query.into().to_string(),
211            params,
212        };
213        self.tx
214            .send(Request { tx, kind })
215            .context("sending request")?;
216
217        let response = rx.await.context("channel")??;
218        match response {
219            Response::Rows(rows) => Ok(rows),
220            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
221                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
222            ),
223        }
224    }
225
226    /// Executes SQL statements in SQL Server, returning a [`Stream`] of
227    /// resulting rows.
228    ///
229    /// Passthrough method for [`tiberius::Client::query`].
230    pub fn query_streaming<'c, 'q, Q>(
231        &'c mut self,
232        query: Q,
233        params: &[&dyn tiberius::ToSql],
234    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
235    where
236        Q: Into<Cow<'q, str>>,
237    {
238        let (tx, rx) = tokio::sync::oneshot::channel();
239        let params = params
240            .iter()
241            .map(|p| OwnedColumnData::from(p.to_sql()))
242            .collect();
243        let kind = RequestKind::QueryStreamed {
244            query: query.into().to_string(),
245            params,
246        };
247
248        // Make our initial request which will return a Stream of Rows.
249        let request_future = async move {
250            self.tx
251                .send(Request { tx, kind })
252                .context("sending request")?;
253
254            let response = rx.await.context("channel")??;
255            match response {
256                Response::RowStream { stream } => {
257                    Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
258                }
259                other @ Response::Execute { .. } | other @ Response::Rows(_) => {
260                    Err(SqlServerError::ProgrammingError(format!(
261                        "expected Response::Rows, got {other:?}"
262                    )))
263                }
264            }
265        };
266
267        // "flatten" our initial request into the returned stream.
268        futures::stream::once(request_future).try_flatten()
269    }
270
271    /// Executes multiple queries, delimited with `;` and return multiple
272    /// result sets; one for each query.
273    ///
274    /// Passthrough method for [`tiberius::Client::simple_query`].
275    ///
276    /// Note: The returned [`Future`] does not need to be awaited for the query
277    /// to be sent.
278    ///
279    /// [`Future`]: std::future::Future
280    pub async fn simple_query<'a>(
281        &mut self,
282        query: impl Into<Cow<'a, str>>,
283    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
284        let (tx, rx) = tokio::sync::oneshot::channel();
285        let kind = RequestKind::SimpleQuery {
286            query: query.into().to_string(),
287        };
288        self.tx
289            .send(Request { tx, kind })
290            .context("sending request")?;
291
292        let response = rx.await.context("channel")??;
293        match response {
294            Response::Rows(rows) => Ok(rows),
295            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
296                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
297            ),
298        }
299    }
300
301    /// Starts a transaction which is automatically rolled back on drop.
302    ///
303    /// To commit or rollback the transaction, see [`Transaction::commit`] and
304    /// [`Transaction::rollback`] respectively.
305    pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
306        Transaction::new(self).await
307    }
308
309    /// Sets the transaction isolation level for the current session.
310    pub async fn set_transaction_isolation(
311        &mut self,
312        level: TransactionIsolationLevel,
313    ) -> Result<(), SqlServerError> {
314        let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
315        self.simple_query(query).await?;
316        Ok(())
317    }
318
319    /// Returns the current transaction isolation level for the current session.
320    pub async fn get_transaction_isolation(
321        &mut self,
322    ) -> Result<TransactionIsolationLevel, SqlServerError> {
323        const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
324        let rows = self.simple_query(QUERY).await?;
325        match &rows[..] {
326            [row] => {
327                let val: i16 = row
328                    .try_get(0)
329                    .context("getting 0th column")?
330                    .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
331                let level = TransactionIsolationLevel::try_from_sql_server(val)?;
332                Ok(level)
333            }
334            other => Err(SqlServerError::InvariantViolated(format!(
335                "expected one row, got {other:?}"
336            ))),
337        }
338    }
339
340    /// Return a [`CdcStream`] that can be used to track changes for the specified
341    /// `capture_instances`.
342    ///
343    /// [`CdcStream`]: crate::cdc::CdcStream
344    pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
345    where
346        I: IntoIterator,
347        I::Item: Into<Arc<str>>,
348        M: SqlServerCdcMetrics,
349    {
350        let instances = capture_instances
351            .into_iter()
352            .map(|i| (i.into(), None))
353            .collect();
354        crate::cdc::CdcStream::new(self, instances, metrics)
355    }
356}
357
358/// A stream of [`tiberius::Row`]s.
359pub type RowStream<'a> =
360    Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
361
362#[derive(Debug)]
363pub struct Transaction<'a> {
364    client: &'a mut Client,
365    closed: bool,
366}
367
368impl<'a> Transaction<'a> {
369    async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
370        let results = client
371            .simple_query("BEGIN TRANSACTION")
372            .await
373            .context("begin")?;
374        if !results.is_empty() {
375            Err(SqlServerError::InvariantViolated(format!(
376                "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
377            )))
378        } else {
379            Ok(Transaction {
380                client,
381                closed: false,
382            })
383        }
384    }
385
386    /// Creates a savepoint via `SAVE TRANSACTION` with the provided name.
387    /// Creating a savepoint forces a write to the transaction log, which will associate an
388    /// [`Lsn`] with the current transaction.
389    ///
390    /// The savepoint name must follow rules for SQL Server identifiers
391    /// - starts with letter or underscore
392    /// - only contains letters, digits, and underscores
393    /// - no reserved words
394    /// - 32 char max
395    pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
396        // Limit the name checks to prevent sending a potentially dangerous string to the SQL Server.
397        // We prefer the server do the majority of the validation.
398        if savepoint_name.is_empty()
399            || !savepoint_name
400                .chars()
401                .all(|c| c.is_alphanumeric() || c == '_')
402        {
403            Err(SqlServerError::ProgrammingError(format!(
404                "Invalid savepoint name: '{savepoint_name}"
405            )))?;
406        }
407
408        let stmt = format!("SAVE TRANSACTION [{savepoint_name}]");
409        let _result = self.client.simple_query(stmt).await?;
410        Ok(())
411    }
412
413    /// Retrieve the [`Lsn`] associated with the current session.
414    ///
415    /// MS SQL Server will not assign an [`Lsn`] until a write is performed (e.g. via `SAVE TRANSACTION`).
416    pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
417        static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
418            FROM sys.dm_tran_database_transactions dt \
419            JOIN sys.dm_tran_current_transaction ct \
420                ON ct.transaction_id = dt.transaction_id \
421            WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
422        let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
423        crate::inspect::parse_numeric_lsn(&result)
424    }
425
426    /// Lock the provided table to prevent writes but allow reads, uses `(TABLOCK, HOLDLOCK)`.
427    ///
428    /// This will set the transaction isolation level to `READ COMMITTED` and then obtain the
429    /// lock using a `SELECT` statement that will not read any data from the table.
430    /// The lock is released after transaction commit or rollback.
431    pub async fn lock_table_shared(
432        &mut self,
433        schema: &str,
434        table: &str,
435    ) -> Result<(), SqlServerError> {
436        // Locks in MS SQL server do not behave the same way under all isolation levels. In testing,
437        // it has been observed that if the isolation level is SNAPSHOT, these locks are ineffective.
438        static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
439        // This query probably seems odd, but there is no LOCK command in MS SQL. Locks are specified
440        // in SELECT using the WITH keyword.  This query does not need to return any rows to lock the table,
441        // hence the 1=0, which is something short that always evaluates to false in this universe.
442        let query = format!(
443            "{SET_READ_COMMITTED}\nSELECT * FROM [{schema}].[{table}] WITH (TABLOCK, HOLDLOCK) WHERE 1=0;"
444        );
445        let _result = self.client.simple_query(query).await?;
446        Ok(())
447    }
448
449    /// See [`Client::execute`].
450    pub async fn execute<'q>(
451        &mut self,
452        query: impl Into<Cow<'q, str>>,
453        params: &[&dyn ToSql],
454    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
455        self.client.execute(query, params).await
456    }
457
458    /// See [`Client::query`].
459    pub async fn query<'q>(
460        &mut self,
461        query: impl Into<Cow<'q, str>>,
462        params: &[&dyn tiberius::ToSql],
463    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
464        self.client.query(query, params).await
465    }
466
467    /// See [`Client::query_streaming`]
468    pub fn query_streaming<'c, 'q, Q>(
469        &'c mut self,
470        query: Q,
471        params: &[&dyn tiberius::ToSql],
472    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
473    where
474        Q: Into<Cow<'q, str>>,
475    {
476        self.client.query_streaming(query, params)
477    }
478
479    /// See [`Client::simple_query`].
480    pub async fn simple_query<'q>(
481        &mut self,
482        query: impl Into<Cow<'q, str>>,
483    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
484        self.client.simple_query(query).await
485    }
486
487    /// Rollback the [`Transaction`].
488    pub async fn rollback(mut self) -> Result<(), SqlServerError> {
489        static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
490        // N.B. Mark closed _before_ running the query. This prevents us from
491        // double closing the transaction if this query itself fails.
492        self.closed = true;
493        self.client.simple_query(ROLLBACK_QUERY).await?;
494        Ok(())
495    }
496
497    /// Commit the [`Transaction`].
498    pub async fn commit(mut self) -> Result<(), SqlServerError> {
499        static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
500        // N.B. Mark closed _before_ running the query. This prevents us from
501        // double closing the transaction if this query itself fails.
502        self.closed = true;
503        self.client.simple_query(COMMIT_QUERY).await?;
504        Ok(())
505    }
506}
507
508impl Drop for Transaction<'_> {
509    fn drop(&mut self) {
510        // Internally the query is synchronously sent down a channel, and the response is what
511        // we await. In other words, we don't need to `.await` here for the query to be run.
512        if !self.closed {
513            let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
514        }
515    }
516}
517
518/// Transaction isolation levels defined by Microsoft's SQL Server.
519///
520/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql>
521#[derive(Debug, PartialEq, Eq)]
522pub enum TransactionIsolationLevel {
523    ReadUncommitted,
524    ReadCommitted,
525    RepeatableRead,
526    Snapshot,
527    Serializable,
528}
529
530impl TransactionIsolationLevel {
531    /// Return the string representation of a transaction isolation level.
532    fn as_str(&self) -> &'static str {
533        match self {
534            TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
535            TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
536            TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
537            TransactionIsolationLevel::Snapshot => "SNAPSHOT",
538            TransactionIsolationLevel::Serializable => "SERIALIZABLE",
539        }
540    }
541
542    /// Try to parse a [`TransactionIsolationLevel`] from the value returned from SQL Server.
543    fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
544        let level = match val {
545            1 => TransactionIsolationLevel::ReadUncommitted,
546            2 => TransactionIsolationLevel::ReadCommitted,
547            3 => TransactionIsolationLevel::RepeatableRead,
548            4 => TransactionIsolationLevel::Serializable,
549            5 => TransactionIsolationLevel::Snapshot,
550            x => anyhow::bail!("unknown level {x}"),
551        };
552        Ok(level)
553    }
554}
555
556#[derive(Derivative)]
557#[derivative(Debug)]
558enum Response {
559    Execute {
560        rows_affected: SmallVec<[u64; 1]>,
561    },
562    Rows(SmallVec<[tiberius::Row; 1]>),
563    RowStream {
564        #[derivative(Debug = "ignore")]
565        stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
566    },
567}
568
569#[derive(Debug)]
570struct Request {
571    tx: oneshot::Sender<Result<Response, SqlServerError>>,
572    kind: RequestKind,
573}
574
575#[derive(Derivative)]
576#[derivative(Debug)]
577enum RequestKind {
578    Execute {
579        query: String,
580        #[derivative(Debug = "ignore")]
581        params: SmallVec<[OwnedColumnData; 4]>,
582    },
583    Query {
584        query: String,
585        #[derivative(Debug = "ignore")]
586        params: SmallVec<[OwnedColumnData; 4]>,
587    },
588    QueryStreamed {
589        query: String,
590        #[derivative(Debug = "ignore")]
591        params: SmallVec<[OwnedColumnData; 4]>,
592    },
593    SimpleQuery {
594        query: String,
595    },
596}
597
598pub struct Connection {
599    /// Other end of the channel that [`Client`] holds.
600    rx: UnboundedReceiver<Request>,
601    /// Actual client that we use to send requests.
602    client: tiberius::Client<Compat<TcpStream>>,
603    /// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection.
604    _resources: Option<Box<dyn Any + Send + Sync>>,
605}
606
607impl Connection {
608    async fn run(mut self) {
609        while let Some(Request { tx, kind }) = self.rx.recv().await {
610            tracing::trace!(?kind, "processing SQL Server query");
611            let result = Connection::handle_request(&mut self.client, kind).await;
612            let (response, maybe_extra_work) = match result {
613                Ok((response, work)) => (Ok(response), work),
614                Err(err) => (Err(err), None),
615            };
616
617            // We don't care if our listener for this query has gone away.
618            let _ = tx.send(response);
619
620            // After we handle a request there might still be something in-flight
621            // that we need to continue driving, e.g. when the response is a
622            // Stream of Rows.
623            if let Some(extra_work) = maybe_extra_work {
624                extra_work.await;
625            }
626        }
627        tracing::debug!("channel closed, SQL Server InnerClient shutting down");
628    }
629
630    async fn handle_request<'c>(
631        client: &'c mut tiberius::Client<Compat<TcpStream>>,
632        kind: RequestKind,
633    ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
634        match kind {
635            RequestKind::Execute { query, params } => {
636                #[allow(clippy::as_conversions)]
637                let params: SmallVec<[&dyn ToSql; 4]> =
638                    params.iter().map(|x| x as &dyn ToSql).collect();
639                let result = client.execute(query, &params[..]).await?;
640
641                match result.rows_affected() {
642                    rows_affected => {
643                        let response = Response::Execute {
644                            rows_affected: rows_affected.into(),
645                        };
646                        Ok((response, None))
647                    }
648                }
649            }
650            RequestKind::Query { query, params } => {
651                #[allow(clippy::as_conversions)]
652                let params: SmallVec<[&dyn ToSql; 4]> =
653                    params.iter().map(|x| x as &dyn ToSql).collect();
654                let result = client.query(query, params.as_slice()).await?;
655
656                let mut results = result.into_results().await.context("into results")?;
657                if results.is_empty() {
658                    Ok((Response::Rows(smallvec![]), None))
659                } else if results.len() == 1 {
660                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
661                    // push onto a SmallVec to avoid the heap allocations.
662                    let rows = results.pop().expect("checked len").into();
663                    Ok((Response::Rows(rows), None))
664                } else {
665                    Err(SqlServerError::ProgrammingError(format!(
666                        "Query only supports 1 statement, got {}",
667                        results.len()
668                    )))
669                }
670            }
671            RequestKind::QueryStreamed { query, params } => {
672                #[allow(clippy::as_conversions)]
673                let params: SmallVec<[&dyn ToSql; 4]> =
674                    params.iter().map(|x| x as &dyn ToSql).collect();
675                let result = client.query(query, params.as_slice()).await?;
676
677                // ~~ Rust Lifetimes ~~
678                //
679                // What's going on here, why do we have some extra channel and
680                // this 'work' future?
681                //
682                // Remember, we run the actual `tiberius::Client` in a separate
683                // `tokio::task` and the `mz::Client` sends query requests via
684                // a channel, this allows us to "automatically" manage
685                // transactions.
686                //
687                // But the returned `QueryStream` from a `tiberius::Client` has
688                // a lifetime associated with said client running in this
689                // separate task. Thus we cannot send the `QueryStream` back to
690                // the `mz::Client` because the lifetime of these two clients
691                // is not linked at all. The fix is to create a separate owned
692                // channel and return the receiving end, while this work future
693                // pulls events off the `QueryStream` and sends them over the
694                // channel we just returned.
695                let (tx, rx) = tokio::sync::mpsc::channel(256);
696                let work = Box::pin(async move {
697                    let mut stream = result.into_row_stream();
698                    while let Some(result) = stream.next().await {
699                        if let Err(err) = tx.send(result.err_into()).await {
700                            tracing::warn!(?err, "SQL Server row stream receiver went away");
701                        }
702                    }
703                    tracing::info!("SQL Server row stream complete");
704                });
705
706                Ok((Response::RowStream { stream: rx }, Some(work)))
707            }
708            RequestKind::SimpleQuery { query } => {
709                let result = client.simple_query(query).await?;
710
711                let mut results = result.into_results().await.context("into results")?;
712                if results.is_empty() {
713                    Ok((Response::Rows(smallvec![]), None))
714                } else if results.len() == 1 {
715                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
716                    // push onto a SmallVec to avoid the heap allocations.
717                    let rows = results.pop().expect("checked len").into();
718                    Ok((Response::Rows(rows), None))
719                } else {
720                    Err(SqlServerError::ProgrammingError(format!(
721                        "Simple query only supports 1 statement, got {}",
722                        results.len()
723                    )))
724                }
725            }
726        }
727    }
728}
729
730impl IntoFuture for Connection {
731    type Output = ();
732    type IntoFuture = BoxFuture<'static, Self::Output>;
733
734    fn into_future(self) -> Self::IntoFuture {
735        self.run().boxed()
736    }
737}
738
739/// Owned version of [`tiberius::ColumnData`] that can be more easily sent
740/// across threads or through a channel.
741#[derive(Debug)]
742enum OwnedColumnData {
743    U8(Option<u8>),
744    I16(Option<i16>),
745    I32(Option<i32>),
746    I64(Option<i64>),
747    F32(Option<f32>),
748    F64(Option<f64>),
749    Bit(Option<bool>),
750    String(Option<String>),
751    Guid(Option<uuid::Uuid>),
752    Binary(Option<Vec<u8>>),
753    Numeric(Option<tiberius::numeric::Numeric>),
754    Xml(Option<tiberius::xml::XmlData>),
755    DateTime(Option<tiberius::time::DateTime>),
756    SmallDateTime(Option<tiberius::time::SmallDateTime>),
757    Time(Option<tiberius::time::Time>),
758    Date(Option<tiberius::time::Date>),
759    DateTime2(Option<tiberius::time::DateTime2>),
760    DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
761}
762
763impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
764    fn from(value: tiberius::ColumnData<'a>) -> Self {
765        match value {
766            tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
767            tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
768            tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
769            tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
770            tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
771            tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
772            tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
773            tiberius::ColumnData::String(inner) => {
774                OwnedColumnData::String(inner.map(|s| s.to_string()))
775            }
776            tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
777            tiberius::ColumnData::Binary(inner) => {
778                OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
779            }
780            tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
781            tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
782            tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
783            tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
784            tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
785            tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
786            tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
787            tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
788        }
789    }
790}
791
792impl tiberius::ToSql for OwnedColumnData {
793    fn to_sql(&self) -> tiberius::ColumnData<'_> {
794        match self {
795            OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
796            OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
797            OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
798            OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
799            OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
800            OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
801            OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
802            OwnedColumnData::String(inner) => {
803                tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
804            }
805            OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
806            OwnedColumnData::Binary(inner) => {
807                tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
808            }
809            OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
810            OwnedColumnData::Xml(inner) => {
811                tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
812            }
813            OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
814            OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
815            OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
816            OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
817            OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
818            OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
819        }
820    }
821}
822
823impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
824    fn from(value: &'a T) -> Self {
825        OwnedColumnData::from(value.to_sql())
826    }
827}
828
829#[derive(Debug, thiserror::Error)]
830pub enum SqlServerError {
831    #[error(transparent)]
832    SqlServer(#[from] tiberius::error::Error),
833    #[error(transparent)]
834    CdcError(#[from] crate::cdc::CdcError),
835    #[error("expected column '{0}' to be present")]
836    MissingColumn(&'static str),
837    #[error("sql server client encountered I/O error: {0}")]
838    IO(#[from] tokio::io::Error),
839    #[error("found invalid data in the column '{column_name}': {error}")]
840    InvalidData { column_name: String, error: String },
841    #[error("got back a null value when querying for the LSN")]
842    NullLsn,
843    #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
844    InvalidSystemSetting {
845        name: String,
846        expected: String,
847        actual: String,
848    },
849    #[error("invariant was violated: {0}")]
850    InvariantViolated(String),
851    #[error(transparent)]
852    Generic(#[from] anyhow::Error),
853    #[error("programming error! {0}")]
854    ProgrammingError(String),
855    #[error(
856        "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
857    )]
858    AuthorizationError {
859        tables: String,
860        capture_instances: String,
861    },
862}
863
864/// Errors returned from decoding SQL Server rows.
865///
866/// **PLEASE READ**
867///
868/// The string representation of this error type is **durably stored** in a source and thus this
869/// error type needs to be **stable** across releases. For example, if in v11 of Materialize we
870/// fail to decode `Row(["foo bar"])` from SQL Server, we will record the error in the source's
871/// Persist shard. If in v12 of Materialize the user deletes the `Row(["foo bar"])` from their
872/// upstream instance, we need to perfectly retract the error we previously committed.
873///
874/// This means be **very** careful when changing this type.
875#[derive(Debug, thiserror::Error)]
876pub enum SqlServerDecodeError {
877    #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
878    InvalidColumn {
879        column_name: String,
880        as_type: &'static str,
881    },
882    #[error("found invalid data in the column '{column_name}': {error}")]
883    InvalidData { column_name: String, error: String },
884    #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
885    Unsupported {
886        sql_server_type: SqlServerColumnDecodeType,
887        mz_type: SqlScalarType,
888    },
889}
890
891impl SqlServerDecodeError {
892    fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
893        // These error messages need to remain stable, do not change them.
894        let error = match error {
895            mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
896        };
897        SqlServerDecodeError::InvalidData {
898            column_name: name.to_string(),
899            error: error.to_string(),
900        }
901    }
902
903    fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
904        // These error messages need to remain stable, do not change them.
905        let error = match error {
906            mz_repr::adt::date::DateError::OutOfRange => "out of range",
907        };
908        SqlServerDecodeError::InvalidData {
909            column_name: name.to_string(),
910            error: error.to_string(),
911        }
912    }
913
914    fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
915        SqlServerDecodeError::InvalidData {
916            column_name: name.to_string(),
917            error: format!("expected {expected_chars} chars found {found_chars}"),
918        }
919    }
920
921    fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
922        SqlServerDecodeError::InvalidData {
923            column_name: name.to_string(),
924            error: format!("expected max {max_chars} chars found {found_chars}"),
925        }
926    }
927
928    fn invalid_column(name: &str, as_type: &'static str) -> Self {
929        SqlServerDecodeError::InvalidColumn {
930            column_name: name.to_string(),
931            as_type,
932        }
933    }
934}
935
936/// Quotes the provided string using '[]' to match SQL Server `QUOTENAME` function. This form
937/// of quotes is unaffected by the SQL Server setting `SET QUOTED_IDENTIFIER`.
938///
939/// See:
940/// - <https://learn.microsoft.com/en-us/sql/t-sql/functions/quotename-transact-sql?view=sql-server-ver17>
941/// - <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-quoted-identifier-transact-sql?view=sql-server-ver17>
942pub fn quote_identifier(ident: &str) -> String {
943    let mut quoted = ident.replace(']', "]]");
944    quoted.insert(0, '[');
945    quoted.push(']');
946    quoted
947}
948
949pub trait SqlServerCdcMetrics {
950    /// Called before the table lock is aquired
951    fn snapshot_table_lock_start(&self, table_name: &str);
952    /// Called after the table lock is released
953    fn snapshot_table_lock_end(&self, table_name: &str);
954}
955
956/// A simple implementation of [`SqlServerCdcMetrics`] that uses the tracing framework to log
957/// the start and end conditions.
958pub struct LoggingSqlServerCdcMetrics;
959
960impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
961    fn snapshot_table_lock_start(&self, table_name: &str) {
962        tracing::info!("snapshot_table_lock_start: {table_name}");
963    }
964
965    fn snapshot_table_lock_end(&self, table_name: &str) {
966        tracing::info!("snapshot_table_lock_end: {table_name}");
967    }
968}
969
970#[cfg(test)]
971mod test {
972    use super::*;
973
974    #[mz_ore::test]
975    fn test_sql_server_escaping() {
976        assert_eq!("[]", &quote_identifier(""));
977        assert_eq!("[]]]", &quote_identifier("]"));
978        assert_eq!("[a]", &quote_identifier("a"));
979        assert_eq!("[cost(]]\u{00A3})]", &quote_identifier("cost(]\u{00A3})"));
980        assert_eq!("[[g[o[o]][]", &quote_identifier("[g[o[o]["));
981    }
982}