mz_kafka_util/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Helpers for working with Kafka's client API.
11
12use anyhow::bail;
13use aws_config::SdkConfig;
14use fancy_regex::Regex;
15use std::collections::{BTreeMap, btree_map};
16use std::error::Error;
17use std::io;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::str::FromStr;
20use std::sync::Arc;
21use std::sync::Mutex;
22use std::time::Duration;
23use tokio::sync::watch;
24
25use anyhow::{Context, anyhow};
26use crossbeam::channel::{Receiver, Sender, unbounded};
27use mz_ore::collections::CollectionExt;
28use mz_ore::error::ErrorExt;
29use mz_ore::future::InTask;
30use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelStatus};
31use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
32use rdkafka::client::{Client, NativeClient, OAuthToken};
33use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
34use rdkafka::consumer::{ConsumerContext, Rebalance};
35use rdkafka::error::{KafkaError, KafkaResult, RDKafkaErrorCode};
36use rdkafka::producer::{DefaultProducerContext, DeliveryResult, ProducerContext};
37use rdkafka::types::RDKafkaRespErr;
38use rdkafka::util::Timeout;
39use rdkafka::{ClientContext, Statistics, TopicPartitionList};
40use serde::{Deserialize, Serialize};
41use tokio::runtime::Handle;
42use tracing::{Level, debug, error, info, trace, warn};
43
44use crate::aws;
45
46/// A reasonable default timeout when refreshing topic metadata. This is configured
47/// at a source level.
48// 30s may seem infrequent, but the default is 5m. More frequent metadata
49// refresh rates are surprising to Kafka users, as topic partition counts hardly
50// ever change in production.
51pub const DEFAULT_TOPIC_METADATA_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
52
53/// A `ClientContext` implementation that uses `tracing` instead of `log`
54/// macros.
55///
56/// All code in Materialize that constructs Kafka clients should use this
57/// context or a custom context that delegates the `log` and `error` methods to
58/// this implementation.
59pub struct MzClientContext {
60    /// The last observed error log, if any.
61    error_tx: Sender<MzKafkaError>,
62    /// A tokio watch that retains the last statistics received by rdkafka and provides async
63    /// notifications to anyone interested in subscribing.
64    statistics_tx: watch::Sender<Statistics>,
65}
66
67impl Default for MzClientContext {
68    fn default() -> Self {
69        Self::with_errors().0
70    }
71}
72
73impl MzClientContext {
74    /// Constructs a new client context and returns an mpsc `Receiver` that can be used to learn
75    /// about librdkafka errors.
76    // `crossbeam` channel receivers can be cloned, but this is intended to be used as a mpsc,
77    // until we upgrade to `1.72` and the std mpsc sender is `Sync`.
78    pub fn with_errors() -> (Self, Receiver<MzKafkaError>) {
79        let (error_tx, error_rx) = unbounded();
80        let (statistics_tx, _) = watch::channel(Default::default());
81        let ctx = Self {
82            error_tx,
83            statistics_tx,
84        };
85        (ctx, error_rx)
86    }
87
88    /// Creates a tokio Watch subscription for statistics reported by librdkafka. It is necessary
89    /// that the `statistics.ms.interval` is set for this stream to contain any values.
90    pub fn subscribe_statistics(&self) -> watch::Receiver<Statistics> {
91        self.statistics_tx.subscribe()
92    }
93
94    fn record_error(&self, msg: &str) {
95        let err = match MzKafkaError::from_str(msg) {
96            Ok(err) => err,
97            Err(()) => {
98                warn!(original_error = msg, "failed to parse kafka error");
99                MzKafkaError::Internal(msg.to_owned())
100            }
101        };
102        // If no one cares about errors we drop them on the floor
103        let _ = self.error_tx.send(err);
104    }
105}
106
107/// A structured error type for errors reported by librdkafka through its logs.
108#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
109pub enum MzKafkaError {
110    /// Invalid username or password
111    #[error("Invalid username or password")]
112    InvalidCredentials,
113    /// Missing CA certificate
114    #[error("Invalid CA certificate")]
115    InvalidCACertificate,
116    /// Broker might require SSL encryption
117    #[error("Disconnected during handshake; broker might require SSL encryption")]
118    SSLEncryptionMaybeRequired,
119    /// Broker does not support SSL connections
120    #[error("Broker does not support SSL connections")]
121    SSLUnsupported,
122    /// Broker did not provide a certificate
123    #[error("Broker did not provide a certificate")]
124    BrokerCertificateMissing,
125    /// Failed to verify broker certificate
126    #[error("Failed to verify broker certificate")]
127    InvalidBrokerCertificate,
128    /// Connection reset
129    #[error("Connection reset: {0}")]
130    ConnectionReset(String),
131    /// Connection timeout
132    #[error("Connection timeout")]
133    ConnectionTimeout,
134    /// Failed to resolve hostname
135    #[error("Failed to resolve hostname")]
136    HostnameResolutionFailed,
137    /// Unsupported SASL mechanism
138    #[error("Unsupported SASL mechanism")]
139    UnsupportedSASLMechanism,
140    /// Unsupported broker version
141    #[error("Unsupported broker version")]
142    UnsupportedBrokerVersion,
143    /// Connection to broker failed
144    #[error("Broker transport failure")]
145    BrokerTransportFailure,
146    /// All brokers down
147    #[error("All brokers down")]
148    AllBrokersDown,
149    /// SASL authentication required
150    #[error("SASL authentication required")]
151    SaslAuthenticationRequired,
152    /// SASL authentication required
153    #[error("SASL authentication failed")]
154    SaslAuthenticationFailed,
155    /// SSL authentication required
156    #[error("SSL authentication required")]
157    SslAuthenticationRequired,
158    /// Unknown topic or partition
159    #[error("Unknown topic or partition")]
160    UnknownTopicOrPartition,
161    /// An internal kafka error
162    #[error("Internal kafka error: {0}")]
163    Internal(String),
164}
165
166impl FromStr for MzKafkaError {
167    type Err = ();
168
169    fn from_str(s: &str) -> Result<Self, Self::Err> {
170        if s.contains("Authentication failed: Invalid username or password") {
171            Ok(Self::InvalidCredentials)
172        } else if s.contains("broker certificate could not be verified") {
173            Ok(Self::InvalidCACertificate)
174        } else if s.contains("connecting to a SSL listener?") {
175            Ok(Self::SSLEncryptionMaybeRequired)
176        } else if s.contains("client SSL authentication might be required") {
177            Ok(Self::SslAuthenticationRequired)
178        } else if s.contains("connecting to a PLAINTEXT broker listener") {
179            Ok(Self::SSLUnsupported)
180        } else if s.contains("Broker did not provide a certificate") {
181            Ok(Self::BrokerCertificateMissing)
182        } else if s.contains("Failed to verify broker certificate: ") {
183            Ok(Self::InvalidBrokerCertificate)
184        } else if let Some((_prefix, inner)) = s.split_once("Send failed: ") {
185            Ok(Self::ConnectionReset(inner.to_owned()))
186        } else if let Some((_prefix, inner)) = s.split_once("Receive failed: ") {
187            Ok(Self::ConnectionReset(inner.to_owned()))
188        } else if s.contains("request(s) timed out: disconnect") {
189            Ok(Self::ConnectionTimeout)
190        } else if s.contains("Failed to resolve") {
191            Ok(Self::HostnameResolutionFailed)
192        } else if s.contains("mechanism handshake failed:") {
193            Ok(Self::UnsupportedSASLMechanism)
194        } else if s.contains(
195            "verify that security.protocol is correctly configured, \
196            broker might require SASL authentication",
197        ) {
198            Ok(Self::SaslAuthenticationRequired)
199        } else if s.contains("SASL authentication error: Authentication failed") {
200            Ok(Self::SaslAuthenticationFailed)
201        } else if s
202            .contains("incorrect security.protocol configuration (connecting to a SSL listener?)")
203        {
204            Ok(Self::SslAuthenticationRequired)
205        } else if s.contains("probably due to broker version < 0.10") {
206            Ok(Self::UnsupportedBrokerVersion)
207        } else if s.contains("Disconnected while requesting ApiVersion")
208            || s.contains("Broker transport failure")
209            || s.contains("Connection refused")
210        {
211            Ok(Self::BrokerTransportFailure)
212        } else if Regex::new(r"(\d+)/\1 brokers are down")
213            .unwrap()
214            .is_match(s)
215            .unwrap_or_default()
216        {
217            Ok(Self::AllBrokersDown)
218        } else if s.contains("Unknown topic or partition") || s.contains("Unknown partition") {
219            Ok(Self::UnknownTopicOrPartition)
220        } else {
221            Err(())
222        }
223    }
224}
225
226impl ClientContext for MzClientContext {
227    fn log(&self, level: rdkafka::config::RDKafkaLogLevel, fac: &str, log_message: &str) {
228        use rdkafka::config::RDKafkaLogLevel::*;
229
230        // Sniff out log messages that indicate errors.
231        //
232        // We consider any event at error, critical, alert, or emergency level,
233        // for self explanatory reasons. We also consider any event with a
234        // facility of `FAIL`. librdkafka often uses info or warn level for
235        // these `FAIL` events, but as they always indicate a failure to connect
236        // to a broker we want to always treat them as errors.
237        if matches!(level, Emerg | Alert | Critical | Error) || fac == "FAIL" {
238            self.record_error(log_message);
239        }
240
241        // Copied from https://docs.rs/rdkafka/0.28.0/src/rdkafka/client.rs.html#58-79
242        // but using `tracing`
243        match level {
244            Emerg | Alert | Critical | Error => {
245                // We downgrade error messages to `warn!` level to avoid
246                // sending the errors to Sentry. Most errors are customer
247                // configuration problems that are not appropriate to send to
248                // Sentry.
249                warn!(target: "librdkafka", "error: {} {}", fac, log_message);
250            }
251            Warning => warn!(target: "librdkafka", "warning: {} {}", fac, log_message),
252            Notice => info!(target: "librdkafka", "{} {}", fac, log_message),
253            Info => info!(target: "librdkafka", "{} {}", fac, log_message),
254            Debug => debug!(target: "librdkafka", "{} {}", fac, log_message),
255        }
256    }
257
258    fn stats(&self, statistics: Statistics) {
259        self.statistics_tx.send_replace(statistics);
260    }
261
262    fn error(&self, error: KafkaError, reason: &str) {
263        self.record_error(reason);
264        // Refer to the comment in the `log` callback.
265        warn!(target: "librdkafka", "error: {}: {}", error, reason);
266    }
267}
268
269impl ConsumerContext for MzClientContext {}
270
271impl ProducerContext for MzClientContext {
272    type DeliveryOpaque = <DefaultProducerContext as ProducerContext>::DeliveryOpaque;
273    fn delivery(
274        &self,
275        delivery_result: &DeliveryResult<'_>,
276        delivery_opaque: Self::DeliveryOpaque,
277    ) {
278        DefaultProducerContext.delivery(delivery_result, delivery_opaque);
279    }
280}
281
282/// The address of a Kafka broker.
283#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
284pub struct BrokerAddr {
285    /// The broker's hostname.
286    pub host: String,
287    /// The broker's port.
288    pub port: u16,
289}
290
291/// Rewrites a broker address.
292///
293/// For use with [`TunnelingClientContext`].
294#[derive(Debug, Clone)]
295pub struct BrokerRewrite {
296    /// The rewritten hostname.
297    pub host: String,
298    /// The rewritten port.
299    ///
300    /// If unspecified, the broker's original port is left unchanged.
301    pub port: Option<u16>,
302}
303
304#[derive(Clone)]
305enum BrokerRewriteHandle {
306    Simple(BrokerRewrite),
307    SshTunnel(
308        // This ensures the ssh tunnel is not shutdown.
309        ManagedSshTunnelHandle,
310    ),
311    /// For _default_ ssh tunnels, we store an error if _creation_
312    /// of the tunnel failed, so that `tunnel_status` can return it.
313    FailedDefaultSshTunnel(String),
314}
315
316/// Tunneling clients
317/// used for re-writing ports / hosts
318#[derive(Clone)]
319pub enum TunnelConfig {
320    /// Tunnel config option for SSH tunnels
321    Ssh(SshTunnelConfig),
322    /// Re-writes internal hosts using the value, used for privatelink
323    StaticHost(String),
324    /// Performs no re-writes
325    None,
326}
327
328/// A client context that supports rewriting broker addresses.
329#[derive(Clone)]
330pub struct TunnelingClientContext<C> {
331    inner: C,
332    rewrites: Arc<Mutex<BTreeMap<BrokerAddr, BrokerRewriteHandle>>>,
333    default_tunnel: TunnelConfig,
334    in_task: InTask,
335    ssh_tunnel_manager: SshTunnelManager,
336    ssh_timeout_config: SshTimeoutConfig,
337    aws_config: Option<SdkConfig>,
338    runtime: Handle,
339}
340
341impl<C> TunnelingClientContext<C> {
342    /// Constructs a new context that wraps `inner`.
343    pub fn new(
344        inner: C,
345        runtime: Handle,
346        ssh_tunnel_manager: SshTunnelManager,
347        ssh_timeout_config: SshTimeoutConfig,
348        aws_config: Option<SdkConfig>,
349        in_task: InTask,
350    ) -> TunnelingClientContext<C> {
351        TunnelingClientContext {
352            inner,
353            rewrites: Arc::new(Mutex::new(BTreeMap::new())),
354            default_tunnel: TunnelConfig::None,
355            in_task,
356            ssh_tunnel_manager,
357            ssh_timeout_config,
358            aws_config,
359            runtime,
360        }
361    }
362
363    /// Adds the default broker rewrite rule.
364    ///
365    /// Connections to brokers that aren't specified in other rewrites will be rewritten to connect to
366    /// `rewrite_host` and `rewrite_port` instead.
367    pub fn set_default_tunnel(&mut self, tunnel: TunnelConfig) {
368        self.default_tunnel = tunnel;
369    }
370
371    /// Adds an SSH tunnel for a specific broker.
372    ///
373    /// Overrides the existing SSH tunnel or rewrite for this broker, if any.
374    ///
375    /// This tunnel allows the rewrite to evolve over time, for example, if
376    /// the ssh tunnel's address changes if it fails and restarts.
377    pub async fn add_ssh_tunnel(
378        &self,
379        broker: BrokerAddr,
380        tunnel: SshTunnelConfig,
381    ) -> Result<(), anyhow::Error> {
382        let ssh_tunnel = self
383            .ssh_tunnel_manager
384            .connect(
385                tunnel,
386                &broker.host,
387                broker.port,
388                self.ssh_timeout_config,
389                self.in_task,
390            )
391            .await
392            .context("creating ssh tunnel")?;
393
394        let mut rewrites = self.rewrites.lock().expect("poisoned");
395        rewrites.insert(broker, BrokerRewriteHandle::SshTunnel(ssh_tunnel));
396        Ok(())
397    }
398
399    /// Adds a broker rewrite rule.
400    ///
401    /// Overrides the existing SSH tunnel or rewrite for this broker, if any.
402    ///
403    /// `rewrite` is `BrokerRewrite` that specifies how to rewrite the address for `broker`.
404    pub fn add_broker_rewrite(&self, broker: BrokerAddr, rewrite: BrokerRewrite) {
405        let mut rewrites = self.rewrites.lock().expect("poisoned");
406        rewrites.insert(broker, BrokerRewriteHandle::Simple(rewrite));
407    }
408
409    /// Returns a reference to the wrapped context.
410    pub fn inner(&self) -> &C {
411        &self.inner
412    }
413
414    /// Returns a _consolidated_ `SshTunnelStatus` that communicates the status
415    /// of all active ssh tunnels `self` knows about.
416    pub fn tunnel_status(&self) -> SshTunnelStatus {
417        self.rewrites
418            .lock()
419            .expect("poisoned")
420            .values()
421            .map(|handle| match handle {
422                BrokerRewriteHandle::SshTunnel(s) => s.check_status(),
423                BrokerRewriteHandle::FailedDefaultSshTunnel(e) => {
424                    SshTunnelStatus::Errored(e.clone())
425                }
426                BrokerRewriteHandle::Simple(_) => SshTunnelStatus::Running,
427            })
428            .fold(SshTunnelStatus::Running, |acc, status| {
429                match (acc, status) {
430                    (SshTunnelStatus::Running, SshTunnelStatus::Errored(e))
431                    | (SshTunnelStatus::Errored(e), SshTunnelStatus::Running) => {
432                        SshTunnelStatus::Errored(e)
433                    }
434                    (SshTunnelStatus::Errored(err), SshTunnelStatus::Errored(e)) => {
435                        SshTunnelStatus::Errored(format!("{}, {}", err, e))
436                    }
437                    (SshTunnelStatus::Running, SshTunnelStatus::Running) => {
438                        SshTunnelStatus::Running
439                    }
440                }
441            })
442    }
443}
444
445impl<C> ClientContext for TunnelingClientContext<C>
446where
447    C: ClientContext,
448{
449    const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;
450
451    fn generate_oauth_token(
452        &self,
453        _oauthbearer_config: Option<&str>,
454    ) -> Result<OAuthToken, Box<dyn Error>> {
455        // NOTE(benesch): We abuse the `TunnelingClientContext` to handle AWS
456        // IAM authentication because it's used in exactly the right places and
457        // already has a handle to the Tokio runtime. It might be slightly
458        // cleaner to have a separate `AwsIamAuthenticatingClientContext`, but
459        // that would be quite a bit of additional plumbing.
460
461        // NOTE(benesch): at the moment, the only OAUTHBEARER authentication we
462        // support is AWS IAM, so we can assume that if this method is invoked
463        // AWS IAM is desired. We may need to generalize this in the future.
464
465        info!(target: "librdkafka", "generating OAuth token");
466
467        let generate = || {
468            let Some(sdk_config) = &self.aws_config else {
469                bail!("internal error: AWS configuration missing");
470            };
471
472            self.runtime.block_on(aws::generate_auth_token(sdk_config))
473        };
474
475        match generate() {
476            Ok((token, lifetime_ms)) => {
477                info!(target: "librdkafka", %lifetime_ms, "successfully generated OAuth token");
478                trace!(target: "librdkafka", %token);
479                Ok(OAuthToken {
480                    token,
481                    lifetime_ms,
482                    principal_name: "".to_string(),
483                })
484            }
485            Err(e) => {
486                warn!(target: "librdkafka", "failed to generate OAuth token: {e:#}");
487                Err(e.into())
488            }
489        }
490    }
491
492    fn resolve_broker_addr(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, io::Error> {
493        let return_rewrite = |rewrite: &BrokerRewriteHandle| -> Result<Vec<SocketAddr>, io::Error> {
494            let rewrite = match rewrite {
495                BrokerRewriteHandle::Simple(rewrite) => rewrite.clone(),
496                BrokerRewriteHandle::SshTunnel(ssh_tunnel) => {
497                    // The port for this can change over time, as the ssh tunnel is maintained through
498                    // errors.
499                    let addr = ssh_tunnel.local_addr();
500                    BrokerRewrite {
501                        host: addr.ip().to_string(),
502                        port: Some(addr.port()),
503                    }
504                }
505                BrokerRewriteHandle::FailedDefaultSshTunnel(_) => {
506                    unreachable!()
507                }
508            };
509            let rewrite_port = rewrite.port.unwrap_or(port);
510
511            info!(
512                "rewriting broker {}:{} to {}:{}",
513                host, port, rewrite.host, rewrite_port
514            );
515
516            (rewrite.host, rewrite_port)
517                .to_socket_addrs()
518                .map(|addrs| addrs.collect())
519        };
520
521        let addr = BrokerAddr {
522            host: host.into(),
523            port,
524        };
525        let rewrite = self.rewrites.lock().expect("poisoned").get(&addr).cloned();
526
527        match rewrite {
528            None | Some(BrokerRewriteHandle::FailedDefaultSshTunnel(_)) => {
529                match &self.default_tunnel {
530                    TunnelConfig::Ssh(default_tunnel) => {
531                        // Multiple users could all run `connect` at the same time; only one ssh
532                        // tunnel will ever be connected, and only one will be inserted into the
533                        // map.
534                        let ssh_tunnel = self.runtime.block_on(async {
535                            self.ssh_tunnel_manager
536                                .connect(
537                                    default_tunnel.clone(),
538                                    host,
539                                    port,
540                                    self.ssh_timeout_config,
541                                    self.in_task,
542                                )
543                                .await
544                        });
545                        match ssh_tunnel {
546                            Ok(ssh_tunnel) => {
547                                let mut rewrites = self.rewrites.lock().expect("poisoned");
548                                let rewrite = match rewrites.entry(addr.clone()) {
549                                    btree_map::Entry::Occupied(mut o)
550                                        if matches!(
551                                            o.get(),
552                                            BrokerRewriteHandle::FailedDefaultSshTunnel(_)
553                                        ) =>
554                                    {
555                                        o.insert(BrokerRewriteHandle::SshTunnel(
556                                            ssh_tunnel.clone(),
557                                        ));
558                                        o.into_mut()
559                                    }
560                                    btree_map::Entry::Occupied(o) => o.into_mut(),
561                                    btree_map::Entry::Vacant(v) => {
562                                        v.insert(BrokerRewriteHandle::SshTunnel(ssh_tunnel.clone()))
563                                    }
564                                };
565
566                                return_rewrite(rewrite)
567                            }
568                            Err(e) => {
569                                warn!(
570                                    "failed to create ssh tunnel for {:?}: {}",
571                                    addr,
572                                    e.display_with_causes()
573                                );
574
575                                // Write an error if no one else has already written one.
576                                let mut rewrites = self.rewrites.lock().expect("poisoned");
577                                rewrites.entry(addr.clone()).or_insert_with(|| {
578                                    BrokerRewriteHandle::FailedDefaultSshTunnel(
579                                        e.to_string_with_causes(),
580                                    )
581                                });
582
583                                Err(io::Error::new(
584                                    io::ErrorKind::Other,
585                                    "creating SSH tunnel failed",
586                                ))
587                            }
588                        }
589                    }
590                    TunnelConfig::StaticHost(host) => (host.as_str(), port)
591                        .to_socket_addrs()
592                        .map(|addrs| addrs.collect()),
593                    TunnelConfig::None => {
594                        (host, port).to_socket_addrs().map(|addrs| addrs.collect())
595                    }
596                }
597            }
598            Some(rewrite) => return_rewrite(&rewrite),
599        }
600    }
601
602    fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) {
603        self.inner.log(level, fac, log_message)
604    }
605
606    fn error(&self, error: KafkaError, reason: &str) {
607        self.inner.error(error, reason)
608    }
609
610    fn stats(&self, statistics: Statistics) {
611        self.inner.stats(statistics)
612    }
613
614    fn stats_raw(&self, statistics: &[u8]) {
615        self.inner.stats_raw(statistics)
616    }
617}
618
619impl<C> ConsumerContext for TunnelingClientContext<C>
620where
621    C: ConsumerContext,
622{
623    fn rebalance(
624        &self,
625        native_client: &NativeClient,
626        err: RDKafkaRespErr,
627        tpl: &mut TopicPartitionList,
628    ) {
629        self.inner.rebalance(native_client, err, tpl)
630    }
631
632    fn pre_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
633        self.inner.pre_rebalance(rebalance)
634    }
635
636    fn post_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
637        self.inner.post_rebalance(rebalance)
638    }
639
640    fn commit_callback(&self, result: KafkaResult<()>, offsets: &TopicPartitionList) {
641        self.inner.commit_callback(result, offsets)
642    }
643
644    fn main_queue_min_poll_interval(&self) -> Timeout {
645        self.inner.main_queue_min_poll_interval()
646    }
647}
648
649impl<C> ProducerContext for TunnelingClientContext<C>
650where
651    C: ProducerContext,
652{
653    type DeliveryOpaque = C::DeliveryOpaque;
654
655    fn delivery(
656        &self,
657        delivery_result: &DeliveryResult<'_>,
658        delivery_opaque: Self::DeliveryOpaque,
659    ) {
660        self.inner.delivery(delivery_result, delivery_opaque)
661    }
662}
663
664/// Id of a partition in a topic.
665pub type PartitionId = i32;
666
667/// The error returned by [`get_partitions`].
668#[derive(Debug, thiserror::Error)]
669pub enum GetPartitionsError {
670    /// The specified topic does not exist.
671    #[error("Topic does not exist")]
672    TopicDoesNotExist,
673    /// A Kafka error.
674    #[error(transparent)]
675    Kafka(#[from] KafkaError),
676    /// An unstructured error.
677    #[error(transparent)]
678    Other(#[from] anyhow::Error),
679}
680
681/// Retrieve number of partitions for a given `topic` using the given `client`
682pub fn get_partitions<C: ClientContext>(
683    client: &Client<C>,
684    topic: &str,
685    timeout: Duration,
686) -> Result<Vec<PartitionId>, GetPartitionsError> {
687    let meta = client.fetch_metadata(Some(topic), timeout)?;
688    if meta.topics().len() != 1 {
689        Err(anyhow!(
690            "topic {} has {} metadata entries; expected 1",
691            topic,
692            meta.topics().len()
693        ))?;
694    }
695
696    fn check_err(err: Option<RDKafkaRespErr>) -> Result<(), GetPartitionsError> {
697        match err.map(RDKafkaErrorCode::from) {
698            Some(RDKafkaErrorCode::UnknownTopic | RDKafkaErrorCode::UnknownTopicOrPartition) => {
699                Err(GetPartitionsError::TopicDoesNotExist)
700            }
701            Some(code) => Err(anyhow!(code))?,
702            None => Ok(()),
703        }
704    }
705
706    let meta_topic = meta.topics().into_element();
707    check_err(meta_topic.error())?;
708
709    if meta_topic.name() != topic {
710        Err(anyhow!(
711            "got results for wrong topic {} (expected {})",
712            meta_topic.name(),
713            topic
714        ))?;
715    }
716
717    let mut partition_ids = Vec::with_capacity(meta_topic.partitions().len());
718    for partition_meta in meta_topic.partitions() {
719        check_err(partition_meta.error())?;
720
721        partition_ids.push(partition_meta.id());
722    }
723
724    if partition_ids.len() == 0 {
725        Err(GetPartitionsError::TopicDoesNotExist)?;
726    }
727
728    Ok(partition_ids)
729}
730
731/// Default to true as they have no downsides <https://github.com/confluentinc/librdkafka/issues/283>.
732pub const DEFAULT_KEEPALIVE: bool = true;
733/// The `rdkafka` default.
734/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
735pub const DEFAULT_SOCKET_TIMEOUT: Duration = Duration::from_secs(60);
736/// Increased from the rdkafka default
737/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
738pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(600);
739/// The `rdkafka` default.
740/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
741pub const DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT: Duration = Duration::from_secs(30);
742/// A reasonable default timeout when fetching metadata or partitions.
743pub const DEFAULT_FETCH_METADATA_TIMEOUT: Duration = Duration::from_secs(10);
744/// The timeout for reading records from the progress topic. Set to something slightly longer than
745/// the idle transaction timeout (60s) to wait out any stuck producers.
746pub const DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT: Duration = Duration::from_secs(90);
747
748/// Configurable timeouts for Kafka connections.
749#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
750pub struct TimeoutConfig {
751    /// Whether or not to enable
752    pub keepalive: bool,
753    /// The timeout for network requests. Can't be more than 100ms longer than
754    /// `transaction_timeout.
755    pub socket_timeout: Duration,
756    /// The timeout for transactions.
757    pub transaction_timeout: Duration,
758    /// The timeout for setting up network connections.
759    pub socket_connection_setup_timeout: Duration,
760    /// The timeout for fetching metadata from upstream.
761    pub fetch_metadata_timeout: Duration,
762    /// The timeout for reading records from the progress topic.
763    pub progress_record_fetch_timeout: Duration,
764}
765
766impl Default for TimeoutConfig {
767    fn default() -> Self {
768        TimeoutConfig {
769            keepalive: DEFAULT_KEEPALIVE,
770            socket_timeout: DEFAULT_SOCKET_TIMEOUT,
771            transaction_timeout: DEFAULT_TRANSACTION_TIMEOUT,
772            socket_connection_setup_timeout: DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT,
773            fetch_metadata_timeout: DEFAULT_FETCH_METADATA_TIMEOUT,
774            progress_record_fetch_timeout: DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT,
775        }
776    }
777}
778
779impl TimeoutConfig {
780    /// Build a `TcpTimeoutConfig` from the given parameters. Parameters outside the supported
781    /// range are defaulted and cause an error log.
782    pub fn build(
783        keepalive: bool,
784        socket_timeout: Option<Duration>,
785        transaction_timeout: Duration,
786        socket_connection_setup_timeout: Duration,
787        fetch_metadata_timeout: Duration,
788        progress_record_fetch_timeout: Option<Duration>,
789    ) -> TimeoutConfig {
790        // Constrain values based on ranges here:
791        // <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
792        //
793        // Note we error log but do not fail as this is called in a non-fallible
794        // LD-sync in the adapter.
795
796        let transaction_timeout = if transaction_timeout.as_millis() > i32::MAX.try_into().unwrap()
797        {
798            error!(
799                "transaction_timeout ({transaction_timeout:?}) greater than max \
800                of {}, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}",
801                i32::MAX
802            );
803            DEFAULT_TRANSACTION_TIMEOUT
804        } else if transaction_timeout.as_millis() < 1000 {
805            error!(
806                "transaction_timeout ({transaction_timeout:?}) less than max \
807                of 1000ms, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}"
808            );
809            DEFAULT_TRANSACTION_TIMEOUT
810        } else {
811            transaction_timeout
812        };
813
814        let progress_record_fetch_timeout_derived_default =
815            std::cmp::max(transaction_timeout, DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT);
816        let progress_record_fetch_timeout =
817            progress_record_fetch_timeout.unwrap_or(progress_record_fetch_timeout_derived_default);
818        let progress_record_fetch_timeout = if progress_record_fetch_timeout < transaction_timeout {
819            error!(
820                "progress record fetch ({progress_record_fetch_timeout:?}) less than transaction \
821                timeout ({transaction_timeout:?}), defaulting to transaction timeout {transaction_timeout:?}",
822            );
823            transaction_timeout
824        } else {
825            progress_record_fetch_timeout
826        };
827
828        // The documented max here is `300000`, but rdkafka bans `socket.timeout.ms` being more
829        // than `transaction.timeout.ms` + 100ms.
830        let max_socket_timeout = std::cmp::min(
831            transaction_timeout + Duration::from_millis(100),
832            Duration::from_secs(300),
833        );
834        let socket_timeout_derived_default =
835            std::cmp::min(max_socket_timeout, DEFAULT_SOCKET_TIMEOUT);
836        let socket_timeout = socket_timeout.unwrap_or(socket_timeout_derived_default);
837        let socket_timeout = if socket_timeout > max_socket_timeout {
838            error!(
839                "socket_timeout ({socket_timeout:?}) greater than max \
840                of min(30000, transaction.timeout.ms + 100 ({})), \
841                defaulting to the maximum of {max_socket_timeout:?}",
842                transaction_timeout.as_millis() + 100
843            );
844            max_socket_timeout
845        } else if socket_timeout.as_millis() < 10 {
846            error!(
847                "socket_timeout ({socket_timeout:?}) less than min \
848                of 10ms, defaulting to the default of {socket_timeout_derived_default:?}"
849            );
850            socket_timeout_derived_default
851        } else {
852            socket_timeout
853        };
854
855        let socket_connection_setup_timeout =
856            if socket_connection_setup_timeout.as_millis() > i32::MAX.try_into().unwrap() {
857                error!(
858                    "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
859                    greater than max of {}ms, defaulting to the default \
860                    of {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}",
861                    i32::MAX,
862                );
863                DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
864            } else if socket_connection_setup_timeout.as_millis() < 10 {
865                error!(
866                    "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
867                    less than max of 10ms, defaulting to the default of \
868                {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}"
869                );
870                DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
871            } else {
872                socket_connection_setup_timeout
873            };
874
875        TimeoutConfig {
876            keepalive,
877            socket_timeout,
878            transaction_timeout,
879            socket_connection_setup_timeout,
880            fetch_metadata_timeout,
881            progress_record_fetch_timeout,
882        }
883    }
884}
885
886/// A simpler version of [`create_new_client_config`] that defaults
887/// the `log_level` to `INFO` and should only be used in tests.
888pub fn create_new_client_config_simple() -> ClientConfig {
889    create_new_client_config(tracing::Level::INFO, Default::default())
890}
891
892/// Build a new [`rdkafka`] [`ClientConfig`] with its `log_level` set correctly
893/// based on the passed through [`tracing::Level`]. This level should be
894/// determined for `target: "librdkafka"`.
895pub fn create_new_client_config(
896    tracing_level: Level,
897    timeout_config: TimeoutConfig,
898) -> ClientConfig {
899    #[allow(clippy::disallowed_methods)]
900    let mut config = ClientConfig::new();
901
902    let level = if tracing_level >= Level::DEBUG {
903        RDKafkaLogLevel::Debug
904    } else if tracing_level >= Level::INFO {
905        RDKafkaLogLevel::Info
906    } else if tracing_level >= Level::WARN {
907        RDKafkaLogLevel::Warning
908    } else {
909        RDKafkaLogLevel::Error
910    };
911    // WARNING WARNING WARNING
912    //
913    // For whatever reason, if you change this `target` to something else, this
914    // log line might break. I (guswynn) did some extensive investigation with
915    // the tracing folks, and we learned that this edge case only happens with
916    // 1. a different target
917    // 2. only this file (so far as I can tell)
918    // 3. only in certain subscriber combinations
919    // 4. only if the `tracing-log` feature is on.
920    //
921    // Our conclusion was that one of our dependencies is doing something
922    // problematic with `log`.
923    //
924    // For now, this works, and prints a nice log line exactly when we want it.
925    //
926    // TODO(guswynn): when we can remove `tracing-log`, remove this warning
927    tracing::debug!(target: "librdkafka", level = ?level, "Determined log level for librdkafka");
928    config.set_log_level(level);
929
930    // Patch the librdkafka debug log system into the Rust `log` ecosystem. This
931    // is a very simple integration at the moment; enabling `debug`-level logs
932    // for the `librdkafka` target enables the full firehouse of librdkafka
933    // debug logs. We may want to investigate finer-grained control.
934    if tracing_level >= Level::DEBUG {
935        tracing::debug!(target: "librdkafka", "Enabling debug logs for rdkafka");
936        config.set("debug", "all");
937    }
938
939    if timeout_config.keepalive {
940        config.set("socket.keepalive.enable", "true");
941    }
942
943    config.set(
944        "socket.timeout.ms",
945        timeout_config.socket_timeout.as_millis().to_string(),
946    );
947    config.set(
948        "transaction.timeout.ms",
949        timeout_config.transaction_timeout.as_millis().to_string(),
950    );
951    config.set(
952        "socket.connection.setup.timeout.ms",
953        timeout_config
954            .socket_connection_setup_timeout
955            .as_millis()
956            .to_string(),
957    );
958
959    config
960}