1mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55 FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58 Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59 ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78 INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79 has_tracing_config_update, tracing_config,
80};
81
82pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86 build_version: Version,
88 internal_http_listen_addr: SocketAddr,
90 pgwire_listen_addr: SocketAddr,
92 https_listen_addr: SocketAddr,
94 cancellation_resolver: CancellationResolver,
96 resolver: Resolver,
98 https_sni_addr_template: String,
99 tls: Option<TlsCertConfig>,
100 internal_tls: bool,
101 metrics_registry: MetricsRegistry,
102 reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103 launchdarkly_sdk_key: Option<String>,
104 config_sync_file_path: Option<PathBuf>,
105 config_sync_timeout: Duration,
106 config_sync_loop_interval: Option<Duration>,
107 cloud_provider: Option<String>,
108 cloud_provider_region: Option<String>,
109 tracing_handle: TracingHandle,
110 default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114 pub fn new(
115 build_info: &BuildInfo,
116 internal_http_listen_addr: SocketAddr,
117 pgwire_listen_addr: SocketAddr,
118 https_listen_addr: SocketAddr,
119 cancellation_resolver: CancellationResolver,
120 resolver: Resolver,
121 https_sni_addr_template: String,
122 tls: Option<TlsCertConfig>,
123 internal_tls: bool,
124 metrics_registry: MetricsRegistry,
125 reload_certs: ReloadTrigger,
126 launchdarkly_sdk_key: Option<String>,
127 config_sync_file: Option<PathBuf>,
128 config_sync_timeout: Duration,
129 config_sync_loop_interval: Option<Duration>,
130 cloud_provider: Option<String>,
131 cloud_provider_region: Option<String>,
132 tracing_handle: TracingHandle,
133 default_configs: Vec<(String, String)>,
134 ) -> Self {
135 Self {
136 build_version: build_info.semver_version(),
137 internal_http_listen_addr,
138 pgwire_listen_addr,
139 https_listen_addr,
140 cancellation_resolver,
141 resolver,
142 https_sni_addr_template,
143 tls,
144 internal_tls,
145 metrics_registry,
146 reload_certs,
147 launchdarkly_sdk_key,
148 config_sync_file_path: config_sync_file,
149 config_sync_timeout,
150 config_sync_loop_interval,
151 cloud_provider,
152 cloud_provider_region,
153 tracing_handle,
154 default_configs,
155 }
156 }
157}
158
159#[derive(Debug)]
161pub struct BalancerMetrics {
162 _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166 pub fn new(cfg: &BalancerConfig) -> Self {
168 let start = Instant::now();
169 let uptime = cfg.metrics_registry.register_computed_gauge(
170 metric!(
171 name: "mz_balancer_metadata_seconds",
172 help: "server uptime, labels are build metadata",
173 const_labels: {
174 "version" => cfg.build_version,
175 "build_type" => if cfg!(release) { "release" } else { "debug" }
176 },
177 ),
178 move || start.elapsed().as_secs_f64(),
179 );
180 BalancerMetrics { _uptime: uptime }
181 }
182}
183
184pub struct BalancerService {
185 cfg: BalancerConfig,
186 pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187 pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188 pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189 _metrics: BalancerMetrics,
190 configs: ConfigSet,
191}
192
193impl BalancerService {
194 pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195 let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196 let https = listen(&cfg.https_listen_addr).await?;
197 let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198 let metrics = BalancerMetrics::new(&cfg);
199 let mut configs = ConfigSet::default();
200 configs = dyncfgs::all_dyncfgs(configs);
201 dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202 let tracing_handle = cfg.tracing_handle.clone();
203 match (
205 cfg.launchdarkly_sdk_key.as_deref(),
206 cfg.config_sync_file_path.as_deref(),
207 ) {
208 (Some(key), None) => {
209 mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210 configs.clone(),
211 &BUILD_INFO,
212 |builder| {
213 let region = cfg
214 .cloud_provider_region
215 .clone()
216 .unwrap_or_else(|| String::from("unknown"));
217 if let Some(provider) = cfg.cloud_provider.clone() {
218 builder.add_context(
219 ld::ContextBuilder::new(format!(
220 "{}/{}/{}",
221 provider, region, cfg.build_version
222 ))
223 .kind("balancer")
224 .set_string("provider", provider)
225 .set_string("region", region)
226 .set_string("version", cfg.build_version.to_string())
227 .build()
228 .map_err(|e| anyhow::anyhow!(e))?,
229 );
230 } else {
231 builder.add_context(
232 ld::ContextBuilder::new(format!(
233 "{}/{}/{}",
234 "unknown", region, cfg.build_version
235 ))
236 .anonymous(true) .kind("balancer")
238 .set_string("provider", "unknown")
239 .set_string("region", region)
240 .set_string("version", cfg.build_version.to_string())
241 .build()
242 .map_err(|e| anyhow::anyhow!(e))?,
243 );
244 }
245 Ok(())
246 },
247 Some(key),
248 cfg.config_sync_timeout,
249 cfg.config_sync_loop_interval,
250 move |updates, configs| {
251 if has_tracing_config_update(updates) {
252 match tracing_config(configs) {
253 Ok(parameters) => parameters.apply(&tracing_handle),
254 Err(err) => warn!("unable to update tracing: {err}"),
255 }
256 }
257 },
258 )
259 .await
260 .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"))
261 .ok();
262 }
263 (None, Some(path)) => {
264 mz_dyncfg_file::sync_file_to_configset(
265 configs.clone(),
266 path,
267 cfg.config_sync_timeout,
268 cfg.config_sync_loop_interval,
269 move |updates, configs| {
270 if has_tracing_config_update(updates) {
271 match tracing_config(configs) {
272 Ok(parameters) => parameters.apply(&tracing_handle),
273 Err(err) => warn!("unable to update tracing: {err}"),
274 }
275 }
276 },
277 )
278 .await
279 .inspect_err(|e| warn!("File config sync error: {e}"))
285 .ok();
286 }
287 (Some(_), Some(_)) => panic!(
288 "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
289 ),
290 (None, None) => {}
291 };
292 Ok(Self {
293 cfg,
294 pgwire,
295 https,
296 internal_http,
297 _metrics: metrics,
298 configs,
299 })
300 }
301
302 pub async fn serve(self) -> Result<(), anyhow::Error> {
303 let (pgwire_tls, https_tls) = match &self.cfg.tls {
304 Some(tls) => {
305 let context = tls.reloading_context(self.cfg.reload_certs)?;
306 (
307 Some(ReloadingTlsConfig {
308 context: context.clone(),
309 mode: TlsMode::Require,
310 }),
311 Some(context),
312 )
313 }
314 None => (None, None),
315 };
316
317 let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
318
319 let mut set = JoinSet::new();
320 let mut server_handles = Vec::new();
321 let pgwire_addr = self.pgwire.0.local_addr();
322 let https_addr = self.https.0.local_addr();
323 let internal_http_addr = self.internal_http.0.local_addr();
324
325 {
326 let pgwire = PgwireBalancer {
327 resolver: Arc::new(self.cfg.resolver),
328 cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
329 tls: pgwire_tls,
330 internal_tls: self.cfg.internal_tls,
331 metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
332 now: SYSTEM_TIME.clone(),
333 };
334 let (handle, stream) = self.pgwire;
335 server_handles.push(handle);
336 set.spawn_named(|| "pgwire_stream", {
337 let config_set = self.configs.clone();
338 async move {
339 mz_server_core::serve(ServeConfig {
340 server: pgwire,
341 conns: stream,
342 dyncfg: Some(ServeDyncfg {
343 config_set,
344 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
345 }),
346 })
347 .await;
348 warn!("pgwire server exited");
349 }
350 });
351 }
352 {
353 let Some((addr, port)) = self.cfg.https_sni_addr_template.split_once(':') else {
354 panic!("expected port in https_addr_template");
355 };
356 let port: u16 = port.parse().expect("unexpected port");
357 let resolver = StubResolver::new();
358 let https = HttpsBalancer {
359 resolver: Arc::from(resolver),
360 tls: https_tls,
361 resolve_template: Arc::from(addr),
362 port,
363 metrics: Arc::from(ServerMetrics::new(metrics, "https")),
364 configs: self.configs.clone(),
365 internal_tls: self.cfg.internal_tls,
366 };
367 let (handle, stream) = self.https;
368 server_handles.push(handle);
369 set.spawn_named(|| "https_stream", {
370 let config_set = self.configs.clone();
371 async move {
372 mz_server_core::serve(ServeConfig {
373 server: https,
374 conns: stream,
375 dyncfg: Some(ServeDyncfg {
376 config_set,
377 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
378 }),
379 })
380 .await;
381 warn!("https server exited");
382 }
383 });
384 }
385 {
386 let router = Router::new()
387 .route(
388 "/metrics",
389 routing::get(move || async move {
390 mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
391 }),
392 )
393 .route(
394 "/api/livez",
395 routing::get(mz_http_util::handle_liveness_check),
396 )
397 .route("/api/readyz", routing::get(handle_readiness_check));
398 let internal_http = InternalHttpServer { router };
399 let (handle, stream) = self.internal_http;
400 server_handles.push(handle);
401 set.spawn_named(|| "internal_http_stream", async move {
402 mz_server_core::serve(ServeConfig {
403 server: internal_http,
404 conns: stream,
405 dyncfg: None,
408 })
409 .await;
410 warn!("internal_http server exited");
411 });
412 }
413 #[cfg(unix)]
414 {
415 let mut sigterm =
416 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
417 set.spawn_named(|| "sigterm_handler", async move {
418 sigterm.recv().await;
419 let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
420 warn!("received signal TERM - delaying for {:?}!", wait);
421 tokio::time::sleep(wait).await;
422 warn!("sigterm delay complete, dropping server handles");
423 drop(server_handles);
424 });
425 }
426
427 println!("balancerd {} listening...", BUILD_INFO.human_version(None));
428 println!(" TLS enabled: {}", self.cfg.tls.is_some());
429 println!(" pgwire address: {}", pgwire_addr);
430 println!(" HTTPS address: {}", https_addr);
431 println!(" internal HTTP address: {}", internal_http_addr);
432
433 while let Some(res) = set.join_next().await {
435 if let Err(err) = res {
436 error!("serving task failed: {err}")
437 }
438 }
439 Ok(())
440 }
441}
442
443#[allow(clippy::unused_async)]
444async fn handle_readiness_check() -> impl IntoResponse {
445 (StatusCode::OK, "ready")
446}
447
448struct InternalHttpServer {
449 router: Router,
450}
451
452impl mz_server_core::Server for InternalHttpServer {
453 const NAME: &'static str = "internal_http";
454
455 fn handle_connection(
457 &self,
458 conn: Connection,
459 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
460 ) -> mz_server_core::ConnectionHandler {
461 let router = self.router.clone();
462 let service = hyper::service::service_fn(move |req| router.clone().call(req));
463 let conn = TokioIo::new(conn);
464
465 Box::pin(async {
466 let http = hyper::server::conn::http1::Builder::new();
467 http.serve_connection(conn, service).err_into().await
468 })
469 }
470}
471
472struct GaugeGuard {
475 gauge: IntGauge,
476}
477
478impl From<IntGauge> for GaugeGuard {
479 fn from(gauge: IntGauge) -> Self {
480 let _self = Self { gauge };
481 _self.gauge.inc();
482 _self
483 }
484}
485
486impl Drop for GaugeGuard {
487 fn drop(&mut self) {
488 self.gauge.dec();
489 }
490}
491
492#[derive(Clone, Debug)]
493struct ServerMetricsConfig {
494 connection_status: IntCounterVec,
495 active_connections: IntGaugeVec,
496 tenant_connections: IntGaugeVec,
497 tenant_connection_rx: IntCounterVec,
498 tenant_connection_tx: IntCounterVec,
499 tenant_pgwire_sni_count: IntCounterVec,
500}
501
502impl ServerMetricsConfig {
503 fn register_into(registry: &MetricsRegistry) -> Self {
504 let connection_status = registry.register(metric!(
505 name: "mz_balancer_connection_status",
506 help: "Count of completed network connections, by status",
507 var_labels: ["source", "status"],
508 ));
509 let active_connections = registry.register(metric!(
510 name: "mz_balancer_connection_active",
511 help: "Count of currently open network connections.",
512 var_labels: ["source"],
513 ));
514 let tenant_connections = registry.register(metric!(
515 name: "mz_balancer_tenant_connection_active",
516 help: "Count of opened network connections by tenant.",
517 var_labels: ["source", "tenant"]
518 ));
519 let tenant_connection_rx = registry.register(metric!(
520 name: "mz_balancer_tenant_connection_rx",
521 help: "Number of bytes received from a client for a tenant.",
522 var_labels: ["source", "tenant"],
523 ));
524 let tenant_connection_tx = registry.register(metric!(
525 name: "mz_balancer_tenant_connection_tx",
526 help: "Number of bytes sent to a client for a tenant.",
527 var_labels: ["source", "tenant"],
528 ));
529 let tenant_pgwire_sni_count = registry.register(metric!(
530 name: "mz_balancer_tenant_pgwire_sni_count",
531 help: "Count of pgwire connections that have and do not have SNI available per tenant.",
532 var_labels: ["tenant", "has_sni"],
533 ));
534 Self {
535 connection_status,
536 active_connections,
537 tenant_connections,
538 tenant_connection_rx,
539 tenant_connection_tx,
540 tenant_pgwire_sni_count,
541 }
542 }
543}
544
545#[derive(Clone, Debug)]
546struct ServerMetrics {
547 inner: ServerMetricsConfig,
548 source: &'static str,
549}
550
551impl ServerMetrics {
552 fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
553 let self_ = Self { inner, source };
554
555 self_.connection_status(false);
558 self_.connection_status(true);
559 drop(self_.active_connections());
560
561 self_
562 }
563
564 fn connection_status(&self, is_ok: bool) -> IntCounter {
565 self.inner
566 .connection_status
567 .with_label_values(&[self.source, Self::status_label(is_ok)])
568 }
569
570 fn active_connections(&self) -> GaugeGuard {
571 self.inner
572 .active_connections
573 .with_label_values(&[self.source])
574 .into()
575 }
576
577 fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
578 self.inner
579 .tenant_connections
580 .with_label_values(&[self.source, tenant])
581 .into()
582 }
583
584 fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
585 self.inner
586 .tenant_connection_rx
587 .with_label_values(&[self.source, tenant])
588 }
589
590 fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
591 self.inner
592 .tenant_connection_tx
593 .with_label_values(&[self.source, tenant])
594 }
595
596 fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
597 self.inner
598 .tenant_pgwire_sni_count
599 .with_label_values(&[tenant, &has_sni.to_string()])
600 }
601
602 fn status_label(is_ok: bool) -> &'static str {
603 if is_ok { "success" } else { "error" }
604 }
605}
606
607pub enum CancellationResolver {
608 Directory(PathBuf),
609 Static(String),
610}
611
612struct PgwireBalancer {
613 tls: Option<ReloadingTlsConfig>,
614 internal_tls: bool,
615 cancellation_resolver: Arc<CancellationResolver>,
616 resolver: Arc<Resolver>,
617 metrics: ServerMetrics,
618 now: NowFn,
619}
620
621impl PgwireBalancer {
622 #[mz_ore::instrument(level = "debug")]
623 async fn run<'a, A>(
624 conn: &'a mut FramedConn<A>,
625 version: i32,
626 params: BTreeMap<String, String>,
627 resolver: &Resolver,
628 tls_mode: Option<TlsMode>,
629 internal_tls: bool,
630 metrics: &ServerMetrics,
631 ) -> Result<(), io::Error>
632 where
633 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
634 {
635 if version != VERSION_3 {
636 return conn
637 .send(ErrorResponse::fatal(
638 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
639 "server does not support the client's requested protocol version",
640 ))
641 .await;
642 }
643
644 let Some(user) = params.get("user") else {
645 return conn
646 .send(ErrorResponse::fatal(
647 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
648 "user parameter required",
649 ))
650 .await;
651 };
652
653 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
654 return conn.send(err).await;
655 }
656
657 let resolved = match resolver.resolve(conn, user, metrics).await {
658 Ok(v) => v,
659 Err(err) => {
660 return conn
661 .send(ErrorResponse::fatal(
662 SqlState::INVALID_PASSWORD,
663 err.to_string(),
664 ))
665 .await;
666 }
667 };
668
669 let _active_guard = resolved
670 .tenant
671 .as_ref()
672 .map(|tenant| metrics.tenant_connections(tenant));
673 let Ok(mut mz_stream) =
674 Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
675 else {
676 return Ok(());
677 };
678
679 let mut client_counter = CountingConn::new(conn.inner_mut());
680
681 let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
684 if let Some(tenant) = &resolved.tenant {
685 metrics
686 .tenant_connections_tx(tenant)
687 .inc_by(u64::cast_from(client_counter.written));
688 metrics
689 .tenant_connections_rx(tenant)
690 .inc_by(u64::cast_from(client_counter.read));
691 }
692 res?;
693
694 Ok(())
695 }
696
697 #[mz_ore::instrument(level = "debug")]
698 async fn init_stream<'a, A>(
699 conn: &'a mut FramedConn<A>,
700 envd_addr: SocketAddr,
701 password: Option<String>,
702 params: BTreeMap<String, String>,
703 internal_tls: bool,
704 ) -> Result<Conn<TcpStream>, anyhow::Error>
705 where
706 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
707 {
708 let mut mz_stream = TcpStream::connect(envd_addr).await?;
709 let mut buf = BytesMut::new();
710
711 let mut mz_stream = if internal_tls {
712 FrontendStartupMessage::SslRequest.encode(&mut buf)?;
713 mz_stream.write_all(&buf).await?;
714 buf.clear();
715 let mut maybe_ssl_request_response = [0u8; 1];
716 let nread =
717 netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
718 if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
719 let mut builder =
721 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
722 builder.set_verify(SslVerifyMode::NONE);
724 let mut ssl = builder
725 .build()
726 .configure()?
727 .into_ssl(&envd_addr.to_string())?;
728 ssl.set_connect_state();
729 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
730 } else {
731 Conn::Unencrypted(mz_stream)
732 }
733 } else {
734 Conn::Unencrypted(mz_stream)
735 };
736
737 let startup = FrontendStartupMessage::Startup {
739 version: VERSION_3,
740 params,
741 };
742 startup.encode(&mut buf)?;
743 mz_stream.write_all(&buf).await?;
744 let client_stream = conn.inner_mut();
745
746 if password.is_none() {
762 return Ok(mz_stream);
763 }
764
765 let mut maybe_auth_frame = [0; 1 + 4 + 4];
769 let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
770 const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
773 if nread == AUTH_PASSWORD_CLEARTEXT.len()
774 && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
775 && password.is_some()
776 {
777 let Some(password) = password else {
779 unreachable!("verified some above");
780 };
781 let password = FrontendMessage::Password { password };
782 buf.clear();
783 password.encode(&mut buf)?;
784 mz_stream.write_all(&buf).await?;
785 mz_stream.flush().await?;
786 } else {
787 client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
790 }
791
792 Ok(mz_stream)
793 }
794}
795
796impl mz_server_core::Server for PgwireBalancer {
797 const NAME: &'static str = "pgwire_balancer";
798
799 fn handle_connection(
800 &self,
801 conn: Connection,
802 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
803 ) -> mz_server_core::ConnectionHandler {
804 let tls = self.tls.clone();
805 let internal_tls = self.internal_tls;
806 let resolver = Arc::clone(&self.resolver);
807 let inner_metrics = self.metrics.clone();
808 let outer_metrics = self.metrics.clone();
809 let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
810 let conn_uuid = epoch_to_uuid_v7(&(self.now)());
811 let peer_addr = conn.peer_addr();
812 conn.uuid_handle().set(conn_uuid);
813 Box::pin(async move {
814 let active_guard = outer_metrics.active_connections();
817 let result: Result<(), anyhow::Error> = async move {
818 let mut conn = Conn::Unencrypted(conn);
819 loop {
820 let message = decode_startup(&mut conn).await?;
821 conn = match message {
822 None => return Ok(()),
826
827 Some(FrontendStartupMessage::Startup {
828 version,
829 mut params,
830 }) => {
831 let mut conn = FramedConn::new(conn);
832 let peer_addr = match peer_addr {
833 Ok(addr) => addr.ip(),
834 Err(e) => {
835 error!("Invalid peer_addr {:?}", e);
836 return Ok(conn
837 .send(ErrorResponse::fatal(
838 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
839 "invalid peer address",
840 ))
841 .await?);
842 }
843 };
844 debug!(%conn_uuid, %peer_addr, "starting new pgwire connection in balancer");
845 let prev =
846 params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
847 if prev.is_some() {
848 return Ok(conn
849 .send(ErrorResponse::fatal(
850 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
851 format!("invalid parameter '{CONN_UUID_KEY}'"),
852 ))
853 .await?);
854 }
855
856 if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
857 return Ok(conn
858 .send(ErrorResponse::fatal(
859 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
860 format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
861 ))
862 .await?);
863 };
864
865 Self::run(
866 &mut conn,
867 version,
868 params,
869 &resolver,
870 tls.map(|tls| tls.mode),
871 internal_tls,
872 &inner_metrics,
873 )
874 .await?;
875 conn.flush().await?;
876 return Ok(());
877 }
878
879 Some(FrontendStartupMessage::CancelRequest {
880 conn_id,
881 secret_key,
882 }) => {
883 spawn(|| "cancel request", async move {
884 cancel_request(conn_id, secret_key, &cancellation_resolver).await;
885 });
886 return Ok(());
889 }
890
891 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
892 (Conn::Unencrypted(mut conn), Some(tls)) => {
893 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
894 let mut ssl_stream =
895 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
896 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
897 let _ = ssl_stream.get_mut().shutdown().await;
898 return Err(e.into());
899 }
900 Conn::Ssl(ssl_stream)
901 }
902 (mut conn, _) => {
903 conn.write_all(&[REJECT_ENCRYPTION]).await?;
904 conn
905 }
906 },
907
908 Some(FrontendStartupMessage::GssEncRequest) => {
909 conn.write_all(&[REJECT_ENCRYPTION]).await?;
910 conn
911 }
912 }
913 }
914 }
915 .await;
916 drop(active_guard);
917 outer_metrics.connection_status(result.is_ok()).inc();
918 Ok(())
919 })
920 }
921}
922
923struct CountingConn<C> {
925 inner: C,
926 read: usize,
927 written: usize,
928}
929
930impl<C> CountingConn<C> {
931 fn new(inner: C) -> Self {
932 CountingConn {
933 inner,
934 read: 0,
935 written: 0,
936 }
937 }
938}
939
940impl<C> AsyncRead for CountingConn<C>
941where
942 C: AsyncRead + Unpin,
943{
944 fn poll_read(
945 self: Pin<&mut Self>,
946 cx: &mut std::task::Context<'_>,
947 buf: &mut io::ReadBuf<'_>,
948 ) -> std::task::Poll<std::io::Result<()>> {
949 let counter = self.get_mut();
950 let pin = Pin::new(&mut counter.inner);
951 let bytes = buf.filled().len();
952 let poll = pin.poll_read(cx, buf);
953 let bytes = buf.filled().len() - bytes;
954 if let std::task::Poll::Ready(Ok(())) = poll {
955 counter.read += bytes
956 }
957 poll
958 }
959}
960
961impl<C> AsyncWrite for CountingConn<C>
962where
963 C: AsyncWrite + Unpin,
964{
965 fn poll_write(
966 self: Pin<&mut Self>,
967 cx: &mut std::task::Context<'_>,
968 buf: &[u8],
969 ) -> std::task::Poll<Result<usize, std::io::Error>> {
970 let counter = self.get_mut();
971 let pin = Pin::new(&mut counter.inner);
972 let poll = pin.poll_write(cx, buf);
973 if let std::task::Poll::Ready(Ok(bytes)) = poll {
974 counter.written += bytes
975 }
976 poll
977 }
978
979 fn poll_flush(
980 self: Pin<&mut Self>,
981 cx: &mut std::task::Context<'_>,
982 ) -> std::task::Poll<Result<(), std::io::Error>> {
983 let counter = self.get_mut();
984 let pin = Pin::new(&mut counter.inner);
985 pin.poll_flush(cx)
986 }
987
988 fn poll_shutdown(
989 self: Pin<&mut Self>,
990 cx: &mut std::task::Context<'_>,
991 ) -> std::task::Poll<Result<(), std::io::Error>> {
992 let counter = self.get_mut();
993 let pin = Pin::new(&mut counter.inner);
994 pin.poll_shutdown(cx)
995 }
996}
997
998async fn cancel_request(
1015 conn_id: u32,
1016 secret_key: u32,
1017 cancellation_resolver: &CancellationResolver,
1018) {
1019 let suffix = conn_id_org_uuid(conn_id);
1020 let contents = match cancellation_resolver {
1021 CancellationResolver::Directory(dir) => {
1022 let path = dir.join(&suffix);
1023 match std::fs::read_to_string(&path) {
1024 Ok(contents) => contents,
1025 Err(err) => {
1026 error!("could not read cancel file {path:?}: {err}");
1027 return;
1028 }
1029 }
1030 }
1031 CancellationResolver::Static(addr) => addr.to_owned(),
1032 };
1033 let mut all_ips = Vec::new();
1034 for addr in contents.lines() {
1035 let addr = addr.trim();
1036 if addr.is_empty() {
1037 continue;
1038 }
1039 match tokio::net::lookup_host(addr).await {
1040 Ok(ips) => all_ips.extend(ips),
1041 Err(err) => {
1042 error!("{addr} failed resolution: {err}");
1043 }
1044 }
1045 }
1046 let mut buf = BytesMut::with_capacity(16);
1047 let msg = FrontendStartupMessage::CancelRequest {
1048 conn_id,
1049 secret_key,
1050 };
1051 msg.encode(&mut buf).expect("must encode");
1052 let buf = buf.freeze();
1053 for ip in all_ips {
1054 debug!("cancelling {suffix} to {ip}");
1055 let buf = buf.clone();
1056 spawn(|| "cancel request for ip", async move {
1057 let send = async {
1058 let mut stream = TcpStream::connect(&ip).await?;
1059 stream.write_all(&buf).await?;
1060 stream.shutdown().await?;
1061 Ok::<_, io::Error>(())
1062 };
1063 if let Err(err) = send.await {
1064 error!("error mirroring cancel to {ip}: {err}");
1065 }
1066 });
1067 }
1068}
1069
1070struct HttpsBalancer {
1071 resolver: Arc<StubResolver>,
1072 tls: Option<ReloadingSslContext>,
1073 resolve_template: Arc<str>,
1074 port: u16,
1075 metrics: Arc<ServerMetrics>,
1076 configs: ConfigSet,
1077 internal_tls: bool,
1078}
1079
1080impl HttpsBalancer {
1081 async fn resolve(
1082 resolver: &StubResolver,
1083 resolve_template: &str,
1084 port: u16,
1085 servername: Option<&str>,
1086 ) -> Result<ResolvedAddr, anyhow::Error> {
1087 let addr = match &servername {
1088 Some(servername) => resolve_template.replace("{}", servername),
1089 None => resolve_template.to_string(),
1090 };
1091 debug!("https address: {addr}");
1092
1093 let tenant = resolver.tenant(&addr).await;
1106
1107 let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1109
1110 Ok(ResolvedAddr {
1111 addr: envd_addr,
1112 password: None,
1113 tenant,
1114 })
1115 }
1116}
1117
1118trait StubResolverExt {
1119 async fn tenant(&self, addr: &str) -> Option<String>;
1120}
1121
1122impl StubResolverExt for StubResolver {
1123 async fn tenant(&self, addr: &str) -> Option<String> {
1126 let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1127 return None;
1128 };
1129 debug!("resolving tenant for {:?}", addr);
1130 let lookup = self.query((dname, Rtype::CNAME)).await;
1132 if let Ok(lookup) = lookup {
1133 if let Ok(answer) = lookup.answer() {
1134 let res = answer.limit_to::<AllRecordData<_, _>>();
1135 for record in res {
1136 let Ok(record) = record else {
1137 continue;
1138 };
1139 if record.rtype() != Rtype::CNAME {
1140 continue;
1141 }
1142 let cname = record.data();
1143 let cname = cname.to_string();
1144 debug!("cname: {cname}");
1145 return extract_tenant_from_cname(&cname);
1146 }
1147 }
1148 }
1149 None
1150 }
1151}
1152
1153fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1155 let mut parts = cname.split('.');
1156 let _service = parts.next();
1157 let Some(namespace) = parts.next() else {
1158 return None;
1159 };
1160 let Some((_, namespace)) = namespace.split_once('-') else {
1162 return None;
1163 };
1164 let Some((tenant, _)) = namespace.rsplit_once('-') else {
1166 return None;
1167 };
1168 let Ok(tenant) = Uuid::parse_str(tenant) else {
1171 error!("cname tenant not a uuid: {tenant}");
1172 return None;
1173 };
1174 Some(tenant.to_string())
1175}
1176
1177impl mz_server_core::Server for HttpsBalancer {
1178 const NAME: &'static str = "https_balancer";
1179
1180 fn handle_connection(
1182 &self,
1183 conn: Connection,
1184 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1185 ) -> mz_server_core::ConnectionHandler {
1186 let tls_context = self.tls.clone();
1187 let internal_tls = self.internal_tls.clone();
1188 let resolver = Arc::clone(&self.resolver);
1189 let resolve_template = Arc::clone(&self.resolve_template);
1190 let port = self.port;
1191 let inner_metrics = Arc::clone(&self.metrics);
1192 let outer_metrics = Arc::clone(&self.metrics);
1193 let peer_addr = conn.peer_addr();
1194 let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1195 Box::pin(async move {
1196 let active_guard = inner_metrics.active_connections();
1197 let result: Result<_, anyhow::Error> = Box::pin(async move {
1198 let peer_addr = peer_addr.context("fetching peer addr")?;
1199 let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1200 match tls_context {
1201 Some(tls_context) => {
1202 let mut ssl_stream =
1203 SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1204 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1205 let _ = ssl_stream.get_mut().shutdown().await;
1206 return Err(e.into());
1207 }
1208 let servername: Option<String> =
1209 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1210 match sn.split_once('.') {
1211 Some((left, _right)) => left,
1212 None => sn,
1213 }
1214 .into()
1215 });
1216 debug!("Found sni servername: {servername:?} (https)");
1217 (Box::new(ssl_stream), servername)
1218 }
1219 _ => (Box::new(conn), None),
1220 };
1221 let resolved =
1222 Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1223 .await?;
1224 let inner_active_guard = resolved
1225 .tenant
1226 .as_ref()
1227 .map(|tenant| inner_metrics.tenant_connections(tenant));
1228
1229 let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1230
1231 if inject_proxy_headers {
1232 let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1234 let header = ProxyHeader::with_address(addrs);
1235 let mut buf = [0u8; 1024];
1236 let len = header.encode_to_slice_v2(&mut buf)?;
1237 mz_stream.write_all(&buf[..len]).await?;
1238 }
1239
1240 let mut mz_stream = if internal_tls {
1241 let mut builder =
1243 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1244 builder.set_verify(SslVerifyMode::NONE);
1246 let mut ssl = builder
1247 .build()
1248 .configure()?
1249 .into_ssl(&resolved.addr.to_string())?;
1250 ssl.set_connect_state();
1251 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1252 } else {
1253 Conn::Unencrypted(mz_stream)
1254 };
1255
1256 let mut client_counter = CountingConn::new(client_stream);
1257
1258 let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1262 if let Some(tenant) = &resolved.tenant {
1263 inner_metrics
1264 .tenant_connections_tx(tenant)
1265 .inc_by(u64::cast_from(client_counter.written));
1266 inner_metrics
1267 .tenant_connections_rx(tenant)
1268 .inc_by(u64::cast_from(client_counter.read));
1269 }
1270 drop(inner_active_guard);
1271 Ok(())
1272 })
1273 .await;
1274 drop(active_guard);
1275 outer_metrics.connection_status(result.is_ok()).inc();
1276 if let Err(e) = result {
1277 debug!("connection error: {e}");
1278 }
1279 Ok(())
1280 })
1281 }
1282}
1283
1284#[derive(Debug)]
1285pub struct SniResolver {
1286 pub resolver: StubResolver,
1287 pub template: String,
1288 pub port: u16,
1289}
1290
1291trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1292impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1293
1294#[derive(Debug)]
1295pub enum Resolver {
1296 Static(String),
1297 MultiTenant(FronteggResolver, Option<SniResolver>),
1298}
1299
1300impl Resolver {
1301 async fn resolve<A>(
1302 &self,
1303 conn: &mut FramedConn<A>,
1304 user: &str,
1305 metrics: &ServerMetrics,
1306 ) -> Result<ResolvedAddr, anyhow::Error>
1307 where
1308 A: AsyncRead + AsyncWrite + Unpin,
1309 {
1310 match self {
1311 Resolver::MultiTenant(
1312 FronteggResolver {
1313 auth,
1314 addr_template,
1315 },
1316 sni_resolver,
1317 ) => {
1318 let servername = match conn.inner() {
1319 Conn::Ssl(ssl_stream) => {
1320 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1321 match sn.split_once('.') {
1322 Some((left, _right)) => left,
1323 None => sn,
1324 }
1325 })
1326 }
1327 Conn::Unencrypted(_) => None,
1328 };
1329 let has_sni = servername.is_some();
1330 let resolved_addr = match (servername, sni_resolver) {
1332 (
1333 Some(servername),
1334 Some(SniResolver {
1335 resolver: stub_resolver,
1336 template: sni_addr_template,
1337 port,
1338 }),
1339 ) => {
1340 let sni_addr = sni_addr_template.replace("{}", servername);
1341 let tenant = stub_resolver.tenant(&sni_addr).await;
1342 let sni_addr = format!("{sni_addr}:{port}");
1343 let addr = lookup(&sni_addr).await?;
1344 if tenant.is_some() {
1345 debug!("SNI header found for tenant {:?}", tenant);
1346 }
1347 ResolvedAddr {
1348 addr,
1349 password: None,
1350 tenant,
1351 }
1352 }
1353 _ => {
1354 conn.send(BackendMessage::AuthenticationCleartextPassword)
1355 .await?;
1356 conn.flush().await?;
1357 let password = match conn.recv().await? {
1358 Some(FrontendMessage::Password { password }) => password,
1359 _ => anyhow::bail!("expected Password message"),
1360 };
1361
1362 let auth_response = auth.authenticate(user, &password).await;
1363 let auth_session = match auth_response {
1364 Ok(auth_session) => auth_session,
1365 Err(e) => {
1366 warn!("pgwire connection failed authentication: {}", e);
1367 anyhow::bail!("invalid password");
1369 }
1370 };
1371
1372 let addr =
1373 addr_template.replace("{}", &auth_session.tenant_id().to_string());
1374 let addr = lookup(&addr).await?;
1375 let tenant = Some(auth_session.tenant_id().to_string());
1376 if tenant.is_some() {
1377 debug!("SNI header NOT found for tenant {:?}", tenant);
1378 }
1379 ResolvedAddr {
1380 addr,
1381 password: Some(password),
1382 tenant,
1383 }
1384 }
1385 };
1386 metrics
1387 .tenant_pgwire_sni_count(
1388 resolved_addr.tenant.as_deref().unwrap_or("unknown"),
1389 has_sni,
1390 )
1391 .inc();
1392
1393 Ok(resolved_addr)
1394 }
1395 Resolver::Static(addr) => {
1396 let addr = lookup(addr).await?;
1397 Ok(ResolvedAddr {
1398 addr,
1399 password: None,
1400 tenant: None,
1401 })
1402 }
1403 }
1404 }
1405}
1406
1407async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1409 let mut addrs = tokio::net::lookup_host(name).await?;
1410 match addrs.next() {
1411 Some(addr) => Ok(addr),
1412 None => {
1413 error!("{name} did not resolve to any addresses");
1414 anyhow::bail!("internal error")
1415 }
1416 }
1417}
1418
1419#[derive(Debug)]
1420pub struct FronteggResolver {
1421 pub auth: FronteggAuthentication,
1422 pub addr_template: String,
1423}
1424
1425#[derive(Debug)]
1426struct ResolvedAddr {
1427 addr: SocketAddr,
1428 password: Option<String>,
1429 tenant: Option<String>,
1430}
1431
1432#[cfg(test)]
1433mod tests {
1434 use super::*;
1435
1436 #[mz_ore::test]
1437 fn test_tenant() {
1438 let tests = vec![
1439 ("", None),
1440 (
1441 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1442 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1443 ),
1444 (
1445 "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1447 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1448 ),
1449 (
1450 "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1452 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1453 ),
1454 (
1455 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1457 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1458 ),
1459 (
1460 "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1462 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1463 ),
1464 (
1465 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1467 None,
1468 ),
1469 (
1470 "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1472 None,
1473 ),
1474 (
1475 "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1477 None,
1478 ),
1479 ];
1480 for (name, expect) in tests {
1481 let cname = extract_tenant_from_cname(name);
1482 assert_eq!(
1483 cname.as_deref(),
1484 expect,
1485 "{name} got {cname:?} expected {expect:?}"
1486 );
1487 }
1488 }
1489}