deadpool_postgres/
config.rs

1//! Configuration used for [`Pool`] creation.
2
3use std::{env, fmt, time::Duration};
4
5#[cfg(feature = "serde")]
6use serde_1 as serde;
7use tokio_postgres::{
8    config::{
9        ChannelBinding as PgChannelBinding, SslMode as PgSslMode,
10        TargetSessionAttrs as PgTargetSessionAttrs,
11    },
12    tls::{MakeTlsConnect, TlsConnect},
13    Socket,
14};
15
16use crate::{CreatePoolError, PoolBuilder, Runtime};
17
18use super::{Pool, PoolConfig};
19
20/// Configuration object.
21///
22/// # Example (from environment)
23///
24/// By enabling the `serde` feature you can read the configuration using the
25/// [`config`](https://crates.io/crates/config) crate as following:
26/// ```env
27/// PG__HOST=pg.example.com
28/// PG__USER=john_doe
29/// PG__PASSWORD=topsecret
30/// PG__DBNAME=example
31/// PG__POOL__MAX_SIZE=16
32/// PG__POOL__TIMEOUTS__WAIT__SECS=5
33/// PG__POOL__TIMEOUTS__WAIT__NANOS=0
34/// ```
35/// ```rust
36/// # use serde_1 as serde;
37/// #
38/// #[derive(serde::Deserialize, serde::Serialize)]
39/// # #[serde(crate = "serde_1")]
40/// struct Config {
41///     pg: deadpool_postgres::Config,
42/// }
43/// impl Config {
44///     pub fn from_env() -> Result<Self, config::ConfigError> {
45///         let mut cfg = config::Config::builder()
46///            .add_source(config::Environment::default().separator("__"))
47///            .build()?;
48///            cfg.try_deserialize()
49///     }
50/// }
51/// ```
52#[derive(Clone, Debug, Default)]
53#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
54#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
55pub struct Config {
56    /// See [`tokio_postgres::Config::user`].
57    pub user: Option<String>,
58    /// See [`tokio_postgres::Config::password`].
59    pub password: Option<String>,
60    /// See [`tokio_postgres::Config::dbname`].
61    pub dbname: Option<String>,
62    /// See [`tokio_postgres::Config::options`].
63    pub options: Option<String>,
64    /// See [`tokio_postgres::Config::application_name`].
65    pub application_name: Option<String>,
66    /// See [`tokio_postgres::Config::ssl_mode`].
67    pub ssl_mode: Option<SslMode>,
68    /// This is similar to [`Config::hosts`] but only allows one host to be
69    /// specified.
70    ///
71    /// Unlike [`tokio_postgres::Config`] this structure differentiates between
72    /// one host and more than one host. This makes it possible to store this
73    /// configuration in an environment variable.
74    ///
75    /// See [`tokio_postgres::Config::host`].
76    pub host: Option<String>,
77    /// See [`tokio_postgres::Config::host`].
78    pub hosts: Option<Vec<String>>,
79    /// This is similar to [`Config::ports`] but only allows one port to be
80    /// specified.
81    ///
82    /// Unlike [`tokio_postgres::Config`] this structure differentiates between
83    /// one port and more than one port. This makes it possible to store this
84    /// configuration in an environment variable.
85    ///
86    /// See [`tokio_postgres::Config::port`].
87    pub port: Option<u16>,
88    /// See [`tokio_postgres::Config::port`].
89    pub ports: Option<Vec<u16>>,
90    /// See [`tokio_postgres::Config::connect_timeout`].
91    pub connect_timeout: Option<Duration>,
92    /// See [`tokio_postgres::Config::keepalives`].
93    pub keepalives: Option<bool>,
94    /// See [`tokio_postgres::Config::keepalives_idle`].
95    pub keepalives_idle: Option<Duration>,
96    /// See [`tokio_postgres::Config::target_session_attrs`].
97    pub target_session_attrs: Option<TargetSessionAttrs>,
98    /// See [`tokio_postgres::Config::channel_binding`].
99    pub channel_binding: Option<ChannelBinding>,
100
101    /// [`Manager`] configuration.
102    ///
103    /// [`Manager`]: super::Manager
104    pub manager: Option<ManagerConfig>,
105
106    /// [`Pool`] configuration.
107    pub pool: Option<PoolConfig>,
108}
109
110/// This error is returned if there is something wrong with the configuration
111#[derive(Copy, Clone, Debug)]
112pub enum ConfigError {
113    /// This variant is returned if the `dbname` is missing from the config
114    DbnameMissing,
115    /// This variant is returned if the `dbname` contains an empty string
116    DbnameEmpty,
117}
118
119impl fmt::Display for ConfigError {
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        match self {
122            Self::DbnameMissing => write!(f, "configuration property \"dbname\" not found"),
123            Self::DbnameEmpty => write!(
124                f,
125                "configuration property \"dbname\" contains an empty string",
126            ),
127        }
128    }
129}
130
131impl std::error::Error for ConfigError {}
132
133impl Config {
134    /// Create a new [`Config`] instance with default values. This function is
135    /// identical to [`Config::default()`].
136    #[must_use]
137    pub fn new() -> Self {
138        Self::default()
139    }
140
141    /// Creates a new [`Pool`] using this [`Config`].
142    ///
143    /// # Errors
144    ///
145    /// See [`CreatePoolError`] for details.
146    pub fn create_pool<T>(&self, runtime: Option<Runtime>, tls: T) -> Result<Pool, CreatePoolError>
147    where
148        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
149        T::Stream: Sync + Send,
150        T::TlsConnect: Sync + Send,
151        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
152    {
153        let mut builder = self.builder(tls).map_err(CreatePoolError::Config)?;
154        if let Some(runtime) = runtime {
155            builder = builder.runtime(runtime);
156        }
157        builder.build().map_err(CreatePoolError::Build)
158    }
159
160    /// Creates a new [`PoolBuilder`] using this [`Config`].
161    ///
162    /// # Errors
163    ///
164    /// See [`ConfigError`] and [`tokio_postgres::Error`] for details.
165    pub fn builder<T>(&self, tls: T) -> Result<PoolBuilder, ConfigError>
166    where
167        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
168        T::Stream: Sync + Send,
169        T::TlsConnect: Sync + Send,
170        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
171    {
172        let pg_config = self.get_pg_config()?;
173        let manager_config = self.get_manager_config();
174        let manager = crate::Manager::from_config(pg_config, tls, manager_config);
175        let pool_config = self.get_pool_config();
176        Ok(Pool::builder(manager).config(pool_config))
177    }
178
179    /// Returns [`tokio_postgres::Config`] which can be used to connect to
180    /// the database.
181    #[allow(unused_results)]
182    pub fn get_pg_config(&self) -> Result<tokio_postgres::Config, ConfigError> {
183        let mut cfg = tokio_postgres::Config::new();
184        if let Some(user) = &self.user {
185            cfg.user(user.as_str());
186        } else if let Ok(user) = env::var("USER") {
187            cfg.user(user.as_str());
188        }
189        if let Some(password) = &self.password {
190            cfg.password(password);
191        }
192        match &self.dbname {
193            Some(dbname) => match dbname.as_str() {
194                "" => return Err(ConfigError::DbnameMissing),
195                dbname => cfg.dbname(dbname),
196            },
197            None => return Err(ConfigError::DbnameEmpty),
198        };
199        if let Some(options) = &self.options {
200            cfg.options(options.as_str());
201        }
202        if let Some(application_name) = &self.application_name {
203            cfg.application_name(application_name.as_str());
204        }
205        if let Some(host) = &self.host {
206            cfg.host(host.as_str());
207        }
208        if let Some(hosts) = &self.hosts {
209            for host in hosts.iter() {
210                cfg.host(host.as_str());
211            }
212        }
213        if self.host.is_none() && self.hosts.is_none() {
214            // Systems that support it default to unix domain sockets.
215            #[cfg(unix)]
216            {
217                cfg.host_path("/run/postgresql");
218                cfg.host_path("/var/run/postgresql");
219                cfg.host_path("/tmp");
220            }
221            // Windows and other systems use 127.0.0.1 instead.
222            #[cfg(not(unix))]
223            cfg.host("127.0.0.1");
224        }
225        if let Some(port) = self.port {
226            cfg.port(port);
227        }
228        if let Some(ports) = &self.ports {
229            for port in ports.iter() {
230                cfg.port(*port);
231            }
232        }
233        if let Some(connect_timeout) = self.connect_timeout {
234            cfg.connect_timeout(connect_timeout);
235        }
236        if let Some(keepalives) = self.keepalives {
237            cfg.keepalives(keepalives);
238        }
239        if let Some(keepalives_idle) = self.keepalives_idle {
240            cfg.keepalives_idle(keepalives_idle);
241        }
242        if let Some(mode) = self.ssl_mode {
243            cfg.ssl_mode(mode.into());
244        }
245        Ok(cfg)
246    }
247
248    /// Returns [`ManagerConfig`] which can be used to construct a
249    /// [`deadpool::managed::Pool`] instance.
250    #[must_use]
251    pub fn get_manager_config(&self) -> ManagerConfig {
252        self.manager.clone().unwrap_or_default()
253    }
254
255    /// Returns [`deadpool::managed::PoolConfig`] which can be used to construct
256    /// a [`deadpool::managed::Pool`] instance.
257    #[must_use]
258    pub fn get_pool_config(&self) -> PoolConfig {
259        self.pool.unwrap_or_default()
260    }
261}
262
263/// Possible methods of how a connection is recycled.
264///
265/// **Attention:** The current default is [`Verified`] but will be changed to
266/// [`Fast`] in the next minor release of [`deadpool-postgres`]. Please, make
267/// sure to explicitly state this if you want to keep using the [`Verified`]
268/// recycling method.
269///
270/// [`Fast`]: RecyclingMethod::Fast
271/// [`Verified`]: RecyclingMethod::Verified
272#[derive(Clone, Debug, Eq, PartialEq)]
273#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
274#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
275pub enum RecyclingMethod {
276    /// Only run [`Client::is_closed()`][1] when recycling existing connections.
277    ///
278    /// Unless you have special needs this is a safe choice.
279    ///
280    /// [1]: tokio_postgres::Client::is_closed
281    Fast,
282
283    /// Run [`Client::is_closed()`][1] and execute a test query.
284    ///
285    /// This is slower, but guarantees that the database connection is ready to
286    /// be used. Normally, [`Client::is_closed()`][1] should be enough to filter
287    /// out bad connections, but under some circumstances (i.e. hard-closed
288    /// network connections) it's possible that [`Client::is_closed()`][1]
289    /// returns `false` while the connection is dead. You will receive an error
290    /// on your first query then.
291    ///
292    /// [1]: tokio_postgres::Client::is_closed
293    Verified,
294
295    /// Like [`Verified`] query method, but instead use the following sequence
296    /// of statements which guarantees a pristine connection:
297    /// ```sql
298    /// CLOSE ALL;
299    /// SET SESSION AUTHORIZATION DEFAULT;
300    /// RESET ALL;
301    /// UNLISTEN *;
302    /// SELECT pg_advisory_unlock_all();
303    /// DISCARD TEMP;
304    /// DISCARD SEQUENCES;
305    /// ```
306    ///
307    /// This is similar to calling `DISCARD ALL`. but doesn't call
308    /// `DEALLOCATE ALL` and `DISCARD PLAN`, so that the statement cache is not
309    /// rendered ineffective.
310    ///
311    /// [`Verified`]: RecyclingMethod::Verified
312    Clean,
313
314    /// Like [`Verified`] but allows to specify a custom SQL to be executed.
315    ///
316    /// [`Verified`]: RecyclingMethod::Verified
317    Custom(String),
318}
319
320impl Default for RecyclingMethod {
321    fn default() -> Self {
322        Self::Fast
323    }
324}
325
326impl RecyclingMethod {
327    const DISCARD_SQL: &'static str = "\
328        CLOSE ALL; \
329        SET SESSION AUTHORIZATION DEFAULT; \
330        RESET ALL; \
331        UNLISTEN *; \
332        SELECT pg_advisory_unlock_all(); \
333        DISCARD TEMP; \
334        DISCARD SEQUENCES;\
335    ";
336
337    /// Returns SQL query to be executed when recycling a connection.
338    pub fn query(&self) -> Option<&str> {
339        match self {
340            Self::Fast => None,
341            Self::Verified => Some(""),
342            Self::Clean => Some(Self::DISCARD_SQL),
343            Self::Custom(sql) => Some(sql),
344        }
345    }
346}
347
348/// Configuration object for a [`Manager`].
349///
350/// This currently only makes it possible to specify which [`RecyclingMethod`]
351/// should be used when retrieving existing objects from the [`Pool`].
352///
353/// [`Manager`]: super::Manager
354#[derive(Clone, Debug, Default)]
355#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
356#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
357pub struct ManagerConfig {
358    /// Method of how a connection is recycled. See [`RecyclingMethod`].
359    pub recycling_method: RecyclingMethod,
360}
361
362/// Properties required of a session.
363///
364/// This is a 1:1 copy of the [`PgTargetSessionAttrs`] enumeration.
365/// This is duplicated here in order to add support for the
366/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
367#[derive(Clone, Copy, Debug, Eq, PartialEq)]
368#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
369#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
370#[non_exhaustive]
371pub enum TargetSessionAttrs {
372    /// No special properties are required.
373    Any,
374
375    /// The session must allow writes.
376    ReadWrite,
377}
378
379impl From<TargetSessionAttrs> for PgTargetSessionAttrs {
380    fn from(attrs: TargetSessionAttrs) -> Self {
381        match attrs {
382            TargetSessionAttrs::Any => Self::Any,
383            TargetSessionAttrs::ReadWrite => Self::ReadWrite,
384        }
385    }
386}
387
388/// TLS configuration.
389///
390/// This is a 1:1 copy of the [`PgSslMode`] enumeration.
391/// This is duplicated here in order to add support for the
392/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
393#[derive(Clone, Copy, Debug, Eq, PartialEq)]
394#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
395#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
396#[non_exhaustive]
397pub enum SslMode {
398    /// Do not use TLS.
399    Disable,
400
401    /// Attempt to connect with TLS but allow sessions without.
402    Prefer,
403
404    /// Require the use of TLS.
405    Require,
406}
407
408impl From<SslMode> for PgSslMode {
409    fn from(mode: SslMode) -> Self {
410        match mode {
411            SslMode::Disable => Self::Disable,
412            SslMode::Prefer => Self::Prefer,
413            SslMode::Require => Self::Require,
414        }
415    }
416}
417
418/// Channel binding configuration.
419///
420/// This is a 1:1 copy of the [`PgChannelBinding`] enumeration.
421/// This is duplicated here in order to add support for the
422/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
423#[derive(Clone, Copy, Debug, Eq, PartialEq)]
424#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
425#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
426#[non_exhaustive]
427pub enum ChannelBinding {
428    /// Do not use channel binding.
429    Disable,
430
431    /// Attempt to use channel binding but allow sessions without.
432    Prefer,
433
434    /// Require the use of channel binding.
435    Require,
436}
437
438impl From<ChannelBinding> for PgChannelBinding {
439    fn from(cb: ChannelBinding) -> Self {
440        match cb {
441            ChannelBinding::Disable => Self::Disable,
442            ChannelBinding::Prefer => Self::Prefer,
443            ChannelBinding::Require => Self::Require,
444        }
445    }
446}