1use crate::codec::FrontendMessage;
2use crate::connection::RequestMessages;
3use crate::copy_out::CopyOutStream;
4use crate::query::RowStream;
5#[cfg(feature = "runtime")]
6use crate::tls::MakeTlsConnect;
7use crate::tls::TlsConnect;
8use crate::types::{BorrowToSql, ToSql, Type};
9#[cfg(feature = "runtime")]
10use crate::Socket;
11use crate::{
12 bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row,
13 SimpleQueryMessage, Statement, ToStatement,
14};
15use bytes::Buf;
16use futures_util::TryStreamExt;
17use postgres_protocol::message::frontend;
18use tokio::io::{AsyncRead, AsyncWrite};
19
20pub struct Transaction<'a> {
25 client: &'a mut Client,
26 savepoint: Option<Savepoint>,
27 done: bool,
28}
29
30struct Savepoint {
32 name: String,
33 depth: u32,
34}
35
36impl<'a> Drop for Transaction<'a> {
37 fn drop(&mut self) {
38 if self.done {
39 return;
40 }
41
42 let query = if let Some(sp) = self.savepoint.as_ref() {
43 format!("ROLLBACK TO {}", sp.name)
44 } else {
45 "ROLLBACK".to_string()
46 };
47 let buf = self.client.inner().with_buf(|buf| {
48 frontend::query(&query, buf).unwrap();
49 buf.split().freeze()
50 });
51 let _ = self
52 .client
53 .inner()
54 .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
55 }
56}
57
58impl<'a> Transaction<'a> {
59 pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
60 Transaction {
61 client,
62 savepoint: None,
63 done: false,
64 }
65 }
66
67 pub async fn commit(mut self) -> Result<(), Error> {
69 self.done = true;
70 let query = if let Some(sp) = self.savepoint.as_ref() {
71 format!("RELEASE {}", sp.name)
72 } else {
73 "COMMIT".to_string()
74 };
75 self.client.batch_execute(&query).await
76 }
77
78 pub async fn rollback(mut self) -> Result<(), Error> {
82 self.done = true;
83 let query = if let Some(sp) = self.savepoint.as_ref() {
84 format!("ROLLBACK TO {}", sp.name)
85 } else {
86 "ROLLBACK".to_string()
87 };
88 self.client.batch_execute(&query).await
89 }
90
91 pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
93 self.client.prepare(query).await
94 }
95
96 pub async fn prepare_typed(
98 &self,
99 query: &str,
100 parameter_types: &[Type],
101 ) -> Result<Statement, Error> {
102 self.client.prepare_typed(query, parameter_types).await
103 }
104
105 pub async fn query<T>(
107 &self,
108 statement: &T,
109 params: &[&(dyn ToSql + Sync)],
110 ) -> Result<Vec<Row>, Error>
111 where
112 T: ?Sized + ToStatement,
113 {
114 self.client.query(statement, params).await
115 }
116
117 pub async fn query_one<T>(
119 &self,
120 statement: &T,
121 params: &[&(dyn ToSql + Sync)],
122 ) -> Result<Row, Error>
123 where
124 T: ?Sized + ToStatement,
125 {
126 self.client.query_one(statement, params).await
127 }
128
129 pub async fn query_opt<T>(
131 &self,
132 statement: &T,
133 params: &[&(dyn ToSql + Sync)],
134 ) -> Result<Option<Row>, Error>
135 where
136 T: ?Sized + ToStatement,
137 {
138 self.client.query_opt(statement, params).await
139 }
140
141 pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
143 where
144 T: ?Sized + ToStatement,
145 P: BorrowToSql,
146 I: IntoIterator<Item = P>,
147 I::IntoIter: ExactSizeIterator,
148 {
149 self.client.query_raw(statement, params).await
150 }
151
152 pub async fn query_typed(
154 &self,
155 statement: &str,
156 params: &[(&(dyn ToSql + Sync), Type)],
157 ) -> Result<Vec<Row>, Error> {
158 self.client.query_typed(statement, params).await
159 }
160
161 pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
163 where
164 P: BorrowToSql,
165 I: IntoIterator<Item = (P, Type)>,
166 {
167 self.client.query_typed_raw(query, params).await
168 }
169
170 pub async fn execute<T>(
172 &self,
173 statement: &T,
174 params: &[&(dyn ToSql + Sync)],
175 ) -> Result<u64, Error>
176 where
177 T: ?Sized + ToStatement,
178 {
179 self.client.execute(statement, params).await
180 }
181
182 pub async fn execute_raw<P, I, T>(&self, statement: &T, params: I) -> Result<u64, Error>
184 where
185 T: ?Sized + ToStatement,
186 P: BorrowToSql,
187 I: IntoIterator<Item = P>,
188 I::IntoIter: ExactSizeIterator,
189 {
190 self.client.execute_raw(statement, params).await
191 }
192
193 pub async fn bind<T>(
202 &self,
203 statement: &T,
204 params: &[&(dyn ToSql + Sync)],
205 ) -> Result<Portal, Error>
206 where
207 T: ?Sized + ToStatement,
208 {
209 self.bind_raw(statement, slice_iter(params)).await
210 }
211
212 pub async fn bind_raw<P, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
216 where
217 T: ?Sized + ToStatement,
218 P: BorrowToSql,
219 I: IntoIterator<Item = P>,
220 I::IntoIter: ExactSizeIterator,
221 {
222 let statement = statement.__convert().into_statement(self.client).await?;
223 bind::bind(self.client.inner(), statement, params).await
224 }
225
226 pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
231 self.query_portal_raw(portal, max_rows)
232 .await?
233 .try_collect()
234 .await
235 }
236
237 pub async fn query_portal_raw(
241 &self,
242 portal: &Portal,
243 max_rows: i32,
244 ) -> Result<RowStream, Error> {
245 query::query_portal(self.client.inner(), portal, max_rows).await
246 }
247
248 pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
250 where
251 T: ?Sized + ToStatement,
252 U: Buf + 'static + Send,
253 {
254 self.client.copy_in(statement).await
255 }
256
257 pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
259 where
260 T: ?Sized + ToStatement,
261 {
262 self.client.copy_out(statement).await
263 }
264
265 pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
267 self.client.simple_query(query).await
268 }
269
270 pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
272 self.client.batch_execute(query).await
273 }
274
275 pub fn cancel_token(&self) -> CancelToken {
277 self.client.cancel_token()
278 }
279
280 #[cfg(feature = "runtime")]
282 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
283 pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
284 where
285 T: MakeTlsConnect<Socket>,
286 {
287 #[allow(deprecated)]
288 self.client.cancel_query(tls).await
289 }
290
291 #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")]
293 pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
294 where
295 S: AsyncRead + AsyncWrite + Unpin,
296 T: TlsConnect<S>,
297 {
298 #[allow(deprecated)]
299 self.client.cancel_query_raw(stream, tls).await
300 }
301
302 pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
304 self._savepoint(None).await
305 }
306
307 pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
309 where
310 I: Into<String>,
311 {
312 self._savepoint(Some(name.into())).await
313 }
314
315 async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
316 let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
317 let name = name.unwrap_or_else(|| format!("sp_{}", depth));
318 let query = format!("SAVEPOINT {}", name);
319 self.batch_execute(&query).await?;
320
321 Ok(Transaction {
322 client: self.client,
323 savepoint: Some(Savepoint { name, depth }),
324 done: false,
325 })
326 }
327
328 pub fn client(&self) -> &Client {
330 self.client
331 }
332}