1use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::netio::DUMMY_DNS_PORT;
21use mz_ore::result::ResultExt;
22use mz_repr::SqlScalarType;
23use smallvec::{SmallVec, smallvec};
24use tiberius::ToSql;
25use tokio::net::TcpStream;
26use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
27use tokio::sync::oneshot;
28use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
29
30pub mod cdc;
31pub mod config;
32pub mod desc;
33pub mod inspect;
34
35pub use config::Config;
36pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
37
38use crate::cdc::Lsn;
39use crate::config::TunnelConfig;
40use crate::desc::SqlServerColumnDecodeType;
41
42#[derive(Debug)]
45pub struct Client {
46 tx: UnboundedSender<Request>,
47 config: Config,
49}
50static_assertions::assert_not_impl_all!(Client: Clone);
53
54impl Client {
55 pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
64 let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
67 TunnelConfig::Direct => {
68 let tcp = TcpStream::connect(config.inner.get_addr())
69 .await
70 .context("direct")?;
71 (tcp, None)
72 }
73 TunnelConfig::Ssh {
74 config: ssh_config,
75 manager,
76 timeout,
77 host,
78 port,
79 } => {
80 let tunnel = manager
83 .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
84 .await?;
85 let tcp = TcpStream::connect(tunnel.local_addr())
86 .await
87 .context("ssh tunnel")?;
88
89 (tcp, Some(Box::new(tunnel)))
90 }
91 TunnelConfig::AwsPrivatelink {
92 connection_id,
93 port,
94 } => {
95 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
96 let mut privatelink_addrs =
97 tokio::net::lookup_host((privatelink_host.clone(), DUMMY_DNS_PORT)).await?;
98
99 let Some(mut addr) = privatelink_addrs.next() else {
100 return Err(SqlServerError::InvariantViolated(format!(
101 "aws privatelink: no addresses found for host {:?}",
102 privatelink_host
103 )));
104 };
105
106 addr.set_port(port.clone());
107
108 let tcp = TcpStream::connect(addr)
109 .await
110 .context(format!("aws privatelink {:?}", addr))?;
111
112 (tcp, None)
113 }
114 };
115
116 tcp.set_nodelay(true)?;
117
118 let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
119 mz_ore::task::spawn(|| "sql-server-client-connection", async move {
120 connection.await
121 });
122
123 Ok(client)
124 }
125
126 pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
129 Self::connect(self.config.clone()).await
130 }
131
132 pub async fn connect_raw(
133 config: Config,
134 tcp: tokio::net::TcpStream,
135 resources: Option<Box<dyn Any + Send + Sync>>,
136 ) -> Result<(Self, Connection), SqlServerError> {
137 let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
138 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
139
140 Ok((
143 Client { tx, config },
144 Connection {
145 rx,
146 client,
147 _resources: resources,
148 },
149 ))
150 }
151
152 pub async fn execute<'a>(
161 &mut self,
162 query: impl Into<Cow<'a, str>>,
163 params: &[&dyn ToSql],
164 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
165 let (tx, rx) = tokio::sync::oneshot::channel();
166
167 let params = params
168 .iter()
169 .map(|p| OwnedColumnData::from(p.to_sql()))
170 .collect();
171 let kind = RequestKind::Execute {
172 query: query.into().to_string(),
173 params,
174 };
175 self.tx
176 .send(Request { tx, kind })
177 .context("sending request")?;
178
179 let response = rx.await.context("channel")??;
180 match response {
181 Response::Execute { rows_affected } => Ok(rows_affected),
182 other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
183 Err(SqlServerError::ProgrammingError(format!(
184 "expected Response::Execute, got {other:?}"
185 )))
186 }
187 }
188 }
189
190 pub async fn query<'a>(
199 &mut self,
200 query: impl Into<Cow<'a, str>>,
201 params: &[&dyn tiberius::ToSql],
202 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
203 let (tx, rx) = tokio::sync::oneshot::channel();
204
205 let params = params
206 .iter()
207 .map(|p| OwnedColumnData::from(p.to_sql()))
208 .collect();
209 let kind = RequestKind::Query {
210 query: query.into().to_string(),
211 params,
212 };
213 self.tx
214 .send(Request { tx, kind })
215 .context("sending request")?;
216
217 let response = rx.await.context("channel")??;
218 match response {
219 Response::Rows(rows) => Ok(rows),
220 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
221 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
222 ),
223 }
224 }
225
226 pub fn query_streaming<'c, 'q, Q>(
231 &'c mut self,
232 query: Q,
233 params: &[&dyn tiberius::ToSql],
234 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
235 where
236 Q: Into<Cow<'q, str>>,
237 {
238 let (tx, rx) = tokio::sync::oneshot::channel();
239 let params = params
240 .iter()
241 .map(|p| OwnedColumnData::from(p.to_sql()))
242 .collect();
243 let kind = RequestKind::QueryStreamed {
244 query: query.into().to_string(),
245 params,
246 };
247
248 let request_future = async move {
250 self.tx
251 .send(Request { tx, kind })
252 .context("sending request")?;
253
254 let response = rx.await.context("channel")??;
255 match response {
256 Response::RowStream { stream } => {
257 Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
258 }
259 other @ Response::Execute { .. } | other @ Response::Rows(_) => {
260 Err(SqlServerError::ProgrammingError(format!(
261 "expected Response::Rows, got {other:?}"
262 )))
263 }
264 }
265 };
266
267 futures::stream::once(request_future).try_flatten()
269 }
270
271 pub async fn simple_query<'a>(
281 &mut self,
282 query: impl Into<Cow<'a, str>>,
283 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
284 let (tx, rx) = tokio::sync::oneshot::channel();
285 let kind = RequestKind::SimpleQuery {
286 query: query.into().to_string(),
287 };
288 self.tx
289 .send(Request { tx, kind })
290 .context("sending request")?;
291
292 let response = rx.await.context("channel")??;
293 match response {
294 Response::Rows(rows) => Ok(rows),
295 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
296 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
297 ),
298 }
299 }
300
301 pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
306 Transaction::new(self).await
307 }
308
309 pub async fn set_transaction_isolation(
311 &mut self,
312 level: TransactionIsolationLevel,
313 ) -> Result<(), SqlServerError> {
314 let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
315 self.simple_query(query).await?;
316 Ok(())
317 }
318
319 pub async fn get_transaction_isolation(
321 &mut self,
322 ) -> Result<TransactionIsolationLevel, SqlServerError> {
323 const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
324 let rows = self.simple_query(QUERY).await?;
325 match &rows[..] {
326 [row] => {
327 let val: i16 = row
328 .try_get(0)
329 .context("getting 0th column")?
330 .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
331 let level = TransactionIsolationLevel::try_from_sql_server(val)?;
332 Ok(level)
333 }
334 other => Err(SqlServerError::InvariantViolated(format!(
335 "expected one row, got {other:?}"
336 ))),
337 }
338 }
339
340 pub fn cdc<I, M>(&mut self, capture_instances: I, metrics: M) -> crate::cdc::CdcStream<'_, M>
345 where
346 I: IntoIterator,
347 I::Item: Into<Arc<str>>,
348 M: SqlServerCdcMetrics,
349 {
350 let instances = capture_instances
351 .into_iter()
352 .map(|i| (i.into(), None))
353 .collect();
354 crate::cdc::CdcStream::new(self, instances, metrics)
355 }
356}
357
358pub type RowStream<'a> =
360 Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
361
362#[derive(Debug)]
363pub struct Transaction<'a> {
364 client: &'a mut Client,
365 closed: bool,
366}
367
368impl<'a> Transaction<'a> {
369 async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
370 let results = client
371 .simple_query("BEGIN TRANSACTION")
372 .await
373 .context("begin")?;
374 if !results.is_empty() {
375 Err(SqlServerError::InvariantViolated(format!(
376 "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
377 )))
378 } else {
379 Ok(Transaction {
380 client,
381 closed: false,
382 })
383 }
384 }
385
386 pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
396 if savepoint_name.is_empty()
399 || !savepoint_name
400 .chars()
401 .all(|c| c.is_alphanumeric() || c == '_')
402 {
403 Err(SqlServerError::ProgrammingError(format!(
404 "Invalid savepoint name: '{savepoint_name}"
405 )))?;
406 }
407
408 let stmt = format!("SAVE TRANSACTION [{savepoint_name}]");
409 let _result = self.client.simple_query(stmt).await?;
410 Ok(())
411 }
412
413 pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
417 static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
418 FROM sys.dm_tran_database_transactions dt \
419 JOIN sys.dm_tran_current_transaction ct \
420 ON ct.transaction_id = dt.transaction_id \
421 WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
422 let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
423 crate::inspect::parse_numeric_lsn(&result)
424 }
425
426 pub async fn lock_table_shared(
432 &mut self,
433 schema: &str,
434 table: &str,
435 ) -> Result<(), SqlServerError> {
436 static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
439 let query = format!(
443 "{SET_READ_COMMITTED}\nSELECT * FROM [{schema}].[{table}] WITH (TABLOCK, HOLDLOCK) WHERE 1=0;"
444 );
445 let _result = self.client.simple_query(query).await?;
446 Ok(())
447 }
448
449 pub async fn execute<'q>(
451 &mut self,
452 query: impl Into<Cow<'q, str>>,
453 params: &[&dyn ToSql],
454 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
455 self.client.execute(query, params).await
456 }
457
458 pub async fn query<'q>(
460 &mut self,
461 query: impl Into<Cow<'q, str>>,
462 params: &[&dyn tiberius::ToSql],
463 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
464 self.client.query(query, params).await
465 }
466
467 pub fn query_streaming<'c, 'q, Q>(
469 &'c mut self,
470 query: Q,
471 params: &[&dyn tiberius::ToSql],
472 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
473 where
474 Q: Into<Cow<'q, str>>,
475 {
476 self.client.query_streaming(query, params)
477 }
478
479 pub async fn simple_query<'q>(
481 &mut self,
482 query: impl Into<Cow<'q, str>>,
483 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
484 self.client.simple_query(query).await
485 }
486
487 pub async fn rollback(mut self) -> Result<(), SqlServerError> {
489 static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
490 self.closed = true;
493 self.client.simple_query(ROLLBACK_QUERY).await?;
494 Ok(())
495 }
496
497 pub async fn commit(mut self) -> Result<(), SqlServerError> {
499 static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
500 self.closed = true;
503 self.client.simple_query(COMMIT_QUERY).await?;
504 Ok(())
505 }
506}
507
508impl Drop for Transaction<'_> {
509 fn drop(&mut self) {
510 if !self.closed {
513 let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
514 }
515 }
516}
517
518#[derive(Debug, PartialEq, Eq)]
522pub enum TransactionIsolationLevel {
523 ReadUncommitted,
524 ReadCommitted,
525 RepeatableRead,
526 Snapshot,
527 Serializable,
528}
529
530impl TransactionIsolationLevel {
531 fn as_str(&self) -> &'static str {
533 match self {
534 TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
535 TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
536 TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
537 TransactionIsolationLevel::Snapshot => "SNAPSHOT",
538 TransactionIsolationLevel::Serializable => "SERIALIZABLE",
539 }
540 }
541
542 fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
544 let level = match val {
545 1 => TransactionIsolationLevel::ReadUncommitted,
546 2 => TransactionIsolationLevel::ReadCommitted,
547 3 => TransactionIsolationLevel::RepeatableRead,
548 4 => TransactionIsolationLevel::Serializable,
549 5 => TransactionIsolationLevel::Snapshot,
550 x => anyhow::bail!("unknown level {x}"),
551 };
552 Ok(level)
553 }
554}
555
556#[derive(Derivative)]
557#[derivative(Debug)]
558enum Response {
559 Execute {
560 rows_affected: SmallVec<[u64; 1]>,
561 },
562 Rows(SmallVec<[tiberius::Row; 1]>),
563 RowStream {
564 #[derivative(Debug = "ignore")]
565 stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
566 },
567}
568
569#[derive(Debug)]
570struct Request {
571 tx: oneshot::Sender<Result<Response, SqlServerError>>,
572 kind: RequestKind,
573}
574
575#[derive(Derivative)]
576#[derivative(Debug)]
577enum RequestKind {
578 Execute {
579 query: String,
580 #[derivative(Debug = "ignore")]
581 params: SmallVec<[OwnedColumnData; 4]>,
582 },
583 Query {
584 query: String,
585 #[derivative(Debug = "ignore")]
586 params: SmallVec<[OwnedColumnData; 4]>,
587 },
588 QueryStreamed {
589 query: String,
590 #[derivative(Debug = "ignore")]
591 params: SmallVec<[OwnedColumnData; 4]>,
592 },
593 SimpleQuery {
594 query: String,
595 },
596}
597
598pub struct Connection {
599 rx: UnboundedReceiver<Request>,
601 client: tiberius::Client<Compat<TcpStream>>,
603 _resources: Option<Box<dyn Any + Send + Sync>>,
605}
606
607impl Connection {
608 async fn run(mut self) {
609 while let Some(Request { tx, kind }) = self.rx.recv().await {
610 tracing::trace!(?kind, "processing SQL Server query");
611 let result = Connection::handle_request(&mut self.client, kind).await;
612 let (response, maybe_extra_work) = match result {
613 Ok((response, work)) => (Ok(response), work),
614 Err(err) => (Err(err), None),
615 };
616
617 let _ = tx.send(response);
619
620 if let Some(extra_work) = maybe_extra_work {
624 extra_work.await;
625 }
626 }
627 tracing::debug!("channel closed, SQL Server InnerClient shutting down");
628 }
629
630 async fn handle_request<'c>(
631 client: &'c mut tiberius::Client<Compat<TcpStream>>,
632 kind: RequestKind,
633 ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
634 match kind {
635 RequestKind::Execute { query, params } => {
636 #[allow(clippy::as_conversions)]
637 let params: SmallVec<[&dyn ToSql; 4]> =
638 params.iter().map(|x| x as &dyn ToSql).collect();
639 let result = client.execute(query, ¶ms[..]).await?;
640
641 match result.rows_affected() {
642 rows_affected => {
643 let response = Response::Execute {
644 rows_affected: rows_affected.into(),
645 };
646 Ok((response, None))
647 }
648 }
649 }
650 RequestKind::Query { query, params } => {
651 #[allow(clippy::as_conversions)]
652 let params: SmallVec<[&dyn ToSql; 4]> =
653 params.iter().map(|x| x as &dyn ToSql).collect();
654 let result = client.query(query, params.as_slice()).await?;
655
656 let mut results = result.into_results().await.context("into results")?;
657 if results.is_empty() {
658 Ok((Response::Rows(smallvec![]), None))
659 } else if results.len() == 1 {
660 let rows = results.pop().expect("checked len").into();
663 Ok((Response::Rows(rows), None))
664 } else {
665 Err(SqlServerError::ProgrammingError(format!(
666 "Query only supports 1 statement, got {}",
667 results.len()
668 )))
669 }
670 }
671 RequestKind::QueryStreamed { query, params } => {
672 #[allow(clippy::as_conversions)]
673 let params: SmallVec<[&dyn ToSql; 4]> =
674 params.iter().map(|x| x as &dyn ToSql).collect();
675 let result = client.query(query, params.as_slice()).await?;
676
677 let (tx, rx) = tokio::sync::mpsc::channel(256);
696 let work = Box::pin(async move {
697 let mut stream = result.into_row_stream();
698 while let Some(result) = stream.next().await {
699 if let Err(err) = tx.send(result.err_into()).await {
700 tracing::warn!(?err, "SQL Server row stream receiver went away");
701 }
702 }
703 tracing::info!("SQL Server row stream complete");
704 });
705
706 Ok((Response::RowStream { stream: rx }, Some(work)))
707 }
708 RequestKind::SimpleQuery { query } => {
709 let result = client.simple_query(query).await?;
710
711 let mut results = result.into_results().await.context("into results")?;
712 if results.is_empty() {
713 Ok((Response::Rows(smallvec![]), None))
714 } else if results.len() == 1 {
715 let rows = results.pop().expect("checked len").into();
718 Ok((Response::Rows(rows), None))
719 } else {
720 Err(SqlServerError::ProgrammingError(format!(
721 "Simple query only supports 1 statement, got {}",
722 results.len()
723 )))
724 }
725 }
726 }
727 }
728}
729
730impl IntoFuture for Connection {
731 type Output = ();
732 type IntoFuture = BoxFuture<'static, Self::Output>;
733
734 fn into_future(self) -> Self::IntoFuture {
735 self.run().boxed()
736 }
737}
738
739#[derive(Debug)]
742enum OwnedColumnData {
743 U8(Option<u8>),
744 I16(Option<i16>),
745 I32(Option<i32>),
746 I64(Option<i64>),
747 F32(Option<f32>),
748 F64(Option<f64>),
749 Bit(Option<bool>),
750 String(Option<String>),
751 Guid(Option<uuid::Uuid>),
752 Binary(Option<Vec<u8>>),
753 Numeric(Option<tiberius::numeric::Numeric>),
754 Xml(Option<tiberius::xml::XmlData>),
755 DateTime(Option<tiberius::time::DateTime>),
756 SmallDateTime(Option<tiberius::time::SmallDateTime>),
757 Time(Option<tiberius::time::Time>),
758 Date(Option<tiberius::time::Date>),
759 DateTime2(Option<tiberius::time::DateTime2>),
760 DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
761}
762
763impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
764 fn from(value: tiberius::ColumnData<'a>) -> Self {
765 match value {
766 tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
767 tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
768 tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
769 tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
770 tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
771 tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
772 tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
773 tiberius::ColumnData::String(inner) => {
774 OwnedColumnData::String(inner.map(|s| s.to_string()))
775 }
776 tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
777 tiberius::ColumnData::Binary(inner) => {
778 OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
779 }
780 tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
781 tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
782 tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
783 tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
784 tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
785 tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
786 tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
787 tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
788 }
789 }
790}
791
792impl tiberius::ToSql for OwnedColumnData {
793 fn to_sql(&self) -> tiberius::ColumnData<'_> {
794 match self {
795 OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
796 OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
797 OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
798 OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
799 OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
800 OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
801 OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
802 OwnedColumnData::String(inner) => {
803 tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
804 }
805 OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
806 OwnedColumnData::Binary(inner) => {
807 tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
808 }
809 OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
810 OwnedColumnData::Xml(inner) => {
811 tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
812 }
813 OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
814 OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
815 OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
816 OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
817 OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
818 OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
819 }
820 }
821}
822
823impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
824 fn from(value: &'a T) -> Self {
825 OwnedColumnData::from(value.to_sql())
826 }
827}
828
829#[derive(Debug, thiserror::Error)]
830pub enum SqlServerError {
831 #[error(transparent)]
832 SqlServer(#[from] tiberius::error::Error),
833 #[error(transparent)]
834 CdcError(#[from] crate::cdc::CdcError),
835 #[error("expected column '{0}' to be present")]
836 MissingColumn(&'static str),
837 #[error("sql server client encountered I/O error: {0}")]
838 IO(#[from] tokio::io::Error),
839 #[error("found invalid data in the column '{column_name}': {error}")]
840 InvalidData { column_name: String, error: String },
841 #[error("got back a null value when querying for the LSN")]
842 NullLsn,
843 #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
844 InvalidSystemSetting {
845 name: String,
846 expected: String,
847 actual: String,
848 },
849 #[error("invariant was violated: {0}")]
850 InvariantViolated(String),
851 #[error(transparent)]
852 Generic(#[from] anyhow::Error),
853 #[error("programming error! {0}")]
854 ProgrammingError(String),
855 #[error(
856 "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
857 )]
858 AuthorizationError {
859 tables: String,
860 capture_instances: String,
861 },
862}
863
864#[derive(Debug, thiserror::Error)]
876pub enum SqlServerDecodeError {
877 #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
878 InvalidColumn {
879 column_name: String,
880 as_type: &'static str,
881 },
882 #[error("found invalid data in the column '{column_name}': {error}")]
883 InvalidData { column_name: String, error: String },
884 #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
885 Unsupported {
886 sql_server_type: SqlServerColumnDecodeType,
887 mz_type: SqlScalarType,
888 },
889}
890
891impl SqlServerDecodeError {
892 fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
893 let error = match error {
895 mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
896 };
897 SqlServerDecodeError::InvalidData {
898 column_name: name.to_string(),
899 error: error.to_string(),
900 }
901 }
902
903 fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
904 let error = match error {
906 mz_repr::adt::date::DateError::OutOfRange => "out of range",
907 };
908 SqlServerDecodeError::InvalidData {
909 column_name: name.to_string(),
910 error: error.to_string(),
911 }
912 }
913
914 fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
915 SqlServerDecodeError::InvalidData {
916 column_name: name.to_string(),
917 error: format!("expected {expected_chars} chars found {found_chars}"),
918 }
919 }
920
921 fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
922 SqlServerDecodeError::InvalidData {
923 column_name: name.to_string(),
924 error: format!("expected max {max_chars} chars found {found_chars}"),
925 }
926 }
927
928 fn invalid_column(name: &str, as_type: &'static str) -> Self {
929 SqlServerDecodeError::InvalidColumn {
930 column_name: name.to_string(),
931 as_type,
932 }
933 }
934}
935
936pub fn quote_identifier(ident: &str) -> String {
943 let mut quoted = ident.replace(']', "]]");
944 quoted.insert(0, '[');
945 quoted.push(']');
946 quoted
947}
948
949pub trait SqlServerCdcMetrics {
950 fn snapshot_table_lock_start(&self, table_name: &str);
952 fn snapshot_table_lock_end(&self, table_name: &str);
954}
955
956pub struct LoggingSqlServerCdcMetrics;
959
960impl SqlServerCdcMetrics for LoggingSqlServerCdcMetrics {
961 fn snapshot_table_lock_start(&self, table_name: &str) {
962 tracing::info!("snapshot_table_lock_start: {table_name}");
963 }
964
965 fn snapshot_table_lock_end(&self, table_name: &str) {
966 tracing::info!("snapshot_table_lock_end: {table_name}");
967 }
968}
969
970#[cfg(test)]
971mod test {
972 use super::*;
973
974 #[mz_ore::test]
975 fn test_sql_server_escaping() {
976 assert_eq!("[]", "e_identifier(""));
977 assert_eq!("[]]]", "e_identifier("]"));
978 assert_eq!("[a]", "e_identifier("a"));
979 assert_eq!("[cost(]]\u{00A3})]", "e_identifier("cost(]\u{00A3})"));
980 assert_eq!("[[g[o[o]][]", "e_identifier("[g[o[o]["));
981 }
982}