1use futures_util::FutureExt;
10use mysql_common::{
11 io::ParseBuf,
12 named_params::ParsedNamedParams,
13 packets::{ComStmtClose, StmtPacket},
14};
15
16use std::{borrow::Cow, sync::Arc};
17
18use crate::{
19 conn::routines::{ExecRoutine, PrepareRoutine},
20 consts::CapabilityFlags,
21 error::*,
22 Column, Params,
23};
24
25use super::AsQuery;
26
27pub enum ToStatementResult<'a> {
29 Immediate(Statement),
31 Mediate(crate::BoxFuture<'a, Statement>),
33}
34
35pub trait StatementLike: Send + Sync {
36 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
38 where
39 Self: 'a;
40}
41
42fn to_statement_move<'a, T: AsQuery + 'a>(
43 stmt: T,
44 conn: &'a mut crate::Conn,
45) -> ToStatementResult<'a> {
46 let fut = async move {
47 let query = stmt.as_query();
48 let parsed = ParsedNamedParams::parse(query.as_ref())?;
49 let inner_stmt = match conn.get_cached_stmt(parsed.query()) {
50 Some(inner_stmt) => inner_stmt,
51 None => {
52 conn.prepare_statement(Cow::Borrowed(parsed.query()))
53 .await?
54 }
55 };
56 Ok(Statement::new(
57 inner_stmt,
58 parsed
59 .params()
60 .iter()
61 .map(|x| x.as_ref().to_vec())
62 .collect::<Vec<_>>(),
63 ))
64 }
65 .boxed();
66 ToStatementResult::Mediate(fut)
67}
68
69impl StatementLike for Cow<'_, str> {
70 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
71 where
72 Self: 'a,
73 {
74 to_statement_move(self, conn)
75 }
76}
77
78impl StatementLike for &'_ str {
79 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
80 where
81 Self: 'a,
82 {
83 to_statement_move(self, conn)
84 }
85}
86
87impl StatementLike for String {
88 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
89 where
90 Self: 'a,
91 {
92 to_statement_move(self, conn)
93 }
94}
95
96impl StatementLike for Box<str> {
97 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
98 where
99 Self: 'a,
100 {
101 to_statement_move(self, conn)
102 }
103}
104
105impl StatementLike for Arc<str> {
106 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
107 where
108 Self: 'a,
109 {
110 to_statement_move(self, conn)
111 }
112}
113
114impl StatementLike for Cow<'_, [u8]> {
115 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
116 where
117 Self: 'a,
118 {
119 to_statement_move(self, conn)
120 }
121}
122
123impl StatementLike for &'_ [u8] {
124 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
125 where
126 Self: 'a,
127 {
128 to_statement_move(self, conn)
129 }
130}
131
132impl StatementLike for Vec<u8> {
133 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
134 where
135 Self: 'a,
136 {
137 to_statement_move(self, conn)
138 }
139}
140
141impl StatementLike for Box<[u8]> {
142 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
143 where
144 Self: 'a,
145 {
146 to_statement_move(self, conn)
147 }
148}
149
150impl StatementLike for Arc<[u8]> {
151 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
152 where
153 Self: 'a,
154 {
155 to_statement_move(self, conn)
156 }
157}
158
159impl StatementLike for Statement {
160 fn to_statement<'a>(self, _conn: &'a mut crate::Conn) -> ToStatementResult<'static>
161 where
162 Self: 'a,
163 {
164 ToStatementResult::Immediate(self.clone())
165 }
166}
167
168impl<T: StatementLike + Clone> StatementLike for &'_ T {
169 fn to_statement<'a>(self, conn: &'a mut crate::Conn) -> ToStatementResult<'a>
170 where
171 Self: 'a,
172 {
173 self.clone().to_statement(conn)
174 }
175}
176
177#[derive(Debug, Clone, Eq, PartialEq)]
179pub struct StmtInner {
180 pub(crate) raw_query: Arc<[u8]>,
181 columns: Option<Box<[Column]>>,
182 params: Option<Box<[Column]>>,
183 stmt_packet: StmtPacket,
184 connection_id: u32,
185}
186
187impl StmtInner {
188 pub(crate) fn from_payload(
189 pld: &[u8],
190 connection_id: u32,
191 raw_query: Arc<[u8]>,
192 ) -> std::io::Result<Self> {
193 let stmt_packet = ParseBuf(pld).parse(())?;
194
195 Ok(Self {
196 raw_query,
197 columns: None,
198 params: None,
199 stmt_packet,
200 connection_id,
201 })
202 }
203
204 pub(crate) fn with_params(mut self, params: Vec<Column>) -> Self {
205 self.params = if params.is_empty() {
206 None
207 } else {
208 Some(params.into_boxed_slice())
209 };
210 self
211 }
212
213 pub(crate) fn with_columns(mut self, columns: Vec<Column>) -> Self {
214 self.columns = if columns.is_empty() {
215 None
216 } else {
217 Some(columns.into_boxed_slice())
218 };
219 self
220 }
221
222 pub(crate) fn columns(&self) -> &[Column] {
223 self.columns.as_ref().map(AsRef::as_ref).unwrap_or(&[])
224 }
225
226 pub(crate) fn params(&self) -> &[Column] {
227 self.params.as_ref().map(AsRef::as_ref).unwrap_or(&[])
228 }
229
230 pub(crate) fn id(&self) -> u32 {
231 self.stmt_packet.statement_id()
232 }
233
234 pub(crate) const fn connection_id(&self) -> u32 {
235 self.connection_id
236 }
237
238 pub(crate) fn num_params(&self) -> u16 {
239 self.stmt_packet.num_params()
240 }
241
242 pub(crate) fn num_columns(&self) -> u16 {
243 self.stmt_packet.num_columns()
244 }
245}
246
247#[derive(Debug, Clone, Eq, PartialEq)]
251pub struct Statement {
252 pub(crate) inner: Arc<StmtInner>,
253 pub(crate) named_params: Vec<Vec<u8>>,
255}
256
257impl Statement {
258 pub(crate) fn new(inner: Arc<StmtInner>, named_params: Vec<Vec<u8>>) -> Self {
259 Self {
260 inner,
261 named_params,
262 }
263 }
264
265 pub fn columns(&self) -> &[Column] {
267 self.inner.columns()
268 }
269
270 pub fn params(&self) -> &[Column] {
272 self.inner.params()
273 }
274
275 pub fn id(&self) -> u32 {
277 self.inner.id()
278 }
279
280 pub fn connection_id(&self) -> u32 {
282 self.inner.connection_id()
283 }
284
285 pub fn num_params(&self) -> u16 {
287 self.inner.num_params()
288 }
289
290 pub fn num_columns(&self) -> u16 {
292 self.inner.num_columns()
293 }
294}
295
296impl crate::Conn {
297 pub(crate) async fn read_column_defs<U>(&mut self, num: U) -> Result<Vec<Column>>
301 where
302 U: Into<usize>,
303 {
304 let num = num.into();
305 debug_assert!(num > 0);
306 let packets = self.read_packets(num).await?;
307 let defs = packets
308 .into_iter()
309 .map(|x| ParseBuf(&x).parse(()))
310 .collect::<std::result::Result<Vec<Column>, _>>()
311 .map_err(Error::from)?;
312
313 if !self
314 .capabilities()
315 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
316 {
317 self.read_packet().await?;
318 }
319
320 Ok(defs)
321 }
322
323 pub(crate) async fn get_statement<U>(&mut self, stmt_like: U) -> Result<Statement>
325 where
326 U: StatementLike,
327 {
328 match stmt_like.to_statement(self) {
329 ToStatementResult::Immediate(statement) => Ok(statement),
330 ToStatementResult::Mediate(statement) => statement.await,
331 }
332 }
333
334 async fn prepare_statement(&mut self, raw_query: Cow<'_, [u8]>) -> Result<Arc<StmtInner>> {
338 let inner_stmt = self.routine(PrepareRoutine::new(raw_query)).await?;
339
340 if let Some(old_stmt) = self.cache_stmt(&inner_stmt) {
341 self.close_statement(old_stmt.id()).await?;
342 }
343
344 Ok(inner_stmt)
345 }
346
347 pub(crate) async fn execute_statement<P>(
349 &mut self,
350 statement: &Statement,
351 params: P,
352 ) -> Result<()>
353 where
354 P: Into<Params>,
355 {
356 self.routine(ExecRoutine::new(statement, params.into()))
357 .await?;
358 Ok(())
359 }
360
361 pub(crate) async fn close_statement(&mut self, id: u32) -> Result<()> {
363 self.stmt_cache_mut().remove(id);
364 self.write_command(&ComStmtClose::new(id)).await
365 }
366}