mz_balancerd/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The balancerd service is a horizontally scalable, stateless, multi-tenant ingress router for
11//! pgwire and HTTPS connections.
12//!
13//! It listens on pgwire and HTTPS ports. When a new pgwire connection starts, the requested user is
14//! authenticated with frontegg from which a tenant id is returned. From that a target internal
15//! hostname is resolved to an IP address, and the connection is proxied to that address which has a
16//! running environmentd's pgwire port. When a new HTTPS connection starts, its SNI hostname is used
17//! to generate an internal hostname that is resolved to an IP address, which is similarly proxied.
18
19mod codec;
20mod dyncfgs;
21
22use std::collections::BTreeMap;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::pin::Pin;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use anyhow::Context;
31use axum::response::IntoResponse;
32use axum::{Router, routing};
33use bytes::BytesMut;
34use domain::base::{Name, Rtype};
35use domain::rdata::AllRecordData;
36use domain::resolv::StubResolver;
37use futures::TryFutureExt;
38use futures::stream::BoxStream;
39use hyper::StatusCode;
40use hyper_util::rt::TokioIo;
41use launchdarkly_server_sdk as ld;
42use mz_build_info::{BuildInfo, build_info};
43use mz_dyncfg::ConfigSet;
44use mz_frontegg_auth::Authenticator as FronteggAuthentication;
45use mz_ore::cast::CastFrom;
46use mz_ore::id_gen::conn_id_org_uuid;
47use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
48use mz_ore::netio::AsyncReady;
49use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7};
50use mz_ore::task::{JoinSetExt, spawn};
51use mz_ore::tracing::TracingHandle;
52use mz_ore::{metric, netio};
53use mz_pgwire_common::{
54    ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ErrorResponse, FrontendMessage,
55    FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3, decode_startup,
56};
57use mz_server_core::{
58    Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
59    ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode, listen,
60};
61use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
62use prometheus::{IntCounterVec, IntGaugeVec};
63use proxy_header::{ProxiedAddress, ProxyHeader};
64use semver::Version;
65use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
66use tokio::net::TcpStream;
67use tokio::sync::oneshot;
68use tokio::task::JoinSet;
69use tokio_metrics::TaskMetrics;
70use tokio_openssl::SslStream;
71use tokio_postgres::error::SqlState;
72use tower::Service;
73use tracing::{debug, error, warn};
74use uuid::Uuid;
75
76use crate::codec::{BackendMessage, FramedConn};
77use crate::dyncfgs::{
78    INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_CONNECTION_WAIT, SIGTERM_LISTEN_WAIT,
79    has_tracing_config_update, tracing_config,
80};
81
82/// Balancer build information.
83pub const BUILD_INFO: BuildInfo = build_info!();
84
85pub struct BalancerConfig {
86    /// Info about which version of the code is running.
87    build_version: Version,
88    /// Listen address for internal HTTP health and metrics server.
89    internal_http_listen_addr: SocketAddr,
90    /// Listen address for pgwire connections.
91    pgwire_listen_addr: SocketAddr,
92    /// Listen address for HTTPS connections.
93    https_listen_addr: SocketAddr,
94    /// DNS resolver for pgwire cancellation requests
95    cancellation_resolver: CancellationResolver,
96    /// DNS resolver.
97    resolver: Resolver,
98    https_sni_addr_template: String,
99    tls: Option<TlsCertConfig>,
100    internal_tls: bool,
101    metrics_registry: MetricsRegistry,
102    reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
103    launchdarkly_sdk_key: Option<String>,
104    config_sync_file_path: Option<PathBuf>,
105    config_sync_timeout: Duration,
106    config_sync_loop_interval: Option<Duration>,
107    cloud_provider: Option<String>,
108    cloud_provider_region: Option<String>,
109    tracing_handle: TracingHandle,
110    default_configs: Vec<(String, String)>,
111}
112
113impl BalancerConfig {
114    pub fn new(
115        build_info: &BuildInfo,
116        internal_http_listen_addr: SocketAddr,
117        pgwire_listen_addr: SocketAddr,
118        https_listen_addr: SocketAddr,
119        cancellation_resolver: CancellationResolver,
120        resolver: Resolver,
121        https_sni_addr_template: String,
122        tls: Option<TlsCertConfig>,
123        internal_tls: bool,
124        metrics_registry: MetricsRegistry,
125        reload_certs: ReloadTrigger,
126        launchdarkly_sdk_key: Option<String>,
127        config_sync_file: Option<PathBuf>,
128        config_sync_timeout: Duration,
129        config_sync_loop_interval: Option<Duration>,
130        cloud_provider: Option<String>,
131        cloud_provider_region: Option<String>,
132        tracing_handle: TracingHandle,
133        default_configs: Vec<(String, String)>,
134    ) -> Self {
135        Self {
136            build_version: build_info.semver_version(),
137            internal_http_listen_addr,
138            pgwire_listen_addr,
139            https_listen_addr,
140            cancellation_resolver,
141            resolver,
142            https_sni_addr_template,
143            tls,
144            internal_tls,
145            metrics_registry,
146            reload_certs,
147            launchdarkly_sdk_key,
148            config_sync_file_path: config_sync_file,
149            config_sync_timeout,
150            config_sync_loop_interval,
151            cloud_provider,
152            cloud_provider_region,
153            tracing_handle,
154            default_configs,
155        }
156    }
157}
158
159/// Prometheus monitoring metrics.
160#[derive(Debug)]
161pub struct BalancerMetrics {
162    _uptime: ComputedGauge,
163}
164
165impl BalancerMetrics {
166    /// Returns a new [BalancerMetrics] instance connected to the registry in cfg.
167    pub fn new(cfg: &BalancerConfig) -> Self {
168        let start = Instant::now();
169        let uptime = cfg.metrics_registry.register_computed_gauge(
170            metric!(
171                name: "mz_balancer_metadata_seconds",
172                help: "server uptime, labels are build metadata",
173                const_labels: {
174                    "version" => cfg.build_version,
175                    "build_type" => if cfg!(release) { "release" } else { "debug" }
176                },
177            ),
178            move || start.elapsed().as_secs_f64(),
179        );
180        BalancerMetrics { _uptime: uptime }
181    }
182}
183
184pub struct BalancerService {
185    cfg: BalancerConfig,
186    pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
187    pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
188    pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
189    _metrics: BalancerMetrics,
190    configs: ConfigSet,
191}
192
193impl BalancerService {
194    pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
195        let pgwire = listen(&cfg.pgwire_listen_addr).await?;
196        let https = listen(&cfg.https_listen_addr).await?;
197        let internal_http = listen(&cfg.internal_http_listen_addr).await?;
198        let metrics = BalancerMetrics::new(&cfg);
199        let mut configs = ConfigSet::default();
200        configs = dyncfgs::all_dyncfgs(configs);
201        dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
202        let tracing_handle = cfg.tracing_handle.clone();
203        // Configure dyncfg sync
204        match (
205            cfg.launchdarkly_sdk_key.as_deref(),
206            cfg.config_sync_file_path.as_deref(),
207        ) {
208            (Some(key), None) => {
209                mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
210                    configs.clone(),
211                    &BUILD_INFO,
212                    |builder| {
213                        let region = cfg
214                            .cloud_provider_region
215                            .clone()
216                            .unwrap_or_else(|| String::from("unknown"));
217                        if let Some(provider) = cfg.cloud_provider.clone() {
218                            builder.add_context(
219                                ld::ContextBuilder::new(format!(
220                                    "{}/{}/{}",
221                                    provider, region, cfg.build_version
222                                ))
223                                .kind("balancer")
224                                .set_string("provider", provider)
225                                .set_string("region", region)
226                                .set_string("version", cfg.build_version.to_string())
227                                .build()
228                                .map_err(|e| anyhow::anyhow!(e))?,
229                            );
230                        } else {
231                            builder.add_context(
232                                ld::ContextBuilder::new(format!(
233                                    "{}/{}/{}",
234                                    "unknown", region, cfg.build_version
235                                ))
236                                .anonymous(true) // exclude this user from the dashboard
237                                .kind("balancer")
238                                .set_string("provider", "unknown")
239                                .set_string("region", region)
240                                .set_string("version", cfg.build_version.to_string())
241                                .build()
242                                .map_err(|e| anyhow::anyhow!(e))?,
243                            );
244                        }
245                        Ok(())
246                    },
247                    Some(key),
248                    cfg.config_sync_timeout,
249                    cfg.config_sync_loop_interval,
250                    move |updates, configs| {
251                        if has_tracing_config_update(updates) {
252                            match tracing_config(configs) {
253                                Ok(parameters) => parameters.apply(&tracing_handle),
254                                Err(err) => warn!("unable to update tracing: {err}"),
255                            }
256                        }
257                    },
258                )
259                .await
260                .inspect_err(|e| warn!("LaunchDarkly sync error: {e}"))
261                .ok();
262            }
263            (None, Some(path)) => {
264                mz_dyncfg_file::sync_file_to_configset(
265                    configs.clone(),
266                    path,
267                    cfg.config_sync_timeout,
268                    cfg.config_sync_loop_interval,
269                    move |updates, configs| {
270                        if has_tracing_config_update(updates) {
271                            match tracing_config(configs) {
272                                Ok(parameters) => parameters.apply(&tracing_handle),
273                                Err(err) => warn!("unable to update tracing: {err}"),
274                            }
275                        }
276                    },
277                )
278                .await
279                // If there's an Error, log but continue anyway. If LD is down
280                // we have no way of fetching the previous value of the flag
281                // (unlike the adapter, but it has a durable catalog). The
282                // ConfigSet defaults have been chosen to be good enough if this
283                // is the case.
284                .inspect_err(|e| warn!("File config sync error: {e}"))
285                .ok();
286            }
287            (Some(_), Some(_)) => panic!(
288                "must provide either config_sync_file_path or launchdarkly_sdk_key for config syncing",
289            ),
290            (None, None) => {}
291        };
292        Ok(Self {
293            cfg,
294            pgwire,
295            https,
296            internal_http,
297            _metrics: metrics,
298            configs,
299        })
300    }
301
302    pub async fn serve(self) -> Result<(), anyhow::Error> {
303        let (pgwire_tls, https_tls) = match &self.cfg.tls {
304            Some(tls) => {
305                let context = tls.reloading_context(self.cfg.reload_certs)?;
306                (
307                    Some(ReloadingTlsConfig {
308                        context: context.clone(),
309                        mode: TlsMode::Require,
310                    }),
311                    Some(context),
312                )
313            }
314            None => (None, None),
315        };
316
317        let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
318
319        let mut set = JoinSet::new();
320        let mut server_handles = Vec::new();
321        let pgwire_addr = self.pgwire.0.local_addr();
322        let https_addr = self.https.0.local_addr();
323        let internal_http_addr = self.internal_http.0.local_addr();
324
325        {
326            let pgwire = PgwireBalancer {
327                resolver: Arc::new(self.cfg.resolver),
328                cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
329                tls: pgwire_tls,
330                internal_tls: self.cfg.internal_tls,
331                metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
332                now: SYSTEM_TIME.clone(),
333            };
334            let (handle, stream) = self.pgwire;
335            server_handles.push(handle);
336            set.spawn_named(|| "pgwire_stream", {
337                let config_set = self.configs.clone();
338                async move {
339                    mz_server_core::serve(ServeConfig {
340                        server: pgwire,
341                        conns: stream,
342                        dyncfg: Some(ServeDyncfg {
343                            config_set,
344                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
345                        }),
346                    })
347                    .await;
348                    warn!("pgwire server exited");
349                }
350            });
351        }
352        {
353            let Some((addr, port)) = self.cfg.https_sni_addr_template.split_once(':') else {
354                panic!("expected port in https_addr_template");
355            };
356            let port: u16 = port.parse().expect("unexpected port");
357            let resolver = StubResolver::new();
358            let https = HttpsBalancer {
359                resolver: Arc::from(resolver),
360                tls: https_tls,
361                resolve_template: Arc::from(addr),
362                port,
363                metrics: Arc::from(ServerMetrics::new(metrics, "https")),
364                configs: self.configs.clone(),
365                internal_tls: self.cfg.internal_tls,
366            };
367            let (handle, stream) = self.https;
368            server_handles.push(handle);
369            set.spawn_named(|| "https_stream", {
370                let config_set = self.configs.clone();
371                async move {
372                    mz_server_core::serve(ServeConfig {
373                        server: https,
374                        conns: stream,
375                        dyncfg: Some(ServeDyncfg {
376                            config_set,
377                            sigterm_wait_config: &SIGTERM_CONNECTION_WAIT,
378                        }),
379                    })
380                    .await;
381                    warn!("https server exited");
382                }
383            });
384        }
385        {
386            let router = Router::new()
387                .route(
388                    "/metrics",
389                    routing::get(move || async move {
390                        mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
391                    }),
392                )
393                .route(
394                    "/api/livez",
395                    routing::get(mz_http_util::handle_liveness_check),
396                )
397                .route("/api/readyz", routing::get(handle_readiness_check));
398            let internal_http = InternalHttpServer { router };
399            let (handle, stream) = self.internal_http;
400            server_handles.push(handle);
401            set.spawn_named(|| "internal_http_stream", async move {
402                mz_server_core::serve(ServeConfig {
403                    server: internal_http,
404                    conns: stream,
405                    // Disable graceful termination because our internal
406                    // monitoring keeps persistent HTTP connections open.
407                    dyncfg: None,
408                })
409                .await;
410                warn!("internal_http server exited");
411            });
412        }
413        #[cfg(unix)]
414        {
415            let mut sigterm =
416                tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
417            set.spawn_named(|| "sigterm_handler", async move {
418                sigterm.recv().await;
419                let wait = SIGTERM_LISTEN_WAIT.get(&self.configs);
420                warn!("received signal TERM - delaying for {:?}!", wait);
421                tokio::time::sleep(wait).await;
422                warn!("sigterm delay complete, dropping server handles");
423                drop(server_handles);
424            });
425        }
426
427        println!("balancerd {} listening...", BUILD_INFO.human_version(None));
428        println!(" TLS enabled: {}", self.cfg.tls.is_some());
429        println!(" pgwire address: {}", pgwire_addr);
430        println!(" HTTPS address: {}", https_addr);
431        println!(" internal HTTP address: {}", internal_http_addr);
432
433        // Wait for all tasks to exit, which can happen on SIGTERM.
434        while let Some(res) = set.join_next().await {
435            if let Err(err) = res {
436                error!("serving task failed: {err}")
437            }
438        }
439        Ok(())
440    }
441}
442
443#[allow(clippy::unused_async)]
444async fn handle_readiness_check() -> impl IntoResponse {
445    (StatusCode::OK, "ready")
446}
447
448struct InternalHttpServer {
449    router: Router,
450}
451
452impl mz_server_core::Server for InternalHttpServer {
453    const NAME: &'static str = "internal_http";
454
455    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
456    fn handle_connection(
457        &self,
458        conn: Connection,
459        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
460    ) -> mz_server_core::ConnectionHandler {
461        let router = self.router.clone();
462        let service = hyper::service::service_fn(move |req| router.clone().call(req));
463        let conn = TokioIo::new(conn);
464
465        Box::pin(async {
466            let http = hyper::server::conn::http1::Builder::new();
467            http.serve_connection(conn, service).err_into().await
468        })
469    }
470}
471
472/// Wraps an IntGauge and automatically `inc`s on init and `drop`s on drop. Callers should not call
473/// `inc().`. Useful for handling multiple task exit points, for example in the case of a panic.
474struct GaugeGuard {
475    gauge: IntGauge,
476}
477
478impl From<IntGauge> for GaugeGuard {
479    fn from(gauge: IntGauge) -> Self {
480        let _self = Self { gauge };
481        _self.gauge.inc();
482        _self
483    }
484}
485
486impl Drop for GaugeGuard {
487    fn drop(&mut self) {
488        self.gauge.dec();
489    }
490}
491
492#[derive(Clone, Debug)]
493struct ServerMetricsConfig {
494    connection_status: IntCounterVec,
495    active_connections: IntGaugeVec,
496    tenant_connections: IntGaugeVec,
497    tenant_connection_rx: IntCounterVec,
498    tenant_connection_tx: IntCounterVec,
499    tenant_pgwire_sni_count: IntCounterVec,
500}
501
502impl ServerMetricsConfig {
503    fn register_into(registry: &MetricsRegistry) -> Self {
504        let connection_status = registry.register(metric!(
505            name: "mz_balancer_connection_status",
506            help: "Count of completed network connections, by status",
507            var_labels: ["source", "status"],
508        ));
509        let active_connections = registry.register(metric!(
510            name: "mz_balancer_connection_active",
511            help: "Count of currently open network connections.",
512            var_labels: ["source"],
513        ));
514        let tenant_connections = registry.register(metric!(
515            name: "mz_balancer_tenant_connection_active",
516            help: "Count of opened network connections by tenant.",
517            var_labels: ["source",  "tenant"]
518        ));
519        let tenant_connection_rx = registry.register(metric!(
520            name: "mz_balancer_tenant_connection_rx",
521            help: "Number of bytes received from a client for a tenant.",
522            var_labels: ["source", "tenant"],
523        ));
524        let tenant_connection_tx = registry.register(metric!(
525            name: "mz_balancer_tenant_connection_tx",
526            help: "Number of bytes sent to a client for a tenant.",
527            var_labels: ["source", "tenant"],
528        ));
529        let tenant_pgwire_sni_count = registry.register(metric!(
530            name: "mz_balancer_tenant_pgwire_sni_count",
531            help: "Count of pgwire connections that have and do not have SNI available per tenant.",
532            var_labels: ["tenant", "has_sni"],
533        ));
534        Self {
535            connection_status,
536            active_connections,
537            tenant_connections,
538            tenant_connection_rx,
539            tenant_connection_tx,
540            tenant_pgwire_sni_count,
541        }
542    }
543}
544
545#[derive(Clone, Debug)]
546struct ServerMetrics {
547    inner: ServerMetricsConfig,
548    source: &'static str,
549}
550
551impl ServerMetrics {
552    fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
553        let self_ = Self { inner, source };
554
555        // Pre-initialize labels we are planning to use to ensure they are all always emitted as
556        // time series.
557        self_.connection_status(false);
558        self_.connection_status(true);
559        drop(self_.active_connections());
560
561        self_
562    }
563
564    fn connection_status(&self, is_ok: bool) -> IntCounter {
565        self.inner
566            .connection_status
567            .with_label_values(&[self.source, Self::status_label(is_ok)])
568    }
569
570    fn active_connections(&self) -> GaugeGuard {
571        self.inner
572            .active_connections
573            .with_label_values(&[self.source])
574            .into()
575    }
576
577    fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
578        self.inner
579            .tenant_connections
580            .with_label_values(&[self.source, tenant])
581            .into()
582    }
583
584    fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
585        self.inner
586            .tenant_connection_rx
587            .with_label_values(&[self.source, tenant])
588    }
589
590    fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
591        self.inner
592            .tenant_connection_tx
593            .with_label_values(&[self.source, tenant])
594    }
595
596    fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
597        self.inner
598            .tenant_pgwire_sni_count
599            .with_label_values(&[tenant, &has_sni.to_string()])
600    }
601
602    fn status_label(is_ok: bool) -> &'static str {
603        if is_ok { "success" } else { "error" }
604    }
605}
606
607pub enum CancellationResolver {
608    Directory(PathBuf),
609    Static(String),
610}
611
612struct PgwireBalancer {
613    tls: Option<ReloadingTlsConfig>,
614    internal_tls: bool,
615    cancellation_resolver: Arc<CancellationResolver>,
616    resolver: Arc<Resolver>,
617    metrics: ServerMetrics,
618    now: NowFn,
619}
620
621impl PgwireBalancer {
622    #[mz_ore::instrument(level = "debug")]
623    async fn run<'a, A>(
624        conn: &'a mut FramedConn<A>,
625        version: i32,
626        params: BTreeMap<String, String>,
627        resolver: &Resolver,
628        tls_mode: Option<TlsMode>,
629        internal_tls: bool,
630        metrics: &ServerMetrics,
631    ) -> Result<(), io::Error>
632    where
633        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
634    {
635        if version != VERSION_3 {
636            return conn
637                .send(ErrorResponse::fatal(
638                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
639                    "server does not support the client's requested protocol version",
640                ))
641                .await;
642        }
643
644        let Some(user) = params.get("user") else {
645            return conn
646                .send(ErrorResponse::fatal(
647                    SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
648                    "user parameter required",
649                ))
650                .await;
651        };
652
653        if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
654            return conn.send(err).await;
655        }
656
657        let resolved = match resolver.resolve(conn, user, metrics).await {
658            Ok(v) => v,
659            Err(err) => {
660                return conn
661                    .send(ErrorResponse::fatal(
662                        SqlState::INVALID_PASSWORD,
663                        err.to_string(),
664                    ))
665                    .await;
666            }
667        };
668
669        let _active_guard = resolved
670            .tenant
671            .as_ref()
672            .map(|tenant| metrics.tenant_connections(tenant));
673        let Ok(mut mz_stream) =
674            Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
675        else {
676            return Ok(());
677        };
678
679        let mut client_counter = CountingConn::new(conn.inner_mut());
680
681        // Now blindly shuffle bytes back and forth until closed.
682        // TODO: Limit total memory use.
683        let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
684        if let Some(tenant) = &resolved.tenant {
685            metrics
686                .tenant_connections_tx(tenant)
687                .inc_by(u64::cast_from(client_counter.written));
688            metrics
689                .tenant_connections_rx(tenant)
690                .inc_by(u64::cast_from(client_counter.read));
691        }
692        res?;
693
694        Ok(())
695    }
696
697    #[mz_ore::instrument(level = "debug")]
698    async fn init_stream<'a, A>(
699        conn: &'a mut FramedConn<A>,
700        envd_addr: SocketAddr,
701        password: Option<String>,
702        params: BTreeMap<String, String>,
703        internal_tls: bool,
704    ) -> Result<Conn<TcpStream>, anyhow::Error>
705    where
706        A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
707    {
708        let mut mz_stream = TcpStream::connect(envd_addr).await?;
709        let mut buf = BytesMut::new();
710
711        let mut mz_stream = if internal_tls {
712            FrontendStartupMessage::SslRequest.encode(&mut buf)?;
713            mz_stream.write_all(&buf).await?;
714            buf.clear();
715            let mut maybe_ssl_request_response = [0u8; 1];
716            let nread =
717                netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
718            if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
719                // do a TLS handshake
720                let mut builder =
721                    SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
722                // environmentd doesn't yet have a cert we trust, so for now disable verification.
723                builder.set_verify(SslVerifyMode::NONE);
724                let mut ssl = builder
725                    .build()
726                    .configure()?
727                    .into_ssl(&envd_addr.to_string())?;
728                ssl.set_connect_state();
729                Conn::Ssl(SslStream::new(ssl, mz_stream)?)
730            } else {
731                Conn::Unencrypted(mz_stream)
732            }
733        } else {
734            Conn::Unencrypted(mz_stream)
735        };
736
737        // Send initial startup and password messages.
738        let startup = FrontendStartupMessage::Startup {
739            version: VERSION_3,
740            params,
741        };
742        startup.encode(&mut buf)?;
743        mz_stream.write_all(&buf).await?;
744        let client_stream = conn.inner_mut();
745
746        // This early return is important in self managed with SASL mode.
747        // The below code specifically looks for cleartext password requests, but in SASL mode
748        // the server will send a different message type (SASLInitialResponse) that we should
749        // not try to interpret or respond to.
750        // "Why not? That code looks like it should fall back fine?" You may ask.
751        // The below block unconditionally reads 9 bytes from the server. If we don't have
752        // a password or the message isn't a cleartext password request, we forward those 9 bytes
753        // to the client. Then we return the stream to the caller, who will continue shuffling bytes.
754        // The problem is that with TLS enabled between balancerd <-> client, flushing the first 9 bytes
755        // before copying bidirectionally will have the side effect of splitting the auth handshake into
756        // two SSL records. Pgbouncer misbehaves in this scenario, and fails the connection.
757        // PGbouncer shouldn't do this! It's a common footgun of protocols over TLS.
758        // So common in fact that PGbouncer already hit and fixed this issue on the bouncer <-> client side:
759        // once before: https://github.com/pgbouncer/pgbouncer/pull/1058.
760        // We will work to upstream a fix, but in the meantime, this early return avoids the issue entirely.
761        if password.is_none() {
762            return Ok(mz_stream);
763        }
764
765        // Read a single backend message, which may be a password request. Send ours if so.
766        // Otherwise start shuffling bytes. message type (len 1, 'R') + message len (len 4, 8_i32) +
767        // auth type (len 4, 3_i32).
768        let mut maybe_auth_frame = [0; 1 + 4 + 4];
769        let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
770        // 'R' for auth message, 0008 for message length, 0003 for password cleartext variant.
771        // See: https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD
772        const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
773        if nread == AUTH_PASSWORD_CLEARTEXT.len()
774            && maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
775            && password.is_some()
776        {
777            // If we got exactly a cleartext password request and have one, send it.
778            let Some(password) = password else {
779                unreachable!("verified some above");
780            };
781            let password = FrontendMessage::Password { password };
782            buf.clear();
783            password.encode(&mut buf)?;
784            mz_stream.write_all(&buf).await?;
785            mz_stream.flush().await?;
786        } else {
787            // Otherwise pass on the bytes we just got. This *might* even be a password request, but
788            // we don't have a password. In which case it can be forwarded up to the client.
789            client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
790        }
791
792        Ok(mz_stream)
793    }
794}
795
796impl mz_server_core::Server for PgwireBalancer {
797    const NAME: &'static str = "pgwire_balancer";
798
799    fn handle_connection(
800        &self,
801        conn: Connection,
802        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
803    ) -> mz_server_core::ConnectionHandler {
804        let tls = self.tls.clone();
805        let internal_tls = self.internal_tls;
806        let resolver = Arc::clone(&self.resolver);
807        let inner_metrics = self.metrics.clone();
808        let outer_metrics = self.metrics.clone();
809        let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
810        let conn_uuid = epoch_to_uuid_v7(&(self.now)());
811        let peer_addr = conn.peer_addr();
812        conn.uuid_handle().set(conn_uuid);
813        Box::pin(async move {
814            // TODO: Try to merge this with pgwire/server.rs to avoid the duplication. May not be
815            // worth it.
816            let active_guard = outer_metrics.active_connections();
817            let result: Result<(), anyhow::Error> = async move {
818                let mut conn = Conn::Unencrypted(conn);
819                loop {
820                    let message = decode_startup(&mut conn).await?;
821                    conn = match message {
822                        // Clients sometimes hang up during the startup sequence, e.g.
823                        // because they receive an unacceptable response to an
824                        // `SslRequest`. This is considered a graceful termination.
825                        None => return Ok(()),
826
827                        Some(FrontendStartupMessage::Startup {
828                            version,
829                            mut params,
830                        }) => {
831                            let mut conn = FramedConn::new(conn);
832                            let peer_addr = match peer_addr {
833                                Ok(addr) => addr.ip(),
834                                Err(e) => {
835                                    error!("Invalid peer_addr {:?}", e);
836                                    return Ok(conn
837                                        .send(ErrorResponse::fatal(
838                                            SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
839                                            "invalid peer address",
840                                        ))
841                                        .await?);
842                                }
843                            };
844                            debug!(%conn_uuid, %peer_addr,  "starting new pgwire connection in balancer");
845                            let prev =
846                                params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
847                            if prev.is_some() {
848                                return Ok(conn
849                                    .send(ErrorResponse::fatal(
850                                        SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
851                                        format!("invalid parameter '{CONN_UUID_KEY}'"),
852                                    ))
853                                    .await?);
854                            }
855
856                            if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
857                                return Ok(conn
858                                    .send(ErrorResponse::fatal(
859                                        SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
860                                        format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
861                                    ))
862                                    .await?);
863                            };
864
865                            Self::run(
866                                &mut conn,
867                                version,
868                                params,
869                                &resolver,
870                                tls.map(|tls| tls.mode),
871                                internal_tls,
872                                &inner_metrics,
873                            )
874                            .await?;
875                            conn.flush().await?;
876                            return Ok(());
877                        }
878
879                        Some(FrontendStartupMessage::CancelRequest {
880                            conn_id,
881                            secret_key,
882                        }) => {
883                            spawn(|| "cancel request", async move {
884                                cancel_request(conn_id, secret_key, &cancellation_resolver).await;
885                            });
886                            // Do not wait on cancel requests to return because cancellation is best
887                            // effort.
888                            return Ok(());
889                        }
890
891                        Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
892                            (Conn::Unencrypted(mut conn), Some(tls)) => {
893                                conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
894                                let mut ssl_stream =
895                                    SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
896                                if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
897                                    let _ = ssl_stream.get_mut().shutdown().await;
898                                    return Err(e.into());
899                                }
900                                Conn::Ssl(ssl_stream)
901                            }
902                            (mut conn, _) => {
903                                conn.write_all(&[REJECT_ENCRYPTION]).await?;
904                                conn
905                            }
906                        },
907
908                        Some(FrontendStartupMessage::GssEncRequest) => {
909                            conn.write_all(&[REJECT_ENCRYPTION]).await?;
910                            conn
911                        }
912                    }
913                }
914            }
915            .await;
916            drop(active_guard);
917            outer_metrics.connection_status(result.is_ok()).inc();
918            Ok(())
919        })
920    }
921}
922
923// A struct that counts bytes exchanged.
924struct CountingConn<C> {
925    inner: C,
926    read: usize,
927    written: usize,
928}
929
930impl<C> CountingConn<C> {
931    fn new(inner: C) -> Self {
932        CountingConn {
933            inner,
934            read: 0,
935            written: 0,
936        }
937    }
938}
939
940impl<C> AsyncRead for CountingConn<C>
941where
942    C: AsyncRead + Unpin,
943{
944    fn poll_read(
945        self: Pin<&mut Self>,
946        cx: &mut std::task::Context<'_>,
947        buf: &mut io::ReadBuf<'_>,
948    ) -> std::task::Poll<std::io::Result<()>> {
949        let counter = self.get_mut();
950        let pin = Pin::new(&mut counter.inner);
951        let bytes = buf.filled().len();
952        let poll = pin.poll_read(cx, buf);
953        let bytes = buf.filled().len() - bytes;
954        if let std::task::Poll::Ready(Ok(())) = poll {
955            counter.read += bytes
956        }
957        poll
958    }
959}
960
961impl<C> AsyncWrite for CountingConn<C>
962where
963    C: AsyncWrite + Unpin,
964{
965    fn poll_write(
966        self: Pin<&mut Self>,
967        cx: &mut std::task::Context<'_>,
968        buf: &[u8],
969    ) -> std::task::Poll<Result<usize, std::io::Error>> {
970        let counter = self.get_mut();
971        let pin = Pin::new(&mut counter.inner);
972        let poll = pin.poll_write(cx, buf);
973        if let std::task::Poll::Ready(Ok(bytes)) = poll {
974            counter.written += bytes
975        }
976        poll
977    }
978
979    fn poll_flush(
980        self: Pin<&mut Self>,
981        cx: &mut std::task::Context<'_>,
982    ) -> std::task::Poll<Result<(), std::io::Error>> {
983        let counter = self.get_mut();
984        let pin = Pin::new(&mut counter.inner);
985        pin.poll_flush(cx)
986    }
987
988    fn poll_shutdown(
989        self: Pin<&mut Self>,
990        cx: &mut std::task::Context<'_>,
991    ) -> std::task::Poll<Result<(), std::io::Error>> {
992        let counter = self.get_mut();
993        let pin = Pin::new(&mut counter.inner);
994        pin.poll_shutdown(cx)
995    }
996}
997
998/// Broadcasts cancellation to all matching environmentds. `conn_id`'s bits [31..20] are the lower
999/// 12 bits of a UUID for an environmentd/organization. Using that and the template in
1000/// `cancellation_resolver` we generate a hostname. That hostname resolves to all IPs of envds that
1001/// match the UUID (cloud k8s infrastructure maintains that mapping). This function creates a new
1002/// task for each envd and relays the cancellation message to it, broadcasting it to any envd that
1003/// might match the connection.
1004///
1005/// This function returns after it has spawned the tasks, and does not wait for them to complete.
1006/// This is acceptable because cancellation in the Postgres protocol is best effort and has no
1007/// guarantees.
1008///
1009/// The safety of broadcasting this is due to the various randomness in the connection id and secret
1010/// key, which must match exactly in order to execute a query cancellation. The connection id has 19
1011/// bits of randomness, and the secret key the full 32, for a total of 51 bits. That is more than
1012/// 2e15 combinations, enough to nearly certainly prevent two different envds generating identical
1013/// combinations.
1014async fn cancel_request(
1015    conn_id: u32,
1016    secret_key: u32,
1017    cancellation_resolver: &CancellationResolver,
1018) {
1019    let suffix = conn_id_org_uuid(conn_id);
1020    let contents = match cancellation_resolver {
1021        CancellationResolver::Directory(dir) => {
1022            let path = dir.join(&suffix);
1023            match std::fs::read_to_string(&path) {
1024                Ok(contents) => contents,
1025                Err(err) => {
1026                    error!("could not read cancel file {path:?}: {err}");
1027                    return;
1028                }
1029            }
1030        }
1031        CancellationResolver::Static(addr) => addr.to_owned(),
1032    };
1033    let mut all_ips = Vec::new();
1034    for addr in contents.lines() {
1035        let addr = addr.trim();
1036        if addr.is_empty() {
1037            continue;
1038        }
1039        match tokio::net::lookup_host(addr).await {
1040            Ok(ips) => all_ips.extend(ips),
1041            Err(err) => {
1042                error!("{addr} failed resolution: {err}");
1043            }
1044        }
1045    }
1046    let mut buf = BytesMut::with_capacity(16);
1047    let msg = FrontendStartupMessage::CancelRequest {
1048        conn_id,
1049        secret_key,
1050    };
1051    msg.encode(&mut buf).expect("must encode");
1052    let buf = buf.freeze();
1053    for ip in all_ips {
1054        debug!("cancelling {suffix} to {ip}");
1055        let buf = buf.clone();
1056        spawn(|| "cancel request for ip", async move {
1057            let send = async {
1058                let mut stream = TcpStream::connect(&ip).await?;
1059                stream.write_all(&buf).await?;
1060                stream.shutdown().await?;
1061                Ok::<_, io::Error>(())
1062            };
1063            if let Err(err) = send.await {
1064                error!("error mirroring cancel to {ip}: {err}");
1065            }
1066        });
1067    }
1068}
1069
1070struct HttpsBalancer {
1071    resolver: Arc<StubResolver>,
1072    tls: Option<ReloadingSslContext>,
1073    resolve_template: Arc<str>,
1074    port: u16,
1075    metrics: Arc<ServerMetrics>,
1076    configs: ConfigSet,
1077    internal_tls: bool,
1078}
1079
1080impl HttpsBalancer {
1081    async fn resolve(
1082        resolver: &StubResolver,
1083        resolve_template: &str,
1084        port: u16,
1085        servername: Option<&str>,
1086    ) -> Result<ResolvedAddr, anyhow::Error> {
1087        let addr = match &servername {
1088            Some(servername) => resolve_template.replace("{}", servername),
1089            None => resolve_template.to_string(),
1090        };
1091        debug!("https address: {addr}");
1092
1093        // When we lookup the address using SNI, we get a hostname (`3dl07g8zmj91pntk4eo9cfvwe` for
1094        // example), which you convert into a different form for looking up the environment address
1095        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`. When you do a DNS lookup in kubernetes for
1096        // `blncr-3dl07g8zmj91pntk4eo9cfvwe`, you get a CNAME response pointing at environmentd
1097        // `environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local`. This
1098        // is of the form `<service>.<namespace>.svc.cluster.local`. That `<namespace>` is the same
1099        // as the environment name, and is based on the tenant ID. `environment-<tenant_id>-<index>`
1100        // We currently only support a single environment per tenant in a region, so `<index>` is
1101        // always 0. Do not rely on this ending in `-0` so in the future multiple envds are
1102        // supported.
1103
1104        // Attempt to get a tenant.
1105        let tenant = resolver.tenant(&addr).await;
1106
1107        // Now do the regular ip lookup, regardless of if there was a CNAME.
1108        let envd_addr = lookup(&format!("{addr}:{port}")).await?;
1109
1110        Ok(ResolvedAddr {
1111            addr: envd_addr,
1112            password: None,
1113            tenant,
1114        })
1115    }
1116}
1117
1118trait StubResolverExt {
1119    async fn tenant(&self, addr: &str) -> Option<String>;
1120}
1121
1122impl StubResolverExt for StubResolver {
1123    /// Finds the tenant of a DNS address. Errors or lack of cname resolution here are ok, because
1124    /// this is only used for metrics.
1125    async fn tenant(&self, addr: &str) -> Option<String> {
1126        let Ok(dname) = Name::<Vec<_>>::from_str(addr) else {
1127            return None;
1128        };
1129        debug!("resolving tenant for {:?}", addr);
1130        // Lookup the CNAME. If there's a CNAME, find the tenant.
1131        let lookup = self.query((dname, Rtype::CNAME)).await;
1132        if let Ok(lookup) = lookup {
1133            if let Ok(answer) = lookup.answer() {
1134                let res = answer.limit_to::<AllRecordData<_, _>>();
1135                for record in res {
1136                    let Ok(record) = record else {
1137                        continue;
1138                    };
1139                    if record.rtype() != Rtype::CNAME {
1140                        continue;
1141                    }
1142                    let cname = record.data();
1143                    let cname = cname.to_string();
1144                    debug!("cname: {cname}");
1145                    return extract_tenant_from_cname(&cname);
1146                }
1147            }
1148        }
1149        None
1150    }
1151}
1152
1153/// Extracts the tenant from a CNAME.
1154fn extract_tenant_from_cname(cname: &str) -> Option<String> {
1155    let mut parts = cname.split('.');
1156    let _service = parts.next();
1157    let Some(namespace) = parts.next() else {
1158        return None;
1159    };
1160    // Trim off the starting `environmentd-`.
1161    let Some((_, namespace)) = namespace.split_once('-') else {
1162        return None;
1163    };
1164    // Trim off the ending `-0` (or some other number).
1165    let Some((tenant, _)) = namespace.rsplit_once('-') else {
1166        return None;
1167    };
1168    // Convert to a Uuid so that this tenant matches the frontegg resolver exactly, because it
1169    // also uses Uuid::to_string.
1170    let Ok(tenant) = Uuid::parse_str(tenant) else {
1171        error!("cname tenant not a uuid: {tenant}");
1172        return None;
1173    };
1174    Some(tenant.to_string())
1175}
1176
1177impl mz_server_core::Server for HttpsBalancer {
1178    const NAME: &'static str = "https_balancer";
1179
1180    // TODO(jkosh44) consider forwarding the connection UUID to the adapter.
1181    fn handle_connection(
1182        &self,
1183        conn: Connection,
1184        _tokio_metrics_intervals: impl Iterator<Item = TaskMetrics> + Send + 'static,
1185    ) -> mz_server_core::ConnectionHandler {
1186        let tls_context = self.tls.clone();
1187        let internal_tls = self.internal_tls.clone();
1188        let resolver = Arc::clone(&self.resolver);
1189        let resolve_template = Arc::clone(&self.resolve_template);
1190        let port = self.port;
1191        let inner_metrics = Arc::clone(&self.metrics);
1192        let outer_metrics = Arc::clone(&self.metrics);
1193        let peer_addr = conn.peer_addr();
1194        let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
1195        Box::pin(async move {
1196            let active_guard = inner_metrics.active_connections();
1197            let result: Result<_, anyhow::Error> = Box::pin(async move {
1198                let peer_addr = peer_addr.context("fetching peer addr")?;
1199                let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
1200                    match tls_context {
1201                        Some(tls_context) => {
1202                            let mut ssl_stream =
1203                                SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
1204                            if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
1205                                let _ = ssl_stream.get_mut().shutdown().await;
1206                                return Err(e.into());
1207                            }
1208                            let servername: Option<String> =
1209                                ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1210                                    match sn.split_once('.') {
1211                                        Some((left, _right)) => left,
1212                                        None => sn,
1213                                    }
1214                                    .into()
1215                                });
1216                            debug!("Found sni servername: {servername:?} (https)");
1217                            (Box::new(ssl_stream), servername)
1218                        }
1219                        _ => (Box::new(conn), None),
1220                    };
1221                let resolved =
1222                    Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
1223                        .await?;
1224                let inner_active_guard = resolved
1225                    .tenant
1226                    .as_ref()
1227                    .map(|tenant| inner_metrics.tenant_connections(tenant));
1228
1229                let mut mz_stream = TcpStream::connect(resolved.addr).await?;
1230
1231                if inject_proxy_headers {
1232                    // Write the tcp proxy header
1233                    let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
1234                    let header = ProxyHeader::with_address(addrs);
1235                    let mut buf = [0u8; 1024];
1236                    let len = header.encode_to_slice_v2(&mut buf)?;
1237                    mz_stream.write_all(&buf[..len]).await?;
1238                }
1239
1240                let mut mz_stream = if internal_tls {
1241                    // do a TLS handshake
1242                    let mut builder =
1243                        SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
1244                    // environmentd doesn't yet have a cert we trust, so for now disable verification.
1245                    builder.set_verify(SslVerifyMode::NONE);
1246                    let mut ssl = builder
1247                        .build()
1248                        .configure()?
1249                        .into_ssl(&resolved.addr.to_string())?;
1250                    ssl.set_connect_state();
1251                    Conn::Ssl(SslStream::new(ssl, mz_stream)?)
1252                } else {
1253                    Conn::Unencrypted(mz_stream)
1254                };
1255
1256                let mut client_counter = CountingConn::new(client_stream);
1257
1258                // Now blindly shuffle bytes back and forth until closed.
1259                // TODO: Limit total memory use.
1260                // See corresponding comment in pgwire implementation about ignoring the error.
1261                let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
1262                if let Some(tenant) = &resolved.tenant {
1263                    inner_metrics
1264                        .tenant_connections_tx(tenant)
1265                        .inc_by(u64::cast_from(client_counter.written));
1266                    inner_metrics
1267                        .tenant_connections_rx(tenant)
1268                        .inc_by(u64::cast_from(client_counter.read));
1269                }
1270                drop(inner_active_guard);
1271                Ok(())
1272            })
1273            .await;
1274            drop(active_guard);
1275            outer_metrics.connection_status(result.is_ok()).inc();
1276            if let Err(e) = result {
1277                debug!("connection error: {e}");
1278            }
1279            Ok(())
1280        })
1281    }
1282}
1283
1284#[derive(Debug)]
1285pub struct SniResolver {
1286    pub resolver: StubResolver,
1287    pub template: String,
1288    pub port: u16,
1289}
1290
1291trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
1292impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
1293
1294#[derive(Debug)]
1295pub enum Resolver {
1296    Static(String),
1297    MultiTenant(FronteggResolver, Option<SniResolver>),
1298}
1299
1300impl Resolver {
1301    async fn resolve<A>(
1302        &self,
1303        conn: &mut FramedConn<A>,
1304        user: &str,
1305        metrics: &ServerMetrics,
1306    ) -> Result<ResolvedAddr, anyhow::Error>
1307    where
1308        A: AsyncRead + AsyncWrite + Unpin,
1309    {
1310        match self {
1311            Resolver::MultiTenant(
1312                FronteggResolver {
1313                    auth,
1314                    addr_template,
1315                },
1316                sni_resolver,
1317            ) => {
1318                let servername = match conn.inner() {
1319                    Conn::Ssl(ssl_stream) => {
1320                        ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
1321                            match sn.split_once('.') {
1322                                Some((left, _right)) => left,
1323                                None => sn,
1324                            }
1325                        })
1326                    }
1327                    Conn::Unencrypted(_) => None,
1328                };
1329                let has_sni = servername.is_some();
1330                // We found an SNi
1331                let resolved_addr = match (servername, sni_resolver) {
1332                    (
1333                        Some(servername),
1334                        Some(SniResolver {
1335                            resolver: stub_resolver,
1336                            template: sni_addr_template,
1337                            port,
1338                        }),
1339                    ) => {
1340                        let sni_addr = sni_addr_template.replace("{}", servername);
1341                        let tenant = stub_resolver.tenant(&sni_addr).await;
1342                        let sni_addr = format!("{sni_addr}:{port}");
1343                        let addr = lookup(&sni_addr).await?;
1344                        if tenant.is_some() {
1345                            debug!("SNI header found for tenant {:?}", tenant);
1346                        }
1347                        ResolvedAddr {
1348                            addr,
1349                            password: None,
1350                            tenant,
1351                        }
1352                    }
1353                    _ => {
1354                        conn.send(BackendMessage::AuthenticationCleartextPassword)
1355                            .await?;
1356                        conn.flush().await?;
1357                        let password = match conn.recv().await? {
1358                            Some(FrontendMessage::Password { password }) => password,
1359                            _ => anyhow::bail!("expected Password message"),
1360                        };
1361
1362                        let auth_response = auth.authenticate(user, &password).await;
1363                        let auth_session = match auth_response {
1364                            Ok(auth_session) => auth_session,
1365                            Err(e) => {
1366                                warn!("pgwire connection failed authentication: {}", e);
1367                                // TODO: fix error codes.
1368                                anyhow::bail!("invalid password");
1369                            }
1370                        };
1371
1372                        let addr =
1373                            addr_template.replace("{}", &auth_session.tenant_id().to_string());
1374                        let addr = lookup(&addr).await?;
1375                        let tenant = Some(auth_session.tenant_id().to_string());
1376                        if tenant.is_some() {
1377                            debug!("SNI header NOT found for tenant {:?}", tenant);
1378                        }
1379                        ResolvedAddr {
1380                            addr,
1381                            password: Some(password),
1382                            tenant,
1383                        }
1384                    }
1385                };
1386                metrics
1387                    .tenant_pgwire_sni_count(
1388                        resolved_addr.tenant.as_deref().unwrap_or("unknown"),
1389                        has_sni,
1390                    )
1391                    .inc();
1392
1393                Ok(resolved_addr)
1394            }
1395            Resolver::Static(addr) => {
1396                let addr = lookup(addr).await?;
1397                Ok(ResolvedAddr {
1398                    addr,
1399                    password: None,
1400                    tenant: None,
1401                })
1402            }
1403        }
1404    }
1405}
1406
1407/// Returns the first IP address resolved from the provided hostname.
1408async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
1409    let mut addrs = tokio::net::lookup_host(name).await?;
1410    match addrs.next() {
1411        Some(addr) => Ok(addr),
1412        None => {
1413            error!("{name} did not resolve to any addresses");
1414            anyhow::bail!("internal error")
1415        }
1416    }
1417}
1418
1419#[derive(Debug)]
1420pub struct FronteggResolver {
1421    pub auth: FronteggAuthentication,
1422    pub addr_template: String,
1423}
1424
1425#[derive(Debug)]
1426struct ResolvedAddr {
1427    addr: SocketAddr,
1428    password: Option<String>,
1429    tenant: Option<String>,
1430}
1431
1432#[cfg(test)]
1433mod tests {
1434    use super::*;
1435
1436    #[mz_ore::test]
1437    fn test_tenant() {
1438        let tests = vec![
1439            ("", None),
1440            (
1441                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1442                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1443            ),
1444            (
1445                // Variously named parts.
1446                "service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
1447                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1448            ),
1449            (
1450                // No dashes in uuid.
1451                "environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
1452                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1453            ),
1454            (
1455                // -1234 suffix.
1456                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
1457                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1458            ),
1459            (
1460                // Uppercase.
1461                "environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
1462                Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
1463            ),
1464            (
1465                // No -number suffix.
1466                "environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
1467                None,
1468            ),
1469            (
1470                // No service name.
1471                "environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1472                None,
1473            ),
1474            (
1475                // Invalid UUID.
1476                "environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
1477                None,
1478            ),
1479        ];
1480        for (name, expect) in tests {
1481            let cname = extract_tenant_from_cname(name);
1482            assert_eq!(
1483                cname.as_deref(),
1484                expect,
1485                "{name} got {cname:?} expected {expect:?}"
1486            );
1487        }
1488    }
1489}