Skip to main content

mz_balancerd/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The balancerd service is a horizontally scalable, stateless, multi-tenant ingress router for
11//! pgwire and HTTPS connections.
12//!
13//! It listens on pgwire and HTTPS ports. When a new pgwire connection starts, the requested user is
14//! authenticated with frontegg from which a tenant id is returned. From that a target internal
15//! hostname is resolved to an IP address, and the connection is proxied to that address which has a
16//! running environmentd's pgwire port. When a new HTTPS connection starts, its SNI hostname is used
17//! to generate an internal hostname that is resolved to an IP address, which is similarly proxied.
18
19mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55    FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58    Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59    ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78    INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79    has_tracing_config_update, tracing_config,
80};
81
82/// Balancer build information.
83pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86    /// Info about which version of the code is running.
87    build_version: Version,
88    /// Listen address for internal HTTP health and metrics server.
89    internal_http_listen_addr: SocketAddr,
90    /// Listen address for pgwire connections.
91    pgwire_listen_addr: SocketAddr,
92    /// Listen address for HTTPS connections.
93    https_listen_addr: SocketAddr,
94    /// DNS resolver for pgwire cancellation requests
95    cancellation_resolver: CancellationResolver,
96    /// DNS resolver.
97    resolver: Resolver,
98    https_sni_addr_template: String,
99    tls: Option<TlsCertConfig>,
100    internal_tls: bool,
101    metrics_registry: MetricsRegistry,
102    reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103    launchdarkly_sdk_key: Option<String>,
104    config_sync_file_path: Option<PathBuf>,
105    config_sync_timeout: Duration,
106    config_sync_loop_interval: Option<Duration>,
107    cloud_provider: Option<String>,
108    cloud_provider_region: Option<String>,
109    tracing_handle: TracingHandle,
110    default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114    pub fn new(
115        build_info: &BuildInfo,
116        internal_http_listen_addr: SocketAddr,
117        pgwire_listen_addr: SocketAddr,
118        https_listen_addr: SocketAddr,
119        cancellation_resolver: CancellationResolver,
120        resolver: Resolver,
121        https_sni_addr_template: String,
122        tls: Option<TlsCertConfig>,
123        internal_tls: bool,
124        metrics_registry: MetricsRegistry,
125        reload_certs: ReloadTrigger,
126        launchdarkly_sdk_key: Option<String>,
127        config_sync_file: Option<PathBuf>,
128        config_sync_timeout: Duration,
129        config_sync_loop_interval: Option<Duration>,
130        cloud_provider: Option<String>,
131        cloud_provider_region: Option<String>,
132        tracing_handle: TracingHandle,
133        default_configs: Vec<(String, String)>,
134    ) -> Self {
135        Self {
136            build_version: build_info.semver_version(),
137            internal_http_listen_addr,
138            pgwire_listen_addr,
139            https_listen_addr,
140            cancellation_resolver,
141            resolver,
142            https_sni_addr_template,
143            tls,
144            internal_tls,
145            metrics_registry,
146            reload_certs,
147            launchdarkly_sdk_key,
148            config_sync_file_path: config_sync_file,
149            config_sync_timeout,
150            config_sync_loop_interval,
151            cloud_provider,
152            cloud_provider_region,
153            tracing_handle,
154            default_configs,
155        }
156    }
157}
158
159/// Prometheus monitoring metrics.
160#[derive(Debug)]
161pub struct BalancerMetrics {
162    _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166    /// Returns a new [BalancerMetrics] instance connected to the registry in cfg.
167    pub fn new(cfg: &BalancerConfig) -> Self {
168        let start = Instant::now();
169        let uptime = cfg.metrics_registry.register_computed_gauge(
170            metric!(
171                name: "mz_balancer_metadata_seconds",
172                help: "server uptime, labels are build metadata",
173                const_labels: {
174                    "version" => cfg.build_version,
175                    "build_type" => if cfg!(release) { "release" } else { "debug" }
176                },
177            ),
178            move || start.elapsed().as_secs_f64(),
179        );
180        BalancerMetrics { _uptime: uptime }
181    }
182}
183
184pub struct BalancerService {
185    cfg: BalancerConfig,
186    pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187    pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188    pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189    _metrics: BalancerMetrics,
190    configs: ConfigSet,
191}
192
193impl BalancerService {
194    pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195        let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196        let https = listen(&cfg.https_listen_addr).await?;
197        let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198        let metrics = BalancerMetrics::new(&cfg);
199        let mut configs = ConfigSet::default();
200        configs = dyncfgs::all_dyncfgs(configs);
201        dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202        let tracing_handle = cfg.tracing_handle.clone();
203        // Configure dyncfg sync
204        match (
205            cfg.launchdarkly_sdk_key.as_deref(),
206            cfg.config_sync_file_path.as_deref(),
207        ) {
208            (Some(key), None) => {
209                let _ = mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210                    configs.clone(),
211                    &BUILD_INFO,
212                    |builder| {
213                        let region = cfg
214                            .cloud_provider_region
215                            .clone()
216                            .unwrap_or_else(|| String::from("unknown"));
217                        if let Some(provider) = cfg.cloud_provider.clone() {
218                            builder.add_context(
219                                ld::ContextBuilder::new(format!(
220                                    "{}/{}/{}",
221                                    provider, region, cfg.build_version
222                                ))
223                                .kind("balancer")
224                                .set_string("provider", provider)
225                                .set_string("region", region)
226                                .set_string("version", cfg.build_version.to_string())
227                                .build()
228                                .map_err(|e| anyhow::anyhow!(e))?,
229                            );
230                        } else {
231                            builder.add_context(
232                                ld::ContextBuilder::new(format!(
233                                    "{}/{}/{}",
234                                    "unknown", region, cfg.build_version
235                                ))
236                                .anonymous(true) // exclude this user from the dashboard
237                                .kind("balancer")
238                                .set_string("provider", "unknown")
239                                .set_string("region", region)
240                                .set_string("version", cfg.build_version.to_string())
241                                .build()
242                                .map_err(|e| anyhow::anyhow!(e))?,
243                            );
244                        }
245                        Ok(())
246                    },
247                    Some(key),
248                    cfg.config_sync_timeout,
249                    cfg.config_sync_loop_interval,
250                    move |updates, configs| {
251                        if has_tracing_config_update(updates) {
252                            match tracing_config(configs) {
253                                Ok(parameters) => parameters.apply(&tracing_handle),
254                                Err(err) => warn!("unable to update tracing: {err}"),
255                            }
256                        }
257                    },
258                )
259                .await
260                .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"));
261            }
262            (None, Some(path)) => {
263                let _ = mz_dyncfg_file::sync_file_to_configset(
264                    configs.clone(),
265                    path,
266                    cfg.config_sync_timeout,
267                    cfg.config_sync_loop_interval,
268                    move |updates, configs| {
269                        if has_tracing_config_update(updates) {
270                            match tracing_config(configs) {
271                                Ok(parameters) => parameters.apply(&tracing_handle),
272                                Err(err) => warn!("unable to update tracing: {err}"),
273                            }
274                        }
275                    },
276                )
277                .await
278                // If there's an Error, log but continue anyway. If LD is down
279                // we have no way of fetching the previous value of the flag
280                // (unlike the adapter, but it has a durable catalog). The
281                // ConfigSet defaults have been chosen to be good enough if this
282                // is the case.
283                .inspect_err(|e| warn!("File config sync error: {e}"));
284            }
285            (Some(_), Some(_)) => panic!(
286                "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
287            ),
288            (None, None) => {}
289        };
290        Ok(Self {
291            cfg,
292            pgwire,
293            https,
294            internal_http,
295            _metrics: metrics,
296            configs,
297        })
298    }
299
300    pub async fn serve(self) -> Result<(), anyhow::Error> {
301        let (pgwire_tls, https_tls) = match &self.cfg.tls {
302            Some(tls) => {
303                let context = tls.reloading_context(self.cfg.reload_certs)?;
304                (
305                    Some(ReloadingTlsConfig {
306                        context: context.clone(),
307                        mode: TlsMode::Require,
308                    }),
309                    Some(context),
310                )
311            }
312            None => (None, None),
313        };
314
315        let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
316
317        let mut set = JoinSet::new();
318        let mut server_handles = Vec::new();
319        let pgwire_addr = self.pgwire.0.local_addr();
320        let https_addr = self.https.0.local_addr();
321        let internal_http_addr = self.internal_http.0.local_addr();
322
323        {
324            let pgwire = PgwireBalancer {
325                resolver: Arc::new(self.cfg.resolver),
326                cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
327                tls: pgwire_tls,
328                internal_tls: self.cfg.internal_tls,
329                metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
330                now: SYSTEM_TIME.clone(),
331            };
332            let (handle, stream) = self.pgwire;
333            server_handles.push(handle);
334            set.spawn_named(|| "pgwire_stream", {
335                let config_set = self.configs.clone();
336                async move {
337                    mz_server_core::serve(ServeConfig {
338                        server: pgwire,
339                        conns: stream,
340                        dyncfg: Some(ServeDyncfg {
341                            config_set,
342                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
343                        }),
344                    })
345                    .await;
346                    warn!("pgwire server exited");
347                }
348            });
349        }
350        {
351            let Some((addr, port)) = self.cfg.https_sni_addr_template.split_once(':') else {
352                panic!("expected port in https_addr_template");
353            };
354            let port: u16 = port.parse().expect("unexpected port");
355            let resolver = StubResolver::new();
356            let https = HttpsBalancer {
357                resolver: Arc::from(resolver),
358                tls: https_tls,
359                resolve_template: Arc::from(addr),
360                port,
361                metrics: Arc::from(ServerMetrics::new(metrics, "https")),
362                configs: self.configs.clone(),
363                internal_tls: self.cfg.internal_tls,
364            };
365            let (handle, stream) = self.https;
366            server_handles.push(handle);
367            set.spawn_named(|| "https_stream", {
368                let config_set = self.configs.clone();
369                async move {
370                    mz_server_core::serve(ServeConfig {
371                        server: https,
372                        conns: stream,
373                        dyncfg: Some(ServeDyncfg {
374                            config_set,
375                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
376                        }),
377                    })
378                    .await;
379                    warn!("https server exited");
380                }
381            });
382        }
383        {
384            let router = Router::new()
385                .route(
386                    "/metrics",
387                    routing::get(move || async move {
388                        mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
389                    }),
390                )
391                .route(
392                    "/api/livez",
393                    routing::get(mz_http_util::handle_liveness_check),
394                )
395                .route("/api/readyz", routing::get(handle_readiness_check));
396            let internal_http = InternalHttpServer { router };
397            let (handle, stream) = self.internal_http;
398            server_handles.push(handle);
399            set.spawn_named(|| "internal_http_stream", async move {
400                mz_server_core::serve(ServeConfig {
401                    server: internal_http,
402                    conns: stream,
403                    // Disable graceful termination because our internal
404                    // monitoring keeps persistent HTTP connections open.
405                    dyncfg: None,
406                })
407                .await;
408                warn!("internal_http server exited");
409            });
410        }
411        #[cfg(unix)]
412        {
413            let mut sigterm =
414                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
415            set.spawn_named(|| "sigterm_handler", async move {
416                sigterm.recv().await;
417                let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
418                warn!("received signal TERM - delaying for {:?}!", wait);
419                tokio::time::sleep(wait).await;
420                warn!("sigterm delay complete, dropping server handles");
421                drop(server_handles);
422            });
423        }
424
425        println!("balancerd {} listening...", BUILD_INFO.human_version(None));
426        println!(" TLS enabled: {}", self.cfg.tls.is_some());
427        println!(" pgwire address: {}", pgwire_addr);
428        println!(" HTTPS address: {}", https_addr);
429        println!(" internal HTTP address: {}", internal_http_addr);
430
431        // Wait for all tasks to exit, which can happen on SIGTERM.
432        while let Some(res) = set.join_next().await {
433            if let Err(err) = res {
434                error!("serving task failed: {err}")
435            }
436        }
437        Ok(())
438    }
439}
440
441#[allow(clippy::unused_async)]
442async fn handle_readiness_check() -> impl IntoResponse {
443    (StatusCode::OK, "ready")
444}
445
446struct InternalHttpServer {
447    router: Router,
448}
449
450impl mz_server_core::Server for InternalHttpServer {
451    const NAME: &'static str = "internal_http";
452
453    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
454    fn handle_connection(
455        &self,
456        conn: Connection,
457        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
458    ) -> mz_server_core::ConnectionHandler {
459        let router = self.router.clone();
460        let service = hyper::service::service_fn(move |req| router.clone().call(req));
461        let conn = TokioIo::new(conn);
462
463        Box::pin(async {
464            let http = hyper::server::conn::http1::Builder::new();
465            http.serve_connection(conn, service).err_into().await
466        })
467    }
468}
469
470/// Wraps an IntGauge and automatically `inc`s on init and `drop`s on drop. Callers should not call
471/// `inc().`. Useful for handling multiple task exit points, for example in the case of a panic.
472struct GaugeGuard {
473    gauge: IntGauge,
474}
475
476impl From<IntGauge> for GaugeGuard {
477    fn from(gauge: IntGauge) -> Self {
478        let _self = Self { gauge };
479        _self.gauge.inc();
480        _self
481    }
482}
483
484impl Drop for GaugeGuard {
485    fn drop(&mut self) {
486        self.gauge.dec();
487    }
488}
489
490#[derive(Clone, Debug)]
491struct ServerMetricsConfig {
492    connection_status: IntCounterVec,
493    active_connections: IntGaugeVec,
494    tenant_connections: IntGaugeVec,
495    tenant_connection_rx: IntCounterVec,
496    tenant_connection_tx: IntCounterVec,
497    tenant_pgwire_sni_count: IntCounterVec,
498}
499
500impl ServerMetricsConfig {
501    fn register_into(registry: &MetricsRegistry) -> Self {
502        let connection_status = registry.register(metric!(
503            name: "mz_balancer_connection_status",
504            help: "Count of completed network connections, by status",
505            var_labels: ["source", "status"],
506        ));
507        let active_connections = registry.register(metric!(
508            name: "mz_balancer_connection_active",
509            help: "Count of currently open network connections.",
510            var_labels: ["source"],
511        ));
512        let tenant_connections = registry.register(metric!(
513            name: "mz_balancer_tenant_connection_active",
514            help: "Count of opened network connections by tenant.",
515            var_labels: ["source",  "tenant"]
516        ));
517        let tenant_connection_rx = registry.register(metric!(
518            name: "mz_balancer_tenant_connection_rx",
519            help: "Number of bytes received from a client for a tenant.",
520            var_labels: ["source", "tenant"],
521        ));
522        let tenant_connection_tx = registry.register(metric!(
523            name: "mz_balancer_tenant_connection_tx",
524            help: "Number of bytes sent to a client for a tenant.",
525            var_labels: ["source", "tenant"],
526        ));
527        let tenant_pgwire_sni_count = registry.register(metric!(
528            name: "mz_balancer_tenant_pgwire_sni_count",
529            help: "Count of pgwire connections that have and do not have SNI available per tenant.",
530            var_labels: ["tenant", "has_sni"],
531        ));
532        Self {
533            connection_status,
534            active_connections,
535            tenant_connections,
536            tenant_connection_rx,
537            tenant_connection_tx,
538            tenant_pgwire_sni_count,
539        }
540    }
541}
542
543#[derive(Clone, Debug)]
544struct ServerMetrics {
545    inner: ServerMetricsConfig,
546    source: &'static str,
547}
548
549impl ServerMetrics {
550    fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
551        let self_ = Self { inner, source };
552
553        // Pre-initialize labels we are planning to use to ensure they are all always emitted as
554        // time series.
555        self_.connection_status(false);
556        self_.connection_status(true);
557        drop(self_.active_connections());
558
559        self_
560    }
561
562    fn connection_status(&self, is_ok: bool) -> IntCounter {
563        self.inner
564            .connection_status
565            .with_label_values(&[self.source, Self::status_label(is_ok)])
566    }
567
568    fn active_connections(&self) -> GaugeGuard {
569        self.inner
570            .active_connections
571            .with_label_values(&[self.source])
572            .into()
573    }
574
575    fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
576        self.inner
577            .tenant_connections
578            .with_label_values(&[self.source, tenant])
579            .into()
580    }
581
582    fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
583        self.inner
584            .tenant_connection_rx
585            .with_label_values(&[self.source, tenant])
586    }
587
588    fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
589        self.inner
590            .tenant_connection_tx
591            .with_label_values(&[self.source, tenant])
592    }
593
594    fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
595        self.inner
596            .tenant_pgwire_sni_count
597            .with_label_values(&[tenant, &has_sni.to_string()])
598    }
599
600    fn status_label(is_ok: bool) -> &'static str {
601        if is_ok { "success" } else { "error" }
602    }
603}
604
605pub enum CancellationResolver {
606    Directory(PathBuf),
607    Static(String),
608}
609
610struct PgwireBalancer {
611    tls: Option<ReloadingTlsConfig>,
612    internal_tls: bool,
613    cancellation_resolver: Arc<CancellationResolver>,
614    resolver: Arc<Resolver>,
615    metrics: ServerMetrics,
616    now: NowFn,
617}
618
619impl PgwireBalancer {
620    #[mz_ore::instrument(level = "debug")]
621    async fn run<'a, A>(
622        conn: &'a mut FramedConn<A>,
623        version: i32,
624        params: BTreeMap<String, String>,
625        resolver: &Resolver,
626        tls_mode: Option<TlsMode>,
627        internal_tls: bool,
628        metrics: &ServerMetrics,
629    ) -> Result<(), io::Error>
630    where
631        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
632    {
633        if version != VERSION_3 {
634            return conn
635                .send(ErrorResponse::fatal(
636                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
637                    "server does not support the client's requested protocol version",
638                ))
639                .await;
640        }
641
642        let Some(user) = params.get("user") else {
643            return conn
644                .send(ErrorResponse::fatal(
645                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
646                    "user parameter required",
647                ))
648                .await;
649        };
650
651        if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
652            return conn.send(err).await;
653        }
654
655        let resolved = match resolver.resolve(conn, user, metrics).await {
656            Ok(v) => v,
657            Err(err) => {
658                return conn
659                    .send(ErrorResponse::fatal(
660                        SqlState::INVALID_PASSWORD,
661                        err.to_string(),
662                    ))
663                    .await;
664            }
665        };
666
667        let _active_guard = resolved
668            .tenant
669            .as_ref()
670            .map(|tenant| metrics.tenant_connections(tenant));
671        let mut mz_stream =
672            match Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls)
673                .await
674            {
675                Ok(stream) => stream,
676                Err(e) => {
677                    error!("failed to connect to upstream server: {e}");
678                    return conn
679                        .send(ErrorResponse::fatal(
680                            SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
681                            "upstream server not available",
682                        ))
683                        .await;
684                }
685            };
686
687        let mut client_counter = CountingConn::new(conn.inner_mut());
688
689        // Now blindly shuffle bytes back and forth until closed.
690        // TODO: Limit total memory use.
691        let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
692        if let Some(tenant) = &resolved.tenant {
693            metrics
694                .tenant_connections_tx(tenant)
695                .inc_by(u64::cast_from(client_counter.written));
696            metrics
697                .tenant_connections_rx(tenant)
698                .inc_by(u64::cast_from(client_counter.read));
699        }
700        res?;
701
702        Ok(())
703    }
704
705    #[mz_ore::instrument(level = "debug")]
706    async fn init_stream<'a, A>(
707        conn: &'a mut FramedConn<A>,
708        envd_addr: SocketAddr,
709        password: Option<String>,
710        params: BTreeMap<String, String>,
711        internal_tls: bool,
712    ) -> Result<Conn<TcpStream>, anyhow::Error>
713    where
714        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
715    {
716        let mut mz_stream = TcpStream::connect(envd_addr).await?;
717        let mut buf = BytesMut::new();
718
719        let mut mz_stream = if internal_tls {
720            FrontendStartupMessage::SslRequest.encode(&mut buf)?;
721            mz_stream.write_all(&buf).await?;
722            buf.clear();
723            let mut maybe_ssl_request_response = [0u8; 1];
724            let nread =
725                netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
726            if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
727                // do a TLS handshake
728                let mut builder =
729                    SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
730                // environmentd doesn't yet have a cert we trust, so for now disable verification.
731                builder.set_verify(SslVerifyMode::NONE);
732                let mut ssl = builder
733                    .build()
734                    .configure()?
735                    .into_ssl(&envd_addr.to_string())?;
736                ssl.set_connect_state();
737                Conn::Ssl(SslStream::new(ssl, mz_stream)?)
738            } else {
739                Conn::Unencrypted(mz_stream)
740            }
741        } else {
742            Conn::Unencrypted(mz_stream)
743        };
744
745        // Send initial startup and password messages.
746        let startup = FrontendStartupMessage::Startup {
747            version: VERSION_3,
748            params,
749        };
750        startup.encode(&mut buf)?;
751        mz_stream.write_all(&buf).await?;
752        let client_stream = conn.inner_mut();
753
754        // This early return is important in self managed with SASL mode.
755        // The below code specifically looks for cleartext password requests, but in SASL mode
756        // the server will send a different message type (SASLInitialResponse) that we should
757        // not try to interpret or respond to.
758        // "Why not? That code looks like it should fall back fine?" You may ask.
759        // The below block unconditionally reads 9 bytes from the server. If we don't have
760        // a password or the message isn't a cleartext password request, we forward those 9 bytes
761        // to the client. Then we return the stream to the caller, who will continue shuffling bytes.
762        // The problem is that with TLS enabled between balancerd <-> client, flushing the first 9 bytes
763        // before copying bidirectionally will have the side effect of splitting the auth handshake into
764        // two SSL records. Pgbouncer misbehaves in this scenario, and fails the connection.
765        // PGbouncer shouldn't do this! It's a common footgun of protocols over TLS.
766        // So common in fact that PGbouncer already hit and fixed this issue on the bouncer <-> client side:
767        // once before: https://github.com/pgbouncer/pgbouncer/pull/1058.
768        // We will work to upstream a fix, but in the meantime, this early return avoids the issue entirely.
769        if password.is_none() {
770            return Ok(mz_stream);
771        }
772
773        // Read a single backend message, which may be a password request. Send ours if so.
774        // Otherwise start shuffling bytes. message type (len 1, 'R') + message len (len 4, 8_i32) +
775        // auth type (len 4, 3_i32).
776        let mut maybe_auth_frame = [0; 1 + 4 + 4];
777        let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
778        // 'R' for auth message, 0008 for message length, 0003 for password cleartext variant.
779        // See: https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD
780        const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
781        if nread == AUTH_PASSWORD_CLEARTEXT.len()
782            && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
783            && password.is_some()
784        {
785            // If we got exactly a cleartext password request and have one, send it.
786            let Some(password) = password else {
787                unreachable!("verified some above");
788            };
789            let password = FrontendMessage::Password { password };
790            buf.clear();
791            password.encode(&mut buf)?;
792            mz_stream.write_all(&buf).await?;
793            mz_stream.flush().await?;
794        } else {
795            // Otherwise pass on the bytes we just got. This *might* even be a password request, but
796            // we don't have a password. In which case it can be forwarded up to the client.
797            client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
798        }
799
800        Ok(mz_stream)
801    }
802}
803
804impl mz_server_core::Server for PgwireBalancer {
805    const NAME: &'static str = "pgwire_balancer";
806
807    fn handle_connection(
808        &self,
809        conn: Connection,
810        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
811    ) -> mz_server_core::ConnectionHandler {
812        let tls = self.tls.clone();
813        let internal_tls = self.internal_tls;
814        let resolver = Arc::clone(&self.resolver);
815        let inner_metrics = self.metrics.clone();
816        let outer_metrics = self.metrics.clone();
817        let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
818        let conn_uuid = epoch_to_uuid_v7(&(self.now)());
819        let peer_addr = conn.peer_addr();
820        conn.uuid_handle().set(conn_uuid);
821        Box::pin(async move {
822            // TODO: Try to merge this with pgwire/server.rs to avoid the duplication. May not be
823            // worth it.
824            let active_guard = outer_metrics.active_connections();
825            let result: Result<(), anyhow::Error> = async move {
826                let mut conn = Conn::Unencrypted(conn);
827                loop {
828                    let message = decode_startup(&mut conn).await?;
829                    conn = match message {
830                        // Clients sometimes hang up during the startup sequence, e.g.
831                        // because they receive an unacceptable response to an
832                        // `SslRequest`. This is considered a graceful termination.
833                        None => return Ok(()),
834
835                        Some(FrontendStartupMessage::Startup {
836                            version,
837                            mut params,
838                        }) => {
839                            let mut conn = FramedConn::new(conn);
840                            let rejected =
841                                SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION;
842                            let peer_addr = match peer_addr {
843                                Ok(addr) => addr.ip(),
844                                Err(e) => {
845                                    error!("Invalid peer_addr {:?}", e);
846                                    return Ok(conn
847                                        .send(ErrorResponse::fatal(
848                                            rejected,
849                                            "invalid peer address",
850                                        ))
851                                        .await?);
852                                }
853                            };
854                            debug!(
855                                %conn_uuid, %peer_addr,
856                                "starting new pgwire connection in balancer",
857                            );
858                            let prev =
859                                params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
860                            if prev.is_some() {
861                                return Ok(conn
862                                    .send(ErrorResponse::fatal(
863                                        rejected,
864                                        format!("invalid parameter '{CONN_UUID_KEY}'"),
865                                    ))
866                                    .await?);
867                            }
868
869                            let forwarded_for = params.insert(
870                                MZ_FORWARDED_FOR_KEY.to_string(),
871                                peer_addr.to_string().clone(),
872                            );
873                            if let Some(_) = forwarded_for {
874                                return Ok(conn
875                                    .send(ErrorResponse::fatal(
876                                        rejected,
877                                        format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
878                                    ))
879                                    .await?);
880                            };
881
882                            Self::run(
883                                &mut conn,
884                                version,
885                                params,
886                                &resolver,
887                                tls.map(|tls| tls.mode),
888                                internal_tls,
889                                &inner_metrics,
890                            )
891                            .await?;
892                            conn.flush().await?;
893                            return Ok(());
894                        }
895
896                        Some(FrontendStartupMessage::CancelRequest {
897                            conn_id,
898                            secret_key,
899                        }) => {
900                            spawn(|| "cancel request", async move {
901                                cancel_request(conn_id, secret_key, &cancellation_resolver).await;
902                            });
903                            // Do not wait on cancel requests to return because cancellation is best
904                            // effort.
905                            return Ok(());
906                        }
907
908                        Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
909                            (Conn::Unencrypted(mut conn), Some(tls)) => {
910                                conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
911                                let mut ssl_stream =
912                                    SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
913                                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
914                                    let _ = ssl_stream.get_mut().shutdown().await;
915                                    return Err(e.into());
916                                }
917                                Conn::Ssl(ssl_stream)
918                            }
919                            (mut conn, _) => {
920                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
921                                conn
922                            }
923                        },
924
925                        Some(FrontendStartupMessage::GssEncRequest) => {
926                            conn.write_all(&[REJECT_ENCRYPTION]).await?;
927                            conn
928                        }
929                    }
930                }
931            }
932            .await;
933            drop(active_guard);
934            outer_metrics.connection_status(result.is_ok()).inc();
935            Ok(())
936        })
937    }
938}
939
940// A struct that counts bytes exchanged.
941struct CountingConn<C> {
942    inner: C,
943    read: usize,
944    written: usize,
945}
946
947impl<C> CountingConn<C> {
948    fn new(inner: C) -> Self {
949        CountingConn {
950            inner,
951            read: 0,
952            written: 0,
953        }
954    }
955}
956
957impl<C> AsyncRead for CountingConn<C>
958where
959    C: AsyncRead + Unpin,
960{
961    fn poll_read(
962        self: Pin<&mut Self>,
963        cx: &mut std::task::Context<'_>,
964        buf: &mut io::ReadBuf<'_>,
965    ) -> std::task::Poll<std::io::Result<()>> {
966        let counter = self.get_mut();
967        let pin = Pin::new(&mut counter.inner);
968        let bytes = buf.filled().len();
969        let poll = pin.poll_read(cx, buf);
970        let bytes = buf.filled().len() - bytes;
971        if let std::task::Poll::Ready(Ok(())) = poll {
972            counter.read += bytes
973        }
974        poll
975    }
976}
977
978impl<C> AsyncWrite for CountingConn<C>
979where
980    C: AsyncWrite + Unpin,
981{
982    fn poll_write(
983        self: Pin<&mut Self>,
984        cx: &mut std::task::Context<'_>,
985        buf: &[u8],
986    ) -> std::task::Poll<Result<usize, std::io::Error>> {
987        let counter = self.get_mut();
988        let pin = Pin::new(&mut counter.inner);
989        let poll = pin.poll_write(cx, buf);
990        if let std::task::Poll::Ready(Ok(bytes)) = poll {
991            counter.written += bytes
992        }
993        poll
994    }
995
996    fn poll_flush(
997        self: Pin<&mut Self>,
998        cx: &mut std::task::Context<'_>,
999    ) -> std::task::Poll<Result<(), std::io::Error>> {
1000        let counter = self.get_mut();
1001        let pin = Pin::new(&mut counter.inner);
1002        pin.poll_flush(cx)
1003    }
1004
1005    fn poll_shutdown(
1006        self: Pin<&mut Self>,
1007        cx: &mut std::task::Context<'_>,
1008    ) -> std::task::Poll<Result<(), std::io::Error>> {
1009        let counter = self.get_mut();
1010        let pin = Pin::new(&mut counter.inner);
1011        pin.poll_shutdown(cx)
1012    }
1013}
1014
1015/// Broadcasts cancellation to all matching environmentds. `conn_id`'s bits [31..20] are the lower
1016/// 12 bits of a UUID for an environmentd/organization. Using that and the template in
1017/// `cancellation_resolver` we generate a hostname. That hostname resolves to all IPs of envds that
1018/// match the UUID (cloud k8s infrastructure maintains that mapping). This function creates a new
1019/// task for each envd and relays the cancellation message to it, broadcasting it to any envd that
1020/// might match the connection.
1021///
1022/// This function returns after it has spawned the tasks, and does not wait for them to complete.
1023/// This is acceptable because cancellation in the Postgres protocol is best effort and has no
1024/// guarantees.
1025///
1026/// The safety of broadcasting this is due to the various randomness in the connection id and secret
1027/// key, which must match exactly in order to execute a query cancellation. The connection id has 19
1028/// bits of randomness, and the secret key the full 32, for a total of 51 bits. That is more than
1029/// 2e15 combinations, enough to nearly certainly prevent two different envds generating identical
1030/// combinations.
1031async fn cancel_request(
1032    conn_id: u32,
1033    secret_key: u32,
1034    cancellation_resolver: &CancellationResolver,
1035) {
1036    let suffix = conn_id_org_uuid(conn_id);
1037    let contents = match cancellation_resolver {
1038        CancellationResolver::Directory(dir) => {
1039            let path = dir.join(&suffix);
1040            match std::fs::read_to_string(&path) {
1041                Ok(contents) => contents,
1042                Err(err) => {
1043                    error!("could not read cancel file {path:?}: {err}");
1044                    return;
1045                }
1046            }
1047        }
1048        CancellationResolver::Static(addr) => addr.to_owned(),
1049    };
1050    let mut all_ips = Vec::new();
1051    for addr in contents.lines() {
1052        let addr = addr.trim();
1053        if addr.is_empty() {
1054            continue;
1055        }
1056        match tokio::net::lookup_host(addr).await {
1057            Ok(ips) => all_ips.extend(ips),
1058            Err(err) => {
1059                error!("{addr} failed resolution: {err}");
1060            }
1061        }
1062    }
1063    let mut buf = BytesMut::with_capacity(16);
1064    let msg = FrontendStartupMessage::CancelRequest {
1065        conn_id,
1066        secret_key,
1067    };
1068    msg.encode(&mut buf).expect("must encode");
1069    let buf = buf.freeze();
1070    for ip in all_ips {
1071        debug!("cancelling {suffix} to {ip}");
1072        let buf = buf.clone();
1073        spawn(|| "cancel request for ip", async move {
1074            let send = async {
1075                let mut stream = TcpStream::connect(&ip).await?;
1076                stream.write_all(&buf).await?;
1077                stream.shutdown().await?;
1078                Ok::<_, io::Error>(())
1079            };
1080            if let Err(err) = send.await {
1081                error!("error mirroring cancel to {ip}: {err}");
1082            }
1083        });
1084    }
1085}
1086
1087struct HttpsBalancer {
1088    resolver: Arc<StubResolver>,
1089    tls: Option<ReloadingSslContext>,
1090    resolve_template: Arc<str>,
1091    port: u16,
1092    metrics: Arc<ServerMetrics>,
1093    configs: ConfigSet,
1094    internal_tls: bool,
1095}
1096
1097impl HttpsBalancer {
1098    async fn resolve(
1099        resolver: &StubResolver,
1100        resolve_template: &str,
1101        port: u16,
1102        servername: Option<&str>,
1103    ) -> Result<ResolvedAddr, anyhow::Error> {
1104        let addr = match &servername {
1105            Some(servername) => resolve_template.replace("{}", servername),
1106            None => resolve_template.to_string(),
1107        };
1108        debug!("https address: {addr}");
1109
1110        // When we lookup the address using SNI, we get a hostname (`3dl07g8zmj91pntk4eo9cfvwe` for
1111        // example), which you convert into a different form for looking up the environment address
1112        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`. When you do a DNS lookup in kubernetes for
1113        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`, you get a CNAME response pointing at environmentd
1114        // `environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local`. This
1115        // is of the form `<service>.<namespace>.svc.cluster.local`. That `<namespace>` is the same
1116        // as the environment name, and is based on the tenant ID. `environment-<tenant_id>-<index>`
1117        // We currently only support a single environment per tenant in a region, so `<index>` is
1118        // always 0. Do not rely on this ending in `-0` so in the future multiple envds are
1119        // supported.
1120
1121        // Attempt to get a tenant.
1122        let tenant = resolver.tenant(&addr).await;
1123
1124        // Now do the regular ip lookup, regardless of if there was a CNAME.
1125        let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1126
1127        Ok(ResolvedAddr {
1128            addr: envd_addr,
1129            password: None,
1130            tenant,
1131        })
1132    }
1133}
1134
1135trait StubResolverExt {
1136    async fn tenant(&self, addr: &str) -> Option<String>;
1137}
1138
1139impl StubResolverExt for StubResolver {
1140    /// Finds the tenant of a DNS address. Errors or lack of cname resolution here are ok, because
1141    /// this is only used for metrics.
1142    async fn tenant(&self, addr: &str) -> Option<String> {
1143        let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1144            return None;
1145        };
1146        debug!("resolving tenant for {:?}", addr);
1147        // Lookup the CNAME. If there's a CNAME, find the tenant.
1148        let lookup = self.query((dname, Rtype::CNAME)).await;
1149        if let Ok(lookup) = lookup {
1150            if let Ok(answer) = lookup.answer() {
1151                let res = answer.limit_to::<AllRecordData<_, _>>();
1152                for record in res {
1153                    let Ok(record) = record else {
1154                        continue;
1155                    };
1156                    if record.rtype() != Rtype::CNAME {
1157                        continue;
1158                    }
1159                    let cname = record.data();
1160                    let cname = cname.to_string();
1161                    debug!("cname: {cname}");
1162                    return extract_tenant_from_cname(&cname);
1163                }
1164            }
1165        }
1166        None
1167    }
1168}
1169
1170/// Extracts the tenant from a CNAME.
1171fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1172    let mut parts = cname.split('.');
1173    let _service = parts.next();
1174    let Some(namespace) = parts.next() else {
1175        return None;
1176    };
1177    // Trim off the starting `environmentd-`.
1178    let Some((_, namespace)) = namespace.split_once('-') else {
1179        return None;
1180    };
1181    // Trim off the ending `-0` (or some other number).
1182    let Some((tenant, _)) = namespace.rsplit_once('-') else {
1183        return None;
1184    };
1185    // Convert to a Uuid so that this tenant matches the frontegg resolver exactly, because it
1186    // also uses Uuid::to_string.
1187    let Ok(tenant) = Uuid::parse_str(tenant) else {
1188        error!("cname tenant not a uuid: {tenant}");
1189        return None;
1190    };
1191    Some(tenant.to_string())
1192}
1193
1194impl mz_server_core::Server for HttpsBalancer {
1195    const NAME: &'static str = "https_balancer";
1196
1197    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
1198    fn handle_connection(
1199        &self,
1200        conn: Connection,
1201        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1202    ) -> mz_server_core::ConnectionHandler {
1203        let tls_context = self.tls.clone();
1204        let internal_tls = self.internal_tls.clone();
1205        let resolver = Arc::clone(&self.resolver);
1206        let resolve_template = Arc::clone(&self.resolve_template);
1207        let port = self.port;
1208        let inner_metrics = Arc::clone(&self.metrics);
1209        let outer_metrics = Arc::clone(&self.metrics);
1210        let peer_addr = conn.peer_addr();
1211        let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1212        Box::pin(async move {
1213            let active_guard = inner_metrics.active_connections();
1214            let result: Result<_, anyhow::Error> = Box::pin(async move {
1215                let peer_addr = peer_addr.context("fetching peer addr")?;
1216                let (mut client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1217                    match tls_context {
1218                        Some(tls_context) => {
1219                            let mut ssl_stream =
1220                                SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1221                            if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1222                                let _ = ssl_stream.get_mut().shutdown().await;
1223                                return Err(e.into());
1224                            }
1225                            let servername: Option<String> =
1226                                ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1227                                    match sn.split_once('.') {
1228                                        Some((left, _right)) => left,
1229                                        None => sn,
1230                                    }
1231                                    .into()
1232                                });
1233                            debug!("Found sni servername: {servername:?} (https)");
1234                            (Box::new(ssl_stream), servername)
1235                        }
1236                        _ => (Box::new(conn), None),
1237                    };
1238                let resolved =
1239                    Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1240                        .await?;
1241                let inner_active_guard = resolved
1242                    .tenant
1243                    .as_ref()
1244                    .map(|tenant| inner_metrics.tenant_connections(tenant));
1245                let mut mz_stream = match TcpStream::connect(resolved.addr).await {
1246                    Ok(stream) => stream,
1247                    Err(e) => {
1248                        error!("failed to connect to upstream server: {e}");
1249                        let body = "upstream server not available";
1250                        // We know this is an HTTPs stream (see name
1251                        // HttpsBalancer), but we actually don't care what type
1252                        // of traffic it is and we only use raw tcp streams.In
1253                        // order to respond with HTTP we have to write this as a
1254                        // raw http message.
1255                        let response = format!(
1256                            "HTTP/1.1 502 Bad Gateway\r\n\
1257                             Content-Type: text/plain\r\n\
1258                             Content-Length: {}\r\n\
1259                             Connection: close\r\n\
1260                             \r\n\
1261                             {}",
1262                            body.len(),
1263                            body
1264                        );
1265                        let _ = client_stream.write_all(response.as_bytes()).await;
1266                        let _ = client_stream.shutdown().await;
1267                        return Ok(());
1268                    }
1269                };
1270
1271                if inject_proxy_headers {
1272                    // Write the tcp proxy header
1273                    let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1274                    let header = ProxyHeader::with_address(addrs);
1275                    let mut buf = [0u8; 1024];
1276                    let len = header.encode_to_slice_v2(&mut buf)?;
1277                    mz_stream.write_all(&buf[..len]).await?;
1278                }
1279
1280                let mut mz_stream = if internal_tls {
1281                    // do a TLS handshake
1282                    let mut builder =
1283                        SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1284                    // environmentd doesn't yet have a cert we trust, so for now disable verification.
1285                    builder.set_verify(SslVerifyMode::NONE);
1286                    let mut ssl = builder
1287                        .build()
1288                        .configure()?
1289                        .into_ssl(&resolved.addr.to_string())?;
1290                    ssl.set_connect_state();
1291                    Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1292                } else {
1293                    Conn::Unencrypted(mz_stream)
1294                };
1295
1296                let mut client_counter = CountingConn::new(client_stream);
1297
1298                // Now blindly shuffle bytes back and forth until closed.
1299                // TODO: Limit total memory use.
1300                // See corresponding comment in pgwire implementation about ignoring the error.
1301                let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1302                if let Some(tenant) = &resolved.tenant {
1303                    inner_metrics
1304                        .tenant_connections_tx(tenant)
1305                        .inc_by(u64::cast_from(client_counter.written));
1306                    inner_metrics
1307                        .tenant_connections_rx(tenant)
1308                        .inc_by(u64::cast_from(client_counter.read));
1309                }
1310                drop(inner_active_guard);
1311                Ok(())
1312            })
1313            .await;
1314            drop(active_guard);
1315            outer_metrics.connection_status(result.is_ok()).inc();
1316            if let Err(e) = result {
1317                debug!("connection error: {e}");
1318            }
1319            Ok(())
1320        })
1321    }
1322}
1323
1324#[derive(Debug)]
1325pub struct SniResolver {
1326    pub resolver: StubResolver,
1327    pub template: String,
1328    pub port: u16,
1329}
1330
1331trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1332impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1333
1334#[derive(Debug)]
1335pub enum Resolver {
1336    Static(String),
1337    MultiTenant(FronteggResolver, Option<SniResolver>),
1338}
1339
1340impl Resolver {
1341    async fn resolve<A>(
1342        &self,
1343        conn: &mut FramedConn<A>,
1344        user: &str,
1345        metrics: &ServerMetrics,
1346    ) -> Result<ResolvedAddr, anyhow::Error>
1347    where
1348        A: AsyncRead + AsyncWrite + Unpin,
1349    {
1350        match self {
1351            Resolver::MultiTenant(
1352                FronteggResolver {
1353                    auth,
1354                    addr_template,
1355                },
1356                sni_resolver,
1357            ) => {
1358                let servername = match conn.inner() {
1359                    Conn::Ssl(ssl_stream) => {
1360                        ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1361                            match sn.split_once('.') {
1362                                Some((left, _right)) => left,
1363                                None => sn,
1364                            }
1365                        })
1366                    }
1367                    Conn::Unencrypted(_) => None,
1368                };
1369                let has_sni = servername.is_some();
1370                // We found an SNi
1371                let resolved_addr = match (servername, sni_resolver) {
1372                    (
1373                        Some(servername),
1374                        Some(SniResolver {
1375                            resolver: stub_resolver,
1376                            template: sni_addr_template,
1377                            port,
1378                        }),
1379                    ) => {
1380                        let sni_addr = sni_addr_template.replace("{}", servername);
1381                        let tenant = stub_resolver.tenant(&sni_addr).await;
1382                        let sni_addr = format!("{sni_addr}:{port}");
1383                        let addr = lookup(&sni_addr).await?;
1384                        if tenant.is_some() {
1385                            debug!("SNI header found for tenant {:?}", tenant);
1386                        }
1387                        ResolvedAddr {
1388                            addr,
1389                            password: None,
1390                            tenant,
1391                        }
1392                    }
1393                    _ => {
1394                        conn.send(BackendMessage::AuthenticationCleartextPassword)
1395                            .await?;
1396                        conn.flush().await?;
1397                        let password = match conn.recv().await? {
1398                            Some(FrontendMessage::Password { password }) => password,
1399                            _ => anyhow::bail!("expected Password message"),
1400                        };
1401
1402                        let auth_response = auth.authenticate(user, &password).await;
1403                        let auth_session = match auth_response {
1404                            Ok((auth_session, _)) => auth_session,
1405                            Err(e) => {
1406                                warn!("pgwire connection failed authentication: {}", e);
1407                                // TODO: fix error codes.
1408                                anyhow::bail!("invalid password");
1409                            }
1410                        };
1411
1412                        let addr =
1413                            addr_template.replace("{}", &auth_session.tenant_id().to_string());
1414                        let addr = lookup(&addr).await?;
1415                        let tenant = Some(auth_session.tenant_id().to_string());
1416                        if tenant.is_some() {
1417                            debug!("SNI header NOT found for tenant {:?}", tenant);
1418                        }
1419                        ResolvedAddr {
1420                            addr,
1421                            password: Some(password),
1422                            tenant,
1423                        }
1424                    }
1425                };
1426                metrics
1427                    .tenant_pgwire_sni_count(
1428                        resolved_addr.tenant.as_deref().unwrap_or("unknown"),
1429                        has_sni,
1430                    )
1431                    .inc();
1432
1433                Ok(resolved_addr)
1434            }
1435            Resolver::Static(addr) => {
1436                let addr = lookup(addr).await?;
1437                Ok(ResolvedAddr {
1438                    addr,
1439                    password: None,
1440                    tenant: None,
1441                })
1442            }
1443        }
1444    }
1445}
1446
1447/// Returns the first IP address resolved from the provided hostname.
1448async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1449    let mut addrs = tokio::net::lookup_host(name).await?;
1450    match addrs.next() {
1451        Some(addr) => Ok(addr),
1452        None => {
1453            error!("{name} did not resolve to any addresses");
1454            anyhow::bail!("internal error")
1455        }
1456    }
1457}
1458
1459#[derive(Debug)]
1460pub struct FronteggResolver {
1461    pub auth: FronteggAuthentication,
1462    pub addr_template: String,
1463}
1464
1465#[derive(Debug)]
1466struct ResolvedAddr {
1467    addr: SocketAddr,
1468    password: Option<String>,
1469    tenant: Option<String>,
1470}
1471
1472#[cfg(test)]
1473mod tests {
1474    use super::*;
1475
1476    #[mz_ore::test]
1477    fn test_tenant() {
1478        let tests = vec![
1479            ("", None),
1480            (
1481                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1482                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1483            ),
1484            (
1485                // Variously named parts.
1486                "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1487                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1488            ),
1489            (
1490                // No dashes in uuid.
1491                "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1492                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1493            ),
1494            (
1495                // -1234 suffix.
1496                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1497                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1498            ),
1499            (
1500                // Uppercase.
1501                "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1502                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1503            ),
1504            (
1505                // No -number suffix.
1506                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1507                None,
1508            ),
1509            (
1510                // No service name.
1511                "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1512                None,
1513            ),
1514            (
1515                // Invalid UUID.
1516                "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1517                None,
1518            ),
1519        ];
1520        for (name, expect) in tests {
1521            let cname = extract_tenant_from_cname(name);
1522            assert_eq!(
1523                cname.as_deref(),
1524                expect,
1525                "{name} got {cname:?} expected {expect:?}"
1526            );
1527        }
1528    }
1529}