1use std::error::Error;
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::time::Duration;
19
20use anyhow::Context;
21use jsonwebtoken::DecodingKey;
22use mz_balancerd::{
23 BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver,
24};
25use mz_frontegg_auth::{
26 Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR,
27 DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
28};
29use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
30use mz_ore::cli::{self, CliConfig};
31use mz_ore::error::ErrorExt;
32use mz_ore::metrics::MetricsRegistry;
33use mz_ore::tracing::TracingHandle;
34use mz_server_core::TlsCliArgs;
35use tracing::{Instrument, info_span, warn};
36
37#[derive(Debug, clap::Parser)]
38#[clap(about = "Balancer service", long_about = None)]
39struct Args {
40 #[clap(subcommand)]
41 command: Command,
42
43 #[clap(flatten)]
44 tracing: TracingCliArgs,
45}
46
47#[derive(Debug, clap::Subcommand)]
48enum Command {
49 Service(ServiceArgs),
50}
51
52#[derive(Debug, clap::Parser)]
53pub struct ServiceArgs {
54 #[clap(long, value_name = "HOST:PORT")]
55 pgwire_listen_addr: SocketAddr,
56 #[clap(long, value_name = "HOST:PORT")]
57 https_listen_addr: SocketAddr,
58 #[clap(flatten)]
59 tls: TlsCliArgs,
60 #[clap(long, value_name = "HOST:PORT")]
61 internal_http_listen_addr: SocketAddr,
62
63 #[clap(long)]
65 internal_tls: bool,
66 #[clap(
68 long,
69 value_name = "HOST:PORT",
70 conflicts_with = "frontegg_resolver_template"
71 )]
72 static_resolver_addr: Option<String>,
73 #[clap(long,
76 value_name = "HOST.{}.NAME:PORT",
77 requires_all = &["frontegg_api_token_url", "frontegg_admin_role"],
78 )]
79 frontegg_resolver_template: Option<String>,
80 #[clap(long, value_name = "HOST.{}.NAME:PORT")]
84 https_resolver_template: String,
85 #[clap(
91 long,
92 value_name = "/path/to/configmap/dir/",
93 required_unless_present = "static_resolver_addr"
94 )]
95 cancellation_resolver_dir: Option<PathBuf>,
96
97 #[clap(long, env = "FRONTEGG_JWK", requires = "frontegg_resolver_template")]
100 frontegg_jwk: Option<String>,
101 #[clap(
103 long,
104 env = "FRONTEGG_JWK_FILE",
105 requires = "frontegg_resolver_template"
106 )]
107 frontegg_jwk_file: Option<PathBuf>,
108 #[clap(
110 long,
111 env = "FRONTEGG_API_TOKEN_URL",
112 requires = "frontegg_resolver_template"
113 )]
114 frontegg_api_token_url: Option<String>,
115 #[clap(
117 long,
118 env = "FRONTEGG_ADMIN_ROLE",
119 requires = "frontegg_resolver_template"
120 )]
121 frontegg_admin_role: Option<String>,
122
123 #[clap(long, env = "LAUNCHDARKLY_SDK_KEY")]
127 launchdarkly_sdk_key: Option<String>,
128 #[clap(
130 long,
131 env = "CONFIG_SYNC_TIMEOUT",
132 value_parser = humantime::parse_duration,
133 default_value = "30s"
134 )]
135 config_sync_timeout: Duration,
136 #[clap(
141 long,
142 env = "CONFIG_SYNC_LOOP_INTERVAL",
143 value_parser = humantime::parse_duration,
144 )]
145 config_sync_loop_interval: Option<Duration>,
146
147 #[clap(long, env = "CLOUD_PROVIDER")]
149 cloud_provider: Option<String>,
150 #[clap(long, env = "CLOUD_PROVIDER_REGION")]
152 cloud_provider_region: Option<String>,
153 #[clap(long, value_parser = parse_key_val::<String, String>, value_delimiter = ',')]
155 default_config: Option<Vec<(String, String)>>,
156}
157
158fn main() {
159 let args: Args = cli::parse_args(CliConfig::default());
160
161 let ncpus_useful = usize::max(1, std::cmp::min(num_cpus::get(), num_cpus::get_physical()));
163 let runtime = tokio::runtime::Builder::new_multi_thread()
164 .worker_threads(ncpus_useful)
165 .enable_all()
166 .build()
167 .expect("Failed building the Runtime");
168
169 let metrics_registry = MetricsRegistry::new();
170 let (tracing_handle, _tracing_guard) = runtime
171 .block_on(args.tracing.configure_tracing(
172 StaticTracingConfig {
173 service_name: "balancerd",
174 build_info: BUILD_INFO,
175 },
176 metrics_registry.clone(),
177 ))
178 .expect("failed to init tracing");
179
180 runtime.block_on(mz_alloc::register_metrics_into(&metrics_registry));
181
182 let root_span = info_span!("balancer");
183 let res = match args.command {
184 Command::Service(args) => runtime.block_on(run(args, tracing_handle).instrument(root_span)),
185 };
186
187 if let Err(err) = res {
188 panic!("balancer: fatal: {}", err.display_with_causes());
189 }
190 drop(_tracing_guard);
191}
192
193pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), anyhow::Error> {
194 let metrics_registry = MetricsRegistry::new();
195 let (resolver, cancellation_resolver) = match (
196 args.static_resolver_addr,
197 args.frontegg_resolver_template,
198 ) {
199 (None, Some(addr_template)) => {
200 let auth = Authenticator::new(
201 AuthenticatorConfig {
202 admin_api_token_url: args.frontegg_api_token_url.expect("clap enforced"),
203 decoding_key: match (args.frontegg_jwk, args.frontegg_jwk_file) {
204 (None, Some(path)) => {
205 let jwk = std::fs::read(&path).with_context(|| {
206 format!("read {path:?} for --frontegg-jwk-file")
207 })?;
208 DecodingKey::from_rsa_pem(&jwk)?
209 }
210 (Some(jwk), None) => DecodingKey::from_rsa_pem(jwk.as_bytes())?,
211 _ => anyhow::bail!(
212 "exactly one of --frontegg-jwk or --frontegg-jwk-file must be present"
213 ),
214 },
215 tenant_id: None,
216 now: mz_ore::now::SYSTEM_TIME.clone(),
217 admin_role: args.frontegg_admin_role.expect("clap enforced"),
218 refresh_drop_lru_size: DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
219 refresh_drop_factor: DEFAULT_REFRESH_DROP_FACTOR,
220 },
221 mz_frontegg_auth::Client::environmentd_default(),
222 &metrics_registry,
223 );
224 let cancellation_resolver_dir = args
225 .cancellation_resolver_dir
226 .expect("required unless static resolver present");
227 if !cancellation_resolver_dir.is_dir() {
228 anyhow::bail!("{cancellation_resolver_dir:?} is not a directory");
229 }
230 (
231 Resolver::Frontegg(FronteggResolver {
232 auth,
233 addr_template,
234 }),
235 CancellationResolver::Directory(cancellation_resolver_dir),
236 )
237 }
238 (Some(addr), None) => {
239 let mut addrs = tokio::net::lookup_host(&addr)
243 .await
244 .unwrap_or_else(|_| panic!("could not resolve {addr}"));
245 let Some(_resolved) = addrs.next() else {
246 panic!("{addr} did not resolve to any addresses");
247 };
248 drop(addrs);
249
250 (
251 Resolver::Static(addr.clone()),
252 CancellationResolver::Static(addr),
253 )
254 }
255 _ => anyhow::bail!(
256 "exactly one of --static-resolver-addr or --frontegg-resolver-template must be present"
257 ),
258 };
259 let config = BalancerConfig::new(
260 &BUILD_INFO,
261 args.internal_http_listen_addr,
262 args.pgwire_listen_addr,
263 args.https_listen_addr,
264 cancellation_resolver,
265 resolver,
266 args.https_resolver_template,
267 args.tls.into_config()?,
268 args.internal_tls,
269 metrics_registry,
270 mz_server_core::default_cert_reload_ticker(),
271 args.launchdarkly_sdk_key,
272 args.config_sync_timeout,
273 args.config_sync_loop_interval,
274 args.cloud_provider,
275 args.cloud_provider_region,
276 tracing_handle,
277 args.default_config.unwrap_or(vec![]),
278 );
279 let service = BalancerService::new(config).await?;
280 service.serve().await?;
281 warn!("balancer service exited");
282 Ok(())
283}
284
285fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
287where
288 T: std::str::FromStr,
289 T::Err: Error + Send + Sync + 'static,
290 U: std::str::FromStr,
291 U::Err: Error + Send + Sync + 'static,
292{
293 let pos = s
294 .find('=')
295 .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
296 Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
297}