mz_balancerd/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The balancerd service is a horizontally scalable, stateless, multi-tenant ingress router for
11//! pgwire and HTTPS connections.
12//!
13//! It listens on pgwire and HTTPS ports. When a new pgwire connection starts, the requested user is
14//! authenticated with frontegg from which a tenant id is returned. From that a target internal
15//! hostname is resolved to an IP address, and the connection is proxied to that address which has a
16//! running environmentd's pgwire port. When a new HTTPS connection starts, its SNI hostname is used
17//! to generate an internal hostname that is resolved to an IP address, which is similarly proxied.
18
19mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Dname, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55    FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58    Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59    ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_openssl::SslStream;
70use tokio_postgres::error::SqlState;
71use tower::Service;
72use tracing::{debug, error, warn};
73use uuid::Uuid;
74
75use crate::codec::{BackendMessage, FramedConn};
76use crate::dyncfgs::{
77    INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
78    has_tracing_config_update, tracing_config,
79};
80
81/// Balancer build information.
82pub const BUILD_INFO: BuildInfo = build_info!();
83
84pub struct BalancerConfig {
85    /// Info about which version of the code is running.
86    build_version: Version,
87    /// Listen address for internal HTTP health and metrics server.
88    internal_http_listen_addr: SocketAddr,
89    /// Listen address for pgwire connections.
90    pgwire_listen_addr: SocketAddr,
91    /// Listen address for HTTPS connections.
92    https_listen_addr: SocketAddr,
93    /// DNS resolver for pgwire cancellation requests
94    cancellation_resolver: CancellationResolver,
95    /// DNS resolver.
96    resolver: Resolver,
97    https_addr_template: String,
98    tls: Option<TlsCertConfig>,
99    internal_tls: bool,
100    metrics_registry: MetricsRegistry,
101    reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
102    launchdarkly_sdk_key: Option<String>,
103    config_sync_timeout: Duration,
104    config_sync_loop_interval: Option<Duration>,
105    cloud_provider: Option<String>,
106    cloud_provider_region: Option<String>,
107    tracing_handle: TracingHandle,
108    default_configs: Vec<(String, String)>,
109}
110
111impl BalancerConfig {
112    pub fn new(
113        build_info: &BuildInfo,
114        internal_http_listen_addr: SocketAddr,
115        pgwire_listen_addr: SocketAddr,
116        https_listen_addr: SocketAddr,
117        cancellation_resolver: CancellationResolver,
118        resolver: Resolver,
119        https_addr_template: String,
120        tls: Option<TlsCertConfig>,
121        internal_tls: bool,
122        metrics_registry: MetricsRegistry,
123        reload_certs: ReloadTrigger,
124        launchdarkly_sdk_key: Option<String>,
125        config_sync_timeout: Duration,
126        config_sync_loop_interval: Option<Duration>,
127        cloud_provider: Option<String>,
128        cloud_provider_region: Option<String>,
129        tracing_handle: TracingHandle,
130        default_configs: Vec<(String, String)>,
131    ) -> Self {
132        Self {
133            build_version: build_info.semver_version(),
134            internal_http_listen_addr,
135            pgwire_listen_addr,
136            https_listen_addr,
137            cancellation_resolver,
138            resolver,
139            https_addr_template,
140            tls,
141            internal_tls,
142            metrics_registry,
143            reload_certs,
144            launchdarkly_sdk_key,
145            config_sync_timeout,
146            config_sync_loop_interval,
147            cloud_provider,
148            cloud_provider_region,
149            tracing_handle,
150            default_configs,
151        }
152    }
153}
154
155/// Prometheus monitoring metrics.
156#[derive(Debug)]
157pub struct BalancerMetrics {
158    _uptime: ComputedGauge,
159}
160
161impl BalancerMetrics {
162    /// Returns a new [BalancerMetrics] instance connected to the registry in cfg.
163    pub fn new(cfg: &BalancerConfig) -> Self {
164        let start = Instant::now();
165        let uptime = cfg.metrics_registry.register_computed_gauge(
166            metric!(
167                name: "mz_balancer_metadata_seconds",
168                help: "server uptime, labels are build metadata",
169                const_labels: {
170                    "version" => cfg.build_version,
171                    "build_type" => if cfg!(release) { "release" } else { "debug" }
172                },
173            ),
174            move || start.elapsed().as_secs_f64(),
175        );
176        BalancerMetrics { _uptime: uptime }
177    }
178}
179
180pub struct BalancerService {
181    cfg: BalancerConfig,
182    pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
183    pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
184    pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
185    _metrics: BalancerMetrics,
186    configs: ConfigSet,
187}
188
189impl BalancerService {
190    pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
191        let pgwire = listen(&cfg.pgwire_listen_addr).await?;
192        let https = listen(&cfg.https_listen_addr).await?;
193        let internal_http = listen(&cfg.internal_http_listen_addr).await?;
194        let metrics = BalancerMetrics::new(&cfg);
195        let mut configs = ConfigSet::default();
196        configs = dyncfgs::all_dyncfgs(configs);
197        dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
198        let tracing_handle = cfg.tracing_handle.clone();
199        if let Err(err) = mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
200            configs.clone(),
201            &BUILD_INFO,
202            |builder| {
203                let region = cfg
204                    .cloud_provider_region
205                    .clone()
206                    .unwrap_or_else(|| String::from("unknown"));
207                if let Some(provider) = cfg.cloud_provider.clone() {
208                    builder.add_context(
209                        ld::ContextBuilder::new(format!(
210                            "{}/{}/{}",
211                            provider, region, cfg.build_version
212                        ))
213                        .kind("balancer")
214                        .set_string("provider", provider)
215                        .set_string("region", region)
216                        .set_string("version", cfg.build_version.to_string())
217                        .build()
218                        .map_err(|e| anyhow::anyhow!(e))?,
219                    );
220                } else {
221                    builder.add_context(
222                        ld::ContextBuilder::new(format!(
223                            "{}/{}/{}",
224                            "unknown", region, cfg.build_version
225                        ))
226                        .anonymous(true) // exclude this user from the dashboard
227                        .kind("balancer")
228                        .set_string("provider", "unknown")
229                        .set_string("region", region)
230                        .set_string("version", cfg.build_version.to_string())
231                        .build()
232                        .map_err(|e| anyhow::anyhow!(e))?,
233                    );
234                }
235                Ok(())
236            },
237            cfg.launchdarkly_sdk_key.as_deref(),
238            cfg.config_sync_timeout,
239            cfg.config_sync_loop_interval,
240            move |updates, configs| {
241                if has_tracing_config_update(updates) {
242                    match tracing_config(configs) {
243                        Ok(parameters) => parameters.apply(&tracing_handle),
244                        Err(err) => warn!("unable to update tracing: {err}"),
245                    }
246                }
247            },
248        )
249        .await
250        {
251            // Log but continue anyway. If LD is down we have no way of fetching the previous value
252            // of the flag (unlike the adapter, but it has a durable catalog). The ConfigSet
253            // defaults have been chosen to be good enough if this is the case.
254            warn!("LaunchDarkly sync error: {err}");
255        }
256        Ok(Self {
257            cfg,
258            pgwire,
259            https,
260            internal_http,
261            _metrics: metrics,
262            configs,
263        })
264    }
265
266    pub async fn serve(self) -> Result<(), anyhow::Error> {
267        let (pgwire_tls, https_tls) = match &self.cfg.tls {
268            Some(tls) => {
269                let context = tls.reloading_context(self.cfg.reload_certs)?;
270                (
271                    Some(ReloadingTlsConfig {
272                        context: context.clone(),
273                        mode: TlsMode::Require,
274                    }),
275                    Some(context),
276                )
277            }
278            None => (None, None),
279        };
280
281        let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
282
283        let mut set = JoinSet::new();
284        let mut server_handles = Vec::new();
285        let pgwire_addr = self.pgwire.0.local_addr();
286        let https_addr = self.https.0.local_addr();
287        let internal_http_addr = self.internal_http.0.local_addr();
288
289        {
290            let pgwire = PgwireBalancer {
291                resolver: Arc::new(self.cfg.resolver),
292                cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
293                tls: pgwire_tls,
294                internal_tls: self.cfg.internal_tls,
295                metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
296                now: SYSTEM_TIME.clone(),
297            };
298            let (handle, stream) = self.pgwire;
299            server_handles.push(handle);
300            set.spawn_named(|| "pgwire_stream", {
301                let config_set = self.configs.clone();
302                async move {
303                    mz_server_core::serve(ServeConfig {
304                        server: pgwire,
305                        conns: stream,
306                        dyncfg: Some(ServeDyncfg {
307                            config_set,
308                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
309                        }),
310                    })
311                    .await;
312                    warn!("pgwire server exited");
313                }
314            });
315        }
316        {
317            let Some((addr, port)) = self.cfg.https_addr_template.split_once(':') else {
318                panic!("expected port in https_addr_template");
319            };
320            let port: u16 = port.parse().expect("unexpected port");
321            let resolver = StubResolver::new();
322            let https = HttpsBalancer {
323                resolver: Arc::from(resolver),
324                tls: https_tls,
325                resolve_template: Arc::from(addr),
326                port,
327                metrics: Arc::from(ServerMetrics::new(metrics, "https")),
328                configs: self.configs.clone(),
329                internal_tls: self.cfg.internal_tls,
330            };
331            let (handle, stream) = self.https;
332            server_handles.push(handle);
333            set.spawn_named(|| "https_stream", {
334                let config_set = self.configs.clone();
335                async move {
336                    mz_server_core::serve(ServeConfig {
337                        server: https,
338                        conns: stream,
339                        dyncfg: Some(ServeDyncfg {
340                            config_set,
341                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
342                        }),
343                    })
344                    .await;
345                    warn!("https server exited");
346                }
347            });
348        }
349        {
350            let router = Router::new()
351                .route(
352                    "/metrics",
353                    routing::get(move || async move {
354                        mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
355                    }),
356                )
357                .route(
358                    "/api/livez",
359                    routing::get(mz_http_util::handle_liveness_check),
360                )
361                .route("/api/readyz", routing::get(handle_readiness_check));
362            let internal_http = InternalHttpServer { router };
363            let (handle, stream) = self.internal_http;
364            server_handles.push(handle);
365            set.spawn_named(|| "internal_http_stream", async move {
366                mz_server_core::serve(ServeConfig {
367                    server: internal_http,
368                    conns: stream,
369                    // Disable graceful termination because our internal
370                    // monitoring keeps persistent HTTP connections open.
371                    dyncfg: None,
372                })
373                .await;
374                warn!("internal_http server exited");
375            });
376        }
377        #[cfg(unix)]
378        {
379            let mut sigterm =
380                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
381            set.spawn_named(|| "sigterm_handler", async move {
382                sigterm.recv().await;
383                let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
384                warn!("received signal TERM - delaying for {:?}!", wait);
385                tokio::time::sleep(wait).await;
386                warn!("sigterm delay complete, dropping server handles");
387                drop(server_handles);
388            });
389        }
390
391        println!("balancerd {} listening...", BUILD_INFO.human_version(None));
392        println!(" TLS enabled: {}", self.cfg.tls.is_some());
393        println!(" pgwire address: {}", pgwire_addr);
394        println!(" HTTPS address: {}", https_addr);
395        println!(" internal HTTP address: {}", internal_http_addr);
396
397        // Wait for all tasks to exit, which can happen on SIGTERM.
398        while let Some(res) = set.join_next().await {
399            if let Err(err) = res {
400                error!("serving task failed: {err}")
401            }
402        }
403        Ok(())
404    }
405}
406
407#[allow(clippy::unused_async)]
408async fn handle_readiness_check() -> impl IntoResponse {
409    (StatusCode::OK, "ready")
410}
411
412struct InternalHttpServer {
413    router: Router,
414}
415
416impl mz_server_core::Server for InternalHttpServer {
417    const NAME: &'static str = "internal_http";
418
419    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
420    fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
421        let router = self.router.clone();
422        let service = hyper::service::service_fn(move |req| router.clone().call(req));
423        let conn = TokioIo::new(conn);
424
425        Box::pin(async {
426            let http = hyper::server::conn::http1::Builder::new();
427            http.serve_connection(conn, service).err_into().await
428        })
429    }
430}
431
432/// Wraps an IntGauge and automatically `inc`s on init and `drop`s on drop. Callers should not call
433/// `inc().`. Useful for handling multiple task exit points, for example in the case of a panic.
434struct GaugeGuard {
435    gauge: IntGauge,
436}
437
438impl From<IntGauge> for GaugeGuard {
439    fn from(gauge: IntGauge) -> Self {
440        let _self = Self { gauge };
441        _self.gauge.inc();
442        _self
443    }
444}
445
446impl Drop for GaugeGuard {
447    fn drop(&mut self) {
448        self.gauge.dec();
449    }
450}
451
452#[derive(Clone, Debug)]
453struct ServerMetricsConfig {
454    connection_status: IntCounterVec,
455    active_connections: IntGaugeVec,
456    tenant_connections: IntGaugeVec,
457    tenant_connection_rx: IntCounterVec,
458    tenant_connection_tx: IntCounterVec,
459    tenant_pgwire_sni_count: IntCounterVec,
460}
461
462impl ServerMetricsConfig {
463    fn register_into(registry: &MetricsRegistry) -> Self {
464        let connection_status = registry.register(metric!(
465            name: "mz_balancer_connection_status",
466            help: "Count of completed network connections, by status",
467            var_labels: ["source", "status"],
468        ));
469        let active_connections = registry.register(metric!(
470            name: "mz_balancer_connection_active",
471            help: "Count of currently open network connections.",
472            var_labels: ["source"],
473        ));
474        let tenant_connections = registry.register(metric!(
475            name: "mz_balancer_tenant_connection_active",
476            help: "Count of opened network connections by tenant.",
477            var_labels: ["source",  "tenant"]
478        ));
479        let tenant_connection_rx = registry.register(metric!(
480            name: "mz_balancer_tenant_connection_rx",
481            help: "Number of bytes received from a client for a tenant.",
482            var_labels: ["source", "tenant"],
483        ));
484        let tenant_connection_tx = registry.register(metric!(
485            name: "mz_balancer_tenant_connection_tx",
486            help: "Number of bytes sent to a client for a tenant.",
487            var_labels: ["source", "tenant"],
488        ));
489        let tenant_pgwire_sni_count = registry.register(metric!(
490            name: "mz_balancer_tenant_pgwire_sni_count",
491            help: "Count of pgwire connections that have and do not have SNI available per tenant.",
492            var_labels: ["tenant", "has_sni"],
493        ));
494        Self {
495            connection_status,
496            active_connections,
497            tenant_connections,
498            tenant_connection_rx,
499            tenant_connection_tx,
500            tenant_pgwire_sni_count,
501        }
502    }
503}
504
505#[derive(Clone, Debug)]
506struct ServerMetrics {
507    inner: ServerMetricsConfig,
508    source: &'static str,
509}
510
511impl ServerMetrics {
512    fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
513        let self_ = Self { inner, source };
514
515        // Pre-initialize labels we are planning to use to ensure they are all always emitted as
516        // time series.
517        self_.connection_status(false);
518        self_.connection_status(true);
519        drop(self_.active_connections());
520
521        self_
522    }
523
524    fn connection_status(&self, is_ok: bool) -> IntCounter {
525        self.inner
526            .connection_status
527            .with_label_values(&[self.source, Self::status_label(is_ok)])
528    }
529
530    fn active_connections(&self) -> GaugeGuard {
531        self.inner
532            .active_connections
533            .with_label_values(&[self.source])
534            .into()
535    }
536
537    fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
538        self.inner
539            .tenant_connections
540            .with_label_values(&[self.source, tenant])
541            .into()
542    }
543
544    fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
545        self.inner
546            .tenant_connection_rx
547            .with_label_values(&[self.source, tenant])
548    }
549
550    fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
551        self.inner
552            .tenant_connection_tx
553            .with_label_values(&[self.source, tenant])
554    }
555
556    fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
557        self.inner
558            .tenant_pgwire_sni_count
559            .with_label_values(&[tenant, &has_sni.to_string()])
560    }
561
562    fn status_label(is_ok: bool) -> &'static str {
563        if is_ok { "success" } else { "error" }
564    }
565}
566
567pub enum CancellationResolver {
568    Directory(PathBuf),
569    Static(String),
570}
571
572struct PgwireBalancer {
573    tls: Option<ReloadingTlsConfig>,
574    internal_tls: bool,
575    cancellation_resolver: Arc<CancellationResolver>,
576    resolver: Arc<Resolver>,
577    metrics: ServerMetrics,
578    now: NowFn,
579}
580
581impl PgwireBalancer {
582    #[mz_ore::instrument(level = "debug")]
583    async fn run<'a, A>(
584        conn: &'a mut FramedConn<A>,
585        version: i32,
586        params: BTreeMap<String, String>,
587        resolver: &Resolver,
588        tls_mode: Option<TlsMode>,
589        internal_tls: bool,
590        metrics: &ServerMetrics,
591    ) -> Result<(), io::Error>
592    where
593        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
594    {
595        if version != VERSION_3 {
596            return conn
597                .send(ErrorResponse::fatal(
598                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
599                    "server does not support the client's requested protocol version",
600                ))
601                .await;
602        }
603
604        let Some(user) = params.get("user") else {
605            return conn
606                .send(ErrorResponse::fatal(
607                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
608                    "user parameter required",
609                ))
610                .await;
611        };
612
613        if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
614            return conn.send(err).await;
615        }
616
617        let resolved = match resolver.resolve(conn, user).await {
618            Ok(v) => v,
619            Err(err) => {
620                return conn
621                    .send(ErrorResponse::fatal(
622                        SqlState::INVALID_PASSWORD,
623                        err.to_string(),
624                    ))
625                    .await;
626            }
627        };
628
629        // Count the # of pgwire connections that have SNI available / unavailable
630        // per tenant. In the future we may want to remove non-SNI connections.
631        if let Conn::Ssl(ssl_stream) = conn.inner() {
632            let tenant = resolved.tenant.as_deref().unwrap_or("unknown");
633            let has_sni = ssl_stream.ssl().servername(NameType::HOST_NAME).is_some();
634            metrics.tenant_pgwire_sni_count(tenant, has_sni).inc();
635        }
636
637        let _active_guard = resolved
638            .tenant
639            .as_ref()
640            .map(|tenant| metrics.tenant_connections(tenant));
641        let Ok(mut mz_stream) =
642            Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
643        else {
644            return Ok(());
645        };
646
647        let mut client_counter = CountingConn::new(conn.inner_mut());
648
649        // Now blindly shuffle bytes back and forth until closed.
650        // TODO: Limit total memory use.
651        let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
652        if let Some(tenant) = &resolved.tenant {
653            metrics
654                .tenant_connections_tx(tenant)
655                .inc_by(u64::cast_from(client_counter.written));
656            metrics
657                .tenant_connections_rx(tenant)
658                .inc_by(u64::cast_from(client_counter.read));
659        }
660        res?;
661
662        Ok(())
663    }
664
665    #[mz_ore::instrument(level = "debug")]
666    async fn init_stream<'a, A>(
667        conn: &'a mut FramedConn<A>,
668        envd_addr: SocketAddr,
669        password: Option<String>,
670        params: BTreeMap<String, String>,
671        internal_tls: bool,
672    ) -> Result<Conn<TcpStream>, anyhow::Error>
673    where
674        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
675    {
676        let mut mz_stream = TcpStream::connect(envd_addr).await?;
677        let mut buf = BytesMut::new();
678
679        let mut mz_stream = if internal_tls {
680            FrontendStartupMessage::SslRequest.encode(&mut buf)?;
681            mz_stream.write_all(&buf).await?;
682            buf.clear();
683            let mut maybe_ssl_request_response = [0u8; 1];
684            let nread =
685                netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
686            if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
687                // do a TLS handshake
688                let mut builder =
689                    SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
690                // environmentd doesn't yet have a cert we trust, so for now disable verification.
691                builder.set_verify(SslVerifyMode::NONE);
692                let mut ssl = builder
693                    .build()
694                    .configure()?
695                    .into_ssl(&envd_addr.to_string())?;
696                ssl.set_connect_state();
697                Conn::Ssl(SslStream::new(ssl, mz_stream)?)
698            } else {
699                Conn::Unencrypted(mz_stream)
700            }
701        } else {
702            Conn::Unencrypted(mz_stream)
703        };
704
705        // Send initial startup and password messages.
706        let startup = FrontendStartupMessage::Startup {
707            version: VERSION_3,
708            params,
709        };
710        startup.encode(&mut buf)?;
711        mz_stream.write_all(&buf).await?;
712        let client_stream = conn.inner_mut();
713
714        // Read a single backend message, which may be a password request. Send ours if so.
715        // Otherwise start shuffling bytes. message type (len 1, 'R') + message len (len 4, 8_i32) +
716        // auth type (len 4, 3_i32).
717        let mut maybe_auth_frame = [0; 1 + 4 + 4];
718        let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
719        // 'R' for auth message, 0008 for message length, 0003 for password cleartext variant.
720        // See: https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD
721        const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
722        if nread == AUTH_PASSWORD_CLEARTEXT.len()
723            && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
724            && password.is_some()
725        {
726            // If we got exactly a cleartext password request and have one, send it.
727            let Some(password) = password else {
728                unreachable!("verified some above");
729            };
730            let password = FrontendMessage::Password { password };
731            buf.clear();
732            password.encode(&mut buf)?;
733            mz_stream.write_all(&buf).await?;
734            mz_stream.flush().await?;
735        } else {
736            // Otherwise pass on the bytes we just got. This *might* even be a password request, but
737            // we don't have a password. In which case it can be forwarded up to the client.
738            client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
739        }
740
741        Ok(mz_stream)
742    }
743}
744
745impl mz_server_core::Server for PgwireBalancer {
746    const NAME: &'static str = "pgwire_balancer";
747
748    fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
749        let tls = self.tls.clone();
750        let internal_tls = self.internal_tls;
751        let resolver = Arc::clone(&self.resolver);
752        let inner_metrics = self.metrics.clone();
753        let outer_metrics = self.metrics.clone();
754        let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
755        let conn_uuid = epoch_to_uuid_v7(&(self.now)());
756        let peer_addr = conn.peer_addr();
757        conn.uuid_handle().set(conn_uuid);
758        Box::pin(async move {
759            // TODO: Try to merge this with pgwire/server.rs to avoid the duplication. May not be
760            // worth it.
761            let active_guard = outer_metrics.active_connections();
762            let result: Result<(), anyhow::Error> = async move {
763                let mut conn = Conn::Unencrypted(conn);
764                loop {
765                    let message = decode_startup(&mut conn).await?;
766                    conn = match message {
767                        // Clients sometimes hang up during the startup sequence, e.g.
768                        // because they receive an unacceptable response to an
769                        // `SslRequest`. This is considered a graceful termination.
770                        None => return Ok(()),
771
772                        Some(FrontendStartupMessage::Startup {
773                            version,
774                            mut params,
775                        }) => {
776                            let mut conn = FramedConn::new(conn);
777                            let peer_addr = match peer_addr {
778                                Ok(addr) => addr.ip(),
779                                Err(e) => {
780                                    error!("Invalid peer_addr {:?}", e);
781                                    return Ok(conn
782                                        .send(ErrorResponse::fatal(
783                                            SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
784                                            "invalid peer address",
785                                        ))
786                                        .await?);
787                                }
788                            };
789                            debug!(%conn_uuid, %peer_addr,  "starting new pgwire connection in balancer");
790                            let prev =
791                                params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
792                            if prev.is_some() {
793                                return Ok(conn
794                                    .send(ErrorResponse::fatal(
795                                        SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
796                                        format!("invalid parameter '{CONN_UUID_KEY}'"),
797                                    ))
798                                    .await?);
799                            }
800
801                            if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
802                                return Ok(conn
803                                    .send(ErrorResponse::fatal(
804                                        SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
805                                        format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
806                                    ))
807                                    .await?);
808                            };
809
810                            Self::run(
811                                &mut conn,
812                                version,
813                                params,
814                                &resolver,
815                                tls.map(|tls| tls.mode),
816                                internal_tls,
817                                &inner_metrics,
818                            )
819                            .await?;
820                            conn.flush().await?;
821                            return Ok(());
822                        }
823
824                        Some(FrontendStartupMessage::CancelRequest {
825                            conn_id,
826                            secret_key,
827                        }) => {
828                            spawn(|| "cancel request", async move {
829                                cancel_request(conn_id, secret_key, &cancellation_resolver).await;
830                            });
831                            // Do not wait on cancel requests to return because cancellation is best
832                            // effort.
833                            return Ok(());
834                        }
835
836                        Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
837                            (Conn::Unencrypted(mut conn), Some(tls)) => {
838                                conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
839                                let mut ssl_stream =
840                                    SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
841                                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
842                                    let _ = ssl_stream.get_mut().shutdown().await;
843                                    return Err(e.into());
844                                }
845                                Conn::Ssl(ssl_stream)
846                            }
847                            (mut conn, _) => {
848                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
849                                conn
850                            }
851                        },
852
853                        Some(FrontendStartupMessage::GssEncRequest) => {
854                            conn.write_all(&[REJECT_ENCRYPTION]).await?;
855                            conn
856                        }
857                    }
858                }
859            }
860            .await;
861            drop(active_guard);
862            outer_metrics.connection_status(result.is_ok()).inc();
863            Ok(())
864        })
865    }
866}
867
868// A struct that counts bytes exchanged.
869struct CountingConn<C> {
870    inner: C,
871    read: usize,
872    written: usize,
873}
874
875impl<C> CountingConn<C> {
876    fn new(inner: C) -> Self {
877        CountingConn {
878            inner,
879            read: 0,
880            written: 0,
881        }
882    }
883}
884
885impl<C> AsyncRead for CountingConn<C>
886where
887    C: AsyncRead + Unpin,
888{
889    fn poll_read(
890        self: Pin<&mut Self>,
891        cx: &mut std::task::Context<'_>,
892        buf: &mut io::ReadBuf<'_>,
893    ) -> std::task::Poll<std::io::Result<()>> {
894        let counter = self.get_mut();
895        let pin = Pin::new(&mut counter.inner);
896        let bytes = buf.filled().len();
897        let poll = pin.poll_read(cx, buf);
898        let bytes = buf.filled().len() - bytes;
899        if let std::task::Poll::Ready(Ok(())) = poll {
900            counter.read += bytes
901        }
902        poll
903    }
904}
905
906impl<C> AsyncWrite for CountingConn<C>
907where
908    C: AsyncWrite + Unpin,
909{
910    fn poll_write(
911        self: Pin<&mut Self>,
912        cx: &mut std::task::Context<'_>,
913        buf: &[u8],
914    ) -> std::task::Poll<Result<usize, std::io::Error>> {
915        let counter = self.get_mut();
916        let pin = Pin::new(&mut counter.inner);
917        let poll = pin.poll_write(cx, buf);
918        if let std::task::Poll::Ready(Ok(bytes)) = poll {
919            counter.written += bytes
920        }
921        poll
922    }
923
924    fn poll_flush(
925        self: Pin<&mut Self>,
926        cx: &mut std::task::Context<'_>,
927    ) -> std::task::Poll<Result<(), std::io::Error>> {
928        let counter = self.get_mut();
929        let pin = Pin::new(&mut counter.inner);
930        pin.poll_flush(cx)
931    }
932
933    fn poll_shutdown(
934        self: Pin<&mut Self>,
935        cx: &mut std::task::Context<'_>,
936    ) -> std::task::Poll<Result<(), std::io::Error>> {
937        let counter = self.get_mut();
938        let pin = Pin::new(&mut counter.inner);
939        pin.poll_shutdown(cx)
940    }
941}
942
943/// Broadcasts cancellation to all matching environmentds. `conn_id`'s bits [31..20] are the lower
944/// 12 bits of a UUID for an environmentd/organization. Using that and the template in
945/// `cancellation_resolver` we generate a hostname. That hostname resolves to all IPs of envds that
946/// match the UUID (cloud k8s infrastructure maintains that mapping). This function creates a new
947/// task for each envd and relays the cancellation message to it, broadcasting it to any envd that
948/// might match the connection.
949///
950/// This function returns after it has spawned the tasks, and does not wait for them to complete.
951/// This is acceptable because cancellation in the Postgres protocol is best effort and has no
952/// guarantees.
953///
954/// The safety of broadcasting this is due to the various randomness in the connection id and secret
955/// key, which must match exactly in order to execute a query cancellation. The connection id has 19
956/// bits of randomness, and the secret key the full 32, for a total of 51 bits. That is more than
957/// 2e15 combinations, enough to nearly certainly prevent two different envds generating identical
958/// combinations.
959async fn cancel_request(
960    conn_id: u32,
961    secret_key: u32,
962    cancellation_resolver: &CancellationResolver,
963) {
964    let suffix = conn_id_org_uuid(conn_id);
965    let contents = match cancellation_resolver {
966        CancellationResolver::Directory(dir) => {
967            let path = dir.join(&suffix);
968            match std::fs::read_to_string(&path) {
969                Ok(contents) => contents,
970                Err(err) => {
971                    error!("could not read cancel file {path:?}: {err}");
972                    return;
973                }
974            }
975        }
976        CancellationResolver::Static(addr) => addr.to_owned(),
977    };
978    let mut all_ips = Vec::new();
979    for addr in contents.lines() {
980        let addr = addr.trim();
981        if addr.is_empty() {
982            continue;
983        }
984        match tokio::net::lookup_host(addr).await {
985            Ok(ips) => all_ips.extend(ips),
986            Err(err) => {
987                error!("{addr} failed resolution: {err}");
988            }
989        }
990    }
991    let mut buf = BytesMut::with_capacity(16);
992    let msg = FrontendStartupMessage::CancelRequest {
993        conn_id,
994        secret_key,
995    };
996    msg.encode(&mut buf).expect("must encode");
997    let buf = buf.freeze();
998    for ip in all_ips {
999        debug!("cancelling {suffix} to {ip}");
1000        let buf = buf.clone();
1001        spawn(|| "cancel request for ip", async move {
1002            let send = async {
1003                let mut stream = TcpStream::connect(&ip).await?;
1004                stream.write_all(&buf).await?;
1005                stream.shutdown().await?;
1006                Ok::<_, io::Error>(())
1007            };
1008            if let Err(err) = send.await {
1009                error!("error mirroring cancel to {ip}: {err}");
1010            }
1011        });
1012    }
1013}
1014
1015struct HttpsBalancer {
1016    resolver: Arc<StubResolver>,
1017    tls: Option<ReloadingSslContext>,
1018    resolve_template: Arc<str>,
1019    port: u16,
1020    metrics: Arc<ServerMetrics>,
1021    configs: ConfigSet,
1022    internal_tls: bool,
1023}
1024
1025impl HttpsBalancer {
1026    async fn resolve(
1027        resolver: &StubResolver,
1028        resolve_template: &str,
1029        port: u16,
1030        servername: Option<&str>,
1031    ) -> Result<ResolvedAddr, anyhow::Error> {
1032        let addr = match &servername {
1033            Some(servername) => resolve_template.replace("{}", servername),
1034            None => resolve_template.to_string(),
1035        };
1036        debug!("https address: {addr}");
1037
1038        // When we lookup the address using SNI, we get a hostname (`3dl07g8zmj91pntk4eo9cfvwe` for
1039        // example), which you convert into a different form for looking up the environment address
1040        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`. When you do a DNS lookup in kubernetes for
1041        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`, you get a CNAME response pointing at environmentd
1042        // `environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local`. This
1043        // is of the form `<service>.<namespace>.svc.cluster.local`. That `<namespace>` is the same
1044        // as the environment name, and is based on the tenant ID. `environment-<tenant_id>-<index>`
1045        // We currently only support a single environment per tenant in a region, so `<index>` is
1046        // always 0. Do not rely on this ending in `-0` so in the future multiple envds are
1047        // supported.
1048
1049        // Attempt to get a tenant.
1050        let tenant = Self::tenant(resolver, &addr).await;
1051
1052        // Now do the regular ip lookup, regardless of if there was a CNAME.
1053        let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1054
1055        Ok(ResolvedAddr {
1056            addr: envd_addr,
1057            password: None,
1058            tenant,
1059        })
1060    }
1061
1062    /// Finds the tenant of a DNS address. Errors or lack of cname resolution here are ok, because
1063    /// this is only used for metrics.
1064    async fn tenant(resolver: &StubResolver, addr: &str) -> Option<String> {
1065        let Ok(dname) = Dname::<Vec<_>>::from_str(addr) else {
1066            return None;
1067        };
1068        // Lookup the CNAME. If there's a CNAME, find the tenant.
1069        let lookup = resolver.query((dname, Rtype::Cname)).await;
1070        if let Ok(lookup) = lookup {
1071            if let Ok(answer) = lookup.answer() {
1072                let res = answer.limit_to::<AllRecordData<_, _>>();
1073                for record in res {
1074                    let Ok(record) = record else {
1075                        continue;
1076                    };
1077                    if record.rtype() != Rtype::Cname {
1078                        continue;
1079                    }
1080                    let cname = record.data();
1081                    let cname = cname.to_string();
1082                    debug!("cname: {cname}");
1083                    return Self::extract_tenant_from_cname(&cname);
1084                }
1085            }
1086        }
1087        None
1088    }
1089
1090    /// Extracts the tenant from a CNAME.
1091    fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1092        let mut parts = cname.split('.');
1093        let _service = parts.next();
1094        let Some(namespace) = parts.next() else {
1095            return None;
1096        };
1097        // Trim off the starting `environmentd-`.
1098        let Some((_, namespace)) = namespace.split_once('-') else {
1099            return None;
1100        };
1101        // Trim off the ending `-0` (or some other number).
1102        let Some((tenant, _)) = namespace.rsplit_once('-') else {
1103            return None;
1104        };
1105        // Convert to a Uuid so that this tenant matches the frontegg resolver exactly, because it
1106        // also uses Uuid::to_string.
1107        let Ok(tenant) = Uuid::parse_str(tenant) else {
1108            error!("cname tenant not a uuid: {tenant}");
1109            return None;
1110        };
1111        Some(tenant.to_string())
1112    }
1113}
1114
1115impl mz_server_core::Server for HttpsBalancer {
1116    const NAME: &'static str = "https_balancer";
1117
1118    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
1119    fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
1120        let tls_context = self.tls.clone();
1121        let internal_tls = self.internal_tls.clone();
1122        let resolver = Arc::clone(&self.resolver);
1123        let resolve_template = Arc::clone(&self.resolve_template);
1124        let port = self.port;
1125        let inner_metrics = Arc::clone(&self.metrics);
1126        let outer_metrics = Arc::clone(&self.metrics);
1127        let peer_addr = conn.peer_addr();
1128        let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1129        Box::pin(async move {
1130            let active_guard = inner_metrics.active_connections();
1131            let result: Result<_, anyhow::Error> = Box::pin(async move {
1132                let peer_addr = peer_addr.context("fetching peer addr")?;
1133                let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1134                    match tls_context {
1135                        Some(tls_context) => {
1136                            let mut ssl_stream =
1137                                SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1138                            if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1139                                let _ = ssl_stream.get_mut().shutdown().await;
1140                                return Err(e.into());
1141                            }
1142                            let servername: Option<String> =
1143                                ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1144                                    match sn.split_once('.') {
1145                                        Some((left, _right)) => left,
1146                                        None => sn,
1147                                    }
1148                                    .into()
1149                                });
1150                            debug!("servername: {servername:?}");
1151                            (Box::new(ssl_stream), servername)
1152                        }
1153                        _ => (Box::new(conn), None),
1154                    };
1155                let resolved =
1156                    Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1157                        .await?;
1158                let inner_active_guard = resolved
1159                    .tenant
1160                    .as_ref()
1161                    .map(|tenant| inner_metrics.tenant_connections(tenant));
1162
1163                let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1164
1165                if inject_proxy_headers {
1166                    // Write the tcp proxy header
1167                    let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1168                    let header = ProxyHeader::with_address(addrs);
1169                    let mut buf = [0u8; 1024];
1170                    let len = header.encode_to_slice_v2(&mut buf)?;
1171                    mz_stream.write_all(&buf[..len]).await?;
1172                }
1173
1174                let mut mz_stream = if internal_tls {
1175                    // do a TLS handshake
1176                    let mut builder =
1177                        SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1178                    // environmentd doesn't yet have a cert we trust, so for now disable verification.
1179                    builder.set_verify(SslVerifyMode::NONE);
1180                    let mut ssl = builder
1181                        .build()
1182                        .configure()?
1183                        .into_ssl(&resolved.addr.to_string())?;
1184                    ssl.set_connect_state();
1185                    Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1186                } else {
1187                    Conn::Unencrypted(mz_stream)
1188                };
1189
1190                let mut client_counter = CountingConn::new(client_stream);
1191
1192                // Now blindly shuffle bytes back and forth until closed.
1193                // TODO: Limit total memory use.
1194                // See corresponding comment in pgwire implementation about ignoring the error.
1195                let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1196                if let Some(tenant) = &resolved.tenant {
1197                    inner_metrics
1198                        .tenant_connections_tx(tenant)
1199                        .inc_by(u64::cast_from(client_counter.written));
1200                    inner_metrics
1201                        .tenant_connections_rx(tenant)
1202                        .inc_by(u64::cast_from(client_counter.read));
1203                }
1204                drop(inner_active_guard);
1205                Ok(())
1206            })
1207            .await;
1208            drop(active_guard);
1209            outer_metrics.connection_status(result.is_ok()).inc();
1210            if let Err(e) = result {
1211                debug!("connection error: {e}");
1212            }
1213            Ok(())
1214        })
1215    }
1216}
1217
1218trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1219impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1220
1221#[derive(Debug)]
1222pub enum Resolver {
1223    Static(String),
1224    Frontegg(FronteggResolver),
1225}
1226
1227impl Resolver {
1228    async fn resolve<A>(
1229        &self,
1230        conn: &mut FramedConn<A>,
1231        user: &str,
1232    ) -> Result<ResolvedAddr, anyhow::Error>
1233    where
1234        A: AsyncRead + AsyncWrite + Unpin,
1235    {
1236        match self {
1237            Resolver::Frontegg(FronteggResolver {
1238                auth,
1239                addr_template,
1240            }) => {
1241                conn.send(BackendMessage::AuthenticationCleartextPassword)
1242                    .await?;
1243                conn.flush().await?;
1244                let password = match conn.recv().await? {
1245                    Some(FrontendMessage::Password { password }) => password,
1246                    _ => anyhow::bail!("expected Password message"),
1247                };
1248
1249                let auth_response = auth.authenticate(user, &password).await;
1250                let auth_session = match auth_response {
1251                    Ok(auth_session) => auth_session,
1252                    Err(e) => {
1253                        warn!("pgwire connection failed authentication: {}", e);
1254                        // TODO: fix error codes.
1255                        anyhow::bail!("invalid password");
1256                    }
1257                };
1258
1259                let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string());
1260                let addr = lookup(&addr).await?;
1261                Ok(ResolvedAddr {
1262                    addr,
1263                    password: Some(password),
1264                    tenant: Some(auth_session.tenant_id().to_string()),
1265                })
1266            }
1267            Resolver::Static(addr) => {
1268                let addr = lookup(addr).await?;
1269                Ok(ResolvedAddr {
1270                    addr,
1271                    password: None,
1272                    tenant: None,
1273                })
1274            }
1275        }
1276    }
1277}
1278
1279/// Returns the first IP address resolved from the provided hostname.
1280async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1281    let mut addrs = tokio::net::lookup_host(name).await?;
1282    match addrs.next() {
1283        Some(addr) => Ok(addr),
1284        None => {
1285            error!("{name} did not resolve to any addresses");
1286            anyhow::bail!("internal error")
1287        }
1288    }
1289}
1290
1291#[derive(Debug)]
1292pub struct FronteggResolver {
1293    pub auth: FronteggAuthentication,
1294    pub addr_template: String,
1295}
1296
1297#[derive(Debug)]
1298struct ResolvedAddr {
1299    addr: SocketAddr,
1300    password: Option<String>,
1301    tenant: Option<String>,
1302}
1303
1304#[cfg(test)]
1305mod tests {
1306    use super::*;
1307
1308    #[mz_ore::test]
1309    fn test_tenant() {
1310        let tests = vec![
1311            ("", None),
1312            (
1313                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1314                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1315            ),
1316            (
1317                // Variously named parts.
1318                "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1319                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1320            ),
1321            (
1322                // No dashes in uuid.
1323                "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1324                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1325            ),
1326            (
1327                // -1234 suffix.
1328                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1329                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1330            ),
1331            (
1332                // Uppercase.
1333                "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1334                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1335            ),
1336            (
1337                // No -number suffix.
1338                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1339                None,
1340            ),
1341            (
1342                // No service name.
1343                "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1344                None,
1345            ),
1346            (
1347                // Invalid UUID.
1348                "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1349                None,
1350            ),
1351        ];
1352        for (name, expect) in tests {
1353            let cname = HttpsBalancer::extract_tenant_from_cname(name);
1354            assert_eq!(
1355                cname.as_deref(),
1356                expect,
1357                "{name} got {cname:?} expected {expect:?}"
1358            );
1359        }
1360    }
1361}