1use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::result::ResultExt;
21use mz_repr::SqlScalarType;
22use smallvec::{SmallVec, smallvec};
23use tiberius::ToSql;
24use tokio::net::TcpStream;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
26use tokio::sync::oneshot;
27use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
28
29pub mod cdc;
30pub mod config;
31pub mod desc;
32pub mod inspect;
33
34pub use config::Config;
35pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
36
37use crate::cdc::Lsn;
38use crate::config::TunnelConfig;
39use crate::desc::SqlServerColumnDecodeType;
40
41#[derive(Debug)]
44pub struct Client {
45 tx: UnboundedSender<Request>,
46 config: Config,
48}
49static_assertions::assert_not_impl_all!(Client: Clone);
52
53impl Client {
54 pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
63 let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
66 TunnelConfig::Direct { resolved_addresses } => {
67 let tcp = if resolved_addresses.is_empty() {
68 TcpStream::connect(config.inner.get_addr()).await
69 } else {
70 TcpStream::connect(resolved_addresses.as_ref()).await
71 }
72 .context("direct")?;
73 (tcp, None)
74 }
75 TunnelConfig::Ssh {
76 config: ssh_config,
77 manager,
78 timeout,
79 host,
80 port,
81 } => {
82 let tunnel = manager
85 .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
86 .await?;
87 let tcp = TcpStream::connect(tunnel.local_addr())
88 .await
89 .context("ssh tunnel")?;
90
91 (tcp, Some(Box::new(tunnel)))
92 }
93 TunnelConfig::AwsPrivatelink {
94 connection_id,
95 port,
96 } => {
97 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
98 let tcp = TcpStream::connect((privatelink_host.as_str(), *port))
99 .await
100 .context(format!("aws privatelink {:?}", privatelink_host))?;
101
102 (tcp, None)
103 }
104 };
105
106 tcp.set_nodelay(true)?;
107
108 let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
109 mz_ore::task::spawn(|| "sql-server-client-connection", async move {
110 connection.await
111 });
112
113 Ok(client)
114 }
115
116 pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
119 Self::connect(self.config.clone()).await
120 }
121
122 pub async fn connect_raw(
123 config: Config,
124 tcp: tokio::net::TcpStream,
125 resources: Option<Box<dyn Any + Send + Sync>>,
126 ) -> Result<(Self, Connection), SqlServerError> {
127 let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
128 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
129
130 Ok((
133 Client { tx, config },
134 Connection {
135 rx,
136 client,
137 _resources: resources,
138 },
139 ))
140 }
141
142 pub async fn execute<'a>(
151 &mut self,
152 query: impl Into<Cow<'a, str>>,
153 params: &[&dyn ToSql],
154 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
155 let (tx, rx) = tokio::sync::oneshot::channel();
156
157 let params = params
158 .iter()
159 .map(|p| OwnedColumnData::from(p.to_sql()))
160 .collect();
161 let kind = RequestKind::Execute {
162 query: query.into().to_string(),
163 params,
164 };
165 self.tx
166 .send(Request { tx, kind })
167 .context("sending request")?;
168
169 let response = rx.await.context("channel")??;
170 match response {
171 Response::Execute { rows_affected } => Ok(rows_affected),
172 other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
173 Err(SqlServerError::ProgrammingError(format!(
174 "expected Response::Execute, got {other:?}"
175 )))
176 }
177 }
178 }
179
180 pub async fn query<'a>(
189 &mut self,
190 query: impl Into<Cow<'a, str>>,
191 params: &[&dyn tiberius::ToSql],
192 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
193 let (tx, rx) = tokio::sync::oneshot::channel();
194
195 let params = params
196 .iter()
197 .map(|p| OwnedColumnData::from(p.to_sql()))
198 .collect();
199 let kind = RequestKind::Query {
200 query: query.into().to_string(),
201 params,
202 };
203 self.tx
204 .send(Request { tx, kind })
205 .context("sending request")?;
206
207 let response = rx.await.context("channel")??;
208 match response {
209 Response::Rows(rows) => Ok(rows),
210 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
211 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
212 ),
213 }
214 }
215
216 pub fn query_streaming<'c, 'q, Q>(
221 &'c mut self,
222 query: Q,
223 params: &[&dyn tiberius::ToSql],
224 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
225 where
226 Q: Into<Cow<'q, str>>,
227 {
228 let (tx, rx) = tokio::sync::oneshot::channel();
229 let params = params
230 .iter()
231 .map(|p| OwnedColumnData::from(p.to_sql()))
232 .collect();
233 let kind = RequestKind::QueryStreamed {
234 query: query.into().to_string(),
235 params,
236 };
237
238 let request_future = async move {
240 self.tx
241 .send(Request { tx, kind })
242 .context("sending request")?;
243
244 let response = rx.await.context("channel")??;
245 match response {
246 Response::RowStream { stream } => {
247 Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
248 }
249 other @ Response::Execute { .. } | other @ Response::Rows(_) => {
250 Err(SqlServerError::ProgrammingError(format!(
251 "expected Response::Rows, got {other:?}"
252 )))
253 }
254 }
255 };
256
257 futures::stream::once(request_future).try_flatten()
259 }
260
261 pub async fn simple_query<'a>(
271 &mut self,
272 query: impl Into<Cow<'a, str>>,
273 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
274 let (tx, rx) = tokio::sync::oneshot::channel();
275 let kind = RequestKind::SimpleQuery {
276 query: query.into().to_string(),
277 };
278 self.tx
279 .send(Request { tx, kind })
280 .context("sending request")?;
281
282 let response = rx.await.context("channel")??;
283 match response {
284 Response::Rows(rows) => Ok(rows),
285 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
286 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
287 ),
288 }
289 }
290
291 pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
296 Transaction::new(self).await
297 }
298
299 pub async fn set_transaction_isolation(
301 &mut self,
302 level: TransactionIsolationLevel,
303 ) -> Result<(), SqlServerError> {
304 let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
305 self.simple_query(query).await?;
306 Ok(())
307 }
308
309 pub async fn get_transaction_isolation(
311 &mut self,
312 ) -> Result<TransactionIsolationLevel, SqlServerError> {
313 const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
314 let rows = self.simple_query(QUERY).await?;
315 match &rows[..] {
316 [row] => {
317 let val: i16 = row
318 .try_get(0)
319 .context("getting 0th column")?
320 .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
321 let level = TransactionIsolationLevel::try_from_sql_server(val)?;
322 Ok(level)
323 }
324 other => Err(SqlServerError::InvariantViolated(format!(
325 "expected one row, got {other:?}"
326 ))),
327 }
328 }
329
330 pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
335 where
336 I: IntoIterator,
337 I::Item: Into<Arc<str>>,
338 M: SqlServerCdcMetrics,
339 {
340 let instances = capture_instances
341 .into_iter()
342 .map(|i| (i.into(), None))
343 .collect();
344 crate::cdc::CdcStream::new(self, instances, metrics)
345 }
346}
347
348pub type RowStream<'a> =
350 Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
351
352#[derive(Debug)]
353pub struct Transaction<'a> {
354 client: &'a mut Client,
355 closed: bool,
356}
357
358impl<'a> Transaction<'a> {
359 async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
360 let tx = Transaction {
363 client,
364 closed: false,
365 };
366 let results = tx
367 .client
368 .simple_query("BEGIN TRANSACTION")
369 .await
370 .context("begin")?;
371 if !results.is_empty() {
372 Err(SqlServerError::InvariantViolated(format!(
373 "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
374 )))
375 } else {
376 Ok(tx)
377 }
378 }
379
380 pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
390 if savepoint_name.is_empty()
393 || !savepoint_name
394 .chars()
395 .all(|c| c.is_alphanumeric() || c == '_')
396 {
397 Err(SqlServerError::ProgrammingError(format!(
398 "Invalid savepoint name: '{savepoint_name}"
399 )))?;
400 }
401
402 let stmt = format!("SAVE TRANSACTION {}", quote_identifier(savepoint_name));
403 let _result = self.client.simple_query(stmt).await?;
404 Ok(())
405 }
406
407 pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
411 static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
412 FROM sys.dm_tran_database_transactions dt \
413 JOIN sys.dm_tran_current_transaction ct \
414 ON ct.transaction_id = dt.transaction_id \
415 WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
416 let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
417 crate::inspect::parse_numeric_lsn(&result)
418 }
419
420 pub async fn lock_table_shared(
426 &mut self,
427 schema: &str,
428 table: &str,
429 ) -> Result<(), SqlServerError> {
430 static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
433 let query = format!(
437 "{SET_READ_COMMITTED}\nSELECT * FROM {schema}.{table} WITH (TABLOCK, HOLDLOCK) WHERE 1=0;",
438 schema = quote_identifier(schema),
439 table = quote_identifier(table)
440 );
441 let _result = self.client.simple_query(query).await?;
442 Ok(())
443 }
444
445 pub async fn execute<'q>(
447 &mut self,
448 query: impl Into<Cow<'q, str>>,
449 params: &[&dyn ToSql],
450 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
451 self.client.execute(query, params).await
452 }
453
454 pub async fn query<'q>(
456 &mut self,
457 query: impl Into<Cow<'q, str>>,
458 params: &[&dyn tiberius::ToSql],
459 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
460 self.client.query(query, params).await
461 }
462
463 pub fn query_streaming<'c, 'q, Q>(
465 &'c mut self,
466 query: Q,
467 params: &[&dyn tiberius::ToSql],
468 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
469 where
470 Q: Into<Cow<'q, str>>,
471 {
472 self.client.query_streaming(query, params)
473 }
474
475 pub async fn simple_query<'q>(
477 &mut self,
478 query: impl Into<Cow<'q, str>>,
479 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
480 self.client.simple_query(query).await
481 }
482
483 pub async fn rollback(mut self) -> Result<(), SqlServerError> {
485 static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
486 self.closed = true;
489 self.client.simple_query(ROLLBACK_QUERY).await?;
490 Ok(())
491 }
492
493 pub async fn commit(mut self) -> Result<(), SqlServerError> {
495 static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
496 self.closed = true;
499 self.client.simple_query(COMMIT_QUERY).await?;
500 Ok(())
501 }
502}
503
504impl Drop for Transaction<'_> {
505 fn drop(&mut self) {
506 if !self.closed {
507 let (tx, _rx) = oneshot::channel();
516 let kind = RequestKind::SimpleQuery {
517 query: "ROLLBACK TRANSACTION".to_string(),
518 };
519 let _ = self.client.tx.send(Request { tx, kind });
520 }
521 }
522}
523
524#[derive(Debug, PartialEq, Eq)]
528pub enum TransactionIsolationLevel {
529 ReadUncommitted,
530 ReadCommitted,
531 RepeatableRead,
532 Snapshot,
533 Serializable,
534}
535
536impl TransactionIsolationLevel {
537 fn as_str(&self) -> &'static str {
539 match self {
540 TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
541 TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
542 TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
543 TransactionIsolationLevel::Snapshot => "SNAPSHOT",
544 TransactionIsolationLevel::Serializable => "SERIALIZABLE",
545 }
546 }
547
548 fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
550 let level = match val {
551 1 => TransactionIsolationLevel::ReadUncommitted,
552 2 => TransactionIsolationLevel::ReadCommitted,
553 3 => TransactionIsolationLevel::RepeatableRead,
554 4 => TransactionIsolationLevel::Serializable,
555 5 => TransactionIsolationLevel::Snapshot,
556 x => anyhow::bail!("unknown level {x}"),
557 };
558 Ok(level)
559 }
560}
561
562#[derive(Derivative)]
563#[derivative(Debug)]
564enum Response {
565 Execute {
566 rows_affected: SmallVec<[u64; 1]>,
567 },
568 Rows(SmallVec<[tiberius::Row; 1]>),
569 RowStream {
570 #[derivative(Debug = "ignore")]
571 stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
572 },
573}
574
575#[derive(Debug)]
576struct Request {
577 tx: oneshot::Sender<Result<Response, SqlServerError>>,
578 kind: RequestKind,
579}
580
581#[derive(Derivative)]
582#[derivative(Debug)]
583enum RequestKind {
584 Execute {
585 query: String,
586 #[derivative(Debug = "ignore")]
587 params: SmallVec<[OwnedColumnData; 4]>,
588 },
589 Query {
590 query: String,
591 #[derivative(Debug = "ignore")]
592 params: SmallVec<[OwnedColumnData; 4]>,
593 },
594 QueryStreamed {
595 query: String,
596 #[derivative(Debug = "ignore")]
597 params: SmallVec<[OwnedColumnData; 4]>,
598 },
599 SimpleQuery {
600 query: String,
601 },
602}
603
604pub struct Connection {
605 rx: UnboundedReceiver<Request>,
607 client: tiberius::Client<Compat<TcpStream>>,
609 _resources: Option<Box<dyn Any + Send + Sync>>,
611}
612
613impl Connection {
614 async fn run(mut self) {
615 while let Some(Request { tx, kind }) = self.rx.recv().await {
616 tracing::trace!(?kind, "processing SQL Server query");
617 let result = Connection::handle_request(&mut self.client, kind).await;
618 let (response, maybe_extra_work) = match result {
619 Ok((response, work)) => (Ok(response), work),
620 Err(err) => (Err(err), None),
621 };
622
623 let _ = tx.send(response);
625
626 if let Some(extra_work) = maybe_extra_work {
630 extra_work.await;
631 }
632 }
633 tracing::debug!("channel closed, SQL Server InnerClient shutting down");
634 }
635
636 async fn handle_request<'c>(
637 client: &'c mut tiberius::Client<Compat<TcpStream>>,
638 kind: RequestKind,
639 ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
640 match kind {
641 RequestKind::Execute { query, params } => {
642 #[allow(clippy::as_conversions)]
643 let params: SmallVec<[&dyn ToSql; 4]> =
644 params.iter().map(|x| x as &dyn ToSql).collect();
645 let result = client.execute(query, ¶ms[..]).await?;
646
647 match result.rows_affected() {
648 rows_affected => {
649 let response = Response::Execute {
650 rows_affected: rows_affected.into(),
651 };
652 Ok((response, None))
653 }
654 }
655 }
656 RequestKind::Query { query, params } => {
657 #[allow(clippy::as_conversions)]
658 let params: SmallVec<[&dyn ToSql; 4]> =
659 params.iter().map(|x| x as &dyn ToSql).collect();
660 let result = client.query(query, params.as_slice()).await?;
661
662 let mut results = result.into_results().await.context("into results")?;
663 if results.is_empty() {
664 Ok((Response::Rows(smallvec![]), None))
665 } else if results.len() == 1 {
666 let rows = results.pop().expect("checked len").into();
669 Ok((Response::Rows(rows), None))
670 } else {
671 Err(SqlServerError::ProgrammingError(format!(
672 "Query only supports 1 statement, got {}",
673 results.len()
674 )))
675 }
676 }
677 RequestKind::QueryStreamed { query, params } => {
678 #[allow(clippy::as_conversions)]
679 let params: SmallVec<[&dyn ToSql; 4]> =
680 params.iter().map(|x| x as &dyn ToSql).collect();
681 let result = client.query(query, params.as_slice()).await?;
682
683 let (tx, rx) = tokio::sync::mpsc::channel(256);
702 let work = Box::pin(async move {
703 let mut stream = result.into_row_stream();
704 while let Some(result) = stream.next().await {
705 if let Err(err) = tx.send(result.err_into()).await {
706 tracing::warn!(?err, "SQL Server row stream receiver went away");
707 }
708 }
709 tracing::info!("SQL Server row stream complete");
710 });
711
712 Ok((Response::RowStream { stream: rx }, Some(work)))
713 }
714 RequestKind::SimpleQuery { query } => {
715 let result = client.simple_query(query).await?;
716
717 let mut results = result.into_results().await.context("into results")?;
718 if results.is_empty() {
719 Ok((Response::Rows(smallvec![]), None))
720 } else if results.len() == 1 {
721 let rows = results.pop().expect("checked len").into();
724 Ok((Response::Rows(rows), None))
725 } else {
726 Err(SqlServerError::ProgrammingError(format!(
727 "Simple query only supports 1 statement, got {}",
728 results.len()
729 )))
730 }
731 }
732 }
733 }
734}
735
736impl IntoFuture for Connection {
737 type Output = ();
738 type IntoFuture = BoxFuture<'static, Self::Output>;
739
740 fn into_future(self) -> Self::IntoFuture {
741 self.run().boxed()
742 }
743}
744
745#[derive(Debug)]
748enum OwnedColumnData {
749 U8(Option<u8>),
750 I16(Option<i16>),
751 I32(Option<i32>),
752 I64(Option<i64>),
753 F32(Option<f32>),
754 F64(Option<f64>),
755 Bit(Option<bool>),
756 String(Option<String>),
757 Guid(Option<uuid::Uuid>),
758 Binary(Option<Vec<u8>>),
759 Numeric(Option<tiberius::numeric::Numeric>),
760 Xml(Option<tiberius::xml::XmlData>),
761 DateTime(Option<tiberius::time::DateTime>),
762 SmallDateTime(Option<tiberius::time::SmallDateTime>),
763 Time(Option<tiberius::time::Time>),
764 Date(Option<tiberius::time::Date>),
765 DateTime2(Option<tiberius::time::DateTime2>),
766 DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
767}
768
769impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
770 fn from(value: tiberius::ColumnData<'a>) -> Self {
771 match value {
772 tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
773 tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
774 tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
775 tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
776 tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
777 tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
778 tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
779 tiberius::ColumnData::String(inner) => {
780 OwnedColumnData::String(inner.map(|s| s.to_string()))
781 }
782 tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
783 tiberius::ColumnData::Binary(inner) => {
784 OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
785 }
786 tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
787 tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
788 tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
789 tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
790 tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
791 tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
792 tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
793 tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
794 }
795 }
796}
797
798impl tiberius::ToSql for OwnedColumnData {
799 fn to_sql(&self) -> tiberius::ColumnData<'_> {
800 match self {
801 OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
802 OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
803 OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
804 OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
805 OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
806 OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
807 OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
808 OwnedColumnData::String(inner) => {
809 tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
810 }
811 OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
812 OwnedColumnData::Binary(inner) => {
813 tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
814 }
815 OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
816 OwnedColumnData::Xml(inner) => {
817 tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
818 }
819 OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
820 OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
821 OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
822 OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
823 OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
824 OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
825 }
826 }
827}
828
829impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
830 fn from(value: &'a T) -> Self {
831 OwnedColumnData::from(value.to_sql())
832 }
833}
834
835#[derive(Debug, thiserror::Error)]
836pub enum SqlServerError {
837 #[error(transparent)]
838 SqlServer(#[from] tiberius::error::Error),
839 #[error(transparent)]
840 CdcError(#[from] crate::cdc::CdcError),
841 #[error("expected column '{0}' to be present")]
842 MissingColumn(&'static str),
843 #[error("sql server client encountered I/O error: {0}")]
844 IO(#[from] tokio::io::Error),
845 #[error("found invalid data in the column '{column_name}': {error}")]
846 InvalidData { column_name: String, error: String },
847 #[error("got back a null value when querying for the LSN")]
848 NullLsn,
849 #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
850 InvalidSystemSetting {
851 name: String,
852 expected: String,
853 actual: String,
854 },
855 #[error("invariant was violated: {0}")]
856 InvariantViolated(String),
857 #[error(transparent)]
858 Generic(#[from] anyhow::Error),
859 #[error("programming error! {0}")]
860 ProgrammingError(String),
861 #[error(
862 "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
863 )]
864 AuthorizationError {
865 tables: String,
866 capture_instances: String,
867 },
868}
869
870#[derive(Debug, thiserror::Error)]
882pub enum SqlServerDecodeError {
883 #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
884 InvalidColumn {
885 column_name: String,
886 as_type: &'static str,
887 },
888 #[error("found invalid data in the column '{column_name}': {error}")]
889 InvalidData { column_name: String, error: String },
890 #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
891 Unsupported {
892 sql_server_type: SqlServerColumnDecodeType,
893 mz_type: SqlScalarType,
894 },
895}
896
897impl SqlServerDecodeError {
898 fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
899 let error = match error {
901 mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
902 };
903 SqlServerDecodeError::InvalidData {
904 column_name: name.to_string(),
905 error: error.to_string(),
906 }
907 }
908
909 fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
910 let error = match error {
912 mz_repr::adt::date::DateError::OutOfRange => "out of range",
913 };
914 SqlServerDecodeError::InvalidData {
915 column_name: name.to_string(),
916 error: error.to_string(),
917 }
918 }
919
920 fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
921 SqlServerDecodeError::InvalidData {
922 column_name: name.to_string(),
923 error: format!("expected {expected_chars} chars found {found_chars}"),
924 }
925 }
926
927 fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
928 SqlServerDecodeError::InvalidData {
929 column_name: name.to_string(),
930 error: format!("expected max {max_chars} chars found {found_chars}"),
931 }
932 }
933
934 fn invalid_column(name: &str, as_type: &'static str) -> Self {
935 SqlServerDecodeError::InvalidColumn {
936 column_name: name.to_string(),
937 as_type,
938 }
939 }
940}
941
942pub fn quote_identifier(ident: &str) -> String {
949 let mut quoted = ident.replace(']', "]]");
950 quoted.insert(0, '[');
951 quoted.push(']');
952 quoted
953}
954
955pub trait SqlServerCdcMetrics {
956 fn snapshot_table_lock_start(&self, table_name: &str);
958 fn snapshot_table_lock_end(&self, table_name: &str);
960}
961
962pub struct LoggingSqlServerCdcMetrics;
965
966impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
967 fn snapshot_table_lock_start(&self, table_name: &str) {
968 tracing::info!("snapshot_table_lock_start: {table_name}");
969 }
970
971 fn snapshot_table_lock_end(&self, table_name: &str) {
972 tracing::info!("snapshot_table_lock_end: {table_name}");
973 }
974}
975
976#[cfg(test)]
977mod test {
978 use super::*;
979
980 #[mz_ore::test]
981 fn test_sql_server_escaping() {
982 assert_eq!("[]", "e_identifier(""));
983 assert_eq!("[]]]", "e_identifier("]"));
984 assert_eq!("[a]", "e_identifier("a"));
985 assert_eq!("[cost(]]\u{00A3})]", "e_identifier("cost(]\u{00A3})"));
986 assert_eq!("[[g[o[o]][]", "e_identifier("[g[o[o]["));
987 }
988}