mysql_async/conn/routines/
prepare.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use std::{borrow::Cow, sync::Arc};

use futures_core::future::BoxFuture;
use futures_util::FutureExt;
use mysql_common::constants::Command;
#[cfg(feature = "tracing")]
use tracing::{field, info_span, Level, Span};

use crate::{queryable::stmt::StmtInner, Conn};

use super::Routine;

/// A routine that performs `COM_STMT_PREPARE`.
#[derive(Debug, Clone)]
pub struct PrepareRoutine {
    query: Arc<[u8]>,
}

impl PrepareRoutine {
    pub fn new(raw_query: Cow<'_, [u8]>) -> Self {
        Self {
            query: raw_query.into_owned().into_boxed_slice().into(),
        }
    }
}

impl Routine<Arc<StmtInner>> for PrepareRoutine {
    fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<Arc<StmtInner>>> {
        #[cfg(feature = "tracing")]
        let span = info_span!(
            "mysql_async::prepare",
            mysql_async.connection.id = conn.id(),
            mysql_async.statement.id = field::Empty,
            mysql_async.query.sql = field::Empty,
        );
        #[cfg(feature = "tracing")]
        if tracing::span_enabled!(Level::DEBUG) {
            // The statement may contain sensitive data. Restrict to DEBUG.
            span.record(
                "mysql_async.query.sql",
                String::from_utf8_lossy(&*self.query).as_ref(),
            );
        }

        let fut = async move {
            conn.write_command_data(Command::COM_STMT_PREPARE, &self.query)
                .await?;

            let packet = conn.read_packet().await?;
            let mut inner_stmt = StmtInner::from_payload(&packet, conn.id(), self.query.clone())?;

            #[cfg(feature = "tracing")]
            Span::current().record("mysql_async.statement.id", inner_stmt.id());

            if inner_stmt.num_params() > 0 {
                let params = conn.read_column_defs(inner_stmt.num_params()).await?;
                inner_stmt = inner_stmt.with_params(params);
            }

            if inner_stmt.num_columns() > 0 {
                let columns = conn.read_column_defs(inner_stmt.num_columns()).await?;
                inner_stmt = inner_stmt.with_columns(columns);
            }

            Ok(Arc::new(inner_stmt))
        };

        #[cfg(feature = "tracing")]
        let fut = instrument_result!(fut, span);

        fut.boxed()
    }
}