1use std::error::Error;
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::time::Duration;
19
20use anyhow::Context;
21use domain::resolv::StubResolver;
22use jsonwebtoken::DecodingKey;
23use mz_balancerd::{
24 BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver,
25 SniResolver,
26};
27use mz_frontegg_auth::{
28 Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR,
29 DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
30};
31use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
32use mz_ore::cli::{self, CliConfig};
33use mz_ore::error::ErrorExt;
34use mz_ore::metrics::MetricsRegistry;
35use mz_ore::tracing::TracingHandle;
36use mz_server_core::TlsCliArgs;
37use tracing::{Instrument, info_span, warn};
38
39#[derive(Debug, clap::Parser)]
40#[clap(about = "Balancer service", long_about = None)]
41struct Args {
42 #[clap(subcommand)]
43 command: Command,
44
45 #[clap(flatten)]
46 tracing: TracingCliArgs,
47}
48
49#[derive(Debug, clap::Subcommand)]
50enum Command {
51 Service(ServiceArgs),
52}
53
54#[derive(Debug, clap::Parser)]
55pub struct ServiceArgs {
56 #[clap(long, value_name = "HOST:PORT")]
57 pgwire_listen_addr: SocketAddr,
58 #[clap(long, value_name = "HOST:PORT")]
59 https_listen_addr: SocketAddr,
60 #[clap(flatten)]
61 tls: TlsCliArgs,
62 #[clap(long, value_name = "HOST:PORT")]
63 internal_http_listen_addr: SocketAddr,
64
65 #[clap(long)]
67 internal_tls: bool,
68 #[clap(
70 long,
71 value_name = "HOST:PORT",
72 conflicts_with = "frontegg_resolver_template"
73 )]
74 static_resolver_addr: Option<String>,
75 #[clap(long,
78 value_name = "HOST.{}.NAME:PORT",
79 requires_all = &["frontegg_api_token_url", "frontegg_admin_role"],
80 )]
81 frontegg_resolver_template: Option<String>,
82 #[clap(
86 long,
87 value_name = "HOST.{}.NAME:PORT",
88 visible_alias = "https-resolver-template"
89 )]
90 https_sni_resolver_template: String,
91 #[clap(long, value_name = "HOST.{}.NAME:PORT")]
95 pgwire_sni_resolver_template: Option<String>,
96 #[clap(
102 long,
103 value_name = "/path/to/configmap/dir/",
104 required_unless_present = "static_resolver_addr"
105 )]
106 cancellation_resolver_dir: Option<PathBuf>,
107
108 #[clap(long, env = "FRONTEGG_JWK", requires = "frontegg_resolver_template")]
111 frontegg_jwk: Option<String>,
112 #[clap(
114 long,
115 env = "FRONTEGG_JWK_FILE",
116 requires = "frontegg_resolver_template"
117 )]
118 frontegg_jwk_file: Option<PathBuf>,
119 #[clap(
121 long,
122 env = "FRONTEGG_API_TOKEN_URL",
123 requires = "frontegg_resolver_template"
124 )]
125 frontegg_api_token_url: Option<String>,
126 #[clap(
128 long,
129 env = "FRONTEGG_ADMIN_ROLE",
130 requires = "frontegg_resolver_template"
131 )]
132 frontegg_admin_role: Option<String>,
133 #[clap(long, env = "LAUNCHDARKLY_SDK_KEY")]
137 launchdarkly_sdk_key: Option<String>,
138 #[clap(long, env = "CONFIG_SYNC_FILE_PATH")]
141 config_sync_file_path: Option<PathBuf>,
142 #[clap(
144 long,
145 env = "CONFIG_SYNC_TIMEOUT",
146 value_parser = humantime::parse_duration,
147 default_value = "30s"
148 )]
149 config_sync_timeout: Duration,
150 #[clap(
155 long,
156 env = "CONFIG_SYNC_LOOP_INTERVAL",
157 value_parser = humantime::parse_duration,
158 )]
159 config_sync_loop_interval: Option<Duration>,
160
161 #[clap(long, env = "CLOUD_PROVIDER")]
163 cloud_provider: Option<String>,
164 #[clap(long, env = "CLOUD_PROVIDER_REGION")]
166 cloud_provider_region: Option<String>,
167 #[clap(long, value_parser = parse_key_val::<String, String>, value_delimiter = ',')]
169 default_config: Option<Vec<(String, String)>>,
170}
171
172fn main() {
173 let args: Args = cli::parse_args(CliConfig::default());
174
175 let ncpus_useful = usize::max(1, std::cmp::min(num_cpus::get(), num_cpus::get_physical()));
177 let runtime = tokio::runtime::Builder::new_multi_thread()
178 .worker_threads(ncpus_useful)
179 .enable_all()
180 .build()
181 .expect("Failed building the Runtime");
182
183 let metrics_registry = MetricsRegistry::new();
184 let tracing_handle = runtime
185 .block_on(args.tracing.configure_tracing(
186 StaticTracingConfig {
187 service_name: "balancerd",
188 build_info: BUILD_INFO,
189 },
190 metrics_registry.clone(),
191 ))
192 .expect("failed to init tracing");
193
194 runtime.block_on(mz_alloc::register_metrics_into(&metrics_registry));
195
196 let root_span = info_span!("balancer");
197 let res = match args.command {
198 Command::Service(args) => runtime.block_on(run(args, tracing_handle).instrument(root_span)),
199 };
200
201 if let Err(err) = res {
202 panic!("balancer: fatal: {}", err.display_with_causes());
203 }
204}
205
206pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), anyhow::Error> {
207 let metrics_registry = MetricsRegistry::new();
208 let (resolver, cancellation_resolver) = match (
209 args.static_resolver_addr,
210 args.frontegg_resolver_template,
211 ) {
212 (None, Some(addr_template)) => {
213 let auth = Authenticator::new(
214 AuthenticatorConfig {
215 admin_api_token_url: args.frontegg_api_token_url.expect("clap enforced"),
216 decoding_key: match (args.frontegg_jwk, args.frontegg_jwk_file) {
217 (None, Some(path)) => {
218 let jwk = std::fs::read(&path).with_context(|| {
219 format!("read {path:?} for --frontegg-jwk-file")
220 })?;
221 DecodingKey::from_rsa_pem(&jwk)?
222 }
223 (Some(jwk), None) => DecodingKey::from_rsa_pem(jwk.as_bytes())?,
224 _ => anyhow::bail!(
225 "exactly one of --frontegg-jwk or --frontegg-jwk-file must be present"
226 ),
227 },
228 tenant_id: None,
229 now: mz_ore::now::SYSTEM_TIME.clone(),
230 admin_role: args.frontegg_admin_role.expect("clap enforced"),
231 refresh_drop_lru_size: DEFAULT_REFRESH_DROP_LRU_CACHE_SIZE,
232 refresh_drop_factor: DEFAULT_REFRESH_DROP_FACTOR,
233 },
234 mz_frontegg_auth::Client::environmentd_default(),
235 &metrics_registry,
236 );
237 let cancellation_resolver_dir = args
238 .cancellation_resolver_dir
239 .expect("required unless static resolver present");
240 if !cancellation_resolver_dir.is_dir() {
241 anyhow::bail!("{cancellation_resolver_dir:?} is not a directory");
242 }
243 (
244 Resolver::MultiTenant(
245 FronteggResolver {
246 auth,
247 addr_template,
248 },
249 match args.pgwire_sni_resolver_template {
250 None => None,
251 Some(template) => {
252 let (template, port) = template
253 .rsplit_once(':')
254 .map(|(t, p)| {
255 (
256 t.to_owned(),
257 p.parse::<u16>().expect(
258 "invalid port for pgwire_sni_resolver_template",
259 ),
260 )
261 })
262 .expect("invalid port for pgwire_sni_resolver_template");
263 Some(SniResolver {
264 resolver: StubResolver::new(),
265 template,
266 port,
267 })
268 }
269 },
270 ),
271 CancellationResolver::Directory(cancellation_resolver_dir),
272 )
273 }
274 (Some(addr), None) => {
275 let mut addrs = tokio::net::lookup_host(&addr)
279 .await
280 .unwrap_or_else(|_| panic!("could not resolve {addr}"));
281 let Some(_resolved) = addrs.next() else {
282 panic!("{addr} did not resolve to any addresses");
283 };
284 drop(addrs);
285
286 (
287 Resolver::Static(addr.clone()),
288 CancellationResolver::Static(addr),
289 )
290 }
291 _ => anyhow::bail!(
292 "exactly one of --static-resolver-addr or --frontegg-resolver-template must be present"
293 ),
294 };
295 let config = BalancerConfig::new(
296 &BUILD_INFO,
297 args.internal_http_listen_addr,
298 args.pgwire_listen_addr,
299 args.https_listen_addr,
300 cancellation_resolver,
301 resolver,
302 args.https_sni_resolver_template,
303 args.tls.into_config()?,
304 args.internal_tls,
305 metrics_registry,
306 mz_server_core::default_cert_reload_ticker(),
307 args.launchdarkly_sdk_key,
308 args.config_sync_file_path,
309 args.config_sync_timeout,
310 args.config_sync_loop_interval,
311 args.cloud_provider,
312 args.cloud_provider_region,
313 tracing_handle,
314 args.default_config.unwrap_or(vec![]),
315 );
316 let service = BalancerService::new(config).await?;
317 service.serve().await?;
318 warn!("balancer service exited");
319 Ok(())
320}
321
322fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
324where
325 T: std::str::FromStr,
326 T::Err: Error + Send + Sync + 'static,
327 U: std::str::FromStr,
328 U::Err: Error + Send + Sync + 'static,
329{
330 let pos = s
331 .find('=')
332 .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
333 Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
334}