balancerd/
balancerd.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Manages a single Materialize environment.
11//!
12//! It listens for SQL connections on port 6875 (MTRL) and for HTTP connections
13//! on port 6876.
14
15use std::error::Error;
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::time::Duration;
19
20use anyhow::Context;
21use domain::resolv::StubResolver;
22use jsonwebtoken::DecodingKey;
23use mz_balancerd::{
24    BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver,
25    SniResolver,
26};
27use mz_frontegg_auth::{
28    Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR,
29    DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
30};
31use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
32use mz_ore::cli::{self, CliConfig};
33use mz_ore::error::ErrorExt;
34use mz_ore::metrics::MetricsRegistry;
35use mz_ore::tracing::TracingHandle;
36use mz_server_core::TlsCliArgs;
37use tracing::{Instrument, info_span, warn};
38
39#[derive(Debug, clap::Parser)]
40#[clap(about = "Balancer service", long_about = None)]
41struct Args {
42    #[clap(subcommand)]
43    command: Command,
44
45    #[clap(flatten)]
46    tracing: TracingCliArgs,
47}
48
49#[derive(Debug, clap::Subcommand)]
50enum Command {
51    Service(ServiceArgs),
52}
53
54#[derive(Debug, clap::Parser)]
55pub struct ServiceArgs {
56    #[clap(long, value_name = "HOST:PORT")]
57    pgwire_listen_addr: SocketAddr,
58    #[clap(long, value_name = "HOST:PORT")]
59    https_listen_addr: SocketAddr,
60    #[clap(flatten)]
61    tls: TlsCliArgs,
62    #[clap(long, value_name = "HOST:PORT")]
63    internal_http_listen_addr: SocketAddr,
64
65    /// Whether to initiate internal connections over TLS
66    #[clap(long)]
67    internal_tls: bool,
68    /// Static pgwire resolver address to use for local testing.
69    #[clap(
70        long,
71        value_name = "HOST:PORT",
72        conflicts_with = "frontegg_resolver_template"
73    )]
74    static_resolver_addr: Option<String>,
75    /// Frontegg resolver address template. `{}` is replaced with the user's frontegg tenant id to
76    /// get a DNS address. The first IP that address resolves to is the proxy destinations.
77    #[clap(long,
78        value_name = "HOST.{}.NAME:PORT",
79        requires_all = &["frontegg_api_token_url", "frontegg_admin_role"],
80    )]
81    frontegg_resolver_template: Option<String>,
82    /// HTTPS resolver address template. `{}` is replaced with the first subdomain of the HTTPS SNI
83    /// host address to get a DNS address. The first IP that address resolves to is the proxy
84    /// destinations.
85    #[clap(
86        long,
87        value_name = "HOST.{}.NAME:PORT",
88        visible_alias = "https-resolver-template"
89    )]
90    https_sni_resolver_template: String,
91    /// PGWIRE sni resolver address template. `{}` is replaced with the first subdomain of the PGWIRE SNI
92    /// host address to get a DNS address. The first IP that address resolves to is the proxy
93    /// destinations.
94    #[clap(long, value_name = "HOST.{}.NAME:PORT")]
95    pgwire_sni_resolver_template: Option<String>,
96    /// Cancellation resolver configmap directory. The org id part of the incoming connection id
97    /// (the 12 bits after (and excluding) the first bit) converted to a 3-char UUID string is
98    /// appended to this to make a file path. That file is read, and every newline-delimited line
99    /// there is DNS resolved, and all returned IPs get a mirrored cancellation request. The lines
100    /// in the file must be of the form `host:port`.
101    #[clap(
102        long,
103        value_name = "/path/to/configmap/dir/",
104        required_unless_present = "static_resolver_addr"
105    )]
106    cancellation_resolver_dir: Option<PathBuf>,
107
108    /// JWK used to validate JWTs during Frontegg authentication as a PEM public
109    /// key. Can optionally be base64 encoded with the URL-safe alphabet.
110    #[clap(long, env = "FRONTEGG_JWK", requires = "frontegg_resolver_template")]
111    frontegg_jwk: Option<String>,
112    /// Path of JWK used to validate JWTs during Frontegg authentication as a PEM public key.
113    #[clap(
114        long,
115        env = "FRONTEGG_JWK_FILE",
116        requires = "frontegg_resolver_template"
117    )]
118    frontegg_jwk_file: Option<PathBuf>,
119    /// The full URL (including path) to the Frontegg api-token endpoint.
120    #[clap(
121        long,
122        env = "FRONTEGG_API_TOKEN_URL",
123        requires = "frontegg_resolver_template"
124    )]
125    frontegg_api_token_url: Option<String>,
126    /// The name of the admin role in Frontegg.
127    #[clap(
128        long,
129        env = "FRONTEGG_ADMIN_ROLE",
130        requires = "frontegg_resolver_template"
131    )]
132    frontegg_admin_role: Option<String>,
133    /// An SDK key for LaunchDarkly.
134    ///
135    /// Setting this will enable synchronization of LaunchDarkly features.
136    #[clap(long, env = "LAUNCHDARKLY_SDK_KEY")]
137    launchdarkly_sdk_key: Option<String>,
138    /// Path to a JSON file containing system parameter values.
139    /// If specified, this file will be used instead of LaunchDarkly for configuration.
140    #[clap(long, env = "CONFIG_SYNC_FILE_PATH")]
141    config_sync_file_path: Option<PathBuf>,
142    /// The duration at which the LaunchDarkly synchronization times out during startup.
143    #[clap(
144        long,
145        env = "CONFIG_SYNC_TIMEOUT",
146        value_parser = humantime::parse_duration,
147        default_value = "30s"
148    )]
149    config_sync_timeout: Duration,
150    /// The interval in seconds at which to synchronize LaunchDarkly values.
151    ///
152    /// If this is not explicitly set, the loop that synchronizes LaunchDarkly will not run _even if
153    /// [`Self::launchdarkly_sdk_key`] is present_ (however one initial sync is always run).
154    #[clap(
155        long,
156        env = "CONFIG_SYNC_LOOP_INTERVAL",
157        value_parser = humantime::parse_duration,
158    )]
159    config_sync_loop_interval: Option<Duration>,
160
161    /// The cloud provider where the balancer is running.
162    #[clap(long, env = "CLOUD_PROVIDER")]
163    cloud_provider: Option<String>,
164    /// The cloud provider region where the balancer is running.
165    #[clap(long, env = "CLOUD_PROVIDER_REGION")]
166    cloud_provider_region: Option<String>,
167    /// Set startup defaults for dynconfig
168    #[clap(long, value_parser = parse_key_val::<String, String>, value_delimiter = ',')]
169    default_config: Option<Vec<(String, String)>>,
170}
171
172fn main() {
173    let args: Args = cli::parse_args(CliConfig::default());
174
175    // Mirror the tokio Runtime configuration in our production binaries.
176    let ncpus_useful = usize::max(1, std::cmp::min(num_cpus::get(), num_cpus::get_physical()));
177    let runtime = tokio::runtime::Builder::new_multi_thread()
178        .worker_threads(ncpus_useful)
179        .enable_all()
180        .build()
181        .expect("Failed building the Runtime");
182
183    let metrics_registry = MetricsRegistry::new();
184    let tracing_handle = runtime
185        .block_on(args.tracing.configure_tracing(
186            StaticTracingConfig {
187                service_name: "balancerd",
188                build_info: BUILD_INFO,
189            },
190            metrics_registry.clone(),
191        ))
192        .expect("failed to init tracing");
193
194    runtime.block_on(mz_alloc::register_metrics_into(&metrics_registry));
195
196    let root_span = info_span!("balancer");
197    let res = match args.command {
198        Command::Service(args) => runtime.block_on(run(args, tracing_handle).instrument(root_span)),
199    };
200
201    if let Err(err) = res {
202        panic!("balancer: fatal: {}", err.display_with_causes());
203    }
204}
205
206pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), anyhow::Error> {
207    let metrics_registry = MetricsRegistry::new();
208    let (resolver, cancellation_resolver) = match (
209        args.static_resolver_addr,
210        args.frontegg_resolver_template,
211    ) {
212        (None, Some(addr_template)) => {
213            let auth = Authenticator::new(
214                AuthenticatorConfig {
215                    admin_api_token_url: args.frontegg_api_token_url.expect("clap enforced"),
216                    decoding_key: match (args.frontegg_jwk, args.frontegg_jwk_file) {
217                        (None, Some(path)) => {
218                            let jwk = std::fs::read(&path).with_context(|| {
219                                format!("read {path:?} for --frontegg-jwk-file")
220                            })?;
221                            DecodingKey::from_rsa_pem(&jwk)?
222                        }
223                        (Some(jwk), None) => DecodingKey::from_rsa_pem(jwk.as_bytes())?,
224                        _ => anyhow::bail!(
225                            "exactly one of --frontegg-jwk or --frontegg-jwk-file must be present"
226                        ),
227                    },
228                    tenant_id: None,
229                    now: mz_ore::now::SYSTEM_TIME.clone(),
230                    admin_role: args.frontegg_admin_role.expect("clap enforced"),
231                    refresh_drop_lru_size: DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
232                    refresh_drop_factor: DEFAULT_REFRESH_DROP_FACTOR,
233                },
234                mz_frontegg_auth::Client::environmentd_default(),
235                &metrics_registry,
236            );
237            let cancellation_resolver_dir = args
238                .cancellation_resolver_dir
239                .expect("required unless static resolver present");
240            if !cancellation_resolver_dir.is_dir() {
241                anyhow::bail!("{cancellation_resolver_dir:?} is not a directory");
242            }
243            (
244                Resolver::MultiTenant(
245                    FronteggResolver {
246                        auth,
247                        addr_template,
248                    },
249                    match args.pgwire_sni_resolver_template {
250                        None => None,
251                        Some(template) => {
252                            let (template, port) = template
253                                .rsplit_once(':')
254                                .map(|(t, p)| {
255                                    (
256                                        t.to_owned(),
257                                        p.parse::<u16>().expect(
258                                            "invalid port for pgwire_sni_resolver_template",
259                                        ),
260                                    )
261                                })
262                                .expect("invalid port for pgwire_sni_resolver_template");
263                            Some(SniResolver {
264                                resolver: StubResolver::new(),
265                                template,
266                                port,
267                            })
268                        }
269                    },
270                ),
271                CancellationResolver::Directory(cancellation_resolver_dir),
272            )
273        }
274        (Some(addr), None) => {
275            // As a typo-check, verify that the passed address resolves to at least one IP. This
276            // result isn't recorded anywhere: we re-resolve on each request in case DNS changes.
277            // Here only to cause startup to crash if mistyped.
278            let mut addrs = tokio::net::lookup_host(&addr)
279                .await
280                .unwrap_or_else(|_| panic!("could not resolve {addr}"));
281            let Some(_resolved) = addrs.next() else {
282                panic!("{addr} did not resolve to any addresses");
283            };
284            drop(addrs);
285
286            (
287                Resolver::Static(addr.clone()),
288                CancellationResolver::Static(addr),
289            )
290        }
291        _ => anyhow::bail!(
292            "exactly one of --static-resolver-addr or --frontegg-resolver-template must be present"
293        ),
294    };
295    let config = BalancerConfig::new(
296        &BUILD_INFO,
297        args.internal_http_listen_addr,
298        args.pgwire_listen_addr,
299        args.https_listen_addr,
300        cancellation_resolver,
301        resolver,
302        args.https_sni_resolver_template,
303        args.tls.into_config()?,
304        args.internal_tls,
305        metrics_registry,
306        mz_server_core::default_cert_reload_ticker(),
307        args.launchdarkly_sdk_key,
308        args.config_sync_file_path,
309        args.config_sync_timeout,
310        args.config_sync_loop_interval,
311        args.cloud_provider,
312        args.cloud_provider_region,
313        tracing_handle,
314        args.default_config.unwrap_or(vec![]),
315    );
316    let service = BalancerService::new(config).await?;
317    service.serve().await?;
318    warn!("balancer service exited");
319    Ok(())
320}
321
322/// Parse a single key-value pair
323fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
324where
325    T: std::str::FromStr,
326    T::Err: Error + Send + Sync + 'static,
327    U: std::str::FromStr,
328    U::Err: Error + Send + Sync + 'static,
329{
330    let pos = s
331        .find('=')
332        .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
333    Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
334}