1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4 nonstandard_style,
5 rust_2018_idioms,
6 rustdoc::broken_intra_doc_links,
7 rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11 deprecated_in_future,
12 missing_copy_implementations,
13 missing_debug_implementations,
14 missing_docs,
15 unreachable_pub,
16 unused_import_braces,
17 unused_labels,
18 unused_lifetimes,
19 unused_qualifications,
20 unused_results
21)]
22
23mod config;
24
25use std::{
26 borrow::Cow,
27 collections::HashMap,
28 fmt,
29 ops::{Deref, DerefMut},
30 sync::{
31 atomic::{AtomicUsize, Ordering},
32 Arc, Mutex, RwLock, Weak,
33 },
34};
35
36use deadpool::{async_trait, managed};
37use tokio::spawn;
38use tokio_postgres::{
39 tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig,
40 Error, IsolationLevel, Socket, Statement, Transaction as PgTransaction,
41 TransactionBuilder as PgTransactionBuilder,
42};
43
44pub use tokio_postgres;
45
46pub use self::config::{
47 ChannelBinding, Config, ConfigError, ManagerConfig, RecyclingMethod, SslMode,
48 TargetSessionAttrs,
49};
50
51pub use deadpool::managed::reexports::*;
52deadpool::managed_reexports!(
53 "tokio_postgres",
54 Manager,
55 deadpool::managed::Object<Manager>,
56 Error,
57 ConfigError
58);
59
60pub type Client = Object;
62
63type RecycleResult = deadpool::managed::RecycleResult<Error>;
64type RecycleError = deadpool::managed::RecycleError<Error>;
65
66pub struct Manager {
70 config: ManagerConfig,
71 pg_config: PgConfig,
72 connect: Box<dyn Connect>,
73 pub statement_caches: StatementCaches,
75}
76
77impl Manager {
78 pub fn new<T>(pg_config: tokio_postgres::Config, tls: T) -> Self
81 where
82 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
83 T::Stream: Sync + Send,
84 T::TlsConnect: Sync + Send,
85 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
86 {
87 Self::from_config(pg_config, tls, ManagerConfig::default())
88 }
89
90 pub fn from_config<T>(pg_config: tokio_postgres::Config, tls: T, config: ManagerConfig) -> Self
93 where
94 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
95 T::Stream: Sync + Send,
96 T::TlsConnect: Sync + Send,
97 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
98 {
99 Self {
100 config,
101 pg_config,
102 connect: Box::new(ConnectImpl { tls }),
103 statement_caches: StatementCaches::default(),
104 }
105 }
106}
107
108impl fmt::Debug for Manager {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("Manager")
111 .field("config", &self.config)
112 .field("pg_config", &self.pg_config)
113 .field("statement_caches", &self.statement_caches)
115 .finish()
116 }
117}
118
119#[async_trait]
120impl managed::Manager for Manager {
121 type Type = ClientWrapper;
122 type Error = Error;
123
124 async fn create(&self) -> Result<ClientWrapper, Error> {
125 let client = self.connect.connect(&self.pg_config).await?;
126 let client_wrapper = ClientWrapper::new(client);
127 self.statement_caches
128 .attach(&client_wrapper.statement_cache);
129 Ok(client_wrapper)
130 }
131
132 async fn recycle(&self, client: &mut ClientWrapper) -> RecycleResult {
133 if client.is_closed() {
134 log::info!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
135 return Err(RecycleError::StaticMessage("Connection closed"));
136 }
137 match self.config.recycling_method.query() {
138 Some(sql) => match client.simple_query(sql).await {
139 Ok(_) => Ok(()),
140 Err(e) => {
141 log::info!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
142 Err(e.into())
143 }
144 },
145 None => Ok(()),
146 }
147 }
148
149 fn detach(&self, object: &mut ClientWrapper) {
150 self.statement_caches.detach(&object.statement_cache);
151 }
152}
153
154#[async_trait]
155trait Connect: Sync + Send {
156 async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error>;
157}
158
159struct ConnectImpl<T>
160where
161 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
162 T::Stream: Sync + Send,
163 T::TlsConnect: Sync + Send,
164 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
165{
166 tls: T,
167}
168
169#[async_trait]
170impl<T> Connect for ConnectImpl<T>
171where
172 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
173 T::Stream: Sync + Send,
174 T::TlsConnect: Sync + Send,
175 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
176{
177 async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error> {
178 let (client, connection) = pg_config.connect(self.tls.clone()).await?;
179 drop(spawn(async move {
180 if let Err(e) = connection.await {
181 log::warn!(target: "deadpool.postgres", "Connection error: {}", e);
182 }
183 }));
184 Ok(client)
185 }
186}
187
188#[derive(Default, Debug)]
191pub struct StatementCaches {
192 caches: Mutex<Vec<Weak<StatementCache>>>,
193}
194
195impl StatementCaches {
196 fn attach(&self, cache: &Arc<StatementCache>) {
197 let cache = Arc::downgrade(cache);
198 self.caches.lock().unwrap().push(cache);
199 }
200
201 fn detach(&self, cache: &Arc<StatementCache>) {
202 let cache = Arc::downgrade(cache);
203 self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
204 }
205
206 pub fn clear(&self) {
209 let caches = self.caches.lock().unwrap();
210 for cache in caches.iter() {
211 if let Some(cache) = cache.upgrade() {
212 cache.clear();
213 }
214 }
215 }
216
217 pub fn remove(&self, query: &str, types: &[Type]) {
220 let caches = self.caches.lock().unwrap();
221 for cache in caches.iter() {
222 if let Some(cache) = cache.upgrade() {
223 drop(cache.remove(query, types));
224 }
225 }
226 }
227}
228
229impl fmt::Debug for StatementCache {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 f.debug_struct("ClientWrapper")
232 .field("size", &self.size)
234 .finish()
235 }
236}
237
238#[derive(Debug, Eq, Hash, PartialEq)]
241struct StatementCacheKey<'a> {
242 query: Cow<'a, str>,
243 types: Cow<'a, [Type]>,
244}
245
246pub struct StatementCache {
266 map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
267 size: AtomicUsize,
268}
269
270impl StatementCache {
271 fn new() -> Self {
272 Self {
273 map: RwLock::new(HashMap::new()),
274 size: AtomicUsize::new(0),
275 }
276 }
277
278 pub fn size(&self) -> usize {
280 self.size.load(Ordering::Relaxed)
281 }
282
283 pub fn clear(&self) {
289 let mut map = self.map.write().unwrap();
290 map.clear();
291 self.size.store(0, Ordering::Relaxed);
292 }
293
294 pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
301 let key = StatementCacheKey {
302 query: Cow::Owned(query.to_owned()),
303 types: Cow::Owned(types.to_owned()),
304 };
305 let mut map = self.map.write().unwrap();
306 let removed = map.remove(&key);
307 if removed.is_some() {
308 let _ = self.size.fetch_sub(1, Ordering::Relaxed);
309 }
310 removed
311 }
312
313 fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
315 let key = StatementCacheKey {
316 query: Cow::Borrowed(query),
317 types: Cow::Borrowed(types),
318 };
319 self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
320 }
321
322 fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
324 let key = StatementCacheKey {
325 query: Cow::Owned(query.to_owned()),
326 types: Cow::Owned(types.to_owned()),
327 };
328 let mut map = self.map.write().unwrap();
329 if map.insert(key, stmt).is_none() {
330 let _ = self.size.fetch_add(1, Ordering::Relaxed);
331 }
332 }
333
334 pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
339 self.prepare_typed(client, query, &[]).await
340 }
341
342 pub async fn prepare_typed(
347 &self,
348 client: &PgClient,
349 query: &str,
350 types: &[Type],
351 ) -> Result<Statement, Error> {
352 match self.get(query, types) {
353 Some(statement) => Ok(statement),
354 None => {
355 let stmt = client.prepare_typed(query, types).await?;
356 self.insert(query, types, stmt.clone());
357 Ok(stmt)
358 }
359 }
360 }
361}
362
363#[derive(Debug)]
365pub struct ClientWrapper {
366 client: PgClient,
368
369 pub statement_cache: Arc<StatementCache>,
371}
372
373impl ClientWrapper {
374 #[must_use]
377 pub fn new(client: PgClient) -> Self {
378 Self {
379 client,
380 statement_cache: Arc::new(StatementCache::new()),
381 }
382 }
383
384 pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
387 self.statement_cache.prepare(&self.client, query).await
388 }
389
390 pub async fn prepare_typed_cached(
393 &self,
394 query: &str,
395 types: &[Type],
396 ) -> Result<Statement, Error> {
397 self.statement_cache
398 .prepare_typed(&self.client, query, types)
399 .await
400 }
401
402 #[allow(unused_lifetimes)] pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
406 Ok(Transaction {
407 txn: PgClient::transaction(&mut self.client).await?,
408 statement_cache: self.statement_cache.clone(),
409 })
410 }
411
412 pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
415 TransactionBuilder {
416 builder: self.client.build_transaction(),
417 statement_cache: self.statement_cache.clone(),
418 }
419 }
420}
421
422impl Deref for ClientWrapper {
423 type Target = PgClient;
424
425 fn deref(&self) -> &PgClient {
426 &self.client
427 }
428}
429
430impl DerefMut for ClientWrapper {
431 fn deref_mut(&mut self) -> &mut PgClient {
432 &mut self.client
433 }
434}
435
436pub struct Transaction<'a> {
439 txn: PgTransaction<'a>,
441
442 pub statement_cache: Arc<StatementCache>,
444}
445
446impl<'a> Transaction<'a> {
447 pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
450 self.statement_cache.prepare(self.client(), query).await
451 }
452
453 pub async fn prepare_typed_cached(
456 &self,
457 query: &str,
458 types: &[Type],
459 ) -> Result<Statement, Error> {
460 self.statement_cache
461 .prepare_typed(self.client(), query, types)
462 .await
463 }
464
465 pub async fn commit(self) -> Result<(), Error> {
467 self.txn.commit().await
468 }
469
470 pub async fn rollback(self) -> Result<(), Error> {
472 self.txn.rollback().await
473 }
474
475 #[allow(unused_lifetimes)] pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
479 Ok(Transaction {
480 txn: PgTransaction::transaction(&mut self.txn).await?,
481 statement_cache: self.statement_cache.clone(),
482 })
483 }
484
485 #[allow(unused_lifetimes)] pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
489 where
490 I: Into<String>,
491 {
492 Ok(Transaction {
493 txn: PgTransaction::savepoint(&mut self.txn, name).await?,
494 statement_cache: self.statement_cache.clone(),
495 })
496 }
497}
498
499impl<'a> fmt::Debug for Transaction<'a> {
500 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501 f.debug_struct("Transaction")
502 .field("statement_cache", &self.statement_cache)
504 .finish()
505 }
506}
507
508impl<'a> Deref for Transaction<'a> {
509 type Target = PgTransaction<'a>;
510
511 fn deref(&self) -> &PgTransaction<'a> {
512 &self.txn
513 }
514}
515
516impl<'a> DerefMut for Transaction<'a> {
517 fn deref_mut(&mut self) -> &mut PgTransaction<'a> {
518 &mut self.txn
519 }
520}
521
522#[must_use = "builder does nothing itself, use `.start()` to use it"]
525pub struct TransactionBuilder<'a> {
526 builder: PgTransactionBuilder<'a>,
528
529 statement_cache: Arc<StatementCache>,
531}
532
533impl<'a> TransactionBuilder<'a> {
534 pub fn isolation_level(self, isolation_level: IsolationLevel) -> Self {
538 Self {
539 builder: self.builder.isolation_level(isolation_level),
540 statement_cache: self.statement_cache,
541 }
542 }
543
544 pub fn read_only(self, read_only: bool) -> Self {
548 Self {
549 builder: self.builder.read_only(read_only),
550 statement_cache: self.statement_cache,
551 }
552 }
553
554 pub fn deferrable(self, deferrable: bool) -> Self {
563 Self {
564 builder: self.builder.deferrable(deferrable),
565 statement_cache: self.statement_cache,
566 }
567 }
568
569 pub async fn start(self) -> Result<Transaction<'a>, Error> {
576 Ok(Transaction {
577 txn: self.builder.start().await?,
578 statement_cache: self.statement_cache,
579 })
580 }
581}
582
583impl<'a> fmt::Debug for TransactionBuilder<'a> {
584 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585 f.debug_struct("TransactionBuilder")
586 .field("statement_cache", &self.statement_cache)
588 .finish()
589 }
590}
591
592impl<'a> Deref for TransactionBuilder<'a> {
593 type Target = PgTransactionBuilder<'a>;
594
595 fn deref(&self) -> &Self::Target {
596 &self.builder
597 }
598}
599
600impl<'a> DerefMut for TransactionBuilder<'a> {
601 fn deref_mut(&mut self) -> &mut Self::Target {
602 &mut self.builder
603 }
604}