mod codec;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use axum::response::IntoResponse;
use axum::{routing, Router};
use bytes::BytesMut;
use domain::base::{Dname, Rtype};
use domain::rdata::AllRecordData;
use domain::resolv::StubResolver;
use futures::stream::BoxStream;
use futures::TryFutureExt;
use hyper::StatusCode;
use mz_build_info::{build_info, BuildInfo};
use mz_frontegg_auth::Authenticator as FronteggAuthentication;
use mz_ore::cast::CastFrom;
use mz_ore::id_gen::conn_id_org_uuid;
use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
use mz_ore::netio::AsyncReady;
use mz_ore::task::{spawn, JoinSetExt};
use mz_ore::{metric, netio};
use mz_pgwire_common::{
decode_startup, Conn, ErrorResponse, FrontendMessage, FrontendStartupMessage,
ACCEPT_SSL_ENCRYPTION, REJECT_ENCRYPTION, VERSION_3,
};
use mz_server_core::{
listen, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
ReloadingTlsConfig, TlsCertConfig, TlsMode,
};
use openssl::ssl::{NameType, Ssl};
use prometheus::{IntCounterVec, IntGaugeVec};
use semver::Version;
use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tokio_openssl::SslStream;
use tokio_postgres::error::SqlState;
use tracing::{debug, error, warn};
use uuid::Uuid;
use crate::codec::{BackendMessage, FramedConn};
pub const BUILD_INFO: BuildInfo = build_info!();
pub struct BalancerConfig {
sigterm_wait: Option<Duration>,
build_version: Version,
internal_http_listen_addr: SocketAddr,
pgwire_listen_addr: SocketAddr,
https_listen_addr: SocketAddr,
cancellation_resolver_dir: Option<PathBuf>,
resolver: Resolver,
https_addr_template: String,
tls: Option<TlsCertConfig>,
metrics_registry: MetricsRegistry,
reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
}
impl BalancerConfig {
pub fn new(
build_info: &BuildInfo,
sigterm_wait: Option<Duration>,
internal_http_listen_addr: SocketAddr,
pgwire_listen_addr: SocketAddr,
https_listen_addr: SocketAddr,
cancellation_resolver_dir: Option<PathBuf>,
resolver: Resolver,
https_addr_template: String,
tls: Option<TlsCertConfig>,
metrics_registry: MetricsRegistry,
reload_certs: ReloadTrigger,
) -> Self {
Self {
build_version: build_info.semver_version(),
sigterm_wait,
internal_http_listen_addr,
pgwire_listen_addr,
https_listen_addr,
cancellation_resolver_dir,
resolver,
https_addr_template,
tls,
metrics_registry,
reload_certs,
}
}
}
#[derive(Debug)]
pub struct BalancerMetrics {
_uptime: ComputedGauge,
}
impl BalancerMetrics {
pub fn new(cfg: &BalancerConfig) -> Self {
let start = Instant::now();
let uptime = cfg.metrics_registry.register_computed_gauge(
metric!(
name: "mz_balancer_metadata_seconds",
help: "server uptime, labels are build metadata",
const_labels: {
"version" => cfg.build_version,
"build_type" => if cfg!(release) { "release" } else { "debug" }
},
),
move || start.elapsed().as_secs_f64(),
);
BalancerMetrics { _uptime: uptime }
}
}
pub struct BalancerService {
cfg: BalancerConfig,
pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
_metrics: BalancerMetrics,
}
impl BalancerService {
pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
let pgwire = listen(&cfg.pgwire_listen_addr).await?;
let https = listen(&cfg.https_listen_addr).await?;
let internal_http = listen(&cfg.internal_http_listen_addr).await?;
let metrics = BalancerMetrics::new(&cfg);
Ok(Self {
cfg,
pgwire,
https,
internal_http,
_metrics: metrics,
})
}
pub async fn serve(self) -> Result<(), anyhow::Error> {
let (pgwire_tls, https_tls) = match &self.cfg.tls {
Some(tls) => {
let context = tls.reloading_context(self.cfg.reload_certs)?;
(
Some(ReloadingTlsConfig {
context: context.clone(),
mode: TlsMode::Require,
}),
Some(context),
)
}
None => (None, None),
};
let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
let mut set = JoinSet::new();
let mut server_handles = Vec::new();
let pgwire_addr = self.pgwire.0.local_addr();
let https_addr = self.https.0.local_addr();
let internal_http_addr = self.internal_http.0.local_addr();
{
if let Some(dir) = &self.cfg.cancellation_resolver_dir {
if !dir.is_dir() {
anyhow::bail!("{dir:?} is not a directory");
}
}
let cancellation_resolver = self.cfg.cancellation_resolver_dir.map(Arc::new);
let pgwire = PgwireBalancer {
resolver: Arc::new(self.cfg.resolver),
cancellation_resolver,
tls: pgwire_tls,
metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
};
let (handle, stream) = self.pgwire;
server_handles.push(handle);
set.spawn_named(|| "pgwire_stream", async move {
mz_server_core::serve(stream, pgwire, self.cfg.sigterm_wait).await;
warn!("pgwire server exited");
});
}
{
let Some((addr, port)) = self.cfg.https_addr_template.split_once(':') else {
panic!("expected port in https_addr_template");
};
let port: u16 = port.parse().expect("unexpected port");
let resolver = StubResolver::new();
let https = HttpsBalancer {
resolver: Arc::from(resolver),
tls: https_tls,
resolve_template: Arc::from(addr),
port,
metrics: Arc::from(ServerMetrics::new(metrics, "https")),
};
let (handle, stream) = self.https;
server_handles.push(handle);
set.spawn_named(|| "https_stream", async move {
mz_server_core::serve(stream, https, self.cfg.sigterm_wait).await;
warn!("https server exited");
});
}
{
let router = Router::new()
.route(
"/metrics",
routing::get(move || async move {
mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
}),
)
.route(
"/api/livez",
routing::get(mz_http_util::handle_liveness_check),
)
.route("/api/readyz", routing::get(handle_readiness_check));
let internal_http = InternalHttpServer { router };
let (handle, stream) = self.internal_http;
server_handles.push(handle);
set.spawn_named(|| "internal_http_stream", async move {
mz_server_core::serve(stream, internal_http, None).await;
warn!("internal_http server exited");
});
}
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
set.spawn_named(|| "sigterm_handler", async move {
sigterm.recv().await;
warn!("received signal TERM");
drop(server_handles);
});
}
println!("balancerd {} listening...", BUILD_INFO.human_version());
println!(" TLS enabled: {}", self.cfg.tls.is_some());
println!(" pgwire address: {}", pgwire_addr);
println!(" HTTPS address: {}", https_addr);
println!(" internal HTTP address: {}", internal_http_addr);
while let Some(res) = set.join_next().await {
if let Err(err) = res {
error!("serving task failed: {err}")
}
}
Ok(())
}
}
#[allow(clippy::unused_async)]
async fn handle_readiness_check() -> impl IntoResponse {
(StatusCode::OK, "ready")
}
struct InternalHttpServer {
router: Router,
}
impl mz_server_core::Server for InternalHttpServer {
const NAME: &'static str = "internal_http";
fn handle_connection(&self, conn: TcpStream) -> mz_server_core::ConnectionHandler {
let router = self.router.clone();
Box::pin(async {
let http = hyper::server::conn::Http::new();
http.serve_connection(conn, router).err_into().await
})
}
}
struct GaugeGuard {
gauge: IntGauge,
}
impl From<IntGauge> for GaugeGuard {
fn from(gauge: IntGauge) -> Self {
let _self = Self { gauge };
_self.gauge.inc();
_self
}
}
impl Drop for GaugeGuard {
fn drop(&mut self) {
self.gauge.dec();
}
}
#[derive(Clone, Debug)]
struct ServerMetricsConfig {
connection_status: IntCounterVec,
active_connections: IntGaugeVec,
tenant_connections: IntGaugeVec,
tenant_connection_rx: IntCounterVec,
tenant_connection_tx: IntCounterVec,
tenant_pgwire_sni_count: IntCounterVec,
}
impl ServerMetricsConfig {
fn register_into(registry: &MetricsRegistry) -> Self {
let connection_status = registry.register(metric!(
name: "mz_balancer_connection_status",
help: "Count of completed network connections, by status",
var_labels: ["source", "status"],
));
let active_connections = registry.register(metric!(
name: "mz_balancer_connection_active",
help: "Count of currently open network connections.",
var_labels: ["source"],
));
let tenant_connections = registry.register(metric!(
name: "mz_balancer_tenant_connection_active",
help: "Count of opened network connections by tenant.",
var_labels: ["source", "tenant"]
));
let tenant_connection_rx = registry.register(metric!(
name: "mz_balancer_tenant_connection_rx",
help: "Number of bytes received from a client for a tenant.",
var_labels: ["source", "tenant"],
));
let tenant_connection_tx = registry.register(metric!(
name: "mz_balancer_tenant_connection_tx",
help: "Number of bytes sent to a client for a tenant.",
var_labels: ["source", "tenant"],
));
let tenant_pgwire_sni_count = registry.register(metric!(
name: "mz_balancer_tenant_pgwire_sni_count",
help: "Count of pgwire connections that have and do not have SNI available per tenant.",
var_labels: ["tenant", "has_sni"],
));
Self {
connection_status,
active_connections,
tenant_connections,
tenant_connection_rx,
tenant_connection_tx,
tenant_pgwire_sni_count,
}
}
}
#[derive(Clone, Debug)]
struct ServerMetrics {
inner: ServerMetricsConfig,
source: &'static str,
}
impl ServerMetrics {
fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
let self_ = Self { inner, source };
self_.connection_status(false);
self_.connection_status(true);
drop(self_.active_connections());
self_
}
fn connection_status(&self, is_ok: bool) -> IntCounter {
self.inner
.connection_status
.with_label_values(&[self.source, Self::status_label(is_ok)])
}
fn active_connections(&self) -> GaugeGuard {
self.inner
.active_connections
.with_label_values(&[self.source])
.into()
}
fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
self.inner
.tenant_connections
.with_label_values(&[self.source, tenant])
.into()
}
fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
self.inner
.tenant_connection_rx
.with_label_values(&[self.source, tenant])
}
fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
self.inner
.tenant_connection_tx
.with_label_values(&[self.source, tenant])
}
fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
self.inner
.tenant_pgwire_sni_count
.with_label_values(&[tenant, &has_sni.to_string()])
}
fn status_label(is_ok: bool) -> &'static str {
if is_ok {
"success"
} else {
"error"
}
}
}
struct PgwireBalancer {
tls: Option<ReloadingTlsConfig>,
cancellation_resolver: Option<Arc<PathBuf>>,
resolver: Arc<Resolver>,
metrics: ServerMetrics,
}
impl PgwireBalancer {
#[mz_ore::instrument(level = "debug")]
async fn run<'a, A>(
conn: &'a mut FramedConn<A>,
version: i32,
params: BTreeMap<String, String>,
resolver: &Resolver,
tls_mode: Option<TlsMode>,
metrics: &ServerMetrics,
) -> Result<(), io::Error>
where
A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
{
if version != VERSION_3 {
return conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"server does not support the client's requested protocol version",
))
.await;
}
let Some(user) = params.get("user") else {
return conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user parameter required",
))
.await;
};
if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
return conn.send(err).await;
}
let resolved = match resolver.resolve(conn, user).await {
Ok(v) => v,
Err(err) => {
return conn
.send(ErrorResponse::fatal(
SqlState::INVALID_PASSWORD,
err.to_string(),
))
.await;
}
};
if let Conn::Ssl(ssl_stream) = conn.inner() {
let tenant = resolved.tenant.as_deref().unwrap_or_else(|| "unknown");
let has_sni = ssl_stream.ssl().servername(NameType::HOST_NAME).is_some();
metrics.tenant_pgwire_sni_count(tenant, has_sni).inc();
}
let _active_guard = resolved
.tenant
.as_ref()
.map(|tenant| metrics.tenant_connections(tenant));
let Ok(mut mz_stream) =
Self::init_stream(conn, resolved.addr, resolved.password, params).await
else {
return Ok(());
};
let mut client_counter = CountingConn::new(conn.inner_mut());
let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
if let Some(tenant) = &resolved.tenant {
metrics
.tenant_connections_tx(tenant)
.inc_by(u64::cast_from(client_counter.written));
metrics
.tenant_connections_rx(tenant)
.inc_by(u64::cast_from(client_counter.read));
}
Ok(())
}
#[mz_ore::instrument(level = "debug")]
async fn init_stream<'a, A>(
conn: &'a mut FramedConn<A>,
envd_addr: SocketAddr,
password: Option<String>,
params: BTreeMap<String, String>,
) -> Result<TcpStream, anyhow::Error>
where
A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
{
let mut mz_stream = TcpStream::connect(envd_addr).await?;
let mut buf = BytesMut::new();
let startup = FrontendStartupMessage::Startup {
version: VERSION_3,
params,
};
startup.encode(&mut buf)?;
mz_stream.write_all(&buf).await?;
let client_stream = conn.inner_mut();
let mut maybe_auth_frame = [0; 1 + 4 + 4];
let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
if nread == AUTH_PASSWORD_CLEARTEXT.len()
&& maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
&& password.is_some()
{
let Some(password) = password else {
unreachable!("verified some above");
};
let password = FrontendMessage::Password { password };
buf.clear();
password.encode(&mut buf)?;
mz_stream.write_all(&buf).await?;
mz_stream.flush().await?;
} else {
client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
}
Ok(mz_stream)
}
}
impl mz_server_core::Server for PgwireBalancer {
const NAME: &'static str = "pgwire_balancer";
fn handle_connection(&self, conn: TcpStream) -> mz_server_core::ConnectionHandler {
let tls = self.tls.clone();
let resolver = Arc::clone(&self.resolver);
let inner_metrics = self.metrics.clone();
let outer_metrics = self.metrics.clone();
let cancellation_resolver = self.cancellation_resolver.clone();
Box::pin(async move {
let active_guard = outer_metrics.active_connections();
let result: Result<(), anyhow::Error> = async move {
let mut conn = Conn::Unencrypted(conn);
loop {
let message = decode_startup(&mut conn).await?;
conn = match message {
None => return Ok(()),
Some(FrontendStartupMessage::Startup { version, params }) => {
let mut conn = FramedConn::new(conn);
Self::run(
&mut conn,
version,
params,
&resolver,
tls.map(|tls| tls.mode),
&inner_metrics,
)
.await?;
conn.flush().await?;
return Ok(());
}
Some(FrontendStartupMessage::CancelRequest {
conn_id,
secret_key,
}) => {
if let Some(resolver) = cancellation_resolver {
spawn(|| "cancel request", async move {
cancel_request(conn_id, secret_key, &resolver).await;
});
}
return Ok(());
}
Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
(Conn::Unencrypted(mut conn), Some(tls)) => {
conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
let mut ssl_stream =
SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
let _ = ssl_stream.get_mut().shutdown().await;
return Err(e.into());
}
Conn::Ssl(ssl_stream)
}
(mut conn, _) => {
conn.write_all(&[REJECT_ENCRYPTION]).await?;
conn
}
},
Some(FrontendStartupMessage::GssEncRequest) => {
conn.write_all(&[REJECT_ENCRYPTION]).await?;
conn
}
}
}
}
.await;
drop(active_guard);
outer_metrics.connection_status(result.is_ok()).inc();
Ok(())
})
}
}
struct CountingConn<C> {
inner: C,
read: usize,
written: usize,
}
impl<C> CountingConn<C> {
fn new(inner: C) -> Self {
CountingConn {
inner,
read: 0,
written: 0,
}
}
}
impl<C> AsyncRead for CountingConn<C>
where
C: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
let bytes = buf.filled().len();
let poll = pin.poll_read(cx, buf);
let bytes = buf.filled().len() - bytes;
if let std::task::Poll::Ready(Ok(())) = poll {
counter.read += bytes
}
poll
}
}
impl<C> AsyncWrite for CountingConn<C>
where
C: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
let poll = pin.poll_write(cx, buf);
if let std::task::Poll::Ready(Ok(bytes)) = poll {
counter.written += bytes
}
poll
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
pin.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
pin.poll_shutdown(cx)
}
}
async fn cancel_request(conn_id: u32, secret_key: u32, cancellation_resolver: &PathBuf) {
let suffix = conn_id_org_uuid(conn_id);
let path = cancellation_resolver.join(&suffix);
let contents = match std::fs::read_to_string(&path) {
Ok(contents) => contents,
Err(err) => {
error!("could not read cancel file {path:?}: {err}");
return;
}
};
let mut all_ips = Vec::new();
for addr in contents.lines() {
let addr = addr.trim();
if addr.is_empty() {
continue;
}
match tokio::net::lookup_host(addr).await {
Ok(ips) => all_ips.extend(ips),
Err(err) => {
error!("{addr} failed resolution: {err}");
}
}
}
let mut buf = BytesMut::with_capacity(16);
let msg = FrontendStartupMessage::CancelRequest {
conn_id,
secret_key,
};
msg.encode(&mut buf).expect("must encode");
let buf = buf.freeze();
for ip in all_ips {
debug!("cancelling {suffix} to {ip}");
let buf = buf.clone();
spawn(|| "cancel request for ip", async move {
let send = async {
let mut stream = TcpStream::connect(&ip).await?;
stream.write_all(&buf).await?;
stream.shutdown().await?;
Ok::<_, io::Error>(())
};
if let Err(err) = send.await {
error!("error mirroring cancel to {ip}: {err}");
}
});
}
}
struct HttpsBalancer {
resolver: Arc<StubResolver>,
tls: Option<ReloadingSslContext>,
resolve_template: Arc<str>,
port: u16,
metrics: Arc<ServerMetrics>,
}
impl HttpsBalancer {
async fn resolve(
resolver: &StubResolver,
resolve_template: &str,
port: u16,
servername: Option<&str>,
) -> Result<ResolvedAddr, anyhow::Error> {
let addr = match &servername {
Some(servername) => resolve_template.replace("{}", servername),
None => resolve_template.to_string(),
};
debug!("https address: {addr}");
let tenant = Self::tenant(resolver, &addr).await;
let envd_addr = lookup(&format!("{addr}:{port}")).await?;
Ok(ResolvedAddr {
addr: envd_addr,
password: None,
tenant,
})
}
async fn tenant(resolver: &StubResolver, addr: &str) -> Option<String> {
let Ok(dname) = Dname::<Vec<_>>::from_str(addr) else {
return None;
};
let lookup = resolver.query((dname, Rtype::Cname)).await;
if let Ok(lookup) = lookup {
if let Ok(answer) = lookup.answer() {
let res = answer.limit_to::<AllRecordData<_, _>>();
for record in res {
let Ok(record) = record else {
continue;
};
if record.rtype() != Rtype::Cname {
continue;
}
let cname = record.data();
let cname = cname.to_string();
debug!("cname: {cname}");
return Self::extract_tenant_from_cname(&cname);
}
}
}
None
}
fn extract_tenant_from_cname(cname: &str) -> Option<String> {
let mut parts = cname.split('.');
let _service = parts.next();
let Some(namespace) = parts.next() else {
return None;
};
let Some((_, namespace)) = namespace.split_once('-') else {
return None;
};
let Some((tenant, _)) = namespace.rsplit_once('-') else {
return None;
};
let Ok(tenant) = Uuid::parse_str(tenant) else {
error!("cname tenant not a uuid: {tenant}");
return None;
};
Some(tenant.to_string())
}
}
impl mz_server_core::Server for HttpsBalancer {
const NAME: &'static str = "https_balancer";
fn handle_connection(&self, conn: TcpStream) -> mz_server_core::ConnectionHandler {
let tls_context = self.tls.clone();
let resolver = Arc::clone(&self.resolver);
let resolve_template = Arc::clone(&self.resolve_template);
let port = self.port;
let inner_metrics = Arc::clone(&self.metrics);
let outer_metrics = Arc::clone(&self.metrics);
Box::pin(async move {
let active_guard = inner_metrics.active_connections();
let result: Result<_, anyhow::Error> = Box::pin(async move {
let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
match tls_context {
Some(tls_context) => {
let mut ssl_stream =
SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
let _ = ssl_stream.get_mut().shutdown().await;
return Err(e.into());
}
let servername: Option<String> =
ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
match sn.split_once('.') {
Some((left, _right)) => left,
None => sn,
}
.into()
});
debug!("servername: {servername:?}");
(Box::new(ssl_stream), servername)
}
_ => (Box::new(conn), None),
};
let resolved =
Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
.await?;
let inner_active_guard = resolved
.tenant
.as_ref()
.map(|tenant| inner_metrics.tenant_connections(tenant));
let mut mz_stream = TcpStream::connect(resolved.addr).await?;
let mut client_counter = CountingConn::new(client_stream);
let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
if let Some(tenant) = &resolved.tenant {
inner_metrics
.tenant_connections_tx(tenant)
.inc_by(u64::cast_from(client_counter.written));
inner_metrics
.tenant_connections_rx(tenant)
.inc_by(u64::cast_from(client_counter.read));
}
drop(inner_active_guard);
Ok(())
})
.await;
drop(active_guard);
outer_metrics.connection_status(result.is_ok()).inc();
if let Err(e) = result {
debug!("connection error: {e}");
}
Ok(())
})
}
}
trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
#[derive(Debug)]
pub enum Resolver {
Static(String),
Frontegg(FronteggResolver),
}
impl Resolver {
async fn resolve<A>(
&self,
conn: &mut FramedConn<A>,
user: &str,
) -> Result<ResolvedAddr, anyhow::Error>
where
A: AsyncRead + AsyncWrite + Unpin,
{
match self {
Resolver::Frontegg(FronteggResolver {
auth,
addr_template,
}) => {
conn.send(BackendMessage::AuthenticationCleartextPassword)
.await?;
conn.flush().await?;
let password = match conn.recv().await? {
Some(FrontendMessage::Password { password }) => password,
_ => anyhow::bail!("expected Password message"),
};
let auth_response = auth.authenticate(user.into(), &password).await;
let auth_session = match auth_response {
Ok(auth_session) => auth_session,
Err(e) => {
warn!("pgwire connection failed authentication: {}", e);
anyhow::bail!("invalid password");
}
};
let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string());
let addr = lookup(&addr).await?;
Ok(ResolvedAddr {
addr,
password: Some(password),
tenant: Some(auth_session.tenant_id().to_string()),
})
}
Resolver::Static(addr) => {
let addr = lookup(addr).await?;
Ok(ResolvedAddr {
addr,
password: None,
tenant: None,
})
}
}
}
}
async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
let mut addrs = tokio::net::lookup_host(name).await?;
match addrs.next() {
Some(addr) => Ok(addr),
None => {
error!("{name} did not resolve to any addresses");
anyhow::bail!("internal error")
}
}
}
#[derive(Debug)]
pub struct FronteggResolver {
pub auth: FronteggAuthentication,
pub addr_template: String,
}
#[derive(Debug)]
struct ResolvedAddr {
addr: SocketAddr,
password: Option<String>,
tenant: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[mz_ore::test]
fn test_tenant() {
let tests = vec![
("", None),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
None,
),
(
"environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
None,
),
(
"environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
None,
),
];
for (name, expect) in tests {
let cname = HttpsBalancer::extract_tenant_from_cname(name);
assert_eq!(
cname.as_deref(),
expect,
"{name} got {cname:?} expected {expect:?}"
);
}
}
}