1use futures_util::FutureExt;
10use mysql_common::{
11 constants::MAX_PAYLOAD_LEN,
12 io::ParseBuf,
13 proto::{Binary, Text},
14 row::RowDeserializer,
15 value::ServerSide,
16};
17
18use std::{fmt, sync::Arc};
19
20use self::{
21 query_result::QueryResult,
22 stmt::Statement,
23 transaction::{Transaction, TxStatus},
24};
25
26use crate::{
27 conn::routines::{PingRoutine, QueryRoutine},
28 consts::CapabilityFlags,
29 error::*,
30 prelude::{FromRow, StatementLike},
31 query::AsQuery,
32 queryable::query_result::ResultSetMeta,
33 tracing_utils::{LevelInfo, LevelTrace, TracingLevel},
34 BoxFuture, Column, Conn, Params, ResultSetStream, Row,
35};
36
37pub mod query_result;
38pub mod stmt;
39pub mod transaction;
40
41pub trait Protocol: fmt::Debug + Send + Sync + 'static {
42 fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta;
44 fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result<Row>;
45 fn is_last_result_set_packet(capabilities: CapabilityFlags, packet: &[u8]) -> bool {
46 if capabilities.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
47 packet[0] == 0xFE && packet.len() < MAX_PAYLOAD_LEN
48 } else {
49 packet[0] == 0xFE && packet.len() < 8
50 }
51 }
52}
53
54#[derive(Debug)]
56pub struct TextProtocol;
57
58#[derive(Debug)]
60pub struct BinaryProtocol;
61
62impl Protocol for TextProtocol {
63 fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta {
64 ResultSetMeta::Text(columns)
65 }
66
67 fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result<Row> {
68 ParseBuf(packet)
69 .parse::<RowDeserializer<ServerSide, Text>>(columns)
70 .map(Into::into)
71 .map_err(Into::into)
72 }
73}
74
75impl Protocol for BinaryProtocol {
76 fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta {
77 ResultSetMeta::Binary(columns)
78 }
79
80 fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result<Row> {
81 ParseBuf(packet)
82 .parse::<RowDeserializer<ServerSide, Binary>>(columns)
83 .map(Into::into)
84 .map_err(Into::into)
85 }
86}
87
88impl Conn {
89 pub(crate) async fn clean_dirty(&mut self) -> Result<()> {
97 self.drop_result().await?;
98 if self.get_tx_status() == TxStatus::RequiresRollback {
99 self.set_tx_status(TxStatus::None);
100 self.exec_drop("ROLLBACK", ()).await?;
101 }
102 Ok(())
103 }
104
105 pub(crate) async fn raw_query<'a, Q, L: TracingLevel>(&'a mut self, query: Q) -> Result<()>
107 where
108 Q: AsQuery + 'a,
109 {
110 self.routine(QueryRoutine::<'_, L>::new(query.as_query().as_ref()))
111 .await
112 }
113
114 pub(crate) fn query_internal<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Option<T>>
119 where
120 Q: AsQuery + 'a,
121 T: FromRow + Send + 'static,
122 {
123 async move {
124 self.raw_query::<'_, _, LevelTrace>(query).await?;
125 Ok(QueryResult::<'_, '_, TextProtocol>::new(self)
126 .collect_and_drop::<T>()
127 .await?
128 .pop())
129 }
130 .boxed()
131 }
132}
133
134pub trait Queryable: Send {
138 fn ping(&mut self) -> BoxFuture<'_, ()>;
140
141 fn query_iter<'a, Q>(
143 &'a mut self,
144 query: Q,
145 ) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
146 where
147 Q: AsQuery + 'a;
148
149 fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement>
158 where
159 Q: AsQuery + 'a;
160
161 fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()>;
168
169 fn exec_iter<'a: 's, 's, Q, P>(
173 &'a mut self,
174 stmt: Q,
175 params: P,
176 ) -> BoxFuture<'s, QueryResult<'a, 'static, BinaryProtocol>>
177 where
178 Q: StatementLike + 'a,
179 P: Into<Params>;
180
181 fn query<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Vec<T>>
189 where
190 Q: AsQuery + 'a,
191 T: FromRow + Send + 'static,
192 {
193 async move { self.query_iter(query).await?.collect_and_drop::<T>().await }.boxed()
194 }
195
196 fn query_first<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Option<T>>
204 where
205 Q: AsQuery + 'a,
206 T: FromRow + Send + 'static,
207 {
208 async move {
209 let mut result = self.query_iter(query).await?;
210 let output = if result.is_empty() {
211 None
212 } else {
213 result.next().await?.map(crate::from_row)
214 };
215 result.drop_result().await?;
216 Ok(output)
217 }
218 .boxed()
219 }
220
221 fn query_map<'a, T, F, Q, U>(&'a mut self, query: Q, mut f: F) -> BoxFuture<'a, Vec<U>>
229 where
230 Q: AsQuery + 'a,
231 T: FromRow + Send + 'static,
232 F: FnMut(T) -> U + Send + 'a,
233 U: Send,
234 {
235 async move {
236 self.query_fold(query, Vec::new(), |mut acc, row| {
237 acc.push(f(crate::from_row(row)));
238 acc
239 })
240 .await
241 }
242 .boxed()
243 }
244
245 fn query_fold<'a, T, F, Q, U>(&'a mut self, query: Q, init: U, mut f: F) -> BoxFuture<'a, U>
253 where
254 Q: AsQuery + 'a,
255 T: FromRow + Send + 'static,
256 F: FnMut(U, T) -> U + Send + 'a,
257 U: Send + 'a,
258 {
259 async move {
260 self.query_iter(query)
261 .await?
262 .reduce_and_drop(init, |acc, row| f(acc, crate::from_row(row)))
263 .await
264 }
265 .boxed()
266 }
267
268 fn query_drop<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, ()>
270 where
271 Q: AsQuery + 'a,
272 {
273 async move { self.query_iter(query).await?.drop_result().await }.boxed()
274 }
275
276 fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()>
280 where
281 S: StatementLike + 'b,
282 I: IntoIterator<Item = P> + Send + 'b,
283 I::IntoIter: Send,
284 P: Into<Params> + Send;
285
286 fn exec<'a: 'b, 'b, T, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, Vec<T>>
296 where
297 S: StatementLike + 'b,
298 P: Into<Params> + Send + 'b,
299 T: FromRow + Send + 'static,
300 {
301 async move {
302 self.exec_iter(stmt, params)
303 .await?
304 .collect_and_drop::<T>()
305 .await
306 }
307 .boxed()
308 }
309
310 fn exec_first<'a: 'b, 'b, T, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, Option<T>>
320 where
321 S: StatementLike + 'b,
322 P: Into<Params> + Send + 'b,
323 T: FromRow + Send + 'static,
324 {
325 async move {
326 let mut result = self.exec_iter(stmt, params).await?;
327 let row = if result.is_empty() {
328 None
329 } else {
330 result.next().await?
331 };
332 result.drop_result().await?;
333 Ok(row.map(crate::from_row))
334 }
335 .boxed()
336 }
337
338 fn exec_map<'a: 'b, 'b, T, S, P, U, F>(
348 &'a mut self,
349 stmt: S,
350 params: P,
351 mut f: F,
352 ) -> BoxFuture<'b, Vec<U>>
353 where
354 S: StatementLike + 'b,
355 P: Into<Params> + Send + 'b,
356 T: FromRow + Send + 'static,
357 F: FnMut(T) -> U + Send + 'a,
358 U: Send + 'a,
359 {
360 async move {
361 self.exec_fold(stmt, params, Vec::new(), |mut acc, row| {
362 acc.push(f(crate::from_row(row)));
363 acc
364 })
365 .await
366 }
367 .boxed()
368 }
369
370 fn exec_fold<'a: 'b, 'b, T, S, P, U, F>(
380 &'a mut self,
381 stmt: S,
382 params: P,
383 init: U,
384 mut f: F,
385 ) -> BoxFuture<'b, U>
386 where
387 S: StatementLike + 'b,
388 P: Into<Params> + Send + 'b,
389 T: FromRow + Send + 'static,
390 F: FnMut(U, T) -> U + Send + 'a,
391 U: Send + 'a,
392 {
393 async move {
394 self.exec_iter(stmt, params)
395 .await?
396 .reduce_and_drop(init, |acc, row| f(acc, crate::from_row(row)))
397 .await
398 }
399 .boxed()
400 }
401
402 fn exec_drop<'a: 'b, 'b, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, ()>
404 where
405 S: StatementLike + 'b,
406 P: Into<Params> + Send + 'b,
407 {
408 async move { self.exec_iter(stmt, params).await?.drop_result().await }.boxed()
409 }
410
411 fn query_stream<'a, T, Q>(
417 &'a mut self,
418 query: Q,
419 ) -> BoxFuture<'a, ResultSetStream<'a, 'a, 'static, T, TextProtocol>>
420 where
421 T: Unpin + FromRow + Send + 'static,
422 Q: AsQuery + 'a,
423 {
424 async move {
425 self.query_iter(query)
426 .await?
427 .stream_and_drop()
428 .await
429 .transpose()
430 .expect("At least one result set is expected")
431 }
432 .boxed()
433 }
434
435 fn exec_stream<'a: 's, 's, T, Q, P>(
441 &'a mut self,
442 stmt: Q,
443 params: P,
444 ) -> BoxFuture<'s, ResultSetStream<'a, 'a, 'static, T, BinaryProtocol>>
445 where
446 T: Unpin + FromRow + Send + 'static,
447 Q: StatementLike + 'a,
448 P: Into<Params> + Send + 's,
449 {
450 async move {
451 self.exec_iter(stmt, params)
452 .await?
453 .stream_and_drop()
454 .await
455 .transpose()
456 .expect("At least one result set is expected")
457 }
458 .boxed()
459 }
460}
461
462impl Queryable for Conn {
463 fn ping(&mut self) -> BoxFuture<'_, ()> {
464 async move {
465 self.routine(PingRoutine).await?;
466 Ok(())
467 }
468 .boxed()
469 }
470
471 fn query_iter<'a, Q>(
472 &'a mut self,
473 query: Q,
474 ) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
475 where
476 Q: AsQuery + 'a,
477 {
478 async move {
479 self.raw_query::<'_, _, LevelInfo>(query).await?;
480 Ok(QueryResult::new(self))
481 }
482 .boxed()
483 }
484
485 fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement>
486 where
487 Q: AsQuery + 'a,
488 {
489 async move { self.get_statement(query.as_query()).await }.boxed()
490 }
491
492 fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> {
493 async move {
494 self.stmt_cache_mut().remove(stmt.id());
495 self.close_statement(stmt.id()).await
496 }
497 .boxed()
498 }
499
500 fn exec_iter<'a: 's, 's, Q, P>(
501 &'a mut self,
502 stmt: Q,
503 params: P,
504 ) -> BoxFuture<'s, QueryResult<'a, 'static, BinaryProtocol>>
505 where
506 Q: StatementLike + 'a,
507 P: Into<Params>,
508 {
509 let params = params.into();
510 async move {
511 let statement = self.get_statement(stmt).await?;
512 self.execute_statement(&statement, params).await?;
513 Ok(QueryResult::new(self))
514 }
515 .boxed()
516 }
517
518 fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()>
519 where
520 S: StatementLike + 'b,
521 I: IntoIterator<Item = P> + Send + 'b,
522 I::IntoIter: Send,
523 P: Into<Params> + Send,
524 {
525 async move {
526 let statement = self.get_statement(stmt).await?;
527 for params in params_iter {
528 self.execute_statement(&statement, params).await?;
529 QueryResult::<BinaryProtocol>::new(&mut *self)
530 .drop_result()
531 .await?;
532 }
533 Ok(())
534 }
535 .boxed()
536 }
537}
538
539impl Queryable for Transaction<'_> {
540 fn ping(&mut self) -> BoxFuture<'_, ()> {
541 self.0.ping()
542 }
543
544 fn query_iter<'a, Q>(
545 &'a mut self,
546 query: Q,
547 ) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
548 where
549 Q: AsQuery + 'a,
550 {
551 self.0.query_iter(query)
552 }
553
554 fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement>
555 where
556 Q: AsQuery + 'a,
557 {
558 self.0.prep(query)
559 }
560
561 fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> {
562 self.0.close(stmt)
563 }
564
565 fn exec_iter<'a: 's, 's, Q, P>(
566 &'a mut self,
567 stmt: Q,
568 params: P,
569 ) -> BoxFuture<'s, QueryResult<'a, 'static, BinaryProtocol>>
570 where
571 Q: StatementLike + 'a,
572 P: Into<Params>,
573 {
574 self.0.exec_iter(stmt, params)
575 }
576
577 fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()>
578 where
579 S: StatementLike + 'b,
580 I: IntoIterator<Item = P> + Send + 'b,
581 I::IntoIter: Send,
582 P: Into<Params> + Send,
583 {
584 self.0.exec_batch(stmt, params_iter)
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use crate::{error::Result, prelude::*, test_misc::get_opts, Conn};
591
592 #[tokio::test]
593 async fn should_prep() -> Result<()> {
594 const NAMED: &str = "SELECT :foo, :bar, :foo";
595 const POSITIONAL: &str = "SELECT ?, ?, ?";
596
597 let mut conn = Conn::new(get_opts()).await?;
598
599 let stmt_named = conn.prep(NAMED).await?;
600 let stmt_positional = conn.prep(POSITIONAL).await?;
601
602 let result_stmt_named: Option<(String, u8, String)> = conn
603 .exec_first(&stmt_named, params! { "foo" => "bar", "bar" => 42 })
604 .await?;
605 let result_str_named: Option<(String, u8, String)> = conn
606 .exec_first(NAMED, params! { "foo" => "bar", "bar" => 42 })
607 .await?;
608
609 let result_stmt_positional: Option<(String, u8, String)> = conn
610 .exec_first(&stmt_positional, ("bar", 42, "bar"))
611 .await?;
612 let result_str_positional: Option<(String, u8, String)> =
613 conn.exec_first(NAMED, ("bar", 42, "bar")).await?;
614
615 assert_eq!(
616 Some(("bar".to_owned(), 42_u8, "bar".to_owned())),
617 result_stmt_named
618 );
619 assert_eq!(result_stmt_named, result_str_named);
620 assert_eq!(result_str_named, result_stmt_positional);
621 assert_eq!(result_stmt_positional, result_str_positional);
622
623 conn.disconnect().await?;
624
625 Ok(())
626 }
627}