1use std::borrow::Cow;
10
11use futures_util::FutureExt;
12
13use crate::{
14 from_row,
15 prelude::{FromRow, StatementLike, ToConnection},
16 tracing_utils::LevelInfo,
17 BinaryProtocol, BoxFuture, Params, QueryResult, ResultSetStream, TextProtocol,
18};
19
20pub trait AsQuery: Send + Sync {
26 fn as_query(&self) -> Cow<'_, [u8]>;
27}
28
29impl AsQuery for &'_ [u8] {
30 fn as_query(&self) -> Cow<'_, [u8]> {
31 Cow::Borrowed(self)
32 }
33}
34
35macro_rules! impl_as_query_as_ref {
36 ($type: ty) => {
37 impl AsQuery for $type {
38 fn as_query(&self) -> Cow<'_, [u8]> {
39 Cow::Borrowed(self.as_ref())
40 }
41 }
42 };
43}
44
45impl_as_query_as_ref!(Vec<u8>);
46impl_as_query_as_ref!(&Vec<u8>);
47impl_as_query_as_ref!(Box<[u8]>);
48impl_as_query_as_ref!(Cow<'_, [u8]>);
49impl_as_query_as_ref!(std::sync::Arc<[u8]>);
50
51macro_rules! impl_as_query_as_bytes {
52 ($type: ty) => {
53 impl AsQuery for $type {
54 fn as_query(&self) -> Cow<'_, [u8]> {
55 Cow::Borrowed(self.as_bytes())
56 }
57 }
58 };
59}
60
61impl_as_query_as_bytes!(String);
62impl_as_query_as_bytes!(&String);
63impl_as_query_as_bytes!(&str);
64impl_as_query_as_bytes!(Box<str>);
65impl_as_query_as_bytes!(Cow<'_, str>);
66impl_as_query_as_bytes!(std::sync::Arc<str>);
67
68pub trait Query: Send + Sized {
93 type Protocol: crate::prelude::Protocol;
95
96 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, Self::Protocol>>
100 where
101 Self: 'a,
102 C: ToConnection<'a, 't> + 'a;
103
104 fn first<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Option<T>>
108 where
109 Self: 'a,
110 C: ToConnection<'a, 't> + 'a,
111 T: FromRow + Send + 'static,
112 {
113 async move {
114 let mut result = self.run(conn).await?;
115 let output = if result.is_empty() {
116 None
117 } else {
118 result.next().await?.map(from_row)
119 };
120 result.drop_result().await?;
121 Ok(output)
122 }
123 .boxed()
124 }
125
126 fn fetch<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Vec<T>>
130 where
131 Self: 'a,
132 C: ToConnection<'a, 't> + 'a,
133 T: FromRow + Send + 'static,
134 {
135 async move { self.run(conn).await?.collect_and_drop::<T>().await }.boxed()
136 }
137
138 fn reduce<'a, 't: 'a, T, U, F, C>(self, conn: C, init: U, next: F) -> BoxFuture<'a, U>
142 where
143 Self: 'a,
144 C: ToConnection<'a, 't> + 'a,
145 F: FnMut(U, T) -> U + Send + 'a,
146 T: FromRow + Send + 'static,
147 U: Send + 'a,
148 {
149 async move { self.run(conn).await?.reduce_and_drop(init, next).await }.boxed()
150 }
151
152 fn map<'a, 't: 'a, T, U, F, C>(self, conn: C, mut map: F) -> BoxFuture<'a, Vec<U>>
156 where
157 Self: 'a,
158 C: ToConnection<'a, 't> + 'a,
159 F: FnMut(T) -> U + Send + 'a,
160 T: FromRow + Send + 'static,
161 U: Send + 'a,
162 {
163 async move {
164 self.run(conn)
165 .await?
166 .map_and_drop(|row| map(from_row(row)))
167 .await
168 }
169 .boxed()
170 }
171
172 fn stream<'a, 't: 'a, T, C>(
178 self,
179 conn: C,
180 ) -> BoxFuture<'a, ResultSetStream<'a, 'a, 't, T, Self::Protocol>>
181 where
182 Self: 'a,
183 Self::Protocol: Unpin,
184 T: Unpin + FromRow + Send + 'static,
185 C: ToConnection<'a, 't> + 'a,
186 {
187 async move {
188 self.run(conn)
189 .await?
190 .stream_and_drop()
191 .await
192 .transpose()
193 .expect("At least one result set is expected")
194 }
195 .boxed()
196 }
197
198 fn ignore<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
202 where
203 Self: 'a,
204 C: ToConnection<'a, 't> + 'a,
205 {
206 async move { self.run(conn).await?.drop_result().await }.boxed()
207 }
208}
209
210impl<Q: AsQuery> Query for Q {
211 type Protocol = TextProtocol;
212
213 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, TextProtocol>>
214 where
215 Self: 'a,
216 C: ToConnection<'a, 't> + 'a,
217 {
218 async move {
219 let mut conn = conn.to_connection().resolve().await?;
220 conn.as_mut().raw_query::<'_, _, LevelInfo>(self).await?;
221 Ok(QueryResult::new(conn))
222 }
223 .boxed()
224 }
225}
226
227#[derive(Debug, Clone, PartialEq, Eq)]
231pub struct QueryWithParams<Q, P> {
232 pub query: Q,
233 pub params: P,
234}
235
236pub trait WithParams: Sized {
238 fn with<P>(self, params: P) -> QueryWithParams<Self, P>;
239}
240
241impl<T: StatementLike> WithParams for T {
242 fn with<P>(self, params: P) -> QueryWithParams<Self, P> {
243 QueryWithParams {
244 query: self,
245 params,
246 }
247 }
248}
249
250impl<Q, P> Query for QueryWithParams<Q, P>
251where
252 Q: StatementLike,
253 P: Into<Params> + Send,
254{
255 type Protocol = BinaryProtocol;
256
257 fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, BinaryProtocol>>
258 where
259 Self: 'a,
260 C: ToConnection<'a, 't> + 'a,
261 {
262 async move {
263 let mut conn = conn.to_connection().resolve().await?;
264
265 let statement = conn.as_mut().get_statement(self.query).await?;
266
267 conn.as_mut()
268 .execute_statement(&statement, self.params.into())
269 .await?;
270
271 Ok(QueryResult::new(conn))
272 }
273 .boxed()
274 }
275}
276
277pub trait BatchQuery {
302 fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
303 where
304 Self: 'a,
305 C: ToConnection<'a, 't> + 'a;
306}
307
308impl<Q, I, P> BatchQuery for QueryWithParams<Q, I>
309where
310 Q: StatementLike,
311 I: IntoIterator<Item = P> + Send,
312 I::IntoIter: Send,
313 P: Into<Params> + Send,
314{
315 fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
316 where
317 Self: 'a,
318 C: ToConnection<'a, 't> + 'a,
319 {
320 async move {
321 let mut conn = conn.to_connection().resolve().await?;
322
323 let statement = conn.as_mut().get_statement(self.query).await?;
324
325 for params in self.params {
326 conn.as_mut().execute_statement(&statement, params).await?;
327 }
328
329 Ok(())
330 }
331 .boxed()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use crate::{prelude::*, test_misc::get_opts, *};
338
339 #[tokio::test]
340 async fn should_run_text_query() -> Result<()> {
341 let query_static = "SELECT 1, 2 UNION ALL SELECT 3, 4; SELECT 5, 6;";
342 let query_string = String::from(query_static);
343
344 macro_rules! test {
345 ($query:expr, $conn:expr) => {{
346 let mut result = $query.run($conn).await?;
347 let result1: Vec<(u8, u8)> = result.collect().await?;
348 let result2: Vec<(u8, u8)> = result.collect().await?;
349 assert_eq!(result1, vec![(1, 2), (3, 4)]);
350 assert_eq!(result2, vec![(5, 6)]);
351
352 $query.ignore($conn).await?;
353
354 let result: Option<(u8, u8)> = $query.first($conn).await?;
355 assert_eq!(result, Some((1, 2)));
356
357 let result: Vec<(u8, u8)> = $query.fetch($conn).await?;
358 assert_eq!(result, vec![(1, 2), (3, 4)]);
359
360 let result = $query
361 .map($conn, |row: (u8, u8)| format!("{:?}", row))
362 .await?;
363 assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
364
365 let result = $query
366 .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
367 .await?;
368 assert_eq!(result, 10);
369 }};
370 }
371
372 let mut conn = Conn::new(get_opts()).await?;
373 test!(query_static, &mut conn);
374 test!(query_string.as_str(), &mut conn);
375
376 let mut tx = conn.start_transaction(Default::default()).await?;
377 test!(query_static, &mut tx);
378 test!(query_string.as_str(), &mut tx);
379 tx.rollback().await?;
380
381 conn.disconnect().await?;
382
383 let pool = Pool::new(get_opts());
384 test!(query_static, &pool);
385 test!(query_string.as_str(), &pool);
386
387 let mut tx = pool.start_transaction(Default::default()).await?;
388 test!(query_static, &mut tx);
389 test!(query_string.as_str(), &mut tx);
390 tx.rollback().await?;
391
392 pool.disconnect().await?;
393
394 Ok(())
395 }
396
397 #[tokio::test]
398 async fn should_run_bin_query() -> Result<()> {
399 macro_rules! query {
400 (@static) => {
401 "SELECT ?, ? UNION ALL SELECT ?, ?"
402 };
403 (@string) => {
404 String::from("SELECT ?, ? UNION ALL SELECT ?, ?")
405 };
406 (@boxed) => {
407 query!(@string).into_boxed_str()
408 };
409 (@arc) => {
410 std::sync::Arc::<str>::from(query!(@boxed))
411 };
412 }
413
414 let query_string = query!(@string);
415 let params_static = ("1", "2", "3", "4");
416 let params_string = (
417 "1".to_owned(),
418 "2".to_owned(),
419 "3".to_owned(),
420 "4".to_owned(),
421 );
422
423 macro_rules! test {
424 ($query:expr, $params:expr, $conn:expr) => {{
425 let query = { $query.with($params) };
426 let mut result = query.run($conn).await?;
427 let result1: Vec<(u8, u8)> = result.collect().await?;
428 assert_eq!(result1, vec![(1, 2), (3, 4)]);
429
430 $query.with($params).ignore($conn).await?;
431
432 let result: Option<(u8, u8)> = $query.with($params).first($conn).await?;
433 assert_eq!(result, Some((1, 2)));
434
435 let result: Vec<(u8, u8)> = $query.with($params).fetch($conn).await?;
436 assert_eq!(result, vec![(1, 2), (3, 4)]);
437
438 let result = $query
439 .with($params)
440 .map($conn, |row: (u8, u8)| format!("{:?}", row))
441 .await?;
442 assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
443
444 let result = $query
445 .with($params)
446 .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
447 .await?;
448 assert_eq!(result, 10);
449
450 $query
451 .with(vec![$params, $params, $params, $params])
452 .batch($conn)
453 .await?;
454 }};
455 }
456
457 let mut conn = Conn::new(get_opts()).await?;
458 let statement = conn.prep(query!(@static)).await?;
459 test!(query!(@static), params_static, &mut conn);
460 test!(query!(@string), params_string.clone(), &mut conn);
461 test!(query!(@boxed), params_string.clone(), &mut conn);
462 test!(query!(@arc), params_string.clone(), &mut conn);
463 test!(&query_string, params_string.clone(), &mut conn);
464 test!(&statement, params_string.clone(), &mut conn);
465 test!(statement.clone(), params_string.clone(), &mut conn);
466
467 let mut tx = conn.start_transaction(Default::default()).await?;
468 test!(query!(@static), params_string.clone(), &mut tx);
469 test!(query!(@string), params_static, &mut tx);
470 test!(&query_string, params_static, &mut tx);
471 test!(&statement, params_string.clone(), &mut tx);
472 test!(statement.clone(), params_string.clone(), &mut tx);
473 tx.rollback().await?;
474
475 conn.disconnect().await?;
476
477 let pool = Pool::new(get_opts());
478 test!(query!(@static), params_static, &pool);
479 test!(query!(@string), params_string.clone(), &pool);
480 test!(&query_string, params_string.clone(), &pool);
481
482 let mut tx = pool.start_transaction(Default::default()).await?;
483 test!(query!(@static), params_string.clone(), &mut tx);
484 test!(query!(@string), params_static, &mut tx);
485 test!(&query_string, params_static, &mut tx);
486 tx.rollback().await?;
487
488 pool.disconnect().await?;
489
490 Ok(())
491 }
492}