mysql_async/
query.rs

1// Copyright (c) 2020 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::borrow::Cow;
10
11use futures_util::FutureExt;
12
13use crate::{
14    connection_like::ToConnectionResult,
15    from_row,
16    prelude::{FromRow, StatementLike, ToConnection},
17    tracing_utils::LevelInfo,
18    BinaryProtocol, BoxFuture, Params, QueryResult, ResultSetStream, TextProtocol,
19};
20
21/// Types that can be treated as a MySQL query.
22///
23/// This trait is implemented by all "string-ish" standard library types, like `String`, `&str`,
24/// `Cow<str>`, but also all types that can be treated as a slice of bytes (such as `Vec<u8>` and
25/// `&[u8]`), since MySQL does not require queries to be valid UTF-8.
26pub trait AsQuery: Send + Sync {
27    fn as_query(&self) -> Cow<'_, [u8]>;
28}
29
30impl AsQuery for &'_ [u8] {
31    fn as_query(&self) -> Cow<'_, [u8]> {
32        Cow::Borrowed(self)
33    }
34}
35
36macro_rules! impl_as_query_as_ref {
37    ($type: ty) => {
38        impl AsQuery for $type {
39            fn as_query(&self) -> Cow<'_, [u8]> {
40                Cow::Borrowed(self.as_ref())
41            }
42        }
43    };
44}
45
46impl_as_query_as_ref!(Vec<u8>);
47impl_as_query_as_ref!(&Vec<u8>);
48impl_as_query_as_ref!(Box<[u8]>);
49impl_as_query_as_ref!(Cow<'_, [u8]>);
50impl_as_query_as_ref!(std::sync::Arc<[u8]>);
51
52macro_rules! impl_as_query_as_bytes {
53    ($type: ty) => {
54        impl AsQuery for $type {
55            fn as_query(&self) -> Cow<'_, [u8]> {
56                Cow::Borrowed(self.as_bytes())
57            }
58        }
59    };
60}
61
62impl_as_query_as_bytes!(String);
63impl_as_query_as_bytes!(&String);
64impl_as_query_as_bytes!(&str);
65impl_as_query_as_bytes!(Box<str>);
66impl_as_query_as_bytes!(Cow<'_, str>);
67impl_as_query_as_bytes!(std::sync::Arc<str>);
68
69/// MySql text query.
70///
71/// This trait covers the set of `query*` methods on the `Queryable` trait.
72///
73/// Example:
74///
75/// ```rust
76/// # use mysql_async::test_misc::get_opts;
77/// # #[tokio::main]
78/// # async fn main() -> mysql_async::Result<()> {
79/// use mysql_async::*;
80/// use mysql_async::prelude::*;
81/// let pool = Pool::new(get_opts());
82///
83/// // text protocol query
84/// let num: Option<u32> = "SELECT 42".first(&pool).await?;
85/// assert_eq!(num, Some(42));
86///
87/// // binary protocol query (prepared statement)
88/// let row: Option<(u32, String)> = "SELECT ?, ?".with((42, "foo")).first(&pool).await?;
89/// assert_eq!(row.unwrap(), (42, "foo".into()));
90///
91/// # Ok(()) }
92/// ```
93pub trait Query: Send + Sized {
94    /// Query protocol.
95    type Protocol: crate::prelude::Protocol;
96
97    /// This method corresponds to [`Queryable::query_iter`][query_iter].
98    ///
99    /// [query_iter]: crate::prelude::Queryable::query_iter
100    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, Self::Protocol>>
101    where
102        Self: 'a,
103        C: ToConnection<'a, 't> + 'a;
104
105    /// This methods corresponds to [`Queryable::query_first`][query_first].
106    ///
107    /// [query_first]: crate::prelude::Queryable::query_first
108    fn first<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Option<T>>
109    where
110        Self: 'a,
111        C: ToConnection<'a, 't> + 'a,
112        T: FromRow + Send + 'static,
113    {
114        async move {
115            let mut result = self.run(conn).await?;
116            let output = if result.is_empty() {
117                None
118            } else {
119                result.next().await?.map(from_row)
120            };
121            result.drop_result().await?;
122            Ok(output)
123        }
124        .boxed()
125    }
126
127    /// This methods corresponds to [`Queryable::query`][query].
128    ///
129    /// [query]: crate::prelude::Queryable::query
130    fn fetch<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Vec<T>>
131    where
132        Self: 'a,
133        C: ToConnection<'a, 't> + 'a,
134        T: FromRow + Send + 'static,
135    {
136        async move { self.run(conn).await?.collect_and_drop::<T>().await }.boxed()
137    }
138
139    /// This methods corresponds to [`Queryable::query_fold`][query_fold].
140    ///
141    /// [query_fold]: crate::prelude::Queryable::query_fold
142    fn reduce<'a, 't: 'a, T, U, F, C>(self, conn: C, init: U, next: F) -> BoxFuture<'a, U>
143    where
144        Self: 'a,
145        C: ToConnection<'a, 't> + 'a,
146        F: FnMut(U, T) -> U + Send + 'a,
147        T: FromRow + Send + 'static,
148        U: Send + 'a,
149    {
150        async move { self.run(conn).await?.reduce_and_drop(init, next).await }.boxed()
151    }
152
153    /// This methods corresponds to [`Queryable::query_map`][query_map].
154    ///
155    /// [query_map]: crate::prelude::Queryable::query_map
156    fn map<'a, 't: 'a, T, U, F, C>(self, conn: C, mut map: F) -> BoxFuture<'a, Vec<U>>
157    where
158        Self: 'a,
159        C: ToConnection<'a, 't> + 'a,
160        F: FnMut(T) -> U + Send + 'a,
161        T: FromRow + Send + 'static,
162        U: Send + 'a,
163    {
164        async move {
165            self.run(conn)
166                .await?
167                .map_and_drop(|row| map(from_row(row)))
168                .await
169        }
170        .boxed()
171    }
172
173    /// Returns a stream over the first result set.
174    ///
175    /// This method corresponds to [`QueryResult::stream_and_drop`][stream_and_drop].
176    ///
177    /// [stream_and_drop]: crate::QueryResult::stream_and_drop
178    fn stream<'a, 't: 'a, T, C>(
179        self,
180        conn: C,
181    ) -> BoxFuture<'a, ResultSetStream<'a, 'a, 't, T, Self::Protocol>>
182    where
183        Self: 'a,
184        Self::Protocol: Unpin,
185        T: Unpin + FromRow + Send + 'static,
186        C: ToConnection<'a, 't> + 'a,
187    {
188        async move {
189            self.run(conn)
190                .await?
191                .stream_and_drop()
192                .await
193                .transpose()
194                .expect("At least one result set is expected")
195        }
196        .boxed()
197    }
198
199    /// This method corresponds to [`Queryable::query_drop`][query_drop].
200    ///
201    /// [query_drop]: crate::prelude::Queryable::query_drop
202    fn ignore<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
203    where
204        Self: 'a,
205        C: ToConnection<'a, 't> + 'a,
206    {
207        async move { self.run(conn).await?.drop_result().await }.boxed()
208    }
209}
210
211impl<Q: AsQuery> Query for Q {
212    type Protocol = TextProtocol;
213
214    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, TextProtocol>>
215    where
216        Self: 'a,
217        C: ToConnection<'a, 't> + 'a,
218    {
219        async move {
220            let mut conn = match conn.to_connection() {
221                ToConnectionResult::Immediate(conn) => conn,
222                ToConnectionResult::Mediate(fut) => fut.await?,
223            };
224            conn.raw_query::<'_, _, LevelInfo>(self).await?;
225            Ok(QueryResult::new(conn))
226        }
227        .boxed()
228    }
229}
230
231/// Representation of a prepared statement query.
232///
233/// See `BinQuery` for details.
234#[derive(Debug, Clone, PartialEq, Eq)]
235pub struct QueryWithParams<Q, P> {
236    pub query: Q,
237    pub params: P,
238}
239
240/// Helper, that constructs [`QueryWithParams`].
241pub trait WithParams: Sized {
242    fn with<P>(self, params: P) -> QueryWithParams<Self, P>;
243}
244
245impl<T: StatementLike> WithParams for T {
246    fn with<P>(self, params: P) -> QueryWithParams<Self, P> {
247        QueryWithParams {
248            query: self,
249            params,
250        }
251    }
252}
253
254impl<Q, P> Query for QueryWithParams<Q, P>
255where
256    Q: StatementLike,
257    P: Into<Params> + Send,
258{
259    type Protocol = BinaryProtocol;
260
261    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, BinaryProtocol>>
262    where
263        Self: 'a,
264        C: ToConnection<'a, 't> + 'a,
265    {
266        async move {
267            let mut conn = match conn.to_connection() {
268                ToConnectionResult::Immediate(conn) => conn,
269                ToConnectionResult::Mediate(fut) => fut.await?,
270            };
271
272            let statement = conn.get_statement(self.query).await?;
273
274            conn.execute_statement(&statement, self.params.into())
275                .await?;
276
277            Ok(QueryResult::new(conn))
278        }
279        .boxed()
280    }
281}
282
283/// Helper trait for batch statement execution.
284///
285/// This trait covers the [`Queryable::exec_batch`][exec_batch] method.
286///
287/// Example:
288///
289/// ```rust
290/// # use mysql_async::test_misc::get_opts;
291/// # #[tokio::main]
292/// # async fn main() -> mysql_async::Result<()> {
293/// use mysql_async::*;
294/// use mysql_async::prelude::*;
295///
296/// let pool = Pool::new(get_opts());
297///
298/// // This will prepare `DO ?` and execute `DO 0`, `DO 1`, `DO 2` and so on.
299/// "DO ?"
300///     .with((0..10).map(|x| (x,)))
301///     .batch(&pool)
302///     .await?;
303/// # Ok(()) }
304/// ```
305///
306/// [exec_batch]: crate::prelude::Queryable::exec_batch
307pub trait BatchQuery {
308    fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
309    where
310        Self: 'a,
311        C: ToConnection<'a, 't> + 'a;
312}
313
314impl<Q, I, P> BatchQuery for QueryWithParams<Q, I>
315where
316    Q: StatementLike,
317    I: IntoIterator<Item = P> + Send,
318    I::IntoIter: Send,
319    P: Into<Params> + Send,
320{
321    fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
322    where
323        Self: 'a,
324        C: ToConnection<'a, 't> + 'a,
325    {
326        async move {
327            let mut conn = match conn.to_connection() {
328                ToConnectionResult::Immediate(conn) => conn,
329                ToConnectionResult::Mediate(fut) => fut.await?,
330            };
331
332            let statement = conn.get_statement(self.query).await?;
333
334            for params in self.params {
335                conn.execute_statement(&statement, params).await?;
336            }
337
338            Ok(())
339        }
340        .boxed()
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use crate::{prelude::*, test_misc::get_opts, *};
347
348    #[tokio::test]
349    async fn should_run_text_query() -> Result<()> {
350        let query_static = "SELECT 1, 2 UNION ALL SELECT 3, 4; SELECT 5, 6;";
351        let query_string = String::from(query_static);
352
353        macro_rules! test {
354            ($query:expr, $conn:expr) => {{
355                let mut result = $query.run($conn).await?;
356                let result1: Vec<(u8, u8)> = result.collect().await?;
357                let result2: Vec<(u8, u8)> = result.collect().await?;
358                assert_eq!(result1, vec![(1, 2), (3, 4)]);
359                assert_eq!(result2, vec![(5, 6)]);
360
361                $query.ignore($conn).await?;
362
363                let result: Option<(u8, u8)> = $query.first($conn).await?;
364                assert_eq!(result, Some((1, 2)));
365
366                let result: Vec<(u8, u8)> = $query.fetch($conn).await?;
367                assert_eq!(result, vec![(1, 2), (3, 4)]);
368
369                let result = $query
370                    .map($conn, |row: (u8, u8)| format!("{:?}", row))
371                    .await?;
372                assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
373
374                let result = $query
375                    .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
376                    .await?;
377                assert_eq!(result, 10);
378            }};
379        }
380
381        let mut conn = Conn::new(get_opts()).await?;
382        test!(query_static, &mut conn);
383        test!(query_string.as_str(), &mut conn);
384
385        let mut tx = conn.start_transaction(Default::default()).await?;
386        test!(query_static, &mut tx);
387        test!(query_string.as_str(), &mut tx);
388        tx.rollback().await?;
389
390        conn.disconnect().await?;
391
392        let pool = Pool::new(get_opts());
393        test!(query_static, &pool);
394        test!(query_string.as_str(), &pool);
395
396        let mut tx = pool.start_transaction(Default::default()).await?;
397        test!(query_static, &mut tx);
398        test!(query_string.as_str(), &mut tx);
399        tx.rollback().await?;
400
401        pool.disconnect().await?;
402
403        Ok(())
404    }
405
406    #[tokio::test]
407    async fn should_run_bin_query() -> Result<()> {
408        macro_rules! query {
409            (@static) => {
410                "SELECT ?, ? UNION ALL SELECT ?, ?"
411            };
412            (@string) => {
413                String::from("SELECT ?, ? UNION ALL SELECT ?, ?")
414            };
415            (@boxed) => {
416                query!(@string).into_boxed_str()
417            };
418            (@arc) => {
419                std::sync::Arc::<str>::from(query!(@boxed))
420            };
421        }
422
423        let query_string = query!(@string);
424        let params_static = ("1", "2", "3", "4");
425        let params_string = (
426            "1".to_owned(),
427            "2".to_owned(),
428            "3".to_owned(),
429            "4".to_owned(),
430        );
431
432        macro_rules! test {
433            ($query:expr, $params:expr, $conn:expr) => {{
434                let query = { $query.with($params) };
435                let mut result = query.run($conn).await?;
436                let result1: Vec<(u8, u8)> = result.collect().await?;
437                assert_eq!(result1, vec![(1, 2), (3, 4)]);
438
439                $query.with($params).ignore($conn).await?;
440
441                let result: Option<(u8, u8)> = $query.with($params).first($conn).await?;
442                assert_eq!(result, Some((1, 2)));
443
444                let result: Vec<(u8, u8)> = $query.with($params).fetch($conn).await?;
445                assert_eq!(result, vec![(1, 2), (3, 4)]);
446
447                let result = $query
448                    .with($params)
449                    .map($conn, |row: (u8, u8)| format!("{:?}", row))
450                    .await?;
451                assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
452
453                let result = $query
454                    .with($params)
455                    .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
456                    .await?;
457                assert_eq!(result, 10);
458
459                $query
460                    .with(vec![$params, $params, $params, $params])
461                    .batch($conn)
462                    .await?;
463            }};
464        }
465
466        let mut conn = Conn::new(get_opts()).await?;
467        let statement = conn.prep(query!(@static)).await?;
468        test!(query!(@static), params_static, &mut conn);
469        test!(query!(@string), params_string.clone(), &mut conn);
470        test!(query!(@boxed), params_string.clone(), &mut conn);
471        test!(query!(@arc), params_string.clone(), &mut conn);
472        test!(&query_string, params_string.clone(), &mut conn);
473        test!(&statement, params_string.clone(), &mut conn);
474        test!(statement.clone(), params_string.clone(), &mut conn);
475
476        let mut tx = conn.start_transaction(Default::default()).await?;
477        test!(query!(@static), params_string.clone(), &mut tx);
478        test!(query!(@string), params_static, &mut tx);
479        test!(&query_string, params_static, &mut tx);
480        test!(&statement, params_string.clone(), &mut tx);
481        test!(statement.clone(), params_string.clone(), &mut tx);
482        tx.rollback().await?;
483
484        conn.disconnect().await?;
485
486        let pool = Pool::new(get_opts());
487        test!(query!(@static), params_static, &pool);
488        test!(query!(@string), params_string.clone(), &pool);
489        test!(&query_string, params_string.clone(), &pool);
490
491        let mut tx = pool.start_transaction(Default::default()).await?;
492        test!(query!(@static), params_string.clone(), &mut tx);
493        test!(query!(@string), params_static, &mut tx);
494        test!(&query_string, params_static, &mut tx);
495        tx.rollback().await?;
496
497        pool.disconnect().await?;
498
499        Ok(())
500    }
501}