mz_sql_server_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::result::ResultExt;
21use mz_repr::ScalarType;
22use smallvec::{SmallVec, smallvec};
23use tiberius::ToSql;
24use tokio::net::TcpStream;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
26use tokio::sync::oneshot;
27use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
28
29pub mod cdc;
30pub mod config;
31pub mod desc;
32pub mod inspect;
33
34pub use config::Config;
35pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
36
37use crate::config::TunnelConfig;
38use crate::desc::SqlServerColumnDecodeType;
39
40/// Higher level wrapper around a [`tiberius::Client`] that models transaction
41/// management like other database clients.
42#[derive(Debug)]
43pub struct Client {
44    tx: UnboundedSender<Request>,
45}
46// While a Client could implement Clone, it's not obvious how multiple Clients
47// using the same SQL Server connection would interact, so ban it for now.
48static_assertions::assert_not_impl_all!(Client: Clone);
49
50impl Client {
51    /// Connect to the specified SQL Server instance, returning a [`Client`]
52    /// that can be used to query it and a [`Connection`] that must be polled
53    /// to send and receive results.
54    ///
55    /// TODO(sql_server2): Maybe return a `ClientBuilder` here that implements
56    /// IntoFuture and does the default good thing of moving the `Connection`
57    /// into a tokio task? And a `.raw()` option that will instead return both
58    /// the Client and Connection for manual polling.
59    pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
60        // Setup our tunnelling and return any resources that need to be kept
61        // alive for the duration of the connection.
62        let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
63            TunnelConfig::Direct => {
64                let tcp = TcpStream::connect(config.inner.get_addr())
65                    .await
66                    .context("direct")?;
67                (tcp, None)
68            }
69            TunnelConfig::Ssh {
70                config: ssh_config,
71                manager,
72                timeout,
73                host,
74                port,
75            } => {
76                // N.B. If this tunnel is dropped it will close so we need to
77                // keep it alive for the duration of the connection.
78                let tunnel = manager
79                    .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
80                    .await?;
81                let tcp = TcpStream::connect(tunnel.local_addr())
82                    .await
83                    .context("ssh tunnel")?;
84
85                (tcp, Some(Box::new(tunnel)))
86            }
87            TunnelConfig::AwsPrivatelink { connection_id: _ } => {
88                // TODO(sql_server2): Getting this right is tricky because
89                // there is some subtle logic with hostname validation.
90                return Err(SqlServerError::Generic(anyhow::anyhow!(
91                    "Support PrivateLink connections"
92                )));
93            }
94        };
95
96        tcp.set_nodelay(true)?;
97
98        let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
99        mz_ore::task::spawn(|| "sql-server-client-connection", async move {
100            connection.await
101        });
102
103        Ok(client)
104    }
105
106    pub async fn connect_raw(
107        config: Config,
108        tcp: tokio::net::TcpStream,
109        resources: Option<Box<dyn Any + Send + Sync>>,
110    ) -> Result<(Self, Connection), SqlServerError> {
111        let client = tiberius::Client::connect(config.inner, tcp.compat_write()).await?;
112        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
113
114        // TODO(sql_server2): Add a lot more logging here like the Postgres and MySQL clients have.
115
116        Ok((
117            Client { tx },
118            Connection {
119                rx,
120                client,
121                _resources: resources,
122            },
123        ))
124    }
125
126    /// Executes SQL statements in SQL Server, returning the number of rows effected.
127    ///
128    /// Passthrough method for [`tiberius::Client::execute`].
129    ///
130    /// Note: The returned [`Future`] does not need to be awaited for the query
131    /// to be sent.
132    ///
133    /// [`Future`]: std::future::Future
134    pub async fn execute<'a>(
135        &mut self,
136        query: impl Into<Cow<'a, str>>,
137        params: &[&dyn ToSql],
138    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
139        let (tx, rx) = tokio::sync::oneshot::channel();
140
141        let params = params
142            .iter()
143            .map(|p| OwnedColumnData::from(p.to_sql()))
144            .collect();
145        let kind = RequestKind::Execute {
146            query: query.into().to_string(),
147            params,
148        };
149        self.tx
150            .send(Request { tx, kind })
151            .context("sending request")?;
152
153        let response = rx.await.context("channel")??;
154        match response {
155            Response::Execute { rows_affected } => Ok(rows_affected),
156            other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
157                Err(SqlServerError::ProgrammingError(format!(
158                    "expected Response::Execute, got {other:?}"
159                )))
160            }
161        }
162    }
163
164    /// Executes SQL statements in SQL Server, returning the resulting rows.
165    ///
166    /// Passthrough method for [`tiberius::Client::query`].
167    ///
168    /// Note: The returned [`Future`] does not need to be awaited for the query
169    /// to be sent.
170    ///
171    /// [`Future`]: std::future::Future
172    pub async fn query<'a>(
173        &mut self,
174        query: impl Into<Cow<'a, str>>,
175        params: &[&dyn tiberius::ToSql],
176    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
177        let (tx, rx) = tokio::sync::oneshot::channel();
178
179        let params = params
180            .iter()
181            .map(|p| OwnedColumnData::from(p.to_sql()))
182            .collect();
183        let kind = RequestKind::Query {
184            query: query.into().to_string(),
185            params,
186        };
187        self.tx
188            .send(Request { tx, kind })
189            .context("sending request")?;
190
191        let response = rx.await.context("channel")??;
192        match response {
193            Response::Rows(rows) => Ok(rows),
194            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
195                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
196            ),
197        }
198    }
199
200    /// Executes SQL statements in SQL Server, returning a [`Stream`] of
201    /// resulting rows.
202    ///
203    /// Passthrough method for [`tiberius::Client::query`].
204    pub fn query_streaming<'c, 'q, Q>(
205        &'c mut self,
206        query: Q,
207        params: &[&dyn tiberius::ToSql],
208    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
209    where
210        Q: Into<Cow<'q, str>>,
211    {
212        let (tx, rx) = tokio::sync::oneshot::channel();
213        let params = params
214            .iter()
215            .map(|p| OwnedColumnData::from(p.to_sql()))
216            .collect();
217        let kind = RequestKind::QueryStreamed {
218            query: query.into().to_string(),
219            params,
220        };
221
222        // Make our initial request which will return a Stream of Rows.
223        let request_future = async move {
224            self.tx
225                .send(Request { tx, kind })
226                .context("sending request")?;
227
228            let response = rx.await.context("channel")??;
229            match response {
230                Response::RowStream { stream } => {
231                    Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
232                }
233                other @ Response::Execute { .. } | other @ Response::Rows(_) => {
234                    Err(SqlServerError::ProgrammingError(format!(
235                        "expected Response::Rows, got {other:?}"
236                    )))
237                }
238            }
239        };
240
241        // "flatten" our initial request into the returned stream.
242        futures::stream::once(request_future).try_flatten()
243    }
244
245    /// Executes multiple queries, delimited with `;` and return multiple
246    /// result sets; one for each query.
247    ///
248    /// Passthrough method for [`tiberius::Client::simple_query`].
249    ///
250    /// Note: The returned [`Future`] does not need to be awaited for the query
251    /// to be sent.
252    ///
253    /// [`Future`]: std::future::Future
254    pub async fn simple_query<'a>(
255        &mut self,
256        query: impl Into<Cow<'a, str>>,
257    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
258        let (tx, rx) = tokio::sync::oneshot::channel();
259        let kind = RequestKind::SimpleQuery {
260            query: query.into().to_string(),
261        };
262        self.tx
263            .send(Request { tx, kind })
264            .context("sending request")?;
265
266        let response = rx.await.context("channel")??;
267        match response {
268            Response::Rows(rows) => Ok(rows),
269            other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
270                SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
271            ),
272        }
273    }
274
275    /// Starts a transaction which is automatically rolled back on drop.
276    ///
277    /// To commit or rollback the transaction, see [`Transaction::commit`] and
278    /// [`Transaction::rollback`] respectively.
279    pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
280        Transaction::new(self).await
281    }
282
283    /// Sets the transaction isolation level for the current session.
284    pub async fn set_transaction_isolation(
285        &mut self,
286        level: TransactionIsolationLevel,
287    ) -> Result<(), SqlServerError> {
288        let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
289        self.simple_query(query).await?;
290        Ok(())
291    }
292
293    /// Returns the current transaction isolation level for the current session.
294    pub async fn get_transaction_isolation(
295        &mut self,
296    ) -> Result<TransactionIsolationLevel, SqlServerError> {
297        const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
298        let rows = self.simple_query(QUERY).await?;
299        match &rows[..] {
300            [row] => {
301                let val: i16 = row
302                    .try_get(0)
303                    .context("getting 0th column")?
304                    .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
305                let level = TransactionIsolationLevel::try_from_sql_server(val)?;
306                Ok(level)
307            }
308            other => Err(SqlServerError::InvariantViolated(format!(
309                "expected one row, got {other:?}"
310            ))),
311        }
312    }
313
314    /// Return a [`CdcStream`] that can be used to track changes for the specified
315    /// `capture_instances`.
316    ///
317    /// [`CdcStream`]: crate::cdc::CdcStream
318    pub fn cdc<I>(&mut self, capture_instances: I) -> crate::cdc::CdcStream<'_>
319    where
320        I: IntoIterator,
321        I::Item: Into<Arc<str>>,
322    {
323        let instances = capture_instances
324            .into_iter()
325            .map(|i| (i.into(), None))
326            .collect();
327        crate::cdc::CdcStream::new(self, instances)
328    }
329}
330
331/// A stream of [`tiberius::Row`]s.
332pub type RowStream<'a> =
333    Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
334
335#[derive(Debug)]
336pub struct Transaction<'a> {
337    client: &'a mut Client,
338    closed: bool,
339}
340
341impl<'a> Transaction<'a> {
342    async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
343        let results = client
344            .simple_query("BEGIN TRANSACTION")
345            .await
346            .context("begin")?;
347        if !results.is_empty() {
348            Err(SqlServerError::InvariantViolated(format!(
349                "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
350            )))
351        } else {
352            Ok(Transaction {
353                client,
354                closed: false,
355            })
356        }
357    }
358
359    /// See [`Client::execute`].
360    pub async fn execute<'q>(
361        &mut self,
362        query: impl Into<Cow<'q, str>>,
363        params: &[&dyn ToSql],
364    ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
365        self.client.execute(query, params).await
366    }
367
368    /// See [`Client::query`].
369    pub async fn query<'q>(
370        &mut self,
371        query: impl Into<Cow<'q, str>>,
372        params: &[&dyn tiberius::ToSql],
373    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
374        self.client.query(query, params).await
375    }
376
377    /// See [`Client::query_streaming`]
378    pub fn query_streaming<'c, 'q, Q>(
379        &'c mut self,
380        query: Q,
381        params: &[&dyn tiberius::ToSql],
382    ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
383    where
384        Q: Into<Cow<'q, str>>,
385    {
386        self.client.query_streaming(query, params)
387    }
388
389    /// See [`Client::simple_query`].
390    pub async fn simple_query<'q>(
391        &mut self,
392        query: impl Into<Cow<'q, str>>,
393    ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
394        self.client.simple_query(query).await
395    }
396
397    /// Rollback the [`Transaction`].
398    pub async fn rollback(mut self) -> Result<(), SqlServerError> {
399        static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
400        // N.B. Mark closed _before_ running the query. This prevents us from
401        // double closing the transaction if this query itself fails.
402        self.closed = true;
403        self.client.simple_query(ROLLBACK_QUERY).await?;
404        Ok(())
405    }
406
407    /// Commit the [`Transaction`].
408    pub async fn commit(mut self) -> Result<(), SqlServerError> {
409        static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
410        // N.B. Mark closed _before_ running the query. This prevents us from
411        // double closing the transaction if this query itself fails.
412        self.closed = true;
413        self.client.simple_query(COMMIT_QUERY).await?;
414        Ok(())
415    }
416}
417
418impl Drop for Transaction<'_> {
419    fn drop(&mut self) {
420        // Internally the query is synchronously sent down a channel, and the response is what
421        // we await. In other words, we don't need to `.await` here for the query to be run.
422        if !self.closed {
423            let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
424        }
425    }
426}
427
428/// Transaction isolation levels defined by Microsoft's SQL Server.
429///
430/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql>
431#[derive(Debug, PartialEq, Eq)]
432pub enum TransactionIsolationLevel {
433    ReadUncommitted,
434    ReadCommitted,
435    RepeatableRead,
436    Snapshot,
437    Serializable,
438}
439
440impl TransactionIsolationLevel {
441    /// Return the string representation of a transaction isolation level.
442    fn as_str(&self) -> &'static str {
443        match self {
444            TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
445            TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
446            TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
447            TransactionIsolationLevel::Snapshot => "SNAPSHOT",
448            TransactionIsolationLevel::Serializable => "SERIALIZABLE",
449        }
450    }
451
452    /// Try to parse a [`TransactionIsolationLevel`] from the value returned from SQL Server.
453    fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
454        let level = match val {
455            1 => TransactionIsolationLevel::ReadUncommitted,
456            2 => TransactionIsolationLevel::ReadCommitted,
457            3 => TransactionIsolationLevel::RepeatableRead,
458            4 => TransactionIsolationLevel::Serializable,
459            5 => TransactionIsolationLevel::Snapshot,
460            x => anyhow::bail!("unknown level {x}"),
461        };
462        Ok(level)
463    }
464}
465
466#[derive(Derivative)]
467#[derivative(Debug)]
468enum Response {
469    Execute {
470        rows_affected: SmallVec<[u64; 1]>,
471    },
472    Rows(SmallVec<[tiberius::Row; 1]>),
473    RowStream {
474        #[derivative(Debug = "ignore")]
475        stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
476    },
477}
478
479#[derive(Debug)]
480struct Request {
481    tx: oneshot::Sender<Result<Response, SqlServerError>>,
482    kind: RequestKind,
483}
484
485#[derive(Derivative)]
486#[derivative(Debug)]
487enum RequestKind {
488    Execute {
489        query: String,
490        #[derivative(Debug = "ignore")]
491        params: SmallVec<[OwnedColumnData; 4]>,
492    },
493    Query {
494        query: String,
495        #[derivative(Debug = "ignore")]
496        params: SmallVec<[OwnedColumnData; 4]>,
497    },
498    QueryStreamed {
499        query: String,
500        #[derivative(Debug = "ignore")]
501        params: SmallVec<[OwnedColumnData; 4]>,
502    },
503    SimpleQuery {
504        query: String,
505    },
506}
507
508pub struct Connection {
509    /// Other end of the channel that [`Client`] holds.
510    rx: UnboundedReceiver<Request>,
511    /// Actual client that we use to send requests.
512    client: tiberius::Client<Compat<TcpStream>>,
513    /// Resources (e.g. SSH tunnel) that need to be held open for the life of this connection.
514    _resources: Option<Box<dyn Any + Send + Sync>>,
515}
516
517impl Connection {
518    async fn run(mut self) {
519        while let Some(Request { tx, kind }) = self.rx.recv().await {
520            tracing::trace!(?kind, "processing SQL Server query");
521            let result = Connection::handle_request(&mut self.client, kind).await;
522            let (response, maybe_extra_work) = match result {
523                Ok((response, work)) => (Ok(response), work),
524                Err(err) => (Err(err), None),
525            };
526
527            // We don't care if our listener for this query has gone away.
528            let _ = tx.send(response);
529
530            // After we handle a request there might still be something in-flight
531            // that we need to continue driving, e.g. when the response is a
532            // Stream of Rows.
533            if let Some(extra_work) = maybe_extra_work {
534                extra_work.await;
535            }
536        }
537        tracing::debug!("channel closed, SQL Server InnerClient shutting down");
538    }
539
540    async fn handle_request<'c>(
541        client: &'c mut tiberius::Client<Compat<TcpStream>>,
542        kind: RequestKind,
543    ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
544        match kind {
545            RequestKind::Execute { query, params } => {
546                #[allow(clippy::as_conversions)]
547                let params: SmallVec<[&dyn ToSql; 4]> =
548                    params.iter().map(|x| x as &dyn ToSql).collect();
549                let result = client.execute(query, &params[..]).await?;
550
551                match result.rows_affected() {
552                    [] => Err(SqlServerError::InvariantViolated(
553                        "got empty response".into(),
554                    )),
555                    rows_affected => {
556                        let response = Response::Execute {
557                            rows_affected: rows_affected.into(),
558                        };
559                        Ok((response, None))
560                    }
561                }
562            }
563            RequestKind::Query { query, params } => {
564                #[allow(clippy::as_conversions)]
565                let params: SmallVec<[&dyn ToSql; 4]> =
566                    params.iter().map(|x| x as &dyn ToSql).collect();
567                let result = client.query(query, params.as_slice()).await?;
568
569                let mut results = result.into_results().await.context("into results")?;
570                if results.is_empty() {
571                    Ok((Response::Rows(smallvec![]), None))
572                } else if results.len() == 1 {
573                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
574                    // push onto a SmallVec to avoid the heap allocations.
575                    let rows = results.pop().expect("checked len").into();
576                    Ok((Response::Rows(rows), None))
577                } else {
578                    Err(SqlServerError::ProgrammingError(format!(
579                        "Query only supports 1 statement, got {}",
580                        results.len()
581                    )))
582                }
583            }
584            RequestKind::QueryStreamed { query, params } => {
585                #[allow(clippy::as_conversions)]
586                let params: SmallVec<[&dyn ToSql; 4]> =
587                    params.iter().map(|x| x as &dyn ToSql).collect();
588                let result = client.query(query, params.as_slice()).await?;
589
590                // ~~ Rust Lifetimes ~~
591                //
592                // What's going on here, why do we have some extra channel and
593                // this 'work' future?
594                //
595                // Remember, we run the actual `tiberius::Client` in a separate
596                // `tokio::task` and the `mz::Client` sends query requests via
597                // a channel, this allows us to "automatically" manage
598                // transactions.
599                //
600                // But the returned `QueryStream` from a `tiberius::Client` has
601                // a lifetime associated with said client running in this
602                // separate task. Thus we cannot send the `QueryStream` back to
603                // the `mz::Client` because the lifetime of these two clients
604                // is not linked at all. The fix is to create a separate owned
605                // channel and return the receiving end, while this work future
606                // pulls events off the `QueryStream` and sends them over the
607                // channel we just returned.
608                let (tx, rx) = tokio::sync::mpsc::channel(256);
609                let work = Box::pin(async move {
610                    let mut stream = result.into_row_stream();
611                    while let Some(result) = stream.next().await {
612                        if let Err(err) = tx.send(result.err_into()).await {
613                            tracing::warn!(?err, "SQL Server row stream receiver went away");
614                        }
615                    }
616                    tracing::info!("SQL Server row stream complete");
617                });
618
619                Ok((Response::RowStream { stream: rx }, Some(work)))
620            }
621            RequestKind::SimpleQuery { query } => {
622                let result = client.simple_query(query).await?;
623
624                let mut results = result.into_results().await.context("into results")?;
625                if results.is_empty() {
626                    Ok((Response::Rows(smallvec![]), None))
627                } else if results.len() == 1 {
628                    // TODO(sql_server3): Don't use `into_results()` above, instead directly
629                    // push onto a SmallVec to avoid the heap allocations.
630                    let rows = results.pop().expect("checked len").into();
631                    Ok((Response::Rows(rows), None))
632                } else {
633                    Err(SqlServerError::ProgrammingError(format!(
634                        "Simple query only supports 1 statement, got {}",
635                        results.len()
636                    )))
637                }
638            }
639        }
640    }
641}
642
643impl IntoFuture for Connection {
644    type Output = ();
645    type IntoFuture = BoxFuture<'static, Self::Output>;
646
647    fn into_future(self) -> Self::IntoFuture {
648        self.run().boxed()
649    }
650}
651
652/// Owned version of [`tiberius::ColumnData`] that can be more easily sent
653/// across threads or through a channel.
654#[derive(Debug)]
655enum OwnedColumnData {
656    U8(Option<u8>),
657    I16(Option<i16>),
658    I32(Option<i32>),
659    I64(Option<i64>),
660    F32(Option<f32>),
661    F64(Option<f64>),
662    Bit(Option<bool>),
663    String(Option<String>),
664    Guid(Option<uuid::Uuid>),
665    Binary(Option<Vec<u8>>),
666    Numeric(Option<tiberius::numeric::Numeric>),
667    Xml(Option<tiberius::xml::XmlData>),
668    DateTime(Option<tiberius::time::DateTime>),
669    SmallDateTime(Option<tiberius::time::SmallDateTime>),
670    Time(Option<tiberius::time::Time>),
671    Date(Option<tiberius::time::Date>),
672    DateTime2(Option<tiberius::time::DateTime2>),
673    DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
674}
675
676impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
677    fn from(value: tiberius::ColumnData<'a>) -> Self {
678        match value {
679            tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
680            tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
681            tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
682            tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
683            tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
684            tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
685            tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
686            tiberius::ColumnData::String(inner) => {
687                OwnedColumnData::String(inner.map(|s| s.to_string()))
688            }
689            tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
690            tiberius::ColumnData::Binary(inner) => {
691                OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
692            }
693            tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
694            tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
695            tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
696            tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
697            tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
698            tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
699            tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
700            tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
701        }
702    }
703}
704
705impl tiberius::ToSql for OwnedColumnData {
706    fn to_sql(&self) -> tiberius::ColumnData<'_> {
707        match self {
708            OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
709            OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
710            OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
711            OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
712            OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
713            OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
714            OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
715            OwnedColumnData::String(inner) => {
716                tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
717            }
718            OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
719            OwnedColumnData::Binary(inner) => {
720                tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
721            }
722            OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
723            OwnedColumnData::Xml(inner) => {
724                tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
725            }
726            OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
727            OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
728            OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
729            OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
730            OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
731            OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
732        }
733    }
734}
735
736impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
737    fn from(value: &'a T) -> Self {
738        OwnedColumnData::from(value.to_sql())
739    }
740}
741
742#[derive(Debug, thiserror::Error)]
743pub enum SqlServerError {
744    #[error(transparent)]
745    SqlServer(#[from] tiberius::error::Error),
746    #[error(transparent)]
747    CdcError(#[from] crate::cdc::CdcError),
748    #[error("expected column '{0}' to be present")]
749    MissingColumn(&'static str),
750    #[error("sql server client encountered I/O error: {0}")]
751    IO(#[from] tokio::io::Error),
752    #[error("found invalid data in the column '{column_name}': {error}")]
753    InvalidData { column_name: String, error: String },
754    #[error("got back a null value when querying for the LSN")]
755    NullLsn,
756    #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
757    InvalidSystemSetting {
758        name: String,
759        expected: String,
760        actual: String,
761    },
762    #[error("invariant was violated: {0}")]
763    InvariantViolated(String),
764    #[error(transparent)]
765    Generic(#[from] anyhow::Error),
766    #[error("programming error! {0}")]
767    ProgrammingError(String),
768}
769
770/// Errors returned from decoding SQL Server rows.
771///
772/// **PLEASE READ**
773///
774/// The string representation of this error type is **durably stored** in a source and thus this
775/// error type needs to be **stable** across releases. For example, if in v11 of Materialize we
776/// fail to decode `Row(["foo bar"])` from SQL Server, we will record the error in the source's
777/// Persist shard. If in v12 of Materialize the user deletes the `Row(["foo bar"])` from their
778/// upstream instance, we need to perfectly retract the error we previously committed.
779///
780/// This means be **very** careful when changing this type.
781#[derive(Debug, thiserror::Error)]
782pub enum SqlServerDecodeError {
783    #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
784    InvalidColumn {
785        column_name: String,
786        as_type: &'static str,
787    },
788    #[error("found invalid data in the column '{column_name}': {error}")]
789    InvalidData { column_name: String, error: String },
790    #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
791    Unsupported {
792        sql_server_type: SqlServerColumnDecodeType,
793        mz_type: ScalarType,
794    },
795}
796
797impl SqlServerDecodeError {
798    fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
799        // These error messages need to remain stable, do not change them.
800        let error = match error {
801            mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
802        };
803        SqlServerDecodeError::InvalidData {
804            column_name: name.to_string(),
805            error: error.to_string(),
806        }
807    }
808
809    fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
810        // These error messages need to remain stable, do not change them.
811        let error = match error {
812            mz_repr::adt::date::DateError::OutOfRange => "out of range",
813        };
814        SqlServerDecodeError::InvalidData {
815            column_name: name.to_string(),
816            error: error.to_string(),
817        }
818    }
819
820    fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
821        SqlServerDecodeError::InvalidData {
822            column_name: name.to_string(),
823            error: format!("expected {expected_chars} chars found {found_chars}"),
824        }
825    }
826
827    fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
828        SqlServerDecodeError::InvalidData {
829            column_name: name.to_string(),
830            error: format!("expected max {max_chars} chars found {found_chars}"),
831        }
832    }
833
834    fn invalid_column(name: &str, as_type: &'static str) -> Self {
835        SqlServerDecodeError::InvalidColumn {
836            column_name: name.to_string(),
837            as_type,
838        }
839    }
840}