postgres/
transaction.rs

1use crate::connection::ConnectionRef;
2use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
3use tokio_postgres::types::{BorrowToSql, ToSql, Type};
4use tokio_postgres::{Error, Row, SimpleQueryMessage};
5
6/// A representation of a PostgreSQL database transaction.
7///
8/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
9/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
10pub struct Transaction<'a> {
11    connection: ConnectionRef<'a>,
12    transaction: Option<tokio_postgres::Transaction<'a>>,
13}
14
15impl<'a> Drop for Transaction<'a> {
16    fn drop(&mut self) {
17        if let Some(transaction) = self.transaction.take() {
18            let _ = self.connection.block_on(transaction.rollback());
19        }
20    }
21}
22
23impl<'a> Transaction<'a> {
24    pub(crate) fn new(
25        connection: ConnectionRef<'a>,
26        transaction: tokio_postgres::Transaction<'a>,
27    ) -> Transaction<'a> {
28        Transaction {
29            connection,
30            transaction: Some(transaction),
31        }
32    }
33
34    /// Consumes the transaction, committing all changes made within it.
35    pub fn commit(mut self) -> Result<(), Error> {
36        self.connection
37            .block_on(self.transaction.take().unwrap().commit())
38    }
39
40    /// Rolls the transaction back, discarding all changes made within it.
41    ///
42    /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
43    pub fn rollback(mut self) -> Result<(), Error> {
44        self.connection
45            .block_on(self.transaction.take().unwrap().rollback())
46    }
47
48    /// Like `Client::prepare`.
49    pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
50        self.connection
51            .block_on(self.transaction.as_ref().unwrap().prepare(query))
52    }
53
54    /// Like `Client::prepare_typed`.
55    pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
56        self.connection.block_on(
57            self.transaction
58                .as_ref()
59                .unwrap()
60                .prepare_typed(query, types),
61        )
62    }
63
64    /// Like `Client::execute`.
65    pub fn execute<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
66    where
67        T: ?Sized + ToStatement,
68    {
69        self.connection
70            .block_on(self.transaction.as_ref().unwrap().execute(query, params))
71    }
72
73    /// Like `Client::query`.
74    pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
75    where
76        T: ?Sized + ToStatement,
77    {
78        self.connection
79            .block_on(self.transaction.as_ref().unwrap().query(query, params))
80    }
81
82    /// Like `Client::query_one`.
83    pub fn query_one<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error>
84    where
85        T: ?Sized + ToStatement,
86    {
87        self.connection
88            .block_on(self.transaction.as_ref().unwrap().query_one(query, params))
89    }
90
91    /// Like `Client::query_opt`.
92    pub fn query_opt<T>(
93        &mut self,
94        query: &T,
95        params: &[&(dyn ToSql + Sync)],
96    ) -> Result<Option<Row>, Error>
97    where
98        T: ?Sized + ToStatement,
99    {
100        self.connection
101            .block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
102    }
103
104    /// Like `Client::query_raw`.
105    pub fn query_raw<T, P, I>(&mut self, query: &T, params: I) -> Result<RowIter<'_>, Error>
106    where
107        T: ?Sized + ToStatement,
108        P: BorrowToSql,
109        I: IntoIterator<Item = P>,
110        I::IntoIter: ExactSizeIterator,
111    {
112        let stream = self
113            .connection
114            .block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
115        Ok(RowIter::new(self.connection.as_ref(), stream))
116    }
117
118    /// Like `Client::query_typed`.
119    pub fn query_typed(
120        &mut self,
121        statement: &str,
122        params: &[(&(dyn ToSql + Sync), Type)],
123    ) -> Result<Vec<Row>, Error> {
124        self.connection.block_on(
125            self.transaction
126                .as_ref()
127                .unwrap()
128                .query_typed(statement, params),
129        )
130    }
131
132    /// Like `Client::query_typed_raw`.
133    pub fn query_typed_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter<'_>, Error>
134    where
135        P: BorrowToSql,
136        I: IntoIterator<Item = (P, Type)>,
137    {
138        let stream = self.connection.block_on(
139            self.transaction
140                .as_ref()
141                .unwrap()
142                .query_typed_raw(query, params),
143        )?;
144        Ok(RowIter::new(self.connection.as_ref(), stream))
145    }
146
147    /// Binds parameters to a statement, creating a "portal".
148    ///
149    /// Portals can be used with the `query_portal` method to page through the results of a query without being forced
150    /// to consume them all immediately.
151    ///
152    /// Portals are automatically closed when the transaction they were created in is closed.
153    ///
154    /// # Panics
155    ///
156    /// Panics if the number of parameters provided does not match the number expected.
157    pub fn bind<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Portal, Error>
158    where
159        T: ?Sized + ToStatement,
160    {
161        self.connection
162            .block_on(self.transaction.as_ref().unwrap().bind(query, params))
163    }
164
165    /// Continues execution of a portal, returning the next set of rows.
166    ///
167    /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
168    /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
169    pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
170        self.connection.block_on(
171            self.transaction
172                .as_ref()
173                .unwrap()
174                .query_portal(portal, max_rows),
175        )
176    }
177
178    /// The maximally flexible version of `query_portal`.
179    pub fn query_portal_raw(
180        &mut self,
181        portal: &Portal,
182        max_rows: i32,
183    ) -> Result<RowIter<'_>, Error> {
184        let stream = self.connection.block_on(
185            self.transaction
186                .as_ref()
187                .unwrap()
188                .query_portal_raw(portal, max_rows),
189        )?;
190        Ok(RowIter::new(self.connection.as_ref(), stream))
191    }
192
193    /// Like `Client::copy_in`.
194    pub fn copy_in<T>(&mut self, query: &T) -> Result<CopyInWriter<'_>, Error>
195    where
196        T: ?Sized + ToStatement,
197    {
198        let sink = self
199            .connection
200            .block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
201        Ok(CopyInWriter::new(self.connection.as_ref(), sink))
202    }
203
204    /// Like `Client::copy_out`.
205    pub fn copy_out<T>(&mut self, query: &T) -> Result<CopyOutReader<'_>, Error>
206    where
207        T: ?Sized + ToStatement,
208    {
209        let stream = self
210            .connection
211            .block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
212        Ok(CopyOutReader::new(self.connection.as_ref(), stream))
213    }
214
215    /// Like `Client::simple_query`.
216    pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
217        self.connection
218            .block_on(self.transaction.as_ref().unwrap().simple_query(query))
219    }
220
221    /// Like `Client::batch_execute`.
222    pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
223        self.connection
224            .block_on(self.transaction.as_ref().unwrap().batch_execute(query))
225    }
226
227    /// Like `Client::cancel_token`.
228    pub fn cancel_token(&self) -> CancelToken {
229        CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
230    }
231
232    /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
233    pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
234        let transaction = self
235            .connection
236            .block_on(self.transaction.as_mut().unwrap().transaction())?;
237        Ok(Transaction::new(self.connection.as_ref(), transaction))
238    }
239
240    /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
241    pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
242    where
243        I: Into<String>,
244    {
245        let transaction = self
246            .connection
247            .block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
248        Ok(Transaction::new(self.connection.as_ref(), transaction))
249    }
250}