1use std::any::Any;
11use std::borrow::Cow;
12use std::future::IntoFuture;
13use std::pin::Pin;
14use std::sync::Arc;
15
16use anyhow::Context;
17use derivative::Derivative;
18use futures::future::BoxFuture;
19use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
20use mz_ore::netio::DUMMY_DNS_PORT;
21use mz_ore::result::ResultExt;
22use mz_repr::SqlScalarType;
23use smallvec::{SmallVec, smallvec};
24use tiberius::ToSql;
25use tokio::net::TcpStream;
26use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
27use tokio::sync::oneshot;
28use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
29
30pub mod cdc;
31pub mod config;
32pub mod desc;
33pub mod inspect;
34
35pub use config::Config;
36pub use desc::{ProtoSqlServerColumnDesc, ProtoSqlServerTableDesc};
37
38use crate::cdc::Lsn;
39use crate::config::TunnelConfig;
40use crate::desc::SqlServerColumnDecodeType;
41
42#[derive(Debug)]
45pub struct Client {
46 tx: UnboundedSender<Request>,
47 config: Config,
49}
50static_assertions::assert_not_impl_all!(Client: Clone);
53
54impl Client {
55 pub async fn connect(config: Config) -> Result<Self, SqlServerError> {
64 let (tcp, resources): (_, Option<Box<dyn Any + Send + Sync>>) = match &config.tunnel {
67 TunnelConfig::Direct => {
68 let tcp = TcpStream::connect(config.inner.get_addr())
69 .await
70 .context("direct")?;
71 (tcp, None)
72 }
73 TunnelConfig::Ssh {
74 config: ssh_config,
75 manager,
76 timeout,
77 host,
78 port,
79 } => {
80 let tunnel = manager
83 .connect(ssh_config.clone(), host, *port, *timeout, config.in_task)
84 .await?;
85 let tcp = TcpStream::connect(tunnel.local_addr())
86 .await
87 .context("ssh tunnel")?;
88
89 (tcp, Some(Box::new(tunnel)))
90 }
91 TunnelConfig::AwsPrivatelink {
92 connection_id,
93 port,
94 } => {
95 let privatelink_host = mz_cloud_resources::vpc_endpoint_name(*connection_id);
96 let mut privatelink_addrs =
97 tokio::net::lookup_host((privatelink_host.clone(), DUMMY_DNS_PORT)).await?;
98
99 let Some(mut addr) = privatelink_addrs.next() else {
100 return Err(SqlServerError::InvariantViolated(format!(
101 "aws privatelink: no addresses found for host {:?}",
102 privatelink_host
103 )));
104 };
105
106 addr.set_port(port.clone());
107
108 let tcp = TcpStream::connect(addr)
109 .await
110 .context(format!("aws privatelink {:?}", addr))?;
111
112 (tcp, None)
113 }
114 };
115
116 tcp.set_nodelay(true)?;
117
118 let (client, connection) = Self::connect_raw(config, tcp, resources).await?;
119 mz_ore::task::spawn(|| "sql-server-client-connection", async move {
120 connection.await
121 });
122
123 Ok(client)
124 }
125
126 pub async fn new_connection(&self) -> Result<Self, SqlServerError> {
129 Self::connect(self.config.clone()).await
130 }
131
132 pub async fn connect_raw(
133 config: Config,
134 tcp: tokio::net::TcpStream,
135 resources: Option<Box<dyn Any + Send + Sync>>,
136 ) -> Result<(Self, Connection), SqlServerError> {
137 let client = tiberius::Client::connect(config.inner.clone(), tcp.compat_write()).await?;
138 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
139
140 Ok((
143 Client { tx, config },
144 Connection {
145 rx,
146 client,
147 _resources: resources,
148 },
149 ))
150 }
151
152 pub async fn execute<'a>(
161 &mut self,
162 query: impl Into<Cow<'a, str>>,
163 params: &[&dyn ToSql],
164 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
165 let (tx, rx) = tokio::sync::oneshot::channel();
166
167 let params = params
168 .iter()
169 .map(|p| OwnedColumnData::from(p.to_sql()))
170 .collect();
171 let kind = RequestKind::Execute {
172 query: query.into().to_string(),
173 params,
174 };
175 self.tx
176 .send(Request { tx, kind })
177 .context("sending request")?;
178
179 let response = rx.await.context("channel")??;
180 match response {
181 Response::Execute { rows_affected } => Ok(rows_affected),
182 other @ Response::Rows(_) | other @ Response::RowStream { .. } => {
183 Err(SqlServerError::ProgrammingError(format!(
184 "expected Response::Execute, got {other:?}"
185 )))
186 }
187 }
188 }
189
190 pub async fn query<'a>(
199 &mut self,
200 query: impl Into<Cow<'a, str>>,
201 params: &[&dyn tiberius::ToSql],
202 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
203 let (tx, rx) = tokio::sync::oneshot::channel();
204
205 let params = params
206 .iter()
207 .map(|p| OwnedColumnData::from(p.to_sql()))
208 .collect();
209 let kind = RequestKind::Query {
210 query: query.into().to_string(),
211 params,
212 };
213 self.tx
214 .send(Request { tx, kind })
215 .context("sending request")?;
216
217 let response = rx.await.context("channel")??;
218 match response {
219 Response::Rows(rows) => Ok(rows),
220 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
221 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
222 ),
223 }
224 }
225
226 pub fn query_streaming<'c, 'q, Q>(
231 &'c mut self,
232 query: Q,
233 params: &[&dyn tiberius::ToSql],
234 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
235 where
236 Q: Into<Cow<'q, str>>,
237 {
238 let (tx, rx) = tokio::sync::oneshot::channel();
239 let params = params
240 .iter()
241 .map(|p| OwnedColumnData::from(p.to_sql()))
242 .collect();
243 let kind = RequestKind::QueryStreamed {
244 query: query.into().to_string(),
245 params,
246 };
247
248 let request_future = async move {
250 self.tx
251 .send(Request { tx, kind })
252 .context("sending request")?;
253
254 let response = rx.await.context("channel")??;
255 match response {
256 Response::RowStream { stream } => {
257 Ok(tokio_stream::wrappers::ReceiverStream::new(stream))
258 }
259 other @ Response::Execute { .. } | other @ Response::Rows(_) => {
260 Err(SqlServerError::ProgrammingError(format!(
261 "expected Response::Rows, got {other:?}"
262 )))
263 }
264 }
265 };
266
267 futures::stream::once(request_future).try_flatten()
269 }
270
271 pub async fn simple_query<'a>(
281 &mut self,
282 query: impl Into<Cow<'a, str>>,
283 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
284 let (tx, rx) = tokio::sync::oneshot::channel();
285 let kind = RequestKind::SimpleQuery {
286 query: query.into().to_string(),
287 };
288 self.tx
289 .send(Request { tx, kind })
290 .context("sending request")?;
291
292 let response = rx.await.context("channel")??;
293 match response {
294 Response::Rows(rows) => Ok(rows),
295 other @ Response::Execute { .. } | other @ Response::RowStream { .. } => Err(
296 SqlServerError::ProgrammingError(format!("expected Response::Rows, got {other:?}")),
297 ),
298 }
299 }
300
301 pub async fn transaction(&mut self) -> Result<Transaction<'_>, SqlServerError> {
306 Transaction::new(self).await
307 }
308
309 pub async fn set_transaction_isolation(
311 &mut self,
312 level: TransactionIsolationLevel,
313 ) -> Result<(), SqlServerError> {
314 let query = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_str());
315 self.simple_query(query).await?;
316 Ok(())
317 }
318
319 pub async fn get_transaction_isolation(
321 &mut self,
322 ) -> Result<TransactionIsolationLevel, SqlServerError> {
323 const QUERY: &str = "SELECT transaction_isolation_level FROM sys.dm_exec_sessions where session_id = @@SPID;";
324 let rows = self.simple_query(QUERY).await?;
325 match &rows[..] {
326 [row] => {
327 let val: i16 = row
328 .try_get(0)
329 .context("getting 0th column")?
330 .ok_or_else(|| anyhow::anyhow!("no 0th column?"))?;
331 let level = TransactionIsolationLevel::try_from_sql_server(val)?;
332 Ok(level)
333 }
334 other => Err(SqlServerError::InvariantViolated(format!(
335 "expected one row, got {other:?}"
336 ))),
337 }
338 }
339
340 pub fn cdc<I>(&mut self, capture_instances: I) -> crate::cdc::CdcStream<'_>
345 where
346 I: IntoIterator,
347 I::Item: Into<Arc<str>>,
348 {
349 let instances = capture_instances
350 .into_iter()
351 .map(|i| (i.into(), None))
352 .collect();
353 crate::cdc::CdcStream::new(self, instances)
354 }
355}
356
357pub type RowStream<'a> =
359 Pin<Box<dyn Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + 'a>>;
360
361#[derive(Debug)]
362pub struct Transaction<'a> {
363 client: &'a mut Client,
364 closed: bool,
365}
366
367impl<'a> Transaction<'a> {
368 async fn new(client: &'a mut Client) -> Result<Self, SqlServerError> {
369 let results = client
370 .simple_query("BEGIN TRANSACTION")
371 .await
372 .context("begin")?;
373 if !results.is_empty() {
374 Err(SqlServerError::InvariantViolated(format!(
375 "expected empty result from BEGIN TRANSACTION. Got: {results:?}"
376 )))
377 } else {
378 Ok(Transaction {
379 client,
380 closed: false,
381 })
382 }
383 }
384
385 pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
395 if savepoint_name.is_empty()
398 || !savepoint_name
399 .chars()
400 .all(|c| c.is_alphanumeric() || c == '_')
401 {
402 Err(SqlServerError::ProgrammingError(format!(
403 "Invalid savepoint name: '{savepoint_name}"
404 )))?;
405 }
406
407 let stmt = format!("SAVE TRANSACTION [{savepoint_name}]");
408 let _result = self.client.simple_query(stmt).await?;
409 Ok(())
410 }
411
412 pub async fn get_lsn(&mut self) -> Result<Lsn, SqlServerError> {
416 static CURRENT_LSN_QUERY: &str = "SELECT dt.database_transaction_most_recent_savepoint_lsn \
417 FROM sys.dm_tran_database_transactions dt \
418 JOIN sys.dm_tran_current_transaction ct \
419 ON ct.transaction_id = dt.transaction_id \
420 WHERE dt.database_transaction_most_recent_savepoint_lsn IS NOT NULL";
421 let result = self.client.simple_query(CURRENT_LSN_QUERY).await?;
422 crate::inspect::parse_numeric_lsn(&result)
423 }
424
425 pub async fn lock_table_shared(
431 &mut self,
432 schema: &str,
433 table: &str,
434 ) -> Result<(), SqlServerError> {
435 static SET_READ_COMMITTED: &str = "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;";
438 let query = format!(
442 "{SET_READ_COMMITTED}\nSELECT * FROM [{schema}].[{table}] WITH (TABLOCK, HOLDLOCK) WHERE 1=0;"
443 );
444 let _result = self.client.simple_query(query).await?;
445 Ok(())
446 }
447
448 pub async fn execute<'q>(
450 &mut self,
451 query: impl Into<Cow<'q, str>>,
452 params: &[&dyn ToSql],
453 ) -> Result<SmallVec<[u64; 1]>, SqlServerError> {
454 self.client.execute(query, params).await
455 }
456
457 pub async fn query<'q>(
459 &mut self,
460 query: impl Into<Cow<'q, str>>,
461 params: &[&dyn tiberius::ToSql],
462 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
463 self.client.query(query, params).await
464 }
465
466 pub fn query_streaming<'c, 'q, Q>(
468 &'c mut self,
469 query: Q,
470 params: &[&dyn tiberius::ToSql],
471 ) -> impl Stream<Item = Result<tiberius::Row, SqlServerError>> + Send + use<'c, Q>
472 where
473 Q: Into<Cow<'q, str>>,
474 {
475 self.client.query_streaming(query, params)
476 }
477
478 pub async fn simple_query<'q>(
480 &mut self,
481 query: impl Into<Cow<'q, str>>,
482 ) -> Result<SmallVec<[tiberius::Row; 1]>, SqlServerError> {
483 self.client.simple_query(query).await
484 }
485
486 pub async fn rollback(mut self) -> Result<(), SqlServerError> {
488 static ROLLBACK_QUERY: &str = "ROLLBACK TRANSACTION";
489 self.closed = true;
492 self.client.simple_query(ROLLBACK_QUERY).await?;
493 Ok(())
494 }
495
496 pub async fn commit(mut self) -> Result<(), SqlServerError> {
498 static COMMIT_QUERY: &str = "COMMIT TRANSACTION";
499 self.closed = true;
502 self.client.simple_query(COMMIT_QUERY).await?;
503 Ok(())
504 }
505}
506
507impl Drop for Transaction<'_> {
508 fn drop(&mut self) {
509 if !self.closed {
512 let _fut = self.client.simple_query("ROLLBACK TRANSACTION");
513 }
514 }
515}
516
517#[derive(Debug, PartialEq, Eq)]
521pub enum TransactionIsolationLevel {
522 ReadUncommitted,
523 ReadCommitted,
524 RepeatableRead,
525 Snapshot,
526 Serializable,
527}
528
529impl TransactionIsolationLevel {
530 fn as_str(&self) -> &'static str {
532 match self {
533 TransactionIsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
534 TransactionIsolationLevel::ReadCommitted => "READ COMMITTED",
535 TransactionIsolationLevel::RepeatableRead => "REPEATABLE READ",
536 TransactionIsolationLevel::Snapshot => "SNAPSHOT",
537 TransactionIsolationLevel::Serializable => "SERIALIZABLE",
538 }
539 }
540
541 fn try_from_sql_server(val: i16) -> Result<TransactionIsolationLevel, anyhow::Error> {
543 let level = match val {
544 1 => TransactionIsolationLevel::ReadUncommitted,
545 2 => TransactionIsolationLevel::ReadCommitted,
546 3 => TransactionIsolationLevel::RepeatableRead,
547 4 => TransactionIsolationLevel::Serializable,
548 5 => TransactionIsolationLevel::Snapshot,
549 x => anyhow::bail!("unknown level {x}"),
550 };
551 Ok(level)
552 }
553}
554
555#[derive(Derivative)]
556#[derivative(Debug)]
557enum Response {
558 Execute {
559 rows_affected: SmallVec<[u64; 1]>,
560 },
561 Rows(SmallVec<[tiberius::Row; 1]>),
562 RowStream {
563 #[derivative(Debug = "ignore")]
564 stream: tokio::sync::mpsc::Receiver<Result<tiberius::Row, SqlServerError>>,
565 },
566}
567
568#[derive(Debug)]
569struct Request {
570 tx: oneshot::Sender<Result<Response, SqlServerError>>,
571 kind: RequestKind,
572}
573
574#[derive(Derivative)]
575#[derivative(Debug)]
576enum RequestKind {
577 Execute {
578 query: String,
579 #[derivative(Debug = "ignore")]
580 params: SmallVec<[OwnedColumnData; 4]>,
581 },
582 Query {
583 query: String,
584 #[derivative(Debug = "ignore")]
585 params: SmallVec<[OwnedColumnData; 4]>,
586 },
587 QueryStreamed {
588 query: String,
589 #[derivative(Debug = "ignore")]
590 params: SmallVec<[OwnedColumnData; 4]>,
591 },
592 SimpleQuery {
593 query: String,
594 },
595}
596
597pub struct Connection {
598 rx: UnboundedReceiver<Request>,
600 client: tiberius::Client<Compat<TcpStream>>,
602 _resources: Option<Box<dyn Any + Send + Sync>>,
604}
605
606impl Connection {
607 async fn run(mut self) {
608 while let Some(Request { tx, kind }) = self.rx.recv().await {
609 tracing::trace!(?kind, "processing SQL Server query");
610 let result = Connection::handle_request(&mut self.client, kind).await;
611 let (response, maybe_extra_work) = match result {
612 Ok((response, work)) => (Ok(response), work),
613 Err(err) => (Err(err), None),
614 };
615
616 let _ = tx.send(response);
618
619 if let Some(extra_work) = maybe_extra_work {
623 extra_work.await;
624 }
625 }
626 tracing::debug!("channel closed, SQL Server InnerClient shutting down");
627 }
628
629 async fn handle_request<'c>(
630 client: &'c mut tiberius::Client<Compat<TcpStream>>,
631 kind: RequestKind,
632 ) -> Result<(Response, Option<BoxFuture<'c, ()>>), SqlServerError> {
633 match kind {
634 RequestKind::Execute { query, params } => {
635 #[allow(clippy::as_conversions)]
636 let params: SmallVec<[&dyn ToSql; 4]> =
637 params.iter().map(|x| x as &dyn ToSql).collect();
638 let result = client.execute(query, ¶ms[..]).await?;
639
640 match result.rows_affected() {
641 rows_affected => {
642 let response = Response::Execute {
643 rows_affected: rows_affected.into(),
644 };
645 Ok((response, None))
646 }
647 }
648 }
649 RequestKind::Query { query, params } => {
650 #[allow(clippy::as_conversions)]
651 let params: SmallVec<[&dyn ToSql; 4]> =
652 params.iter().map(|x| x as &dyn ToSql).collect();
653 let result = client.query(query, params.as_slice()).await?;
654
655 let mut results = result.into_results().await.context("into results")?;
656 if results.is_empty() {
657 Ok((Response::Rows(smallvec![]), None))
658 } else if results.len() == 1 {
659 let rows = results.pop().expect("checked len").into();
662 Ok((Response::Rows(rows), None))
663 } else {
664 Err(SqlServerError::ProgrammingError(format!(
665 "Query only supports 1 statement, got {}",
666 results.len()
667 )))
668 }
669 }
670 RequestKind::QueryStreamed { query, params } => {
671 #[allow(clippy::as_conversions)]
672 let params: SmallVec<[&dyn ToSql; 4]> =
673 params.iter().map(|x| x as &dyn ToSql).collect();
674 let result = client.query(query, params.as_slice()).await?;
675
676 let (tx, rx) = tokio::sync::mpsc::channel(256);
695 let work = Box::pin(async move {
696 let mut stream = result.into_row_stream();
697 while let Some(result) = stream.next().await {
698 if let Err(err) = tx.send(result.err_into()).await {
699 tracing::warn!(?err, "SQL Server row stream receiver went away");
700 }
701 }
702 tracing::info!("SQL Server row stream complete");
703 });
704
705 Ok((Response::RowStream { stream: rx }, Some(work)))
706 }
707 RequestKind::SimpleQuery { query } => {
708 let result = client.simple_query(query).await?;
709
710 let mut results = result.into_results().await.context("into results")?;
711 if results.is_empty() {
712 Ok((Response::Rows(smallvec![]), None))
713 } else if results.len() == 1 {
714 let rows = results.pop().expect("checked len").into();
717 Ok((Response::Rows(rows), None))
718 } else {
719 Err(SqlServerError::ProgrammingError(format!(
720 "Simple query only supports 1 statement, got {}",
721 results.len()
722 )))
723 }
724 }
725 }
726 }
727}
728
729impl IntoFuture for Connection {
730 type Output = ();
731 type IntoFuture = BoxFuture<'static, Self::Output>;
732
733 fn into_future(self) -> Self::IntoFuture {
734 self.run().boxed()
735 }
736}
737
738#[derive(Debug)]
741enum OwnedColumnData {
742 U8(Option<u8>),
743 I16(Option<i16>),
744 I32(Option<i32>),
745 I64(Option<i64>),
746 F32(Option<f32>),
747 F64(Option<f64>),
748 Bit(Option<bool>),
749 String(Option<String>),
750 Guid(Option<uuid::Uuid>),
751 Binary(Option<Vec<u8>>),
752 Numeric(Option<tiberius::numeric::Numeric>),
753 Xml(Option<tiberius::xml::XmlData>),
754 DateTime(Option<tiberius::time::DateTime>),
755 SmallDateTime(Option<tiberius::time::SmallDateTime>),
756 Time(Option<tiberius::time::Time>),
757 Date(Option<tiberius::time::Date>),
758 DateTime2(Option<tiberius::time::DateTime2>),
759 DateTimeOffset(Option<tiberius::time::DateTimeOffset>),
760}
761
762impl<'a> From<tiberius::ColumnData<'a>> for OwnedColumnData {
763 fn from(value: tiberius::ColumnData<'a>) -> Self {
764 match value {
765 tiberius::ColumnData::U8(inner) => OwnedColumnData::U8(inner),
766 tiberius::ColumnData::I16(inner) => OwnedColumnData::I16(inner),
767 tiberius::ColumnData::I32(inner) => OwnedColumnData::I32(inner),
768 tiberius::ColumnData::I64(inner) => OwnedColumnData::I64(inner),
769 tiberius::ColumnData::F32(inner) => OwnedColumnData::F32(inner),
770 tiberius::ColumnData::F64(inner) => OwnedColumnData::F64(inner),
771 tiberius::ColumnData::Bit(inner) => OwnedColumnData::Bit(inner),
772 tiberius::ColumnData::String(inner) => {
773 OwnedColumnData::String(inner.map(|s| s.to_string()))
774 }
775 tiberius::ColumnData::Guid(inner) => OwnedColumnData::Guid(inner),
776 tiberius::ColumnData::Binary(inner) => {
777 OwnedColumnData::Binary(inner.map(|b| b.to_vec()))
778 }
779 tiberius::ColumnData::Numeric(inner) => OwnedColumnData::Numeric(inner),
780 tiberius::ColumnData::Xml(inner) => OwnedColumnData::Xml(inner.map(|x| x.into_owned())),
781 tiberius::ColumnData::DateTime(inner) => OwnedColumnData::DateTime(inner),
782 tiberius::ColumnData::SmallDateTime(inner) => OwnedColumnData::SmallDateTime(inner),
783 tiberius::ColumnData::Time(inner) => OwnedColumnData::Time(inner),
784 tiberius::ColumnData::Date(inner) => OwnedColumnData::Date(inner),
785 tiberius::ColumnData::DateTime2(inner) => OwnedColumnData::DateTime2(inner),
786 tiberius::ColumnData::DateTimeOffset(inner) => OwnedColumnData::DateTimeOffset(inner),
787 }
788 }
789}
790
791impl tiberius::ToSql for OwnedColumnData {
792 fn to_sql(&self) -> tiberius::ColumnData<'_> {
793 match self {
794 OwnedColumnData::U8(inner) => tiberius::ColumnData::U8(*inner),
795 OwnedColumnData::I16(inner) => tiberius::ColumnData::I16(*inner),
796 OwnedColumnData::I32(inner) => tiberius::ColumnData::I32(*inner),
797 OwnedColumnData::I64(inner) => tiberius::ColumnData::I64(*inner),
798 OwnedColumnData::F32(inner) => tiberius::ColumnData::F32(*inner),
799 OwnedColumnData::F64(inner) => tiberius::ColumnData::F64(*inner),
800 OwnedColumnData::Bit(inner) => tiberius::ColumnData::Bit(*inner),
801 OwnedColumnData::String(inner) => {
802 tiberius::ColumnData::String(inner.as_deref().map(Cow::Borrowed))
803 }
804 OwnedColumnData::Guid(inner) => tiberius::ColumnData::Guid(*inner),
805 OwnedColumnData::Binary(inner) => {
806 tiberius::ColumnData::Binary(inner.as_deref().map(Cow::Borrowed))
807 }
808 OwnedColumnData::Numeric(inner) => tiberius::ColumnData::Numeric(*inner),
809 OwnedColumnData::Xml(inner) => {
810 tiberius::ColumnData::Xml(inner.as_ref().map(Cow::Borrowed))
811 }
812 OwnedColumnData::DateTime(inner) => tiberius::ColumnData::DateTime(*inner),
813 OwnedColumnData::SmallDateTime(inner) => tiberius::ColumnData::SmallDateTime(*inner),
814 OwnedColumnData::Time(inner) => tiberius::ColumnData::Time(*inner),
815 OwnedColumnData::Date(inner) => tiberius::ColumnData::Date(*inner),
816 OwnedColumnData::DateTime2(inner) => tiberius::ColumnData::DateTime2(*inner),
817 OwnedColumnData::DateTimeOffset(inner) => tiberius::ColumnData::DateTimeOffset(*inner),
818 }
819 }
820}
821
822impl<'a, T: tiberius::ToSql> From<&'a T> for OwnedColumnData {
823 fn from(value: &'a T) -> Self {
824 OwnedColumnData::from(value.to_sql())
825 }
826}
827
828#[derive(Debug, thiserror::Error)]
829pub enum SqlServerError {
830 #[error(transparent)]
831 SqlServer(#[from] tiberius::error::Error),
832 #[error(transparent)]
833 CdcError(#[from] crate::cdc::CdcError),
834 #[error("expected column '{0}' to be present")]
835 MissingColumn(&'static str),
836 #[error("sql server client encountered I/O error: {0}")]
837 IO(#[from] tokio::io::Error),
838 #[error("found invalid data in the column '{column_name}': {error}")]
839 InvalidData { column_name: String, error: String },
840 #[error("got back a null value when querying for the LSN")]
841 NullLsn,
842 #[error("invalid SQL Server system setting '{name}'. Expected '{expected}'. Got '{actual}'.")]
843 InvalidSystemSetting {
844 name: String,
845 expected: String,
846 actual: String,
847 },
848 #[error("invariant was violated: {0}")]
849 InvariantViolated(String),
850 #[error(transparent)]
851 Generic(#[from] anyhow::Error),
852 #[error("programming error! {0}")]
853 ProgrammingError(String),
854 #[error(
855 "insufficient permissions for tables [{tables}] or capture instances [{capture_instances}]"
856 )]
857 AuthorizationError {
858 tables: String,
859 capture_instances: String,
860 },
861}
862
863#[derive(Debug, thiserror::Error)]
875pub enum SqlServerDecodeError {
876 #[error("column '{column_name}' was invalid when getting as type '{as_type}'")]
877 InvalidColumn {
878 column_name: String,
879 as_type: &'static str,
880 },
881 #[error("found invalid data in the column '{column_name}': {error}")]
882 InvalidData { column_name: String, error: String },
883 #[error("can't decode {sql_server_type:?} as {mz_type:?}")]
884 Unsupported {
885 sql_server_type: SqlServerColumnDecodeType,
886 mz_type: SqlScalarType,
887 },
888}
889
890impl SqlServerDecodeError {
891 fn invalid_timestamp(name: &str, error: mz_repr::adt::timestamp::TimestampError) -> Self {
892 let error = match error {
894 mz_repr::adt::timestamp::TimestampError::OutOfRange => "out of range",
895 };
896 SqlServerDecodeError::InvalidData {
897 column_name: name.to_string(),
898 error: error.to_string(),
899 }
900 }
901
902 fn invalid_date(name: &str, error: mz_repr::adt::date::DateError) -> Self {
903 let error = match error {
905 mz_repr::adt::date::DateError::OutOfRange => "out of range",
906 };
907 SqlServerDecodeError::InvalidData {
908 column_name: name.to_string(),
909 error: error.to_string(),
910 }
911 }
912
913 fn invalid_char(name: &str, expected_chars: usize, found_chars: usize) -> Self {
914 SqlServerDecodeError::InvalidData {
915 column_name: name.to_string(),
916 error: format!("expected {expected_chars} chars found {found_chars}"),
917 }
918 }
919
920 fn invalid_varchar(name: &str, max_chars: usize, found_chars: usize) -> Self {
921 SqlServerDecodeError::InvalidData {
922 column_name: name.to_string(),
923 error: format!("expected max {max_chars} chars found {found_chars}"),
924 }
925 }
926
927 fn invalid_column(name: &str, as_type: &'static str) -> Self {
928 SqlServerDecodeError::InvalidColumn {
929 column_name: name.to_string(),
930 as_type,
931 }
932 }
933}