mysql_async/
query.rs

1// Copyright (c) 2020 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::borrow::Cow;
10
11use futures_util::FutureExt;
12
13use crate::{
14    from_row,
15    prelude::{FromRow, StatementLike, ToConnection},
16    tracing_utils::LevelInfo,
17    BinaryProtocol, BoxFuture, Params, QueryResult, ResultSetStream, TextProtocol,
18};
19
20/// Types that can be treated as a MySQL query.
21///
22/// This trait is implemented by all "string-ish" standard library types, like `String`, `&str`,
23/// `Cow<str>`, but also all types that can be treated as a slice of bytes (such as `Vec<u8>` and
24/// `&[u8]`), since MySQL does not require queries to be valid UTF-8.
25pub trait AsQuery: Send + Sync {
26    fn as_query(&self) -> Cow<'_, [u8]>;
27}
28
29impl AsQuery for &'_ [u8] {
30    fn as_query(&self) -> Cow<'_, [u8]> {
31        Cow::Borrowed(self)
32    }
33}
34
35macro_rules! impl_as_query_as_ref {
36    ($type: ty) => {
37        impl AsQuery for $type {
38            fn as_query(&self) -> Cow<'_, [u8]> {
39                Cow::Borrowed(self.as_ref())
40            }
41        }
42    };
43}
44
45impl_as_query_as_ref!(Vec<u8>);
46impl_as_query_as_ref!(&Vec<u8>);
47impl_as_query_as_ref!(Box<[u8]>);
48impl_as_query_as_ref!(Cow<'_, [u8]>);
49impl_as_query_as_ref!(std::sync::Arc<[u8]>);
50
51macro_rules! impl_as_query_as_bytes {
52    ($type: ty) => {
53        impl AsQuery for $type {
54            fn as_query(&self) -> Cow<'_, [u8]> {
55                Cow::Borrowed(self.as_bytes())
56            }
57        }
58    };
59}
60
61impl_as_query_as_bytes!(String);
62impl_as_query_as_bytes!(&String);
63impl_as_query_as_bytes!(&str);
64impl_as_query_as_bytes!(Box<str>);
65impl_as_query_as_bytes!(Cow<'_, str>);
66impl_as_query_as_bytes!(std::sync::Arc<str>);
67
68/// MySql text query.
69///
70/// This trait covers the set of `query*` methods on the `Queryable` trait.
71///
72/// Example:
73///
74/// ```rust
75/// # use mysql_async::test_misc::get_opts;
76/// # #[tokio::main]
77/// # async fn main() -> mysql_async::Result<()> {
78/// use mysql_async::*;
79/// use mysql_async::prelude::*;
80/// let pool = Pool::new(get_opts());
81///
82/// // text protocol query
83/// let num: Option<u32> = "SELECT 42".first(&pool).await?;
84/// assert_eq!(num, Some(42));
85///
86/// // binary protocol query (prepared statement)
87/// let row: Option<(u32, String)> = "SELECT ?, ?".with((42, "foo")).first(&pool).await?;
88/// assert_eq!(row.unwrap(), (42, "foo".into()));
89///
90/// # Ok(()) }
91/// ```
92pub trait Query: Send + Sized {
93    /// Query protocol.
94    type Protocol: crate::prelude::Protocol;
95
96    /// This method corresponds to [`Queryable::query_iter`][query_iter].
97    ///
98    /// [query_iter]: crate::prelude::Queryable::query_iter
99    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, Self::Protocol>>
100    where
101        Self: 'a,
102        C: ToConnection<'a, 't> + 'a;
103
104    /// This methods corresponds to [`Queryable::query_first`][query_first].
105    ///
106    /// [query_first]: crate::prelude::Queryable::query_first
107    fn first<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Option<T>>
108    where
109        Self: 'a,
110        C: ToConnection<'a, 't> + 'a,
111        T: FromRow + Send + 'static,
112    {
113        async move {
114            let mut result = self.run(conn).await?;
115            let output = if result.is_empty() {
116                None
117            } else {
118                result.next().await?.map(from_row)
119            };
120            result.drop_result().await?;
121            Ok(output)
122        }
123        .boxed()
124    }
125
126    /// This methods corresponds to [`Queryable::query`][query].
127    ///
128    /// [query]: crate::prelude::Queryable::query
129    fn fetch<'a, 't: 'a, T, C>(self, conn: C) -> BoxFuture<'a, Vec<T>>
130    where
131        Self: 'a,
132        C: ToConnection<'a, 't> + 'a,
133        T: FromRow + Send + 'static,
134    {
135        async move { self.run(conn).await?.collect_and_drop::<T>().await }.boxed()
136    }
137
138    /// This methods corresponds to [`Queryable::query_fold`][query_fold].
139    ///
140    /// [query_fold]: crate::prelude::Queryable::query_fold
141    fn reduce<'a, 't: 'a, T, U, F, C>(self, conn: C, init: U, next: F) -> BoxFuture<'a, U>
142    where
143        Self: 'a,
144        C: ToConnection<'a, 't> + 'a,
145        F: FnMut(U, T) -> U + Send + 'a,
146        T: FromRow + Send + 'static,
147        U: Send + 'a,
148    {
149        async move { self.run(conn).await?.reduce_and_drop(init, next).await }.boxed()
150    }
151
152    /// This methods corresponds to [`Queryable::query_map`][query_map].
153    ///
154    /// [query_map]: crate::prelude::Queryable::query_map
155    fn map<'a, 't: 'a, T, U, F, C>(self, conn: C, mut map: F) -> BoxFuture<'a, Vec<U>>
156    where
157        Self: 'a,
158        C: ToConnection<'a, 't> + 'a,
159        F: FnMut(T) -> U + Send + 'a,
160        T: FromRow + Send + 'static,
161        U: Send + 'a,
162    {
163        async move {
164            self.run(conn)
165                .await?
166                .map_and_drop(|row| map(from_row(row)))
167                .await
168        }
169        .boxed()
170    }
171
172    /// Returns a stream over the first result set.
173    ///
174    /// This method corresponds to [`QueryResult::stream_and_drop`][stream_and_drop].
175    ///
176    /// [stream_and_drop]: crate::QueryResult::stream_and_drop
177    fn stream<'a, 't: 'a, T, C>(
178        self,
179        conn: C,
180    ) -> BoxFuture<'a, ResultSetStream<'a, 'a, 't, T, Self::Protocol>>
181    where
182        Self: 'a,
183        Self::Protocol: Unpin,
184        T: Unpin + FromRow + Send + 'static,
185        C: ToConnection<'a, 't> + 'a,
186    {
187        async move {
188            self.run(conn)
189                .await?
190                .stream_and_drop()
191                .await
192                .transpose()
193                .expect("At least one result set is expected")
194        }
195        .boxed()
196    }
197
198    /// This method corresponds to [`Queryable::query_drop`][query_drop].
199    ///
200    /// [query_drop]: crate::prelude::Queryable::query_drop
201    fn ignore<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
202    where
203        Self: 'a,
204        C: ToConnection<'a, 't> + 'a,
205    {
206        async move { self.run(conn).await?.drop_result().await }.boxed()
207    }
208}
209
210impl<Q: AsQuery> Query for Q {
211    type Protocol = TextProtocol;
212
213    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, TextProtocol>>
214    where
215        Self: 'a,
216        C: ToConnection<'a, 't> + 'a,
217    {
218        async move {
219            let mut conn = conn.to_connection().resolve().await?;
220            conn.as_mut().raw_query::<'_, _, LevelInfo>(self).await?;
221            Ok(QueryResult::new(conn))
222        }
223        .boxed()
224    }
225}
226
227/// Representation of a prepared statement query.
228///
229/// See `BinQuery` for details.
230#[derive(Debug, Clone, PartialEq, Eq)]
231pub struct QueryWithParams<Q, P> {
232    pub query: Q,
233    pub params: P,
234}
235
236/// Helper, that constructs [`QueryWithParams`].
237pub trait WithParams: Sized {
238    fn with<P>(self, params: P) -> QueryWithParams<Self, P>;
239}
240
241impl<T: StatementLike> WithParams for T {
242    fn with<P>(self, params: P) -> QueryWithParams<Self, P> {
243        QueryWithParams {
244            query: self,
245            params,
246        }
247    }
248}
249
250impl<Q, P> Query for QueryWithParams<Q, P>
251where
252    Q: StatementLike,
253    P: Into<Params> + Send,
254{
255    type Protocol = BinaryProtocol;
256
257    fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, BinaryProtocol>>
258    where
259        Self: 'a,
260        C: ToConnection<'a, 't> + 'a,
261    {
262        async move {
263            let mut conn = conn.to_connection().resolve().await?;
264
265            let statement = conn.as_mut().get_statement(self.query).await?;
266
267            conn.as_mut()
268                .execute_statement(&statement, self.params.into())
269                .await?;
270
271            Ok(QueryResult::new(conn))
272        }
273        .boxed()
274    }
275}
276
277/// Helper trait for batch statement execution.
278///
279/// This trait covers the [`Queryable::exec_batch`][exec_batch] method.
280///
281/// Example:
282///
283/// ```rust
284/// # use mysql_async::test_misc::get_opts;
285/// # #[tokio::main]
286/// # async fn main() -> mysql_async::Result<()> {
287/// use mysql_async::*;
288/// use mysql_async::prelude::*;
289///
290/// let pool = Pool::new(get_opts());
291///
292/// // This will prepare `DO ?` and execute `DO 0`, `DO 1`, `DO 2` and so on.
293/// "DO ?"
294///     .with((0..10).map(|x| (x,)))
295///     .batch(&pool)
296///     .await?;
297/// # Ok(()) }
298/// ```
299///
300/// [exec_batch]: crate::prelude::Queryable::exec_batch
301pub trait BatchQuery {
302    fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
303    where
304        Self: 'a,
305        C: ToConnection<'a, 't> + 'a;
306}
307
308impl<Q, I, P> BatchQuery for QueryWithParams<Q, I>
309where
310    Q: StatementLike,
311    I: IntoIterator<Item = P> + Send,
312    I::IntoIter: Send,
313    P: Into<Params> + Send,
314{
315    fn batch<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, ()>
316    where
317        Self: 'a,
318        C: ToConnection<'a, 't> + 'a,
319    {
320        async move {
321            let mut conn = conn.to_connection().resolve().await?;
322
323            let statement = conn.as_mut().get_statement(self.query).await?;
324
325            for params in self.params {
326                conn.as_mut().execute_statement(&statement, params).await?;
327            }
328
329            Ok(())
330        }
331        .boxed()
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use crate::{prelude::*, test_misc::get_opts, *};
338
339    #[tokio::test]
340    async fn should_run_text_query() -> Result<()> {
341        let query_static = "SELECT 1, 2 UNION ALL SELECT 3, 4; SELECT 5, 6;";
342        let query_string = String::from(query_static);
343
344        macro_rules! test {
345            ($query:expr, $conn:expr) => {{
346                let mut result = $query.run($conn).await?;
347                let result1: Vec<(u8, u8)> = result.collect().await?;
348                let result2: Vec<(u8, u8)> = result.collect().await?;
349                assert_eq!(result1, vec![(1, 2), (3, 4)]);
350                assert_eq!(result2, vec![(5, 6)]);
351
352                $query.ignore($conn).await?;
353
354                let result: Option<(u8, u8)> = $query.first($conn).await?;
355                assert_eq!(result, Some((1, 2)));
356
357                let result: Vec<(u8, u8)> = $query.fetch($conn).await?;
358                assert_eq!(result, vec![(1, 2), (3, 4)]);
359
360                let result = $query
361                    .map($conn, |row: (u8, u8)| format!("{:?}", row))
362                    .await?;
363                assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
364
365                let result = $query
366                    .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
367                    .await?;
368                assert_eq!(result, 10);
369            }};
370        }
371
372        let mut conn = Conn::new(get_opts()).await?;
373        test!(query_static, &mut conn);
374        test!(query_string.as_str(), &mut conn);
375
376        let mut tx = conn.start_transaction(Default::default()).await?;
377        test!(query_static, &mut tx);
378        test!(query_string.as_str(), &mut tx);
379        tx.rollback().await?;
380
381        conn.disconnect().await?;
382
383        let pool = Pool::new(get_opts());
384        test!(query_static, &pool);
385        test!(query_string.as_str(), &pool);
386
387        let mut tx = pool.start_transaction(Default::default()).await?;
388        test!(query_static, &mut tx);
389        test!(query_string.as_str(), &mut tx);
390        tx.rollback().await?;
391
392        pool.disconnect().await?;
393
394        Ok(())
395    }
396
397    #[tokio::test]
398    async fn should_run_bin_query() -> Result<()> {
399        macro_rules! query {
400            (@static) => {
401                "SELECT ?, ? UNION ALL SELECT ?, ?"
402            };
403            (@string) => {
404                String::from("SELECT ?, ? UNION ALL SELECT ?, ?")
405            };
406            (@boxed) => {
407                query!(@string).into_boxed_str()
408            };
409            (@arc) => {
410                std::sync::Arc::<str>::from(query!(@boxed))
411            };
412        }
413
414        let query_string = query!(@string);
415        let params_static = ("1", "2", "3", "4");
416        let params_string = (
417            "1".to_owned(),
418            "2".to_owned(),
419            "3".to_owned(),
420            "4".to_owned(),
421        );
422
423        macro_rules! test {
424            ($query:expr, $params:expr, $conn:expr) => {{
425                let query = { $query.with($params) };
426                let mut result = query.run($conn).await?;
427                let result1: Vec<(u8, u8)> = result.collect().await?;
428                assert_eq!(result1, vec![(1, 2), (3, 4)]);
429
430                $query.with($params).ignore($conn).await?;
431
432                let result: Option<(u8, u8)> = $query.with($params).first($conn).await?;
433                assert_eq!(result, Some((1, 2)));
434
435                let result: Vec<(u8, u8)> = $query.with($params).fetch($conn).await?;
436                assert_eq!(result, vec![(1, 2), (3, 4)]);
437
438                let result = $query
439                    .with($params)
440                    .map($conn, |row: (u8, u8)| format!("{:?}", row))
441                    .await?;
442                assert_eq!(result, vec![String::from("(1, 2)"), String::from("(3, 4)")]);
443
444                let result = $query
445                    .with($params)
446                    .reduce($conn, 0_u8, |acc, row: (u8, u8)| acc + row.0 + row.1)
447                    .await?;
448                assert_eq!(result, 10);
449
450                $query
451                    .with(vec![$params, $params, $params, $params])
452                    .batch($conn)
453                    .await?;
454            }};
455        }
456
457        let mut conn = Conn::new(get_opts()).await?;
458        let statement = conn.prep(query!(@static)).await?;
459        test!(query!(@static), params_static, &mut conn);
460        test!(query!(@string), params_string.clone(), &mut conn);
461        test!(query!(@boxed), params_string.clone(), &mut conn);
462        test!(query!(@arc), params_string.clone(), &mut conn);
463        test!(&query_string, params_string.clone(), &mut conn);
464        test!(&statement, params_string.clone(), &mut conn);
465        test!(statement.clone(), params_string.clone(), &mut conn);
466
467        let mut tx = conn.start_transaction(Default::default()).await?;
468        test!(query!(@static), params_string.clone(), &mut tx);
469        test!(query!(@string), params_static, &mut tx);
470        test!(&query_string, params_static, &mut tx);
471        test!(&statement, params_string.clone(), &mut tx);
472        test!(statement.clone(), params_string.clone(), &mut tx);
473        tx.rollback().await?;
474
475        conn.disconnect().await?;
476
477        let pool = Pool::new(get_opts());
478        test!(query!(@static), params_static, &pool);
479        test!(query!(@string), params_string.clone(), &pool);
480        test!(&query_string, params_string.clone(), &pool);
481
482        let mut tx = pool.start_transaction(Default::default()).await?;
483        test!(query!(@static), params_string.clone(), &mut tx);
484        test!(query!(@string), params_static, &mut tx);
485        test!(&query_string, params_static, &mut tx);
486        tx.rollback().await?;
487
488        pool.disconnect().await?;
489
490        Ok(())
491    }
492}