1use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::result::ResultExt;
21use mz_repr::SqlScalarType;
22use smallvec::{SmallVec, smallvec};
23use tiberius::ToSql;
24use tokio::net::TcpStream;
25use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
26use tokio::sync::oneshot;
27use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
28
29pub mod cdc;
30pub mod config;
31pub mod desc;
32pub mod inspect;
33
34pub use config::Config;
35pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
36
37use crate::cdc::Lsn;
38use crate::config::TunnelConfig;
39use crate::desc::SqlServerColumnDecodeType;
40
41#[derive(Debug)]
44pub struct Client {
45 tx: UnboundedSender<Request>,
46 config: Config,
48}
49static_assertions::assert_not_impl_all!(Client: Clone);
52
53impl Client {
54 pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
63 let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
66 TunnelConfig::Direct => {
67 let tcp = TcpStream::connect(config.inner.get_addr())
68 .await
69 .context("direct")?;
70 (tcp, None)
71 }
72 TunnelConfig::Ssh {
73 config: ssh_config,
74 manager,
75 timeout,
76 host,
77 port,
78 } => {
79 let tunnel = manager
82 .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
83 .await?;
84 let tcp = TcpStream::connect(tunnel.local_addr())
85 .await
86 .context("ssh tunnel")?;
87
88 (tcp, Some(Box::new(tunnel)))
89 }
90 TunnelConfig::AwsPrivatelink {
91 connection_id,
92 port,
93 } => {
94 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
95 let mut privatelink_addrs =
96 tokio::net::lookup_host((privatelink_host.clone(), 0)).await?;
97
98 let Some(mut addr) = privatelink_addrs.next() else {
99 return Err(SqlServerError::InvariantViolated(format!(
100 "aws privatelink: no addresses found for host {:?}",
101 privatelink_host
102 )));
103 };
104
105 addr.set_port(port.clone());
106
107 let tcp = TcpStream::connect(addr)
108 .await
109 .context(format!("aws privatelink {:?}", addr))?;
110
111 (tcp, None)
112 }
113 };
114
115 tcp.set_nodelay(true)?;
116
117 let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
118 mz_ore::task::spawn(|| "sql-server-client-connection", async move {
119 connection.await
120 });
121
122 Ok(client)
123 }
124
125 pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
128 Self::connect(self.config.clone()).await
129 }
130
131 pub async fn connect_raw(
132 config: Config,
133 tcp: tokio::net::TcpStream,
134 resources: Option<Box<dyn Any + Send + Sync>>,
135 ) -> Result<(Self, Connection), SqlServerError> {
136 let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
137 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
138
139 Ok((
142 Client { tx, config },
143 Connection {
144 rx,
145 client,
146 _resources: resources,
147 },
148 ))
149 }
150
151 pub async fn execute<'a>(
160 &mut self,
161 query: impl Into<Cow<'a, str>>,
162 params: &[&dyn ToSql],
163 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
164 let (tx, rx) = tokio::sync::oneshot::channel();
165
166 let params = params
167 .iter()
168 .map(|p| OwnedColumnData::from(p.to_sql()))
169 .collect();
170 let kind = RequestKind::Execute {
171 query: query.into().to_string(),
172 params,
173 };
174 self.tx
175 .send(Request { tx, kind })
176 .context("sending request")?;
177
178 let response = rx.await.context("channel")??;
179 match response {
180 Response::Execute { rows_affected } => Ok(rows_affected),
181 other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
182 Err(SqlServerError::ProgrammingError(format!(
183 "expected Response::Execute, got {other:?}"
184 )))
185 }
186 }
187 }
188
189 pub async fn query<'a>(
198 &mut self,
199 query: impl Into<Cow<'a, str>>,
200 params: &[&dyn tiberius::ToSql],
201 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
202 let (tx, rx) = tokio::sync::oneshot::channel();
203
204 let params = params
205 .iter()
206 .map(|p| OwnedColumnData::from(p.to_sql()))
207 .collect();
208 let kind = RequestKind::Query {
209 query: query.into().to_string(),
210 params,
211 };
212 self.tx
213 .send(Request { tx, kind })
214 .context("sending request")?;
215
216 let response = rx.await.context("channel")??;
217 match response {
218 Response::Rows(rows) => Ok(rows),
219 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
220 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
221 ),
222 }
223 }
224
225 pub fn query_streaming<'c, 'q, Q>(
230 &'c mut self,
231 query: Q,
232 params: &[&dyn tiberius::ToSql],
233 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
234 where
235 Q: Into<Cow<'q, str>>,
236 {
237 let (tx, rx) = tokio::sync::oneshot::channel();
238 let params = params
239 .iter()
240 .map(|p| OwnedColumnData::from(p.to_sql()))
241 .collect();
242 let kind = RequestKind::QueryStreamed {
243 query: query.into().to_string(),
244 params,
245 };
246
247 let request_future = async move {
249 self.tx
250 .send(Request { tx, kind })
251 .context("sending request")?;
252
253 let response = rx.await.context("channel")??;
254 match response {
255 Response::RowStream { stream } => {
256 Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
257 }
258 other @ Response::Execute { .. } | other @ Response::Rows(_) => {
259 Err(SqlServerError::ProgrammingError(format!(
260 "expected Response::Rows, got {other:?}"
261 )))
262 }
263 }
264 };
265
266 futures::stream::once(request_future).try_flatten()
268 }
269
270 pub async fn simple_query<'a>(
280 &mut self,
281 query: impl Into<Cow<'a, str>>,
282 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
283 let (tx, rx) = tokio::sync::oneshot::channel();
284 let kind = RequestKind::SimpleQuery {
285 query: query.into().to_string(),
286 };
287 self.tx
288 .send(Request { tx, kind })
289 .context("sending request")?;
290
291 let response = rx.await.context("channel")??;
292 match response {
293 Response::Rows(rows) => Ok(rows),
294 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
295 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
296 ),
297 }
298 }
299
300 pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
305 Transaction::new(self).await
306 }
307
308 pub async fn set_transaction_isolation(
310 &mut self,
311 level: TransactionIsolationLevel,
312 ) -> Result<(), SqlServerError> {
313 let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
314 self.simple_query(query).await?;
315 Ok(())
316 }
317
318 pub async fn get_transaction_isolation(
320 &mut self,
321 ) -> Result<TransactionIsolationLevel, SqlServerError> {
322 const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
323 let rows = self.simple_query(QUERY).await?;
324 match &rows[..] {
325 [row] => {
326 let val: i16 = row
327 .try_get(0)
328 .context("getting 0th column")?
329 .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
330 let level = TransactionIsolationLevel::try_from_sql_server(val)?;
331 Ok(level)
332 }
333 other => Err(SqlServerError::InvariantViolated(format!(
334 "expected one row, got {other:?}"
335 ))),
336 }
337 }
338
339 pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
344 where
345 I: IntoIterator,
346 I::Item: Into<Arc<str>>,
347 M: SqlServerCdcMetrics,
348 {
349 let instances = capture_instances
350 .into_iter()
351 .map(|i| (i.into(), None))
352 .collect();
353 crate::cdc::CdcStream::new(self, instances, metrics)
354 }
355}
356
357pub type RowStream<'a> =
359 Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
360
361#[derive(Debug)]
362pub struct Transaction<'a> {
363 client: &'a mut Client,
364 closed: bool,
365}
366
367impl<'a> Transaction<'a> {
368 async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
369 let results = client
370 .simple_query("BEGIN TRANSACTION")
371 .await
372 .context("begin")?;
373 if !results.is_empty() {
374 Err(SqlServerError::InvariantViolated(format!(
375 "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
376 )))
377 } else {
378 Ok(Transaction {
379 client,
380 closed: false,
381 })
382 }
383 }
384
385 pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
395 if savepoint_name.is_empty()
398 || !savepoint_name
399 .chars()
400 .all(|c| c.is_alphanumeric() || c == '_')
401 {
402 Err(SqlServerError::ProgrammingError(format!(
403 "Invalid savepoint name: '{savepoint_name}"
404 )))?;
405 }
406
407 let stmt = format!("SAVE TRANSACTION {}", quote_identifier(savepoint_name));
408 let _result = self.client.simple_query(stmt).await?;
409 Ok(())
410 }
411
412 pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
416 static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
417 FROM sys.dm_tran_database_transactions dt \
418 JOIN sys.dm_tran_current_transaction ct \
419 ON ct.transaction_id = dt.transaction_id \
420 WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
421 let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
422 crate::inspect::parse_numeric_lsn(&result)
423 }
424
425 pub async fn lock_table_shared(
431 &mut self,
432 schema: &str,
433 table: &str,
434 ) -> Result<(), SqlServerError> {
435 static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
438 let query = format!(
442 "{SET_READ_COMMITTED}\nSELECT * FROM {schema}.{table} WITH (TABLOCK, HOLDLOCK) WHERE 1=0;",
443 schema = quote_identifier(schema),
444 table = quote_identifier(table)
445 );
446 let _result = self.client.simple_query(query).await?;
447 Ok(())
448 }
449
450 pub async fn execute<'q>(
452 &mut self,
453 query: impl Into<Cow<'q, str>>,
454 params: &[&dyn ToSql],
455 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
456 self.client.execute(query, params).await
457 }
458
459 pub async fn query<'q>(
461 &mut self,
462 query: impl Into<Cow<'q, str>>,
463 params: &[&dyn tiberius::ToSql],
464 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
465 self.client.query(query, params).await
466 }
467
468 pub fn query_streaming<'c, 'q, Q>(
470 &'c mut self,
471 query: Q,
472 params: &[&dyn tiberius::ToSql],
473 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
474 where
475 Q: Into<Cow<'q, str>>,
476 {
477 self.client.query_streaming(query, params)
478 }
479
480 pub async fn simple_query<'q>(
482 &mut self,
483 query: impl Into<Cow<'q, str>>,
484 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
485 self.client.simple_query(query).await
486 }
487
488 pub async fn rollback(mut self) -> Result<(), SqlServerError> {
490 static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
491 self.closed = true;
494 self.client.simple_query(ROLLBACK_QUERY).await?;
495 Ok(())
496 }
497
498 pub async fn commit(mut self) -> Result<(), SqlServerError> {
500 static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
501 self.closed = true;
504 self.client.simple_query(COMMIT_QUERY).await?;
505 Ok(())
506 }
507}
508
509impl Drop for Transaction<'_> {
510 fn drop(&mut self) {
511 if !self.closed {
514 let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
515 }
516 }
517}
518
519#[derive(Debug, PartialEq, Eq)]
523pub enum TransactionIsolationLevel {
524 ReadUncommitted,
525 ReadCommitted,
526 RepeatableRead,
527 Snapshot,
528 Serializable,
529}
530
531impl TransactionIsolationLevel {
532 fn as_str(&self) -> &'static str {
534 match self {
535 TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
536 TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
537 TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
538 TransactionIsolationLevel::Snapshot => "SNAPSHOT",
539 TransactionIsolationLevel::Serializable => "SERIALIZABLE",
540 }
541 }
542
543 fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
545 let level = match val {
546 1 => TransactionIsolationLevel::ReadUncommitted,
547 2 => TransactionIsolationLevel::ReadCommitted,
548 3 => TransactionIsolationLevel::RepeatableRead,
549 4 => TransactionIsolationLevel::Serializable,
550 5 => TransactionIsolationLevel::Snapshot,
551 x => anyhow::bail!("unknown level {x}"),
552 };
553 Ok(level)
554 }
555}
556
557#[derive(Derivative)]
558#[derivative(Debug)]
559enum Response {
560 Execute {
561 rows_affected: SmallVec<[u64; 1]>,
562 },
563 Rows(SmallVec<[tiberius::Row; 1]>),
564 RowStream {
565 #[derivative(Debug = "ignore")]
566 stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
567 },
568}
569
570#[derive(Debug)]
571struct Request {
572 tx: oneshot::Sender<Result<Response, SqlServerError>>,
573 kind: RequestKind,
574}
575
576#[derive(Derivative)]
577#[derivative(Debug)]
578enum RequestKind {
579 Execute {
580 query: String,
581 #[derivative(Debug = "ignore")]
582 params: SmallVec<[OwnedColumnData; 4]>,
583 },
584 Query {
585 query: String,
586 #[derivative(Debug = "ignore")]
587 params: SmallVec<[OwnedColumnData; 4]>,
588 },
589 QueryStreamed {
590 query: String,
591 #[derivative(Debug = "ignore")]
592 params: SmallVec<[OwnedColumnData; 4]>,
593 },
594 SimpleQuery {
595 query: String,
596 },
597}
598
599pub struct Connection {
600 rx: UnboundedReceiver<Request>,
602 client: tiberius::Client<Compat<TcpStream>>,
604 _resources: Option<Box<dyn Any + Send + Sync>>,
606}
607
608impl Connection {
609 async fn run(mut self) {
610 while let Some(Request { tx, kind }) = self.rx.recv().await {
611 tracing::trace!(?kind, "processing SQL Server query");
612 let result = Connection::handle_request(&mut self.client, kind).await;
613 let (response, maybe_extra_work) = match result {
614 Ok((response, work)) => (Ok(response), work),
615 Err(err) => (Err(err), None),
616 };
617
618 let _ = tx.send(response);
620
621 if let Some(extra_work) = maybe_extra_work {
625 extra_work.await;
626 }
627 }
628 tracing::debug!("channel closed, SQL Server InnerClient shutting down");
629 }
630
631 async fn handle_request<'c>(
632 client: &'c mut tiberius::Client<Compat<TcpStream>>,
633 kind: RequestKind,
634 ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
635 match kind {
636 RequestKind::Execute { query, params } => {
637 #[allow(clippy::as_conversions)]
638 let params: SmallVec<[&dyn ToSql; 4]> =
639 params.iter().map(|x| x as &dyn ToSql).collect();
640 let result = client.execute(query, ¶ms[..]).await?;
641
642 match result.rows_affected() {
643 rows_affected => {
644 let response = Response::Execute {
645 rows_affected: rows_affected.into(),
646 };
647 Ok((response, None))
648 }
649 }
650 }
651 RequestKind::Query { query, params } => {
652 #[allow(clippy::as_conversions)]
653 let params: SmallVec<[&dyn ToSql; 4]> =
654 params.iter().map(|x| x as &dyn ToSql).collect();
655 let result = client.query(query, params.as_slice()).await?;
656
657 let mut results = result.into_results().await.context("into results")?;
658 if results.is_empty() {
659 Ok((Response::Rows(smallvec![]), None))
660 } else if results.len() == 1 {
661 let rows = results.pop().expect("checked len").into();
664 Ok((Response::Rows(rows), None))
665 } else {
666 Err(SqlServerError::ProgrammingError(format!(
667 "Query only supports 1 statement, got {}",
668 results.len()
669 )))
670 }
671 }
672 RequestKind::QueryStreamed { query, params } => {
673 #[allow(clippy::as_conversions)]
674 let params: SmallVec<[&dyn ToSql; 4]> =
675 params.iter().map(|x| x as &dyn ToSql).collect();
676 let result = client.query(query, params.as_slice()).await?;
677
678 let (tx, rx) = tokio::sync::mpsc::channel(256);
697 let work = Box::pin(async move {
698 let mut stream = result.into_row_stream();
699 while let Some(result) = stream.next().await {
700 if let Err(err) = tx.send(result.err_into()).await {
701 tracing::warn!(?err, "SQL Server row stream receiver went away");
702 }
703 }
704 tracing::info!("SQL Server row stream complete");
705 });
706
707 Ok((Response::RowStream { stream: rx }, Some(work)))
708 }
709 RequestKind::SimpleQuery { query } => {
710 let result = client.simple_query(query).await?;
711
712 let mut results = result.into_results().await.context("into results")?;
713 if results.is_empty() {
714 Ok((Response::Rows(smallvec![]), None))
715 } else if results.len() == 1 {
716 let rows = results.pop().expect("checked len").into();
719 Ok((Response::Rows(rows), None))
720 } else {
721 Err(SqlServerError::ProgrammingError(format!(
722 "Simple query only supports 1 statement, got {}",
723 results.len()
724 )))
725 }
726 }
727 }
728 }
729}
730
731impl IntoFuture for Connection {
732 type Output = ();
733 type IntoFuture = BoxFuture<'static, Self::Output>;
734
735 fn into_future(self) -> Self::IntoFuture {
736 self.run().boxed()
737 }
738}
739
740#[derive(Debug)]
743enum OwnedColumnData {
744 U8(Option<u8>),
745 I16(Option<i16>),
746 I32(Option<i32>),
747 I64(Option<i64>),
748 F32(Option<f32>),
749 F64(Option<f64>),
750 Bit(Option<bool>),
751 String(Option<String>),
752 Guid(Option<uuid::Uuid>),
753 Binary(Option<Vec<u8>>),
754 Numeric(Option<tiberius::numeric::Numeric>),
755 Xml(Option<tiberius::xml::XmlData>),
756 DateTime(Option<tiberius::time::DateTime>),
757 SmallDateTime(Option<tiberius::time::SmallDateTime>),
758 Time(Option<tiberius::time::Time>),
759 Date(Option<tiberius::time::Date>),
760 DateTime2(Option<tiberius::time::DateTime2>),
761 DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
762}
763
764impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
765 fn from(value: tiberius::ColumnData<'a>) -> Self {
766 match value {
767 tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
768 tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
769 tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
770 tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
771 tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
772 tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
773 tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
774 tiberius::ColumnData::String(inner) => {
775 OwnedColumnData::String(inner.map(|s| s.to_string()))
776 }
777 tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
778 tiberius::ColumnData::Binary(inner) => {
779 OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
780 }
781 tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
782 tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
783 tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
784 tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
785 tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
786 tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
787 tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
788 tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
789 }
790 }
791}
792
793impl tiberius::ToSql for OwnedColumnData {
794 fn to_sql(&self) -> tiberius::ColumnData<'_> {
795 match self {
796 OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
797 OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
798 OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
799 OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
800 OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
801 OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
802 OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
803 OwnedColumnData::String(inner) => {
804 tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
805 }
806 OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
807 OwnedColumnData::Binary(inner) => {
808 tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
809 }
810 OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
811 OwnedColumnData::Xml(inner) => {
812 tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
813 }
814 OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
815 OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
816 OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
817 OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
818 OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
819 OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
820 }
821 }
822}
823
824impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
825 fn from(value: &'a T) -> Self {
826 OwnedColumnData::from(value.to_sql())
827 }
828}
829
830#[derive(Debug, thiserror::Error)]
831pub enum SqlServerError {
832 #[error(transparent)]
833 SqlServer(#[from] tiberius::error::Error),
834 #[error(transparent)]
835 CdcError(#[from] crate::cdc::CdcError),
836 #[error("expected column '{0}' to be present")]
837 MissingColumn(&'static str),
838 #[error("sql server client encountered I/O error: {0}")]
839 IO(#[from] tokio::io::Error),
840 #[error("found invalid data in the column '{column_name}': {error}")]
841 InvalidData { column_name: String, error: String },
842 #[error("got back a null value when querying for the LSN")]
843 NullLsn,
844 #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
845 InvalidSystemSetting {
846 name: String,
847 expected: String,
848 actual: String,
849 },
850 #[error("invariant was violated: {0}")]
851 InvariantViolated(String),
852 #[error(transparent)]
853 Generic(#[from] anyhow::Error),
854 #[error("programming error! {0}")]
855 ProgrammingError(String),
856 #[error(
857 "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
858 )]
859 AuthorizationError {
860 tables: String,
861 capture_instances: String,
862 },
863}
864
865#[derive(Debug, thiserror::Error)]
877pub enum SqlServerDecodeError {
878 #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
879 InvalidColumn {
880 column_name: String,
881 as_type: &'static str,
882 },
883 #[error("found invalid data in the column '{column_name}': {error}")]
884 InvalidData { column_name: String, error: String },
885 #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
886 Unsupported {
887 sql_server_type: SqlServerColumnDecodeType,
888 mz_type: SqlScalarType,
889 },
890}
891
892impl SqlServerDecodeError {
893 fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
894 let error = match error {
896 mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
897 };
898 SqlServerDecodeError::InvalidData {
899 column_name: name.to_string(),
900 error: error.to_string(),
901 }
902 }
903
904 fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
905 let error = match error {
907 mz_repr::adt::date::DateError::OutOfRange => "out of range",
908 };
909 SqlServerDecodeError::InvalidData {
910 column_name: name.to_string(),
911 error: error.to_string(),
912 }
913 }
914
915 fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
916 SqlServerDecodeError::InvalidData {
917 column_name: name.to_string(),
918 error: format!("expected {expected_chars} chars found {found_chars}"),
919 }
920 }
921
922 fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
923 SqlServerDecodeError::InvalidData {
924 column_name: name.to_string(),
925 error: format!("expected max {max_chars} chars found {found_chars}"),
926 }
927 }
928
929 fn invalid_column(name: &str, as_type: &'static str) -> Self {
930 SqlServerDecodeError::InvalidColumn {
931 column_name: name.to_string(),
932 as_type,
933 }
934 }
935}
936
937pub fn quote_identifier(ident: &str) -> String {
944 let mut quoted = ident.replace(']', "]]");
945 quoted.insert(0, '[');
946 quoted.push(']');
947 quoted
948}
949
950pub trait SqlServerCdcMetrics {
951 fn snapshot_table_lock_start(&self, table_name: &str);
953 fn snapshot_table_lock_end(&self, table_name: &str);
955}
956
957pub struct LoggingSqlServerCdcMetrics;
960
961impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
962 fn snapshot_table_lock_start(&self, table_name: &str) {
963 tracing::info!("snapshot_table_lock_start: {table_name}");
964 }
965
966 fn snapshot_table_lock_end(&self, table_name: &str) {
967 tracing::info!("snapshot_table_lock_end: {table_name}");
968 }
969}
970
971#[cfg(test)]
972mod test {
973 use super::*;
974
975 #[mz_ore::test]
976 fn test_sql_server_escaping() {
977 assert_eq!("[]", "e_identifier(""));
978 assert_eq!("[]]]", "e_identifier("]"));
979 assert_eq!("[a]", "e_identifier("a"));
980 assert_eq!("[cost(]]\u{00A3})]", "e_identifier("cost(]\u{00A3})"));
981 assert_eq!("[[g[o[o]][]", "e_identifier("[g[o[o]["));
982 }
983}