1use crate::client::errors::ConnectionError;
39use crate::config::{Profile, SslMode};
40use crate::info;
41use mz_postgres_util::Sql;
42use std::collections::BTreeMap;
43use tokio_postgres::types::ToSql;
44use tokio_postgres::{Client as PgClient, NoTls, Row, SimpleQueryMessage, Transaction};
45
46pub struct Client {
55 client: PgClient,
56 profile: Profile,
57}
58
59pub struct DeploymentsClient<'a> {
61 pub(crate) client: &'a Client,
62}
63
64pub struct DeploymentsClientMut<'a> {
66 pub(crate) client: &'a mut Client,
67}
68
69pub struct IntrospectionClient<'a> {
71 pub(crate) client: &'a Client,
72}
73
74pub struct ValidationClient<'a> {
76 pub(crate) client: &'a Client,
77}
78
79pub struct TypeInfoClient<'a> {
81 pub(crate) client: &'a Client,
82}
83
84pub struct ProvisioningClient<'a> {
86 pub(crate) client: &'a Client,
87}
88
89pub struct DevOverlaysClient<'a> {
91 pub(crate) client: &'a Client,
92}
93
94const APPLICATION_NAME: &str = "mz-deploy";
95
96impl Client {
97 pub async fn connect_with_profile(profile: Profile) -> Result<Self, ConnectionError> {
109 Self::connect_with_profile_inner(profile, true).await
110 }
111
112 pub(crate) async fn connect_with_profile_no_pin(
123 profile: Profile,
124 ) -> Result<Self, ConnectionError> {
125 Self::connect_with_profile_inner(profile, false).await
126 }
127
128 async fn connect_with_profile_inner(
129 profile: Profile,
130 pin_server_cluster: bool,
131 ) -> Result<Self, ConnectionError> {
132 let host = profile.require_host()?;
133 let mut config = tokio_postgres::Config::new();
134 config.host(host);
135 config.port(profile.port);
136 config.user(&profile.username);
137 config.dbname("materialize");
138 if let Some(password) = &profile.password {
139 config.password(password.as_str());
140 }
141 config.application_name(APPLICATION_NAME);
142
143 let mut effective_options = profile.options.clone();
144 if pin_server_cluster {
145 effective_options.insert(
146 "cluster".to_string(),
147 crate::client::SERVER_CLUSTER_NAME.to_string(),
148 );
149 }
150 if let Some(inner) = build_options_string(&effective_options) {
151 config.options(&inner);
152 }
153
154 let mode = profile.sslmode.unwrap_or_else(|| default_sslmode(host));
155 let hunt: Vec<&std::path::Path> =
156 DEFAULT_CA_PATHS.iter().map(std::path::Path::new).collect();
157 let spec = plan_connector(mode, profile.sslrootcert.as_deref(), host, &hunt, |p| {
158 p.exists()
159 })?;
160 let connector = build_connector(spec)?;
161
162 config.ssl_mode(tokio_ssl_mode(mode));
163
164 type BoxConnection =
168 Box<dyn Future<Output = Result<(), tokio_postgres::Error>> + Send + Unpin>;
169 let (client, connection): (PgClient, BoxConnection) = match connector {
170 Connector::NoTls => {
171 let (client, connection) = config
172 .connect(NoTls)
173 .await
174 .map_err(|source| classify_connect_error(source, &profile, mode))?;
175 (client, Box::new(connection))
176 }
177 Connector::Tls(tls) => {
178 let (client, connection) = config
179 .connect(tls)
180 .await
181 .map_err(|source| classify_connect_error(source, &profile, mode))?;
182 (client, Box::new(connection))
183 }
184 };
185
186 mz_ore::task::spawn(|| "mz-deploy-connection", async move {
187 if let Err(e) = connection.await {
188 info!("connection error: {}", e);
189 }
190 });
191
192 Ok(Client { client, profile })
193 }
194
195 pub fn profile(&self) -> &Profile {
197 &self.profile
198 }
199
200 pub(crate) async fn begin_transaction(&mut self) -> Result<Transaction<'_>, ConnectionError> {
202 self.client
203 .transaction()
204 .await
205 .map_err(ConnectionError::Query)
206 }
207
208 pub fn deployments(&self) -> DeploymentsClient<'_> {
210 DeploymentsClient { client: self }
211 }
212
213 pub fn deployments_mut(&mut self) -> DeploymentsClientMut<'_> {
215 DeploymentsClientMut { client: self }
216 }
217
218 pub fn introspection(&self) -> IntrospectionClient<'_> {
220 IntrospectionClient { client: self }
221 }
222
223 pub fn validation(&self) -> ValidationClient<'_> {
225 ValidationClient { client: self }
226 }
227
228 pub fn types(&self) -> TypeInfoClient<'_> {
230 TypeInfoClient { client: self }
231 }
232
233 pub fn provisioning(&self) -> ProvisioningClient<'_> {
235 ProvisioningClient { client: self }
236 }
237
238 pub fn dev_overlays(&self) -> DevOverlaysClient<'_> {
240 DevOverlaysClient { client: self }
241 }
242
243 pub async fn execute(
245 &self,
246 statement: &str,
247 params: &[&(dyn ToSql + Sync)],
248 ) -> Result<u64, ConnectionError> {
249 mz_postgres_util::execute(
250 &self.client,
251 Sql::raw_unchecked(statement.to_string()),
252 params,
253 )
254 .await
255 .map_err(ConnectionError::from)
256 }
257
258 pub async fn query_one(
260 &self,
261 statement: &str,
262 params: &[&(dyn ToSql + Sync)],
263 ) -> Result<Row, ConnectionError> {
264 mz_postgres_util::query_one(
265 &self.client,
266 Sql::raw_unchecked(statement.to_string()),
267 params,
268 )
269 .await
270 .map_err(ConnectionError::from)
271 }
272
273 pub async fn query(
275 &self,
276 statement: &str,
277 params: &[&(dyn ToSql + Sync)],
278 ) -> Result<Vec<Row>, ConnectionError> {
279 mz_postgres_util::query(
280 &self.client,
281 Sql::raw_unchecked(statement.to_string()),
282 params,
283 )
284 .await
285 .map_err(ConnectionError::from)
286 }
287
288 pub async fn simple_query(
290 &self,
291 query: &str,
292 ) -> Result<Vec<SimpleQueryMessage>, ConnectionError> {
293 mz_postgres_util::simple_query(&self.client, Sql::raw_unchecked(query.to_string()))
294 .await
295 .map_err(ConnectionError::from)
296 }
297
298 pub async fn batch_execute(&self, query: &str) -> Result<(), ConnectionError> {
300 mz_postgres_util::batch_execute(&self.client, Sql::raw_unchecked(query.to_string()))
301 .await
302 .map_err(ConnectionError::from)
303 }
304}
305
306const DEFAULT_CA_PATHS: &[&str] = &[
311 "/etc/ssl/cert.pem", "/opt/homebrew/etc/openssl@3/cert.pem", "/usr/local/etc/openssl@3/cert.pem", "/opt/homebrew/etc/openssl/cert.pem", "/usr/local/etc/openssl/cert.pem", "/etc/ssl/certs/ca-certificates.crt", "/etc/pki/tls/certs/ca-bundle.crt", "/etc/ssl/ca-bundle.pem", ];
320
321pub(crate) fn default_sslmode(host: &str) -> SslMode {
328 if is_loopback_host(host) {
329 SslMode::Prefer
330 } else {
331 SslMode::Require
332 }
333}
334
335pub(crate) fn is_loopback_host(host: &str) -> bool {
341 if host == "localhost" {
342 return true;
343 }
344 let unbracketed = host
345 .strip_prefix('[')
346 .and_then(|s| s.strip_suffix(']'))
347 .unwrap_or(host);
348 if let Ok(ip) = unbracketed.parse::<std::net::IpAddr>() {
349 return ip.is_loopback();
350 }
351 false
352}
353
354fn tokio_ssl_mode(mode: SslMode) -> tokio_postgres::config::SslMode {
355 use tokio_postgres::config::SslMode as TokioMode;
356 match mode {
357 SslMode::Disable => TokioMode::Disable,
358 SslMode::Prefer => TokioMode::Prefer,
359 SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => TokioMode::Require,
360 }
361}
362
363#[derive(Debug)]
365enum HostCheck {
366 Dns(String),
368 Ip(std::net::IpAddr),
370}
371
372#[derive(Debug)]
375enum ConnectorSpec {
376 NoTls,
377 Tls {
378 verify: openssl::ssl::SslVerifyMode,
379 host_check: Option<HostCheck>,
380 ca_source: CaSource,
381 },
382}
383
384#[derive(Debug)]
387enum CaSource {
388 None,
390 Explicit(std::path::PathBuf),
392 Hunted(std::path::PathBuf),
394 DefaultVerifyPaths,
398}
399
400enum Connector {
402 NoTls,
403 Tls(postgres_openssl::MakeTlsConnector),
404}
405
406impl std::fmt::Debug for Connector {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 match self {
409 Connector::NoTls => write!(f, "Connector::NoTls"),
410 Connector::Tls(_) => write!(f, "Connector::Tls(...)"),
411 }
412 }
413}
414
415fn plan_connector(
425 mode: SslMode,
426 sslrootcert: Option<&std::path::Path>,
427 host: &str,
428 hunt_candidates: &[&std::path::Path],
429 ca_exists: impl Fn(&std::path::Path) -> bool,
430) -> Result<ConnectorSpec, ConnectionError> {
431 use openssl::ssl::SslVerifyMode;
432
433 match mode {
434 SslMode::Disable => Ok(ConnectorSpec::NoTls),
435 SslMode::Prefer | SslMode::Require => Ok(ConnectorSpec::Tls {
436 verify: SslVerifyMode::NONE,
437 host_check: None,
438 ca_source: CaSource::None,
439 }),
440 SslMode::VerifyCa | SslMode::VerifyFull => {
441 let ca_source = resolve_ca_source(sslrootcert, hunt_candidates, ca_exists)?;
442 let host_check = if matches!(mode, SslMode::VerifyFull) {
443 Some(match host.parse::<std::net::IpAddr>() {
444 Ok(ip) => HostCheck::Ip(ip),
445 Err(_) => HostCheck::Dns(host.to_string()),
446 })
447 } else {
448 None
449 };
450 Ok(ConnectorSpec::Tls {
451 verify: SslVerifyMode::PEER,
452 host_check,
453 ca_source,
454 })
455 }
456 }
457}
458
459fn resolve_ca_source(
460 explicit: Option<&std::path::Path>,
461 hunt_candidates: &[&std::path::Path],
462 ca_exists: impl Fn(&std::path::Path) -> bool,
463) -> Result<CaSource, ConnectionError> {
464 if let Some(path) = explicit {
465 if ca_exists(path) {
466 return Ok(CaSource::Explicit(path.to_path_buf()));
467 } else {
468 return Err(ConnectionError::TlsCaNotFound);
469 }
470 }
471 for candidate in hunt_candidates {
472 if ca_exists(candidate) {
473 return Ok(CaSource::Hunted(candidate.to_path_buf()));
474 }
475 }
476 Ok(CaSource::DefaultVerifyPaths)
477}
478
479fn build_connector(spec: ConnectorSpec) -> Result<Connector, ConnectionError> {
482 use openssl::ssl::{SslConnector, SslMethod};
483
484 match spec {
485 ConnectorSpec::NoTls => Ok(Connector::NoTls),
486 ConnectorSpec::Tls {
487 verify,
488 host_check,
489 ca_source,
490 } => {
491 let mut builder = SslConnector::builder(SslMethod::tls()).map_err(|e| {
492 ConnectionError::Message(format!("Failed to create TLS builder: {}", e))
493 })?;
494
495 match ca_source {
496 CaSource::None => {}
497 CaSource::Explicit(path) | CaSource::Hunted(path) => {
498 builder
499 .set_ca_file(&path)
500 .map_err(|_| ConnectionError::TlsCaNotFound)?;
501 }
502 CaSource::DefaultVerifyPaths => {
503 builder
504 .set_default_verify_paths()
505 .map_err(|_| ConnectionError::TlsCaNotFound)?;
506 }
507 }
508
509 builder.set_verify(verify);
510
511 if let Some(check) = host_check {
512 let param = builder.verify_param_mut();
513 match check {
514 HostCheck::Dns(name) => {
515 param
516 .set_host(&name)
517 .map_err(|e| ConnectionError::Message(format!("{}", e)))?;
518 }
519 HostCheck::Ip(ip) => {
520 param
521 .set_ip(ip)
522 .map_err(|e| ConnectionError::Message(format!("{}", e)))?;
523 }
524 }
525 }
526
527 Ok(Connector::Tls(postgres_openssl::MakeTlsConnector::new(
528 builder.build(),
529 )))
530 }
531 }
532}
533
534fn classify_connect_error(
545 source: tokio_postgres::Error,
546 profile: &Profile,
547 mode: SslMode,
548) -> ConnectionError {
549 let host = profile.host.clone().unwrap_or_default();
552 if matches!(mode, SslMode::VerifyCa | SslMode::VerifyFull) {
553 if let Some(ssl_msg) = ssl_error_in_chain(&source) {
554 let hostname_suffix = if ssl_msg.contains("hostname mismatch")
555 || ssl_msg.contains("Hostname mismatch")
556 || ssl_msg.contains("IP address mismatch")
557 {
558 " (hostname mismatch)"
559 } else {
560 ""
561 };
562 return ConnectionError::TlsVerification {
563 host,
564 port: profile.port,
565 hostname_suffix,
566 source,
567 };
568 }
569 }
570
571 if matches!(
572 mode,
573 SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull
574 ) && message_indicates_tls_refused(&source)
575 {
576 return ConnectionError::TlsRequiredNotSupported {
577 host,
578 port: profile.port,
579 source,
580 };
581 }
582
583 ConnectionError::Connect {
584 host,
585 port: profile.port,
586 source,
587 }
588}
589
590fn ssl_error_in_chain(err: &tokio_postgres::Error) -> Option<String> {
593 let mut cur: &(dyn std::error::Error + 'static) = err;
594 while let Some(source) = std::error::Error::source(cur) {
595 if source.is::<openssl::error::ErrorStack>() {
596 return Some(source.to_string());
597 }
598 cur = source;
599 }
600 None
601}
602
603fn message_indicates_tls_refused(err: &tokio_postgres::Error) -> bool {
609 matches_tls_refused_message(&err.to_string())
610}
611
612fn matches_tls_refused_message(msg: &str) -> bool {
619 msg.contains("TLS was required")
620 || msg.contains("server does not support TLS")
621 || msg.contains("server does not support SSL")
622}
623
624fn escape_options_value(value: &str) -> String {
631 let mut out = String::with_capacity(value.len());
632 for c in value.chars() {
633 match c {
634 '\\' => out.push_str(r"\\"),
635 ' ' => out.push_str(r"\ "),
636 other => out.push(other),
637 }
638 }
639 out
640}
641
642pub(crate) fn build_options_string(options: &BTreeMap<String, String>) -> Option<String> {
649 if options.is_empty() {
650 return None;
651 }
652 let joined = options
653 .iter()
654 .map(|(k, v)| format!("-c {k}={}", escape_options_value(v)))
655 .collect::<Vec<_>>()
656 .join(" ");
657 Some(joined)
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[mz_ore::test]
665 fn test_escape_options_value_plain() {
666 assert_eq!(escape_options_value("prod"), "prod");
667 }
668
669 #[mz_ore::test]
670 fn test_escape_options_value_space() {
671 assert_eq!(escape_options_value("prod cluster"), r"prod\ cluster");
672 }
673
674 #[mz_ore::test]
675 fn test_escape_options_value_backslash() {
676 assert_eq!(escape_options_value(r"a\b"), r"a\\b");
677 }
678
679 #[mz_ore::test]
680 fn test_escape_options_value_mixed() {
681 assert_eq!(escape_options_value(r"a \b"), r"a\ \\b");
683 }
684
685 #[mz_ore::test]
686 fn test_build_options_string_empty() {
687 let options: BTreeMap<String, String> = BTreeMap::new();
688 assert_eq!(build_options_string(&options), None);
689 }
690
691 #[mz_ore::test]
692 fn test_build_options_string_single() {
693 let mut options = BTreeMap::new();
694 options.insert("cluster".to_string(), "prod".to_string());
695 assert_eq!(
696 build_options_string(&options),
697 Some("-c cluster=prod".to_string())
698 );
699 }
700
701 #[mz_ore::test]
702 fn test_build_options_string_multiple_sorted() {
703 let mut options = BTreeMap::new();
704 options.insert("search_path".to_string(), "public".to_string());
706 options.insert("cluster".to_string(), "prod".to_string());
707 assert_eq!(
708 build_options_string(&options),
709 Some("-c cluster=prod -c search_path=public".to_string())
710 );
711 }
712
713 #[mz_ore::test]
714 fn test_build_options_string_escapes_value_space() {
715 let mut options = BTreeMap::new();
716 options.insert("cluster".to_string(), "prod cluster".to_string());
717 assert_eq!(
718 build_options_string(&options),
719 Some(r"-c cluster=prod\ cluster".to_string())
720 );
721 }
722
723 #[mz_ore::test]
724 fn test_build_options_string_escapes_value_backslash() {
725 let mut options = BTreeMap::new();
726 options.insert("cluster".to_string(), r"a\b".to_string());
727 assert_eq!(
728 build_options_string(&options),
729 Some(r"-c cluster=a\\b".to_string())
730 );
731 }
732
733 use std::path::Path;
734
735 #[mz_ore::test]
736 fn plan_disable_produces_notls() {
737 let spec = plan_connector(SslMode::Disable, None, "example.com", &[], |_| false).unwrap();
738 assert!(matches!(spec, ConnectorSpec::NoTls));
739 }
740
741 #[mz_ore::test]
742 fn plan_prefer_and_require_have_verify_none_and_no_ca() {
743 for mode in [SslMode::Prefer, SslMode::Require] {
744 let spec = plan_connector(mode, None, "example.com", &[], |_| true).unwrap();
745 match spec {
746 ConnectorSpec::Tls {
747 verify,
748 host_check,
749 ca_source,
750 } => {
751 assert_eq!(verify, openssl::ssl::SslVerifyMode::NONE);
752 assert!(host_check.is_none());
753 assert!(matches!(ca_source, CaSource::None));
754 }
755 ConnectorSpec::NoTls => panic!("expected Tls for {:?}, got NoTls", mode),
756 }
757 }
758 }
759
760 #[mz_ore::test]
761 fn plan_verify_ca_has_peer_verify_no_host_check() {
762 let spec = plan_connector(
763 SslMode::VerifyCa,
764 None,
765 "example.com",
766 &[Path::new("/does/not/exist"), Path::new("/tmp/fake-ca.pem")],
767 |p| p == Path::new("/tmp/fake-ca.pem"),
768 )
769 .unwrap();
770 match spec {
771 ConnectorSpec::Tls {
772 verify,
773 host_check,
774 ca_source,
775 } => {
776 assert_eq!(verify, openssl::ssl::SslVerifyMode::PEER);
777 assert!(host_check.is_none());
778 assert!(
779 matches!(ca_source, CaSource::Hunted(p) if p == Path::new("/tmp/fake-ca.pem"))
780 );
781 }
782 ConnectorSpec::NoTls => panic!("expected Tls, got NoTls"),
783 }
784 }
785
786 #[mz_ore::test]
787 fn plan_verify_full_dns_host_check() {
788 let spec = plan_connector(
789 SslMode::VerifyFull,
790 None,
791 "example.com",
792 &[Path::new("/tmp/fake-ca.pem")],
793 |_| true,
794 )
795 .unwrap();
796 match spec {
797 ConnectorSpec::Tls {
798 host_check: Some(HostCheck::Dns(ref name)),
799 ..
800 } => assert_eq!(name, "example.com"),
801 other => panic!("expected Tls with Dns host check, got {:?}", other),
802 }
803 }
804
805 #[mz_ore::test]
806 fn plan_verify_full_ip_host_check() {
807 let spec = plan_connector(
808 SslMode::VerifyFull,
809 None,
810 "10.0.0.5",
811 &[Path::new("/tmp/fake-ca.pem")],
812 |_| true,
813 )
814 .unwrap();
815 match spec {
816 ConnectorSpec::Tls {
817 host_check: Some(HostCheck::Ip(ip)),
818 ..
819 } => assert_eq!(ip, "10.0.0.5".parse::<std::net::IpAddr>().unwrap()),
820 other => panic!("expected Tls with Ip host check, got {:?}", other),
821 }
822 }
823
824 #[mz_ore::test]
825 fn plan_explicit_sslrootcert_wins_over_hunt() {
826 let explicit = std::path::PathBuf::from("/my/ca.pem");
827 let spec = plan_connector(
828 SslMode::VerifyCa,
829 Some(&explicit),
830 "example.com",
831 &[Path::new("/tmp/should-be-ignored.pem")],
832 |p| p == explicit.as_path(),
833 )
834 .unwrap();
835 match spec {
836 ConnectorSpec::Tls {
837 ca_source: CaSource::Explicit(p),
838 ..
839 } => assert_eq!(p, explicit),
840 other => panic!("expected Tls/Explicit, got {:?}", other),
841 }
842 }
843
844 #[mz_ore::test]
845 fn plan_explicit_sslrootcert_missing_is_ca_not_found() {
846 let explicit = std::path::PathBuf::from("/no/such/file.pem");
847 let err = plan_connector(
848 SslMode::VerifyCa,
849 Some(&explicit),
850 "example.com",
851 &[Path::new("/tmp/fake-ca.pem")],
852 |_| false,
853 )
854 .unwrap_err();
855 assert!(matches!(err, ConnectionError::TlsCaNotFound));
856 }
857
858 #[mz_ore::test]
859 fn plan_no_ca_sources_at_all_falls_back_to_default_verify_paths() {
860 let spec = plan_connector(
861 SslMode::VerifyFull,
862 None,
863 "example.com",
864 &[Path::new("/nope1"), Path::new("/nope2")],
865 |_| false,
866 )
867 .unwrap();
868 match spec {
869 ConnectorSpec::Tls {
870 ca_source: CaSource::DefaultVerifyPaths,
871 ..
872 } => {}
873 other => panic!("expected Tls/DefaultVerifyPaths, got {:?}", other),
874 }
875 }
876
877 #[mz_ore::test]
878 fn build_disable_returns_notls() {
879 let connector = build_connector(ConnectorSpec::NoTls).unwrap();
880 assert!(matches!(connector, Connector::NoTls));
881 }
882
883 #[cfg_attr(miri, ignore)] #[mz_ore::test]
885 fn build_prefer_returns_tls_no_ca_work() {
886 let connector = build_connector(ConnectorSpec::Tls {
887 verify: openssl::ssl::SslVerifyMode::NONE,
888 host_check: None,
889 ca_source: CaSource::None,
890 })
891 .unwrap();
892 assert!(matches!(connector, Connector::Tls(_)));
893 }
894
895 #[cfg_attr(miri, ignore)] #[mz_ore::test]
897 fn build_explicit_missing_ca_returns_ca_not_found() {
898 let err = build_connector(ConnectorSpec::Tls {
899 verify: openssl::ssl::SslVerifyMode::PEER,
900 host_check: None,
901 ca_source: CaSource::Explicit(std::path::PathBuf::from("/absolutely/not/a/real/file")),
902 })
903 .unwrap_err();
904 assert!(matches!(err, ConnectionError::TlsCaNotFound));
905 }
906
907 #[mz_ore::test]
908 fn matches_tls_refused_tls_was_required() {
909 assert!(matches_tls_refused_message(
910 "some prefix: TLS was required but not provided"
911 ));
912 }
913
914 #[mz_ore::test]
915 fn matches_tls_refused_does_not_support_tls() {
916 assert!(matches_tls_refused_message(
917 "error: server does not support TLS"
918 ));
919 }
920
921 #[mz_ore::test]
922 fn matches_tls_refused_does_not_support_ssl() {
923 assert!(matches_tls_refused_message(
924 "error: server does not support SSL"
925 ));
926 }
927
928 #[mz_ore::test]
929 fn matches_tls_refused_unrelated_message() {
930 assert!(!matches_tls_refused_message("connection refused"));
931 assert!(!matches_tls_refused_message("database does not exist"));
932 assert!(!matches_tls_refused_message(""));
933 }
934}