mysql_async/conn/routines/
exec.rs

1use std::mem;
2
3use futures_core::future::BoxFuture;
4use futures_util::FutureExt;
5use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params};
6#[cfg(feature = "tracing")]
7use tracing::{field, info_span, Level, Span};
8
9use crate::{BinaryProtocol, Conn, DriverError, Statement};
10
11use super::Routine;
12
13/// A routine that executes `COM_STMT_EXECUTE`.
14#[derive(Debug, Clone)]
15pub struct ExecRoutine<'a> {
16    stmt: &'a Statement,
17    params: Params,
18}
19
20impl<'a> ExecRoutine<'a> {
21    pub fn new(stmt: &'a Statement, params: Params) -> Self {
22        Self { stmt, params }
23    }
24}
25
26impl Routine<()> for ExecRoutine<'_> {
27    fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> {
28        #[cfg(feature = "tracing")]
29        let span = info_span!(
30            "mysql_async::exec",
31            mysql_async.connection.id = conn.id(),
32            mysql_async.statement.id = self.stmt.id(),
33            mysql_async.query.params = field::Empty,
34        );
35
36        let fut = async move {
37            loop {
38                match self.params {
39                    Params::Positional(ref params) => {
40                        #[cfg(feature = "tracing")]
41                        if tracing::span_enabled!(Level::DEBUG) {
42                            // The params may contain sensitive data. Restrict to DEBUG.
43                            // TODO: make more efficient
44                            // TODO: use intersperse() once stable
45                            let sep = std::iter::repeat(", ");
46                            let ps = params
47                                .iter()
48                                .map(|p| p.as_sql(true))
49                                .zip(sep)
50                                .map(|(val, sep)| val + sep)
51                                .collect::<String>();
52                            Span::current().record("mysql_async.query.params", ps);
53                        }
54
55                        if self.stmt.num_params() as usize != params.len() {
56                            Err(DriverError::StmtParamsMismatch {
57                                required: self.stmt.num_params(),
58                                supplied: params.len() as u16,
59                            })?
60                        }
61
62                        let (body, as_long_data) =
63                            ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(params);
64
65                        if as_long_data {
66                            conn.send_long_data(self.stmt.id(), params.iter()).await?;
67                        }
68
69                        conn.write_command(&body).await?;
70                        conn.read_result_set::<BinaryProtocol>(true).await?;
71                        break;
72                    }
73                    Params::Named(_) => {
74                        if self.stmt.named_params.is_empty() {
75                            let error = DriverError::NamedParamsForPositionalQuery.into();
76                            return Err(error);
77                        }
78
79                        let named = mem::replace(&mut self.params, Params::Empty);
80                        self.params = named.into_positional(&self.stmt.named_params)?;
81
82                        continue;
83                    }
84                    Params::Empty => {
85                        if self.stmt.num_params() > 0 {
86                            let error = DriverError::StmtParamsMismatch {
87                                required: self.stmt.num_params(),
88                                supplied: 0,
89                            }
90                            .into();
91                            return Err(error);
92                        }
93
94                        let (body, _) =
95                            ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&[]);
96                        conn.write_command(&body).await?;
97                        conn.read_result_set::<BinaryProtocol>(true).await?;
98                        break;
99                    }
100                }
101            }
102            Ok(())
103        };
104
105        #[cfg(feature = "tracing")]
106        let fut = instrument_result!(fut, span);
107
108        fut.boxed()
109    }
110}