mysql_async/conn/routines/
prepare.rs

1use std::{borrow::Cow, sync::Arc};
2
3use futures_core::future::BoxFuture;
4use futures_util::FutureExt;
5use mysql_common::constants::Command;
6#[cfg(feature = "tracing")]
7use tracing::{field, info_span, Level, Span};
8
9use crate::{queryable::stmt::StmtInner, Conn};
10
11use super::Routine;
12
13/// A routine that performs `COM_STMT_PREPARE`.
14#[derive(Debug, Clone)]
15pub struct PrepareRoutine {
16    query: Arc<[u8]>,
17}
18
19impl PrepareRoutine {
20    pub fn new(raw_query: Cow<'_, [u8]>) -> Self {
21        Self {
22            query: raw_query.into_owned().into_boxed_slice().into(),
23        }
24    }
25}
26
27impl Routine<Arc<StmtInner>> for PrepareRoutine {
28    fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<Arc<StmtInner>>> {
29        #[cfg(feature = "tracing")]
30        let span = info_span!(
31            "mysql_async::prepare",
32            mysql_async.connection.id = conn.id(),
33            mysql_async.statement.id = field::Empty,
34            mysql_async.query.sql = field::Empty,
35        );
36        #[cfg(feature = "tracing")]
37        if tracing::span_enabled!(Level::DEBUG) {
38            // The statement may contain sensitive data. Restrict to DEBUG.
39            span.record(
40                "mysql_async.query.sql",
41                String::from_utf8_lossy(&*self.query).as_ref(),
42            );
43        }
44
45        let fut = async move {
46            conn.write_command_data(Command::COM_STMT_PREPARE, &self.query)
47                .await?;
48
49            let packet = conn.read_packet().await?;
50            let mut inner_stmt = StmtInner::from_payload(&packet, conn.id(), self.query.clone())?;
51
52            #[cfg(feature = "tracing")]
53            Span::current().record("mysql_async.statement.id", inner_stmt.id());
54
55            if inner_stmt.num_params() > 0 {
56                let params = conn.read_column_defs(inner_stmt.num_params()).await?;
57                inner_stmt = inner_stmt.with_params(params);
58            }
59
60            if inner_stmt.num_columns() > 0 {
61                let columns = conn.read_column_defs(inner_stmt.num_columns()).await?;
62                inner_stmt = inner_stmt.with_columns(columns);
63            }
64
65            Ok(Arc::new(inner_stmt))
66        };
67
68        #[cfg(feature = "tracing")]
69        let fut = instrument_result!(fut, span);
70
71        fut.boxed()
72    }
73}