mysql_async/conn/routines/
exec.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::mem;

use futures_core::future::BoxFuture;
use futures_util::FutureExt;
use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params};
#[cfg(feature = "tracing")]
use tracing::{field, info_span, Level, Span};

use crate::{BinaryProtocol, Conn, DriverError, Statement};

use super::Routine;

/// A routine that executes `COM_STMT_EXECUTE`.
#[derive(Debug, Clone)]
pub struct ExecRoutine<'a> {
    stmt: &'a Statement,
    params: Params,
}

impl<'a> ExecRoutine<'a> {
    pub fn new(stmt: &'a Statement, params: Params) -> Self {
        Self { stmt, params }
    }
}

impl Routine<()> for ExecRoutine<'_> {
    fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> {
        #[cfg(feature = "tracing")]
        let span = info_span!(
            "mysql_async::exec",
            mysql_async.connection.id = conn.id(),
            mysql_async.statement.id = self.stmt.id(),
            mysql_async.query.params = field::Empty,
        );

        let fut = async move {
            loop {
                match self.params {
                    Params::Positional(ref params) => {
                        #[cfg(feature = "tracing")]
                        if tracing::span_enabled!(Level::DEBUG) {
                            // The params may contain sensitive data. Restrict to DEBUG.
                            // TODO: make more efficient
                            // TODO: use intersperse() once stable
                            let sep = std::iter::repeat(", ");
                            let ps = params
                                .iter()
                                .map(|p| p.as_sql(true))
                                .zip(sep)
                                .map(|(val, sep)| val + sep)
                                .collect::<String>();
                            Span::current().record("mysql_async.query.params", ps);
                        }

                        if self.stmt.num_params() as usize != params.len() {
                            Err(DriverError::StmtParamsMismatch {
                                required: self.stmt.num_params(),
                                supplied: params.len() as u16,
                            })?
                        }

                        let (body, as_long_data) =
                            ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(params);

                        if as_long_data {
                            conn.send_long_data(self.stmt.id(), params.iter()).await?;
                        }

                        conn.write_command(&body).await?;
                        conn.read_result_set::<BinaryProtocol>(true).await?;
                        break;
                    }
                    Params::Named(_) => {
                        if self.stmt.named_params.is_empty() {
                            let error = DriverError::NamedParamsForPositionalQuery.into();
                            return Err(error);
                        }

                        let named = mem::replace(&mut self.params, Params::Empty);
                        self.params = named.into_positional(&self.stmt.named_params)?;

                        continue;
                    }
                    Params::Empty => {
                        if self.stmt.num_params() > 0 {
                            let error = DriverError::StmtParamsMismatch {
                                required: self.stmt.num_params(),
                                supplied: 0,
                            }
                            .into();
                            return Err(error);
                        }

                        let (body, _) =
                            ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&[]);
                        conn.write_command(&body).await?;
                        conn.read_result_set::<BinaryProtocol>(true).await?;
                        break;
                    }
                }
            }
            Ok(())
        };

        #[cfg(feature = "tracing")]
        let fut = instrument_result!(fut, span);

        fut.boxed()
    }
}