Skip to main content

mz_kafka_util/
client.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Helpers for working with Kafka's client API.
11
12use anyhow::bail;
13use aws_config::SdkConfig;
14use fancy_regex::Regex;
15use std::collections::{BTreeMap, btree_map};
16use std::error::Error;
17use std::io;
18use std::net::{SocketAddr, ToSocketAddrs};
19use std::str::FromStr;
20use std::sync::Arc;
21use std::sync::Mutex;
22use std::time::Duration;
23use tokio::sync::watch;
24
25use anyhow::{Context, anyhow};
26use crossbeam::channel::{Receiver, Sender, unbounded};
27use mz_ore::collections::CollectionExt;
28use mz_ore::error::ErrorExt;
29use mz_ore::future::InTask;
30use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig, SshTunnelStatus};
31use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
32use rdkafka::client::{Client, NativeClient, OAuthToken};
33use rdkafka::config::{ClientConfig, RDKafkaLogLevel};
34use rdkafka::consumer::{ConsumerContext, Rebalance};
35use rdkafka::error::{KafkaError, KafkaResult, RDKafkaErrorCode};
36use rdkafka::producer::{DefaultProducerContext, DeliveryResult, ProducerContext};
37use rdkafka::types::RDKafkaRespErr;
38use rdkafka::util::Timeout;
39use rdkafka::{ClientContext, Statistics, TopicPartitionList};
40use serde::{Deserialize, Serialize};
41use tokio::runtime::Handle;
42use tracing::{Level, debug, error, info, trace, warn};
43
44use crate::aws;
45
46/// A reasonable default timeout when refreshing topic metadata. This is configured
47/// at a source level.
48// 30s may seem infrequent, but the default is 5m. More frequent metadata
49// refresh rates are surprising to Kafka users, as topic partition counts hardly
50// ever change in production.
51pub const DEFAULT_TOPIC_METADATA_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
52
53/// A `ClientContext` implementation that uses `tracing` instead of `log`
54/// macros.
55///
56/// All code in Materialize that constructs Kafka clients should use this
57/// context or a custom context that delegates the `log` and `error` methods to
58/// this implementation.
59pub struct MzClientContext {
60    /// The last observed error log, if any.
61    error_tx: Sender<MzKafkaError>,
62    /// A tokio watch that retains the last statistics received by rdkafka and provides async
63    /// notifications to anyone interested in subscribing.
64    statistics_tx: watch::Sender<Statistics>,
65}
66
67impl Default for MzClientContext {
68    fn default() -> Self {
69        Self::with_errors().0
70    }
71}
72
73impl MzClientContext {
74    /// Constructs a new client context and returns an mpsc `Receiver` that can be used to learn
75    /// about librdkafka errors.
76    // `crossbeam` channel receivers can be cloned, but this is intended to be used as a mpsc,
77    // until we upgrade to `1.72` and the std mpsc sender is `Sync`.
78    pub fn with_errors() -> (Self, Receiver<MzKafkaError>) {
79        let (error_tx, error_rx) = unbounded();
80        let (statistics_tx, _) = watch::channel(Default::default());
81        let ctx = Self {
82            error_tx,
83            statistics_tx,
84        };
85        (ctx, error_rx)
86    }
87
88    /// Creates a tokio Watch subscription for statistics reported by librdkafka. It is necessary
89    /// that the `statistics.ms.interval` is set for this stream to contain any values.
90    pub fn subscribe_statistics(&self) -> watch::Receiver<Statistics> {
91        self.statistics_tx.subscribe()
92    }
93
94    fn record_error(&self, msg: &str) {
95        let err = match MzKafkaError::from_str(msg) {
96            Ok(err) => err,
97            Err(()) => {
98                warn!(original_error = msg, "failed to parse kafka error");
99                MzKafkaError::Internal(msg.to_owned())
100            }
101        };
102        // If no one cares about errors we drop them on the floor
103        let _ = self.error_tx.send(err);
104    }
105}
106
107/// A structured error type for errors reported by librdkafka through its logs.
108#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
109pub enum MzKafkaError {
110    /// Invalid username or password
111    #[error("Invalid username or password")]
112    InvalidCredentials,
113    /// Missing CA certificate
114    #[error("Invalid CA certificate")]
115    InvalidCACertificate,
116    /// Broker might require SSL encryption
117    #[error("Disconnected during handshake; broker might require SSL encryption")]
118    SSLEncryptionMaybeRequired,
119    /// Broker does not support SSL connections
120    #[error("Broker does not support SSL connections")]
121    SSLUnsupported,
122    /// Broker did not provide a certificate
123    #[error("Broker did not provide a certificate")]
124    BrokerCertificateMissing,
125    /// Failed to verify broker certificate
126    #[error("Failed to verify broker certificate")]
127    InvalidBrokerCertificate,
128    /// Connection reset
129    #[error("Connection reset: {0}")]
130    ConnectionReset(String),
131    /// Connection timeout
132    #[error("Connection timeout")]
133    ConnectionTimeout,
134    /// Failed to resolve hostname
135    #[error("Failed to resolve hostname")]
136    HostnameResolutionFailed,
137    /// Unsupported SASL mechanism
138    #[error("Unsupported SASL mechanism")]
139    UnsupportedSASLMechanism,
140    /// Unsupported broker version
141    #[error("Unsupported broker version")]
142    UnsupportedBrokerVersion,
143    /// Connection to broker failed
144    #[error("Broker transport failure")]
145    BrokerTransportFailure,
146    /// All brokers down
147    #[error("All brokers down")]
148    AllBrokersDown,
149    /// SASL authentication required
150    #[error("SASL authentication required")]
151    SaslAuthenticationRequired,
152    /// SASL authentication required
153    #[error("SASL authentication failed")]
154    SaslAuthenticationFailed,
155    /// SSL authentication required
156    #[error("SSL authentication required")]
157    SslAuthenticationRequired,
158    /// Unknown topic or partition
159    #[error("Unknown topic or partition")]
160    UnknownTopicOrPartition,
161    /// An internal kafka error
162    #[error("Internal kafka error: {0}")]
163    Internal(String),
164}
165
166impl FromStr for MzKafkaError {
167    type Err = ();
168
169    fn from_str(s: &str) -> Result<Self, Self::Err> {
170        if s.contains("Authentication failed: Invalid username or password") {
171            Ok(Self::InvalidCredentials)
172        } else if s.contains("broker certificate could not be verified") {
173            Ok(Self::InvalidCACertificate)
174        } else if s.contains("connecting to a SSL listener?") {
175            Ok(Self::SSLEncryptionMaybeRequired)
176        } else if s.contains("client SSL authentication might be required") {
177            Ok(Self::SslAuthenticationRequired)
178        } else if s.contains("connecting to a PLAINTEXT broker listener") {
179            Ok(Self::SSLUnsupported)
180        } else if s.contains("Broker did not provide a certificate") {
181            Ok(Self::BrokerCertificateMissing)
182        } else if s.contains("Failed to verify broker certificate: ") {
183            Ok(Self::InvalidBrokerCertificate)
184        } else if let Some((_prefix, inner)) = s.split_once("Send failed: ") {
185            Ok(Self::ConnectionReset(inner.to_owned()))
186        } else if let Some((_prefix, inner)) = s.split_once("Receive failed: ") {
187            Ok(Self::ConnectionReset(inner.to_owned()))
188        } else if s.contains("request(s) timed out: disconnect") {
189            Ok(Self::ConnectionTimeout)
190        } else if s.contains("Failed to resolve") {
191            Ok(Self::HostnameResolutionFailed)
192        } else if s.contains("mechanism handshake failed:") {
193            Ok(Self::UnsupportedSASLMechanism)
194        } else if s.contains(
195            "verify that security.protocol is correctly configured, \
196            broker might require SASL authentication",
197        ) {
198            Ok(Self::SaslAuthenticationRequired)
199        } else if s.contains("SASL authentication error: Authentication failed") {
200            Ok(Self::SaslAuthenticationFailed)
201        } else if s
202            .contains("incorrect security.protocol configuration (connecting to a SSL listener?)")
203        {
204            Ok(Self::SslAuthenticationRequired)
205        } else if s.contains("probably due to broker version < 0.10") {
206            Ok(Self::UnsupportedBrokerVersion)
207        } else if s.contains("Disconnected while requesting ApiVersion")
208            || s.contains("Broker transport failure")
209            || s.contains("Connection refused")
210        {
211            Ok(Self::BrokerTransportFailure)
212        } else if Regex::new(r"(\d+)/\1 brokers are down")
213            .unwrap()
214            .is_match(s)
215            .unwrap_or_default()
216        {
217            Ok(Self::AllBrokersDown)
218        } else if s.contains("Unknown topic or partition") || s.contains("Unknown partition") {
219            Ok(Self::UnknownTopicOrPartition)
220        } else {
221            Err(())
222        }
223    }
224}
225
226impl ClientContext for MzClientContext {
227    fn log(&self, level: rdkafka::config::RDKafkaLogLevel, fac: &str, log_message: &str) {
228        use rdkafka::config::RDKafkaLogLevel::*;
229
230        // Sniff out log messages that indicate errors.
231        //
232        // We consider any event at error, critical, alert, or emergency level,
233        // for self explanatory reasons. We also consider any event with a
234        // facility of `FAIL`. librdkafka often uses info or warn level for
235        // these `FAIL` events, but as they always indicate a failure to connect
236        // to a broker we want to always treat them as errors.
237        if matches!(level, Emerg | Alert | Critical | Error) || fac == "FAIL" {
238            self.record_error(log_message);
239        }
240
241        // Copied from https://docs.rs/rdkafka/0.28.0/src/rdkafka/client.rs.html#58-79
242        // but using `tracing`
243        match level {
244            Emerg | Alert | Critical | Error => {
245                // We downgrade error messages to `warn!` level to avoid
246                // sending the errors to Sentry. Most errors are customer
247                // configuration problems that are not appropriate to send to
248                // Sentry.
249                warn!(target: "librdkafka", "error: {} {}", fac, log_message);
250            }
251            Warning => warn!(target: "librdkafka", "warning: {} {}", fac, log_message),
252            Notice => info!(target: "librdkafka", "{} {}", fac, log_message),
253            Info => info!(target: "librdkafka", "{} {}", fac, log_message),
254            Debug => debug!(target: "librdkafka", "{} {}", fac, log_message),
255        }
256    }
257
258    fn stats(&self, statistics: Statistics) {
259        self.statistics_tx.send_replace(statistics);
260    }
261
262    fn error(&self, error: KafkaError, reason: &str) {
263        self.record_error(reason);
264        // Refer to the comment in the `log` callback.
265        warn!(target: "librdkafka", "error: {}: {}", error, reason);
266    }
267}
268
269impl ConsumerContext for MzClientContext {}
270
271impl ProducerContext for MzClientContext {
272    type DeliveryOpaque = <DefaultProducerContext as ProducerContext>::DeliveryOpaque;
273    fn delivery(
274        &self,
275        delivery_result: &DeliveryResult<'_>,
276        delivery_opaque: Self::DeliveryOpaque,
277    ) {
278        DefaultProducerContext.delivery(delivery_result, delivery_opaque);
279    }
280}
281
282/// The address of a Kafka broker.
283#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
284pub struct BrokerAddr {
285    /// The broker's hostname.
286    pub host: String,
287    /// The broker's port.
288    pub port: u16,
289}
290
291impl BrokerAddr {
292    /// Attempt to resolve this broker address into a list of socket addresses.
293    pub fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
294        Ok((self.host.as_str(), self.port).to_socket_addrs()?.collect())
295    }
296}
297
298/// Rewrites a broker address.
299///
300/// For use with [`TunnelingClientContext`].
301#[derive(Debug, Clone)]
302pub struct BrokerRewrite {
303    /// The rewritten hostname.
304    pub host: String,
305    /// The rewritten port.
306    ///
307    /// If unspecified, the broker's original port is left unchanged.
308    pub port: Option<u16>,
309}
310
311impl BrokerRewrite {
312    /// Apply the rewrite to this broker address.
313    pub fn rewrite(&self, address: &BrokerAddr) -> BrokerAddr {
314        BrokerAddr {
315            host: self.host.clone(),
316            port: self.port.unwrap_or(address.port),
317        }
318    }
319}
320
321#[derive(Clone)]
322enum BrokerRewriteHandle {
323    Simple(BrokerRewrite),
324    SshTunnel(
325        // This ensures the ssh tunnel is not shutdown.
326        ManagedSshTunnelHandle,
327    ),
328    /// For _default_ ssh tunnels, we store an error if _creation_
329    /// of the tunnel failed, so that `tunnel_status` can return it.
330    FailedDefaultSshTunnel(String),
331}
332
333#[derive(Clone)]
334/// Parsed from a string, with optional leading and trailing '*' wildcards.
335pub struct ConnectionRulePattern {
336    /// If true, allow any combination of characters before the literal match.
337    pub prefix_wildcard: bool,
338    /// We expect the broker's host:port to match these characters in their entirety.
339    pub literal_match: String,
340    /// If true, allow any combination of characters after the literal match.
341    pub suffix_wildcard: bool,
342}
343
344impl std::fmt::Display for ConnectionRulePattern {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        if self.prefix_wildcard {
347            f.write_str("*")?;
348        }
349        f.write_str(&self.literal_match)?;
350        if self.suffix_wildcard {
351            f.write_str("*")?;
352        }
353        Ok(())
354    }
355}
356
357impl ConnectionRulePattern {
358    /// Does this "{host}:{port}" address fit the pattern?
359    pub fn matches(&self, address: &str) -> bool {
360        if self.prefix_wildcard {
361            if self.suffix_wildcard {
362                address.contains(&self.literal_match)
363            } else {
364                address.ends_with(&self.literal_match)
365            }
366        } else if self.suffix_wildcard {
367            address.starts_with(&self.literal_match)
368        } else {
369            address == self.literal_match
370        }
371    }
372}
373
374#[derive(Clone)]
375/// Given a host address, map it to a different host.
376pub struct HostMappingRules {
377    /// Map matching hosts to a different host. First applicable rule wins.
378    pub rules: Vec<(ConnectionRulePattern, BrokerRewrite)>,
379}
380
381impl HostMappingRules {
382    /// Rewrite this broker address according to the rules. Returns `None` when
383    /// no rule matches.
384    pub fn rewrite(&self, src: &BrokerAddr) -> Option<BrokerAddr> {
385        let address = format!("{}:{}", src.host, src.port);
386        for (pattern, dst) in &self.rules {
387            if pattern.matches(&address) {
388                let result = dst.rewrite(src);
389                info!(
390                    "HostMappingRules: broker {}:{} matched pattern '{}' -> rewriting to {}:{}",
391                    src.host, src.port, pattern, result.host, result.port,
392                );
393                return Some(result);
394            }
395        }
396
397        warn!(
398            "HostMappingRules: broker {}:{} matched no rules, using original address",
399            src.host, src.port,
400        );
401        None
402    }
403}
404
405/// Tunneling clients
406/// used for re-writing ports / hosts
407#[derive(Clone)]
408pub enum TunnelConfig {
409    /// Tunnel config option for SSH tunnels
410    Ssh(SshTunnelConfig),
411    /// Re-writes internal hosts using the value, used for privatelink
412    StaticHost(String),
413    /// Re-writes internal hosts according to an ordered list of rules, also used for privatelink
414    Rules(HostMappingRules),
415    /// Performs no re-writes
416    None,
417}
418
419/// A client context that supports rewriting broker addresses.
420#[derive(Clone)]
421pub struct TunnelingClientContext<C> {
422    inner: C,
423    rewrites: Arc<Mutex<BTreeMap<BrokerAddr, BrokerRewriteHandle>>>,
424    default_tunnel: TunnelConfig,
425    in_task: InTask,
426    ssh_tunnel_manager: SshTunnelManager,
427    ssh_timeout_config: SshTimeoutConfig,
428    aws_config: Option<SdkConfig>,
429    runtime: Handle,
430}
431
432impl<C> TunnelingClientContext<C> {
433    /// Constructs a new context that wraps `inner`.
434    pub fn new(
435        inner: C,
436        runtime: Handle,
437        ssh_tunnel_manager: SshTunnelManager,
438        ssh_timeout_config: SshTimeoutConfig,
439        aws_config: Option<SdkConfig>,
440        in_task: InTask,
441    ) -> TunnelingClientContext<C> {
442        TunnelingClientContext {
443            inner,
444            rewrites: Arc::new(Mutex::new(BTreeMap::new())),
445            default_tunnel: TunnelConfig::None,
446            in_task,
447            ssh_tunnel_manager,
448            ssh_timeout_config,
449            aws_config,
450            runtime,
451        }
452    }
453
454    /// Adds the default broker rewrite rule.
455    ///
456    /// Connections to brokers that aren't specified in other rewrites will be rewritten to connect to
457    /// `rewrite_host` and `rewrite_port` instead.
458    pub fn set_default_tunnel(&mut self, tunnel: TunnelConfig) {
459        self.default_tunnel = tunnel;
460    }
461
462    /// Adds an SSH tunnel for a specific broker.
463    ///
464    /// Overrides the existing SSH tunnel or rewrite for this broker, if any.
465    ///
466    /// This tunnel allows the rewrite to evolve over time, for example, if
467    /// the ssh tunnel's address changes if it fails and restarts.
468    pub async fn add_ssh_tunnel(
469        &self,
470        broker: BrokerAddr,
471        tunnel: SshTunnelConfig,
472    ) -> Result<(), anyhow::Error> {
473        let ssh_tunnel = self
474            .ssh_tunnel_manager
475            .connect(
476                tunnel,
477                &broker.host,
478                broker.port,
479                self.ssh_timeout_config,
480                self.in_task,
481            )
482            .await
483            .context("creating ssh tunnel")?;
484
485        let mut rewrites = self.rewrites.lock().expect("poisoned");
486        rewrites.insert(broker, BrokerRewriteHandle::SshTunnel(ssh_tunnel));
487        Ok(())
488    }
489
490    /// Adds a broker rewrite rule.
491    ///
492    /// Overrides the existing SSH tunnel or rewrite for this broker, if any.
493    ///
494    /// `rewrite` is `BrokerRewrite` that specifies how to rewrite the address for `broker`.
495    pub fn add_broker_rewrite(&self, broker: BrokerAddr, rewrite: BrokerRewrite) {
496        let mut rewrites = self.rewrites.lock().expect("poisoned");
497        rewrites.insert(broker, BrokerRewriteHandle::Simple(rewrite));
498    }
499
500    /// Returns a reference to the wrapped context.
501    pub fn inner(&self) -> &C {
502        &self.inner
503    }
504
505    /// Returns a _consolidated_ `SshTunnelStatus` that communicates the status
506    /// of all active ssh tunnels `self` knows about.
507    pub fn tunnel_status(&self) -> SshTunnelStatus {
508        self.rewrites
509            .lock()
510            .expect("poisoned")
511            .values()
512            .map(|handle| match handle {
513                BrokerRewriteHandle::SshTunnel(s) => s.check_status(),
514                BrokerRewriteHandle::FailedDefaultSshTunnel(e) => {
515                    SshTunnelStatus::Errored(e.clone())
516                }
517                BrokerRewriteHandle::Simple(_) => SshTunnelStatus::Running,
518            })
519            .fold(SshTunnelStatus::Running, |acc, status| {
520                match (acc, status) {
521                    (SshTunnelStatus::Running, SshTunnelStatus::Errored(e))
522                    | (SshTunnelStatus::Errored(e), SshTunnelStatus::Running) => {
523                        SshTunnelStatus::Errored(e)
524                    }
525                    (SshTunnelStatus::Errored(err), SshTunnelStatus::Errored(e)) => {
526                        SshTunnelStatus::Errored(format!("{}, {}", err, e))
527                    }
528                    (SshTunnelStatus::Running, SshTunnelStatus::Running) => {
529                        SshTunnelStatus::Running
530                    }
531                }
532            })
533    }
534}
535
536impl<C> ClientContext for TunnelingClientContext<C>
537where
538    C: ClientContext,
539{
540    const ENABLE_REFRESH_OAUTH_TOKEN: bool = true;
541
542    fn generate_oauth_token(
543        &self,
544        _oauthbearer_config: Option<&str>,
545    ) -> Result<OAuthToken, Box<dyn Error>> {
546        // NOTE(benesch): We abuse the `TunnelingClientContext` to handle AWS
547        // IAM authentication because it's used in exactly the right places and
548        // already has a handle to the Tokio runtime. It might be slightly
549        // cleaner to have a separate `AwsIamAuthenticatingClientContext`, but
550        // that would be quite a bit of additional plumbing.
551
552        // NOTE(benesch): at the moment, the only OAUTHBEARER authentication we
553        // support is AWS IAM, so we can assume that if this method is invoked
554        // AWS IAM is desired. We may need to generalize this in the future.
555
556        info!(target: "librdkafka", "generating OAuth token");
557
558        let generate = || {
559            let Some(sdk_config) = &self.aws_config else {
560                bail!("internal error: AWS configuration missing");
561            };
562
563            self.runtime.block_on(aws::generate_auth_token(sdk_config))
564        };
565
566        match generate() {
567            Ok((token, lifetime_ms)) => {
568                info!(target: "librdkafka", %lifetime_ms, "successfully generated OAuth token");
569                trace!(target: "librdkafka", %token);
570                Ok(OAuthToken {
571                    token,
572                    lifetime_ms,
573                    principal_name: "".to_string(),
574                })
575            }
576            Err(e) => {
577                warn!(target: "librdkafka", "failed to generate OAuth token: {e:#}");
578                Err(e.into())
579            }
580        }
581    }
582
583    /// Look up the broker's address in our book of rewrites.
584    /// If we've already rewritten it before, reuse the existing rewrite.
585    /// Otherwise, use our "default tunnel" rewriting strategy to attempt to rewrite this broker's address
586    /// and record it in the book of rewrites.
587    fn resolve_broker_addr(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>, io::Error> {
588        info!("kafka: resolve_broker_addr called for {}:{}", host, port);
589        let return_rewrite = |rewrite: &BrokerRewriteHandle| -> Result<Vec<SocketAddr>, io::Error> {
590            let rewrite = match rewrite {
591                BrokerRewriteHandle::Simple(rewrite) => rewrite.clone(),
592                BrokerRewriteHandle::SshTunnel(ssh_tunnel) => {
593                    // The port for this can change over time, as the ssh tunnel is maintained through
594                    // errors.
595                    let addr = ssh_tunnel.local_addr();
596                    BrokerRewrite {
597                        host: addr.ip().to_string(),
598                        port: Some(addr.port()),
599                    }
600                }
601                BrokerRewriteHandle::FailedDefaultSshTunnel(_) => {
602                    unreachable!()
603                }
604            };
605            let rewrite_port = rewrite.port.unwrap_or(port);
606
607            info!(
608                "rewriting broker {}:{} to {}:{}",
609                host, port, rewrite.host, rewrite_port
610            );
611
612            (rewrite.host, rewrite_port)
613                .to_socket_addrs()
614                .map(|addrs| addrs.collect())
615        };
616
617        let addr = BrokerAddr {
618            host: host.into(),
619            port,
620        };
621        let rewrite = self.rewrites.lock().expect("poisoned").get(&addr).cloned();
622
623        match rewrite {
624            // No (successful) broker address rewrite exists yet.
625            None | Some(BrokerRewriteHandle::FailedDefaultSshTunnel(_)) => {
626                // "Default tunnel" is actually the configured rewriting strategy used for brokers we haven't already rewritten.
627                match &self.default_tunnel {
628                    // This "default tunnel" is actually a default tunnel.
629                    // Try connecting so we have a valid rewrite for thsi broker address.
630                    TunnelConfig::Ssh(default_tunnel) => {
631                        // Multiple users could all run `connect` at the same time; only one ssh
632                        // tunnel will ever be connected, and only one will be inserted into the
633                        // map.
634                        let ssh_tunnel = self.runtime.block_on(async {
635                            self.ssh_tunnel_manager
636                                .connect(
637                                    default_tunnel.clone(),
638                                    host,
639                                    port,
640                                    self.ssh_timeout_config,
641                                    self.in_task,
642                                )
643                                .await
644                        });
645                        match ssh_tunnel {
646                            // Use the tunnel we just created, but only if nobody beat us in the race.
647                            Ok(ssh_tunnel) => {
648                                let mut rewrites = self.rewrites.lock().expect("poisoned");
649                                let rewrite = match rewrites.entry(addr.clone()) {
650                                    btree_map::Entry::Occupied(mut o)
651                                        if matches!(
652                                            o.get(),
653                                            BrokerRewriteHandle::FailedDefaultSshTunnel(_)
654                                        ) =>
655                                    {
656                                        o.insert(BrokerRewriteHandle::SshTunnel(
657                                            ssh_tunnel.clone(),
658                                        ));
659                                        o.into_mut()
660                                    }
661                                    btree_map::Entry::Occupied(o) => o.into_mut(),
662                                    btree_map::Entry::Vacant(v) => {
663                                        v.insert(BrokerRewriteHandle::SshTunnel(ssh_tunnel.clone()))
664                                    }
665                                };
666
667                                return_rewrite(rewrite)
668                            }
669                            // We couldn't connect. Someone else will have to try again.
670                            Err(e) => {
671                                warn!(
672                                    "failed to create ssh tunnel for {:?}: {}",
673                                    addr,
674                                    e.display_with_causes()
675                                );
676
677                                // Write an error if no one else has already written one.
678                                let mut rewrites = self.rewrites.lock().expect("poisoned");
679                                rewrites.entry(addr.clone()).or_insert_with(|| {
680                                    BrokerRewriteHandle::FailedDefaultSshTunnel(
681                                        e.to_string_with_causes(),
682                                    )
683                                });
684
685                                Err(io::Error::new(
686                                    io::ErrorKind::Other,
687                                    "creating SSH tunnel failed",
688                                ))
689                            }
690                        }
691                    }
692                    // Our rewrite strategy is to use a specific host, e.g. a PrivateLink endpoint.
693                    TunnelConfig::StaticHost(host) => (host.as_str(), port)
694                        .to_socket_addrs()
695                        .map(|addrs| addrs.collect()),
696                    // Rewrite according to the routing rules.
697                    TunnelConfig::Rules(rules) => {
698                        // If no rules match, just use the address as-is.
699                        let resolved = rules.rewrite(&addr).unwrap_or_else(|| addr.clone());
700                        match resolved.to_socket_addrs() {
701                            Ok(addrs) => {
702                                info!(
703                                    "kafka: resolve_broker_addr {}:{} -> {}:{} resolved to {:?}",
704                                    host, port, resolved.host, resolved.port, addrs,
705                                );
706                                Ok(addrs)
707                            }
708                            Err(e) => {
709                                warn!(
710                                    "kafka: resolve_broker_addr {}:{} -> {}:{} DNS resolution FAILED: {e}",
711                                    host, port, resolved.host, resolved.port,
712                                );
713                                Err(e)
714                            }
715                        }
716                    }
717                    // We leave the broker's address as it is.
718                    TunnelConfig::None => {
719                        (host, port).to_socket_addrs().map(|addrs| addrs.collect())
720                    }
721                }
722            }
723            // This broker's address was already rewritten. Reuse the existing rewrite.
724            Some(rewrite) => {
725                info!(
726                    "kafka: resolve_broker_addr {}:{} using cached rewrite",
727                    host, port
728                );
729                return_rewrite(&rewrite)
730            }
731        }
732    }
733
734    fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) {
735        self.inner.log(level, fac, log_message)
736    }
737
738    fn error(&self, error: KafkaError, reason: &str) {
739        self.inner.error(error, reason)
740    }
741
742    fn stats(&self, statistics: Statistics) {
743        self.inner.stats(statistics)
744    }
745
746    fn stats_raw(&self, statistics: &[u8]) {
747        self.inner.stats_raw(statistics)
748    }
749}
750
751impl<C> ConsumerContext for TunnelingClientContext<C>
752where
753    C: ConsumerContext,
754{
755    fn rebalance(
756        &self,
757        native_client: &NativeClient,
758        err: RDKafkaRespErr,
759        tpl: &mut TopicPartitionList,
760    ) {
761        self.inner.rebalance(native_client, err, tpl)
762    }
763
764    fn pre_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
765        self.inner.pre_rebalance(rebalance)
766    }
767
768    fn post_rebalance<'a>(&self, rebalance: &Rebalance<'a>) {
769        self.inner.post_rebalance(rebalance)
770    }
771
772    fn commit_callback(&self, result: KafkaResult<()>, offsets: &TopicPartitionList) {
773        self.inner.commit_callback(result, offsets)
774    }
775
776    fn main_queue_min_poll_interval(&self) -> Timeout {
777        self.inner.main_queue_min_poll_interval()
778    }
779}
780
781impl<C> ProducerContext for TunnelingClientContext<C>
782where
783    C: ProducerContext,
784{
785    type DeliveryOpaque = C::DeliveryOpaque;
786
787    fn delivery(
788        &self,
789        delivery_result: &DeliveryResult<'_>,
790        delivery_opaque: Self::DeliveryOpaque,
791    ) {
792        self.inner.delivery(delivery_result, delivery_opaque)
793    }
794}
795
796/// Id of a partition in a topic.
797pub type PartitionId = i32;
798
799/// The error returned by [`get_partitions`].
800#[derive(Debug, thiserror::Error)]
801pub enum GetPartitionsError {
802    /// The specified topic does not exist.
803    #[error("Topic does not exist")]
804    TopicDoesNotExist,
805    /// A Kafka error.
806    #[error(transparent)]
807    Kafka(#[from] KafkaError),
808    /// An unstructured error.
809    #[error(transparent)]
810    Other(#[from] anyhow::Error),
811}
812
813/// Retrieve number of partitions for a given `topic` using the given `client`
814pub fn get_partitions<C: ClientContext>(
815    client: &Client<C>,
816    topic: &str,
817    timeout: Duration,
818) -> Result<Vec<PartitionId>, GetPartitionsError> {
819    let meta = client.fetch_metadata(Some(topic), timeout)?;
820    if meta.topics().len() != 1 {
821        Err(anyhow!(
822            "topic {} has {} metadata entries; expected 1",
823            topic,
824            meta.topics().len()
825        ))?;
826    }
827
828    fn check_err(err: Option<RDKafkaRespErr>) -> Result<(), GetPartitionsError> {
829        match err.map(RDKafkaErrorCode::from) {
830            Some(RDKafkaErrorCode::UnknownTopic | RDKafkaErrorCode::UnknownTopicOrPartition) => {
831                Err(GetPartitionsError::TopicDoesNotExist)
832            }
833            Some(code) => Err(anyhow!(code))?,
834            None => Ok(()),
835        }
836    }
837
838    let meta_topic = meta.topics().into_element();
839    check_err(meta_topic.error())?;
840
841    if meta_topic.name() != topic {
842        Err(anyhow!(
843            "got results for wrong topic {} (expected {})",
844            meta_topic.name(),
845            topic
846        ))?;
847    }
848
849    let mut partition_ids = Vec::with_capacity(meta_topic.partitions().len());
850    for partition_meta in meta_topic.partitions() {
851        check_err(partition_meta.error())?;
852
853        partition_ids.push(partition_meta.id());
854    }
855
856    if partition_ids.len() == 0 {
857        Err(GetPartitionsError::TopicDoesNotExist)?;
858    }
859
860    Ok(partition_ids)
861}
862
863/// Default to true as they have no downsides <https://github.com/confluentinc/librdkafka/issues/283>.
864pub const DEFAULT_KEEPALIVE: bool = true;
865/// The `rdkafka` default.
866/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
867pub const DEFAULT_SOCKET_TIMEOUT: Duration = Duration::from_secs(60);
868/// Increased from the rdkafka default
869/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
870pub const DEFAULT_TRANSACTION_TIMEOUT: Duration = Duration::from_secs(600);
871/// The `rdkafka` default.
872/// - <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
873pub const DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT: Duration = Duration::from_secs(30);
874/// A reasonable default timeout when fetching metadata or partitions.
875pub const DEFAULT_FETCH_METADATA_TIMEOUT: Duration = Duration::from_secs(10);
876/// The timeout for reading records from the progress topic. Set to something slightly longer than
877/// the idle transaction timeout (60s) to wait out any stuck producers.
878pub const DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT: Duration = Duration::from_secs(90);
879
880/// Configurable timeouts for Kafka connections.
881#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
882pub struct TimeoutConfig {
883    /// Whether or not to enable
884    pub keepalive: bool,
885    /// The timeout for network requests. Can't be more than 100ms longer than
886    /// `transaction_timeout.
887    pub socket_timeout: Duration,
888    /// The timeout for transactions.
889    pub transaction_timeout: Duration,
890    /// The timeout for setting up network connections.
891    pub socket_connection_setup_timeout: Duration,
892    /// The timeout for fetching metadata from upstream.
893    pub fetch_metadata_timeout: Duration,
894    /// The timeout for reading records from the progress topic.
895    pub progress_record_fetch_timeout: Duration,
896}
897
898impl Default for TimeoutConfig {
899    fn default() -> Self {
900        TimeoutConfig {
901            keepalive: DEFAULT_KEEPALIVE,
902            socket_timeout: DEFAULT_SOCKET_TIMEOUT,
903            transaction_timeout: DEFAULT_TRANSACTION_TIMEOUT,
904            socket_connection_setup_timeout: DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT,
905            fetch_metadata_timeout: DEFAULT_FETCH_METADATA_TIMEOUT,
906            progress_record_fetch_timeout: DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT,
907        }
908    }
909}
910
911impl TimeoutConfig {
912    /// Build a `TcpTimeoutConfig` from the given parameters. Parameters outside the supported
913    /// range are defaulted and cause an error log.
914    pub fn build(
915        keepalive: bool,
916        socket_timeout: Option<Duration>,
917        transaction_timeout: Duration,
918        socket_connection_setup_timeout: Duration,
919        fetch_metadata_timeout: Duration,
920        progress_record_fetch_timeout: Option<Duration>,
921    ) -> TimeoutConfig {
922        // Constrain values based on ranges here:
923        // <https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md>
924        //
925        // Note we error log but do not fail as this is called in a non-fallible
926        // LD-sync in the adapter.
927
928        let transaction_timeout = if transaction_timeout.as_millis() > i32::MAX.try_into().unwrap()
929        {
930            error!(
931                "transaction_timeout ({transaction_timeout:?}) greater than max \
932                of {}, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}",
933                i32::MAX
934            );
935            DEFAULT_TRANSACTION_TIMEOUT
936        } else if transaction_timeout.as_millis() < 1000 {
937            error!(
938                "transaction_timeout ({transaction_timeout:?}) less than max \
939                of 1000ms, defaulting to the default of {DEFAULT_TRANSACTION_TIMEOUT:?}"
940            );
941            DEFAULT_TRANSACTION_TIMEOUT
942        } else {
943            transaction_timeout
944        };
945
946        let progress_record_fetch_timeout_derived_default =
947            std::cmp::max(transaction_timeout, DEFAULT_PROGRESS_RECORD_FETCH_TIMEOUT);
948        let progress_record_fetch_timeout =
949            progress_record_fetch_timeout.unwrap_or(progress_record_fetch_timeout_derived_default);
950        let progress_record_fetch_timeout = if progress_record_fetch_timeout < transaction_timeout {
951            error!(
952                "progress record fetch ({progress_record_fetch_timeout:?}) less than transaction \
953                timeout ({transaction_timeout:?}), defaulting to transaction timeout {transaction_timeout:?}",
954            );
955            transaction_timeout
956        } else {
957            progress_record_fetch_timeout
958        };
959
960        // The documented max here is `300000`, but rdkafka bans `socket.timeout.ms` being more
961        // than `transaction.timeout.ms` + 100ms.
962        let max_socket_timeout = std::cmp::min(
963            transaction_timeout + Duration::from_millis(100),
964            Duration::from_secs(300),
965        );
966        let socket_timeout_derived_default =
967            std::cmp::min(max_socket_timeout, DEFAULT_SOCKET_TIMEOUT);
968        let socket_timeout = socket_timeout.unwrap_or(socket_timeout_derived_default);
969        let socket_timeout = if socket_timeout > max_socket_timeout {
970            error!(
971                "socket_timeout ({socket_timeout:?}) greater than max \
972                of min(30000, transaction.timeout.ms + 100 ({})), \
973                defaulting to the maximum of {max_socket_timeout:?}",
974                transaction_timeout.as_millis() + 100
975            );
976            max_socket_timeout
977        } else if socket_timeout.as_millis() < 10 {
978            error!(
979                "socket_timeout ({socket_timeout:?}) less than min \
980                of 10ms, defaulting to the default of {socket_timeout_derived_default:?}"
981            );
982            socket_timeout_derived_default
983        } else {
984            socket_timeout
985        };
986
987        let socket_connection_setup_timeout =
988            if socket_connection_setup_timeout.as_millis() > i32::MAX.try_into().unwrap() {
989                error!(
990                    "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
991                    greater than max of {}ms, defaulting to the default \
992                    of {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}",
993                    i32::MAX,
994                );
995                DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
996            } else if socket_connection_setup_timeout.as_millis() < 10 {
997                error!(
998                    "socket_connection_setup_timeout ({socket_connection_setup_timeout:?}) \
999                    less than max of 10ms, defaulting to the default of \
1000                {DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT:?}"
1001                );
1002                DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT
1003            } else {
1004                socket_connection_setup_timeout
1005            };
1006
1007        TimeoutConfig {
1008            keepalive,
1009            socket_timeout,
1010            transaction_timeout,
1011            socket_connection_setup_timeout,
1012            fetch_metadata_timeout,
1013            progress_record_fetch_timeout,
1014        }
1015    }
1016}
1017
1018/// A simpler version of [`create_new_client_config`] that defaults
1019/// the `log_level` to `INFO` and should only be used in tests.
1020pub fn create_new_client_config_simple() -> ClientConfig {
1021    create_new_client_config(tracing::Level::INFO, Default::default())
1022}
1023
1024/// Build a new [`rdkafka`] [`ClientConfig`] with its `log_level` set correctly
1025/// based on the passed through [`tracing::Level`]. This level should be
1026/// determined for `target: "librdkafka"`.
1027pub fn create_new_client_config(
1028    tracing_level: Level,
1029    timeout_config: TimeoutConfig,
1030) -> ClientConfig {
1031    #[allow(clippy::disallowed_methods)]
1032    let mut config = ClientConfig::new();
1033
1034    let level = if tracing_level >= Level::DEBUG {
1035        RDKafkaLogLevel::Debug
1036    } else if tracing_level >= Level::INFO {
1037        RDKafkaLogLevel::Info
1038    } else if tracing_level >= Level::WARN {
1039        RDKafkaLogLevel::Warning
1040    } else {
1041        RDKafkaLogLevel::Error
1042    };
1043    // WARNING WARNING WARNING
1044    //
1045    // For whatever reason, if you change this `target` to something else, this
1046    // log line might break. I (guswynn) did some extensive investigation with
1047    // the tracing folks, and we learned that this edge case only happens with
1048    // 1. a different target
1049    // 2. only this file (so far as I can tell)
1050    // 3. only in certain subscriber combinations
1051    // 4. only if the `tracing-log` feature is on.
1052    //
1053    // Our conclusion was that one of our dependencies is doing something
1054    // problematic with `log`.
1055    //
1056    // For now, this works, and prints a nice log line exactly when we want it.
1057    //
1058    // TODO(guswynn): when we can remove `tracing-log`, remove this warning
1059    tracing::debug!(target: "librdkafka", level = ?level, "Determined log level for librdkafka");
1060    config.set_log_level(level);
1061
1062    // Patch the librdkafka debug log system into the Rust `log` ecosystem. This
1063    // is a very simple integration at the moment; enabling `debug`-level logs
1064    // for the `librdkafka` target enables the full firehouse of librdkafka
1065    // debug logs. We may want to investigate finer-grained control.
1066    if tracing_level >= Level::DEBUG {
1067        tracing::debug!(target: "librdkafka", "Enabling debug logs for rdkafka");
1068        config.set("debug", "all");
1069    }
1070
1071    if timeout_config.keepalive {
1072        config.set("socket.keepalive.enable", "true");
1073    }
1074
1075    config.set(
1076        "socket.timeout.ms",
1077        timeout_config.socket_timeout.as_millis().to_string(),
1078    );
1079    config.set(
1080        "transaction.timeout.ms",
1081        timeout_config.transaction_timeout.as_millis().to_string(),
1082    );
1083    config.set(
1084        "socket.connection.setup.timeout.ms",
1085        timeout_config
1086            .socket_connection_setup_timeout
1087            .as_millis()
1088            .to_string(),
1089    );
1090
1091    config
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    #[mz_ore::test]
1099    fn test_connection_rule_pattern_matches() {
1100        let p = ConnectionRulePattern {
1101            prefix_wildcard: false,
1102            literal_match: "broker:9092".to_string(),
1103            suffix_wildcard: false,
1104        };
1105        assert!(p.matches("broker:9092"));
1106        assert!(!p.matches("other:9092"));
1107
1108        let p = ConnectionRulePattern {
1109            prefix_wildcard: true,
1110            literal_match: ":9092".to_string(),
1111            suffix_wildcard: false,
1112        };
1113        assert!(p.matches("any-host:9092"));
1114        assert!(!p.matches("broker:9093"));
1115
1116        let p = ConnectionRulePattern {
1117            prefix_wildcard: false,
1118            literal_match: "broker:".to_string(),
1119            suffix_wildcard: true,
1120        };
1121        assert!(p.matches("broker:9092"));
1122        assert!(!p.matches("other:9092"));
1123
1124        let p = ConnectionRulePattern {
1125            prefix_wildcard: true,
1126            literal_match: "broker".to_string(),
1127            suffix_wildcard: true,
1128        };
1129        assert!(p.matches("my-broker-host:1234"));
1130        assert!(!p.matches("other:9092"));
1131    }
1132}