postgres/
connection.rs

1use crate::{Error, Notification};
2use futures_util::{future, pin_mut, Stream};
3use std::collections::VecDeque;
4use std::future::Future;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio::runtime::Runtime;
11use tokio_postgres::error::DbError;
12use tokio_postgres::AsyncMessage;
13
14pub struct Connection {
15    runtime: Runtime,
16    connection: Pin<Box<dyn Stream<Item = Result<AsyncMessage, Error>> + Send>>,
17    notifications: VecDeque<Notification>,
18    notice_callback: Arc<dyn Fn(DbError) + Sync + Send>,
19}
20
21impl Connection {
22    pub fn new<S, T>(
23        runtime: Runtime,
24        connection: tokio_postgres::Connection<S, T>,
25        notice_callback: Arc<dyn Fn(DbError) + Sync + Send>,
26    ) -> Connection
27    where
28        S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
29        T: AsyncRead + AsyncWrite + Unpin + 'static + Send,
30    {
31        Connection {
32            runtime,
33            connection: Box::pin(ConnectionStream { connection }),
34            notifications: VecDeque::new(),
35            notice_callback,
36        }
37    }
38
39    pub fn as_ref(&mut self) -> ConnectionRef<'_> {
40        ConnectionRef { connection: self }
41    }
42
43    pub fn enter<F, T>(&self, f: F) -> T
44    where
45        F: FnOnce() -> T,
46    {
47        let _guard = self.runtime.enter();
48        f()
49    }
50
51    pub fn block_on<F, T>(&mut self, future: F) -> Result<T, Error>
52    where
53        F: Future<Output = Result<T, Error>>,
54    {
55        pin_mut!(future);
56        self.poll_block_on(|cx, _, _| future.as_mut().poll(cx))
57    }
58
59    pub fn poll_block_on<F, T>(&mut self, mut f: F) -> Result<T, Error>
60    where
61        F: FnMut(&mut Context<'_>, &mut VecDeque<Notification>, bool) -> Poll<Result<T, Error>>,
62    {
63        let connection = &mut self.connection;
64        let notifications = &mut self.notifications;
65        let notice_callback = &mut self.notice_callback;
66        self.runtime.block_on({
67            future::poll_fn(|cx| {
68                let done = loop {
69                    match connection.as_mut().poll_next(cx) {
70                        Poll::Ready(Some(Ok(AsyncMessage::Notification(notification)))) => {
71                            notifications.push_back(notification);
72                        }
73                        Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => {
74                            notice_callback(notice)
75                        }
76                        Poll::Ready(Some(Ok(_))) => {}
77                        Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
78                        Poll::Ready(None) => break true,
79                        Poll::Pending => break false,
80                    }
81                };
82
83                f(cx, notifications, done)
84            })
85        })
86    }
87
88    pub fn notifications(&self) -> &VecDeque<Notification> {
89        &self.notifications
90    }
91
92    pub fn notifications_mut(&mut self) -> &mut VecDeque<Notification> {
93        &mut self.notifications
94    }
95}
96
97pub struct ConnectionRef<'a> {
98    connection: &'a mut Connection,
99}
100
101// no-op impl to extend the borrow until drop
102impl Drop for ConnectionRef<'_> {
103    #[inline]
104    fn drop(&mut self) {}
105}
106
107impl Deref for ConnectionRef<'_> {
108    type Target = Connection;
109
110    #[inline]
111    fn deref(&self) -> &Connection {
112        self.connection
113    }
114}
115
116impl DerefMut for ConnectionRef<'_> {
117    #[inline]
118    fn deref_mut(&mut self) -> &mut Connection {
119        self.connection
120    }
121}
122
123struct ConnectionStream<S, T> {
124    connection: tokio_postgres::Connection<S, T>,
125}
126
127impl<S, T> Stream for ConnectionStream<S, T>
128where
129    S: AsyncRead + AsyncWrite + Unpin,
130    T: AsyncRead + AsyncWrite + Unpin,
131{
132    type Item = Result<AsyncMessage, Error>;
133
134    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135        self.connection.poll_message(cx)
136    }
137}