1mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55 FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58 Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59 ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78 INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79 has_tracing_config_update, tracing_config,
80};
81
82pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86 build_version: Version,
88 internal_http_listen_addr: SocketAddr,
90 pgwire_listen_addr: SocketAddr,
92 https_listen_addr: SocketAddr,
94 cancellation_resolver: CancellationResolver,
96 resolver: Resolver,
98 https_addr_template: String,
99 tls: Option<TlsCertConfig>,
100 internal_tls: bool,
101 metrics_registry: MetricsRegistry,
102 reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103 launchdarkly_sdk_key: Option<String>,
104 config_sync_file_path: Option<PathBuf>,
105 config_sync_timeout: Duration,
106 config_sync_loop_interval: Option<Duration>,
107 cloud_provider: Option<String>,
108 cloud_provider_region: Option<String>,
109 tracing_handle: TracingHandle,
110 default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114 pub fn new(
115 build_info: &BuildInfo,
116 internal_http_listen_addr: SocketAddr,
117 pgwire_listen_addr: SocketAddr,
118 https_listen_addr: SocketAddr,
119 cancellation_resolver: CancellationResolver,
120 resolver: Resolver,
121 https_addr_template: String,
122 tls: Option<TlsCertConfig>,
123 internal_tls: bool,
124 metrics_registry: MetricsRegistry,
125 reload_certs: ReloadTrigger,
126 launchdarkly_sdk_key: Option<String>,
127 config_sync_file: Option<PathBuf>,
128 config_sync_timeout: Duration,
129 config_sync_loop_interval: Option<Duration>,
130 cloud_provider: Option<String>,
131 cloud_provider_region: Option<String>,
132 tracing_handle: TracingHandle,
133 default_configs: Vec<(String, String)>,
134 ) -> Self {
135 Self {
136 build_version: build_info.semver_version(),
137 internal_http_listen_addr,
138 pgwire_listen_addr,
139 https_listen_addr,
140 cancellation_resolver,
141 resolver,
142 https_addr_template,
143 tls,
144 internal_tls,
145 metrics_registry,
146 reload_certs,
147 launchdarkly_sdk_key,
148 config_sync_file_path: config_sync_file,
149 config_sync_timeout,
150 config_sync_loop_interval,
151 cloud_provider,
152 cloud_provider_region,
153 tracing_handle,
154 default_configs,
155 }
156 }
157}
158
159#[derive(Debug)]
161pub struct BalancerMetrics {
162 _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166 pub fn new(cfg: &BalancerConfig) -> Self {
168 let start = Instant::now();
169 let uptime = cfg.metrics_registry.register_computed_gauge(
170 metric!(
171 name: "mz_balancer_metadata_seconds",
172 help: "server uptime, labels are build metadata",
173 const_labels: {
174 "version" => cfg.build_version,
175 "build_type" => if cfg!(release) { "release" } else { "debug" }
176 },
177 ),
178 move || start.elapsed().as_secs_f64(),
179 );
180 BalancerMetrics { _uptime: uptime }
181 }
182}
183
184pub struct BalancerService {
185 cfg: BalancerConfig,
186 pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187 pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188 pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189 _metrics: BalancerMetrics,
190 configs: ConfigSet,
191}
192
193impl BalancerService {
194 pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195 let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196 let https = listen(&cfg.https_listen_addr).await?;
197 let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198 let metrics = BalancerMetrics::new(&cfg);
199 let mut configs = ConfigSet::default();
200 configs = dyncfgs::all_dyncfgs(configs);
201 dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202 let tracing_handle = cfg.tracing_handle.clone();
203 match (
205 cfg.launchdarkly_sdk_key.as_deref(),
206 cfg.config_sync_file_path.as_deref(),
207 ) {
208 (Some(key), None) => {
209 mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210 configs.clone(),
211 &BUILD_INFO,
212 |builder| {
213 let region = cfg
214 .cloud_provider_region
215 .clone()
216 .unwrap_or_else(|| String::from("unknown"));
217 if let Some(provider) = cfg.cloud_provider.clone() {
218 builder.add_context(
219 ld::ContextBuilder::new(format!(
220 "{}/{}/{}",
221 provider, region, cfg.build_version
222 ))
223 .kind("balancer")
224 .set_string("provider", provider)
225 .set_string("region", region)
226 .set_string("version", cfg.build_version.to_string())
227 .build()
228 .map_err(|e| anyhow::anyhow!(e))?,
229 );
230 } else {
231 builder.add_context(
232 ld::ContextBuilder::new(format!(
233 "{}/{}/{}",
234 "unknown", region, cfg.build_version
235 ))
236 .anonymous(true) .kind("balancer")
238 .set_string("provider", "unknown")
239 .set_string("region", region)
240 .set_string("version", cfg.build_version.to_string())
241 .build()
242 .map_err(|e| anyhow::anyhow!(e))?,
243 );
244 }
245 Ok(())
246 },
247 Some(key),
248 cfg.config_sync_timeout,
249 cfg.config_sync_loop_interval,
250 move |updates, configs| {
251 if has_tracing_config_update(updates) {
252 match tracing_config(configs) {
253 Ok(parameters) => parameters.apply(&tracing_handle),
254 Err(err) => warn!("unable to update tracing: {err}"),
255 }
256 }
257 },
258 )
259 .await
260 .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"))
261 .ok();
262 }
263 (None, Some(path)) => {
264 mz_dyncfg_file::sync_file_to_configset(
265 configs.clone(),
266 path,
267 cfg.config_sync_timeout,
268 cfg.config_sync_loop_interval,
269 move |updates, configs| {
270 if has_tracing_config_update(updates) {
271 match tracing_config(configs) {
272 Ok(parameters) => parameters.apply(&tracing_handle),
273 Err(err) => warn!("unable to update tracing: {err}"),
274 }
275 }
276 },
277 )
278 .await
279 .inspect_err(|e| warn!("File config sync error: {e}"))
285 .ok();
286 }
287 (Some(_), Some(_)) => panic!(
288 "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
289 ),
290 (None, None) => {}
291 };
292 Ok(Self {
293 cfg,
294 pgwire,
295 https,
296 internal_http,
297 _metrics: metrics,
298 configs,
299 })
300 }
301
302 pub async fn serve(self) -> Result<(), anyhow::Error> {
303 let (pgwire_tls, https_tls) = match &self.cfg.tls {
304 Some(tls) => {
305 let context = tls.reloading_context(self.cfg.reload_certs)?;
306 (
307 Some(ReloadingTlsConfig {
308 context: context.clone(),
309 mode: TlsMode::Require,
310 }),
311 Some(context),
312 )
313 }
314 None => (None, None),
315 };
316
317 let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
318
319 let mut set = JoinSet::new();
320 let mut server_handles = Vec::new();
321 let pgwire_addr = self.pgwire.0.local_addr();
322 let https_addr = self.https.0.local_addr();
323 let internal_http_addr = self.internal_http.0.local_addr();
324
325 {
326 let pgwire = PgwireBalancer {
327 resolver: Arc::new(self.cfg.resolver),
328 cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
329 tls: pgwire_tls,
330 internal_tls: self.cfg.internal_tls,
331 metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
332 now: SYSTEM_TIME.clone(),
333 };
334 let (handle, stream) = self.pgwire;
335 server_handles.push(handle);
336 set.spawn_named(|| "pgwire_stream", {
337 let config_set = self.configs.clone();
338 async move {
339 mz_server_core::serve(ServeConfig {
340 server: pgwire,
341 conns: stream,
342 dyncfg: Some(ServeDyncfg {
343 config_set,
344 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
345 }),
346 })
347 .await;
348 warn!("pgwire server exited");
349 }
350 });
351 }
352 {
353 let Some((addr, port)) = self.cfg.https_addr_template.split_once(':') else {
354 panic!("expected port in https_addr_template");
355 };
356 let port: u16 = port.parse().expect("unexpected port");
357 let resolver = StubResolver::new();
358 let https = HttpsBalancer {
359 resolver: Arc::from(resolver),
360 tls: https_tls,
361 resolve_template: Arc::from(addr),
362 port,
363 metrics: Arc::from(ServerMetrics::new(metrics, "https")),
364 configs: self.configs.clone(),
365 internal_tls: self.cfg.internal_tls,
366 };
367 let (handle, stream) = self.https;
368 server_handles.push(handle);
369 set.spawn_named(|| "https_stream", {
370 let config_set = self.configs.clone();
371 async move {
372 mz_server_core::serve(ServeConfig {
373 server: https,
374 conns: stream,
375 dyncfg: Some(ServeDyncfg {
376 config_set,
377 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
378 }),
379 })
380 .await;
381 warn!("https server exited");
382 }
383 });
384 }
385 {
386 let router = Router::new()
387 .route(
388 "/metrics",
389 routing::get(move || async move {
390 mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
391 }),
392 )
393 .route(
394 "/api/livez",
395 routing::get(mz_http_util::handle_liveness_check),
396 )
397 .route("/api/readyz", routing::get(handle_readiness_check));
398 let internal_http = InternalHttpServer { router };
399 let (handle, stream) = self.internal_http;
400 server_handles.push(handle);
401 set.spawn_named(|| "internal_http_stream", async move {
402 mz_server_core::serve(ServeConfig {
403 server: internal_http,
404 conns: stream,
405 dyncfg: None,
408 })
409 .await;
410 warn!("internal_http server exited");
411 });
412 }
413 #[cfg(unix)]
414 {
415 let mut sigterm =
416 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
417 set.spawn_named(|| "sigterm_handler", async move {
418 sigterm.recv().await;
419 let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
420 warn!("received signal TERM - delaying for {:?}!", wait);
421 tokio::time::sleep(wait).await;
422 warn!("sigterm delay complete, dropping server handles");
423 drop(server_handles);
424 });
425 }
426
427 println!("balancerd {} listening...", BUILD_INFO.human_version(None));
428 println!(" TLS enabled: {}", self.cfg.tls.is_some());
429 println!(" pgwire address: {}", pgwire_addr);
430 println!(" HTTPS address: {}", https_addr);
431 println!(" internal HTTP address: {}", internal_http_addr);
432
433 while let Some(res) = set.join_next().await {
435 if let Err(err) = res {
436 error!("serving task failed: {err}")
437 }
438 }
439 Ok(())
440 }
441}
442
443#[allow(clippy::unused_async)]
444async fn handle_readiness_check() -> impl IntoResponse {
445 (StatusCode::OK, "ready")
446}
447
448struct InternalHttpServer {
449 router: Router,
450}
451
452impl mz_server_core::Server for InternalHttpServer {
453 const NAME: &'static str = "internal_http";
454
455 fn handle_connection(
457 &self,
458 conn: Connection,
459 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
460 ) -> mz_server_core::ConnectionHandler {
461 let router = self.router.clone();
462 let service = hyper::service::service_fn(move |req| router.clone().call(req));
463 let conn = TokioIo::new(conn);
464
465 Box::pin(async {
466 let http = hyper::server::conn::http1::Builder::new();
467 http.serve_connection(conn, service).err_into().await
468 })
469 }
470}
471
472struct GaugeGuard {
475 gauge: IntGauge,
476}
477
478impl From<IntGauge> for GaugeGuard {
479 fn from(gauge: IntGauge) -> Self {
480 let _self = Self { gauge };
481 _self.gauge.inc();
482 _self
483 }
484}
485
486impl Drop for GaugeGuard {
487 fn drop(&mut self) {
488 self.gauge.dec();
489 }
490}
491
492#[derive(Clone, Debug)]
493struct ServerMetricsConfig {
494 connection_status: IntCounterVec,
495 active_connections: IntGaugeVec,
496 tenant_connections: IntGaugeVec,
497 tenant_connection_rx: IntCounterVec,
498 tenant_connection_tx: IntCounterVec,
499 tenant_pgwire_sni_count: IntCounterVec,
500}
501
502impl ServerMetricsConfig {
503 fn register_into(registry: &MetricsRegistry) -> Self {
504 let connection_status = registry.register(metric!(
505 name: "mz_balancer_connection_status",
506 help: "Count of completed network connections, by status",
507 var_labels: ["source", "status"],
508 ));
509 let active_connections = registry.register(metric!(
510 name: "mz_balancer_connection_active",
511 help: "Count of currently open network connections.",
512 var_labels: ["source"],
513 ));
514 let tenant_connections = registry.register(metric!(
515 name: "mz_balancer_tenant_connection_active",
516 help: "Count of opened network connections by tenant.",
517 var_labels: ["source", "tenant"]
518 ));
519 let tenant_connection_rx = registry.register(metric!(
520 name: "mz_balancer_tenant_connection_rx",
521 help: "Number of bytes received from a client for a tenant.",
522 var_labels: ["source", "tenant"],
523 ));
524 let tenant_connection_tx = registry.register(metric!(
525 name: "mz_balancer_tenant_connection_tx",
526 help: "Number of bytes sent to a client for a tenant.",
527 var_labels: ["source", "tenant"],
528 ));
529 let tenant_pgwire_sni_count = registry.register(metric!(
530 name: "mz_balancer_tenant_pgwire_sni_count",
531 help: "Count of pgwire connections that have and do not have SNI available per tenant.",
532 var_labels: ["tenant", "has_sni"],
533 ));
534 Self {
535 connection_status,
536 active_connections,
537 tenant_connections,
538 tenant_connection_rx,
539 tenant_connection_tx,
540 tenant_pgwire_sni_count,
541 }
542 }
543}
544
545#[derive(Clone, Debug)]
546struct ServerMetrics {
547 inner: ServerMetricsConfig,
548 source: &'static str,
549}
550
551impl ServerMetrics {
552 fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
553 let self_ = Self { inner, source };
554
555 self_.connection_status(false);
558 self_.connection_status(true);
559 drop(self_.active_connections());
560
561 self_
562 }
563
564 fn connection_status(&self, is_ok: bool) -> IntCounter {
565 self.inner
566 .connection_status
567 .with_label_values(&[self.source, Self::status_label(is_ok)])
568 }
569
570 fn active_connections(&self) -> GaugeGuard {
571 self.inner
572 .active_connections
573 .with_label_values(&[self.source])
574 .into()
575 }
576
577 fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
578 self.inner
579 .tenant_connections
580 .with_label_values(&[self.source, tenant])
581 .into()
582 }
583
584 fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
585 self.inner
586 .tenant_connection_rx
587 .with_label_values(&[self.source, tenant])
588 }
589
590 fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
591 self.inner
592 .tenant_connection_tx
593 .with_label_values(&[self.source, tenant])
594 }
595
596 fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
597 self.inner
598 .tenant_pgwire_sni_count
599 .with_label_values(&[tenant, &has_sni.to_string()])
600 }
601
602 fn status_label(is_ok: bool) -> &'static str {
603 if is_ok { "success" } else { "error" }
604 }
605}
606
607pub enum CancellationResolver {
608 Directory(PathBuf),
609 Static(String),
610}
611
612struct PgwireBalancer {
613 tls: Option<ReloadingTlsConfig>,
614 internal_tls: bool,
615 cancellation_resolver: Arc<CancellationResolver>,
616 resolver: Arc<Resolver>,
617 metrics: ServerMetrics,
618 now: NowFn,
619}
620
621impl PgwireBalancer {
622 #[mz_ore::instrument(level = "debug")]
623 async fn run<'a, A>(
624 conn: &'a mut FramedConn<A>,
625 version: i32,
626 params: BTreeMap<String, String>,
627 resolver: &Resolver,
628 tls_mode: Option<TlsMode>,
629 internal_tls: bool,
630 metrics: &ServerMetrics,
631 ) -> Result<(), io::Error>
632 where
633 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
634 {
635 if version != VERSION_3 {
636 return conn
637 .send(ErrorResponse::fatal(
638 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
639 "server does not support the client's requested protocol version",
640 ))
641 .await;
642 }
643
644 let Some(user) = params.get("user") else {
645 return conn
646 .send(ErrorResponse::fatal(
647 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
648 "user parameter required",
649 ))
650 .await;
651 };
652
653 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
654 return conn.send(err).await;
655 }
656
657 let resolved = match resolver.resolve(conn, user).await {
658 Ok(v) => v,
659 Err(err) => {
660 return conn
661 .send(ErrorResponse::fatal(
662 SqlState::INVALID_PASSWORD,
663 err.to_string(),
664 ))
665 .await;
666 }
667 };
668
669 if let Conn::Ssl(ssl_stream) = conn.inner() {
672 let tenant = resolved.tenant.as_deref().unwrap_or("unknown");
673 let has_sni = ssl_stream.ssl().servername(NameType::HOST_NAME).is_some();
674 metrics.tenant_pgwire_sni_count(tenant, has_sni).inc();
675 }
676
677 let _active_guard = resolved
678 .tenant
679 .as_ref()
680 .map(|tenant| metrics.tenant_connections(tenant));
681 let Ok(mut mz_stream) =
682 Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
683 else {
684 return Ok(());
685 };
686
687 let mut client_counter = CountingConn::new(conn.inner_mut());
688
689 let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
692 if let Some(tenant) = &resolved.tenant {
693 metrics
694 .tenant_connections_tx(tenant)
695 .inc_by(u64::cast_from(client_counter.written));
696 metrics
697 .tenant_connections_rx(tenant)
698 .inc_by(u64::cast_from(client_counter.read));
699 }
700 res?;
701
702 Ok(())
703 }
704
705 #[mz_ore::instrument(level = "debug")]
706 async fn init_stream<'a, A>(
707 conn: &'a mut FramedConn<A>,
708 envd_addr: SocketAddr,
709 password: Option<String>,
710 params: BTreeMap<String, String>,
711 internal_tls: bool,
712 ) -> Result<Conn<TcpStream>, anyhow::Error>
713 where
714 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
715 {
716 let mut mz_stream = TcpStream::connect(envd_addr).await?;
717 let mut buf = BytesMut::new();
718
719 let mut mz_stream = if internal_tls {
720 FrontendStartupMessage::SslRequest.encode(&mut buf)?;
721 mz_stream.write_all(&buf).await?;
722 buf.clear();
723 let mut maybe_ssl_request_response = [0u8; 1];
724 let nread =
725 netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
726 if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
727 let mut builder =
729 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
730 builder.set_verify(SslVerifyMode::NONE);
732 let mut ssl = builder
733 .build()
734 .configure()?
735 .into_ssl(&envd_addr.to_string())?;
736 ssl.set_connect_state();
737 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
738 } else {
739 Conn::Unencrypted(mz_stream)
740 }
741 } else {
742 Conn::Unencrypted(mz_stream)
743 };
744
745 let startup = FrontendStartupMessage::Startup {
747 version: VERSION_3,
748 params,
749 };
750 startup.encode(&mut buf)?;
751 mz_stream.write_all(&buf).await?;
752 let client_stream = conn.inner_mut();
753
754 let mut maybe_auth_frame = [0; 1 + 4 + 4];
758 let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
759 const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
762 if nread == AUTH_PASSWORD_CLEARTEXT.len()
763 && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
764 && password.is_some()
765 {
766 let Some(password) = password else {
768 unreachable!("verified some above");
769 };
770 let password = FrontendMessage::Password { password };
771 buf.clear();
772 password.encode(&mut buf)?;
773 mz_stream.write_all(&buf).await?;
774 mz_stream.flush().await?;
775 } else {
776 client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
779 }
780
781 Ok(mz_stream)
782 }
783}
784
785impl mz_server_core::Server for PgwireBalancer {
786 const NAME: &'static str = "pgwire_balancer";
787
788 fn handle_connection(
789 &self,
790 conn: Connection,
791 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
792 ) -> mz_server_core::ConnectionHandler {
793 let tls = self.tls.clone();
794 let internal_tls = self.internal_tls;
795 let resolver = Arc::clone(&self.resolver);
796 let inner_metrics = self.metrics.clone();
797 let outer_metrics = self.metrics.clone();
798 let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
799 let conn_uuid = epoch_to_uuid_v7(&(self.now)());
800 let peer_addr = conn.peer_addr();
801 conn.uuid_handle().set(conn_uuid);
802 Box::pin(async move {
803 let active_guard = outer_metrics.active_connections();
806 let result: Result<(), anyhow::Error> = async move {
807 let mut conn = Conn::Unencrypted(conn);
808 loop {
809 let message = decode_startup(&mut conn).await?;
810 conn = match message {
811 None => return Ok(()),
815
816 Some(FrontendStartupMessage::Startup {
817 version,
818 mut params,
819 }) => {
820 let mut conn = FramedConn::new(conn);
821 let peer_addr = match peer_addr {
822 Ok(addr) => addr.ip(),
823 Err(e) => {
824 error!("Invalid peer_addr {:?}", e);
825 return Ok(conn
826 .send(ErrorResponse::fatal(
827 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
828 "invalid peer address",
829 ))
830 .await?);
831 }
832 };
833 debug!(%conn_uuid, %peer_addr, "starting new pgwire connection in balancer");
834 let prev =
835 params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
836 if prev.is_some() {
837 return Ok(conn
838 .send(ErrorResponse::fatal(
839 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
840 format!("invalid parameter '{CONN_UUID_KEY}'"),
841 ))
842 .await?);
843 }
844
845 if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
846 return Ok(conn
847 .send(ErrorResponse::fatal(
848 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
849 format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
850 ))
851 .await?);
852 };
853
854 Self::run(
855 &mut conn,
856 version,
857 params,
858 &resolver,
859 tls.map(|tls| tls.mode),
860 internal_tls,
861 &inner_metrics,
862 )
863 .await?;
864 conn.flush().await?;
865 return Ok(());
866 }
867
868 Some(FrontendStartupMessage::CancelRequest {
869 conn_id,
870 secret_key,
871 }) => {
872 spawn(|| "cancel request", async move {
873 cancel_request(conn_id, secret_key, &cancellation_resolver).await;
874 });
875 return Ok(());
878 }
879
880 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
881 (Conn::Unencrypted(mut conn), Some(tls)) => {
882 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
883 let mut ssl_stream =
884 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
885 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
886 let _ = ssl_stream.get_mut().shutdown().await;
887 return Err(e.into());
888 }
889 Conn::Ssl(ssl_stream)
890 }
891 (mut conn, _) => {
892 conn.write_all(&[REJECT_ENCRYPTION]).await?;
893 conn
894 }
895 },
896
897 Some(FrontendStartupMessage::GssEncRequest) => {
898 conn.write_all(&[REJECT_ENCRYPTION]).await?;
899 conn
900 }
901 }
902 }
903 }
904 .await;
905 drop(active_guard);
906 outer_metrics.connection_status(result.is_ok()).inc();
907 Ok(())
908 })
909 }
910}
911
912struct CountingConn<C> {
914 inner: C,
915 read: usize,
916 written: usize,
917}
918
919impl<C> CountingConn<C> {
920 fn new(inner: C) -> Self {
921 CountingConn {
922 inner,
923 read: 0,
924 written: 0,
925 }
926 }
927}
928
929impl<C> AsyncRead for CountingConn<C>
930where
931 C: AsyncRead + Unpin,
932{
933 fn poll_read(
934 self: Pin<&mut Self>,
935 cx: &mut std::task::Context<'_>,
936 buf: &mut io::ReadBuf<'_>,
937 ) -> std::task::Poll<std::io::Result<()>> {
938 let counter = self.get_mut();
939 let pin = Pin::new(&mut counter.inner);
940 let bytes = buf.filled().len();
941 let poll = pin.poll_read(cx, buf);
942 let bytes = buf.filled().len() - bytes;
943 if let std::task::Poll::Ready(Ok(())) = poll {
944 counter.read += bytes
945 }
946 poll
947 }
948}
949
950impl<C> AsyncWrite for CountingConn<C>
951where
952 C: AsyncWrite + Unpin,
953{
954 fn poll_write(
955 self: Pin<&mut Self>,
956 cx: &mut std::task::Context<'_>,
957 buf: &[u8],
958 ) -> std::task::Poll<Result<usize, std::io::Error>> {
959 let counter = self.get_mut();
960 let pin = Pin::new(&mut counter.inner);
961 let poll = pin.poll_write(cx, buf);
962 if let std::task::Poll::Ready(Ok(bytes)) = poll {
963 counter.written += bytes
964 }
965 poll
966 }
967
968 fn poll_flush(
969 self: Pin<&mut Self>,
970 cx: &mut std::task::Context<'_>,
971 ) -> std::task::Poll<Result<(), std::io::Error>> {
972 let counter = self.get_mut();
973 let pin = Pin::new(&mut counter.inner);
974 pin.poll_flush(cx)
975 }
976
977 fn poll_shutdown(
978 self: Pin<&mut Self>,
979 cx: &mut std::task::Context<'_>,
980 ) -> std::task::Poll<Result<(), std::io::Error>> {
981 let counter = self.get_mut();
982 let pin = Pin::new(&mut counter.inner);
983 pin.poll_shutdown(cx)
984 }
985}
986
987async fn cancel_request(
1004 conn_id: u32,
1005 secret_key: u32,
1006 cancellation_resolver: &CancellationResolver,
1007) {
1008 let suffix = conn_id_org_uuid(conn_id);
1009 let contents = match cancellation_resolver {
1010 CancellationResolver::Directory(dir) => {
1011 let path = dir.join(&suffix);
1012 match std::fs::read_to_string(&path) {
1013 Ok(contents) => contents,
1014 Err(err) => {
1015 error!("could not read cancel file {path:?}: {err}");
1016 return;
1017 }
1018 }
1019 }
1020 CancellationResolver::Static(addr) => addr.to_owned(),
1021 };
1022 let mut all_ips = Vec::new();
1023 for addr in contents.lines() {
1024 let addr = addr.trim();
1025 if addr.is_empty() {
1026 continue;
1027 }
1028 match tokio::net::lookup_host(addr).await {
1029 Ok(ips) => all_ips.extend(ips),
1030 Err(err) => {
1031 error!("{addr} failed resolution: {err}");
1032 }
1033 }
1034 }
1035 let mut buf = BytesMut::with_capacity(16);
1036 let msg = FrontendStartupMessage::CancelRequest {
1037 conn_id,
1038 secret_key,
1039 };
1040 msg.encode(&mut buf).expect("must encode");
1041 let buf = buf.freeze();
1042 for ip in all_ips {
1043 debug!("cancelling {suffix} to {ip}");
1044 let buf = buf.clone();
1045 spawn(|| "cancel request for ip", async move {
1046 let send = async {
1047 let mut stream = TcpStream::connect(&ip).await?;
1048 stream.write_all(&buf).await?;
1049 stream.shutdown().await?;
1050 Ok::<_, io::Error>(())
1051 };
1052 if let Err(err) = send.await {
1053 error!("error mirroring cancel to {ip}: {err}");
1054 }
1055 });
1056 }
1057}
1058
1059struct HttpsBalancer {
1060 resolver: Arc<StubResolver>,
1061 tls: Option<ReloadingSslContext>,
1062 resolve_template: Arc<str>,
1063 port: u16,
1064 metrics: Arc<ServerMetrics>,
1065 configs: ConfigSet,
1066 internal_tls: bool,
1067}
1068
1069impl HttpsBalancer {
1070 async fn resolve(
1071 resolver: &StubResolver,
1072 resolve_template: &str,
1073 port: u16,
1074 servername: Option<&str>,
1075 ) -> Result<ResolvedAddr, anyhow::Error> {
1076 let addr = match &servername {
1077 Some(servername) => resolve_template.replace("{}", servername),
1078 None => resolve_template.to_string(),
1079 };
1080 debug!("https address: {addr}");
1081
1082 let tenant = Self::tenant(resolver, &addr).await;
1095
1096 let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1098
1099 Ok(ResolvedAddr {
1100 addr: envd_addr,
1101 password: None,
1102 tenant,
1103 })
1104 }
1105
1106 async fn tenant(resolver: &StubResolver, addr: &str) -> Option<String> {
1109 let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1110 return None;
1111 };
1112 let lookup = resolver.query((dname, Rtype::CNAME)).await;
1114 if let Ok(lookup) = lookup {
1115 if let Ok(answer) = lookup.answer() {
1116 let res = answer.limit_to::<AllRecordData<_, _>>();
1117 for record in res {
1118 let Ok(record) = record else {
1119 continue;
1120 };
1121 if record.rtype() != Rtype::CNAME {
1122 continue;
1123 }
1124 let cname = record.data();
1125 let cname = cname.to_string();
1126 debug!("cname: {cname}");
1127 return Self::extract_tenant_from_cname(&cname);
1128 }
1129 }
1130 }
1131 None
1132 }
1133
1134 fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1136 let mut parts = cname.split('.');
1137 let _service = parts.next();
1138 let Some(namespace) = parts.next() else {
1139 return None;
1140 };
1141 let Some((_, namespace)) = namespace.split_once('-') else {
1143 return None;
1144 };
1145 let Some((tenant, _)) = namespace.rsplit_once('-') else {
1147 return None;
1148 };
1149 let Ok(tenant) = Uuid::parse_str(tenant) else {
1152 error!("cname tenant not a uuid: {tenant}");
1153 return None;
1154 };
1155 Some(tenant.to_string())
1156 }
1157}
1158
1159impl mz_server_core::Server for HttpsBalancer {
1160 const NAME: &'static str = "https_balancer";
1161
1162 fn handle_connection(
1164 &self,
1165 conn: Connection,
1166 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1167 ) -> mz_server_core::ConnectionHandler {
1168 let tls_context = self.tls.clone();
1169 let internal_tls = self.internal_tls.clone();
1170 let resolver = Arc::clone(&self.resolver);
1171 let resolve_template = Arc::clone(&self.resolve_template);
1172 let port = self.port;
1173 let inner_metrics = Arc::clone(&self.metrics);
1174 let outer_metrics = Arc::clone(&self.metrics);
1175 let peer_addr = conn.peer_addr();
1176 let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1177 Box::pin(async move {
1178 let active_guard = inner_metrics.active_connections();
1179 let result: Result<_, anyhow::Error> = Box::pin(async move {
1180 let peer_addr = peer_addr.context("fetching peer addr")?;
1181 let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1182 match tls_context {
1183 Some(tls_context) => {
1184 let mut ssl_stream =
1185 SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1186 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1187 let _ = ssl_stream.get_mut().shutdown().await;
1188 return Err(e.into());
1189 }
1190 let servername: Option<String> =
1191 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1192 match sn.split_once('.') {
1193 Some((left, _right)) => left,
1194 None => sn,
1195 }
1196 .into()
1197 });
1198 debug!("servername: {servername:?}");
1199 (Box::new(ssl_stream), servername)
1200 }
1201 _ => (Box::new(conn), None),
1202 };
1203 let resolved =
1204 Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1205 .await?;
1206 let inner_active_guard = resolved
1207 .tenant
1208 .as_ref()
1209 .map(|tenant| inner_metrics.tenant_connections(tenant));
1210
1211 let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1212
1213 if inject_proxy_headers {
1214 let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1216 let header = ProxyHeader::with_address(addrs);
1217 let mut buf = [0u8; 1024];
1218 let len = header.encode_to_slice_v2(&mut buf)?;
1219 mz_stream.write_all(&buf[..len]).await?;
1220 }
1221
1222 let mut mz_stream = if internal_tls {
1223 let mut builder =
1225 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1226 builder.set_verify(SslVerifyMode::NONE);
1228 let mut ssl = builder
1229 .build()
1230 .configure()?
1231 .into_ssl(&resolved.addr.to_string())?;
1232 ssl.set_connect_state();
1233 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1234 } else {
1235 Conn::Unencrypted(mz_stream)
1236 };
1237
1238 let mut client_counter = CountingConn::new(client_stream);
1239
1240 let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1244 if let Some(tenant) = &resolved.tenant {
1245 inner_metrics
1246 .tenant_connections_tx(tenant)
1247 .inc_by(u64::cast_from(client_counter.written));
1248 inner_metrics
1249 .tenant_connections_rx(tenant)
1250 .inc_by(u64::cast_from(client_counter.read));
1251 }
1252 drop(inner_active_guard);
1253 Ok(())
1254 })
1255 .await;
1256 drop(active_guard);
1257 outer_metrics.connection_status(result.is_ok()).inc();
1258 if let Err(e) = result {
1259 debug!("connection error: {e}");
1260 }
1261 Ok(())
1262 })
1263 }
1264}
1265
1266trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1267impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1268
1269#[derive(Debug)]
1270pub enum Resolver {
1271 Static(String),
1272 Frontegg(FronteggResolver),
1273}
1274
1275impl Resolver {
1276 async fn resolve<A>(
1277 &self,
1278 conn: &mut FramedConn<A>,
1279 user: &str,
1280 ) -> Result<ResolvedAddr, anyhow::Error>
1281 where
1282 A: AsyncRead + AsyncWrite + Unpin,
1283 {
1284 match self {
1285 Resolver::Frontegg(FronteggResolver {
1286 auth,
1287 addr_template,
1288 }) => {
1289 conn.send(BackendMessage::AuthenticationCleartextPassword)
1290 .await?;
1291 conn.flush().await?;
1292 let password = match conn.recv().await? {
1293 Some(FrontendMessage::Password { password }) => password,
1294 _ => anyhow::bail!("expected Password message"),
1295 };
1296
1297 let auth_response = auth.authenticate(user, &password).await;
1298 let auth_session = match auth_response {
1299 Ok(auth_session) => auth_session,
1300 Err(e) => {
1301 warn!("pgwire connection failed authentication: {}", e);
1302 anyhow::bail!("invalid password");
1304 }
1305 };
1306
1307 let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string());
1308 let addr = lookup(&addr).await?;
1309 Ok(ResolvedAddr {
1310 addr,
1311 password: Some(password),
1312 tenant: Some(auth_session.tenant_id().to_string()),
1313 })
1314 }
1315 Resolver::Static(addr) => {
1316 let addr = lookup(addr).await?;
1317 Ok(ResolvedAddr {
1318 addr,
1319 password: None,
1320 tenant: None,
1321 })
1322 }
1323 }
1324 }
1325}
1326
1327async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1329 let mut addrs = tokio::net::lookup_host(name).await?;
1330 match addrs.next() {
1331 Some(addr) => Ok(addr),
1332 None => {
1333 error!("{name} did not resolve to any addresses");
1334 anyhow::bail!("internal error")
1335 }
1336 }
1337}
1338
1339#[derive(Debug)]
1340pub struct FronteggResolver {
1341 pub auth: FronteggAuthentication,
1342 pub addr_template: String,
1343}
1344
1345#[derive(Debug)]
1346struct ResolvedAddr {
1347 addr: SocketAddr,
1348 password: Option<String>,
1349 tenant: Option<String>,
1350}
1351
1352#[cfg(test)]
1353mod tests {
1354 use super::*;
1355
1356 #[mz_ore::test]
1357 fn test_tenant() {
1358 let tests = vec![
1359 ("", None),
1360 (
1361 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1362 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1363 ),
1364 (
1365 "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1367 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1368 ),
1369 (
1370 "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1372 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1373 ),
1374 (
1375 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1377 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1378 ),
1379 (
1380 "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1382 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1383 ),
1384 (
1385 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1387 None,
1388 ),
1389 (
1390 "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1392 None,
1393 ),
1394 (
1395 "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1397 None,
1398 ),
1399 ];
1400 for (name, expect) in tests {
1401 let cname = HttpsBalancer::extract_tenant_from_cname(name);
1402 assert_eq!(
1403 cname.as_deref(),
1404 expect,
1405 "{name} got {cname:?} expected {expect:?}"
1406 );
1407 }
1408 }
1409}