balancerd/
balancerd.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Manages a single Materialize environment.
11//!
12//! It listens for SQL connections on port 6875 (MTRL) and for HTTP connections
13//! on port 6876.
14
15use std::error::Error;
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::time::Duration;
19
20use anyhow::Context;
21use jsonwebtoken::DecodingKey;
22use mz_balancerd::{
23    BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver,
24};
25use mz_frontegg_auth::{
26    Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR,
27    DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
28};
29use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
30use mz_ore::cli::{self, CliConfig};
31use mz_ore::error::ErrorExt;
32use mz_ore::metrics::MetricsRegistry;
33use mz_ore::tracing::TracingHandle;
34use mz_server_core::TlsCliArgs;
35use tracing::{Instrument, info_span, warn};
36
37#[derive(Debug, clap::Parser)]
38#[clap(about = "Balancer service", long_about = None)]
39struct Args {
40    #[clap(subcommand)]
41    command: Command,
42
43    #[clap(flatten)]
44    tracing: TracingCliArgs,
45}
46
47#[derive(Debug, clap::Subcommand)]
48enum Command {
49    Service(ServiceArgs),
50}
51
52#[derive(Debug, clap::Parser)]
53pub struct ServiceArgs {
54    #[clap(long, value_name = "HOST:PORT")]
55    pgwire_listen_addr: SocketAddr,
56    #[clap(long, value_name = "HOST:PORT")]
57    https_listen_addr: SocketAddr,
58    #[clap(flatten)]
59    tls: TlsCliArgs,
60    #[clap(long, value_name = "HOST:PORT")]
61    internal_http_listen_addr: SocketAddr,
62
63    /// Whether to initiate internal connections over TLS
64    #[clap(long)]
65    internal_tls: bool,
66    /// Static pgwire resolver address to use for local testing.
67    #[clap(
68        long,
69        value_name = "HOST:PORT",
70        conflicts_with = "frontegg_resolver_template"
71    )]
72    static_resolver_addr: Option<String>,
73    /// Frontegg resolver address template. `{}` is replaced with the user's frontegg tenant id to
74    /// get a DNS address. The first IP that address resolves to is the proxy destinations.
75    #[clap(long,
76        value_name = "HOST.{}.NAME:PORT",
77        requires_all = &["frontegg_api_token_url", "frontegg_admin_role"],
78    )]
79    frontegg_resolver_template: Option<String>,
80    /// HTTPS resolver address template. `{}` is replaced with the first subdomain of the HTTPS SNI
81    /// host address to get a DNS address. The first IP that address resolves to is the proxy
82    /// destinations.
83    #[clap(long, value_name = "HOST.{}.NAME:PORT")]
84    https_resolver_template: String,
85    /// Cancellation resolver configmap directory. The org id part of the incoming connection id
86    /// (the 12 bits after (and excluding) the first bit) converted to a 3-char UUID string is
87    /// appended to this to make a file path. That file is read, and every newline-delimited line
88    /// there is DNS resolved, and all returned IPs get a mirrored cancellation request. The lines
89    /// in the file must be of the form `host:port`.
90    #[clap(
91        long,
92        value_name = "/path/to/configmap/dir/",
93        required_unless_present = "static_resolver_addr"
94    )]
95    cancellation_resolver_dir: Option<PathBuf>,
96
97    /// JWK used to validate JWTs during Frontegg authentication as a PEM public
98    /// key. Can optionally be base64 encoded with the URL-safe alphabet.
99    #[clap(long, env = "FRONTEGG_JWK", requires = "frontegg_resolver_template")]
100    frontegg_jwk: Option<String>,
101    /// Path of JWK used to validate JWTs during Frontegg authentication as a PEM public key.
102    #[clap(
103        long,
104        env = "FRONTEGG_JWK_FILE",
105        requires = "frontegg_resolver_template"
106    )]
107    frontegg_jwk_file: Option<PathBuf>,
108    /// The full URL (including path) to the Frontegg api-token endpoint.
109    #[clap(
110        long,
111        env = "FRONTEGG_API_TOKEN_URL",
112        requires = "frontegg_resolver_template"
113    )]
114    frontegg_api_token_url: Option<String>,
115    /// The name of the admin role in Frontegg.
116    #[clap(
117        long,
118        env = "FRONTEGG_ADMIN_ROLE",
119        requires = "frontegg_resolver_template"
120    )]
121    frontegg_admin_role: Option<String>,
122
123    /// An SDK key for LaunchDarkly.
124    ///
125    /// Setting this will enable synchronization of LaunchDarkly features.
126    #[clap(long, env = "LAUNCHDARKLY_SDK_KEY")]
127    launchdarkly_sdk_key: Option<String>,
128    /// The duration at which the LaunchDarkly synchronization times out during startup.
129    #[clap(
130        long,
131        env = "CONFIG_SYNC_TIMEOUT",
132        value_parser = humantime::parse_duration,
133        default_value = "30s"
134    )]
135    config_sync_timeout: Duration,
136    /// The interval in seconds at which to synchronize LaunchDarkly values.
137    ///
138    /// If this is not explicitly set, the loop that synchronizes LaunchDarkly will not run _even if
139    /// [`Self::launchdarkly_sdk_key`] is present_ (however one initial sync is always run).
140    #[clap(
141        long,
142        env = "CONFIG_SYNC_LOOP_INTERVAL",
143        value_parser = humantime::parse_duration,
144    )]
145    config_sync_loop_interval: Option<Duration>,
146
147    /// The cloud provider where the balancer is running.
148    #[clap(long, env = "CLOUD_PROVIDER")]
149    cloud_provider: Option<String>,
150    /// The cloud provider region where the balancer is running.
151    #[clap(long, env = "CLOUD_PROVIDER_REGION")]
152    cloud_provider_region: Option<String>,
153    /// Set startup defaults for dynconfig
154    #[clap(long, value_parser = parse_key_val::<String, String>, value_delimiter = ',')]
155    default_config: Option<Vec<(String, String)>>,
156}
157
158fn main() {
159    let args: Args = cli::parse_args(CliConfig::default());
160
161    // Mirror the tokio Runtime configuration in our production binaries.
162    let ncpus_useful = usize::max(1, std::cmp::min(num_cpus::get(), num_cpus::get_physical()));
163    let runtime = tokio::runtime::Builder::new_multi_thread()
164        .worker_threads(ncpus_useful)
165        .enable_all()
166        .build()
167        .expect("Failed building the Runtime");
168
169    let metrics_registry = MetricsRegistry::new();
170    let (tracing_handle, _tracing_guard) = runtime
171        .block_on(args.tracing.configure_tracing(
172            StaticTracingConfig {
173                service_name: "balancerd",
174                build_info: BUILD_INFO,
175            },
176            metrics_registry.clone(),
177        ))
178        .expect("failed to init tracing");
179
180    runtime.block_on(mz_alloc::register_metrics_into(&metrics_registry));
181
182    let root_span = info_span!("balancer");
183    let res = match args.command {
184        Command::Service(args) => runtime.block_on(run(args, tracing_handle).instrument(root_span)),
185    };
186
187    if let Err(err) = res {
188        panic!("balancer: fatal: {}", err.display_with_causes());
189    }
190    drop(_tracing_guard);
191}
192
193pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), anyhow::Error> {
194    let metrics_registry = MetricsRegistry::new();
195    let (resolver, cancellation_resolver) = match (
196        args.static_resolver_addr,
197        args.frontegg_resolver_template,
198    ) {
199        (None, Some(addr_template)) => {
200            let auth = Authenticator::new(
201                AuthenticatorConfig {
202                    admin_api_token_url: args.frontegg_api_token_url.expect("clap enforced"),
203                    decoding_key: match (args.frontegg_jwk, args.frontegg_jwk_file) {
204                        (None, Some(path)) => {
205                            let jwk = std::fs::read(&path).with_context(|| {
206                                format!("read {path:?} for --frontegg-jwk-file")
207                            })?;
208                            DecodingKey::from_rsa_pem(&jwk)?
209                        }
210                        (Some(jwk), None) => DecodingKey::from_rsa_pem(jwk.as_bytes())?,
211                        _ => anyhow::bail!(
212                            "exactly one of --frontegg-jwk or --frontegg-jwk-file must be present"
213                        ),
214                    },
215                    tenant_id: None,
216                    now: mz_ore::now::SYSTEM_TIME.clone(),
217                    admin_role: args.frontegg_admin_role.expect("clap enforced"),
218                    refresh_drop_lru_size: DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
219                    refresh_drop_factor: DEFAULT_REFRESH_DROP_FACTOR,
220                },
221                mz_frontegg_auth::Client::environmentd_default(),
222                &metrics_registry,
223            );
224            let cancellation_resolver_dir = args
225                .cancellation_resolver_dir
226                .expect("required unless static resolver present");
227            if !cancellation_resolver_dir.is_dir() {
228                anyhow::bail!("{cancellation_resolver_dir:?} is not a directory");
229            }
230            (
231                Resolver::Frontegg(FronteggResolver {
232                    auth,
233                    addr_template,
234                }),
235                CancellationResolver::Directory(cancellation_resolver_dir),
236            )
237        }
238        (Some(addr), None) => {
239            // As a typo-check, verify that the passed address resolves to at least one IP. This
240            // result isn't recorded anywhere: we re-resolve on each request in case DNS changes.
241            // Here only to cause startup to crash if mis-typed.
242            let mut addrs = tokio::net::lookup_host(&addr)
243                .await
244                .unwrap_or_else(|_| panic!("could not resolve {addr}"));
245            let Some(_resolved) = addrs.next() else {
246                panic!("{addr} did not resolve to any addresses");
247            };
248            drop(addrs);
249
250            (
251                Resolver::Static(addr.clone()),
252                CancellationResolver::Static(addr),
253            )
254        }
255        _ => anyhow::bail!(
256            "exactly one of --static-resolver-addr or --frontegg-resolver-template must be present"
257        ),
258    };
259    let config = BalancerConfig::new(
260        &BUILD_INFO,
261        args.internal_http_listen_addr,
262        args.pgwire_listen_addr,
263        args.https_listen_addr,
264        cancellation_resolver,
265        resolver,
266        args.https_resolver_template,
267        args.tls.into_config()?,
268        args.internal_tls,
269        metrics_registry,
270        mz_server_core::default_cert_reload_ticker(),
271        args.launchdarkly_sdk_key,
272        args.config_sync_timeout,
273        args.config_sync_loop_interval,
274        args.cloud_provider,
275        args.cloud_provider_region,
276        tracing_handle,
277        args.default_config.unwrap_or(vec![]),
278    );
279    let service = BalancerService::new(config).await?;
280    service.serve().await?;
281    warn!("balancer service exited");
282    Ok(())
283}
284
285/// Parse a single key-value pair
286fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
287where
288    T: std::str::FromStr,
289    T::Err: Error + Send + Sync + 'static,
290    U: std::str::FromStr,
291    U::Err: Error + Send + Sync + 'static,
292{
293    let pos = s
294        .find('=')
295        .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
296    Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
297}