Skip to main content

mz_postgres_client/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A Postgres client that uses deadpool as a connection pool and comes with
11//! common/default configuration options.
12
13#![warn(missing_docs, missing_debug_implementations)]
14#![warn(
15    clippy::cast_possible_truncation,
16    clippy::cast_precision_loss,
17    clippy::cast_sign_loss,
18    clippy::clone_on_ref_ptr
19)]
20
21pub mod error;
22pub mod metrics;
23
24use std::fmt::Write;
25use std::sync::Arc;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::time::{Duration, Instant};
28
29use deadpool_postgres::tokio_postgres::Config;
30use deadpool_postgres::{
31    Hook, HookError, HookErrorCause, Manager, ManagerConfig, Object, Pool, PoolError,
32    RecyclingMethod, Runtime, Status,
33};
34use mz_ore::cast::{CastFrom, CastLossy};
35use mz_ore::now::SYSTEM_TIME;
36use mz_ore::url::SensitiveUrl;
37use tracing::debug;
38
39use crate::error::PostgresError;
40use crate::metrics::PostgresClientMetrics;
41
42/// Configuration knobs for [PostgresClient].
43pub trait PostgresClientKnobs: std::fmt::Debug + Send + Sync {
44    /// Maximum number of connections allowed in a pool.
45    fn connection_pool_max_size(&self) -> usize;
46    /// The maximum time to wait to obtain a connection, if any.
47    fn connection_pool_max_wait(&self) -> Option<Duration>;
48    /// Minimum TTL of a connection. It is expected that connections are
49    /// routinely culled to balance load to the backing store.
50    fn connection_pool_ttl(&self) -> Duration;
51    /// Minimum time between TTLing connections. Helps stagger reconnections
52    /// to avoid stampeding the backing store.
53    fn connection_pool_ttl_stagger(&self) -> Duration;
54    /// Time to wait for a connection to be made before retrying.
55    fn connect_timeout(&self) -> Duration;
56    /// TCP user timeout for connections.
57    fn tcp_user_timeout(&self) -> Duration;
58    /// Amount of idle time before a TCP keepalive packet is sent on a connection.
59    fn keepalives_idle(&self) -> Duration;
60    /// Time interval between TCP keepalive probes.
61    fn keepalives_interval(&self) -> Duration;
62    /// Maximum number of TCP keepalive probes that will be sent before dropping a connection.
63    fn keepalives_retries(&self) -> u32;
64    /// Server-side `statement_timeout` to set on each connection. A value of
65    /// zero is a sentinel that means "do not set a statement timeout".
66    fn statement_timeout(&self) -> Duration;
67}
68
69/// Configuration for creating a [PostgresClient].
70#[derive(Clone, Debug)]
71pub struct PostgresClientConfig {
72    url: SensitiveUrl,
73    knobs: Arc<dyn PostgresClientKnobs>,
74    metrics: PostgresClientMetrics,
75}
76
77impl PostgresClientConfig {
78    /// Returns a new [PostgresClientConfig] for use in production.
79    pub fn new(
80        url: SensitiveUrl,
81        knobs: Arc<dyn PostgresClientKnobs>,
82        metrics: PostgresClientMetrics,
83    ) -> Self {
84        PostgresClientConfig {
85            url,
86            knobs,
87            metrics,
88        }
89    }
90}
91
92/// A Postgres client wrapper that uses deadpool as a connection pool.
93pub struct PostgresClient {
94    pool: Pool,
95    metrics: PostgresClientMetrics,
96}
97
98impl std::fmt::Debug for PostgresClient {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("PostgresClient").finish_non_exhaustive()
101    }
102}
103
104impl PostgresClient {
105    /// Open a [PostgresClient] using the given `config`.
106    pub fn open(config: PostgresClientConfig) -> Result<Self, PostgresError> {
107        let mut pg_config: Config = config.url.to_string_unredacted().parse()?;
108        pg_config.connect_timeout(config.knobs.connect_timeout());
109        pg_config.tcp_user_timeout(config.knobs.tcp_user_timeout());
110
111        // Configuring keepalives is important to ensure we can detect broken connections quickly.
112        // TCP_USER_TIMEOUT is not sufficient as it only enforces a timeout on ACKs for transmitted
113        // data, which only helps if we... transmit data.
114        pg_config.keepalives(true);
115        pg_config.keepalives_idle(config.knobs.keepalives_idle());
116        pg_config.keepalives_interval(config.knobs.keepalives_interval());
117        pg_config.keepalives_retries(config.knobs.keepalives_retries());
118
119        let tls = mz_tls_util::make_tls(&pg_config).map_err(|tls_err| match tls_err {
120            mz_tls_util::TlsError::Generic(e) => PostgresError::Indeterminate(e),
121            mz_tls_util::TlsError::OpenSsl(e) => PostgresError::Indeterminate(anyhow::anyhow!(e)),
122        })?;
123
124        let manager = Manager::from_config(
125            pg_config,
126            tls,
127            ManagerConfig {
128                recycling_method: RecyclingMethod::Fast,
129            },
130        );
131
132        let last_ttl_connection = AtomicU64::new(0);
133        let connections_created = config.metrics.connpool_connections_created.clone();
134        let ttl_reconnections = config.metrics.connpool_ttl_reconnections.clone();
135        let knobs = Arc::clone(&config.knobs);
136        let builder = Pool::builder(manager);
137        let builder = match config.knobs.connection_pool_max_wait() {
138            None => builder,
139            Some(wait) => builder.wait_timeout(Some(wait)).runtime(Runtime::Tokio1),
140        };
141        let pool = builder
142            .max_size(config.knobs.connection_pool_max_size())
143            .post_create(Hook::async_fn(move |client, _| {
144                connections_created.inc();
145                let knobs = Arc::clone(&knobs);
146                Box::pin(async move {
147                    debug!("opened new consensus postgres connection");
148                    let mut setup = String::from(
149                        "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL SERIALIZABLE",
150                    );
151                    // A zero `statement_timeout` is our sentinel for "leave it
152                    // unset". We only emit the `SET` when non-zero so we don't
153                    // override a timeout configured out of band.
154                    let statement_timeout = knobs.statement_timeout();
155                    if !statement_timeout.is_zero() {
156                        // A bare integer value for `statement_timeout` is
157                        // interpreted as milliseconds.
158                        write!(
159                            setup,
160                            "; SET statement_timeout = {}",
161                            statement_timeout.as_millis()
162                        )
163                        .expect("writing to a String never fails");
164                    }
165                    // This hook must return `tokio_postgres::Error`; using
166                    // `mz_postgres_util` wrappers would change the error type.
167                    #[allow(clippy::disallowed_methods)]
168                    client
169                        .batch_execute(&setup)
170                        .await
171                        .map_err(|e| HookError::Abort(HookErrorCause::Backend(e)))
172                })
173            }))
174            .pre_recycle(Hook::sync_fn(move |_client, conn_metrics| {
175                // proactively TTL connections to rebalance load to Postgres/CRDB. this helps
176                // fix skew when downstream DB operations (e.g. CRDB rolling restart) result
177                // in uneven load to each node, and works to reduce the # of connections
178                // maintained by the pool after bursty workloads.
179
180                // add a bias towards TTLing older connections first
181                if conn_metrics.age() < config.knobs.connection_pool_ttl() {
182                    return Ok(());
183                }
184
185                let last_ttl = last_ttl_connection.load(Ordering::SeqCst);
186                let now = (SYSTEM_TIME)();
187                let elapsed_since_last_ttl = Duration::from_millis(now.saturating_sub(last_ttl));
188
189                // stagger out reconnections to avoid stampeding the DB
190                if elapsed_since_last_ttl > config.knobs.connection_pool_ttl_stagger()
191                    && last_ttl_connection
192                        .compare_exchange_weak(last_ttl, now, Ordering::SeqCst, Ordering::SeqCst)
193                        .is_ok()
194                {
195                    ttl_reconnections.inc();
196                    return Err(HookError::Continue(Some(HookErrorCause::Message(
197                        "connection has been TTLed".to_string(),
198                    ))));
199                }
200
201                Ok(())
202            }))
203            .build()
204            .expect("postgres connection pool built with incorrect parameters");
205
206        Ok(PostgresClient {
207            pool,
208            metrics: config.metrics,
209        })
210    }
211
212    fn status_metrics(&self, status: Status) {
213        self.metrics
214            .connpool_available
215            .set(f64::cast_lossy(status.available));
216        self.metrics.connpool_size.set(u64::cast_from(status.size));
217        // Don't bother reporting the maximum size of the pool... we know that from config.
218    }
219
220    /// Gets connection from the pool or waits for one to become available.
221    pub async fn get_connection(&self) -> Result<Object, PoolError> {
222        let start = Instant::now();
223        // note that getting the pool size here requires briefly locking the pool
224        self.status_metrics(self.pool.status());
225        let res = self.pool.get().await;
226        if let Err(PoolError::Backend(err)) = &res {
227            debug!("error establishing connection: {}", err);
228            self.metrics.connpool_connection_errors.inc();
229        }
230        self.metrics
231            .connpool_acquire_seconds
232            .inc_by(start.elapsed().as_secs_f64());
233        self.metrics.connpool_acquires.inc();
234        self.status_metrics(self.pool.status());
235        res
236    }
237}