1mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55 FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58 Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59 ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78 INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79 has_tracing_config_update, tracing_config,
80};
81
82pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86 build_version: Version,
88 internal_http_listen_addr: SocketAddr,
90 pgwire_listen_addr: SocketAddr,
92 https_listen_addr: SocketAddr,
94 cancellation_resolver: CancellationResolver,
96 resolver: Resolver,
98 https_sni_addr_template: String,
99 tls: Option<TlsCertConfig>,
100 internal_tls: bool,
101 metrics_registry: MetricsRegistry,
102 reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103 launchdarkly_sdk_key: Option<String>,
104 config_sync_file_path: Option<PathBuf>,
105 config_sync_timeout: Duration,
106 config_sync_loop_interval: Option<Duration>,
107 cloud_provider: Option<String>,
108 cloud_provider_region: Option<String>,
109 tracing_handle: TracingHandle,
110 default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114 pub fn new(
115 build_info: &BuildInfo,
116 internal_http_listen_addr: SocketAddr,
117 pgwire_listen_addr: SocketAddr,
118 https_listen_addr: SocketAddr,
119 cancellation_resolver: CancellationResolver,
120 resolver: Resolver,
121 https_sni_addr_template: String,
122 tls: Option<TlsCertConfig>,
123 internal_tls: bool,
124 metrics_registry: MetricsRegistry,
125 reload_certs: ReloadTrigger,
126 launchdarkly_sdk_key: Option<String>,
127 config_sync_file: Option<PathBuf>,
128 config_sync_timeout: Duration,
129 config_sync_loop_interval: Option<Duration>,
130 cloud_provider: Option<String>,
131 cloud_provider_region: Option<String>,
132 tracing_handle: TracingHandle,
133 default_configs: Vec<(String, String)>,
134 ) -> Self {
135 Self {
136 build_version: build_info.semver_version(),
137 internal_http_listen_addr,
138 pgwire_listen_addr,
139 https_listen_addr,
140 cancellation_resolver,
141 resolver,
142 https_sni_addr_template,
143 tls,
144 internal_tls,
145 metrics_registry,
146 reload_certs,
147 launchdarkly_sdk_key,
148 config_sync_file_path: config_sync_file,
149 config_sync_timeout,
150 config_sync_loop_interval,
151 cloud_provider,
152 cloud_provider_region,
153 tracing_handle,
154 default_configs,
155 }
156 }
157}
158
159#[derive(Debug)]
161pub struct BalancerMetrics {
162 _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166 pub fn new(cfg: &BalancerConfig) -> Self {
168 let start = Instant::now();
169 let uptime = cfg.metrics_registry.register_computed_gauge(
170 metric!(
171 name: "mz_balancer_metadata_seconds",
172 help: "server uptime, labels are build metadata",
173 const_labels: {
174 "version" => cfg.build_version,
175 "build_type" => if cfg!(release) { "release" } else { "debug" }
176 },
177 ),
178 move || start.elapsed().as_secs_f64(),
179 );
180 BalancerMetrics { _uptime: uptime }
181 }
182}
183
184pub struct BalancerService {
185 cfg: BalancerConfig,
186 pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187 pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188 pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189 _metrics: BalancerMetrics,
190 configs: ConfigSet,
191}
192
193impl BalancerService {
194 pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195 let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196 let https = listen(&cfg.https_listen_addr).await?;
197 let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198 let metrics = BalancerMetrics::new(&cfg);
199 let mut configs = ConfigSet::default();
200 configs = dyncfgs::all_dyncfgs(configs);
201 dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202 let tracing_handle = cfg.tracing_handle.clone();
203 match (
205 cfg.launchdarkly_sdk_key.as_deref(),
206 cfg.config_sync_file_path.as_deref(),
207 ) {
208 (Some(key), None) => {
209 mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210 configs.clone(),
211 &BUILD_INFO,
212 |builder| {
213 let region = cfg
214 .cloud_provider_region
215 .clone()
216 .unwrap_or_else(|| String::from("unknown"));
217 if let Some(provider) = cfg.cloud_provider.clone() {
218 builder.add_context(
219 ld::ContextBuilder::new(format!(
220 "{}/{}/{}",
221 provider, region, cfg.build_version
222 ))
223 .kind("balancer")
224 .set_string("provider", provider)
225 .set_string("region", region)
226 .set_string("version", cfg.build_version.to_string())
227 .build()
228 .map_err(|e| anyhow::anyhow!(e))?,
229 );
230 } else {
231 builder.add_context(
232 ld::ContextBuilder::new(format!(
233 "{}/{}/{}",
234 "unknown", region, cfg.build_version
235 ))
236 .anonymous(true) .kind("balancer")
238 .set_string("provider", "unknown")
239 .set_string("region", region)
240 .set_string("version", cfg.build_version.to_string())
241 .build()
242 .map_err(|e| anyhow::anyhow!(e))?,
243 );
244 }
245 Ok(())
246 },
247 Some(key),
248 cfg.config_sync_timeout,
249 cfg.config_sync_loop_interval,
250 move |updates, configs| {
251 if has_tracing_config_update(updates) {
252 match tracing_config(configs) {
253 Ok(parameters) => parameters.apply(&tracing_handle),
254 Err(err) => warn!("unable to update tracing: {err}"),
255 }
256 }
257 },
258 )
259 .await
260 .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"))
261 .ok();
262 }
263 (None, Some(path)) => {
264 mz_dyncfg_file::sync_file_to_configset(
265 configs.clone(),
266 path,
267 cfg.config_sync_timeout,
268 cfg.config_sync_loop_interval,
269 move |updates, configs| {
270 if has_tracing_config_update(updates) {
271 match tracing_config(configs) {
272 Ok(parameters) => parameters.apply(&tracing_handle),
273 Err(err) => warn!("unable to update tracing: {err}"),
274 }
275 }
276 },
277 )
278 .await
279 .inspect_err(|e| warn!("File config sync error: {e}"))
285 .ok();
286 }
287 (Some(_), Some(_)) => panic!(
288 "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
289 ),
290 (None, None) => {}
291 };
292 Ok(Self {
293 cfg,
294 pgwire,
295 https,
296 internal_http,
297 _metrics: metrics,
298 configs,
299 })
300 }
301
302 pub async fn serve(self) -> Result<(), anyhow::Error> {
303 let (pgwire_tls, https_tls) = match &self.cfg.tls {
304 Some(tls) => {
305 let context = tls.reloading_context(self.cfg.reload_certs)?;
306 (
307 Some(ReloadingTlsConfig {
308 context: context.clone(),
309 mode: TlsMode::Require,
310 }),
311 Some(context),
312 )
313 }
314 None => (None, None),
315 };
316
317 let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
318
319 let mut set = JoinSet::new();
320 let mut server_handles = Vec::new();
321 let pgwire_addr = self.pgwire.0.local_addr();
322 let https_addr = self.https.0.local_addr();
323 let internal_http_addr = self.internal_http.0.local_addr();
324
325 {
326 let pgwire = PgwireBalancer {
327 resolver: Arc::new(self.cfg.resolver),
328 cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
329 tls: pgwire_tls,
330 internal_tls: self.cfg.internal_tls,
331 metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
332 now: SYSTEM_TIME.clone(),
333 };
334 let (handle, stream) = self.pgwire;
335 server_handles.push(handle);
336 set.spawn_named(|| "pgwire_stream", {
337 let config_set = self.configs.clone();
338 async move {
339 mz_server_core::serve(ServeConfig {
340 server: pgwire,
341 conns: stream,
342 dyncfg: Some(ServeDyncfg {
343 config_set,
344 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
345 }),
346 })
347 .await;
348 warn!("pgwire server exited");
349 }
350 });
351 }
352 {
353 let Some((addr, port)) = self.cfg.https_sni_addr_template.split_once(':') else {
354 panic!("expected port in https_addr_template");
355 };
356 let port: u16 = port.parse().expect("unexpected port");
357 let resolver = StubResolver::new();
358 let https = HttpsBalancer {
359 resolver: Arc::from(resolver),
360 tls: https_tls,
361 resolve_template: Arc::from(addr),
362 port,
363 metrics: Arc::from(ServerMetrics::new(metrics, "https")),
364 configs: self.configs.clone(),
365 internal_tls: self.cfg.internal_tls,
366 };
367 let (handle, stream) = self.https;
368 server_handles.push(handle);
369 set.spawn_named(|| "https_stream", {
370 let config_set = self.configs.clone();
371 async move {
372 mz_server_core::serve(ServeConfig {
373 server: https,
374 conns: stream,
375 dyncfg: Some(ServeDyncfg {
376 config_set,
377 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
378 }),
379 })
380 .await;
381 warn!("https server exited");
382 }
383 });
384 }
385 {
386 let router = Router::new()
387 .route(
388 "/metrics",
389 routing::get(move || async move {
390 mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
391 }),
392 )
393 .route(
394 "/api/livez",
395 routing::get(mz_http_util::handle_liveness_check),
396 )
397 .route("/api/readyz", routing::get(handle_readiness_check));
398 let internal_http = InternalHttpServer { router };
399 let (handle, stream) = self.internal_http;
400 server_handles.push(handle);
401 set.spawn_named(|| "internal_http_stream", async move {
402 mz_server_core::serve(ServeConfig {
403 server: internal_http,
404 conns: stream,
405 dyncfg: None,
408 })
409 .await;
410 warn!("internal_http server exited");
411 });
412 }
413 #[cfg(unix)]
414 {
415 let mut sigterm =
416 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
417 set.spawn_named(|| "sigterm_handler", async move {
418 sigterm.recv().await;
419 let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
420 warn!("received signal TERM - delaying for {:?}!", wait);
421 tokio::time::sleep(wait).await;
422 warn!("sigterm delay complete, dropping server handles");
423 drop(server_handles);
424 });
425 }
426
427 println!("balancerd {} listening...", BUILD_INFO.human_version(None));
428 println!(" TLS enabled: {}", self.cfg.tls.is_some());
429 println!(" pgwire address: {}", pgwire_addr);
430 println!(" HTTPS address: {}", https_addr);
431 println!(" internal HTTP address: {}", internal_http_addr);
432
433 while let Some(res) = set.join_next().await {
435 if let Err(err) = res {
436 error!("serving task failed: {err}")
437 }
438 }
439 Ok(())
440 }
441}
442
443#[allow(clippy::unused_async)]
444async fn handle_readiness_check() -> impl IntoResponse {
445 (StatusCode::OK, "ready")
446}
447
448struct InternalHttpServer {
449 router: Router,
450}
451
452impl mz_server_core::Server for InternalHttpServer {
453 const NAME: &'static str = "internal_http";
454
455 fn handle_connection(
457 &self,
458 conn: Connection,
459 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
460 ) -> mz_server_core::ConnectionHandler {
461 let router = self.router.clone();
462 let service = hyper::service::service_fn(move |req| router.clone().call(req));
463 let conn = TokioIo::new(conn);
464
465 Box::pin(async {
466 let http = hyper::server::conn::http1::Builder::new();
467 http.serve_connection(conn, service).err_into().await
468 })
469 }
470}
471
472struct GaugeGuard {
475 gauge: IntGauge,
476}
477
478impl From<IntGauge> for GaugeGuard {
479 fn from(gauge: IntGauge) -> Self {
480 let _self = Self { gauge };
481 _self.gauge.inc();
482 _self
483 }
484}
485
486impl Drop for GaugeGuard {
487 fn drop(&mut self) {
488 self.gauge.dec();
489 }
490}
491
492#[derive(Clone, Debug)]
493struct ServerMetricsConfig {
494 connection_status: IntCounterVec,
495 active_connections: IntGaugeVec,
496 tenant_connections: IntGaugeVec,
497 tenant_connection_rx: IntCounterVec,
498 tenant_connection_tx: IntCounterVec,
499 tenant_pgwire_sni_count: IntCounterVec,
500}
501
502impl ServerMetricsConfig {
503 fn register_into(registry: &MetricsRegistry) -> Self {
504 let connection_status = registry.register(metric!(
505 name: "mz_balancer_connection_status",
506 help: "Count of completed network connections, by status",
507 var_labels: ["source", "status"],
508 ));
509 let active_connections = registry.register(metric!(
510 name: "mz_balancer_connection_active",
511 help: "Count of currently open network connections.",
512 var_labels: ["source"],
513 ));
514 let tenant_connections = registry.register(metric!(
515 name: "mz_balancer_tenant_connection_active",
516 help: "Count of opened network connections by tenant.",
517 var_labels: ["source", "tenant"]
518 ));
519 let tenant_connection_rx = registry.register(metric!(
520 name: "mz_balancer_tenant_connection_rx",
521 help: "Number of bytes received from a client for a tenant.",
522 var_labels: ["source", "tenant"],
523 ));
524 let tenant_connection_tx = registry.register(metric!(
525 name: "mz_balancer_tenant_connection_tx",
526 help: "Number of bytes sent to a client for a tenant.",
527 var_labels: ["source", "tenant"],
528 ));
529 let tenant_pgwire_sni_count = registry.register(metric!(
530 name: "mz_balancer_tenant_pgwire_sni_count",
531 help: "Count of pgwire connections that have and do not have SNI available per tenant.",
532 var_labels: ["tenant", "has_sni"],
533 ));
534 Self {
535 connection_status,
536 active_connections,
537 tenant_connections,
538 tenant_connection_rx,
539 tenant_connection_tx,
540 tenant_pgwire_sni_count,
541 }
542 }
543}
544
545#[derive(Clone, Debug)]
546struct ServerMetrics {
547 inner: ServerMetricsConfig,
548 source: &'static str,
549}
550
551impl ServerMetrics {
552 fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
553 let self_ = Self { inner, source };
554
555 self_.connection_status(false);
558 self_.connection_status(true);
559 drop(self_.active_connections());
560
561 self_
562 }
563
564 fn connection_status(&self, is_ok: bool) -> IntCounter {
565 self.inner
566 .connection_status
567 .with_label_values(&[self.source, Self::status_label(is_ok)])
568 }
569
570 fn active_connections(&self) -> GaugeGuard {
571 self.inner
572 .active_connections
573 .with_label_values(&[self.source])
574 .into()
575 }
576
577 fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
578 self.inner
579 .tenant_connections
580 .with_label_values(&[self.source, tenant])
581 .into()
582 }
583
584 fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
585 self.inner
586 .tenant_connection_rx
587 .with_label_values(&[self.source, tenant])
588 }
589
590 fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
591 self.inner
592 .tenant_connection_tx
593 .with_label_values(&[self.source, tenant])
594 }
595
596 fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
597 self.inner
598 .tenant_pgwire_sni_count
599 .with_label_values(&[tenant, &has_sni.to_string()])
600 }
601
602 fn status_label(is_ok: bool) -> &'static str {
603 if is_ok { "success" } else { "error" }
604 }
605}
606
607pub enum CancellationResolver {
608 Directory(PathBuf),
609 Static(String),
610}
611
612struct PgwireBalancer {
613 tls: Option<ReloadingTlsConfig>,
614 internal_tls: bool,
615 cancellation_resolver: Arc<CancellationResolver>,
616 resolver: Arc<Resolver>,
617 metrics: ServerMetrics,
618 now: NowFn,
619}
620
621impl PgwireBalancer {
622 #[mz_ore::instrument(level = "debug")]
623 async fn run<'a, A>(
624 conn: &'a mut FramedConn<A>,
625 version: i32,
626 params: BTreeMap<String, String>,
627 resolver: &Resolver,
628 tls_mode: Option<TlsMode>,
629 internal_tls: bool,
630 metrics: &ServerMetrics,
631 ) -> Result<(), io::Error>
632 where
633 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
634 {
635 if version != VERSION_3 {
636 return conn
637 .send(ErrorResponse::fatal(
638 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
639 "server does not support the client's requested protocol version",
640 ))
641 .await;
642 }
643
644 let Some(user) = params.get("user") else {
645 return conn
646 .send(ErrorResponse::fatal(
647 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
648 "user parameter required",
649 ))
650 .await;
651 };
652
653 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
654 return conn.send(err).await;
655 }
656
657 let resolved = match resolver.resolve(conn, user, metrics).await {
658 Ok(v) => v,
659 Err(err) => {
660 return conn
661 .send(ErrorResponse::fatal(
662 SqlState::INVALID_PASSWORD,
663 err.to_string(),
664 ))
665 .await;
666 }
667 };
668
669 let _active_guard = resolved
670 .tenant
671 .as_ref()
672 .map(|tenant| metrics.tenant_connections(tenant));
673 let Ok(mut mz_stream) =
674 Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
675 else {
676 return Ok(());
677 };
678
679 let mut client_counter = CountingConn::new(conn.inner_mut());
680
681 let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
684 if let Some(tenant) = &resolved.tenant {
685 metrics
686 .tenant_connections_tx(tenant)
687 .inc_by(u64::cast_from(client_counter.written));
688 metrics
689 .tenant_connections_rx(tenant)
690 .inc_by(u64::cast_from(client_counter.read));
691 }
692 res?;
693
694 Ok(())
695 }
696
697 #[mz_ore::instrument(level = "debug")]
698 async fn init_stream<'a, A>(
699 conn: &'a mut FramedConn<A>,
700 envd_addr: SocketAddr,
701 password: Option<String>,
702 params: BTreeMap<String, String>,
703 internal_tls: bool,
704 ) -> Result<Conn<TcpStream>, anyhow::Error>
705 where
706 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
707 {
708 let mut mz_stream = TcpStream::connect(envd_addr).await?;
709 let mut buf = BytesMut::new();
710
711 let mut mz_stream = if internal_tls {
712 FrontendStartupMessage::SslRequest.encode(&mut buf)?;
713 mz_stream.write_all(&buf).await?;
714 buf.clear();
715 let mut maybe_ssl_request_response = [0u8; 1];
716 let nread =
717 netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
718 if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
719 let mut builder =
721 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
722 builder.set_verify(SslVerifyMode::NONE);
724 let mut ssl = builder
725 .build()
726 .configure()?
727 .into_ssl(&envd_addr.to_string())?;
728 ssl.set_connect_state();
729 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
730 } else {
731 Conn::Unencrypted(mz_stream)
732 }
733 } else {
734 Conn::Unencrypted(mz_stream)
735 };
736
737 let startup = FrontendStartupMessage::Startup {
739 version: VERSION_3,
740 params,
741 };
742 startup.encode(&mut buf)?;
743 mz_stream.write_all(&buf).await?;
744 let client_stream = conn.inner_mut();
745
746 let mut maybe_auth_frame = [0; 1 + 4 + 4];
750 let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
751 const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
754 if nread == AUTH_PASSWORD_CLEARTEXT.len()
755 && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
756 && password.is_some()
757 {
758 let Some(password) = password else {
760 unreachable!("verified some above");
761 };
762 let password = FrontendMessage::Password { password };
763 buf.clear();
764 password.encode(&mut buf)?;
765 mz_stream.write_all(&buf).await?;
766 mz_stream.flush().await?;
767 } else {
768 client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
771 }
772
773 Ok(mz_stream)
774 }
775}
776
777impl mz_server_core::Server for PgwireBalancer {
778 const NAME: &'static str = "pgwire_balancer";
779
780 fn handle_connection(
781 &self,
782 conn: Connection,
783 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
784 ) -> mz_server_core::ConnectionHandler {
785 let tls = self.tls.clone();
786 let internal_tls = self.internal_tls;
787 let resolver = Arc::clone(&self.resolver);
788 let inner_metrics = self.metrics.clone();
789 let outer_metrics = self.metrics.clone();
790 let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
791 let conn_uuid = epoch_to_uuid_v7(&(self.now)());
792 let peer_addr = conn.peer_addr();
793 conn.uuid_handle().set(conn_uuid);
794 Box::pin(async move {
795 let active_guard = outer_metrics.active_connections();
798 let result: Result<(), anyhow::Error> = async move {
799 let mut conn = Conn::Unencrypted(conn);
800 loop {
801 let message = decode_startup(&mut conn).await?;
802 conn = match message {
803 None => return Ok(()),
807
808 Some(FrontendStartupMessage::Startup {
809 version,
810 mut params,
811 }) => {
812 let mut conn = FramedConn::new(conn);
813 let peer_addr = match peer_addr {
814 Ok(addr) => addr.ip(),
815 Err(e) => {
816 error!("Invalid peer_addr {:?}", e);
817 return Ok(conn
818 .send(ErrorResponse::fatal(
819 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
820 "invalid peer address",
821 ))
822 .await?);
823 }
824 };
825 debug!(%conn_uuid, %peer_addr, "starting new pgwire connection in balancer");
826 let prev =
827 params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
828 if prev.is_some() {
829 return Ok(conn
830 .send(ErrorResponse::fatal(
831 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
832 format!("invalid parameter '{CONN_UUID_KEY}'"),
833 ))
834 .await?);
835 }
836
837 if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
838 return Ok(conn
839 .send(ErrorResponse::fatal(
840 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
841 format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
842 ))
843 .await?);
844 };
845
846 Self::run(
847 &mut conn,
848 version,
849 params,
850 &resolver,
851 tls.map(|tls| tls.mode),
852 internal_tls,
853 &inner_metrics,
854 )
855 .await?;
856 conn.flush().await?;
857 return Ok(());
858 }
859
860 Some(FrontendStartupMessage::CancelRequest {
861 conn_id,
862 secret_key,
863 }) => {
864 spawn(|| "cancel request", async move {
865 cancel_request(conn_id, secret_key, &cancellation_resolver).await;
866 });
867 return Ok(());
870 }
871
872 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
873 (Conn::Unencrypted(mut conn), Some(tls)) => {
874 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
875 let mut ssl_stream =
876 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
877 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
878 let _ = ssl_stream.get_mut().shutdown().await;
879 return Err(e.into());
880 }
881 Conn::Ssl(ssl_stream)
882 }
883 (mut conn, _) => {
884 conn.write_all(&[REJECT_ENCRYPTION]).await?;
885 conn
886 }
887 },
888
889 Some(FrontendStartupMessage::GssEncRequest) => {
890 conn.write_all(&[REJECT_ENCRYPTION]).await?;
891 conn
892 }
893 }
894 }
895 }
896 .await;
897 drop(active_guard);
898 outer_metrics.connection_status(result.is_ok()).inc();
899 Ok(())
900 })
901 }
902}
903
904struct CountingConn<C> {
906 inner: C,
907 read: usize,
908 written: usize,
909}
910
911impl<C> CountingConn<C> {
912 fn new(inner: C) -> Self {
913 CountingConn {
914 inner,
915 read: 0,
916 written: 0,
917 }
918 }
919}
920
921impl<C> AsyncRead for CountingConn<C>
922where
923 C: AsyncRead + Unpin,
924{
925 fn poll_read(
926 self: Pin<&mut Self>,
927 cx: &mut std::task::Context<'_>,
928 buf: &mut io::ReadBuf<'_>,
929 ) -> std::task::Poll<std::io::Result<()>> {
930 let counter = self.get_mut();
931 let pin = Pin::new(&mut counter.inner);
932 let bytes = buf.filled().len();
933 let poll = pin.poll_read(cx, buf);
934 let bytes = buf.filled().len() - bytes;
935 if let std::task::Poll::Ready(Ok(())) = poll {
936 counter.read += bytes
937 }
938 poll
939 }
940}
941
942impl<C> AsyncWrite for CountingConn<C>
943where
944 C: AsyncWrite + Unpin,
945{
946 fn poll_write(
947 self: Pin<&mut Self>,
948 cx: &mut std::task::Context<'_>,
949 buf: &[u8],
950 ) -> std::task::Poll<Result<usize, std::io::Error>> {
951 let counter = self.get_mut();
952 let pin = Pin::new(&mut counter.inner);
953 let poll = pin.poll_write(cx, buf);
954 if let std::task::Poll::Ready(Ok(bytes)) = poll {
955 counter.written += bytes
956 }
957 poll
958 }
959
960 fn poll_flush(
961 self: Pin<&mut Self>,
962 cx: &mut std::task::Context<'_>,
963 ) -> std::task::Poll<Result<(), std::io::Error>> {
964 let counter = self.get_mut();
965 let pin = Pin::new(&mut counter.inner);
966 pin.poll_flush(cx)
967 }
968
969 fn poll_shutdown(
970 self: Pin<&mut Self>,
971 cx: &mut std::task::Context<'_>,
972 ) -> std::task::Poll<Result<(), std::io::Error>> {
973 let counter = self.get_mut();
974 let pin = Pin::new(&mut counter.inner);
975 pin.poll_shutdown(cx)
976 }
977}
978
979async fn cancel_request(
996 conn_id: u32,
997 secret_key: u32,
998 cancellation_resolver: &CancellationResolver,
999) {
1000 let suffix = conn_id_org_uuid(conn_id);
1001 let contents = match cancellation_resolver {
1002 CancellationResolver::Directory(dir) => {
1003 let path = dir.join(&suffix);
1004 match std::fs::read_to_string(&path) {
1005 Ok(contents) => contents,
1006 Err(err) => {
1007 error!("could not read cancel file {path:?}: {err}");
1008 return;
1009 }
1010 }
1011 }
1012 CancellationResolver::Static(addr) => addr.to_owned(),
1013 };
1014 let mut all_ips = Vec::new();
1015 for addr in contents.lines() {
1016 let addr = addr.trim();
1017 if addr.is_empty() {
1018 continue;
1019 }
1020 match tokio::net::lookup_host(addr).await {
1021 Ok(ips) => all_ips.extend(ips),
1022 Err(err) => {
1023 error!("{addr} failed resolution: {err}");
1024 }
1025 }
1026 }
1027 let mut buf = BytesMut::with_capacity(16);
1028 let msg = FrontendStartupMessage::CancelRequest {
1029 conn_id,
1030 secret_key,
1031 };
1032 msg.encode(&mut buf).expect("must encode");
1033 let buf = buf.freeze();
1034 for ip in all_ips {
1035 debug!("cancelling {suffix} to {ip}");
1036 let buf = buf.clone();
1037 spawn(|| "cancel request for ip", async move {
1038 let send = async {
1039 let mut stream = TcpStream::connect(&ip).await?;
1040 stream.write_all(&buf).await?;
1041 stream.shutdown().await?;
1042 Ok::<_, io::Error>(())
1043 };
1044 if let Err(err) = send.await {
1045 error!("error mirroring cancel to {ip}: {err}");
1046 }
1047 });
1048 }
1049}
1050
1051struct HttpsBalancer {
1052 resolver: Arc<StubResolver>,
1053 tls: Option<ReloadingSslContext>,
1054 resolve_template: Arc<str>,
1055 port: u16,
1056 metrics: Arc<ServerMetrics>,
1057 configs: ConfigSet,
1058 internal_tls: bool,
1059}
1060
1061impl HttpsBalancer {
1062 async fn resolve(
1063 resolver: &StubResolver,
1064 resolve_template: &str,
1065 port: u16,
1066 servername: Option<&str>,
1067 ) -> Result<ResolvedAddr, anyhow::Error> {
1068 let addr = match &servername {
1069 Some(servername) => resolve_template.replace("{}", servername),
1070 None => resolve_template.to_string(),
1071 };
1072 debug!("https address: {addr}");
1073
1074 let tenant = resolver.tenant(&addr).await;
1087
1088 let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1090
1091 Ok(ResolvedAddr {
1092 addr: envd_addr,
1093 password: None,
1094 tenant,
1095 })
1096 }
1097}
1098
1099trait StubResolverExt {
1100 async fn tenant(&self, addr: &str) -> Option<String>;
1101}
1102
1103impl StubResolverExt for StubResolver {
1104 async fn tenant(&self, addr: &str) -> Option<String> {
1107 let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1108 return None;
1109 };
1110 debug!("resolving tenant for {:?}", addr);
1111 let lookup = self.query((dname, Rtype::CNAME)).await;
1113 if let Ok(lookup) = lookup {
1114 if let Ok(answer) = lookup.answer() {
1115 let res = answer.limit_to::<AllRecordData<_, _>>();
1116 for record in res {
1117 let Ok(record) = record else {
1118 continue;
1119 };
1120 if record.rtype() != Rtype::CNAME {
1121 continue;
1122 }
1123 let cname = record.data();
1124 let cname = cname.to_string();
1125 debug!("cname: {cname}");
1126 return extract_tenant_from_cname(&cname);
1127 }
1128 }
1129 }
1130 None
1131 }
1132}
1133
1134fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1136 let mut parts = cname.split('.');
1137 let _service = parts.next();
1138 let Some(namespace) = parts.next() else {
1139 return None;
1140 };
1141 let Some((_, namespace)) = namespace.split_once('-') else {
1143 return None;
1144 };
1145 let Some((tenant, _)) = namespace.rsplit_once('-') else {
1147 return None;
1148 };
1149 let Ok(tenant) = Uuid::parse_str(tenant) else {
1152 error!("cname tenant not a uuid: {tenant}");
1153 return None;
1154 };
1155 Some(tenant.to_string())
1156}
1157
1158impl mz_server_core::Server for HttpsBalancer {
1159 const NAME: &'static str = "https_balancer";
1160
1161 fn handle_connection(
1163 &self,
1164 conn: Connection,
1165 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1166 ) -> mz_server_core::ConnectionHandler {
1167 let tls_context = self.tls.clone();
1168 let internal_tls = self.internal_tls.clone();
1169 let resolver = Arc::clone(&self.resolver);
1170 let resolve_template = Arc::clone(&self.resolve_template);
1171 let port = self.port;
1172 let inner_metrics = Arc::clone(&self.metrics);
1173 let outer_metrics = Arc::clone(&self.metrics);
1174 let peer_addr = conn.peer_addr();
1175 let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1176 Box::pin(async move {
1177 let active_guard = inner_metrics.active_connections();
1178 let result: Result<_, anyhow::Error> = Box::pin(async move {
1179 let peer_addr = peer_addr.context("fetching peer addr")?;
1180 let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1181 match tls_context {
1182 Some(tls_context) => {
1183 let mut ssl_stream =
1184 SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1185 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1186 let _ = ssl_stream.get_mut().shutdown().await;
1187 return Err(e.into());
1188 }
1189 let servername: Option<String> =
1190 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1191 match sn.split_once('.') {
1192 Some((left, _right)) => left,
1193 None => sn,
1194 }
1195 .into()
1196 });
1197 debug!("Found sni servername: {servername:?} (https)");
1198 (Box::new(ssl_stream), servername)
1199 }
1200 _ => (Box::new(conn), None),
1201 };
1202 let resolved =
1203 Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1204 .await?;
1205 let inner_active_guard = resolved
1206 .tenant
1207 .as_ref()
1208 .map(|tenant| inner_metrics.tenant_connections(tenant));
1209
1210 let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1211
1212 if inject_proxy_headers {
1213 let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1215 let header = ProxyHeader::with_address(addrs);
1216 let mut buf = [0u8; 1024];
1217 let len = header.encode_to_slice_v2(&mut buf)?;
1218 mz_stream.write_all(&buf[..len]).await?;
1219 }
1220
1221 let mut mz_stream = if internal_tls {
1222 let mut builder =
1224 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1225 builder.set_verify(SslVerifyMode::NONE);
1227 let mut ssl = builder
1228 .build()
1229 .configure()?
1230 .into_ssl(&resolved.addr.to_string())?;
1231 ssl.set_connect_state();
1232 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1233 } else {
1234 Conn::Unencrypted(mz_stream)
1235 };
1236
1237 let mut client_counter = CountingConn::new(client_stream);
1238
1239 let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1243 if let Some(tenant) = &resolved.tenant {
1244 inner_metrics
1245 .tenant_connections_tx(tenant)
1246 .inc_by(u64::cast_from(client_counter.written));
1247 inner_metrics
1248 .tenant_connections_rx(tenant)
1249 .inc_by(u64::cast_from(client_counter.read));
1250 }
1251 drop(inner_active_guard);
1252 Ok(())
1253 })
1254 .await;
1255 drop(active_guard);
1256 outer_metrics.connection_status(result.is_ok()).inc();
1257 if let Err(e) = result {
1258 debug!("connection error: {e}");
1259 }
1260 Ok(())
1261 })
1262 }
1263}
1264
1265#[derive(Debug)]
1266pub struct SniResolver {
1267 pub resolver: StubResolver,
1268 pub template: String,
1269 pub port: u16,
1270}
1271
1272trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1273impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1274
1275#[derive(Debug)]
1276pub enum Resolver {
1277 Static(String),
1278 MultiTenant(FronteggResolver, Option<SniResolver>),
1279}
1280
1281impl Resolver {
1282 async fn resolve<A>(
1283 &self,
1284 conn: &mut FramedConn<A>,
1285 user: &str,
1286 metrics: &ServerMetrics,
1287 ) -> Result<ResolvedAddr, anyhow::Error>
1288 where
1289 A: AsyncRead + AsyncWrite + Unpin,
1290 {
1291 match self {
1292 Resolver::MultiTenant(
1293 FronteggResolver {
1294 auth,
1295 addr_template,
1296 },
1297 sni_resolver,
1298 ) => {
1299 let servername = match conn.inner() {
1300 Conn::Ssl(ssl_stream) => {
1301 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1302 match sn.split_once('.') {
1303 Some((left, _right)) => left,
1304 None => sn,
1305 }
1306 })
1307 }
1308 Conn::Unencrypted(_) => None,
1309 };
1310 let has_sni = servername.is_some();
1311 let resolved_addr = match (servername, sni_resolver) {
1313 (
1314 Some(servername),
1315 Some(SniResolver {
1316 resolver: stub_resolver,
1317 template: sni_addr_template,
1318 port,
1319 }),
1320 ) => {
1321 let sni_addr = sni_addr_template.replace("{}", servername);
1322 let tenant = stub_resolver.tenant(&sni_addr).await;
1323 let sni_addr = format!("{sni_addr}:{port}");
1324 let addr = lookup(&sni_addr).await?;
1325 if tenant.is_some() {
1326 debug!("SNI header found for tenant {:?}", tenant);
1327 }
1328 ResolvedAddr {
1329 addr,
1330 password: None,
1331 tenant,
1332 }
1333 }
1334 _ => {
1335 conn.send(BackendMessage::AuthenticationCleartextPassword)
1336 .await?;
1337 conn.flush().await?;
1338 let password = match conn.recv().await? {
1339 Some(FrontendMessage::Password { password }) => password,
1340 _ => anyhow::bail!("expected Password message"),
1341 };
1342
1343 let auth_response = auth.authenticate(user, &password).await;
1344 let auth_session = match auth_response {
1345 Ok(auth_session) => auth_session,
1346 Err(e) => {
1347 warn!("pgwire connection failed authentication: {}", e);
1348 anyhow::bail!("invalid password");
1350 }
1351 };
1352
1353 let addr =
1354 addr_template.replace("{}", &auth_session.tenant_id().to_string());
1355 let addr = lookup(&addr).await?;
1356 let tenant = Some(auth_session.tenant_id().to_string());
1357 if tenant.is_some() {
1358 debug!("SNI header NOT found for tenant {:?}", tenant);
1359 }
1360 ResolvedAddr {
1361 addr,
1362 password: Some(password),
1363 tenant,
1364 }
1365 }
1366 };
1367 metrics
1368 .tenant_pgwire_sni_count(
1369 resolved_addr.tenant.as_deref().unwrap_or("unknown"),
1370 has_sni,
1371 )
1372 .inc();
1373
1374 Ok(resolved_addr)
1375 }
1376 Resolver::Static(addr) => {
1377 let addr = lookup(addr).await?;
1378 Ok(ResolvedAddr {
1379 addr,
1380 password: None,
1381 tenant: None,
1382 })
1383 }
1384 }
1385 }
1386}
1387
1388async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1390 let mut addrs = tokio::net::lookup_host(name).await?;
1391 match addrs.next() {
1392 Some(addr) => Ok(addr),
1393 None => {
1394 error!("{name} did not resolve to any addresses");
1395 anyhow::bail!("internal error")
1396 }
1397 }
1398}
1399
1400#[derive(Debug)]
1401pub struct FronteggResolver {
1402 pub auth: FronteggAuthentication,
1403 pub addr_template: String,
1404}
1405
1406#[derive(Debug)]
1407struct ResolvedAddr {
1408 addr: SocketAddr,
1409 password: Option<String>,
1410 tenant: Option<String>,
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use super::*;
1416
1417 #[mz_ore::test]
1418 fn test_tenant() {
1419 let tests = vec![
1420 ("", None),
1421 (
1422 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1423 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1424 ),
1425 (
1426 "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1428 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1429 ),
1430 (
1431 "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1433 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1434 ),
1435 (
1436 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1438 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1439 ),
1440 (
1441 "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1443 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1444 ),
1445 (
1446 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1448 None,
1449 ),
1450 (
1451 "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1453 None,
1454 ),
1455 (
1456 "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1458 None,
1459 ),
1460 ];
1461 for (name, expect) in tests {
1462 let cname = extract_tenant_from_cname(name);
1463 assert_eq!(
1464 cname.as_deref(),
1465 expect,
1466 "{name} got {cname:?} expected {expect:?}"
1467 );
1468 }
1469 }
1470}