1mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Dname, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55 FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58 Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59 ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_openssl::SslStream;
70use tokio_postgres::error::SqlState;
71use tower::Service;
72use tracing::{debug, error, warn};
73use uuid::Uuid;
74
75use crate::codec::{BackendMessage, FramedConn};
76use crate::dyncfgs::{
77 INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
78 has_tracing_config_update, tracing_config,
79};
80
81pub const BUILD_INFO: BuildInfo = build_info!();
83
84pub struct BalancerConfig {
85 build_version: Version,
87 internal_http_listen_addr: SocketAddr,
89 pgwire_listen_addr: SocketAddr,
91 https_listen_addr: SocketAddr,
93 cancellation_resolver: CancellationResolver,
95 resolver: Resolver,
97 https_addr_template: String,
98 tls: Option<TlsCertConfig>,
99 internal_tls: bool,
100 metrics_registry: MetricsRegistry,
101 reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
102 launchdarkly_sdk_key: Option<String>,
103 config_sync_timeout: Duration,
104 config_sync_loop_interval: Option<Duration>,
105 cloud_provider: Option<String>,
106 cloud_provider_region: Option<String>,
107 tracing_handle: TracingHandle,
108 default_configs: Vec<(String, String)>,
109}
110
111impl BalancerConfig {
112 pub fn new(
113 build_info: &BuildInfo,
114 internal_http_listen_addr: SocketAddr,
115 pgwire_listen_addr: SocketAddr,
116 https_listen_addr: SocketAddr,
117 cancellation_resolver: CancellationResolver,
118 resolver: Resolver,
119 https_addr_template: String,
120 tls: Option<TlsCertConfig>,
121 internal_tls: bool,
122 metrics_registry: MetricsRegistry,
123 reload_certs: ReloadTrigger,
124 launchdarkly_sdk_key: Option<String>,
125 config_sync_timeout: Duration,
126 config_sync_loop_interval: Option<Duration>,
127 cloud_provider: Option<String>,
128 cloud_provider_region: Option<String>,
129 tracing_handle: TracingHandle,
130 default_configs: Vec<(String, String)>,
131 ) -> Self {
132 Self {
133 build_version: build_info.semver_version(),
134 internal_http_listen_addr,
135 pgwire_listen_addr,
136 https_listen_addr,
137 cancellation_resolver,
138 resolver,
139 https_addr_template,
140 tls,
141 internal_tls,
142 metrics_registry,
143 reload_certs,
144 launchdarkly_sdk_key,
145 config_sync_timeout,
146 config_sync_loop_interval,
147 cloud_provider,
148 cloud_provider_region,
149 tracing_handle,
150 default_configs,
151 }
152 }
153}
154
155#[derive(Debug)]
157pub struct BalancerMetrics {
158 _uptime: ComputedGauge,
159}
160
161impl BalancerMetrics {
162 pub fn new(cfg: &BalancerConfig) -> Self {
164 let start = Instant::now();
165 let uptime = cfg.metrics_registry.register_computed_gauge(
166 metric!(
167 name: "mz_balancer_metadata_seconds",
168 help: "server uptime, labels are build metadata",
169 const_labels: {
170 "version" => cfg.build_version,
171 "build_type" => if cfg!(release) { "release" } else { "debug" }
172 },
173 ),
174 move || start.elapsed().as_secs_f64(),
175 );
176 BalancerMetrics { _uptime: uptime }
177 }
178}
179
180pub struct BalancerService {
181 cfg: BalancerConfig,
182 pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
183 pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
184 pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
185 _metrics: BalancerMetrics,
186 configs: ConfigSet,
187}
188
189impl BalancerService {
190 pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
191 let pgwire = listen(&cfg.pgwire_listen_addr).await?;
192 let https = listen(&cfg.https_listen_addr).await?;
193 let internal_http = listen(&cfg.internal_http_listen_addr).await?;
194 let metrics = BalancerMetrics::new(&cfg);
195 let mut configs = ConfigSet::default();
196 configs = dyncfgs::all_dyncfgs(configs);
197 dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
198 let tracing_handle = cfg.tracing_handle.clone();
199 if let Err(err) = mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
200 configs.clone(),
201 &BUILD_INFO,
202 |builder| {
203 let region = cfg
204 .cloud_provider_region
205 .clone()
206 .unwrap_or_else(|| String::from("unknown"));
207 if let Some(provider) = cfg.cloud_provider.clone() {
208 builder.add_context(
209 ld::ContextBuilder::new(format!(
210 "{}/{}/{}",
211 provider, region, cfg.build_version
212 ))
213 .kind("balancer")
214 .set_string("provider", provider)
215 .set_string("region", region)
216 .set_string("version", cfg.build_version.to_string())
217 .build()
218 .map_err(|e| anyhow::anyhow!(e))?,
219 );
220 } else {
221 builder.add_context(
222 ld::ContextBuilder::new(format!(
223 "{}/{}/{}",
224 "unknown", region, cfg.build_version
225 ))
226 .anonymous(true) .kind("balancer")
228 .set_string("provider", "unknown")
229 .set_string("region", region)
230 .set_string("version", cfg.build_version.to_string())
231 .build()
232 .map_err(|e| anyhow::anyhow!(e))?,
233 );
234 }
235 Ok(())
236 },
237 cfg.launchdarkly_sdk_key.as_deref(),
238 cfg.config_sync_timeout,
239 cfg.config_sync_loop_interval,
240 move |updates, configs| {
241 if has_tracing_config_update(updates) {
242 match tracing_config(configs) {
243 Ok(parameters) => parameters.apply(&tracing_handle),
244 Err(err) => warn!("unable to update tracing: {err}"),
245 }
246 }
247 },
248 )
249 .await
250 {
251 warn!("LaunchDarkly sync error: {err}");
255 }
256 Ok(Self {
257 cfg,
258 pgwire,
259 https,
260 internal_http,
261 _metrics: metrics,
262 configs,
263 })
264 }
265
266 pub async fn serve(self) -> Result<(), anyhow::Error> {
267 let (pgwire_tls, https_tls) = match &self.cfg.tls {
268 Some(tls) => {
269 let context = tls.reloading_context(self.cfg.reload_certs)?;
270 (
271 Some(ReloadingTlsConfig {
272 context: context.clone(),
273 mode: TlsMode::Require,
274 }),
275 Some(context),
276 )
277 }
278 None => (None, None),
279 };
280
281 let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
282
283 let mut set = JoinSet::new();
284 let mut server_handles = Vec::new();
285 let pgwire_addr = self.pgwire.0.local_addr();
286 let https_addr = self.https.0.local_addr();
287 let internal_http_addr = self.internal_http.0.local_addr();
288
289 {
290 let pgwire = PgwireBalancer {
291 resolver: Arc::new(self.cfg.resolver),
292 cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
293 tls: pgwire_tls,
294 internal_tls: self.cfg.internal_tls,
295 metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
296 now: SYSTEM_TIME.clone(),
297 };
298 let (handle, stream) = self.pgwire;
299 server_handles.push(handle);
300 set.spawn_named(|| "pgwire_stream", {
301 let config_set = self.configs.clone();
302 async move {
303 mz_server_core::serve(ServeConfig {
304 server: pgwire,
305 conns: stream,
306 dyncfg: Some(ServeDyncfg {
307 config_set,
308 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
309 }),
310 })
311 .await;
312 warn!("pgwire server exited");
313 }
314 });
315 }
316 {
317 let Some((addr, port)) = self.cfg.https_addr_template.split_once(':') else {
318 panic!("expected port in https_addr_template");
319 };
320 let port: u16 = port.parse().expect("unexpected port");
321 let resolver = StubResolver::new();
322 let https = HttpsBalancer {
323 resolver: Arc::from(resolver),
324 tls: https_tls,
325 resolve_template: Arc::from(addr),
326 port,
327 metrics: Arc::from(ServerMetrics::new(metrics, "https")),
328 configs: self.configs.clone(),
329 internal_tls: self.cfg.internal_tls,
330 };
331 let (handle, stream) = self.https;
332 server_handles.push(handle);
333 set.spawn_named(|| "https_stream", {
334 let config_set = self.configs.clone();
335 async move {
336 mz_server_core::serve(ServeConfig {
337 server: https,
338 conns: stream,
339 dyncfg: Some(ServeDyncfg {
340 config_set,
341 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
342 }),
343 })
344 .await;
345 warn!("https server exited");
346 }
347 });
348 }
349 {
350 let router = Router::new()
351 .route(
352 "/metrics",
353 routing::get(move || async move {
354 mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
355 }),
356 )
357 .route(
358 "/api/livez",
359 routing::get(mz_http_util::handle_liveness_check),
360 )
361 .route("/api/readyz", routing::get(handle_readiness_check));
362 let internal_http = InternalHttpServer { router };
363 let (handle, stream) = self.internal_http;
364 server_handles.push(handle);
365 set.spawn_named(|| "internal_http_stream", async move {
366 mz_server_core::serve(ServeConfig {
367 server: internal_http,
368 conns: stream,
369 dyncfg: None,
372 })
373 .await;
374 warn!("internal_http server exited");
375 });
376 }
377 #[cfg(unix)]
378 {
379 let mut sigterm =
380 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
381 set.spawn_named(|| "sigterm_handler", async move {
382 sigterm.recv().await;
383 let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
384 warn!("received signal TERM - delaying for {:?}!", wait);
385 tokio::time::sleep(wait).await;
386 warn!("sigterm delay complete, dropping server handles");
387 drop(server_handles);
388 });
389 }
390
391 println!("balancerd {} listening...", BUILD_INFO.human_version(None));
392 println!(" TLS enabled: {}", self.cfg.tls.is_some());
393 println!(" pgwire address: {}", pgwire_addr);
394 println!(" HTTPS address: {}", https_addr);
395 println!(" internal HTTP address: {}", internal_http_addr);
396
397 while let Some(res) = set.join_next().await {
399 if let Err(err) = res {
400 error!("serving task failed: {err}")
401 }
402 }
403 Ok(())
404 }
405}
406
407#[allow(clippy::unused_async)]
408async fn handle_readiness_check() -> impl IntoResponse {
409 (StatusCode::OK, "ready")
410}
411
412struct InternalHttpServer {
413 router: Router,
414}
415
416impl mz_server_core::Server for InternalHttpServer {
417 const NAME: &'static str = "internal_http";
418
419 fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
421 let router = self.router.clone();
422 let service = hyper::service::service_fn(move |req| router.clone().call(req));
423 let conn = TokioIo::new(conn);
424
425 Box::pin(async {
426 let http = hyper::server::conn::http1::Builder::new();
427 http.serve_connection(conn, service).err_into().await
428 })
429 }
430}
431
432struct GaugeGuard {
435 gauge: IntGauge,
436}
437
438impl From<IntGauge> for GaugeGuard {
439 fn from(gauge: IntGauge) -> Self {
440 let _self = Self { gauge };
441 _self.gauge.inc();
442 _self
443 }
444}
445
446impl Drop for GaugeGuard {
447 fn drop(&mut self) {
448 self.gauge.dec();
449 }
450}
451
452#[derive(Clone, Debug)]
453struct ServerMetricsConfig {
454 connection_status: IntCounterVec,
455 active_connections: IntGaugeVec,
456 tenant_connections: IntGaugeVec,
457 tenant_connection_rx: IntCounterVec,
458 tenant_connection_tx: IntCounterVec,
459 tenant_pgwire_sni_count: IntCounterVec,
460}
461
462impl ServerMetricsConfig {
463 fn register_into(registry: &MetricsRegistry) -> Self {
464 let connection_status = registry.register(metric!(
465 name: "mz_balancer_connection_status",
466 help: "Count of completed network connections, by status",
467 var_labels: ["source", "status"],
468 ));
469 let active_connections = registry.register(metric!(
470 name: "mz_balancer_connection_active",
471 help: "Count of currently open network connections.",
472 var_labels: ["source"],
473 ));
474 let tenant_connections = registry.register(metric!(
475 name: "mz_balancer_tenant_connection_active",
476 help: "Count of opened network connections by tenant.",
477 var_labels: ["source", "tenant"]
478 ));
479 let tenant_connection_rx = registry.register(metric!(
480 name: "mz_balancer_tenant_connection_rx",
481 help: "Number of bytes received from a client for a tenant.",
482 var_labels: ["source", "tenant"],
483 ));
484 let tenant_connection_tx = registry.register(metric!(
485 name: "mz_balancer_tenant_connection_tx",
486 help: "Number of bytes sent to a client for a tenant.",
487 var_labels: ["source", "tenant"],
488 ));
489 let tenant_pgwire_sni_count = registry.register(metric!(
490 name: "mz_balancer_tenant_pgwire_sni_count",
491 help: "Count of pgwire connections that have and do not have SNI available per tenant.",
492 var_labels: ["tenant", "has_sni"],
493 ));
494 Self {
495 connection_status,
496 active_connections,
497 tenant_connections,
498 tenant_connection_rx,
499 tenant_connection_tx,
500 tenant_pgwire_sni_count,
501 }
502 }
503}
504
505#[derive(Clone, Debug)]
506struct ServerMetrics {
507 inner: ServerMetricsConfig,
508 source: &'static str,
509}
510
511impl ServerMetrics {
512 fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
513 let self_ = Self { inner, source };
514
515 self_.connection_status(false);
518 self_.connection_status(true);
519 drop(self_.active_connections());
520
521 self_
522 }
523
524 fn connection_status(&self, is_ok: bool) -> IntCounter {
525 self.inner
526 .connection_status
527 .with_label_values(&[self.source, Self::status_label(is_ok)])
528 }
529
530 fn active_connections(&self) -> GaugeGuard {
531 self.inner
532 .active_connections
533 .with_label_values(&[self.source])
534 .into()
535 }
536
537 fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
538 self.inner
539 .tenant_connections
540 .with_label_values(&[self.source, tenant])
541 .into()
542 }
543
544 fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
545 self.inner
546 .tenant_connection_rx
547 .with_label_values(&[self.source, tenant])
548 }
549
550 fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
551 self.inner
552 .tenant_connection_tx
553 .with_label_values(&[self.source, tenant])
554 }
555
556 fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
557 self.inner
558 .tenant_pgwire_sni_count
559 .with_label_values(&[tenant, &has_sni.to_string()])
560 }
561
562 fn status_label(is_ok: bool) -> &'static str {
563 if is_ok { "success" } else { "error" }
564 }
565}
566
567pub enum CancellationResolver {
568 Directory(PathBuf),
569 Static(String),
570}
571
572struct PgwireBalancer {
573 tls: Option<ReloadingTlsConfig>,
574 internal_tls: bool,
575 cancellation_resolver: Arc<CancellationResolver>,
576 resolver: Arc<Resolver>,
577 metrics: ServerMetrics,
578 now: NowFn,
579}
580
581impl PgwireBalancer {
582 #[mz_ore::instrument(level = "debug")]
583 async fn run<'a, A>(
584 conn: &'a mut FramedConn<A>,
585 version: i32,
586 params: BTreeMap<String, String>,
587 resolver: &Resolver,
588 tls_mode: Option<TlsMode>,
589 internal_tls: bool,
590 metrics: &ServerMetrics,
591 ) -> Result<(), io::Error>
592 where
593 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
594 {
595 if version != VERSION_3 {
596 return conn
597 .send(ErrorResponse::fatal(
598 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
599 "server does not support the client's requested protocol version",
600 ))
601 .await;
602 }
603
604 let Some(user) = params.get("user") else {
605 return conn
606 .send(ErrorResponse::fatal(
607 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
608 "user parameter required",
609 ))
610 .await;
611 };
612
613 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
614 return conn.send(err).await;
615 }
616
617 let resolved = match resolver.resolve(conn, user).await {
618 Ok(v) => v,
619 Err(err) => {
620 return conn
621 .send(ErrorResponse::fatal(
622 SqlState::INVALID_PASSWORD,
623 err.to_string(),
624 ))
625 .await;
626 }
627 };
628
629 if let Conn::Ssl(ssl_stream) = conn.inner() {
632 let tenant = resolved.tenant.as_deref().unwrap_or("unknown");
633 let has_sni = ssl_stream.ssl().servername(NameType::HOST_NAME).is_some();
634 metrics.tenant_pgwire_sni_count(tenant, has_sni).inc();
635 }
636
637 let _active_guard = resolved
638 .tenant
639 .as_ref()
640 .map(|tenant| metrics.tenant_connections(tenant));
641 let Ok(mut mz_stream) =
642 Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
643 else {
644 return Ok(());
645 };
646
647 let mut client_counter = CountingConn::new(conn.inner_mut());
648
649 let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
652 if let Some(tenant) = &resolved.tenant {
653 metrics
654 .tenant_connections_tx(tenant)
655 .inc_by(u64::cast_from(client_counter.written));
656 metrics
657 .tenant_connections_rx(tenant)
658 .inc_by(u64::cast_from(client_counter.read));
659 }
660 res?;
661
662 Ok(())
663 }
664
665 #[mz_ore::instrument(level = "debug")]
666 async fn init_stream<'a, A>(
667 conn: &'a mut FramedConn<A>,
668 envd_addr: SocketAddr,
669 password: Option<String>,
670 params: BTreeMap<String, String>,
671 internal_tls: bool,
672 ) -> Result<Conn<TcpStream>, anyhow::Error>
673 where
674 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
675 {
676 let mut mz_stream = TcpStream::connect(envd_addr).await?;
677 let mut buf = BytesMut::new();
678
679 let mut mz_stream = if internal_tls {
680 FrontendStartupMessage::SslRequest.encode(&mut buf)?;
681 mz_stream.write_all(&buf).await?;
682 buf.clear();
683 let mut maybe_ssl_request_response = [0u8; 1];
684 let nread =
685 netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
686 if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
687 let mut builder =
689 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
690 builder.set_verify(SslVerifyMode::NONE);
692 let mut ssl = builder
693 .build()
694 .configure()?
695 .into_ssl(&envd_addr.to_string())?;
696 ssl.set_connect_state();
697 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
698 } else {
699 Conn::Unencrypted(mz_stream)
700 }
701 } else {
702 Conn::Unencrypted(mz_stream)
703 };
704
705 let startup = FrontendStartupMessage::Startup {
707 version: VERSION_3,
708 params,
709 };
710 startup.encode(&mut buf)?;
711 mz_stream.write_all(&buf).await?;
712 let client_stream = conn.inner_mut();
713
714 let mut maybe_auth_frame = [0; 1 + 4 + 4];
718 let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
719 const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
722 if nread == AUTH_PASSWORD_CLEARTEXT.len()
723 && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
724 && password.is_some()
725 {
726 let Some(password) = password else {
728 unreachable!("verified some above");
729 };
730 let password = FrontendMessage::Password { password };
731 buf.clear();
732 password.encode(&mut buf)?;
733 mz_stream.write_all(&buf).await?;
734 mz_stream.flush().await?;
735 } else {
736 client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
739 }
740
741 Ok(mz_stream)
742 }
743}
744
745impl mz_server_core::Server for PgwireBalancer {
746 const NAME: &'static str = "pgwire_balancer";
747
748 fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
749 let tls = self.tls.clone();
750 let internal_tls = self.internal_tls;
751 let resolver = Arc::clone(&self.resolver);
752 let inner_metrics = self.metrics.clone();
753 let outer_metrics = self.metrics.clone();
754 let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
755 let conn_uuid = epoch_to_uuid_v7(&(self.now)());
756 let peer_addr = conn.peer_addr();
757 conn.uuid_handle().set(conn_uuid);
758 Box::pin(async move {
759 let active_guard = outer_metrics.active_connections();
762 let result: Result<(), anyhow::Error> = async move {
763 let mut conn = Conn::Unencrypted(conn);
764 loop {
765 let message = decode_startup(&mut conn).await?;
766 conn = match message {
767 None => return Ok(()),
771
772 Some(FrontendStartupMessage::Startup {
773 version,
774 mut params,
775 }) => {
776 let mut conn = FramedConn::new(conn);
777 let peer_addr = match peer_addr {
778 Ok(addr) => addr.ip(),
779 Err(e) => {
780 error!("Invalid peer_addr {:?}", e);
781 return Ok(conn
782 .send(ErrorResponse::fatal(
783 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
784 "invalid peer address",
785 ))
786 .await?);
787 }
788 };
789 debug!(%conn_uuid, %peer_addr, "starting new pgwire connection in balancer");
790 let prev =
791 params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
792 if prev.is_some() {
793 return Ok(conn
794 .send(ErrorResponse::fatal(
795 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
796 format!("invalid parameter '{CONN_UUID_KEY}'"),
797 ))
798 .await?);
799 }
800
801 if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
802 return Ok(conn
803 .send(ErrorResponse::fatal(
804 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
805 format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
806 ))
807 .await?);
808 };
809
810 Self::run(
811 &mut conn,
812 version,
813 params,
814 &resolver,
815 tls.map(|tls| tls.mode),
816 internal_tls,
817 &inner_metrics,
818 )
819 .await?;
820 conn.flush().await?;
821 return Ok(());
822 }
823
824 Some(FrontendStartupMessage::CancelRequest {
825 conn_id,
826 secret_key,
827 }) => {
828 spawn(|| "cancel request", async move {
829 cancel_request(conn_id, secret_key, &cancellation_resolver).await;
830 });
831 return Ok(());
834 }
835
836 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
837 (Conn::Unencrypted(mut conn), Some(tls)) => {
838 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
839 let mut ssl_stream =
840 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
841 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
842 let _ = ssl_stream.get_mut().shutdown().await;
843 return Err(e.into());
844 }
845 Conn::Ssl(ssl_stream)
846 }
847 (mut conn, _) => {
848 conn.write_all(&[REJECT_ENCRYPTION]).await?;
849 conn
850 }
851 },
852
853 Some(FrontendStartupMessage::GssEncRequest) => {
854 conn.write_all(&[REJECT_ENCRYPTION]).await?;
855 conn
856 }
857 }
858 }
859 }
860 .await;
861 drop(active_guard);
862 outer_metrics.connection_status(result.is_ok()).inc();
863 Ok(())
864 })
865 }
866}
867
868struct CountingConn<C> {
870 inner: C,
871 read: usize,
872 written: usize,
873}
874
875impl<C> CountingConn<C> {
876 fn new(inner: C) -> Self {
877 CountingConn {
878 inner,
879 read: 0,
880 written: 0,
881 }
882 }
883}
884
885impl<C> AsyncRead for CountingConn<C>
886where
887 C: AsyncRead + Unpin,
888{
889 fn poll_read(
890 self: Pin<&mut Self>,
891 cx: &mut std::task::Context<'_>,
892 buf: &mut io::ReadBuf<'_>,
893 ) -> std::task::Poll<std::io::Result<()>> {
894 let counter = self.get_mut();
895 let pin = Pin::new(&mut counter.inner);
896 let bytes = buf.filled().len();
897 let poll = pin.poll_read(cx, buf);
898 let bytes = buf.filled().len() - bytes;
899 if let std::task::Poll::Ready(Ok(())) = poll {
900 counter.read += bytes
901 }
902 poll
903 }
904}
905
906impl<C> AsyncWrite for CountingConn<C>
907where
908 C: AsyncWrite + Unpin,
909{
910 fn poll_write(
911 self: Pin<&mut Self>,
912 cx: &mut std::task::Context<'_>,
913 buf: &[u8],
914 ) -> std::task::Poll<Result<usize, std::io::Error>> {
915 let counter = self.get_mut();
916 let pin = Pin::new(&mut counter.inner);
917 let poll = pin.poll_write(cx, buf);
918 if let std::task::Poll::Ready(Ok(bytes)) = poll {
919 counter.written += bytes
920 }
921 poll
922 }
923
924 fn poll_flush(
925 self: Pin<&mut Self>,
926 cx: &mut std::task::Context<'_>,
927 ) -> std::task::Poll<Result<(), std::io::Error>> {
928 let counter = self.get_mut();
929 let pin = Pin::new(&mut counter.inner);
930 pin.poll_flush(cx)
931 }
932
933 fn poll_shutdown(
934 self: Pin<&mut Self>,
935 cx: &mut std::task::Context<'_>,
936 ) -> std::task::Poll<Result<(), std::io::Error>> {
937 let counter = self.get_mut();
938 let pin = Pin::new(&mut counter.inner);
939 pin.poll_shutdown(cx)
940 }
941}
942
943async fn cancel_request(
960 conn_id: u32,
961 secret_key: u32,
962 cancellation_resolver: &CancellationResolver,
963) {
964 let suffix = conn_id_org_uuid(conn_id);
965 let contents = match cancellation_resolver {
966 CancellationResolver::Directory(dir) => {
967 let path = dir.join(&suffix);
968 match std::fs::read_to_string(&path) {
969 Ok(contents) => contents,
970 Err(err) => {
971 error!("could not read cancel file {path:?}: {err}");
972 return;
973 }
974 }
975 }
976 CancellationResolver::Static(addr) => addr.to_owned(),
977 };
978 let mut all_ips = Vec::new();
979 for addr in contents.lines() {
980 let addr = addr.trim();
981 if addr.is_empty() {
982 continue;
983 }
984 match tokio::net::lookup_host(addr).await {
985 Ok(ips) => all_ips.extend(ips),
986 Err(err) => {
987 error!("{addr} failed resolution: {err}");
988 }
989 }
990 }
991 let mut buf = BytesMut::with_capacity(16);
992 let msg = FrontendStartupMessage::CancelRequest {
993 conn_id,
994 secret_key,
995 };
996 msg.encode(&mut buf).expect("must encode");
997 let buf = buf.freeze();
998 for ip in all_ips {
999 debug!("cancelling {suffix} to {ip}");
1000 let buf = buf.clone();
1001 spawn(|| "cancel request for ip", async move {
1002 let send = async {
1003 let mut stream = TcpStream::connect(&ip).await?;
1004 stream.write_all(&buf).await?;
1005 stream.shutdown().await?;
1006 Ok::<_, io::Error>(())
1007 };
1008 if let Err(err) = send.await {
1009 error!("error mirroring cancel to {ip}: {err}");
1010 }
1011 });
1012 }
1013}
1014
1015struct HttpsBalancer {
1016 resolver: Arc<StubResolver>,
1017 tls: Option<ReloadingSslContext>,
1018 resolve_template: Arc<str>,
1019 port: u16,
1020 metrics: Arc<ServerMetrics>,
1021 configs: ConfigSet,
1022 internal_tls: bool,
1023}
1024
1025impl HttpsBalancer {
1026 async fn resolve(
1027 resolver: &StubResolver,
1028 resolve_template: &str,
1029 port: u16,
1030 servername: Option<&str>,
1031 ) -> Result<ResolvedAddr, anyhow::Error> {
1032 let addr = match &servername {
1033 Some(servername) => resolve_template.replace("{}", servername),
1034 None => resolve_template.to_string(),
1035 };
1036 debug!("https address: {addr}");
1037
1038 let tenant = Self::tenant(resolver, &addr).await;
1051
1052 let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1054
1055 Ok(ResolvedAddr {
1056 addr: envd_addr,
1057 password: None,
1058 tenant,
1059 })
1060 }
1061
1062 async fn tenant(resolver: &StubResolver, addr: &str) -> Option<String> {
1065 let Ok(dname) = Dname::<Vec<_>>::from_str(addr) else {
1066 return None;
1067 };
1068 let lookup = resolver.query((dname, Rtype::Cname)).await;
1070 if let Ok(lookup) = lookup {
1071 if let Ok(answer) = lookup.answer() {
1072 let res = answer.limit_to::<AllRecordData<_, _>>();
1073 for record in res {
1074 let Ok(record) = record else {
1075 continue;
1076 };
1077 if record.rtype() != Rtype::Cname {
1078 continue;
1079 }
1080 let cname = record.data();
1081 let cname = cname.to_string();
1082 debug!("cname: {cname}");
1083 return Self::extract_tenant_from_cname(&cname);
1084 }
1085 }
1086 }
1087 None
1088 }
1089
1090 fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1092 let mut parts = cname.split('.');
1093 let _service = parts.next();
1094 let Some(namespace) = parts.next() else {
1095 return None;
1096 };
1097 let Some((_, namespace)) = namespace.split_once('-') else {
1099 return None;
1100 };
1101 let Some((tenant, _)) = namespace.rsplit_once('-') else {
1103 return None;
1104 };
1105 let Ok(tenant) = Uuid::parse_str(tenant) else {
1108 error!("cname tenant not a uuid: {tenant}");
1109 return None;
1110 };
1111 Some(tenant.to_string())
1112 }
1113}
1114
1115impl mz_server_core::Server for HttpsBalancer {
1116 const NAME: &'static str = "https_balancer";
1117
1118 fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
1120 let tls_context = self.tls.clone();
1121 let internal_tls = self.internal_tls.clone();
1122 let resolver = Arc::clone(&self.resolver);
1123 let resolve_template = Arc::clone(&self.resolve_template);
1124 let port = self.port;
1125 let inner_metrics = Arc::clone(&self.metrics);
1126 let outer_metrics = Arc::clone(&self.metrics);
1127 let peer_addr = conn.peer_addr();
1128 let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1129 Box::pin(async move {
1130 let active_guard = inner_metrics.active_connections();
1131 let result: Result<_, anyhow::Error> = Box::pin(async move {
1132 let peer_addr = peer_addr.context("fetching peer addr")?;
1133 let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1134 match tls_context {
1135 Some(tls_context) => {
1136 let mut ssl_stream =
1137 SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1138 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1139 let _ = ssl_stream.get_mut().shutdown().await;
1140 return Err(e.into());
1141 }
1142 let servername: Option<String> =
1143 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1144 match sn.split_once('.') {
1145 Some((left, _right)) => left,
1146 None => sn,
1147 }
1148 .into()
1149 });
1150 debug!("servername: {servername:?}");
1151 (Box::new(ssl_stream), servername)
1152 }
1153 _ => (Box::new(conn), None),
1154 };
1155 let resolved =
1156 Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1157 .await?;
1158 let inner_active_guard = resolved
1159 .tenant
1160 .as_ref()
1161 .map(|tenant| inner_metrics.tenant_connections(tenant));
1162
1163 let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1164
1165 if inject_proxy_headers {
1166 let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1168 let header = ProxyHeader::with_address(addrs);
1169 let mut buf = [0u8; 1024];
1170 let len = header.encode_to_slice_v2(&mut buf)?;
1171 mz_stream.write_all(&buf[..len]).await?;
1172 }
1173
1174 let mut mz_stream = if internal_tls {
1175 let mut builder =
1177 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1178 builder.set_verify(SslVerifyMode::NONE);
1180 let mut ssl = builder
1181 .build()
1182 .configure()?
1183 .into_ssl(&resolved.addr.to_string())?;
1184 ssl.set_connect_state();
1185 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1186 } else {
1187 Conn::Unencrypted(mz_stream)
1188 };
1189
1190 let mut client_counter = CountingConn::new(client_stream);
1191
1192 let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1196 if let Some(tenant) = &resolved.tenant {
1197 inner_metrics
1198 .tenant_connections_tx(tenant)
1199 .inc_by(u64::cast_from(client_counter.written));
1200 inner_metrics
1201 .tenant_connections_rx(tenant)
1202 .inc_by(u64::cast_from(client_counter.read));
1203 }
1204 drop(inner_active_guard);
1205 Ok(())
1206 })
1207 .await;
1208 drop(active_guard);
1209 outer_metrics.connection_status(result.is_ok()).inc();
1210 if let Err(e) = result {
1211 debug!("connection error: {e}");
1212 }
1213 Ok(())
1214 })
1215 }
1216}
1217
1218trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1219impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1220
1221#[derive(Debug)]
1222pub enum Resolver {
1223 Static(String),
1224 Frontegg(FronteggResolver),
1225}
1226
1227impl Resolver {
1228 async fn resolve<A>(
1229 &self,
1230 conn: &mut FramedConn<A>,
1231 user: &str,
1232 ) -> Result<ResolvedAddr, anyhow::Error>
1233 where
1234 A: AsyncRead + AsyncWrite + Unpin,
1235 {
1236 match self {
1237 Resolver::Frontegg(FronteggResolver {
1238 auth,
1239 addr_template,
1240 }) => {
1241 conn.send(BackendMessage::AuthenticationCleartextPassword)
1242 .await?;
1243 conn.flush().await?;
1244 let password = match conn.recv().await? {
1245 Some(FrontendMessage::Password { password }) => password,
1246 _ => anyhow::bail!("expected Password message"),
1247 };
1248
1249 let auth_response = auth.authenticate(user, &password).await;
1250 let auth_session = match auth_response {
1251 Ok(auth_session) => auth_session,
1252 Err(e) => {
1253 warn!("pgwire connection failed authentication: {}", e);
1254 anyhow::bail!("invalid password");
1256 }
1257 };
1258
1259 let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string());
1260 let addr = lookup(&addr).await?;
1261 Ok(ResolvedAddr {
1262 addr,
1263 password: Some(password),
1264 tenant: Some(auth_session.tenant_id().to_string()),
1265 })
1266 }
1267 Resolver::Static(addr) => {
1268 let addr = lookup(addr).await?;
1269 Ok(ResolvedAddr {
1270 addr,
1271 password: None,
1272 tenant: None,
1273 })
1274 }
1275 }
1276 }
1277}
1278
1279async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1281 let mut addrs = tokio::net::lookup_host(name).await?;
1282 match addrs.next() {
1283 Some(addr) => Ok(addr),
1284 None => {
1285 error!("{name} did not resolve to any addresses");
1286 anyhow::bail!("internal error")
1287 }
1288 }
1289}
1290
1291#[derive(Debug)]
1292pub struct FronteggResolver {
1293 pub auth: FronteggAuthentication,
1294 pub addr_template: String,
1295}
1296
1297#[derive(Debug)]
1298struct ResolvedAddr {
1299 addr: SocketAddr,
1300 password: Option<String>,
1301 tenant: Option<String>,
1302}
1303
1304#[cfg(test)]
1305mod tests {
1306 use super::*;
1307
1308 #[mz_ore::test]
1309 fn test_tenant() {
1310 let tests = vec![
1311 ("", None),
1312 (
1313 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1314 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1315 ),
1316 (
1317 "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1319 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1320 ),
1321 (
1322 "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1324 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1325 ),
1326 (
1327 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1329 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1330 ),
1331 (
1332 "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1334 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1335 ),
1336 (
1337 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1339 None,
1340 ),
1341 (
1342 "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1344 None,
1345 ),
1346 (
1347 "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1349 None,
1350 ),
1351 ];
1352 for (name, expect) in tests {
1353 let cname = HttpsBalancer::extract_tenant_from_cname(name);
1354 assert_eq!(
1355 cname.as_deref(),
1356 expect,
1357 "{name} got {cname:?} expected {expect:?}"
1358 );
1359 }
1360 }
1361}