1use anyhow::bail;
13use aws_config::SdkConfig;
14use fancy_regex::Regex;
15use std::collections::{BTreeMap, btree_map};
16use std::error::Error;
17use std::io;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::str::FromStr;
20use std::sync::Arc;
21use std::sync::Mutex;
22use std::time::Duration;
23use tokio::sync::watch;
24
25use anyhow::{Context, anyhow};
26use crossbeam::channel::{Receiver, Sender, unbounded};
27use mz_ore::collections::CollectionExt;
28use mz_ore::error::ErrorExt;
29use mz_ore::future::InTask;
30use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelStatus};
31use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
32use rdkafka::client::{Client, NativeClient, OAuthToken};
33use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
34use rdkafka::consumer::{ConsumerContext, Rebalance};
35use rdkafka::error::{KafkaError, KafkaResult, RDKafkaErrorCode};
36use rdkafka::producer::{DefaultProducerContext, DeliveryResult, ProducerContext};
37use rdkafka::types::RDKafkaRespErr;
38use rdkafka::util::Timeout;
39use rdkafka::{ClientContext, Statistics, TopicPartitionList};
40use serde::{Deserialize, Serialize};
41use tokio::runtime::Handle;
42use tracing::{Level, debug, error, info, trace, warn};
43
44use crate::aws;
45
46pub const DEFAULT_TOPIC_METADATA_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
52
53pub struct MzClientContext {
60 error_tx: Sender<MzKafkaError>,
62 statistics_tx: watch::Sender<Statistics>,
65}
66
67impl Default for MzClientContext {
68 fn default() -> Self {
69 Self::with_errors().0
70 }
71}
72
73impl MzClientContext {
74 pub fn with_errors() -> (Self, Receiver<MzKafkaError>) {
79 let (error_tx, error_rx) = unbounded();
80 let (statistics_tx, _) = watch::channel(Default::default());
81 let ctx = Self {
82 error_tx,
83 statistics_tx,
84 };
85 (ctx, error_rx)
86 }
87
88 pub fn subscribe_statistics(&self) -> watch::Receiver<Statistics> {
91 self.statistics_tx.subscribe()
92 }
93
94 fn record_error(&self, msg: &str) {
95 let err = match MzKafkaError::from_str(msg) {
96 Ok(err) => err,
97 Err(()) => {
98 warn!(original_error = msg, "failed to parse kafka error");
99 MzKafkaError::Internal(msg.to_owned())
100 }
101 };
102 let _ = self.error_tx.send(err);
104 }
105}
106
107#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
109pub enum MzKafkaError {
110 #[error("Invalid username or password")]
112 InvalidCredentials,
113 #[error("Invalid CA certificate")]
115 InvalidCACertificate,
116 #[error("Disconnected during handshake; broker might require SSL encryption")]
118 SSLEncryptionMaybeRequired,
119 #[error("Broker does not support SSL connections")]
121 SSLUnsupported,
122 #[error("Broker did not provide a certificate")]
124 BrokerCertificateMissing,
125 #[error("Failed to verify broker certificate")]
127 InvalidBrokerCertificate,
128 #[error("Connection reset: {0}")]
130 ConnectionReset(String),
131 #[error("Connection timeout")]
133 ConnectionTimeout,
134 #[error("Failed to resolve hostname")]
136 HostnameResolutionFailed,
137 #[error("Unsupported SASL mechanism")]
139 UnsupportedSASLMechanism,
140 #[error("Unsupported broker version")]
142 UnsupportedBrokerVersion,
143 #[error("Broker transport failure")]
145 BrokerTransportFailure,
146 #[error("All brokers down")]
148 AllBrokersDown,
149 #[error("SASL authentication required")]
151 SaslAuthenticationRequired,
152 #[error("SASL authentication failed")]
154 SaslAuthenticationFailed,
155 #[error("SSL authentication required")]
157 SslAuthenticationRequired,
158 #[error("Unknown topic or partition")]
160 UnknownTopicOrPartition,
161 #[error("Internal kafka error: {0}")]
163 Internal(String),
164}
165
166impl FromStr for MzKafkaError {
167 type Err = ();
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 if s.contains("Authentication failed: Invalid username or password") {
171 Ok(Self::InvalidCredentials)
172 } else if s.contains("broker certificate could not be verified") {
173 Ok(Self::InvalidCACertificate)
174 } else if s.contains("connecting to a SSL listener?") {
175 Ok(Self::SSLEncryptionMaybeRequired)
176 } else if s.contains("client SSL authentication might be required") {
177 Ok(Self::SslAuthenticationRequired)
178 } else if s.contains("connecting to a PLAINTEXT broker listener") {
179 Ok(Self::SSLUnsupported)
180 } else if s.contains("Broker did not provide a certificate") {
181 Ok(Self::BrokerCertificateMissing)
182 } else if s.contains("Failed to verify broker certificate: ") {
183 Ok(Self::InvalidBrokerCertificate)
184 } else if let Some((_prefix, inner)) = s.split_once("Send failed: ") {
185 Ok(Self::ConnectionReset(inner.to_owned()))
186 } else if let Some((_prefix, inner)) = s.split_once("Receive failed: ") {
187 Ok(Self::ConnectionReset(inner.to_owned()))
188 } else if s.contains("request(s) timed out: disconnect") {
189 Ok(Self::ConnectionTimeout)
190 } else if s.contains("Failed to resolve") {
191 Ok(Self::HostnameResolutionFailed)
192 } else if s.contains("mechanism handshake failed:") {
193 Ok(Self::UnsupportedSASLMechanism)
194 } else if s.contains(
195 "verify that security.protocol is correctly configured, \
196 broker might require SASL authentication",
197 ) {
198 Ok(Self::SaslAuthenticationRequired)
199 } else if s.contains("SASL authentication error: Authentication failed") {
200 Ok(Self::SaslAuthenticationFailed)
201 } else if s
202 .contains("incorrect security.protocol configuration (connecting to a SSL listener?)")
203 {
204 Ok(Self::SslAuthenticationRequired)
205 } else if s.contains("probably due to broker version < 0.10") {
206 Ok(Self::UnsupportedBrokerVersion)
207 } else if s.contains("Disconnected while requesting ApiVersion")
208 || s.contains("Broker transport failure")
209 || s.contains("Connection refused")
210 {
211 Ok(Self::BrokerTransportFailure)
212 } else if Regex::new(r"(\d+)/\1 brokers are down")
213 .unwrap()
214 .is_match(s)
215 .unwrap_or_default()
216 {
217 Ok(Self::AllBrokersDown)
218 } else if s.contains("Unknown topic or partition") || s.contains("Unknown partition") {
219 Ok(Self::UnknownTopicOrPartition)
220 } else {
221 Err(())
222 }
223 }
224}
225
226impl ClientContext for MzClientContext {
227 fn log(&self, level: rdkafka::config::RDKafkaLogLevel, fac: &str, log_message: &str) {
228 use rdkafka::config::RDKafkaLogLevel::*;
229
230 if matches!(level, Emerg | Alert | Critical | Error) || fac == "FAIL" {
238 self.record_error(log_message);
239 }
240
241 match level {
244 Emerg | Alert | Critical | Error => {
245 warn!(target: "librdkafka", "error: {} {}", fac, log_message);
250 }
251 Warning => warn!(target: "librdkafka", "warning: {} {}", fac, log_message),
252 Notice => info!(target: "librdkafka", "{} {}", fac, log_message),
253 Info => info!(target: "librdkafka", "{} {}", fac, log_message),
254 Debug => debug!(target: "librdkafka", "{} {}", fac, log_message),
255 }
256 }
257
258 fn stats(&self, statistics: Statistics) {
259 self.statistics_tx.send_replace(statistics);
260 }
261
262 fn error(&self, error: KafkaError, reason: &str) {
263 self.record_error(reason);
264 warn!(target: "librdkafka", "error: {}: {}", error, reason);
266 }
267}
268
269impl ConsumerContext for MzClientContext {}
270
271impl ProducerContext for MzClientContext {
272 type DeliveryOpaque = <DefaultProducerContext as ProducerContext>::DeliveryOpaque;
273 fn delivery(
274 &self,
275 delivery_result: &DeliveryResult<'_>,
276 delivery_opaque: Self::DeliveryOpaque,
277 ) {
278 DefaultProducerContext.delivery(delivery_result, delivery_opaque);
279 }
280}
281
282#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
284pub struct BrokerAddr {
285 pub host: String,
287 pub port: u16,
289}
290
291impl BrokerAddr {
292 pub fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
294 Ok((self.host.as_str(), self.port).to_socket_addrs()?.collect())
295 }
296}
297
298#[derive(Debug, Clone)]
302pub struct BrokerRewrite {
303 pub host: String,
305 pub port: Option<u16>,
309}
310
311impl BrokerRewrite {
312 pub fn rewrite(&self, address: &BrokerAddr) -> BrokerAddr {
314 BrokerAddr {
315 host: self.host.clone(),
316 port: self.port.unwrap_or(address.port),
317 }
318 }
319}
320
321#[derive(Clone)]
322enum BrokerRewriteHandle {
323 Simple(BrokerRewrite),
324 SshTunnel(
325 ManagedSshTunnelHandle,
327 ),
328 FailedDefaultSshTunnel(String),
331}
332
333#[derive(Clone)]
334pub struct ConnectionRulePattern {
336 pub prefix_wildcard: bool,
338 pub literal_match: String,
340 pub suffix_wildcard: bool,
342}
343
344impl std::fmt::Display for ConnectionRulePattern {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 if self.prefix_wildcard {
347 f.write_str("*")?;
348 }
349 f.write_str(&self.literal_match)?;
350 if self.suffix_wildcard {
351 f.write_str("*")?;
352 }
353 Ok(())
354 }
355}
356
357impl ConnectionRulePattern {
358 pub fn matches(&self, address: &str) -> bool {
360 if self.prefix_wildcard {
361 if self.suffix_wildcard {
362 address.contains(&self.literal_match)
363 } else {
364 address.ends_with(&self.literal_match)
365 }
366 } else if self.suffix_wildcard {
367 address.starts_with(&self.literal_match)
368 } else {
369 address == self.literal_match
370 }
371 }
372}
373
374#[derive(Clone)]
375pub struct HostMappingRules {
377 pub rules: Vec<(ConnectionRulePattern, BrokerRewrite)>,
379}
380
381impl HostMappingRules {
382 pub fn rewrite(&self, src: &BrokerAddr) -> Option<BrokerAddr> {
385 let address = format!("{}:{}", src.host, src.port);
386 for (pattern, dst) in &self.rules {
387 if pattern.matches(&address) {
388 let result = dst.rewrite(src);
389 info!(
390 "HostMappingRules: broker {}:{} matched pattern '{}' -> rewriting to {}:{}",
391 src.host, src.port, pattern, result.host, result.port,
392 );
393 return Some(result);
394 }
395 }
396
397 warn!(
398 "HostMappingRules: broker {}:{} matched no rules, using original address",
399 src.host, src.port,
400 );
401 None
402 }
403}
404
405#[derive(Clone)]
408pub enum TunnelConfig {
409 Ssh(SshTunnelConfig),
411 StaticHost(String),
413 Rules(HostMappingRules),
415 None,
417}
418
419#[derive(Clone)]
421pub struct TunnelingClientContext<C> {
422 inner: C,
423 rewrites: Arc<Mutex<BTreeMap<BrokerAddr, BrokerRewriteHandle>>>,
424 default_tunnel: TunnelConfig,
425 in_task: InTask,
426 ssh_tunnel_manager: SshTunnelManager,
427 ssh_timeout_config: SshTimeoutConfig,
428 aws_config: Option<SdkConfig>,
429 runtime: Handle,
430}
431
432impl<C> TunnelingClientContext<C> {
433 pub fn new(
435 inner: C,
436 runtime: Handle,
437 ssh_tunnel_manager: SshTunnelManager,
438 ssh_timeout_config: SshTimeoutConfig,
439 aws_config: Option<SdkConfig>,
440 in_task: InTask,
441 ) -> TunnelingClientContext<C> {
442 TunnelingClientContext {
443 inner,
444 rewrites: Arc::new(Mutex::new(BTreeMap::new())),
445 default_tunnel: TunnelConfig::None,
446 in_task,
447 ssh_tunnel_manager,
448 ssh_timeout_config,
449 aws_config,
450 runtime,
451 }
452 }
453
454 pub fn set_default_tunnel(&mut self, tunnel: TunnelConfig) {
459 self.default_tunnel = tunnel;
460 }
461
462 pub async fn add_ssh_tunnel(
469 &self,
470 broker: BrokerAddr,
471 tunnel: SshTunnelConfig,
472 ) -> Result<(), anyhow::Error> {
473 let ssh_tunnel = self
474 .ssh_tunnel_manager
475 .connect(
476 tunnel,
477 &broker.host,
478 broker.port,
479 self.ssh_timeout_config,
480 self.in_task,
481 )
482 .await
483 .context("creating ssh tunnel")?;
484
485 let mut rewrites = self.rewrites.lock().expect("poisoned");
486 rewrites.insert(broker, BrokerRewriteHandle::SshTunnel(ssh_tunnel));
487 Ok(())
488 }
489
490 pub fn add_broker_rewrite(&self, broker: BrokerAddr, rewrite: BrokerRewrite) {
496 let mut rewrites = self.rewrites.lock().expect("poisoned");
497 rewrites.insert(broker, BrokerRewriteHandle::Simple(rewrite));
498 }
499
500 pub fn inner(&self) -> &C {
502 &self.inner
503 }
504
505 pub fn tunnel_status(&self) -> SshTunnelStatus {
508 self.rewrites
509 .lock()
510 .expect("poisoned")
511 .values()
512 .map(|handle| match handle {
513 BrokerRewriteHandle::SshTunnel(s) => s.check_status(),
514 BrokerRewriteHandle::FailedDefaultSshTunnel(e) => {
515 SshTunnelStatus::Errored(e.clone())
516 }
517 BrokerRewriteHandle::Simple(_) => SshTunnelStatus::Running,
518 })
519 .fold(SshTunnelStatus::Running, |acc, status| {
520 match (acc, status) {
521 (SshTunnelStatus::Running, SshTunnelStatus::Errored(e))
522 | (SshTunnelStatus::Errored(e), SshTunnelStatus::Running) => {
523 SshTunnelStatus::Errored(e)
524 }
525 (SshTunnelStatus::Errored(err), SshTunnelStatus::Errored(e)) => {
526 SshTunnelStatus::Errored(format!("{}, {}", err, e))
527 }
528 (SshTunnelStatus::Running, SshTunnelStatus::Running) => {
529 SshTunnelStatus::Running
530 }
531 }
532 })
533 }
534}
535
536impl<C> ClientContext for TunnelingClientContext<C>
537where
538 C: ClientContext,
539{
540 const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;
541
542 fn generate_oauth_token(
543 &self,
544 _oauthbearer_config: Option<&str>,
545 ) -> Result<OAuthToken, Box<dyn Error>> {
546 info!(target: "librdkafka", "generating OAuth token");
557
558 let generate = || {
559 let Some(sdk_config) = &self.aws_config else {
560 bail!("internal error: AWS configuration missing");
561 };
562
563 self.runtime.block_on(aws::generate_auth_token(sdk_config))
564 };
565
566 match generate() {
567 Ok((token, lifetime_ms)) => {
568 info!(target: "librdkafka", %lifetime_ms, "successfully generated OAuth token");
569 trace!(target: "librdkafka", %token);
570 Ok(OAuthToken {
571 token,
572 lifetime_ms,
573 principal_name: "".to_string(),
574 })
575 }
576 Err(e) => {
577 warn!(target: "librdkafka", "failed to generate OAuth token: {e:#}");
578 Err(e.into())
579 }
580 }
581 }
582
583 fn resolve_broker_addr(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, io::Error> {
588 info!("kafka: resolve_broker_addr called for {}:{}", host, port);
589 let return_rewrite = |rewrite: &BrokerRewriteHandle| -> Result<Vec<SocketAddr>, io::Error> {
590 let rewrite = match rewrite {
591 BrokerRewriteHandle::Simple(rewrite) => rewrite.clone(),
592 BrokerRewriteHandle::SshTunnel(ssh_tunnel) => {
593 let addr = ssh_tunnel.local_addr();
596 BrokerRewrite {
597 host: addr.ip().to_string(),
598 port: Some(addr.port()),
599 }
600 }
601 BrokerRewriteHandle::FailedDefaultSshTunnel(_) => {
602 unreachable!()
603 }
604 };
605 let rewrite_port = rewrite.port.unwrap_or(port);
606
607 info!(
608 "rewriting broker {}:{} to {}:{}",
609 host, port, rewrite.host, rewrite_port
610 );
611
612 (rewrite.host, rewrite_port)
613 .to_socket_addrs()
614 .map(|addrs| addrs.collect())
615 };
616
617 let addr = BrokerAddr {
618 host: host.into(),
619 port,
620 };
621 let rewrite = self.rewrites.lock().expect("poisoned").get(&addr).cloned();
622
623 match rewrite {
624 None | Some(BrokerRewriteHandle::FailedDefaultSshTunnel(_)) => {
626 match &self.default_tunnel {
628 TunnelConfig::Ssh(default_tunnel) => {
631 let ssh_tunnel = self.runtime.block_on(async {
635 self.ssh_tunnel_manager
636 .connect(
637 default_tunnel.clone(),
638 host,
639 port,
640 self.ssh_timeout_config,
641 self.in_task,
642 )
643 .await
644 });
645 match ssh_tunnel {
646 Ok(ssh_tunnel) => {
648 let mut rewrites = self.rewrites.lock().expect("poisoned");
649 let rewrite = match rewrites.entry(addr.clone()) {
650 btree_map::Entry::Occupied(mut o)
651 if matches!(
652 o.get(),
653 BrokerRewriteHandle::FailedDefaultSshTunnel(_)
654 ) =>
655 {
656 o.insert(BrokerRewriteHandle::SshTunnel(
657 ssh_tunnel.clone(),
658 ));
659 o.into_mut()
660 }
661 btree_map::Entry::Occupied(o) => o.into_mut(),
662 btree_map::Entry::Vacant(v) => {
663 v.insert(BrokerRewriteHandle::SshTunnel(ssh_tunnel.clone()))
664 }
665 };
666
667 return_rewrite(rewrite)
668 }
669 Err(e) => {
671 warn!(
672 "failed to create ssh tunnel for {:?}: {}",
673 addr,
674 e.display_with_causes()
675 );
676
677 let mut rewrites = self.rewrites.lock().expect("poisoned");
679 rewrites.entry(addr.clone()).or_insert_with(|| {
680 BrokerRewriteHandle::FailedDefaultSshTunnel(
681 e.to_string_with_causes(),
682 )
683 });
684
685 Err(io::Error::new(
686 io::ErrorKind::Other,
687 "creating SSH tunnel failed",
688 ))
689 }
690 }
691 }
692 TunnelConfig::StaticHost(host) => (host.as_str(), port)
694 .to_socket_addrs()
695 .map(|addrs| addrs.collect()),
696 TunnelConfig::Rules(rules) => {
698 let resolved = rules.rewrite(&addr).unwrap_or_else(|| addr.clone());
700 match resolved.to_socket_addrs() {
701 Ok(addrs) => {
702 info!(
703 "kafka: resolve_broker_addr {}:{} -> {}:{} resolved to {:?}",
704 host, port, resolved.host, resolved.port, addrs,
705 );
706 Ok(addrs)
707 }
708 Err(e) => {
709 warn!(
710 "kafka: resolve_broker_addr {}:{} -> {}:{} DNS resolution FAILED: {e}",
711 host, port, resolved.host, resolved.port,
712 );
713 Err(e)
714 }
715 }
716 }
717 TunnelConfig::None => {
719 (host, port).to_socket_addrs().map(|addrs| addrs.collect())
720 }
721 }
722 }
723 Some(rewrite) => {
725 info!(
726 "kafka: resolve_broker_addr {}:{} using cached rewrite",
727 host, port
728 );
729 return_rewrite(&rewrite)
730 }
731 }
732 }
733
734 fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) {
735 self.inner.log(level, fac, log_message)
736 }
737
738 fn error(&self, error: KafkaError, reason: &str) {
739 self.inner.error(error, reason)
740 }
741
742 fn stats(&self, statistics: Statistics) {
743 self.inner.stats(statistics)
744 }
745
746 fn stats_raw(&self, statistics: &[u8]) {
747 self.inner.stats_raw(statistics)
748 }
749}
750
751impl<C> ConsumerContext for TunnelingClientContext<C>
752where
753 C: ConsumerContext,
754{
755 fn rebalance(
756 &self,
757 native_client: &NativeClient,
758 err: RDKafkaRespErr,
759 tpl: &mut TopicPartitionList,
760 ) {
761 self.inner.rebalance(native_client, err, tpl)
762 }
763
764 fn pre_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
765 self.inner.pre_rebalance(rebalance)
766 }
767
768 fn post_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
769 self.inner.post_rebalance(rebalance)
770 }
771
772 fn commit_callback(&self, result: KafkaResult<()>, offsets: &TopicPartitionList) {
773 self.inner.commit_callback(result, offsets)
774 }
775
776 fn main_queue_min_poll_interval(&self) -> Timeout {
777 self.inner.main_queue_min_poll_interval()
778 }
779}
780
781impl<C> ProducerContext for TunnelingClientContext<C>
782where
783 C: ProducerContext,
784{
785 type DeliveryOpaque = C::DeliveryOpaque;
786
787 fn delivery(
788 &self,
789 delivery_result: &DeliveryResult<'_>,
790 delivery_opaque: Self::DeliveryOpaque,
791 ) {
792 self.inner.delivery(delivery_result, delivery_opaque)
793 }
794}
795
796pub type PartitionId = i32;
798
799#[derive(Debug, thiserror::Error)]
801pub enum GetPartitionsError {
802 #[error("Topic does not exist")]
804 TopicDoesNotExist,
805 #[error(transparent)]
807 Kafka(#[from] KafkaError),
808 #[error(transparent)]
810 Other(#[from] anyhow::Error),
811}
812
813pub fn get_partitions<C: ClientContext>(
815 client: &Client<C>,
816 topic: &str,
817 timeout: Duration,
818) -> Result<Vec<PartitionId>, GetPartitionsError> {
819 let meta = client.fetch_metadata(Some(topic), timeout)?;
820 if meta.topics().len() != 1 {
821 Err(anyhow!(
822 "topic {} has {} metadata entries; expected 1",
823 topic,
824 meta.topics().len()
825 ))?;
826 }
827
828 fn check_err(err: Option<RDKafkaRespErr>) -> Result<(), GetPartitionsError> {
829 match err.map(RDKafkaErrorCode::from) {
830 Some(RDKafkaErrorCode::UnknownTopic | RDKafkaErrorCode::UnknownTopicOrPartition) => {
831 Err(GetPartitionsError::TopicDoesNotExist)
832 }
833 Some(code) => Err(anyhow!(code))?,
834 None => Ok(()),
835 }
836 }
837
838 let meta_topic = meta.topics().into_element();
839 check_err(meta_topic.error())?;
840
841 if meta_topic.name() != topic {
842 Err(anyhow!(
843 "got results for wrong topic {} (expected {})",
844 meta_topic.name(),
845 topic
846 ))?;
847 }
848
849 let mut partition_ids = Vec::with_capacity(meta_topic.partitions().len());
850 for partition_meta in meta_topic.partitions() {
851 check_err(partition_meta.error())?;
852
853 partition_ids.push(partition_meta.id());
854 }
855
856 if partition_ids.len() == 0 {
857 Err(GetPartitionsError::TopicDoesNotExist)?;
858 }
859
860 Ok(partition_ids)
861}
862
863pub const DEFAULT_KEEPALIVE: bool = true;
865pub const DEFAULT_SOCKET_TIMEOUT: Duration = Duration::from_secs(60);
868pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(600);
871pub const DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT: Duration = Duration::from_secs(30);
874pub const DEFAULT_FETCH_METADATA_TIMEOUT: Duration = Duration::from_secs(10);
876pub const DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT: Duration = Duration::from_secs(90);
879
880#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
882pub struct TimeoutConfig {
883 pub keepalive: bool,
885 pub socket_timeout: Duration,
888 pub transaction_timeout: Duration,
890 pub socket_connection_setup_timeout: Duration,
892 pub fetch_metadata_timeout: Duration,
894 pub progress_record_fetch_timeout: Duration,
896}
897
898impl Default for TimeoutConfig {
899 fn default() -> Self {
900 TimeoutConfig {
901 keepalive: DEFAULT_KEEPALIVE,
902 socket_timeout: DEFAULT_SOCKET_TIMEOUT,
903 transaction_timeout: DEFAULT_TRANSACTION_TIMEOUT,
904 socket_connection_setup_timeout: DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT,
905 fetch_metadata_timeout: DEFAULT_FETCH_METADATA_TIMEOUT,
906 progress_record_fetch_timeout: DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT,
907 }
908 }
909}
910
911impl TimeoutConfig {
912 pub fn build(
915 keepalive: bool,
916 socket_timeout: Option<Duration>,
917 transaction_timeout: Duration,
918 socket_connection_setup_timeout: Duration,
919 fetch_metadata_timeout: Duration,
920 progress_record_fetch_timeout: Option<Duration>,
921 ) -> TimeoutConfig {
922 let transaction_timeout = if transaction_timeout.as_millis() > i32::MAX.try_into().unwrap()
929 {
930 error!(
931 "transaction_timeout ({transaction_timeout:?}) greater than max \
932 of {}, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}",
933 i32::MAX
934 );
935 DEFAULT_TRANSACTION_TIMEOUT
936 } else if transaction_timeout.as_millis() < 1000 {
937 error!(
938 "transaction_timeout ({transaction_timeout:?}) less than max \
939 of 1000ms, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}"
940 );
941 DEFAULT_TRANSACTION_TIMEOUT
942 } else {
943 transaction_timeout
944 };
945
946 let progress_record_fetch_timeout_derived_default =
947 std::cmp::max(transaction_timeout, DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT);
948 let progress_record_fetch_timeout =
949 progress_record_fetch_timeout.unwrap_or(progress_record_fetch_timeout_derived_default);
950 let progress_record_fetch_timeout = if progress_record_fetch_timeout < transaction_timeout {
951 error!(
952 "progress record fetch ({progress_record_fetch_timeout:?}) less than transaction \
953 timeout ({transaction_timeout:?}), defaulting to transaction timeout {transaction_timeout:?}",
954 );
955 transaction_timeout
956 } else {
957 progress_record_fetch_timeout
958 };
959
960 let max_socket_timeout = std::cmp::min(
963 transaction_timeout + Duration::from_millis(100),
964 Duration::from_secs(300),
965 );
966 let socket_timeout_derived_default =
967 std::cmp::min(max_socket_timeout, DEFAULT_SOCKET_TIMEOUT);
968 let socket_timeout = socket_timeout.unwrap_or(socket_timeout_derived_default);
969 let socket_timeout = if socket_timeout > max_socket_timeout {
970 error!(
971 "socket_timeout ({socket_timeout:?}) greater than max \
972 of min(30000, transaction.timeout.ms + 100 ({})), \
973 defaulting to the maximum of {max_socket_timeout:?}",
974 transaction_timeout.as_millis() + 100
975 );
976 max_socket_timeout
977 } else if socket_timeout.as_millis() < 10 {
978 error!(
979 "socket_timeout ({socket_timeout:?}) less than min \
980 of 10ms, defaulting to the default of {socket_timeout_derived_default:?}"
981 );
982 socket_timeout_derived_default
983 } else {
984 socket_timeout
985 };
986
987 let socket_connection_setup_timeout =
988 if socket_connection_setup_timeout.as_millis() > i32::MAX.try_into().unwrap() {
989 error!(
990 "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
991 greater than max of {}ms, defaulting to the default \
992 of {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}",
993 i32::MAX,
994 );
995 DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
996 } else if socket_connection_setup_timeout.as_millis() < 10 {
997 error!(
998 "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
999 less than max of 10ms, defaulting to the default of \
1000 {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}"
1001 );
1002 DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
1003 } else {
1004 socket_connection_setup_timeout
1005 };
1006
1007 TimeoutConfig {
1008 keepalive,
1009 socket_timeout,
1010 transaction_timeout,
1011 socket_connection_setup_timeout,
1012 fetch_metadata_timeout,
1013 progress_record_fetch_timeout,
1014 }
1015 }
1016}
1017
1018pub fn create_new_client_config_simple() -> ClientConfig {
1021 create_new_client_config(tracing::Level::INFO, Default::default())
1022}
1023
1024pub fn create_new_client_config(
1028 tracing_level: Level,
1029 timeout_config: TimeoutConfig,
1030) -> ClientConfig {
1031 #[allow(clippy::disallowed_methods)]
1032 let mut config = ClientConfig::new();
1033
1034 let level = if tracing_level >= Level::DEBUG {
1035 RDKafkaLogLevel::Debug
1036 } else if tracing_level >= Level::INFO {
1037 RDKafkaLogLevel::Info
1038 } else if tracing_level >= Level::WARN {
1039 RDKafkaLogLevel::Warning
1040 } else {
1041 RDKafkaLogLevel::Error
1042 };
1043 tracing::debug!(target: "librdkafka", level = ?level, "Determined log level for librdkafka");
1060 config.set_log_level(level);
1061
1062 if tracing_level >= Level::DEBUG {
1067 tracing::debug!(target: "librdkafka", "Enabling debug logs for rdkafka");
1068 config.set("debug", "all");
1069 }
1070
1071 if timeout_config.keepalive {
1072 config.set("socket.keepalive.enable", "true");
1073 }
1074
1075 config.set(
1076 "socket.timeout.ms",
1077 timeout_config.socket_timeout.as_millis().to_string(),
1078 );
1079 config.set(
1080 "transaction.timeout.ms",
1081 timeout_config.transaction_timeout.as_millis().to_string(),
1082 );
1083 config.set(
1084 "socket.connection.setup.timeout.ms",
1085 timeout_config
1086 .socket_connection_setup_timeout
1087 .as_millis()
1088 .to_string(),
1089 );
1090
1091 config
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097
1098 #[mz_ore::test]
1099 fn test_connection_rule_pattern_matches() {
1100 let p = ConnectionRulePattern {
1101 prefix_wildcard: false,
1102 literal_match: "broker:9092".to_string(),
1103 suffix_wildcard: false,
1104 };
1105 assert!(p.matches("broker:9092"));
1106 assert!(!p.matches("other:9092"));
1107
1108 let p = ConnectionRulePattern {
1109 prefix_wildcard: true,
1110 literal_match: ":9092".to_string(),
1111 suffix_wildcard: false,
1112 };
1113 assert!(p.matches("any-host:9092"));
1114 assert!(!p.matches("broker:9093"));
1115
1116 let p = ConnectionRulePattern {
1117 prefix_wildcard: false,
1118 literal_match: "broker:".to_string(),
1119 suffix_wildcard: true,
1120 };
1121 assert!(p.matches("broker:9092"));
1122 assert!(!p.matches("other:9092"));
1123
1124 let p = ConnectionRulePattern {
1125 prefix_wildcard: true,
1126 literal_match: "broker".to_string(),
1127 suffix_wildcard: true,
1128 };
1129 assert!(p.matches("my-broker-host:1234"));
1130 assert!(!p.matches("other:9092"));
1131 }
1132}