deadpool_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4    nonstandard_style,
5    rust_2018_idioms,
6    rustdoc::broken_intra_doc_links,
7    rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11    deprecated_in_future,
12    missing_copy_implementations,
13    missing_debug_implementations,
14    missing_docs,
15    unreachable_pub,
16    unused_import_braces,
17    unused_labels,
18    unused_lifetimes,
19    unused_qualifications,
20    unused_results
21)]
22
23mod config;
24
25use std::{
26    borrow::Cow,
27    collections::HashMap,
28    fmt,
29    ops::{Deref, DerefMut},
30    sync::{
31        atomic::{AtomicUsize, Ordering},
32        Arc, Mutex, RwLock, Weak,
33    },
34};
35
36use deadpool::{async_trait, managed};
37use tokio::spawn;
38use tokio_postgres::{
39    tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig,
40    Error, IsolationLevel, Socket, Statement, Transaction as PgTransaction,
41    TransactionBuilder as PgTransactionBuilder,
42};
43
44pub use tokio_postgres;
45
46pub use self::config::{
47    ChannelBinding, Config, ConfigError, ManagerConfig, RecyclingMethod, SslMode,
48    TargetSessionAttrs,
49};
50
51pub use deadpool::managed::reexports::*;
52deadpool::managed_reexports!(
53    "tokio_postgres",
54    Manager,
55    deadpool::managed::Object<Manager>,
56    Error,
57    ConfigError
58);
59
60/// Type alias for [`Object`]
61pub type Client = Object;
62
63type RecycleResult = deadpool::managed::RecycleResult<Error>;
64type RecycleError = deadpool::managed::RecycleError<Error>;
65
66/// [`Manager`] for creating and recycling PostgreSQL connections.
67///
68/// [`Manager`]: managed::Manager
69pub struct Manager {
70    config: ManagerConfig,
71    pg_config: PgConfig,
72    connect: Box<dyn Connect>,
73    /// [`StatementCaches`] of [`Client`]s handed out by the [`Pool`].
74    pub statement_caches: StatementCaches,
75}
76
77impl Manager {
78    /// Creates a new [`Manager`] using the given [`tokio_postgres::Config`] and
79    /// `tls` connector.
80    pub fn new<T>(pg_config: tokio_postgres::Config, tls: T) -> Self
81    where
82        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
83        T::Stream: Sync + Send,
84        T::TlsConnect: Sync + Send,
85        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
86    {
87        Self::from_config(pg_config, tls, ManagerConfig::default())
88    }
89
90    /// Create a new [`Manager`] using the given [`tokio_postgres::Config`], and
91    /// `tls` connector and [`ManagerConfig`].
92    pub fn from_config<T>(pg_config: tokio_postgres::Config, tls: T, config: ManagerConfig) -> Self
93    where
94        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
95        T::Stream: Sync + Send,
96        T::TlsConnect: Sync + Send,
97        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
98    {
99        Self {
100            config,
101            pg_config,
102            connect: Box::new(ConnectImpl { tls }),
103            statement_caches: StatementCaches::default(),
104        }
105    }
106}
107
108impl fmt::Debug for Manager {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        f.debug_struct("Manager")
111            .field("config", &self.config)
112            .field("pg_config", &self.pg_config)
113            //.field("connect", &self.connect)
114            .field("statement_caches", &self.statement_caches)
115            .finish()
116    }
117}
118
119#[async_trait]
120impl managed::Manager for Manager {
121    type Type = ClientWrapper;
122    type Error = Error;
123
124    async fn create(&self) -> Result<ClientWrapper, Error> {
125        let client = self.connect.connect(&self.pg_config).await?;
126        let client_wrapper = ClientWrapper::new(client);
127        self.statement_caches
128            .attach(&client_wrapper.statement_cache);
129        Ok(client_wrapper)
130    }
131
132    async fn recycle(&self, client: &mut ClientWrapper) -> RecycleResult {
133        if client.is_closed() {
134            log::info!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
135            return Err(RecycleError::StaticMessage("Connection closed"));
136        }
137        match self.config.recycling_method.query() {
138            Some(sql) => match client.simple_query(sql).await {
139                Ok(_) => Ok(()),
140                Err(e) => {
141                    log::info!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
142                    Err(e.into())
143                }
144            },
145            None => Ok(()),
146        }
147    }
148
149    fn detach(&self, object: &mut ClientWrapper) {
150        self.statement_caches.detach(&object.statement_cache);
151    }
152}
153
154#[async_trait]
155trait Connect: Sync + Send {
156    async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error>;
157}
158
159struct ConnectImpl<T>
160where
161    T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
162    T::Stream: Sync + Send,
163    T::TlsConnect: Sync + Send,
164    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
165{
166    tls: T,
167}
168
169#[async_trait]
170impl<T> Connect for ConnectImpl<T>
171where
172    T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
173    T::Stream: Sync + Send,
174    T::TlsConnect: Sync + Send,
175    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
176{
177    async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error> {
178        let (client, connection) = pg_config.connect(self.tls.clone()).await?;
179        drop(spawn(async move {
180            if let Err(e) = connection.await {
181                log::warn!(target: "deadpool.postgres", "Connection error: {}", e);
182            }
183        }));
184        Ok(client)
185    }
186}
187
188/// Structure holding a reference to all [`StatementCache`]s and providing
189/// access for clearing all caches and removing single statements from them.
190#[derive(Default, Debug)]
191pub struct StatementCaches {
192    caches: Mutex<Vec<Weak<StatementCache>>>,
193}
194
195impl StatementCaches {
196    fn attach(&self, cache: &Arc<StatementCache>) {
197        let cache = Arc::downgrade(cache);
198        self.caches.lock().unwrap().push(cache);
199    }
200
201    fn detach(&self, cache: &Arc<StatementCache>) {
202        let cache = Arc::downgrade(cache);
203        self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
204    }
205
206    /// Clears [`StatementCache`] of all connections which were handed out by a
207    /// [`Manager`].
208    pub fn clear(&self) {
209        let caches = self.caches.lock().unwrap();
210        for cache in caches.iter() {
211            if let Some(cache) = cache.upgrade() {
212                cache.clear();
213            }
214        }
215    }
216
217    /// Removes statement from all caches which were handed out by a
218    /// [`Manager`].
219    pub fn remove(&self, query: &str, types: &[Type]) {
220        let caches = self.caches.lock().unwrap();
221        for cache in caches.iter() {
222            if let Some(cache) = cache.upgrade() {
223                drop(cache.remove(query, types));
224            }
225        }
226    }
227}
228
229impl fmt::Debug for StatementCache {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231        f.debug_struct("ClientWrapper")
232            //.field("map", &self.map)
233            .field("size", &self.size)
234            .finish()
235    }
236}
237
238// Allows us to use owned keys in a `HashMap`, but still be able to call `get`
239// with borrowed keys instead of allocating them each time.
240#[derive(Debug, Eq, Hash, PartialEq)]
241struct StatementCacheKey<'a> {
242    query: Cow<'a, str>,
243    types: Cow<'a, [Type]>,
244}
245
246/// Representation of a cache of [`Statement`]s.
247///
248/// [`StatementCache`] is bound to one [`Client`], and [`Statement`]s generated
249/// by that [`Client`] must not be used with other [`Client`]s.
250///
251/// It can be used like that:
252/// ```rust,ignore
253/// let client = pool.get().await?;
254/// let stmt = client
255///     .statement_cache
256///     .prepare(&client, "SELECT 1")
257///     .await;
258/// let rows = client.query(stmt, &[]).await?;
259/// ...
260/// ```
261///
262/// Normally, you probably want to use the [`ClientWrapper::prepare_cached()`]
263/// and [`ClientWrapper::prepare_typed_cached()`] methods instead (or the
264/// similar ones on [`Transaction`]).
265pub struct StatementCache {
266    map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
267    size: AtomicUsize,
268}
269
270impl StatementCache {
271    fn new() -> Self {
272        Self {
273            map: RwLock::new(HashMap::new()),
274            size: AtomicUsize::new(0),
275        }
276    }
277
278    /// Returns current size of this [`StatementCache`].
279    pub fn size(&self) -> usize {
280        self.size.load(Ordering::Relaxed)
281    }
282
283    /// Clears this [`StatementCache`].
284    ///
285    /// **Important:** This only clears the [`StatementCache`] of one [`Client`]
286    /// instance. If you want to clear the [`StatementCache`] of all [`Client`]s
287    /// you should be calling `pool.manager().statement_caches.clear()` instead.
288    pub fn clear(&self) {
289        let mut map = self.map.write().unwrap();
290        map.clear();
291        self.size.store(0, Ordering::Relaxed);
292    }
293
294    /// Removes a [`Statement`] from this [`StatementCache`].
295    ///
296    /// **Important:** This only removes a [`Statement`] from one [`Client`]
297    /// cache. If you want to remove a [`Statement`] from all
298    /// [`StatementCaches`] you should be calling
299    /// `pool.manager().statement_caches.remove()` instead.
300    pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
301        let key = StatementCacheKey {
302            query: Cow::Owned(query.to_owned()),
303            types: Cow::Owned(types.to_owned()),
304        };
305        let mut map = self.map.write().unwrap();
306        let removed = map.remove(&key);
307        if removed.is_some() {
308            let _ = self.size.fetch_sub(1, Ordering::Relaxed);
309        }
310        removed
311    }
312
313    /// Returns a [`Statement`] from this [`StatementCache`].
314    fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
315        let key = StatementCacheKey {
316            query: Cow::Borrowed(query),
317            types: Cow::Borrowed(types),
318        };
319        self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
320    }
321
322    /// Inserts a [`Statement`] into this [`StatementCache`].
323    fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
324        let key = StatementCacheKey {
325            query: Cow::Owned(query.to_owned()),
326            types: Cow::Owned(types.to_owned()),
327        };
328        let mut map = self.map.write().unwrap();
329        if map.insert(key, stmt).is_none() {
330            let _ = self.size.fetch_add(1, Ordering::Relaxed);
331        }
332    }
333
334    /// Creates a new prepared [`Statement`] using this [`StatementCache`], if
335    /// possible.
336    ///
337    /// See [`tokio_postgres::Client::prepare()`].
338    pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
339        self.prepare_typed(client, query, &[]).await
340    }
341
342    /// Creates a new prepared [`Statement`] with specifying its [`Type`]s
343    /// explicitly using this [`StatementCache`], if possible.
344    ///
345    /// See [`tokio_postgres::Client::prepare_typed()`].
346    pub async fn prepare_typed(
347        &self,
348        client: &PgClient,
349        query: &str,
350        types: &[Type],
351    ) -> Result<Statement, Error> {
352        match self.get(query, types) {
353            Some(statement) => Ok(statement),
354            None => {
355                let stmt = client.prepare_typed(query, types).await?;
356                self.insert(query, types, stmt.clone());
357                Ok(stmt)
358            }
359        }
360    }
361}
362
363/// Wrapper around [`tokio_postgres::Client`] with a [`StatementCache`].
364#[derive(Debug)]
365pub struct ClientWrapper {
366    /// Original [`PgClient`].
367    client: PgClient,
368
369    /// [`StatementCache`] of this client.
370    pub statement_cache: Arc<StatementCache>,
371}
372
373impl ClientWrapper {
374    /// Create a new [`ClientWrapper`] instance using the given
375    /// [`tokio_postgres::Client`].
376    #[must_use]
377    pub fn new(client: PgClient) -> Self {
378        Self {
379            client,
380            statement_cache: Arc::new(StatementCache::new()),
381        }
382    }
383
384    /// Like [`tokio_postgres::Transaction::prepare()`], but uses an existing
385    /// [`Statement`] from the [`StatementCache`] if possible.
386    pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
387        self.statement_cache.prepare(&self.client, query).await
388    }
389
390    /// Like [`tokio_postgres::Transaction::prepare_typed()`], but uses an
391    /// existing [`Statement`] from the [`StatementCache`] if possible.
392    pub async fn prepare_typed_cached(
393        &self,
394        query: &str,
395        types: &[Type],
396    ) -> Result<Statement, Error> {
397        self.statement_cache
398            .prepare_typed(&self.client, query, types)
399            .await
400    }
401
402    /// Like [`tokio_postgres::Client::transaction()`], but returns a wrapped
403    /// [`Transaction`] with a [`StatementCache`].
404    #[allow(unused_lifetimes)] // false positive
405    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
406        Ok(Transaction {
407            txn: PgClient::transaction(&mut self.client).await?,
408            statement_cache: self.statement_cache.clone(),
409        })
410    }
411
412    /// Like [`tokio_postgres::Client::build_transaction()`], but creates a
413    /// wrapped [`Transaction`] with a [`StatementCache`].
414    pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
415        TransactionBuilder {
416            builder: self.client.build_transaction(),
417            statement_cache: self.statement_cache.clone(),
418        }
419    }
420}
421
422impl Deref for ClientWrapper {
423    type Target = PgClient;
424
425    fn deref(&self) -> &PgClient {
426        &self.client
427    }
428}
429
430impl DerefMut for ClientWrapper {
431    fn deref_mut(&mut self) -> &mut PgClient {
432        &mut self.client
433    }
434}
435
436/// Wrapper around [`tokio_postgres::Transaction`] with a [`StatementCache`]
437/// from the [`Client`] object it was created by.
438pub struct Transaction<'a> {
439    /// Original [`PgTransaction`].
440    txn: PgTransaction<'a>,
441
442    /// [`StatementCache`] of this [`Transaction`].
443    pub statement_cache: Arc<StatementCache>,
444}
445
446impl<'a> Transaction<'a> {
447    /// Like [`tokio_postgres::Transaction::prepare()`], but uses an existing
448    /// [`Statement`] from the [`StatementCache`] if possible.
449    pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
450        self.statement_cache.prepare(self.client(), query).await
451    }
452
453    /// Like [`tokio_postgres::Transaction::prepare_typed()`], but uses an
454    /// existing [`Statement`] from the [`StatementCache`] if possible.
455    pub async fn prepare_typed_cached(
456        &self,
457        query: &str,
458        types: &[Type],
459    ) -> Result<Statement, Error> {
460        self.statement_cache
461            .prepare_typed(self.client(), query, types)
462            .await
463    }
464
465    /// Like [`tokio_postgres::Transaction::commit()`].
466    pub async fn commit(self) -> Result<(), Error> {
467        self.txn.commit().await
468    }
469
470    /// Like [`tokio_postgres::Transaction::rollback()`].
471    pub async fn rollback(self) -> Result<(), Error> {
472        self.txn.rollback().await
473    }
474
475    /// Like [`tokio_postgres::Transaction::transaction()`], but returns a
476    /// wrapped [`Transaction`] with a [`StatementCache`].
477    #[allow(unused_lifetimes)] // false positive
478    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
479        Ok(Transaction {
480            txn: PgTransaction::transaction(&mut self.txn).await?,
481            statement_cache: self.statement_cache.clone(),
482        })
483    }
484
485    /// Like [`tokio_postgres::Transaction::savepoint()`], but returns a wrapped
486    /// [`Transaction`] with a [`StatementCache`].
487    #[allow(unused_lifetimes)] // false positive
488    pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
489    where
490        I: Into<String>,
491    {
492        Ok(Transaction {
493            txn: PgTransaction::savepoint(&mut self.txn, name).await?,
494            statement_cache: self.statement_cache.clone(),
495        })
496    }
497}
498
499impl<'a> fmt::Debug for Transaction<'a> {
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501        f.debug_struct("Transaction")
502            //.field("txn", &self.txn)
503            .field("statement_cache", &self.statement_cache)
504            .finish()
505    }
506}
507
508impl<'a> Deref for Transaction<'a> {
509    type Target = PgTransaction<'a>;
510
511    fn deref(&self) -> &PgTransaction<'a> {
512        &self.txn
513    }
514}
515
516impl<'a> DerefMut for Transaction<'a> {
517    fn deref_mut(&mut self) -> &mut PgTransaction<'a> {
518        &mut self.txn
519    }
520}
521
522/// Wrapper around [`tokio_postgres::TransactionBuilder`] with a
523/// [`StatementCache`] from the [`Client`] object it was created by.
524#[must_use = "builder does nothing itself, use `.start()` to use it"]
525pub struct TransactionBuilder<'a> {
526    /// Original [`PgTransactionBuilder`].
527    builder: PgTransactionBuilder<'a>,
528
529    /// [`StatementCache`] of this [`TransactionBuilder`].
530    statement_cache: Arc<StatementCache>,
531}
532
533impl<'a> TransactionBuilder<'a> {
534    /// Sets the isolation level of the transaction.
535    ///
536    /// Like [`tokio_postgres::TransactionBuilder::isolation_level()`].
537    pub fn isolation_level(self, isolation_level: IsolationLevel) -> Self {
538        Self {
539            builder: self.builder.isolation_level(isolation_level),
540            statement_cache: self.statement_cache,
541        }
542    }
543
544    /// Sets the access mode of the transaction.
545    ///
546    /// Like [`tokio_postgres::TransactionBuilder::read_only()`].
547    pub fn read_only(self, read_only: bool) -> Self {
548        Self {
549            builder: self.builder.read_only(read_only),
550            statement_cache: self.statement_cache,
551        }
552    }
553
554    /// Sets the deferrability of the transaction.
555    ///
556    /// If the transaction is also serializable and read only, creation
557    /// of the transaction may block, but when it completes the transaction
558    /// is able to run with less overhead and a guarantee that it will not
559    /// be aborted due to serialization failure.
560    ///
561    /// Like [`tokio_postgres::TransactionBuilder::deferrable()`].
562    pub fn deferrable(self, deferrable: bool) -> Self {
563        Self {
564            builder: self.builder.deferrable(deferrable),
565            statement_cache: self.statement_cache,
566        }
567    }
568
569    /// Begins the [`Transaction`].
570    ///
571    /// The transaction will roll back by default - use the commit method
572    /// to commit it.
573    ///
574    /// Like [`tokio_postgres::TransactionBuilder::start()`].
575    pub async fn start(self) -> Result<Transaction<'a>, Error> {
576        Ok(Transaction {
577            txn: self.builder.start().await?,
578            statement_cache: self.statement_cache,
579        })
580    }
581}
582
583impl<'a> fmt::Debug for TransactionBuilder<'a> {
584    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585        f.debug_struct("TransactionBuilder")
586            //.field("builder", &self.builder)
587            .field("statement_cache", &self.statement_cache)
588            .finish()
589    }
590}
591
592impl<'a> Deref for TransactionBuilder<'a> {
593    type Target = PgTransactionBuilder<'a>;
594
595    fn deref(&self) -> &Self::Target {
596        &self.builder
597    }
598}
599
600impl<'a> DerefMut for TransactionBuilder<'a> {
601    fn deref_mut(&mut self) -> &mut Self::Target {
602        &mut self.builder
603    }
604}