mysql_async/conn/routines/
exec.rs
1use std::mem;
2
3use futures_core::future::BoxFuture;
4use futures_util::FutureExt;
5use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params};
6#[cfg(feature = "tracing")]
7use tracing::{field, info_span, Level, Span};
8
9use crate::{BinaryProtocol, Conn, DriverError, Statement};
10
11use super::Routine;
12
13#[derive(Debug, Clone)]
15pub struct ExecRoutine<'a> {
16 stmt: &'a Statement,
17 params: Params,
18}
19
20impl<'a> ExecRoutine<'a> {
21 pub fn new(stmt: &'a Statement, params: Params) -> Self {
22 Self { stmt, params }
23 }
24}
25
26impl Routine<()> for ExecRoutine<'_> {
27 fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> {
28 #[cfg(feature = "tracing")]
29 let span = info_span!(
30 "mysql_async::exec",
31 mysql_async.connection.id = conn.id(),
32 mysql_async.statement.id = self.stmt.id(),
33 mysql_async.query.params = field::Empty,
34 );
35
36 let fut = async move {
37 loop {
38 match self.params {
39 Params::Positional(ref params) => {
40 #[cfg(feature = "tracing")]
41 if tracing::span_enabled!(Level::DEBUG) {
42 let sep = std::iter::repeat(", ");
46 let ps = params
47 .iter()
48 .map(|p| p.as_sql(true))
49 .zip(sep)
50 .map(|(val, sep)| val + sep)
51 .collect::<String>();
52 Span::current().record("mysql_async.query.params", ps);
53 }
54
55 if self.stmt.num_params() as usize != params.len() {
56 Err(DriverError::StmtParamsMismatch {
57 required: self.stmt.num_params(),
58 supplied: params.len() as u16,
59 })?
60 }
61
62 let (body, as_long_data) =
63 ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(params);
64
65 if as_long_data {
66 conn.send_long_data(self.stmt.id(), params.iter()).await?;
67 }
68
69 conn.write_command(&body).await?;
70 conn.read_result_set::<BinaryProtocol>(true).await?;
71 break;
72 }
73 Params::Named(_) => {
74 if self.stmt.named_params.is_empty() {
75 let error = DriverError::NamedParamsForPositionalQuery.into();
76 return Err(error);
77 }
78
79 let named = mem::replace(&mut self.params, Params::Empty);
80 self.params = named.into_positional(&self.stmt.named_params)?;
81
82 continue;
83 }
84 Params::Empty => {
85 if self.stmt.num_params() > 0 {
86 let error = DriverError::StmtParamsMismatch {
87 required: self.stmt.num_params(),
88 supplied: 0,
89 }
90 .into();
91 return Err(error);
92 }
93
94 let (body, _) =
95 ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&[]);
96 conn.write_command(&body).await?;
97 conn.read_result_set::<BinaryProtocol>(true).await?;
98 break;
99 }
100 }
101 }
102 Ok(())
103 };
104
105 #[cfg(feature = "tracing")]
106 let fut = instrument_result!(fut, span);
107
108 fut.boxed()
109 }
110}