1use std::borrow::Cow;
10
11use futures_util::FutureExt;
12
13use crate::{
14 connection_like::ToConnectionResult,
15 from_row,
16 prelude::{FromRow, StatementLike, ToConnection},
17 tracing_utils::LevelInfo,
18 BinaryProtocol, BoxFuture, Params, QueryResult, ResultSetStream, TextProtocol,
19};
20
21pub trait AsQuery: Send + Sync {
27 fn as_query(&self) -> Cow<'_, [u8]>;
28}
29
30impl AsQuery for &'_ [u8] {
31 fn as_query(&self) -> Cow<'_, [u8]> {
32 Cow::Borrowed(self)
33 }
34}
35
36macro_rules! impl_as_query_as_ref {
37 ($type: ty) => {
38 impl AsQuery for $type {
39 fn as_query(&self) -> Cow<'_, [u8]> {
40 Cow::Borrowed(self.as_ref())
41 }
42 }
43 };
44}
45
46impl_as_query_as_ref!(Vec<u8>);
47impl_as_query_as_ref!(&Vec<u8>);
48impl_as_query_as_ref!(Box<[u8]>);
49impl_as_query_as_ref!(Cow<'_, [u8]>);
50impl_as_query_as_ref!(std::sync::Arc<[u8]>);
51
52macro_rules! impl_as_query_as_bytes {
53 ($type: ty) => {
54 impl AsQuery for $type {
55 fn as_query(&self) -> Cow<'_, [u8]> {
56 Cow::Borrowed(self.as_bytes())
57 }
58 }
59 };
60}
61
62impl_as_query_as_bytes!(String);
63impl_as_query_as_bytes!(&String);
64impl_as_query_as_bytes!(&str);
65impl_as_query_as_bytes!(Box<str>);
66impl_as_query_as_bytes!(Cow<'_, str>);
67impl_as_query_as_bytes!(std::sync::Arc<str>);
68
69pub trait Query: Send + Sized {
94 type Protocol: crate::prelude::Protocol;
96
97 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, Self::Protocol>>
101 where
102 Self: 'a,
103 C: ToConnection<'a, 't> + 'a;
104
105 fn first<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Option<T>>
109 where
110 Self: 'a,
111 C: ToConnection<'a, 't> + 'a,
112 T: FromRow + Send + 'static,
113 {
114 async move {
115 let mut result = self.run(conn).await?;
116 let output = if result.is_empty() {
117 None
118 } else {
119 result.next().await?.map(from_row)
120 };
121 result.drop_result().await?;
122 Ok(output)
123 }
124 .boxed()
125 }
126
127 fn fetch<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Vec<T>>
131 where
132 Self: 'a,
133 C: ToConnection<'a, 't> + 'a,
134 T: FromRow + Send + 'static,
135 {
136 async move { self.run(conn).await?.collect_and_drop::<T>().await }.boxed()
137 }
138
139 fn reduce<'a, 't: 'a, T, U, F, C>(self, conn: C, init: U, next: F) -> BoxFuture<'a, U>
143 where
144 Self: 'a,
145 C: ToConnection<'a, 't> + 'a,
146 F: FnMut(U, T) -> U + Send + 'a,
147 T: FromRow + Send + 'static,
148 U: Send + 'a,
149 {
150 async move { self.run(conn).await?.reduce_and_drop(init, next).await }.boxed()
151 }
152
153 fn map<'a, 't: 'a, T, U, F, C>(self, conn: C, mut map: F) -> BoxFuture<'a, Vec<U>>
157 where
158 Self: 'a,
159 C: ToConnection<'a, 't> + 'a,
160 F: FnMut(T) -> U + Send + 'a,
161 T: FromRow + Send + 'static,
162 U: Send + 'a,
163 {
164 async move {
165 self.run(conn)
166 .await?
167 .map_and_drop(|row| map(from_row(row)))
168 .await
169 }
170 .boxed()
171 }
172
173 fn stream<'a, 't: 'a, T, C>(
179 self,
180 conn: C,
181 ) -> BoxFuture<'a, ResultSetStream<'a, 'a, 't, T, Self::Protocol>>
182 where
183 Self: 'a,
184 Self::Protocol: Unpin,
185 T: Unpin + FromRow + Send + 'static,
186 C: ToConnection<'a, 't> + 'a,
187 {
188 async move {
189 self.run(conn)
190 .await?
191 .stream_and_drop()
192 .await
193 .transpose()
194 .expect("At least one result set is expected")
195 }
196 .boxed()
197 }
198
199 fn ignore<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
203 where
204 Self: 'a,
205 C: ToConnection<'a, 't> + 'a,
206 {
207 async move { self.run(conn).await?.drop_result().await }.boxed()
208 }
209}
210
211impl<Q: AsQuery> Query for Q {
212 type Protocol = TextProtocol;
213
214 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, TextProtocol>>
215 where
216 Self: 'a,
217 C: ToConnection<'a, 't> + 'a,
218 {
219 async move {
220 let mut conn = match conn.to_connection() {
221 ToConnectionResult::Immediate(conn) => conn,
222 ToConnectionResult::Mediate(fut) => fut.await?,
223 };
224 conn.raw_query::<'_, _, LevelInfo>(self).await?;
225 Ok(QueryResult::new(conn))
226 }
227 .boxed()
228 }
229}
230
231#[derive(Debug, Clone, PartialEq, Eq)]
235pub struct QueryWithParams<Q, P> {
236 pub query: Q,
237 pub params: P,
238}
239
240pub trait WithParams: Sized {
242 fn with<P>(self, params: P) -> QueryWithParams<Self, P>;
243}
244
245impl<T: StatementLike> WithParams for T {
246 fn with<P>(self, params: P) -> QueryWithParams<Self, P> {
247 QueryWithParams {
248 query: self,
249 params,
250 }
251 }
252}
253
254impl<Q, P> Query for QueryWithParams<Q, P>
255where
256 Q: StatementLike,
257 P: Into<Params> + Send,
258{
259 type Protocol = BinaryProtocol;
260
261 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, BinaryProtocol>>
262 where
263 Self: 'a,
264 C: ToConnection<'a, 't> + 'a,
265 {
266 async move {
267 let mut conn = match conn.to_connection() {
268 ToConnectionResult::Immediate(conn) => conn,
269 ToConnectionResult::Mediate(fut) => fut.await?,
270 };
271
272 let statement = conn.get_statement(self.query).await?;
273
274 conn.execute_statement(&statement, self.params.into())
275 .await?;
276
277 Ok(QueryResult::new(conn))
278 }
279 .boxed()
280 }
281}
282
283pub trait BatchQuery {
308 fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
309 where
310 Self: 'a,
311 C: ToConnection<'a, 't> + 'a;
312}
313
314impl<Q, I, P> BatchQuery for QueryWithParams<Q, I>
315where
316 Q: StatementLike,
317 I: IntoIterator<Item = P> + Send,
318 I::IntoIter: Send,
319 P: Into<Params> + Send,
320{
321 fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
322 where
323 Self: 'a,
324 C: ToConnection<'a, 't> + 'a,
325 {
326 async move {
327 let mut conn = match conn.to_connection() {
328 ToConnectionResult::Immediate(conn) => conn,
329 ToConnectionResult::Mediate(fut) => fut.await?,
330 };
331
332 let statement = conn.get_statement(self.query).await?;
333
334 for params in self.params {
335 conn.execute_statement(&statement, params).await?;
336 }
337
338 Ok(())
339 }
340 .boxed()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use crate::{prelude::*, test_misc::get_opts, *};
347
348 #[tokio::test]
349 async fn should_run_text_query() -> Result<()> {
350 let query_static = "SELECT 1, 2 UNION ALL SELECT 3, 4; SELECT 5, 6;";
351 let query_string = String::from(query_static);
352
353 macro_rules! test {
354 ($query:expr, $conn:expr) => {{
355 let mut result = $query.run($conn).await?;
356 let result1: Vec<(u8, u8)> = result.collect().await?;
357 let result2: Vec<(u8, u8)> = result.collect().await?;
358 assert_eq!(result1, vec![(1, 2), (3, 4)]);
359 assert_eq!(result2, vec![(5, 6)]);
360
361 $query.ignore($conn).await?;
362
363 let result: Option<(u8, u8)> = $query.first($conn).await?;
364 assert_eq!(result, Some((1, 2)));
365
366 let result: Vec<(u8, u8)> = $query.fetch($conn).await?;
367 assert_eq!(result, vec![(1, 2), (3, 4)]);
368
369 let result = $query
370 .map($conn, |row: (u8, u8)| format!("{:?}", row))
371 .await?;
372 assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
373
374 let result = $query
375 .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
376 .await?;
377 assert_eq!(result, 10);
378 }};
379 }
380
381 let mut conn = Conn::new(get_opts()).await?;
382 test!(query_static, &mut conn);
383 test!(query_string.as_str(), &mut conn);
384
385 let mut tx = conn.start_transaction(Default::default()).await?;
386 test!(query_static, &mut tx);
387 test!(query_string.as_str(), &mut tx);
388 tx.rollback().await?;
389
390 conn.disconnect().await?;
391
392 let pool = Pool::new(get_opts());
393 test!(query_static, &pool);
394 test!(query_string.as_str(), &pool);
395
396 let mut tx = pool.start_transaction(Default::default()).await?;
397 test!(query_static, &mut tx);
398 test!(query_string.as_str(), &mut tx);
399 tx.rollback().await?;
400
401 pool.disconnect().await?;
402
403 Ok(())
404 }
405
406 #[tokio::test]
407 async fn should_run_bin_query() -> Result<()> {
408 macro_rules! query {
409 (@static) => {
410 "SELECT ?, ? UNION ALL SELECT ?, ?"
411 };
412 (@string) => {
413 String::from("SELECT ?, ? UNION ALL SELECT ?, ?")
414 };
415 (@boxed) => {
416 query!(@string).into_boxed_str()
417 };
418 (@arc) => {
419 std::sync::Arc::<str>::from(query!(@boxed))
420 };
421 }
422
423 let query_string = query!(@string);
424 let params_static = ("1", "2", "3", "4");
425 let params_string = (
426 "1".to_owned(),
427 "2".to_owned(),
428 "3".to_owned(),
429 "4".to_owned(),
430 );
431
432 macro_rules! test {
433 ($query:expr, $params:expr, $conn:expr) => {{
434 let query = { $query.with($params) };
435 let mut result = query.run($conn).await?;
436 let result1: Vec<(u8, u8)> = result.collect().await?;
437 assert_eq!(result1, vec![(1, 2), (3, 4)]);
438
439 $query.with($params).ignore($conn).await?;
440
441 let result: Option<(u8, u8)> = $query.with($params).first($conn).await?;
442 assert_eq!(result, Some((1, 2)));
443
444 let result: Vec<(u8, u8)> = $query.with($params).fetch($conn).await?;
445 assert_eq!(result, vec![(1, 2), (3, 4)]);
446
447 let result = $query
448 .with($params)
449 .map($conn, |row: (u8, u8)| format!("{:?}", row))
450 .await?;
451 assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
452
453 let result = $query
454 .with($params)
455 .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
456 .await?;
457 assert_eq!(result, 10);
458
459 $query
460 .with(vec![$params, $params, $params, $params])
461 .batch($conn)
462 .await?;
463 }};
464 }
465
466 let mut conn = Conn::new(get_opts()).await?;
467 let statement = conn.prep(query!(@static)).await?;
468 test!(query!(@static), params_static, &mut conn);
469 test!(query!(@string), params_string.clone(), &mut conn);
470 test!(query!(@boxed), params_string.clone(), &mut conn);
471 test!(query!(@arc), params_string.clone(), &mut conn);
472 test!(&query_string, params_string.clone(), &mut conn);
473 test!(&statement, params_string.clone(), &mut conn);
474 test!(statement.clone(), params_string.clone(), &mut conn);
475
476 let mut tx = conn.start_transaction(Default::default()).await?;
477 test!(query!(@static), params_string.clone(), &mut tx);
478 test!(query!(@string), params_static, &mut tx);
479 test!(&query_string, params_static, &mut tx);
480 test!(&statement, params_string.clone(), &mut tx);
481 test!(statement.clone(), params_string.clone(), &mut tx);
482 tx.rollback().await?;
483
484 conn.disconnect().await?;
485
486 let pool = Pool::new(get_opts());
487 test!(query!(@static), params_static, &pool);
488 test!(query!(@string), params_string.clone(), &pool);
489 test!(&query_string, params_string.clone(), &pool);
490
491 let mut tx = pool.start_transaction(Default::default()).await?;
492 test!(query!(@static), params_string.clone(), &mut tx);
493 test!(query!(@string), params_static, &mut tx);
494 test!(&query_string, params_static, &mut tx);
495 tx.rollback().await?;
496
497 pool.disconnect().await?;
498
499 Ok(())
500 }
501}