1use anyhow::bail;
13use aws_config::SdkConfig;
14use fancy_regex::Regex;
15use std::collections::{BTreeMap, btree_map};
16use std::error::Error;
17use std::io;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::str::FromStr;
20use std::sync::Arc;
21use std::sync::Mutex;
22use std::time::Duration;
23use tokio::sync::watch;
24
25use anyhow::{Context, anyhow};
26use crossbeam::channel::{Receiver, Sender, unbounded};
27use mz_ore::collections::CollectionExt;
28use mz_ore::error::ErrorExt;
29use mz_ore::future::InTask;
30use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelStatus};
31use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
32use rdkafka::client::{Client, NativeClient, OAuthToken};
33use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
34use rdkafka::consumer::{ConsumerContext, Rebalance};
35use rdkafka::error::{KafkaError, KafkaResult, RDKafkaErrorCode};
36use rdkafka::producer::{DefaultProducerContext, DeliveryResult, ProducerContext};
37use rdkafka::types::RDKafkaRespErr;
38use rdkafka::util::Timeout;
39use rdkafka::{ClientContext, Statistics, TopicPartitionList};
40use serde::{Deserialize, Serialize};
41use tokio::runtime::Handle;
42use tracing::{Level, debug, error, info, trace, warn};
43
44use crate::aws;
45
46pub const DEFAULT_TOPIC_METADATA_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
52
53pub struct MzClientContext {
60 error_tx: Sender<MzKafkaError>,
62 statistics_tx: watch::Sender<Statistics>,
65}
66
67impl Default for MzClientContext {
68 fn default() -> Self {
69 Self::with_errors().0
70 }
71}
72
73impl MzClientContext {
74 pub fn with_errors() -> (Self, Receiver<MzKafkaError>) {
79 let (error_tx, error_rx) = unbounded();
80 let (statistics_tx, _) = watch::channel(Default::default());
81 let ctx = Self {
82 error_tx,
83 statistics_tx,
84 };
85 (ctx, error_rx)
86 }
87
88 pub fn subscribe_statistics(&self) -> watch::Receiver<Statistics> {
91 self.statistics_tx.subscribe()
92 }
93
94 fn record_error(&self, msg: &str) {
95 let err = match MzKafkaError::from_str(msg) {
96 Ok(err) => err,
97 Err(()) => {
98 warn!(original_error = msg, "failed to parse kafka error");
99 MzKafkaError::Internal(msg.to_owned())
100 }
101 };
102 let _ = self.error_tx.send(err);
104 }
105}
106
107#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
109pub enum MzKafkaError {
110 #[error("Invalid username or password")]
112 InvalidCredentials,
113 #[error("Invalid CA certificate")]
115 InvalidCACertificate,
116 #[error("Disconnected during handshake; broker might require SSL encryption")]
118 SSLEncryptionMaybeRequired,
119 #[error("Broker does not support SSL connections")]
121 SSLUnsupported,
122 #[error("Broker did not provide a certificate")]
124 BrokerCertificateMissing,
125 #[error("Failed to verify broker certificate")]
127 InvalidBrokerCertificate,
128 #[error("Connection reset: {0}")]
130 ConnectionReset(String),
131 #[error("Connection timeout")]
133 ConnectionTimeout,
134 #[error("Failed to resolve hostname")]
136 HostnameResolutionFailed,
137 #[error("Unsupported SASL mechanism")]
139 UnsupportedSASLMechanism,
140 #[error("Unsupported broker version")]
142 UnsupportedBrokerVersion,
143 #[error("Broker transport failure")]
145 BrokerTransportFailure,
146 #[error("All brokers down")]
148 AllBrokersDown,
149 #[error("SASL authentication required")]
151 SaslAuthenticationRequired,
152 #[error("SASL authentication failed")]
154 SaslAuthenticationFailed,
155 #[error("SSL authentication required")]
157 SslAuthenticationRequired,
158 #[error("Unknown topic or partition")]
160 UnknownTopicOrPartition,
161 #[error("Internal kafka error: {0}")]
163 Internal(String),
164}
165
166impl FromStr for MzKafkaError {
167 type Err = ();
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 if s.contains("Authentication failed: Invalid username or password") {
171 Ok(Self::InvalidCredentials)
172 } else if s.contains("broker certificate could not be verified") {
173 Ok(Self::InvalidCACertificate)
174 } else if s.contains("connecting to a SSL listener?") {
175 Ok(Self::SSLEncryptionMaybeRequired)
176 } else if s.contains("client SSL authentication might be required") {
177 Ok(Self::SslAuthenticationRequired)
178 } else if s.contains("connecting to a PLAINTEXT broker listener") {
179 Ok(Self::SSLUnsupported)
180 } else if s.contains("Broker did not provide a certificate") {
181 Ok(Self::BrokerCertificateMissing)
182 } else if s.contains("Failed to verify broker certificate: ") {
183 Ok(Self::InvalidBrokerCertificate)
184 } else if let Some((_prefix, inner)) = s.split_once("Send failed: ") {
185 Ok(Self::ConnectionReset(inner.to_owned()))
186 } else if let Some((_prefix, inner)) = s.split_once("Receive failed: ") {
187 Ok(Self::ConnectionReset(inner.to_owned()))
188 } else if s.contains("request(s) timed out: disconnect") {
189 Ok(Self::ConnectionTimeout)
190 } else if s.contains("Failed to resolve") {
191 Ok(Self::HostnameResolutionFailed)
192 } else if s.contains("mechanism handshake failed:") {
193 Ok(Self::UnsupportedSASLMechanism)
194 } else if s.contains(
195 "verify that security.protocol is correctly configured, \
196 broker might require SASL authentication",
197 ) {
198 Ok(Self::SaslAuthenticationRequired)
199 } else if s.contains("SASL authentication error: Authentication failed") {
200 Ok(Self::SaslAuthenticationFailed)
201 } else if s
202 .contains("incorrect security.protocol configuration (connecting to a SSL listener?)")
203 {
204 Ok(Self::SslAuthenticationRequired)
205 } else if s.contains("probably due to broker version < 0.10") {
206 Ok(Self::UnsupportedBrokerVersion)
207 } else if s.contains("Disconnected while requesting ApiVersion")
208 || s.contains("Broker transport failure")
209 || s.contains("Connection refused")
210 {
211 Ok(Self::BrokerTransportFailure)
212 } else if Regex::new(r"(\d+)/\1 brokers are down")
213 .unwrap()
214 .is_match(s)
215 .unwrap_or_default()
216 {
217 Ok(Self::AllBrokersDown)
218 } else if s.contains("Unknown topic or partition") || s.contains("Unknown partition") {
219 Ok(Self::UnknownTopicOrPartition)
220 } else {
221 Err(())
222 }
223 }
224}
225
226impl ClientContext for MzClientContext {
227 fn log(&self, level: rdkafka::config::RDKafkaLogLevel, fac: &str, log_message: &str) {
228 use rdkafka::config::RDKafkaLogLevel::*;
229
230 if matches!(level, Emerg | Alert | Critical | Error) || fac == "FAIL" {
238 self.record_error(log_message);
239 }
240
241 match level {
244 Emerg | Alert | Critical | Error => {
245 warn!(target: "librdkafka", "error: {} {}", fac, log_message);
250 }
251 Warning => warn!(target: "librdkafka", "warning: {} {}", fac, log_message),
252 Notice => info!(target: "librdkafka", "{} {}", fac, log_message),
253 Info => info!(target: "librdkafka", "{} {}", fac, log_message),
254 Debug => debug!(target: "librdkafka", "{} {}", fac, log_message),
255 }
256 }
257
258 fn stats(&self, statistics: Statistics) {
259 self.statistics_tx.send_replace(statistics);
260 }
261
262 fn error(&self, error: KafkaError, reason: &str) {
263 self.record_error(reason);
264 warn!(target: "librdkafka", "error: {}: {}", error, reason);
266 }
267}
268
269impl ConsumerContext for MzClientContext {}
270
271impl ProducerContext for MzClientContext {
272 type DeliveryOpaque = <DefaultProducerContext as ProducerContext>::DeliveryOpaque;
273 fn delivery(
274 &self,
275 delivery_result: &DeliveryResult<'_>,
276 delivery_opaque: Self::DeliveryOpaque,
277 ) {
278 DefaultProducerContext.delivery(delivery_result, delivery_opaque);
279 }
280}
281
282#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
284pub struct BrokerAddr {
285 pub host: String,
287 pub port: u16,
289}
290
291#[derive(Debug, Clone)]
295pub struct BrokerRewrite {
296 pub host: String,
298 pub port: Option<u16>,
302}
303
304#[derive(Clone)]
305enum BrokerRewriteHandle {
306 Simple(BrokerRewrite),
307 SshTunnel(
308 ManagedSshTunnelHandle,
310 ),
311 FailedDefaultSshTunnel(String),
314}
315
316#[derive(Clone)]
319pub enum TunnelConfig {
320 Ssh(SshTunnelConfig),
322 StaticHost(String),
324 None,
326}
327
328#[derive(Clone)]
330pub struct TunnelingClientContext<C> {
331 inner: C,
332 rewrites: Arc<Mutex<BTreeMap<BrokerAddr, BrokerRewriteHandle>>>,
333 default_tunnel: TunnelConfig,
334 in_task: InTask,
335 ssh_tunnel_manager: SshTunnelManager,
336 ssh_timeout_config: SshTimeoutConfig,
337 aws_config: Option<SdkConfig>,
338 runtime: Handle,
339}
340
341impl<C> TunnelingClientContext<C> {
342 pub fn new(
344 inner: C,
345 runtime: Handle,
346 ssh_tunnel_manager: SshTunnelManager,
347 ssh_timeout_config: SshTimeoutConfig,
348 aws_config: Option<SdkConfig>,
349 in_task: InTask,
350 ) -> TunnelingClientContext<C> {
351 TunnelingClientContext {
352 inner,
353 rewrites: Arc::new(Mutex::new(BTreeMap::new())),
354 default_tunnel: TunnelConfig::None,
355 in_task,
356 ssh_tunnel_manager,
357 ssh_timeout_config,
358 aws_config,
359 runtime,
360 }
361 }
362
363 pub fn set_default_tunnel(&mut self, tunnel: TunnelConfig) {
368 self.default_tunnel = tunnel;
369 }
370
371 pub async fn add_ssh_tunnel(
378 &self,
379 broker: BrokerAddr,
380 tunnel: SshTunnelConfig,
381 ) -> Result<(), anyhow::Error> {
382 let ssh_tunnel = self
383 .ssh_tunnel_manager
384 .connect(
385 tunnel,
386 &broker.host,
387 broker.port,
388 self.ssh_timeout_config,
389 self.in_task,
390 )
391 .await
392 .context("creating ssh tunnel")?;
393
394 let mut rewrites = self.rewrites.lock().expect("poisoned");
395 rewrites.insert(broker, BrokerRewriteHandle::SshTunnel(ssh_tunnel));
396 Ok(())
397 }
398
399 pub fn add_broker_rewrite(&self, broker: BrokerAddr, rewrite: BrokerRewrite) {
405 let mut rewrites = self.rewrites.lock().expect("poisoned");
406 rewrites.insert(broker, BrokerRewriteHandle::Simple(rewrite));
407 }
408
409 pub fn inner(&self) -> &C {
411 &self.inner
412 }
413
414 pub fn tunnel_status(&self) -> SshTunnelStatus {
417 self.rewrites
418 .lock()
419 .expect("poisoned")
420 .values()
421 .map(|handle| match handle {
422 BrokerRewriteHandle::SshTunnel(s) => s.check_status(),
423 BrokerRewriteHandle::FailedDefaultSshTunnel(e) => {
424 SshTunnelStatus::Errored(e.clone())
425 }
426 BrokerRewriteHandle::Simple(_) => SshTunnelStatus::Running,
427 })
428 .fold(SshTunnelStatus::Running, |acc, status| {
429 match (acc, status) {
430 (SshTunnelStatus::Running, SshTunnelStatus::Errored(e))
431 | (SshTunnelStatus::Errored(e), SshTunnelStatus::Running) => {
432 SshTunnelStatus::Errored(e)
433 }
434 (SshTunnelStatus::Errored(err), SshTunnelStatus::Errored(e)) => {
435 SshTunnelStatus::Errored(format!("{}, {}", err, e))
436 }
437 (SshTunnelStatus::Running, SshTunnelStatus::Running) => {
438 SshTunnelStatus::Running
439 }
440 }
441 })
442 }
443}
444
445impl<C> ClientContext for TunnelingClientContext<C>
446where
447 C: ClientContext,
448{
449 const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;
450
451 fn generate_oauth_token(
452 &self,
453 _oauthbearer_config: Option<&str>,
454 ) -> Result<OAuthToken, Box<dyn Error>> {
455 info!(target: "librdkafka", "generating OAuth token");
466
467 let generate = || {
468 let Some(sdk_config) = &self.aws_config else {
469 bail!("internal error: AWS configuration missing");
470 };
471
472 self.runtime.block_on(aws::generate_auth_token(sdk_config))
473 };
474
475 match generate() {
476 Ok((token, lifetime_ms)) => {
477 info!(target: "librdkafka", %lifetime_ms, "successfully generated OAuth token");
478 trace!(target: "librdkafka", %token);
479 Ok(OAuthToken {
480 token,
481 lifetime_ms,
482 principal_name: "".to_string(),
483 })
484 }
485 Err(e) => {
486 warn!(target: "librdkafka", "failed to generate OAuth token: {e:#}");
487 Err(e.into())
488 }
489 }
490 }
491
492 fn resolve_broker_addr(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, io::Error> {
493 let return_rewrite = |rewrite: &BrokerRewriteHandle| -> Result<Vec<SocketAddr>, io::Error> {
494 let rewrite = match rewrite {
495 BrokerRewriteHandle::Simple(rewrite) => rewrite.clone(),
496 BrokerRewriteHandle::SshTunnel(ssh_tunnel) => {
497 let addr = ssh_tunnel.local_addr();
500 BrokerRewrite {
501 host: addr.ip().to_string(),
502 port: Some(addr.port()),
503 }
504 }
505 BrokerRewriteHandle::FailedDefaultSshTunnel(_) => {
506 unreachable!()
507 }
508 };
509 let rewrite_port = rewrite.port.unwrap_or(port);
510
511 info!(
512 "rewriting broker {}:{} to {}:{}",
513 host, port, rewrite.host, rewrite_port
514 );
515
516 (rewrite.host, rewrite_port)
517 .to_socket_addrs()
518 .map(|addrs| addrs.collect())
519 };
520
521 let addr = BrokerAddr {
522 host: host.into(),
523 port,
524 };
525 let rewrite = self.rewrites.lock().expect("poisoned").get(&addr).cloned();
526
527 match rewrite {
528 None | Some(BrokerRewriteHandle::FailedDefaultSshTunnel(_)) => {
529 match &self.default_tunnel {
530 TunnelConfig::Ssh(default_tunnel) => {
531 let ssh_tunnel = self.runtime.block_on(async {
535 self.ssh_tunnel_manager
536 .connect(
537 default_tunnel.clone(),
538 host,
539 port,
540 self.ssh_timeout_config,
541 self.in_task,
542 )
543 .await
544 });
545 match ssh_tunnel {
546 Ok(ssh_tunnel) => {
547 let mut rewrites = self.rewrites.lock().expect("poisoned");
548 let rewrite = match rewrites.entry(addr.clone()) {
549 btree_map::Entry::Occupied(mut o)
550 if matches!(
551 o.get(),
552 BrokerRewriteHandle::FailedDefaultSshTunnel(_)
553 ) =>
554 {
555 o.insert(BrokerRewriteHandle::SshTunnel(
556 ssh_tunnel.clone(),
557 ));
558 o.into_mut()
559 }
560 btree_map::Entry::Occupied(o) => o.into_mut(),
561 btree_map::Entry::Vacant(v) => {
562 v.insert(BrokerRewriteHandle::SshTunnel(ssh_tunnel.clone()))
563 }
564 };
565
566 return_rewrite(rewrite)
567 }
568 Err(e) => {
569 warn!(
570 "failed to create ssh tunnel for {:?}: {}",
571 addr,
572 e.display_with_causes()
573 );
574
575 let mut rewrites = self.rewrites.lock().expect("poisoned");
577 rewrites.entry(addr.clone()).or_insert_with(|| {
578 BrokerRewriteHandle::FailedDefaultSshTunnel(
579 e.to_string_with_causes(),
580 )
581 });
582
583 Err(io::Error::new(
584 io::ErrorKind::Other,
585 "creating SSH tunnel failed",
586 ))
587 }
588 }
589 }
590 TunnelConfig::StaticHost(host) => (host.as_str(), port)
591 .to_socket_addrs()
592 .map(|addrs| addrs.collect()),
593 TunnelConfig::None => {
594 (host, port).to_socket_addrs().map(|addrs| addrs.collect())
595 }
596 }
597 }
598 Some(rewrite) => return_rewrite(&rewrite),
599 }
600 }
601
602 fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) {
603 self.inner.log(level, fac, log_message)
604 }
605
606 fn error(&self, error: KafkaError, reason: &str) {
607 self.inner.error(error, reason)
608 }
609
610 fn stats(&self, statistics: Statistics) {
611 self.inner.stats(statistics)
612 }
613
614 fn stats_raw(&self, statistics: &[u8]) {
615 self.inner.stats_raw(statistics)
616 }
617}
618
619impl<C> ConsumerContext for TunnelingClientContext<C>
620where
621 C: ConsumerContext,
622{
623 fn rebalance(
624 &self,
625 native_client: &NativeClient,
626 err: RDKafkaRespErr,
627 tpl: &mut TopicPartitionList,
628 ) {
629 self.inner.rebalance(native_client, err, tpl)
630 }
631
632 fn pre_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
633 self.inner.pre_rebalance(rebalance)
634 }
635
636 fn post_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
637 self.inner.post_rebalance(rebalance)
638 }
639
640 fn commit_callback(&self, result: KafkaResult<()>, offsets: &TopicPartitionList) {
641 self.inner.commit_callback(result, offsets)
642 }
643
644 fn main_queue_min_poll_interval(&self) -> Timeout {
645 self.inner.main_queue_min_poll_interval()
646 }
647}
648
649impl<C> ProducerContext for TunnelingClientContext<C>
650where
651 C: ProducerContext,
652{
653 type DeliveryOpaque = C::DeliveryOpaque;
654
655 fn delivery(
656 &self,
657 delivery_result: &DeliveryResult<'_>,
658 delivery_opaque: Self::DeliveryOpaque,
659 ) {
660 self.inner.delivery(delivery_result, delivery_opaque)
661 }
662}
663
664pub type PartitionId = i32;
666
667#[derive(Debug, thiserror::Error)]
669pub enum GetPartitionsError {
670 #[error("Topic does not exist")]
672 TopicDoesNotExist,
673 #[error(transparent)]
675 Kafka(#[from] KafkaError),
676 #[error(transparent)]
678 Other(#[from] anyhow::Error),
679}
680
681pub fn get_partitions<C: ClientContext>(
683 client: &Client<C>,
684 topic: &str,
685 timeout: Duration,
686) -> Result<Vec<PartitionId>, GetPartitionsError> {
687 let meta = client.fetch_metadata(Some(topic), timeout)?;
688 if meta.topics().len() != 1 {
689 Err(anyhow!(
690 "topic {} has {} metadata entries; expected 1",
691 topic,
692 meta.topics().len()
693 ))?;
694 }
695
696 fn check_err(err: Option<RDKafkaRespErr>) -> Result<(), GetPartitionsError> {
697 match err.map(RDKafkaErrorCode::from) {
698 Some(RDKafkaErrorCode::UnknownTopic | RDKafkaErrorCode::UnknownTopicOrPartition) => {
699 Err(GetPartitionsError::TopicDoesNotExist)
700 }
701 Some(code) => Err(anyhow!(code))?,
702 None => Ok(()),
703 }
704 }
705
706 let meta_topic = meta.topics().into_element();
707 check_err(meta_topic.error())?;
708
709 if meta_topic.name() != topic {
710 Err(anyhow!(
711 "got results for wrong topic {} (expected {})",
712 meta_topic.name(),
713 topic
714 ))?;
715 }
716
717 let mut partition_ids = Vec::with_capacity(meta_topic.partitions().len());
718 for partition_meta in meta_topic.partitions() {
719 check_err(partition_meta.error())?;
720
721 partition_ids.push(partition_meta.id());
722 }
723
724 if partition_ids.len() == 0 {
725 Err(GetPartitionsError::TopicDoesNotExist)?;
726 }
727
728 Ok(partition_ids)
729}
730
731pub const DEFAULT_KEEPALIVE: bool = true;
733pub const DEFAULT_SOCKET_TIMEOUT: Duration = Duration::from_secs(60);
736pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(600);
739pub const DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT: Duration = Duration::from_secs(30);
742pub const DEFAULT_FETCH_METADATA_TIMEOUT: Duration = Duration::from_secs(10);
744pub const DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT: Duration = Duration::from_secs(90);
747
748#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
750pub struct TimeoutConfig {
751 pub keepalive: bool,
753 pub socket_timeout: Duration,
756 pub transaction_timeout: Duration,
758 pub socket_connection_setup_timeout: Duration,
760 pub fetch_metadata_timeout: Duration,
762 pub progress_record_fetch_timeout: Duration,
764}
765
766impl Default for TimeoutConfig {
767 fn default() -> Self {
768 TimeoutConfig {
769 keepalive: DEFAULT_KEEPALIVE,
770 socket_timeout: DEFAULT_SOCKET_TIMEOUT,
771 transaction_timeout: DEFAULT_TRANSACTION_TIMEOUT,
772 socket_connection_setup_timeout: DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT,
773 fetch_metadata_timeout: DEFAULT_FETCH_METADATA_TIMEOUT,
774 progress_record_fetch_timeout: DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT,
775 }
776 }
777}
778
779impl TimeoutConfig {
780 pub fn build(
783 keepalive: bool,
784 socket_timeout: Option<Duration>,
785 transaction_timeout: Duration,
786 socket_connection_setup_timeout: Duration,
787 fetch_metadata_timeout: Duration,
788 progress_record_fetch_timeout: Option<Duration>,
789 ) -> TimeoutConfig {
790 let transaction_timeout = if transaction_timeout.as_millis() > i32::MAX.try_into().unwrap()
797 {
798 error!(
799 "transaction_timeout ({transaction_timeout:?}) greater than max \
800 of {}, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}",
801 i32::MAX
802 );
803 DEFAULT_TRANSACTION_TIMEOUT
804 } else if transaction_timeout.as_millis() < 1000 {
805 error!(
806 "transaction_timeout ({transaction_timeout:?}) less than max \
807 of 1000ms, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}"
808 );
809 DEFAULT_TRANSACTION_TIMEOUT
810 } else {
811 transaction_timeout
812 };
813
814 let progress_record_fetch_timeout_derived_default =
815 std::cmp::max(transaction_timeout, DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT);
816 let progress_record_fetch_timeout =
817 progress_record_fetch_timeout.unwrap_or(progress_record_fetch_timeout_derived_default);
818 let progress_record_fetch_timeout = if progress_record_fetch_timeout < transaction_timeout {
819 error!(
820 "progress record fetch ({progress_record_fetch_timeout:?}) less than transaction \
821 timeout ({transaction_timeout:?}), defaulting to transaction timeout {transaction_timeout:?}",
822 );
823 transaction_timeout
824 } else {
825 progress_record_fetch_timeout
826 };
827
828 let max_socket_timeout = std::cmp::min(
831 transaction_timeout + Duration::from_millis(100),
832 Duration::from_secs(300),
833 );
834 let socket_timeout_derived_default =
835 std::cmp::min(max_socket_timeout, DEFAULT_SOCKET_TIMEOUT);
836 let socket_timeout = socket_timeout.unwrap_or(socket_timeout_derived_default);
837 let socket_timeout = if socket_timeout > max_socket_timeout {
838 error!(
839 "socket_timeout ({socket_timeout:?}) greater than max \
840 of min(30000, transaction.timeout.ms + 100 ({})), \
841 defaulting to the maximum of {max_socket_timeout:?}",
842 transaction_timeout.as_millis() + 100
843 );
844 max_socket_timeout
845 } else if socket_timeout.as_millis() < 10 {
846 error!(
847 "socket_timeout ({socket_timeout:?}) less than min \
848 of 10ms, defaulting to the default of {socket_timeout_derived_default:?}"
849 );
850 socket_timeout_derived_default
851 } else {
852 socket_timeout
853 };
854
855 let socket_connection_setup_timeout =
856 if socket_connection_setup_timeout.as_millis() > i32::MAX.try_into().unwrap() {
857 error!(
858 "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
859 greater than max of {}ms, defaulting to the default \
860 of {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}",
861 i32::MAX,
862 );
863 DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
864 } else if socket_connection_setup_timeout.as_millis() < 10 {
865 error!(
866 "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
867 less than max of 10ms, defaulting to the default of \
868 {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}"
869 );
870 DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
871 } else {
872 socket_connection_setup_timeout
873 };
874
875 TimeoutConfig {
876 keepalive,
877 socket_timeout,
878 transaction_timeout,
879 socket_connection_setup_timeout,
880 fetch_metadata_timeout,
881 progress_record_fetch_timeout,
882 }
883 }
884}
885
886pub fn create_new_client_config_simple() -> ClientConfig {
889 create_new_client_config(tracing::Level::INFO, Default::default())
890}
891
892pub fn create_new_client_config(
896 tracing_level: Level,
897 timeout_config: TimeoutConfig,
898) -> ClientConfig {
899 #[allow(clippy::disallowed_methods)]
900 let mut config = ClientConfig::new();
901
902 let level = if tracing_level >= Level::DEBUG {
903 RDKafkaLogLevel::Debug
904 } else if tracing_level >= Level::INFO {
905 RDKafkaLogLevel::Info
906 } else if tracing_level >= Level::WARN {
907 RDKafkaLogLevel::Warning
908 } else {
909 RDKafkaLogLevel::Error
910 };
911 tracing::debug!(target: "librdkafka", level = ?level, "Determined log level for librdkafka");
928 config.set_log_level(level);
929
930 if tracing_level >= Level::DEBUG {
935 tracing::debug!(target: "librdkafka", "Enabling debug logs for rdkafka");
936 config.set("debug", "all");
937 }
938
939 if timeout_config.keepalive {
940 config.set("socket.keepalive.enable", "true");
941 }
942
943 config.set(
944 "socket.timeout.ms",
945 timeout_config.socket_timeout.as_millis().to_string(),
946 );
947 config.set(
948 "transaction.timeout.ms",
949 timeout_config.transaction_timeout.as_millis().to_string(),
950 );
951 config.set(
952 "socket.connection.setup.timeout.ms",
953 timeout_config
954 .socket_connection_setup_timeout
955 .as_millis()
956 .to_string(),
957 );
958
959 config
960}