tokio_postgres/
transaction.rs

1use crate::codec::FrontendMessage;
2use crate::connection::RequestMessages;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{
12    bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
13    SimpleQueryMessage, Statement, ToStatement,
14};
15use bytes::Buf;
16use futures_util::TryStreamExt;
17use postgres_protocol::message::frontend;
18use tokio::io::{AsyncRead, AsyncWrite};
19
20/// A representation of a PostgreSQL database transaction.
21///
22/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
23/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
24pub struct Transaction<'a> {
25    client: &'a mut Client,
26    savepoint: Option<Savepoint>,
27    done: bool,
28}
29
30/// A representation of a PostgreSQL database savepoint.
31struct Savepoint {
32    name: String,
33    depth: u32,
34}
35
36impl<'a> Drop for Transaction<'a> {
37    fn drop(&mut self) {
38        if self.done {
39            return;
40        }
41
42        let query = if let Some(sp) = self.savepoint.as_ref() {
43            format!("ROLLBACK TO {}", sp.name)
44        } else {
45            "ROLLBACK".to_string()
46        };
47        let buf = self.client.inner().with_buf(|buf| {
48            frontend::query(&query, buf).unwrap();
49            buf.split().freeze()
50        });
51        let _ = self
52            .client
53            .inner()
54            .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
55    }
56}
57
58impl<'a> Transaction<'a> {
59    pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
60        Transaction {
61            client,
62            savepoint: None,
63            done: false,
64        }
65    }
66
67    /// Consumes the transaction, committing all changes made within it.
68    pub async fn commit(mut self) -> Result<(), Error> {
69        self.done = true;
70        let query = if let Some(sp) = self.savepoint.as_ref() {
71            format!("RELEASE {}", sp.name)
72        } else {
73            "COMMIT".to_string()
74        };
75        self.client.batch_execute(&query).await
76    }
77
78    /// Rolls the transaction back, discarding all changes made within it.
79    ///
80    /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
81    pub async fn rollback(mut self) -> Result<(), Error> {
82        self.done = true;
83        let query = if let Some(sp) = self.savepoint.as_ref() {
84            format!("ROLLBACK TO {}", sp.name)
85        } else {
86            "ROLLBACK".to_string()
87        };
88        self.client.batch_execute(&query).await
89    }
90
91    /// Like `Client::prepare`.
92    pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
93        self.client.prepare(query).await
94    }
95
96    /// Like `Client::prepare_typed`.
97    pub async fn prepare_typed(
98        &self,
99        query: &str,
100        parameter_types: &[Type],
101    ) -> Result<Statement, Error> {
102        self.client.prepare_typed(query, parameter_types).await
103    }
104
105    /// Like `Client::query`.
106    pub async fn query<T>(
107        &self,
108        statement: &T,
109        params: &[&(dyn ToSql + Sync)],
110    ) -> Result<Vec<Row>, Error>
111    where
112        T: ?Sized + ToStatement,
113    {
114        self.client.query(statement, params).await
115    }
116
117    /// Like `Client::query_one`.
118    pub async fn query_one<T>(
119        &self,
120        statement: &T,
121        params: &[&(dyn ToSql + Sync)],
122    ) -> Result<Row, Error>
123    where
124        T: ?Sized + ToStatement,
125    {
126        self.client.query_one(statement, params).await
127    }
128
129    /// Like `Client::query_opt`.
130    pub async fn query_opt<T>(
131        &self,
132        statement: &T,
133        params: &[&(dyn ToSql + Sync)],
134    ) -> Result<Option<Row>, Error>
135    where
136        T: ?Sized + ToStatement,
137    {
138        self.client.query_opt(statement, params).await
139    }
140
141    /// Like `Client::query_raw`.
142    pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
143    where
144        T: ?Sized + ToStatement,
145        P: BorrowToSql,
146        I: IntoIterator<Item = P>,
147        I::IntoIter: ExactSizeIterator,
148    {
149        self.client.query_raw(statement, params).await
150    }
151
152    /// Like `Client::query_typed`.
153    pub async fn query_typed(
154        &self,
155        statement: &str,
156        params: &[(&(dyn ToSql + Sync), Type)],
157    ) -> Result<Vec<Row>, Error> {
158        self.client.query_typed(statement, params).await
159    }
160
161    /// Like `Client::query_typed_raw`.
162    pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
163    where
164        P: BorrowToSql,
165        I: IntoIterator<Item = (P, Type)>,
166    {
167        self.client.query_typed_raw(query, params).await
168    }
169
170    /// Like `Client::execute`.
171    pub async fn execute<T>(
172        &self,
173        statement: &T,
174        params: &[&(dyn ToSql + Sync)],
175    ) -> Result<u64, Error>
176    where
177        T: ?Sized + ToStatement,
178    {
179        self.client.execute(statement, params).await
180    }
181
182    /// Like `Client::execute_iter`.
183    pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
184    where
185        T: ?Sized + ToStatement,
186        P: BorrowToSql,
187        I: IntoIterator<Item = P>,
188        I::IntoIter: ExactSizeIterator,
189    {
190        self.client.execute_raw(statement, params).await
191    }
192
193    /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
194    ///
195    /// Portals only last for the duration of the transaction in which they are created, and can only be used on the
196    /// connection that created them.
197    ///
198    /// # Panics
199    ///
200    /// Panics if the number of parameters provided does not match the number expected.
201    pub async fn bind<T>(
202        &self,
203        statement: &T,
204        params: &[&(dyn ToSql + Sync)],
205    ) -> Result<Portal, Error>
206    where
207        T: ?Sized + ToStatement,
208    {
209        self.bind_raw(statement, slice_iter(params)).await
210    }
211
212    /// A maximally flexible version of [`bind`].
213    ///
214    /// [`bind`]: #method.bind
215    pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
216    where
217        T: ?Sized + ToStatement,
218        P: BorrowToSql,
219        I: IntoIterator<Item = P>,
220        I::IntoIter: ExactSizeIterator,
221    {
222        let statement = statement.__convert().into_statement(self.client).await?;
223        bind::bind(self.client.inner(), statement, params).await
224    }
225
226    /// Continues execution of a portal, returning a stream of the resulting rows.
227    ///
228    /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
229    /// `query_portal`. If the requested number is negative or 0, all rows will be returned.
230    pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
231        self.query_portal_raw(portal, max_rows)
232            .await?
233            .try_collect()
234            .await
235    }
236
237    /// The maximally flexible version of [`query_portal`].
238    ///
239    /// [`query_portal`]: #method.query_portal
240    pub async fn query_portal_raw(
241        &self,
242        portal: &Portal,
243        max_rows: i32,
244    ) -> Result<RowStream, Error> {
245        query::query_portal(self.client.inner(), portal, max_rows).await
246    }
247
248    /// Like `Client::copy_in`.
249    pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
250    where
251        T: ?Sized + ToStatement,
252        U: Buf + 'static + Send,
253    {
254        self.client.copy_in(statement).await
255    }
256
257    /// Like `Client::copy_out`.
258    pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
259    where
260        T: ?Sized + ToStatement,
261    {
262        self.client.copy_out(statement).await
263    }
264
265    /// Like `Client::simple_query`.
266    pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
267        self.client.simple_query(query).await
268    }
269
270    /// Like `Client::batch_execute`.
271    pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
272        self.client.batch_execute(query).await
273    }
274
275    /// Like `Client::cancel_token`.
276    pub fn cancel_token(&self) -> CancelToken {
277        self.client.cancel_token()
278    }
279
280    /// Like `Client::cancel_query`.
281    #[cfg(feature = "runtime")]
282    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
283    pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
284    where
285        T: MakeTlsConnect<Socket>,
286    {
287        #[allow(deprecated)]
288        self.client.cancel_query(tls).await
289    }
290
291    /// Like `Client::cancel_query_raw`.
292    #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
293    pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
294    where
295        S: AsyncRead + AsyncWrite + Unpin,
296        T: TlsConnect<S>,
297    {
298        #[allow(deprecated)]
299        self.client.cancel_query_raw(stream, tls).await
300    }
301
302    /// Like `Client::transaction`, but creates a nested transaction via a savepoint.
303    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
304        self._savepoint(None).await
305    }
306
307    /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
308    pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
309    where
310        I: Into<String>,
311    {
312        self._savepoint(Some(name.into())).await
313    }
314
315    async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
316        let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
317        let name = name.unwrap_or_else(|| format!("sp_{}", depth));
318        let query = format!("SAVEPOINT {}", name);
319        self.batch_execute(&query).await?;
320
321        Ok(Transaction {
322            client: self.client,
323            savepoint: Some(Savepoint { name, depth }),
324            done: false,
325        })
326    }
327
328    /// Returns a reference to the underlying `Client`.
329    pub fn client(&self) -> &Client {
330        self.client
331    }
332}