1mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54 ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55 FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58 Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59 ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78 INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79 has_tracing_config_update, tracing_config,
80};
81
82pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86 build_version: Version,
88 internal_http_listen_addr: SocketAddr,
90 pgwire_listen_addr: SocketAddr,
92 https_listen_addr: SocketAddr,
94 cancellation_resolver: CancellationResolver,
96 resolver: Resolver,
98 https_sni_addr_template: String,
99 tls: Option<TlsCertConfig>,
100 internal_tls: bool,
101 metrics_registry: MetricsRegistry,
102 reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103 launchdarkly_sdk_key: Option<String>,
104 config_sync_file_path: Option<PathBuf>,
105 config_sync_timeout: Duration,
106 config_sync_loop_interval: Option<Duration>,
107 cloud_provider: Option<String>,
108 cloud_provider_region: Option<String>,
109 tracing_handle: TracingHandle,
110 default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114 pub fn new(
115 build_info: &BuildInfo,
116 internal_http_listen_addr: SocketAddr,
117 pgwire_listen_addr: SocketAddr,
118 https_listen_addr: SocketAddr,
119 cancellation_resolver: CancellationResolver,
120 resolver: Resolver,
121 https_sni_addr_template: String,
122 tls: Option<TlsCertConfig>,
123 internal_tls: bool,
124 metrics_registry: MetricsRegistry,
125 reload_certs: ReloadTrigger,
126 launchdarkly_sdk_key: Option<String>,
127 config_sync_file: Option<PathBuf>,
128 config_sync_timeout: Duration,
129 config_sync_loop_interval: Option<Duration>,
130 cloud_provider: Option<String>,
131 cloud_provider_region: Option<String>,
132 tracing_handle: TracingHandle,
133 default_configs: Vec<(String, String)>,
134 ) -> Self {
135 Self {
136 build_version: build_info.semver_version(),
137 internal_http_listen_addr,
138 pgwire_listen_addr,
139 https_listen_addr,
140 cancellation_resolver,
141 resolver,
142 https_sni_addr_template,
143 tls,
144 internal_tls,
145 metrics_registry,
146 reload_certs,
147 launchdarkly_sdk_key,
148 config_sync_file_path: config_sync_file,
149 config_sync_timeout,
150 config_sync_loop_interval,
151 cloud_provider,
152 cloud_provider_region,
153 tracing_handle,
154 default_configs,
155 }
156 }
157}
158
159#[derive(Debug)]
161pub struct BalancerMetrics {
162 _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166 pub fn new(cfg: &BalancerConfig) -> Self {
168 let start = Instant::now();
169 let uptime = cfg.metrics_registry.register_computed_gauge(
170 metric!(
171 name: "mz_balancer_metadata_seconds",
172 help: "server uptime, labels are build metadata",
173 const_labels: {
174 "version" => cfg.build_version,
175 "build_type" => if cfg!(release) { "release" } else { "debug" }
176 },
177 ),
178 move || start.elapsed().as_secs_f64(),
179 );
180 BalancerMetrics { _uptime: uptime }
181 }
182}
183
184pub struct BalancerService {
185 cfg: BalancerConfig,
186 pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187 pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188 pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189 _metrics: BalancerMetrics,
190 configs: ConfigSet,
191}
192
193impl BalancerService {
194 pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195 let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196 let https = listen(&cfg.https_listen_addr).await?;
197 let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198 let metrics = BalancerMetrics::new(&cfg);
199 let mut configs = ConfigSet::default();
200 configs = dyncfgs::all_dyncfgs(configs);
201 dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202 let tracing_handle = cfg.tracing_handle.clone();
203 match (
205 cfg.launchdarkly_sdk_key.as_deref(),
206 cfg.config_sync_file_path.as_deref(),
207 ) {
208 (Some(key), None) => {
209 let _ = mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210 configs.clone(),
211 &BUILD_INFO,
212 |builder| {
213 let region = cfg
214 .cloud_provider_region
215 .clone()
216 .unwrap_or_else(|| String::from("unknown"));
217 if let Some(provider) = cfg.cloud_provider.clone() {
218 builder.add_context(
219 ld::ContextBuilder::new(format!(
220 "{}/{}/{}",
221 provider, region, cfg.build_version
222 ))
223 .kind("balancer")
224 .set_string("provider", provider)
225 .set_string("region", region)
226 .set_string("version", cfg.build_version.to_string())
227 .build()
228 .map_err(|e| anyhow::anyhow!(e))?,
229 );
230 } else {
231 builder.add_context(
232 ld::ContextBuilder::new(format!(
233 "{}/{}/{}",
234 "unknown", region, cfg.build_version
235 ))
236 .anonymous(true) .kind("balancer")
238 .set_string("provider", "unknown")
239 .set_string("region", region)
240 .set_string("version", cfg.build_version.to_string())
241 .build()
242 .map_err(|e| anyhow::anyhow!(e))?,
243 );
244 }
245 Ok(())
246 },
247 Some(key),
248 cfg.config_sync_timeout,
249 cfg.config_sync_loop_interval,
250 move |updates, configs| {
251 if has_tracing_config_update(updates) {
252 match tracing_config(configs) {
253 Ok(parameters) => parameters.apply(&tracing_handle),
254 Err(err) => warn!("unable to update tracing: {err}"),
255 }
256 }
257 },
258 )
259 .await
260 .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"));
261 }
262 (None, Some(path)) => {
263 let _ = mz_dyncfg_file::sync_file_to_configset(
264 configs.clone(),
265 path,
266 cfg.config_sync_timeout,
267 cfg.config_sync_loop_interval,
268 move |updates, configs| {
269 if has_tracing_config_update(updates) {
270 match tracing_config(configs) {
271 Ok(parameters) => parameters.apply(&tracing_handle),
272 Err(err) => warn!("unable to update tracing: {err}"),
273 }
274 }
275 },
276 )
277 .await
278 .inspect_err(|e| warn!("File config sync error: {e}"));
284 }
285 (Some(_), Some(_)) => panic!(
286 "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
287 ),
288 (None, None) => {}
289 };
290 Ok(Self {
291 cfg,
292 pgwire,
293 https,
294 internal_http,
295 _metrics: metrics,
296 configs,
297 })
298 }
299
300 pub async fn serve(self) -> Result<(), anyhow::Error> {
301 let (pgwire_tls, https_tls) = match &self.cfg.tls {
302 Some(tls) => {
303 let context = tls.reloading_context(self.cfg.reload_certs)?;
304 (
305 Some(ReloadingTlsConfig {
306 context: context.clone(),
307 mode: TlsMode::Require,
308 }),
309 Some(context),
310 )
311 }
312 None => (None, None),
313 };
314
315 let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
316
317 let mut set = JoinSet::new();
318 let mut server_handles = Vec::new();
319 let pgwire_addr = self.pgwire.0.local_addr();
320 let https_addr = self.https.0.local_addr();
321 let internal_http_addr = self.internal_http.0.local_addr();
322
323 {
324 let pgwire = PgwireBalancer {
325 resolver: Arc::new(self.cfg.resolver),
326 cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
327 tls: pgwire_tls,
328 internal_tls: self.cfg.internal_tls,
329 metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
330 now: SYSTEM_TIME.clone(),
331 };
332 let (handle, stream) = self.pgwire;
333 server_handles.push(handle);
334 set.spawn_named(|| "pgwire_stream", {
335 let config_set = self.configs.clone();
336 async move {
337 mz_server_core::serve(ServeConfig {
338 server: pgwire,
339 conns: stream,
340 dyncfg: Some(ServeDyncfg {
341 config_set,
342 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
343 }),
344 })
345 .await;
346 warn!("pgwire server exited");
347 }
348 });
349 }
350 {
351 let Some((addr, port)) = self.cfg.https_sni_addr_template.split_once(':') else {
352 panic!("expected port in https_addr_template");
353 };
354 let port: u16 = port.parse().expect("unexpected port");
355 let resolver = StubResolver::new();
356 let https = HttpsBalancer {
357 resolver: Arc::from(resolver),
358 tls: https_tls,
359 resolve_template: Arc::from(addr),
360 port,
361 metrics: Arc::from(ServerMetrics::new(metrics, "https")),
362 configs: self.configs.clone(),
363 internal_tls: self.cfg.internal_tls,
364 };
365 let (handle, stream) = self.https;
366 server_handles.push(handle);
367 set.spawn_named(|| "https_stream", {
368 let config_set = self.configs.clone();
369 async move {
370 mz_server_core::serve(ServeConfig {
371 server: https,
372 conns: stream,
373 dyncfg: Some(ServeDyncfg {
374 config_set,
375 sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
376 }),
377 })
378 .await;
379 warn!("https server exited");
380 }
381 });
382 }
383 {
384 let router = Router::new()
385 .route(
386 "/metrics",
387 routing::get(move || async move {
388 mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
389 }),
390 )
391 .route(
392 "/api/livez",
393 routing::get(mz_http_util::handle_liveness_check),
394 )
395 .route("/api/readyz", routing::get(handle_readiness_check));
396 let internal_http = InternalHttpServer { router };
397 let (handle, stream) = self.internal_http;
398 server_handles.push(handle);
399 set.spawn_named(|| "internal_http_stream", async move {
400 mz_server_core::serve(ServeConfig {
401 server: internal_http,
402 conns: stream,
403 dyncfg: None,
406 })
407 .await;
408 warn!("internal_http server exited");
409 });
410 }
411 #[cfg(unix)]
412 {
413 let mut sigterm =
414 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
415 set.spawn_named(|| "sigterm_handler", async move {
416 sigterm.recv().await;
417 let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
418 warn!("received signal TERM - delaying for {:?}!", wait);
419 tokio::time::sleep(wait).await;
420 warn!("sigterm delay complete, dropping server handles");
421 drop(server_handles);
422 });
423 }
424
425 println!("balancerd {} listening...", BUILD_INFO.human_version(None));
426 println!(" TLS enabled: {}", self.cfg.tls.is_some());
427 println!(" pgwire address: {}", pgwire_addr);
428 println!(" HTTPS address: {}", https_addr);
429 println!(" internal HTTP address: {}", internal_http_addr);
430
431 while let Some(res) = set.join_next().await {
433 if let Err(err) = res {
434 error!("serving task failed: {err}")
435 }
436 }
437 Ok(())
438 }
439}
440
441#[allow(clippy::unused_async)]
442async fn handle_readiness_check() -> impl IntoResponse {
443 (StatusCode::OK, "ready")
444}
445
446struct InternalHttpServer {
447 router: Router,
448}
449
450impl mz_server_core::Server for InternalHttpServer {
451 const NAME: &'static str = "internal_http";
452
453 fn handle_connection(
455 &self,
456 conn: Connection,
457 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
458 ) -> mz_server_core::ConnectionHandler {
459 let router = self.router.clone();
460 let service = hyper::service::service_fn(move |req| router.clone().call(req));
461 let conn = TokioIo::new(conn);
462
463 Box::pin(async {
464 let http = hyper::server::conn::http1::Builder::new();
465 http.serve_connection(conn, service).err_into().await
466 })
467 }
468}
469
470struct GaugeGuard {
473 gauge: IntGauge,
474}
475
476impl From<IntGauge> for GaugeGuard {
477 fn from(gauge: IntGauge) -> Self {
478 let _self = Self { gauge };
479 _self.gauge.inc();
480 _self
481 }
482}
483
484impl Drop for GaugeGuard {
485 fn drop(&mut self) {
486 self.gauge.dec();
487 }
488}
489
490#[derive(Clone, Debug)]
491struct ServerMetricsConfig {
492 connection_status: IntCounterVec,
493 active_connections: IntGaugeVec,
494 tenant_connections: IntGaugeVec,
495 tenant_connection_rx: IntCounterVec,
496 tenant_connection_tx: IntCounterVec,
497 tenant_pgwire_sni_count: IntCounterVec,
498}
499
500impl ServerMetricsConfig {
501 fn register_into(registry: &MetricsRegistry) -> Self {
502 let connection_status = registry.register(metric!(
503 name: "mz_balancer_connection_status",
504 help: "Count of completed network connections, by status",
505 var_labels: ["source", "status"],
506 ));
507 let active_connections = registry.register(metric!(
508 name: "mz_balancer_connection_active",
509 help: "Count of currently open network connections.",
510 var_labels: ["source"],
511 ));
512 let tenant_connections = registry.register(metric!(
513 name: "mz_balancer_tenant_connection_active",
514 help: "Count of opened network connections by tenant.",
515 var_labels: ["source", "tenant"]
516 ));
517 let tenant_connection_rx = registry.register(metric!(
518 name: "mz_balancer_tenant_connection_rx",
519 help: "Number of bytes received from a client for a tenant.",
520 var_labels: ["source", "tenant"],
521 ));
522 let tenant_connection_tx = registry.register(metric!(
523 name: "mz_balancer_tenant_connection_tx",
524 help: "Number of bytes sent to a client for a tenant.",
525 var_labels: ["source", "tenant"],
526 ));
527 let tenant_pgwire_sni_count = registry.register(metric!(
528 name: "mz_balancer_tenant_pgwire_sni_count",
529 help: "Count of pgwire connections that have and do not have SNI available per tenant.",
530 var_labels: ["tenant", "has_sni"],
531 ));
532 Self {
533 connection_status,
534 active_connections,
535 tenant_connections,
536 tenant_connection_rx,
537 tenant_connection_tx,
538 tenant_pgwire_sni_count,
539 }
540 }
541}
542
543#[derive(Clone, Debug)]
544struct ServerMetrics {
545 inner: ServerMetricsConfig,
546 source: &'static str,
547}
548
549impl ServerMetrics {
550 fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
551 let self_ = Self { inner, source };
552
553 self_.connection_status(false);
556 self_.connection_status(true);
557 drop(self_.active_connections());
558
559 self_
560 }
561
562 fn connection_status(&self, is_ok: bool) -> IntCounter {
563 self.inner
564 .connection_status
565 .with_label_values(&[self.source, Self::status_label(is_ok)])
566 }
567
568 fn active_connections(&self) -> GaugeGuard {
569 self.inner
570 .active_connections
571 .with_label_values(&[self.source])
572 .into()
573 }
574
575 fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
576 self.inner
577 .tenant_connections
578 .with_label_values(&[self.source, tenant])
579 .into()
580 }
581
582 fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
583 self.inner
584 .tenant_connection_rx
585 .with_label_values(&[self.source, tenant])
586 }
587
588 fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
589 self.inner
590 .tenant_connection_tx
591 .with_label_values(&[self.source, tenant])
592 }
593
594 fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
595 self.inner
596 .tenant_pgwire_sni_count
597 .with_label_values(&[tenant, &has_sni.to_string()])
598 }
599
600 fn status_label(is_ok: bool) -> &'static str {
601 if is_ok { "success" } else { "error" }
602 }
603}
604
605pub enum CancellationResolver {
606 Directory(PathBuf),
607 Static(String),
608}
609
610struct PgwireBalancer {
611 tls: Option<ReloadingTlsConfig>,
612 internal_tls: bool,
613 cancellation_resolver: Arc<CancellationResolver>,
614 resolver: Arc<Resolver>,
615 metrics: ServerMetrics,
616 now: NowFn,
617}
618
619impl PgwireBalancer {
620 #[mz_ore::instrument(level = "debug")]
621 async fn run<'a, A>(
622 conn: &'a mut FramedConn<A>,
623 version: i32,
624 params: BTreeMap<String, String>,
625 resolver: &Resolver,
626 tls_mode: Option<TlsMode>,
627 internal_tls: bool,
628 metrics: &ServerMetrics,
629 ) -> Result<(), io::Error>
630 where
631 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
632 {
633 if version != VERSION_3 {
634 return conn
635 .send(ErrorResponse::fatal(
636 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
637 "server does not support the client's requested protocol version",
638 ))
639 .await;
640 }
641
642 let Some(user) = params.get("user") else {
643 return conn
644 .send(ErrorResponse::fatal(
645 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
646 "user parameter required",
647 ))
648 .await;
649 };
650
651 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
652 return conn.send(err).await;
653 }
654
655 let resolved = match resolver.resolve(conn, user, metrics).await {
656 Ok(v) => v,
657 Err(err) => {
658 return conn
659 .send(ErrorResponse::fatal(
660 SqlState::INVALID_PASSWORD,
661 err.to_string(),
662 ))
663 .await;
664 }
665 };
666
667 let _active_guard = resolved
668 .tenant
669 .as_ref()
670 .map(|tenant| metrics.tenant_connections(tenant));
671 let mut mz_stream =
672 match Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls)
673 .await
674 {
675 Ok(stream) => stream,
676 Err(e) => {
677 error!("failed to connect to upstream server: {e}");
678 return conn
679 .send(ErrorResponse::fatal(
680 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
681 "upstream server not available",
682 ))
683 .await;
684 }
685 };
686
687 let mut client_counter = CountingConn::new(conn.inner_mut());
688
689 let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
692 if let Some(tenant) = &resolved.tenant {
693 metrics
694 .tenant_connections_tx(tenant)
695 .inc_by(u64::cast_from(client_counter.written));
696 metrics
697 .tenant_connections_rx(tenant)
698 .inc_by(u64::cast_from(client_counter.read));
699 }
700 res?;
701
702 Ok(())
703 }
704
705 #[mz_ore::instrument(level = "debug")]
706 async fn init_stream<'a, A>(
707 conn: &'a mut FramedConn<A>,
708 envd_addr: SocketAddr,
709 password: Option<String>,
710 params: BTreeMap<String, String>,
711 internal_tls: bool,
712 ) -> Result<Conn<TcpStream>, anyhow::Error>
713 where
714 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
715 {
716 let mut mz_stream = TcpStream::connect(envd_addr).await?;
717 let mut buf = BytesMut::new();
718
719 let mut mz_stream = if internal_tls {
720 FrontendStartupMessage::SslRequest.encode(&mut buf)?;
721 mz_stream.write_all(&buf).await?;
722 buf.clear();
723 let mut maybe_ssl_request_response = [0u8; 1];
724 let nread =
725 netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
726 if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
727 let mut builder =
729 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
730 builder.set_verify(SslVerifyMode::NONE);
732 let mut ssl = builder
733 .build()
734 .configure()?
735 .into_ssl(&envd_addr.to_string())?;
736 ssl.set_connect_state();
737 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
738 } else {
739 Conn::Unencrypted(mz_stream)
740 }
741 } else {
742 Conn::Unencrypted(mz_stream)
743 };
744
745 let startup = FrontendStartupMessage::Startup {
747 version: VERSION_3,
748 params,
749 };
750 startup.encode(&mut buf)?;
751 mz_stream.write_all(&buf).await?;
752 let client_stream = conn.inner_mut();
753
754 if password.is_none() {
770 return Ok(mz_stream);
771 }
772
773 let mut maybe_auth_frame = [0; 1 + 4 + 4];
777 let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
778 const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
781 if nread == AUTH_PASSWORD_CLEARTEXT.len()
782 && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
783 && password.is_some()
784 {
785 let Some(password) = password else {
787 unreachable!("verified some above");
788 };
789 let password = FrontendMessage::Password { password };
790 buf.clear();
791 password.encode(&mut buf)?;
792 mz_stream.write_all(&buf).await?;
793 mz_stream.flush().await?;
794 } else {
795 client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
798 }
799
800 Ok(mz_stream)
801 }
802}
803
804impl mz_server_core::Server for PgwireBalancer {
805 const NAME: &'static str = "pgwire_balancer";
806
807 fn handle_connection(
808 &self,
809 conn: Connection,
810 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
811 ) -> mz_server_core::ConnectionHandler {
812 let tls = self.tls.clone();
813 let internal_tls = self.internal_tls;
814 let resolver = Arc::clone(&self.resolver);
815 let inner_metrics = self.metrics.clone();
816 let outer_metrics = self.metrics.clone();
817 let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
818 let conn_uuid = epoch_to_uuid_v7(&(self.now)());
819 let peer_addr = conn.peer_addr();
820 conn.uuid_handle().set(conn_uuid);
821 Box::pin(async move {
822 let active_guard = outer_metrics.active_connections();
825 let result: Result<(), anyhow::Error> = async move {
826 let mut conn = Conn::Unencrypted(conn);
827 loop {
828 let message = decode_startup(&mut conn).await?;
829 conn = match message {
830 None => return Ok(()),
834
835 Some(FrontendStartupMessage::Startup {
836 version,
837 mut params,
838 }) => {
839 let mut conn = FramedConn::new(conn);
840 let rejected =
841 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION;
842 let peer_addr = match peer_addr {
843 Ok(addr) => addr.ip(),
844 Err(e) => {
845 error!("Invalid peer_addr {:?}", e);
846 return Ok(conn
847 .send(ErrorResponse::fatal(
848 rejected,
849 "invalid peer address",
850 ))
851 .await?);
852 }
853 };
854 debug!(
855 %conn_uuid, %peer_addr,
856 "starting new pgwire connection in balancer",
857 );
858 let prev =
859 params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
860 if prev.is_some() {
861 return Ok(conn
862 .send(ErrorResponse::fatal(
863 rejected,
864 format!("invalid parameter '{CONN_UUID_KEY}'"),
865 ))
866 .await?);
867 }
868
869 let forwarded_for = params.insert(
870 MZ_FORWARDED_FOR_KEY.to_string(),
871 peer_addr.to_string().clone(),
872 );
873 if let Some(_) = forwarded_for {
874 return Ok(conn
875 .send(ErrorResponse::fatal(
876 rejected,
877 format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
878 ))
879 .await?);
880 };
881
882 Self::run(
883 &mut conn,
884 version,
885 params,
886 &resolver,
887 tls.map(|tls| tls.mode),
888 internal_tls,
889 &inner_metrics,
890 )
891 .await?;
892 conn.flush().await?;
893 return Ok(());
894 }
895
896 Some(FrontendStartupMessage::CancelRequest {
897 conn_id,
898 secret_key,
899 }) => {
900 spawn(|| "cancel request", async move {
901 cancel_request(conn_id, secret_key, &cancellation_resolver).await;
902 });
903 return Ok(());
906 }
907
908 Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
909 (Conn::Unencrypted(mut conn), Some(tls)) => {
910 conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
911 let mut ssl_stream =
912 SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
913 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
914 let _ = ssl_stream.get_mut().shutdown().await;
915 return Err(e.into());
916 }
917 Conn::Ssl(ssl_stream)
918 }
919 (mut conn, _) => {
920 conn.write_all(&[REJECT_ENCRYPTION]).await?;
921 conn
922 }
923 },
924
925 Some(FrontendStartupMessage::GssEncRequest) => {
926 conn.write_all(&[REJECT_ENCRYPTION]).await?;
927 conn
928 }
929 }
930 }
931 }
932 .await;
933 drop(active_guard);
934 outer_metrics.connection_status(result.is_ok()).inc();
935 Ok(())
936 })
937 }
938}
939
940struct CountingConn<C> {
942 inner: C,
943 read: usize,
944 written: usize,
945}
946
947impl<C> CountingConn<C> {
948 fn new(inner: C) -> Self {
949 CountingConn {
950 inner,
951 read: 0,
952 written: 0,
953 }
954 }
955}
956
957impl<C> AsyncRead for CountingConn<C>
958where
959 C: AsyncRead + Unpin,
960{
961 fn poll_read(
962 self: Pin<&mut Self>,
963 cx: &mut std::task::Context<'_>,
964 buf: &mut io::ReadBuf<'_>,
965 ) -> std::task::Poll<std::io::Result<()>> {
966 let counter = self.get_mut();
967 let pin = Pin::new(&mut counter.inner);
968 let bytes = buf.filled().len();
969 let poll = pin.poll_read(cx, buf);
970 let bytes = buf.filled().len() - bytes;
971 if let std::task::Poll::Ready(Ok(())) = poll {
972 counter.read += bytes
973 }
974 poll
975 }
976}
977
978impl<C> AsyncWrite for CountingConn<C>
979where
980 C: AsyncWrite + Unpin,
981{
982 fn poll_write(
983 self: Pin<&mut Self>,
984 cx: &mut std::task::Context<'_>,
985 buf: &[u8],
986 ) -> std::task::Poll<Result<usize, std::io::Error>> {
987 let counter = self.get_mut();
988 let pin = Pin::new(&mut counter.inner);
989 let poll = pin.poll_write(cx, buf);
990 if let std::task::Poll::Ready(Ok(bytes)) = poll {
991 counter.written += bytes
992 }
993 poll
994 }
995
996 fn poll_flush(
997 self: Pin<&mut Self>,
998 cx: &mut std::task::Context<'_>,
999 ) -> std::task::Poll<Result<(), std::io::Error>> {
1000 let counter = self.get_mut();
1001 let pin = Pin::new(&mut counter.inner);
1002 pin.poll_flush(cx)
1003 }
1004
1005 fn poll_shutdown(
1006 self: Pin<&mut Self>,
1007 cx: &mut std::task::Context<'_>,
1008 ) -> std::task::Poll<Result<(), std::io::Error>> {
1009 let counter = self.get_mut();
1010 let pin = Pin::new(&mut counter.inner);
1011 pin.poll_shutdown(cx)
1012 }
1013}
1014
1015async fn cancel_request(
1032 conn_id: u32,
1033 secret_key: u32,
1034 cancellation_resolver: &CancellationResolver,
1035) {
1036 let suffix = conn_id_org_uuid(conn_id);
1037 let contents = match cancellation_resolver {
1038 CancellationResolver::Directory(dir) => {
1039 let path = dir.join(&suffix);
1040 match std::fs::read_to_string(&path) {
1041 Ok(contents) => contents,
1042 Err(err) => {
1043 error!("could not read cancel file {path:?}: {err}");
1044 return;
1045 }
1046 }
1047 }
1048 CancellationResolver::Static(addr) => addr.to_owned(),
1049 };
1050 let mut all_ips = Vec::new();
1051 for addr in contents.lines() {
1052 let addr = addr.trim();
1053 if addr.is_empty() {
1054 continue;
1055 }
1056 match tokio::net::lookup_host(addr).await {
1057 Ok(ips) => all_ips.extend(ips),
1058 Err(err) => {
1059 error!("{addr} failed resolution: {err}");
1060 }
1061 }
1062 }
1063 let mut buf = BytesMut::with_capacity(16);
1064 let msg = FrontendStartupMessage::CancelRequest {
1065 conn_id,
1066 secret_key,
1067 };
1068 msg.encode(&mut buf).expect("must encode");
1069 let buf = buf.freeze();
1070 for ip in all_ips {
1071 debug!("cancelling {suffix} to {ip}");
1072 let buf = buf.clone();
1073 spawn(|| "cancel request for ip", async move {
1074 let send = async {
1075 let mut stream = TcpStream::connect(&ip).await?;
1076 stream.write_all(&buf).await?;
1077 stream.shutdown().await?;
1078 Ok::<_, io::Error>(())
1079 };
1080 if let Err(err) = send.await {
1081 error!("error mirroring cancel to {ip}: {err}");
1082 }
1083 });
1084 }
1085}
1086
1087struct HttpsBalancer {
1088 resolver: Arc<StubResolver>,
1089 tls: Option<ReloadingSslContext>,
1090 resolve_template: Arc<str>,
1091 port: u16,
1092 metrics: Arc<ServerMetrics>,
1093 configs: ConfigSet,
1094 internal_tls: bool,
1095}
1096
1097impl HttpsBalancer {
1098 async fn resolve(
1099 resolver: &StubResolver,
1100 resolve_template: &str,
1101 port: u16,
1102 servername: Option<&str>,
1103 ) -> Result<ResolvedAddr, anyhow::Error> {
1104 let addr = match &servername {
1105 Some(servername) => resolve_template.replace("{}", servername),
1106 None => resolve_template.to_string(),
1107 };
1108 debug!("https address: {addr}");
1109
1110 let tenant = resolver.tenant(&addr).await;
1123
1124 let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1126
1127 Ok(ResolvedAddr {
1128 addr: envd_addr,
1129 password: None,
1130 tenant,
1131 })
1132 }
1133}
1134
1135trait StubResolverExt {
1136 async fn tenant(&self, addr: &str) -> Option<String>;
1137}
1138
1139impl StubResolverExt for StubResolver {
1140 async fn tenant(&self, addr: &str) -> Option<String> {
1143 let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1144 return None;
1145 };
1146 debug!("resolving tenant for {:?}", addr);
1147 let lookup = self.query((dname, Rtype::CNAME)).await;
1149 if let Ok(lookup) = lookup {
1150 if let Ok(answer) = lookup.answer() {
1151 let res = answer.limit_to::<AllRecordData<_, _>>();
1152 for record in res {
1153 let Ok(record) = record else {
1154 continue;
1155 };
1156 if record.rtype() != Rtype::CNAME {
1157 continue;
1158 }
1159 let cname = record.data();
1160 let cname = cname.to_string();
1161 debug!("cname: {cname}");
1162 return extract_tenant_from_cname(&cname);
1163 }
1164 }
1165 }
1166 None
1167 }
1168}
1169
1170fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1172 let mut parts = cname.split('.');
1173 let _service = parts.next();
1174 let Some(namespace) = parts.next() else {
1175 return None;
1176 };
1177 let Some((_, namespace)) = namespace.split_once('-') else {
1179 return None;
1180 };
1181 let Some((tenant, _)) = namespace.rsplit_once('-') else {
1183 return None;
1184 };
1185 let Ok(tenant) = Uuid::parse_str(tenant) else {
1188 error!("cname tenant not a uuid: {tenant}");
1189 return None;
1190 };
1191 Some(tenant.to_string())
1192}
1193
1194impl mz_server_core::Server for HttpsBalancer {
1195 const NAME: &'static str = "https_balancer";
1196
1197 fn handle_connection(
1199 &self,
1200 conn: Connection,
1201 _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1202 ) -> mz_server_core::ConnectionHandler {
1203 let tls_context = self.tls.clone();
1204 let internal_tls = self.internal_tls.clone();
1205 let resolver = Arc::clone(&self.resolver);
1206 let resolve_template = Arc::clone(&self.resolve_template);
1207 let port = self.port;
1208 let inner_metrics = Arc::clone(&self.metrics);
1209 let outer_metrics = Arc::clone(&self.metrics);
1210 let peer_addr = conn.peer_addr();
1211 let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1212 Box::pin(async move {
1213 let active_guard = inner_metrics.active_connections();
1214 let result: Result<_, anyhow::Error> = Box::pin(async move {
1215 let peer_addr = peer_addr.context("fetching peer addr")?;
1216 let (mut client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1217 match tls_context {
1218 Some(tls_context) => {
1219 let mut ssl_stream =
1220 SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1221 if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1222 let _ = ssl_stream.get_mut().shutdown().await;
1223 return Err(e.into());
1224 }
1225 let servername: Option<String> =
1226 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1227 match sn.split_once('.') {
1228 Some((left, _right)) => left,
1229 None => sn,
1230 }
1231 .into()
1232 });
1233 debug!("Found sni servername: {servername:?} (https)");
1234 (Box::new(ssl_stream), servername)
1235 }
1236 _ => (Box::new(conn), None),
1237 };
1238 let resolved =
1239 Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1240 .await?;
1241 let inner_active_guard = resolved
1242 .tenant
1243 .as_ref()
1244 .map(|tenant| inner_metrics.tenant_connections(tenant));
1245 let mut mz_stream = match TcpStream::connect(resolved.addr).await {
1246 Ok(stream) => stream,
1247 Err(e) => {
1248 error!("failed to connect to upstream server: {e}");
1249 let body = "upstream server not available";
1250 let response = format!(
1256 "HTTP/1.1 502 Bad Gateway\r\n\
1257 Content-Type: text/plain\r\n\
1258 Content-Length: {}\r\n\
1259 Connection: close\r\n\
1260 \r\n\
1261 {}",
1262 body.len(),
1263 body
1264 );
1265 let _ = client_stream.write_all(response.as_bytes()).await;
1266 let _ = client_stream.shutdown().await;
1267 return Ok(());
1268 }
1269 };
1270
1271 if inject_proxy_headers {
1272 let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1274 let header = ProxyHeader::with_address(addrs);
1275 let mut buf = [0u8; 1024];
1276 let len = header.encode_to_slice_v2(&mut buf)?;
1277 mz_stream.write_all(&buf[..len]).await?;
1278 }
1279
1280 let mut mz_stream = if internal_tls {
1281 let mut builder =
1283 SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1284 builder.set_verify(SslVerifyMode::NONE);
1286 let mut ssl = builder
1287 .build()
1288 .configure()?
1289 .into_ssl(&resolved.addr.to_string())?;
1290 ssl.set_connect_state();
1291 Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1292 } else {
1293 Conn::Unencrypted(mz_stream)
1294 };
1295
1296 let mut client_counter = CountingConn::new(client_stream);
1297
1298 let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1302 if let Some(tenant) = &resolved.tenant {
1303 inner_metrics
1304 .tenant_connections_tx(tenant)
1305 .inc_by(u64::cast_from(client_counter.written));
1306 inner_metrics
1307 .tenant_connections_rx(tenant)
1308 .inc_by(u64::cast_from(client_counter.read));
1309 }
1310 drop(inner_active_guard);
1311 Ok(())
1312 })
1313 .await;
1314 drop(active_guard);
1315 outer_metrics.connection_status(result.is_ok()).inc();
1316 if let Err(e) = result {
1317 debug!("connection error: {e}");
1318 }
1319 Ok(())
1320 })
1321 }
1322}
1323
1324#[derive(Debug)]
1325pub struct SniResolver {
1326 pub resolver: StubResolver,
1327 pub template: String,
1328 pub port: u16,
1329}
1330
1331trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1332impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1333
1334#[derive(Debug)]
1335pub enum Resolver {
1336 Static(String),
1337 MultiTenant(FronteggResolver, Option<SniResolver>),
1338}
1339
1340impl Resolver {
1341 async fn resolve<A>(
1342 &self,
1343 conn: &mut FramedConn<A>,
1344 user: &str,
1345 metrics: &ServerMetrics,
1346 ) -> Result<ResolvedAddr, anyhow::Error>
1347 where
1348 A: AsyncRead + AsyncWrite + Unpin,
1349 {
1350 match self {
1351 Resolver::MultiTenant(
1352 FronteggResolver {
1353 auth,
1354 addr_template,
1355 },
1356 sni_resolver,
1357 ) => {
1358 let servername = match conn.inner() {
1359 Conn::Ssl(ssl_stream) => {
1360 ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1361 match sn.split_once('.') {
1362 Some((left, _right)) => left,
1363 None => sn,
1364 }
1365 })
1366 }
1367 Conn::Unencrypted(_) => None,
1368 };
1369 let has_sni = servername.is_some();
1370 let resolved_addr = match (servername, sni_resolver) {
1372 (
1373 Some(servername),
1374 Some(SniResolver {
1375 resolver: stub_resolver,
1376 template: sni_addr_template,
1377 port,
1378 }),
1379 ) => {
1380 let sni_addr = sni_addr_template.replace("{}", servername);
1381 let tenant = stub_resolver.tenant(&sni_addr).await;
1382 let sni_addr = format!("{sni_addr}:{port}");
1383 let addr = lookup(&sni_addr).await?;
1384 if tenant.is_some() {
1385 debug!("SNI header found for tenant {:?}", tenant);
1386 }
1387 ResolvedAddr {
1388 addr,
1389 password: None,
1390 tenant,
1391 }
1392 }
1393 _ => {
1394 conn.send(BackendMessage::AuthenticationCleartextPassword)
1395 .await?;
1396 conn.flush().await?;
1397 let password = match conn.recv().await? {
1398 Some(FrontendMessage::Password { password }) => password,
1399 _ => anyhow::bail!("expected Password message"),
1400 };
1401
1402 let auth_response = auth.authenticate(user, &password).await;
1403 let auth_session = match auth_response {
1404 Ok((auth_session, _)) => auth_session,
1405 Err(e) => {
1406 warn!("pgwire connection failed authentication: {}", e);
1407 anyhow::bail!("invalid password");
1409 }
1410 };
1411
1412 let addr =
1413 addr_template.replace("{}", &auth_session.tenant_id().to_string());
1414 let addr = lookup(&addr).await?;
1415 let tenant = Some(auth_session.tenant_id().to_string());
1416 if tenant.is_some() {
1417 debug!("SNI header NOT found for tenant {:?}", tenant);
1418 }
1419 ResolvedAddr {
1420 addr,
1421 password: Some(password),
1422 tenant,
1423 }
1424 }
1425 };
1426 metrics
1427 .tenant_pgwire_sni_count(
1428 resolved_addr.tenant.as_deref().unwrap_or("unknown"),
1429 has_sni,
1430 )
1431 .inc();
1432
1433 Ok(resolved_addr)
1434 }
1435 Resolver::Static(addr) => {
1436 let addr = lookup(addr).await?;
1437 Ok(ResolvedAddr {
1438 addr,
1439 password: None,
1440 tenant: None,
1441 })
1442 }
1443 }
1444 }
1445}
1446
1447async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1449 let mut addrs = tokio::net::lookup_host(name).await?;
1450 match addrs.next() {
1451 Some(addr) => Ok(addr),
1452 None => {
1453 error!("{name} did not resolve to any addresses");
1454 anyhow::bail!("internal error")
1455 }
1456 }
1457}
1458
1459#[derive(Debug)]
1460pub struct FronteggResolver {
1461 pub auth: FronteggAuthentication,
1462 pub addr_template: String,
1463}
1464
1465#[derive(Debug)]
1466struct ResolvedAddr {
1467 addr: SocketAddr,
1468 password: Option<String>,
1469 tenant: Option<String>,
1470}
1471
1472#[cfg(test)]
1473mod tests {
1474 use super::*;
1475
1476 #[mz_ore::test]
1477 fn test_tenant() {
1478 let tests = vec![
1479 ("", None),
1480 (
1481 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1482 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1483 ),
1484 (
1485 "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1487 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1488 ),
1489 (
1490 "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1492 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1493 ),
1494 (
1495 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1497 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1498 ),
1499 (
1500 "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1502 Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1503 ),
1504 (
1505 "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1507 None,
1508 ),
1509 (
1510 "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1512 None,
1513 ),
1514 (
1515 "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1517 None,
1518 ),
1519 ];
1520 for (name, expect) in tests {
1521 let cname = extract_tenant_from_cname(name);
1522 assert_eq!(
1523 cname.as_deref(),
1524 expect,
1525 "{name} got {cname:?} expected {expect:?}"
1526 );
1527 }
1528 }
1529}