mod codec;
mod dyncfgs;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Context;
use axum::response::IntoResponse;
use axum::{routing, Router};
use bytes::BytesMut;
use domain::base::{Dname, Rtype};
use domain::rdata::AllRecordData;
use domain::resolv::StubResolver;
use futures::stream::BoxStream;
use futures::TryFutureExt;
use hyper::StatusCode;
use hyper_util::rt::TokioIo;
use launchdarkly_server_sdk as ld;
use mz_build_info::{build_info, BuildInfo};
use mz_dyncfg::ConfigSet;
use mz_frontegg_auth::Authenticator as FronteggAuthentication;
use mz_ore::cast::CastFrom;
use mz_ore::id_gen::conn_id_org_uuid;
use mz_ore::metrics::{ComputedGauge, IntCounter, IntGauge, MetricsRegistry};
use mz_ore::netio::AsyncReady;
use mz_ore::task::{spawn, JoinSetExt};
use mz_ore::tracing::TracingHandle;
use mz_ore::{metric, netio};
use mz_pgwire_common::{
decode_startup, Conn, ErrorResponse, FrontendMessage, FrontendStartupMessage,
ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, VERSION_3,
};
use mz_server_core::{
listen, Connection, ConnectionStream, ListenerHandle, ReloadTrigger, ReloadingSslContext,
ReloadingTlsConfig, ServeConfig, ServeDyncfg, TlsCertConfig, TlsMode,
};
use openssl::ssl::{NameType, Ssl, SslConnector, SslMethod, SslVerifyMode};
use prometheus::{IntCounterVec, IntGaugeVec};
use proxy_header::{ProxiedAddress, ProxyHeader};
use semver::Version;
use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tokio_openssl::SslStream;
use tokio_postgres::error::SqlState;
use tower::Service;
use tracing::{debug, error, warn};
use uuid::Uuid;
use crate::codec::{BackendMessage, FramedConn};
use crate::dyncfgs::{
has_tracing_config_update, tracing_config, INJECT_PROXY_PROTOCOL_HEADER_HTTP, SIGTERM_WAIT,
};
pub const BUILD_INFO: BuildInfo = build_info!();
pub struct BalancerConfig {
build_version: Version,
internal_http_listen_addr: SocketAddr,
pgwire_listen_addr: SocketAddr,
https_listen_addr: SocketAddr,
cancellation_resolver: CancellationResolver,
resolver: Resolver,
https_addr_template: String,
tls: Option<TlsCertConfig>,
internal_tls: bool,
metrics_registry: MetricsRegistry,
reload_certs: BoxStream<'static, Option<oneshot::Sender<Result<(), anyhow::Error>>>>,
launchdarkly_sdk_key: Option<String>,
config_sync_timeout: Duration,
config_sync_loop_interval: Option<Duration>,
cloud_provider: Option<String>,
cloud_provider_region: Option<String>,
tracing_handle: TracingHandle,
default_configs: Vec<(String, String)>,
}
impl BalancerConfig {
pub fn new(
build_info: &BuildInfo,
internal_http_listen_addr: SocketAddr,
pgwire_listen_addr: SocketAddr,
https_listen_addr: SocketAddr,
cancellation_resolver: CancellationResolver,
resolver: Resolver,
https_addr_template: String,
tls: Option<TlsCertConfig>,
internal_tls: bool,
metrics_registry: MetricsRegistry,
reload_certs: ReloadTrigger,
launchdarkly_sdk_key: Option<String>,
config_sync_timeout: Duration,
config_sync_loop_interval: Option<Duration>,
cloud_provider: Option<String>,
cloud_provider_region: Option<String>,
tracing_handle: TracingHandle,
default_configs: Vec<(String, String)>,
) -> Self {
Self {
build_version: build_info.semver_version(),
internal_http_listen_addr,
pgwire_listen_addr,
https_listen_addr,
cancellation_resolver,
resolver,
https_addr_template,
tls,
internal_tls,
metrics_registry,
reload_certs,
launchdarkly_sdk_key,
config_sync_timeout,
config_sync_loop_interval,
cloud_provider,
cloud_provider_region,
tracing_handle,
default_configs,
}
}
}
#[derive(Debug)]
pub struct BalancerMetrics {
_uptime: ComputedGauge,
}
impl BalancerMetrics {
pub fn new(cfg: &BalancerConfig) -> Self {
let start = Instant::now();
let uptime = cfg.metrics_registry.register_computed_gauge(
metric!(
name: "mz_balancer_metadata_seconds",
help: "server uptime, labels are build metadata",
const_labels: {
"version" => cfg.build_version,
"build_type" => if cfg!(release) { "release" } else { "debug" }
},
),
move || start.elapsed().as_secs_f64(),
);
BalancerMetrics { _uptime: uptime }
}
}
pub struct BalancerService {
cfg: BalancerConfig,
pub pgwire: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
pub https: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
pub internal_http: (ListenerHandle, Pin<Box<dyn ConnectionStream>>),
_metrics: BalancerMetrics,
configs: ConfigSet,
}
impl BalancerService {
pub async fn new(cfg: BalancerConfig) -> Result<Self, anyhow::Error> {
let pgwire = listen(&cfg.pgwire_listen_addr).await?;
let https = listen(&cfg.https_listen_addr).await?;
let internal_http = listen(&cfg.internal_http_listen_addr).await?;
let metrics = BalancerMetrics::new(&cfg);
let mut configs = ConfigSet::default();
configs = dyncfgs::all_dyncfgs(configs);
dyncfgs::set_defaults(&configs, cfg.default_configs.clone())?;
let tracing_handle = cfg.tracing_handle.clone();
if let Err(err) = mz_dyncfg_launchdarkly::sync_launchdarkly_to_configset(
configs.clone(),
&BUILD_INFO,
|builder| {
let region = cfg
.cloud_provider_region
.clone()
.unwrap_or_else(|| String::from("unknown"));
if let Some(provider) = cfg.cloud_provider.clone() {
builder.add_context(
ld::ContextBuilder::new(format!(
"{}/{}/{}",
provider, region, cfg.build_version
))
.kind("balancer")
.set_string("provider", provider)
.set_string("region", region)
.set_string("version", cfg.build_version.to_string())
.build()
.map_err(|e| anyhow::anyhow!(e))?,
);
} else {
builder.add_context(
ld::ContextBuilder::new(format!(
"{}/{}/{}",
"unknown", region, cfg.build_version
))
.anonymous(true) .kind("balancer")
.set_string("provider", "unknown")
.set_string("region", region)
.set_string("version", cfg.build_version.to_string())
.build()
.map_err(|e| anyhow::anyhow!(e))?,
);
}
Ok(())
},
cfg.launchdarkly_sdk_key.as_deref(),
cfg.config_sync_timeout,
cfg.config_sync_loop_interval,
move |updates, configs| {
if has_tracing_config_update(updates) {
match tracing_config(configs) {
Ok(parameters) => parameters.apply(&tracing_handle),
Err(err) => warn!("unable to update tracing: {err}"),
}
}
},
)
.await
{
warn!("LaunchDarkly sync error: {err}");
}
Ok(Self {
cfg,
pgwire,
https,
internal_http,
_metrics: metrics,
configs,
})
}
pub async fn serve(self) -> Result<(), anyhow::Error> {
let (pgwire_tls, https_tls) = match &self.cfg.tls {
Some(tls) => {
let context = tls.reloading_context(self.cfg.reload_certs)?;
(
Some(ReloadingTlsConfig {
context: context.clone(),
mode: TlsMode::Require,
}),
Some(context),
)
}
None => (None, None),
};
let metrics = ServerMetricsConfig::register_into(&self.cfg.metrics_registry);
let mut set = JoinSet::new();
let mut server_handles = Vec::new();
let pgwire_addr = self.pgwire.0.local_addr();
let https_addr = self.https.0.local_addr();
let internal_http_addr = self.internal_http.0.local_addr();
{
let pgwire = PgwireBalancer {
resolver: Arc::new(self.cfg.resolver),
cancellation_resolver: Arc::new(self.cfg.cancellation_resolver),
tls: pgwire_tls,
internal_tls: self.cfg.internal_tls,
metrics: ServerMetrics::new(metrics.clone(), "pgwire"),
};
let (handle, stream) = self.pgwire;
server_handles.push(handle);
set.spawn_named(|| "pgwire_stream", {
let config_set = self.configs.clone();
async move {
mz_server_core::serve(ServeConfig {
server: pgwire,
conns: stream,
dyncfg: Some(ServeDyncfg {
config_set,
sigterm_wait_config: &SIGTERM_WAIT,
}),
})
.await;
warn!("pgwire server exited");
}
});
}
{
let Some((addr, port)) = self.cfg.https_addr_template.split_once(':') else {
panic!("expected port in https_addr_template");
};
let port: u16 = port.parse().expect("unexpected port");
let resolver = StubResolver::new();
let https = HttpsBalancer {
resolver: Arc::from(resolver),
tls: https_tls,
resolve_template: Arc::from(addr),
port,
metrics: Arc::from(ServerMetrics::new(metrics, "https")),
configs: self.configs.clone(),
internal_tls: self.cfg.internal_tls,
};
let (handle, stream) = self.https;
server_handles.push(handle);
set.spawn_named(|| "https_stream", {
let config_set = self.configs.clone();
async move {
mz_server_core::serve(ServeConfig {
server: https,
conns: stream,
dyncfg: Some(ServeDyncfg {
config_set,
sigterm_wait_config: &SIGTERM_WAIT,
}),
})
.await;
warn!("https server exited");
}
});
}
{
let router = Router::new()
.route(
"/metrics",
routing::get(move || async move {
mz_http_util::handle_prometheus(&self.cfg.metrics_registry).await
}),
)
.route(
"/api/livez",
routing::get(mz_http_util::handle_liveness_check),
)
.route("/api/readyz", routing::get(handle_readiness_check));
let internal_http = InternalHttpServer { router };
let (handle, stream) = self.internal_http;
server_handles.push(handle);
set.spawn_named(|| "internal_http_stream", async move {
mz_server_core::serve(ServeConfig {
server: internal_http,
conns: stream,
dyncfg: None,
})
.await;
warn!("internal_http server exited");
});
}
#[cfg(unix)]
{
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
set.spawn_named(|| "sigterm_handler", async move {
sigterm.recv().await;
warn!("received signal TERM");
drop(server_handles);
});
}
println!("balancerd {} listening...", BUILD_INFO.human_version(None));
println!(" TLS enabled: {}", self.cfg.tls.is_some());
println!(" pgwire address: {}", pgwire_addr);
println!(" HTTPS address: {}", https_addr);
println!(" internal HTTP address: {}", internal_http_addr);
while let Some(res) = set.join_next().await {
if let Err(err) = res {
error!("serving task failed: {err}")
}
}
Ok(())
}
}
#[allow(clippy::unused_async)]
async fn handle_readiness_check() -> impl IntoResponse {
(StatusCode::OK, "ready")
}
struct InternalHttpServer {
router: Router,
}
impl mz_server_core::Server for InternalHttpServer {
const NAME: &'static str = "internal_http";
fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
let router = self.router.clone();
let service = hyper::service::service_fn(move |req| router.clone().call(req));
let conn = TokioIo::new(conn);
Box::pin(async {
let http = hyper::server::conn::http1::Builder::new();
http.serve_connection(conn, service).err_into().await
})
}
}
struct GaugeGuard {
gauge: IntGauge,
}
impl From<IntGauge> for GaugeGuard {
fn from(gauge: IntGauge) -> Self {
let _self = Self { gauge };
_self.gauge.inc();
_self
}
}
impl Drop for GaugeGuard {
fn drop(&mut self) {
self.gauge.dec();
}
}
#[derive(Clone, Debug)]
struct ServerMetricsConfig {
connection_status: IntCounterVec,
active_connections: IntGaugeVec,
tenant_connections: IntGaugeVec,
tenant_connection_rx: IntCounterVec,
tenant_connection_tx: IntCounterVec,
tenant_pgwire_sni_count: IntCounterVec,
}
impl ServerMetricsConfig {
fn register_into(registry: &MetricsRegistry) -> Self {
let connection_status = registry.register(metric!(
name: "mz_balancer_connection_status",
help: "Count of completed network connections, by status",
var_labels: ["source", "status"],
));
let active_connections = registry.register(metric!(
name: "mz_balancer_connection_active",
help: "Count of currently open network connections.",
var_labels: ["source"],
));
let tenant_connections = registry.register(metric!(
name: "mz_balancer_tenant_connection_active",
help: "Count of opened network connections by tenant.",
var_labels: ["source", "tenant"]
));
let tenant_connection_rx = registry.register(metric!(
name: "mz_balancer_tenant_connection_rx",
help: "Number of bytes received from a client for a tenant.",
var_labels: ["source", "tenant"],
));
let tenant_connection_tx = registry.register(metric!(
name: "mz_balancer_tenant_connection_tx",
help: "Number of bytes sent to a client for a tenant.",
var_labels: ["source", "tenant"],
));
let tenant_pgwire_sni_count = registry.register(metric!(
name: "mz_balancer_tenant_pgwire_sni_count",
help: "Count of pgwire connections that have and do not have SNI available per tenant.",
var_labels: ["tenant", "has_sni"],
));
Self {
connection_status,
active_connections,
tenant_connections,
tenant_connection_rx,
tenant_connection_tx,
tenant_pgwire_sni_count,
}
}
}
#[derive(Clone, Debug)]
struct ServerMetrics {
inner: ServerMetricsConfig,
source: &'static str,
}
impl ServerMetrics {
fn new(inner: ServerMetricsConfig, source: &'static str) -> Self {
let self_ = Self { inner, source };
self_.connection_status(false);
self_.connection_status(true);
drop(self_.active_connections());
self_
}
fn connection_status(&self, is_ok: bool) -> IntCounter {
self.inner
.connection_status
.with_label_values(&[self.source, Self::status_label(is_ok)])
}
fn active_connections(&self) -> GaugeGuard {
self.inner
.active_connections
.with_label_values(&[self.source])
.into()
}
fn tenant_connections(&self, tenant: &str) -> GaugeGuard {
self.inner
.tenant_connections
.with_label_values(&[self.source, tenant])
.into()
}
fn tenant_connections_rx(&self, tenant: &str) -> IntCounter {
self.inner
.tenant_connection_rx
.with_label_values(&[self.source, tenant])
}
fn tenant_connections_tx(&self, tenant: &str) -> IntCounter {
self.inner
.tenant_connection_tx
.with_label_values(&[self.source, tenant])
}
fn tenant_pgwire_sni_count(&self, tenant: &str, has_sni: bool) -> IntCounter {
self.inner
.tenant_pgwire_sni_count
.with_label_values(&[tenant, &has_sni.to_string()])
}
fn status_label(is_ok: bool) -> &'static str {
if is_ok {
"success"
} else {
"error"
}
}
}
pub enum CancellationResolver {
Directory(PathBuf),
Static(String),
}
struct PgwireBalancer {
tls: Option<ReloadingTlsConfig>,
internal_tls: bool,
cancellation_resolver: Arc<CancellationResolver>,
resolver: Arc<Resolver>,
metrics: ServerMetrics,
}
impl PgwireBalancer {
#[mz_ore::instrument(level = "debug")]
async fn run<'a, A>(
conn: &'a mut FramedConn<A>,
version: i32,
params: BTreeMap<String, String>,
resolver: &Resolver,
tls_mode: Option<TlsMode>,
internal_tls: bool,
metrics: &ServerMetrics,
) -> Result<(), io::Error>
where
A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
{
if version != VERSION_3 {
return conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"server does not support the client's requested protocol version",
))
.await;
}
let Some(user) = params.get("user") else {
return conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"user parameter required",
))
.await;
};
if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
return conn.send(err).await;
}
let resolved = match resolver.resolve(conn, user).await {
Ok(v) => v,
Err(err) => {
return conn
.send(ErrorResponse::fatal(
SqlState::INVALID_PASSWORD,
err.to_string(),
))
.await;
}
};
if let Conn::Ssl(ssl_stream) = conn.inner() {
let tenant = resolved.tenant.as_deref().unwrap_or("unknown");
let has_sni = ssl_stream.ssl().servername(NameType::HOST_NAME).is_some();
metrics.tenant_pgwire_sni_count(tenant, has_sni).inc();
}
let _active_guard = resolved
.tenant
.as_ref()
.map(|tenant| metrics.tenant_connections(tenant));
let Ok(mut mz_stream) =
Self::init_stream(conn, resolved.addr, resolved.password, params, internal_tls).await
else {
return Ok(());
};
let mut client_counter = CountingConn::new(conn.inner_mut());
let res = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
if let Some(tenant) = &resolved.tenant {
metrics
.tenant_connections_tx(tenant)
.inc_by(u64::cast_from(client_counter.written));
metrics
.tenant_connections_rx(tenant)
.inc_by(u64::cast_from(client_counter.read));
}
res?;
Ok(())
}
#[mz_ore::instrument(level = "debug")]
async fn init_stream<'a, A>(
conn: &'a mut FramedConn<A>,
envd_addr: SocketAddr,
password: Option<String>,
params: BTreeMap<String, String>,
internal_tls: bool,
) -> Result<Conn<TcpStream>, anyhow::Error>
where
A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
{
let mut mz_stream = TcpStream::connect(envd_addr).await?;
let mut buf = BytesMut::new();
let mut mz_stream = if internal_tls {
FrontendStartupMessage::SslRequest.encode(&mut buf)?;
mz_stream.write_all(&buf).await?;
buf.clear();
let mut maybe_ssl_request_response = [0u8; 1];
let nread =
netio::read_exact_or_eof(&mut mz_stream, &mut maybe_ssl_request_response).await?;
if nread == 1 && maybe_ssl_request_response == [ACCEPT_SSL_ENCRYPTION] {
let mut builder =
SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
builder.set_verify(SslVerifyMode::NONE);
let mut ssl = builder
.build()
.configure()?
.into_ssl(&envd_addr.to_string())?;
ssl.set_connect_state();
Conn::Ssl(SslStream::new(ssl, mz_stream)?)
} else {
Conn::Unencrypted(mz_stream)
}
} else {
Conn::Unencrypted(mz_stream)
};
let startup = FrontendStartupMessage::Startup {
version: VERSION_3,
params,
};
startup.encode(&mut buf)?;
mz_stream.write_all(&buf).await?;
let client_stream = conn.inner_mut();
let mut maybe_auth_frame = [0; 1 + 4 + 4];
let nread = netio::read_exact_or_eof(&mut mz_stream, &mut maybe_auth_frame).await?;
const AUTH_PASSWORD_CLEARTEXT: [u8; 9] = [b'R', 0, 0, 0, 8, 0, 0, 0, 3];
if nread == AUTH_PASSWORD_CLEARTEXT.len()
&& maybe_auth_frame == AUTH_PASSWORD_CLEARTEXT
&& password.is_some()
{
let Some(password) = password else {
unreachable!("verified some above");
};
let password = FrontendMessage::Password { password };
buf.clear();
password.encode(&mut buf)?;
mz_stream.write_all(&buf).await?;
mz_stream.flush().await?;
} else {
client_stream.write_all(&maybe_auth_frame[0..nread]).await?;
}
Ok(mz_stream)
}
}
impl mz_server_core::Server for PgwireBalancer {
const NAME: &'static str = "pgwire_balancer";
fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
let tls = self.tls.clone();
let internal_tls = self.internal_tls;
let resolver = Arc::clone(&self.resolver);
let inner_metrics = self.metrics.clone();
let outer_metrics = self.metrics.clone();
let cancellation_resolver = Arc::clone(&self.cancellation_resolver);
let conn_uuid = Uuid::new_v4();
let peer_addr = conn.peer_addr();
conn.uuid_handle().set(conn_uuid);
Box::pin(async move {
let active_guard = outer_metrics.active_connections();
let result: Result<(), anyhow::Error> = async move {
let mut conn = Conn::Unencrypted(conn);
loop {
let message = decode_startup(&mut conn).await?;
conn = match message {
None => return Ok(()),
Some(FrontendStartupMessage::Startup {
version,
mut params,
}) => {
let mut conn = FramedConn::new(conn);
let peer_addr = match peer_addr {
Ok(addr) => addr.ip(),
Err(e) => {
error!("Invalid peer_addr {:?}", e);
return Ok(conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"invalid peer address",
))
.await?);
}
};
debug!(%conn_uuid, %peer_addr, "starting new pgwire connection in balancer");
let prev =
params.insert(CONN_UUID_KEY.to_string(), conn_uuid.to_string());
if prev.is_some() {
return Ok(conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
format!("invalid parameter '{CONN_UUID_KEY}'"),
))
.await?);
}
if let Some(_) = params.insert(MZ_FORWARDED_FOR_KEY.to_string(), peer_addr.to_string().clone()) {
return Ok(conn
.send(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
format!("invalid parameter '{MZ_FORWARDED_FOR_KEY}'"),
))
.await?);
};
Self::run(
&mut conn,
version,
params,
&resolver,
tls.map(|tls| tls.mode),
internal_tls,
&inner_metrics,
)
.await?;
conn.flush().await?;
return Ok(());
}
Some(FrontendStartupMessage::CancelRequest {
conn_id,
secret_key,
}) => {
spawn(|| "cancel request", async move {
cancel_request(conn_id, secret_key, &cancellation_resolver).await;
});
return Ok(());
}
Some(FrontendStartupMessage::SslRequest) => match (conn, &tls) {
(Conn::Unencrypted(mut conn), Some(tls)) => {
conn.write_all(&[ACCEPT_SSL_ENCRYPTION]).await?;
let mut ssl_stream =
SslStream::new(Ssl::new(&tls.context.get())?, conn)?;
if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
let _ = ssl_stream.get_mut().shutdown().await;
return Err(e.into());
}
Conn::Ssl(ssl_stream)
}
(mut conn, _) => {
conn.write_all(&[REJECT_ENCRYPTION]).await?;
conn
}
},
Some(FrontendStartupMessage::GssEncRequest) => {
conn.write_all(&[REJECT_ENCRYPTION]).await?;
conn
}
}
}
}
.await;
drop(active_guard);
outer_metrics.connection_status(result.is_ok()).inc();
Ok(())
})
}
}
struct CountingConn<C> {
inner: C,
read: usize,
written: usize,
}
impl<C> CountingConn<C> {
fn new(inner: C) -> Self {
CountingConn {
inner,
read: 0,
written: 0,
}
}
}
impl<C> AsyncRead for CountingConn<C>
where
C: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
let bytes = buf.filled().len();
let poll = pin.poll_read(cx, buf);
let bytes = buf.filled().len() - bytes;
if let std::task::Poll::Ready(Ok(())) = poll {
counter.read += bytes
}
poll
}
}
impl<C> AsyncWrite for CountingConn<C>
where
C: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
let poll = pin.poll_write(cx, buf);
if let std::task::Poll::Ready(Ok(bytes)) = poll {
counter.written += bytes
}
poll
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
pin.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let counter = self.get_mut();
let pin = Pin::new(&mut counter.inner);
pin.poll_shutdown(cx)
}
}
async fn cancel_request(
conn_id: u32,
secret_key: u32,
cancellation_resolver: &CancellationResolver,
) {
let suffix = conn_id_org_uuid(conn_id);
let contents = match cancellation_resolver {
CancellationResolver::Directory(dir) => {
let path = dir.join(&suffix);
match std::fs::read_to_string(&path) {
Ok(contents) => contents,
Err(err) => {
error!("could not read cancel file {path:?}: {err}");
return;
}
}
}
CancellationResolver::Static(addr) => addr.to_owned(),
};
let mut all_ips = Vec::new();
for addr in contents.lines() {
let addr = addr.trim();
if addr.is_empty() {
continue;
}
match tokio::net::lookup_host(addr).await {
Ok(ips) => all_ips.extend(ips),
Err(err) => {
error!("{addr} failed resolution: {err}");
}
}
}
let mut buf = BytesMut::with_capacity(16);
let msg = FrontendStartupMessage::CancelRequest {
conn_id,
secret_key,
};
msg.encode(&mut buf).expect("must encode");
let buf = buf.freeze();
for ip in all_ips {
debug!("cancelling {suffix} to {ip}");
let buf = buf.clone();
spawn(|| "cancel request for ip", async move {
let send = async {
let mut stream = TcpStream::connect(&ip).await?;
stream.write_all(&buf).await?;
stream.shutdown().await?;
Ok::<_, io::Error>(())
};
if let Err(err) = send.await {
error!("error mirroring cancel to {ip}: {err}");
}
});
}
}
struct HttpsBalancer {
resolver: Arc<StubResolver>,
tls: Option<ReloadingSslContext>,
resolve_template: Arc<str>,
port: u16,
metrics: Arc<ServerMetrics>,
configs: ConfigSet,
internal_tls: bool,
}
impl HttpsBalancer {
async fn resolve(
resolver: &StubResolver,
resolve_template: &str,
port: u16,
servername: Option<&str>,
) -> Result<ResolvedAddr, anyhow::Error> {
let addr = match &servername {
Some(servername) => resolve_template.replace("{}", servername),
None => resolve_template.to_string(),
};
debug!("https address: {addr}");
let tenant = Self::tenant(resolver, &addr).await;
let envd_addr = lookup(&format!("{addr}:{port}")).await?;
Ok(ResolvedAddr {
addr: envd_addr,
password: None,
tenant,
})
}
async fn tenant(resolver: &StubResolver, addr: &str) -> Option<String> {
let Ok(dname) = Dname::<Vec<_>>::from_str(addr) else {
return None;
};
let lookup = resolver.query((dname, Rtype::Cname)).await;
if let Ok(lookup) = lookup {
if let Ok(answer) = lookup.answer() {
let res = answer.limit_to::<AllRecordData<_, _>>();
for record in res {
let Ok(record) = record else {
continue;
};
if record.rtype() != Rtype::Cname {
continue;
}
let cname = record.data();
let cname = cname.to_string();
debug!("cname: {cname}");
return Self::extract_tenant_from_cname(&cname);
}
}
}
None
}
fn extract_tenant_from_cname(cname: &str) -> Option<String> {
let mut parts = cname.split('.');
let _service = parts.next();
let Some(namespace) = parts.next() else {
return None;
};
let Some((_, namespace)) = namespace.split_once('-') else {
return None;
};
let Some((tenant, _)) = namespace.rsplit_once('-') else {
return None;
};
let Ok(tenant) = Uuid::parse_str(tenant) else {
error!("cname tenant not a uuid: {tenant}");
return None;
};
Some(tenant.to_string())
}
}
impl mz_server_core::Server for HttpsBalancer {
const NAME: &'static str = "https_balancer";
fn handle_connection(&self, conn: Connection) -> mz_server_core::ConnectionHandler {
let tls_context = self.tls.clone();
let internal_tls = self.internal_tls.clone();
let resolver = Arc::clone(&self.resolver);
let resolve_template = Arc::clone(&self.resolve_template);
let port = self.port;
let inner_metrics = Arc::clone(&self.metrics);
let outer_metrics = Arc::clone(&self.metrics);
let peer_addr = conn.peer_addr();
let inject_proxy_headers = INJECT_PROXY_PROTOCOL_HEADER_HTTP.get(&self.configs);
Box::pin(async move {
let active_guard = inner_metrics.active_connections();
let result: Result<_, anyhow::Error> = Box::pin(async move {
let peer_addr = peer_addr.context("fetching peer addr")?;
let (client_stream, servername): (Box<dyn ClientStream>, Option<String>) =
match tls_context {
Some(tls_context) => {
let mut ssl_stream =
SslStream::new(Ssl::new(&tls_context.get())?, conn)?;
if let Err(e) = Pin::new(&mut ssl_stream).accept().await {
let _ = ssl_stream.get_mut().shutdown().await;
return Err(e.into());
}
let servername: Option<String> =
ssl_stream.ssl().servername(NameType::HOST_NAME).map(|sn| {
match sn.split_once('.') {
Some((left, _right)) => left,
None => sn,
}
.into()
});
debug!("servername: {servername:?}");
(Box::new(ssl_stream), servername)
}
_ => (Box::new(conn), None),
};
let resolved =
Self::resolve(&resolver, &resolve_template, port, servername.as_deref())
.await?;
let inner_active_guard = resolved
.tenant
.as_ref()
.map(|tenant| inner_metrics.tenant_connections(tenant));
let mut mz_stream = TcpStream::connect(resolved.addr).await?;
if inject_proxy_headers {
let addrs = ProxiedAddress::stream(peer_addr, resolved.addr);
let header = ProxyHeader::with_address(addrs);
let mut buf = [0u8; 1024];
let len = header.encode_to_slice_v2(&mut buf)?;
mz_stream.write_all(&buf[..len]).await?;
}
let mut mz_stream = if internal_tls {
let mut builder =
SslConnector::builder(SslMethod::tls()).expect("Error creating builder.");
builder.set_verify(SslVerifyMode::NONE);
let mut ssl = builder
.build()
.configure()?
.into_ssl(&resolved.addr.to_string())?;
ssl.set_connect_state();
Conn::Ssl(SslStream::new(ssl, mz_stream)?)
} else {
Conn::Unencrypted(mz_stream)
};
let mut client_counter = CountingConn::new(client_stream);
let _ = tokio::io::copy_bidirectional(&mut client_counter, &mut mz_stream).await;
if let Some(tenant) = &resolved.tenant {
inner_metrics
.tenant_connections_tx(tenant)
.inc_by(u64::cast_from(client_counter.written));
inner_metrics
.tenant_connections_rx(tenant)
.inc_by(u64::cast_from(client_counter.read));
}
drop(inner_active_guard);
Ok(())
})
.await;
drop(active_guard);
outer_metrics.connection_status(result.is_ok()).inc();
if let Err(e) = result {
debug!("connection error: {e}");
}
Ok(())
})
}
}
trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> ClientStream for T {}
#[derive(Debug)]
pub enum Resolver {
Static(String),
Frontegg(FronteggResolver),
}
impl Resolver {
async fn resolve<A>(
&self,
conn: &mut FramedConn<A>,
user: &str,
) -> Result<ResolvedAddr, anyhow::Error>
where
A: AsyncRead + AsyncWrite + Unpin,
{
match self {
Resolver::Frontegg(FronteggResolver {
auth,
addr_template,
}) => {
conn.send(BackendMessage::AuthenticationCleartextPassword)
.await?;
conn.flush().await?;
let password = match conn.recv().await? {
Some(FrontendMessage::Password { password }) => password,
_ => anyhow::bail!("expected Password message"),
};
let auth_response = auth.authenticate(user, &password).await;
let auth_session = match auth_response {
Ok(auth_session) => auth_session,
Err(e) => {
warn!("pgwire connection failed authentication: {}", e);
anyhow::bail!("invalid password");
}
};
let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string());
let addr = lookup(&addr).await?;
Ok(ResolvedAddr {
addr,
password: Some(password),
tenant: Some(auth_session.tenant_id().to_string()),
})
}
Resolver::Static(addr) => {
let addr = lookup(addr).await?;
Ok(ResolvedAddr {
addr,
password: None,
tenant: None,
})
}
}
}
}
async fn lookup(name: &str) -> Result<SocketAddr, anyhow::Error> {
let mut addrs = tokio::net::lookup_host(name).await?;
match addrs.next() {
Some(addr) => Ok(addr),
None => {
error!("{name} did not resolve to any addresses");
anyhow::bail!("internal error")
}
}
}
#[derive(Debug)]
pub struct FronteggResolver {
pub auth: FronteggAuthentication,
pub addr_template: String,
}
#[derive(Debug)]
struct ResolvedAddr {
addr: SocketAddr,
password: Option<String>,
tenant: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[mz_ore::test]
fn test_tenant() {
let tests = vec![
("", None),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"service.something-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.ssvvcc.cloister.faraway",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ffa4d74bd0ad85a6ff29cc86c3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-1234.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58CD23FF-A4D7-4BD0-AD85-A6FF29CC86C3-0.svc.cluster.local",
Some("58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3"),
),
(
"environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3.svc.cluster.local",
None,
),
(
"environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
None,
),
(
"environmentd.environment-8cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local",
None,
),
];
for (name, expect) in tests {
let cname = HttpsBalancer::extract_tenant_from_cname(name);
assert_eq!(
cname.as_deref(),
expect,
"{name} got {cname:?} expected {expect:?}"
);
}
}
}