1use crate::connection::ConnectionRef;
2use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
3use tokio_postgres::types::{BorrowToSql, ToSql, Type};
4use tokio_postgres::{Error, Row, SimpleQueryMessage};
5
6pub struct Transaction<'a> {
11 connection: ConnectionRef<'a>,
12 transaction: Option<tokio_postgres::Transaction<'a>>,
13}
14
15impl<'a> Drop for Transaction<'a> {
16 fn drop(&mut self) {
17 if let Some(transaction) = self.transaction.take() {
18 let _ = self.connection.block_on(transaction.rollback());
19 }
20 }
21}
22
23impl<'a> Transaction<'a> {
24 pub(crate) fn new(
25 connection: ConnectionRef<'a>,
26 transaction: tokio_postgres::Transaction<'a>,
27 ) -> Transaction<'a> {
28 Transaction {
29 connection,
30 transaction: Some(transaction),
31 }
32 }
33
34 pub fn commit(mut self) -> Result<(), Error> {
36 self.connection
37 .block_on(self.transaction.take().unwrap().commit())
38 }
39
40 pub fn rollback(mut self) -> Result<(), Error> {
44 self.connection
45 .block_on(self.transaction.take().unwrap().rollback())
46 }
47
48 pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
50 self.connection
51 .block_on(self.transaction.as_ref().unwrap().prepare(query))
52 }
53
54 pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
56 self.connection.block_on(
57 self.transaction
58 .as_ref()
59 .unwrap()
60 .prepare_typed(query, types),
61 )
62 }
63
64 pub fn execute<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
66 where
67 T: ?Sized + ToStatement,
68 {
69 self.connection
70 .block_on(self.transaction.as_ref().unwrap().execute(query, params))
71 }
72
73 pub fn query<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
75 where
76 T: ?Sized + ToStatement,
77 {
78 self.connection
79 .block_on(self.transaction.as_ref().unwrap().query(query, params))
80 }
81
82 pub fn query_one<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Row, Error>
84 where
85 T: ?Sized + ToStatement,
86 {
87 self.connection
88 .block_on(self.transaction.as_ref().unwrap().query_one(query, params))
89 }
90
91 pub fn query_opt<T>(
93 &mut self,
94 query: &T,
95 params: &[&(dyn ToSql + Sync)],
96 ) -> Result<Option<Row>, Error>
97 where
98 T: ?Sized + ToStatement,
99 {
100 self.connection
101 .block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
102 }
103
104 pub fn query_raw<T, P, I>(&mut self, query: &T, params: I) -> Result<RowIter<'_>, Error>
106 where
107 T: ?Sized + ToStatement,
108 P: BorrowToSql,
109 I: IntoIterator<Item = P>,
110 I::IntoIter: ExactSizeIterator,
111 {
112 let stream = self
113 .connection
114 .block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
115 Ok(RowIter::new(self.connection.as_ref(), stream))
116 }
117
118 pub fn query_typed(
120 &mut self,
121 statement: &str,
122 params: &[(&(dyn ToSql + Sync), Type)],
123 ) -> Result<Vec<Row>, Error> {
124 self.connection.block_on(
125 self.transaction
126 .as_ref()
127 .unwrap()
128 .query_typed(statement, params),
129 )
130 }
131
132 pub fn query_typed_raw<P, I>(&mut self, query: &str, params: I) -> Result<RowIter<'_>, Error>
134 where
135 P: BorrowToSql,
136 I: IntoIterator<Item = (P, Type)>,
137 {
138 let stream = self.connection.block_on(
139 self.transaction
140 .as_ref()
141 .unwrap()
142 .query_typed_raw(query, params),
143 )?;
144 Ok(RowIter::new(self.connection.as_ref(), stream))
145 }
146
147 pub fn bind<T>(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Portal, Error>
158 where
159 T: ?Sized + ToStatement,
160 {
161 self.connection
162 .block_on(self.transaction.as_ref().unwrap().bind(query, params))
163 }
164
165 pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
170 self.connection.block_on(
171 self.transaction
172 .as_ref()
173 .unwrap()
174 .query_portal(portal, max_rows),
175 )
176 }
177
178 pub fn query_portal_raw(
180 &mut self,
181 portal: &Portal,
182 max_rows: i32,
183 ) -> Result<RowIter<'_>, Error> {
184 let stream = self.connection.block_on(
185 self.transaction
186 .as_ref()
187 .unwrap()
188 .query_portal_raw(portal, max_rows),
189 )?;
190 Ok(RowIter::new(self.connection.as_ref(), stream))
191 }
192
193 pub fn copy_in<T>(&mut self, query: &T) -> Result<CopyInWriter<'_>, Error>
195 where
196 T: ?Sized + ToStatement,
197 {
198 let sink = self
199 .connection
200 .block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
201 Ok(CopyInWriter::new(self.connection.as_ref(), sink))
202 }
203
204 pub fn copy_out<T>(&mut self, query: &T) -> Result<CopyOutReader<'_>, Error>
206 where
207 T: ?Sized + ToStatement,
208 {
209 let stream = self
210 .connection
211 .block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
212 Ok(CopyOutReader::new(self.connection.as_ref(), stream))
213 }
214
215 pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
217 self.connection
218 .block_on(self.transaction.as_ref().unwrap().simple_query(query))
219 }
220
221 pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
223 self.connection
224 .block_on(self.transaction.as_ref().unwrap().batch_execute(query))
225 }
226
227 pub fn cancel_token(&self) -> CancelToken {
229 CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
230 }
231
232 pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
234 let transaction = self
235 .connection
236 .block_on(self.transaction.as_mut().unwrap().transaction())?;
237 Ok(Transaction::new(self.connection.as_ref(), transaction))
238 }
239
240 pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
242 where
243 I: Into<String>,
244 {
245 let transaction = self
246 .connection
247 .block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
248 Ok(Transaction::new(self.connection.as_ref(), transaction))
249 }
250}