1use futures_util::FutureExt;
10
11use mysql_common::{
12 constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
13 crypto,
14 io::ParseBuf,
15 packets::{
16 AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket,
17 HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket,
18 ResultSetTerminator, SslRequest,
19 },
20 proto::MySerialize,
21 row::Row,
22};
23
24use std::{
25 borrow::Cow,
26 fmt,
27 future::Future,
28 mem::{self, replace},
29 pin::Pin,
30 str::FromStr,
31 sync::Arc,
32 time::{Duration, Instant},
33};
34
35use crate::{
36 buffer_pool::PooledBuf,
37 conn::{pool::Pool, stmt_cache::StmtCache},
38 consts::{CapabilityFlags, Command, StatusFlags},
39 error::*,
40 io::Stream,
41 opts::Opts,
42 queryable::{
43 query_result::{QueryResult, ResultSetMeta},
44 transaction::TxStatus,
45 BinaryProtocol, Queryable, TextProtocol,
46 },
47 ChangeUserOpts, InfileData, OptsBuilder,
48};
49
50use self::routines::Routine;
51
52#[cfg(feature = "binlog")]
53pub mod binlog_stream;
54pub mod pool;
55pub mod routines;
56pub mod stmt_cache;
57
58const DEFAULT_WAIT_TIMEOUT: usize = 28800;
59
60fn disconnect(mut conn: Conn) {
62 let disconnected = conn.inner.disconnected;
63
64 conn.inner.disconnected = true;
66
67 if !disconnected {
68 if std::thread::panicking() {
70 return;
71 }
72
73 if let Ok(handle) = tokio::runtime::Handle::try_current() {
76 handle.spawn(async move {
77 if let Ok(conn) = conn.cleanup_for_pool().await {
78 let _ = conn.disconnect().await;
79 }
80 });
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub(crate) enum PendingResult {
88 Pending(ResultSetMeta),
90 Taken(Arc<ResultSetMeta>),
92}
93
94struct ConnInner {
96 stream: Option<Stream>,
97 id: u32,
98 is_mariadb: bool,
99 version: (u16, u16, u16),
100 socket: Option<String>,
101 capabilities: CapabilityFlags,
102 status: StatusFlags,
103 last_ok_packet: Option<OkPacket<'static>>,
104 last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105 handshake_complete: bool,
106 pool: Option<Pool>,
107 pending_result: std::result::Result<Option<PendingResult>, ServerError>,
108 tx_status: TxStatus,
109 reset_upon_returning_to_a_pool: bool,
110 opts: Opts,
111 ttl_deadline: Option<Instant>,
112 last_io: Instant,
113 wait_timeout: Duration,
114 stmt_cache: StmtCache,
115 nonce: Vec<u8>,
116 auth_plugin: AuthPlugin<'static>,
117 auth_switched: bool,
118 server_key: Option<Vec<u8>>,
119 active_since: Instant,
120 pub(crate) disconnected: bool,
122 infile_handler:
124 Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
125}
126
127impl fmt::Debug for ConnInner {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("Conn")
130 .field("connection id", &self.id)
131 .field("server version", &self.version)
132 .field("pool", &self.pool)
133 .field("pending_result", &self.pending_result)
134 .field("tx_status", &self.tx_status)
135 .field("stream", &self.stream)
136 .field("options", &self.opts)
137 .field("server_key", &self.server_key)
138 .field("auth_plugin", &self.auth_plugin)
139 .finish()
140 }
141}
142
143impl ConnInner {
144 fn empty(opts: Opts) -> ConnInner {
146 let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline();
147 ConnInner {
148 capabilities: opts.get_capabilities(),
149 status: StatusFlags::empty(),
150 last_ok_packet: None,
151 last_err_packet: None,
152 handshake_complete: false,
153 stream: None,
154 is_mariadb: false,
155 version: (0, 0, 0),
156 id: 0,
157 pending_result: Ok(None),
158 pool: None,
159 tx_status: TxStatus::None,
160 last_io: Instant::now(),
161 wait_timeout: Duration::from_secs(0),
162 stmt_cache: StmtCache::new(opts.stmt_cache_size()),
163 socket: opts.socket().map(Into::into),
164 opts,
165 ttl_deadline,
166 nonce: Vec::default(),
167 auth_plugin: AuthPlugin::MysqlNativePassword,
168 auth_switched: false,
169 disconnected: false,
170 server_key: None,
171 infile_handler: None,
172 reset_upon_returning_to_a_pool: false,
173 active_since: Instant::now(),
174 }
175 }
176
177 fn stream_mut(&mut self) -> Result<&mut Stream> {
181 self.stream
182 .as_mut()
183 .ok_or_else(|| DriverError::ConnectionClosed.into())
184 }
185}
186
187#[derive(Debug)]
189pub struct Conn {
190 inner: Box<ConnInner>,
191}
192
193impl Conn {
194 pub fn id(&self) -> u32 {
196 self.inner.id
197 }
198
199 pub fn last_insert_id(&self) -> Option<u64> {
203 self.inner
204 .last_ok_packet
205 .as_ref()
206 .and_then(|ok| ok.last_insert_id())
207 }
208
209 pub fn affected_rows(&self) -> u64 {
212 self.inner
213 .last_ok_packet
214 .as_ref()
215 .map(|ok| ok.affected_rows())
216 .unwrap_or_default()
217 }
218
219 pub fn info(&self) -> Cow<'_, str> {
221 self.inner
222 .last_ok_packet
223 .as_ref()
224 .and_then(|ok| ok.info_str())
225 .unwrap_or_else(|| "".into())
226 }
227
228 pub fn get_warnings(&self) -> u16 {
230 self.inner
231 .last_ok_packet
232 .as_ref()
233 .map(|ok| ok.warnings())
234 .unwrap_or_default()
235 }
236
237 pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
239 self.inner.last_ok_packet.as_ref()
240 }
241
242 pub fn reset_connection(&mut self, reset_connection: bool) {
246 self.inner.reset_upon_returning_to_a_pool = reset_connection;
247 }
248
249 pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
250 self.inner.stream_mut()
251 }
252
253 pub(crate) fn capabilities(&self) -> CapabilityFlags {
254 self.inner.capabilities
255 }
256
257 pub(crate) fn touch(&mut self) {
259 self.inner.last_io = Instant::now();
260 }
261
262 pub(crate) fn reset_seq_id(&mut self) {
264 if let Some(stream) = self.inner.stream.as_mut() {
265 stream.reset_seq_id();
266 }
267 }
268
269 pub(crate) fn sync_seq_id(&mut self) {
271 if let Some(stream) = self.inner.stream.as_mut() {
272 stream.sync_seq_id();
273 }
274 }
275
276 pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
278 self.inner.status = ok_packet.status_flags();
279 self.inner.last_err_packet = None;
280 self.inner.last_ok_packet = Some(ok_packet);
281 }
282
283 pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
285 match err_packet {
286 ErrPacket::Error(err) => {
287 self.inner.status = StatusFlags::empty();
288 self.inner.last_ok_packet = None;
289 self.inner.last_err_packet = Some(err.clone().into_owned());
290 Err(Error::from(err))
291 }
292 ErrPacket::Progress(_) => Ok(()),
293 }
294 }
295
296 pub(crate) fn get_tx_status(&self) -> TxStatus {
298 self.inner.tx_status
299 }
300
301 pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
303 self.inner.tx_status = tx_status;
304 }
305
306 pub(crate) fn use_pending_result(
310 &mut self,
311 ) -> std::result::Result<Option<&PendingResult>, ServerError> {
312 if let Err(ref e) = self.inner.pending_result {
313 let e = e.clone();
314 self.inner.pending_result = Ok(None);
315 Err(e)
316 } else {
317 Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
318 }
319 }
320
321 pub(crate) fn get_pending_result(
322 &self,
323 ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
324 self.inner.pending_result.as_ref().map(|x| x.as_ref())
325 }
326
327 pub(crate) fn has_pending_result(&self) -> bool {
328 self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_)))
329 }
330
331 pub(crate) fn set_pending_result(
333 &mut self,
334 meta: Option<ResultSetMeta>,
335 ) -> std::result::Result<Option<PendingResult>, ServerError> {
336 replace(
337 &mut self.inner.pending_result,
338 Ok(meta.map(PendingResult::Pending)),
339 )
340 }
341
342 pub(crate) fn set_pending_result_error(
343 &mut self,
344 error: ServerError,
345 ) -> std::result::Result<Option<PendingResult>, ServerError> {
346 replace(&mut self.inner.pending_result, Err(error))
347 }
348
349 pub(crate) fn take_pending_result(
351 &mut self,
352 ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
353 let mut output = None;
354
355 self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
356 Some(PendingResult::Pending(x)) => {
357 let meta = Arc::new(x);
358 output = Some(meta.clone());
359 Ok(Some(PendingResult::Taken(meta)))
360 }
361 x => Ok(x),
362 };
363
364 Ok(output)
365 }
366
367 pub(crate) fn status(&self) -> StatusFlags {
369 self.inner.status
370 }
371
372 pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
373 where
374 F: Routine<T> + 'a,
375 {
376 self.inner.disconnected = true;
377 let result = f.call(&mut *self).await;
378 match result {
379 result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
380 self.inner.disconnected = false;
382 result
383 }
384 Err(err) => {
385 if self.inner.stream.is_some() {
386 self.take_stream().close().await?;
387 }
388 Err(err)
389 }
390 }
391 }
392
393 pub fn server_version(&self) -> (u16, u16, u16) {
395 self.inner.version
396 }
397
398 pub fn opts(&self) -> &Opts {
400 &self.inner.opts
401 }
402
403 pub fn set_infile_handler<T>(&mut self, handler: T)
410 where
411 T: Future<Output = crate::Result<InfileData>>,
412 T: Send + Sync + 'static,
413 {
414 self.inner.infile_handler = Some(Box::pin(handler));
415 }
416
417 fn take_stream(&mut self) -> Stream {
418 self.inner.stream.take().unwrap()
419 }
420
421 pub async fn disconnect(mut self) -> Result<()> {
423 if !self.inner.disconnected {
424 self.inner.disconnected = true;
425 self.write_command_data(Command::COM_QUIT, &[]).await?;
426 let stream = self.take_stream();
427 stream.close().await?;
428 }
429 Ok(())
430 }
431
432 async fn close_conn(mut self) -> Result<()> {
434 self = self.cleanup_for_pool().await?;
435 self.disconnect().await
436 }
437
438 fn is_secure(&self) -> bool {
440 #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
441 {
442 self.inner
443 .stream
444 .as_ref()
445 .map(|x| x.is_secure())
446 .unwrap_or_default()
447 }
448
449 #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
450 false
451 }
452
453 fn is_socket(&self) -> bool {
455 #[cfg(unix)]
456 {
457 self.inner
458 .stream
459 .as_ref()
460 .map(|x| x.is_socket())
461 .unwrap_or_default()
462 }
463
464 #[cfg(not(unix))]
465 false
466 }
467
468 fn take(&mut self) -> Conn {
470 mem::replace(self, Conn::empty(Default::default()))
471 }
472
473 fn empty(opts: Opts) -> Self {
474 Self {
475 inner: Box::new(ConnInner::empty(opts)),
476 }
477 }
478
479 fn setup_stream(&mut self) -> Result<()> {
483 debug_assert!(self.inner.stream.is_some());
484 if let Some(stream) = self.inner.stream.as_mut() {
485 stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
486 }
487 Ok(())
488 }
489
490 async fn handle_handshake(&mut self) -> Result<()> {
491 let packet = self.read_packet().await?;
492 let handshake = ParseBuf(&packet).parse::<HandshakePacket>(())?;
493
494 self.inner.nonce = {
496 let mut nonce = Vec::from(handshake.scramble_1_ref());
497 nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
498 nonce.resize(20, 0);
501 nonce
502 };
503
504 self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
505 self.inner.version = handshake
506 .maria_db_server_version_parsed()
507 .inspect(|_| self.inner.is_mariadb = true)
508 .or_else(|| handshake.server_version_parsed())
509 .unwrap_or((0, 0, 0));
510 self.inner.id = handshake.connection_id();
511 self.inner.status = handshake.status_flags();
512
513 self.inner.auth_plugin = match handshake.auth_plugin() {
517 Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
518 _ => AuthPlugin::MysqlNativePassword,
519 };
520
521 Ok(())
522 }
523
524 async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
525 if self
526 .inner
527 .opts
528 .get_capabilities()
529 .contains(CapabilityFlags::CLIENT_SSL)
530 {
531 if !self
532 .inner
533 .capabilities
534 .contains(CapabilityFlags::CLIENT_SSL)
535 {
536 return Err(DriverError::NoClientSslFlagFromServer.into());
537 }
538
539 let collation = if self.inner.version >= (5, 5, 3) {
540 UTF8MB4_GENERAL_CI
541 } else {
542 UTF8_GENERAL_CI
543 };
544
545 let ssl_request = SslRequest::new(
546 self.inner.capabilities,
547 DEFAULT_MAX_ALLOWED_PACKET as u32,
548 collation as u8,
549 );
550 self.write_struct(&ssl_request).await?;
551 let conn = self;
552 let ssl_opts = conn.opts().ssl_opts_and_connector().expect("unreachable");
553 let domain = ssl_opts
554 .ssl_opts()
555 .tls_hostname_override()
556 .unwrap_or_else(|| conn.opts().ip_or_hostname())
557 .into();
558 let tls_connector = ssl_opts.build_tls_connector().await?;
559 conn.stream_mut()?
560 .make_secure(domain, &tls_connector)
561 .await?;
562 Ok(())
563 } else {
564 Ok(())
565 }
566 }
567
568 async fn do_handshake_response(&mut self) -> Result<()> {
569 let auth_data = self
570 .inner
571 .auth_plugin
572 .gen_data(self.inner.opts.pass(), &self.inner.nonce);
573
574 let handshake_response = HandshakeResponse::new(
575 auth_data.as_deref(),
576 self.inner.version,
577 self.inner.opts.user().map(|x| x.as_bytes()),
578 self.inner.opts.db_name().map(|x| x.as_bytes()),
579 Some(self.inner.auth_plugin.borrow()),
580 self.capabilities(),
581 Default::default(), self.inner
583 .opts
584 .max_allowed_packet()
585 .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32,
586 );
587
588 let mut buf = crate::buffer_pool().get();
590 handshake_response.serialize(buf.as_mut());
591
592 self.write_packet(buf).await?;
593 self.inner.handshake_complete = true;
594 Ok(())
595 }
596
597 async fn perform_auth_switch(
598 &mut self,
599 auth_switch_request: AuthSwitchRequest<'_>,
600 ) -> Result<()> {
601 if !self.inner.auth_switched {
602 self.inner.auth_switched = true;
603 self.inner.nonce = auth_switch_request.plugin_data().to_vec();
604
605 if matches!(
606 auth_switch_request.auth_plugin(),
607 AuthPlugin::MysqlOldPassword
608 ) && self.inner.opts.secure_auth()
609 {
610 return Err(DriverError::MysqlOldPasswordDisabled.into());
611 }
612
613 self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
614
615 let plugin_data = match &self.inner.auth_plugin {
616 x @ AuthPlugin::CachingSha2Password => {
617 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
618 }
619 x @ AuthPlugin::MysqlNativePassword => {
620 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
621 }
622 x @ AuthPlugin::MysqlOldPassword => {
623 if self.inner.opts.secure_auth() {
624 return Err(DriverError::MysqlOldPasswordDisabled.into());
625 } else {
626 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
627 }
628 }
629 x @ AuthPlugin::MysqlClearPassword => {
630 if self.inner.opts.enable_cleartext_plugin() {
631 x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
632 } else {
633 return Err(DriverError::CleartextPluginDisabled.into());
634 }
635 }
636 x @ AuthPlugin::Ed25519 => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
637 x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
638 };
639
640 if let Some(plugin_data) = plugin_data {
641 self.write_struct(&plugin_data.into_owned()).await?;
642 } else {
643 self.write_packet(crate::buffer_pool().get()).await?;
644 }
645
646 self.continue_auth().await?;
647
648 Ok(())
649 } else {
650 unreachable!("auth_switched flag should be checked by caller")
651 }
652 }
653
654 fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
655 Box::pin(async move {
658 match self.inner.auth_plugin {
659 AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
660 self.continue_mysql_native_password_auth().await?;
661 Ok(())
662 }
663 AuthPlugin::CachingSha2Password => {
664 self.continue_caching_sha2_password_auth().await?;
665 Ok(())
666 }
667 AuthPlugin::MysqlClearPassword => {
668 if self.inner.opts.enable_cleartext_plugin() {
669 self.continue_mysql_native_password_auth().await?;
670 Ok(())
671 } else {
672 Err(DriverError::CleartextPluginDisabled.into())
673 }
674 }
675 AuthPlugin::Ed25519 => {
676 self.continue_ed25519_auth().await?;
677 Ok(())
678 }
679 AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
680 name: String::from_utf8_lossy(name.as_ref()).to_string(),
681 }
682 .into()),
683 }
684 })
685 }
686
687 fn switch_to_compression(&mut self) -> Result<()> {
688 if self
689 .capabilities()
690 .contains(CapabilityFlags::CLIENT_COMPRESS)
691 {
692 if let Some(compression) = self.inner.opts.compression() {
693 if let Some(stream) = self.inner.stream.as_mut() {
694 stream.compress(compression);
695 }
696 }
697 }
698 Ok(())
699 }
700
701 async fn continue_ed25519_auth(&mut self) -> Result<()> {
702 let packet = self.read_packet().await?;
703 match packet.first() {
704 Some(0x00) => {
705 Ok(())
707 }
708 Some(0xfe) if !self.inner.auth_switched => {
709 let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
710 self.perform_auth_switch(auth_switch_request).await
711 }
712 _ => Err(DriverError::UnexpectedPacket {
713 payload: packet.to_vec(),
714 }
715 .into()),
716 }
717 }
718
719 async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
720 let packet = self.read_packet().await?;
721 match packet.first() {
722 Some(0x00) => {
723 Ok(())
725 }
726 Some(0x01) => match packet.get(1) {
727 Some(0x03) => {
728 self.drop_packet().await
730 }
731 Some(0x04) => {
732 let pass = self.inner.opts.pass().unwrap_or_default();
733 let mut pass = crate::buffer_pool().get_with(pass.as_bytes());
734 pass.as_mut().push(0);
735
736 if self.is_secure() || self.is_socket() {
737 self.write_packet(pass).await?;
738 } else {
739 if self.inner.server_key.is_none() {
740 self.write_bytes(&[0x02][..]).await?;
741 let packet = self.read_packet().await?;
742 self.inner.server_key = Some(packet[1..].to_vec());
743 }
744 for (i, byte) in pass.as_mut().iter_mut().enumerate() {
745 *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
746 }
747 let encrypted_pass = crypto::encrypt(
748 &pass,
749 self.inner.server_key.as_deref().expect("unreachable"),
750 );
751 self.write_bytes(&encrypted_pass).await?;
752 };
753 self.drop_packet().await?;
754 Ok(())
755 }
756 _ => Err(DriverError::UnexpectedPacket {
757 payload: packet.to_vec(),
758 }
759 .into()),
760 },
761 Some(0xfe) if !self.inner.auth_switched => {
762 let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
763 self.perform_auth_switch(auth_switch_request).await?;
764 Ok(())
765 }
766 _ => Err(DriverError::UnexpectedPacket {
767 payload: packet.to_vec(),
768 }
769 .into()),
770 }
771 }
772
773 async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
774 let packet = self.read_packet().await?;
775 match packet.first() {
776 Some(0x00) => Ok(()),
777 Some(0xfe) if !self.inner.auth_switched => {
778 let auth_switch = if packet.len() > 1 {
779 ParseBuf(&packet).parse(())?
780 } else {
781 let _ = ParseBuf(&packet).parse::<OldAuthSwitchRequest>(())?;
782 AuthSwitchRequest::new(
784 "mysql_old_password".as_bytes(),
785 self.inner.nonce.clone(),
786 )
787 };
788 self.perform_auth_switch(auth_switch).await
789 }
790 _ => Err(DriverError::UnexpectedPacket {
791 payload: packet.to_vec(),
792 }
793 .into()),
794 }
795 }
796
797 fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
799 let ok_packet = if self.has_pending_result() {
800 if self
801 .capabilities()
802 .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
803 {
804 ParseBuf(packet)
805 .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
806 .map(|x| x.into_inner())
807 } else {
808 ParseBuf(packet)
809 .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
810 .map(|x| x.into_inner())
811 }
812 } else {
813 ParseBuf(packet)
814 .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
815 .map(|x| x.into_inner())
816 };
817
818 if let Ok(ok_packet) = ok_packet {
819 self.handle_ok(ok_packet.into_owned());
820 } else {
821 let capabilities = if self.inner.handshake_complete {
829 self.capabilities()
830 } else {
831 CapabilityFlags::empty()
832 };
833 let err_packet = ParseBuf(packet).parse::<ErrPacket>(capabilities);
834 if let Ok(err_packet) = err_packet {
835 self.handle_err(err_packet)?;
836 return Ok(true);
837 }
838 }
839
840 Ok(false)
841 }
842
843 pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
844 loop {
845 let packet = crate::io::ReadPacket::new(&mut *self)
846 .await
847 .map_err(|io_err| {
848 self.inner.stream.take();
849 self.inner.disconnected = true;
850 Error::from(io_err)
851 })?;
852 if self.handle_packet(&packet)? {
853 continue;
855 } else {
856 return Ok(packet);
857 }
858 }
859 }
860
861 pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
863 let mut packets = Vec::with_capacity(n);
864 for _ in 0..n {
865 packets.push(self.read_packet().await?);
866 }
867 Ok(packets)
868 }
869
870 pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
871 crate::io::WritePacket::new(&mut *self, data)
872 .await
873 .map_err(|io_err| {
874 self.inner.stream.take();
875 self.inner.disconnected = true;
876 From::from(io_err)
877 })
878 }
879
880 pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
882 let buf = crate::buffer_pool().get_with(bytes);
883 self.write_packet(buf).await
884 }
885
886 pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
888 let mut buf = crate::buffer_pool().get();
889 x.serialize(buf.as_mut());
890 self.write_packet(buf).await
891 }
892
893 pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
895 self.clean_dirty().await?;
896 self.reset_seq_id();
897 self.write_struct(cmd).await
898 }
899
900 pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
902 debug_assert!(!body.is_empty());
903 self.clean_dirty().await?;
904 self.reset_seq_id();
905 self.write_packet(body).await
906 }
907
908 pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
910 where
911 T: AsRef<[u8]>,
912 {
913 let cmd_data = cmd_data.as_ref();
914 let mut buf = crate::buffer_pool().get();
915 let body = buf.as_mut();
916 body.push(cmd as u8);
917 body.extend_from_slice(cmd_data);
918 self.write_command_raw(buf).await
919 }
920
921 async fn drop_packet(&mut self) -> Result<()> {
922 self.read_packet().await?;
923 Ok(())
924 }
925
926 async fn run_init_commands(&mut self) -> Result<()> {
927 let mut init = self.inner.opts.init().to_vec();
928
929 while let Some(query) = init.pop() {
930 self.query_drop(query).await?;
931 }
932
933 Ok(())
934 }
935
936 async fn run_setup_commands(&mut self) -> Result<()> {
937 let mut setup = self.inner.opts.setup().to_vec();
938
939 while let Some(query) = setup.pop() {
940 self.query_drop(query).await?;
941 }
942
943 Ok(())
944 }
945
946 pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
948 let opts = opts.into();
949 async move {
950 let mut conn = Conn::empty(opts.clone());
951
952 let stream = if let Some(_path) = opts.socket() {
953 #[cfg(unix)]
954 {
955 Stream::connect_socket(_path.to_owned()).await?
956 }
957 #[cfg(not(unix))]
958 return Err(crate::DriverError::NamedPipesDisabled.into());
959 } else {
960 let keepalive = opts
961 .tcp_keepalive()
962 .map(|x| std::time::Duration::from_millis(x.into()));
963 Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
964 };
965
966 conn.inner.stream = Some(stream);
967 conn.setup_stream()?;
968 conn.handle_handshake().await?;
969 conn.switch_to_ssl_if_needed().await?;
970 conn.do_handshake_response().await?;
971 conn.continue_auth().await?;
972 conn.switch_to_compression()?;
973 conn.read_settings().await?;
974 conn.reconnect_via_socket_if_needed().await?;
975 conn.run_init_commands().await?;
976 conn.run_setup_commands().await?;
977
978 Ok(conn)
979 }
980 .boxed()
981 }
982
983 pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
985 Conn::new(Opts::from_str(url.as_ref())?).await
986 }
987
988 async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
992 if let Some(socket) = self.inner.socket.as_ref() {
993 let opts = self.inner.opts.clone();
994 if opts.socket().is_none() {
995 let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
996 if let Ok(conn) = Conn::new(opts).await {
997 let old_conn = std::mem::replace(self, conn);
998 old_conn.close_conn().await?;
1000 }
1001 }
1002 }
1003 Ok(())
1004 }
1005
1006 async fn read_settings(&mut self) -> Result<()> {
1016 enum Action {
1017 Load(Cfg),
1018 Apply(CfgData),
1019 }
1020
1021 enum CfgData {
1022 MaxAllowedPacket(usize),
1023 WaitTimeout(usize),
1024 }
1025
1026 impl CfgData {
1027 fn apply(&self, conn: &mut Conn) {
1028 match self {
1029 Self::MaxAllowedPacket(value) => {
1030 if let Some(stream) = conn.inner.stream.as_mut() {
1031 stream.set_max_allowed_packet(*value);
1032 }
1033 }
1034 Self::WaitTimeout(value) => {
1035 conn.inner.wait_timeout = Duration::from_secs(*value as u64);
1036 }
1037 }
1038 }
1039 }
1040
1041 enum Cfg {
1042 Socket,
1043 MaxAllowedPacket,
1044 WaitTimeout,
1045 }
1046
1047 impl Cfg {
1048 const fn name(&self) -> &'static str {
1049 match self {
1050 Self::Socket => "@@socket",
1051 Self::MaxAllowedPacket => "@@max_allowed_packet",
1052 Self::WaitTimeout => "@@wait_timeout",
1053 }
1054 }
1055
1056 fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1057 match self {
1058 Cfg::Socket => {
1059 conn.inner.socket = value.and_then(crate::from_value);
1060 }
1061 Cfg::MaxAllowedPacket => {
1062 if let Some(stream) = conn.inner.stream.as_mut() {
1063 stream.set_max_allowed_packet(
1064 value
1065 .and_then(crate::from_value)
1066 .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1067 );
1068 }
1069 }
1070 Cfg::WaitTimeout => {
1071 conn.inner.wait_timeout = Duration::from_secs(
1072 value
1073 .and_then(crate::from_value)
1074 .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1075 );
1076 }
1077 }
1078 }
1079 }
1080
1081 let mut actions = vec![
1082 if let Some(x) = self.opts().max_allowed_packet() {
1083 Action::Apply(CfgData::MaxAllowedPacket(x))
1084 } else {
1085 Action::Load(Cfg::MaxAllowedPacket)
1086 },
1087 if let Some(x) = self.opts().wait_timeout() {
1088 Action::Apply(CfgData::WaitTimeout(x))
1089 } else {
1090 Action::Load(Cfg::WaitTimeout)
1091 },
1092 ];
1093
1094 if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1095 actions.push(Action::Load(Cfg::Socket))
1096 }
1097
1098 let loads = actions
1099 .iter()
1100 .filter_map(|x| match x {
1101 Action::Load(x) => Some(x),
1102 Action::Apply(_) => None,
1103 })
1104 .collect::<Vec<_>>();
1105
1106 let loaded = if !loads.is_empty() {
1107 let query = loads
1108 .iter()
1109 .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1110 .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1111 acc.push(prefix);
1112 acc.push_str(cfg.name());
1113 acc
1114 });
1115
1116 self.query_internal::<Row, String>(query)
1117 .await?
1118 .map(|row| row.unwrap())
1119 .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1120 } else {
1121 vec![]
1122 };
1123 let mut loaded = loaded.into_iter();
1124
1125 for action in actions {
1126 match action {
1127 Action::Load(cfg) => cfg.apply(self, loaded.next()),
1128 Action::Apply(cfg) => cfg.apply(self),
1129 }
1130 }
1131
1132 Ok(())
1133 }
1134
1135 fn expired(&self) -> bool {
1138 if let Some(deadline) = self.inner.ttl_deadline {
1139 if Instant::now() > deadline {
1140 return true;
1141 }
1142 }
1143 let ttl = self
1144 .inner
1145 .opts
1146 .conn_ttl()
1147 .unwrap_or(self.inner.wait_timeout);
1148 !ttl.is_zero() && self.idling() > ttl
1149 }
1150
1151 fn idling(&self) -> Duration {
1153 self.inner.last_io.elapsed()
1154 }
1155
1156 pub async fn reset(&mut self) -> Result<bool> {
1163 let supports_com_reset_connection = if self.inner.is_mariadb {
1164 self.inner.version >= (10, 2, 4)
1165 } else {
1166 self.inner.version > (5, 7, 2)
1168 };
1169
1170 if supports_com_reset_connection {
1171 self.routine(routines::ResetRoutine).await?;
1172 self.inner.stmt_cache.clear();
1173 self.inner.infile_handler = None;
1174 self.run_setup_commands().await?;
1175 }
1176
1177 Ok(supports_com_reset_connection)
1178 }
1179
1180 pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1192 if opts != ChangeUserOpts::default() {
1194 let mut opts_changed = false;
1195 if let Some(user) = opts.user() {
1196 opts_changed |= user != self.opts().user()
1197 };
1198 if let Some(pass) = opts.pass() {
1199 opts_changed |= pass != self.opts().pass()
1200 };
1201 if let Some(db_name) = opts.db_name() {
1202 opts_changed |= db_name != self.opts().db_name()
1203 };
1204 if opts_changed {
1205 if let Some(pool) = self.inner.pool.take() {
1206 pool.cancel_connection();
1207 }
1208 }
1209 }
1210
1211 let conn_opts = &mut self.inner.opts;
1212 opts.update_opts(conn_opts);
1213 self.routine(routines::ChangeUser).await?;
1214 self.inner.stmt_cache.clear();
1215 self.inner.infile_handler = None;
1216 self.run_setup_commands().await?;
1217 Ok(())
1218 }
1219
1220 async fn reset_for_pool(mut self) -> Result<Self> {
1224 if !self.reset().await? {
1225 self.change_user(Default::default()).await?;
1226 }
1227 Ok(self)
1228 }
1229
1230 pub(crate) async fn rollback_transaction(&mut self) -> Result<()> {
1232 debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1233 self.inner.tx_status = TxStatus::None;
1234 self.query_drop("ROLLBACK").await
1235 }
1236
1237 pub(crate) fn more_results_exists(&self) -> bool {
1240 self.status()
1241 .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1242 }
1243
1244 pub(crate) async fn drop_result(&mut self) -> Result<()> {
1249 let meta = match self.set_pending_result(None)? {
1251 Some(PendingResult::Pending(meta)) => Some(meta),
1252 Some(PendingResult::Taken(meta)) => {
1253 Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1256 }
1257 None => None,
1258 };
1259
1260 let _ = self.set_pending_result(meta);
1261
1262 match self.use_pending_result() {
1263 Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1264 QueryResult::<'_, '_, TextProtocol>::new(self)
1265 .drop_result()
1266 .await
1267 }
1268 Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1269 QueryResult::<'_, '_, BinaryProtocol>::new(self)
1270 .drop_result()
1271 .await
1272 }
1273 Ok(None) => Ok(()),
1274 Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1275 unreachable!("this case must be handled earlier in this function")
1276 }
1277 }
1278 }
1279
1280 async fn cleanup_for_pool(mut self) -> Result<Self> {
1284 loop {
1285 let result = if self.has_pending_result() {
1286 self.drop_result().await
1287 } else if self.inner.tx_status != TxStatus::None {
1288 self.rollback_transaction().await
1289 } else {
1290 break;
1291 };
1292
1293 if let Err(err) = result {
1297 if err.is_fatal() {
1298 return Err(err);
1301 }
1302 }
1303 }
1304 Ok(self)
1305 }
1306}
1307
1308#[cfg(test)]
1309mod test {
1310 use bytes::Bytes;
1311 use futures_util::stream::{self, StreamExt};
1312 use mysql_common::constants::MAX_PAYLOAD_LEN;
1313 use rand::Rng;
1314 use tokio::{io::AsyncWriteExt, net::TcpListener};
1315
1316 use crate::{
1317 from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1318 OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler,
1319 };
1320
1321 #[tokio::test]
1322 async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1323 let opts = get_opts().client_found_rows(true);
1324 let mut conn = Conn::new(opts).await.unwrap();
1325
1326 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1327 .ignore(&mut conn)
1328 .await?;
1329
1330 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1331 .ignore(&mut conn)
1332 .await?;
1333
1334 assert_eq!(conn.affected_rows(), 1);
1336
1337 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1338 .ignore(&mut conn)
1339 .await?;
1340
1341 assert_eq!(conn.affected_rows(), 1);
1344
1345 Ok(())
1346 }
1347
1348 #[tokio::test]
1349 async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1350 let mut conn = Conn::new(get_opts()).await.unwrap();
1351
1352 "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1353 .ignore(&mut conn)
1354 .await?;
1355
1356 "INSERT INTO mysql.found_rows (val) VALUES (1)"
1357 .ignore(&mut conn)
1358 .await?;
1359
1360 assert_eq!(conn.affected_rows(), 1);
1362
1363 "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1364 .ignore(&mut conn)
1365 .await?;
1366
1367 assert_eq!(conn.affected_rows(), 0);
1369
1370 Ok(())
1371 }
1372
1373 #[test]
1374 fn opts_should_satisfy_send_and_sync() {
1375 struct A<T: Sync + Send>(T);
1376 #[allow(clippy::unnecessary_operation)]
1377 A(get_opts());
1378 }
1379
1380 #[tokio::test]
1381 async fn should_connect_without_database() -> super::Result<()> {
1382 let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1384 conn.ping().await?;
1385 conn.disconnect().await?;
1386
1387 let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1389 conn.ping().await?;
1390 conn.disconnect().await?;
1391
1392 Ok(())
1393 }
1394
1395 #[tokio::test]
1396 async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1397 let mut conn: Conn = Conn::new(get_opts()).await?;
1398
1399 conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1400 .await?;
1401
1402 conn.query_iter("SELECT 1").await?;
1404 conn.ping().await?;
1405
1406 let mut tx = conn.start_transaction(Default::default()).await?;
1408 tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1409 .await?;
1410 tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1411 drop(tx);
1412 conn.ping().await?;
1413
1414 let count: u8 = conn
1415 .query_first("SELECT COUNT(*) FROM mysql.foo")
1416 .await?
1417 .unwrap_or_default();
1418
1419 assert_eq!(count, 0);
1420
1421 Ok(())
1422 }
1423
1424 #[tokio::test]
1425 async fn should_connect() -> super::Result<()> {
1426 let mut conn: Conn = Conn::new(get_opts()).await?;
1427 conn.ping().await?;
1428 let plugins: Vec<String> = conn
1429 .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1430 row.take("Name").unwrap()
1431 })
1432 .await?;
1433
1434 let variants = vec![
1436 ("caching_sha2_password", 2_u8, "non-empty"),
1437 ("caching_sha2_password", 2_u8, ""),
1438 ("mysql_native_password", 0_u8, "non-empty"),
1439 ("mysql_native_password", 0_u8, ""),
1440 ]
1441 .into_iter()
1442 .filter(|variant| plugins.iter().any(|p| p == variant.0));
1443
1444 for (plug, val, pass) in variants {
1445 dbg!((plug, val, pass, conn.inner.version));
1446
1447 if plug == "mysql_native_password" && conn.inner.version >= (8, 4, 0) {
1448 continue;
1449 }
1450
1451 let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1452
1453 let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1454 conn.query_drop(query).await.unwrap();
1455
1456 if conn.inner.version < (8, 0, 11) {
1457 conn.query_drop(format!("SET old_passwords = {}", val))
1458 .await
1459 .unwrap();
1460 conn.query_drop(format!(
1461 "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1462 pass
1463 ))
1464 .await
1465 .unwrap();
1466 } else {
1467 conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1468 .await
1469 .unwrap();
1470 };
1471
1472 let opts = get_opts()
1473 .user(Some("test_user"))
1474 .pass(Some(pass))
1475 .db_name(None::<String>);
1476 let result = Conn::new(opts).await;
1477
1478 conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1479
1480 result?.disconnect().await?;
1481 }
1482
1483 if crate::test_misc::test_compression() {
1484 assert!(format!("{:?}", conn).contains("Compression"));
1485 }
1486
1487 if crate::test_misc::test_ssl() {
1488 assert!(format!("{:?}", conn).contains("Tls"));
1489 }
1490
1491 conn.disconnect().await?;
1492 Ok(())
1493 }
1494
1495 #[test]
1496 fn should_not_panic_if_dropped_without_tokio_runtime() {
1497 let fut = Conn::new(get_opts());
1498 let runtime = tokio::runtime::Runtime::new().unwrap();
1499 runtime.block_on(async {
1500 fut.await.unwrap();
1501 });
1502 }
1504
1505 #[tokio::test]
1506 async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1507 let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1508 let mut conn = Conn::new(opts).await?;
1509 let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1510 conn.disconnect().await?;
1511 assert_eq!(result, vec![(42, "foo".into())]);
1512 Ok(())
1513 }
1514
1515 #[tokio::test]
1516 async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1517 let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1518 let mut conn = Conn::new(opts).await?;
1519
1520 let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1522 assert_eq!(result, vec![(42, "foo".into())]);
1523
1524 if conn.reset().await? {
1526 result = conn.query("SELECT @a, @b").await?;
1527 assert_eq!(result, vec![(42, "foo".into())]);
1528 }
1529
1530 conn.change_user(Default::default()).await?;
1532 result = conn.query("SELECT @a, @b").await?;
1533 assert_eq!(result, vec![(42, "foo".into())]);
1534
1535 conn.disconnect().await?;
1536 Ok(())
1537 }
1538
1539 #[tokio::test]
1540 async fn should_reset_the_connection() -> super::Result<()> {
1541 let mut conn = Conn::new(get_opts()).await?;
1542
1543 assert_eq!(
1544 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1545 Value::NULL
1546 );
1547
1548 conn.query_drop("SET @foo = 'foo'").await?;
1549
1550 assert_eq!(
1551 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1552 "foo",
1553 );
1554
1555 if conn.reset().await? {
1556 assert_eq!(
1557 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1558 Value::NULL
1559 );
1560 } else {
1561 assert_eq!(
1562 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1563 "foo",
1564 );
1565 }
1566
1567 conn.disconnect().await?;
1568 Ok(())
1569 }
1570
1571 #[tokio::test]
1572 async fn should_change_user() -> super::Result<()> {
1573 type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool;
1575 type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec<String>;
1577
1578 #[allow(clippy::type_complexity)]
1579 const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [
1580 (
1581 "mysql_old_password",
1582 |is_mariadb, version| is_mariadb || version < (5, 7, 0),
1583 |is_mariadb, version, pass| {
1584 if is_mariadb {
1585 vec![
1586 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(),
1587 "SET old_passwords=1".into(),
1588 format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"),
1589 "SET old_passwords=0".into(),
1590 ]
1591 } else if matches!(version, (5, 6, _)) {
1592 vec![
1593 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(),
1594 format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"),
1595 ]
1596 } else {
1597 vec![
1598 "CREATE USER '__mats'@'%'".into(),
1599 format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1600 ]
1601 }
1602 },
1603 ),
1604 (
1605 "mysql_native_password",
1606 |is_mariadb, version| is_mariadb || version < (8, 4, 0),
1607 |is_mariadb, version, pass| {
1608 if is_mariadb {
1609 vec![
1610 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')")
1611 ]
1612 } else if version < (8, 0, 0) {
1613 vec![
1614 format!(
1615 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password"
1616 ),
1617 format!("SET old_passwords = 0"),
1618 format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1619 ]
1620 } else {
1621 vec![
1622 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'")
1623 ]
1624 }
1625 },
1626 ),
1627 (
1628 "caching_sha2_password",
1629 |is_mariadb, version| !is_mariadb && version >= (5, 8, 0),
1630 |_is_mariadb, _version, pass| {
1631 vec![
1632 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'")
1633 ]
1634 },
1635 ),
1636 (
1637 "client_ed25519",
1638 |is_mariadb, version| is_mariadb && version >= (11, 6, 2),
1639 |_is_mariadb, _version, pass| {
1640 vec![format!(
1641 "CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')"
1642 )]
1643 },
1644 ),
1645 ];
1646
1647 fn random_pass() -> String {
1648 let mut rng = rand::rng();
1649 let pass: [u8; 10] = rng.random();
1650
1651 IntoIterator::into_iter(pass)
1652 .map(|x| ((x % (123 - 97)) + 97) as char)
1653 .collect()
1654 }
1655
1656 let mut conn = Conn::new(get_opts()).await?;
1657
1658 assert_eq!(
1659 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1660 Value::NULL
1661 );
1662
1663 conn.query_drop("SET @foo = 'foo'").await?;
1664
1665 assert_eq!(
1666 conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1667 "foo",
1668 );
1669
1670 conn.change_user(Default::default()).await?;
1671 assert_eq!(
1672 conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1673 Value::NULL
1674 );
1675
1676 for (i, (plugin, should_run, create_statements)) in TEST_MATRIX.iter().enumerate() {
1677 dbg!(plugin);
1678 let is_mariadb = conn.inner.is_mariadb;
1679 let version = conn.server_version();
1680
1681 if should_run(is_mariadb, version) {
1682 let pass = random_pass();
1683
1684 let result = conn
1685 .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats")
1686 .await;
1687
1688 if matches!(version, (5, 6, _)) && i == 0 {
1689 drop(result);
1691 } else {
1692 result.unwrap();
1693 }
1694
1695 for statement in create_statements(is_mariadb, version, &pass) {
1696 conn.query_drop(dbg!(statement)).await.unwrap();
1697 }
1698
1699 let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1700 conn2
1701 .change_user(
1702 ChangeUserOpts::default()
1703 .with_db_name(None)
1704 .with_user(Some("__mats".into()))
1705 .with_pass(Some(pass)),
1706 )
1707 .await
1708 .unwrap();
1709
1710 let (db, user) = conn2
1711 .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1712 .await
1713 .unwrap()
1714 .unwrap();
1715 assert_eq!(db, None);
1716 assert!(user.starts_with("__mats"));
1717
1718 conn2.disconnect().await.unwrap();
1719 }
1720 }
1721
1722 Ok(())
1723 }
1724
1725 #[tokio::test]
1726 async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1727 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1728
1729 let mut conn = Conn::new(opts).await?;
1730 conn.exec_drop("DO ?", (1_u8,)).await?;
1731
1732 let stmt = conn.prep("DO 2").await?;
1733 conn.exec_drop(&stmt, ()).await?;
1734 conn.exec_drop(&stmt, ()).await?;
1735 conn.close(stmt).await?;
1736
1737 conn.exec_drop("DO 3", ()).await?;
1738 conn.exec_batch("DO 4", vec![(), ()]).await?;
1739 conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1740 let row: Option<(crate::Value, usize)> = conn
1741 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1742 .await?;
1743
1744 assert_eq!(row.unwrap().1, 1);
1745 assert_eq!(conn.inner.stmt_cache.len(), 0);
1746
1747 conn.disconnect().await?;
1748
1749 Ok(())
1750 }
1751
1752 #[tokio::test]
1753 async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1754 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1755 let mut conn = Conn::new(opts).await?;
1756 conn.exec_drop("DO 1", ()).await?;
1757 conn.exec_drop("DO 2", ()).await?;
1758 conn.exec_drop("DO 3", ()).await?;
1759 conn.exec_drop("DO 1", ()).await?;
1760 conn.exec_drop("DO 4", ()).await?;
1761 conn.exec_drop("DO 3", ()).await?;
1762 conn.exec_drop("DO 5", ()).await?;
1763 conn.exec_drop("DO 6", ()).await?;
1764 let row_opt = conn
1765 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1766 .await?;
1767 let (_, count): (String, usize) = row_opt.unwrap();
1768 assert_eq!(count, 3);
1769 let order = conn
1770 .stmt_cache_ref()
1771 .iter()
1772 .map(|item| item.1.query.0.as_ref())
1773 .collect::<Vec<&[u8]>>();
1774 assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1775 conn.disconnect().await?;
1776 Ok(())
1777 }
1778
1779 #[tokio::test]
1780 async fn should_perform_queries() -> super::Result<()> {
1781 let mut conn = Conn::new(get_opts()).await?;
1782 for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1783 let long_string = "A".repeat(x);
1784 let result: Vec<(String, u8)> = conn
1785 .query(format!(r"SELECT '{}', 231", long_string))
1786 .await?;
1787 assert_eq!((long_string, 231_u8), result[0]);
1788 }
1789 conn.disconnect().await?;
1790 Ok(())
1791 }
1792
1793 #[tokio::test]
1794 async fn should_query_drop() -> super::Result<()> {
1795 let mut conn = Conn::new(get_opts()).await?;
1796 conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1797 .await?;
1798 conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1799 let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1800 conn.disconnect().await?;
1801 assert_eq!(result, Some(1_u8));
1802 Ok(())
1803 }
1804
1805 #[tokio::test]
1806 async fn should_prepare_statement() -> super::Result<()> {
1807 let mut conn = Conn::new(get_opts()).await?;
1808 let stmt = conn.prep(r"SELECT ?").await?;
1809 conn.close(stmt).await?;
1810 conn.disconnect().await?;
1811
1812 let mut conn = Conn::new(get_opts()).await?;
1813 let stmt = conn.prep(r"SELECT :foo").await?;
1814
1815 {
1816 let query = String::from("SELECT ?, ?");
1817 let stmt = conn.prep(&*query).await?;
1818 conn.close(stmt).await?;
1819 {
1820 let mut conn = Conn::new(get_opts()).await?;
1821 let stmt = conn.prep(&*query).await?;
1822 conn.close(stmt).await?;
1823 conn.disconnect().await?;
1824 }
1825 }
1826
1827 conn.close(stmt).await?;
1828 conn.disconnect().await?;
1829
1830 Ok(())
1831 }
1832
1833 #[tokio::test]
1834 async fn should_execute_statement() -> super::Result<()> {
1835 let long_string = "A".repeat(18 * 1024 * 1024);
1836 let mut conn = Conn::new(get_opts()).await?;
1837 let stmt = conn.prep(r"SELECT ?").await?;
1838 let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1839 let mut mapped = result.map_and_drop(from_row::<(String,)>).await?;
1840 assert_eq!(mapped.len(), 1);
1841 assert_eq!(mapped.pop(), Some((long_string,)));
1842 let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1843 let collected = result.collect_and_drop::<(u8,)>().await?;
1844 assert_eq!(collected, vec![(42u8,)]);
1845 let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1846 let reduced = result
1847 .reduce_and_drop(2, |mut acc, row| {
1848 acc += from_row::<i32>(row);
1849 acc
1850 })
1851 .await?;
1852 conn.close(stmt).await?;
1853 conn.disconnect().await?;
1854 assert_eq!(reduced, 10);
1855
1856 let mut conn = Conn::new(get_opts()).await?;
1857 let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1858 let result = conn
1859 .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1860 .await?;
1861 let mut mapped = result
1862 .map_and_drop(from_row::<(String, String, String, u8)>)
1863 .await?;
1864 assert_eq!(mapped.len(), 1);
1865 assert_eq!(
1866 mapped.pop(),
1867 Some(("quux".into(), "baz".into(), "quux".into(), 3))
1868 );
1869 let result = conn
1870 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1871 .await?;
1872 let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1873 assert_eq!(collected, vec![(2, 3, 2, 3)]);
1874 let result = conn
1875 .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1876 .await?;
1877 let reduced = result
1878 .reduce_and_drop(0, |acc, row| {
1879 let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1880 acc + a + b + c + d
1881 })
1882 .await?;
1883 conn.close(stmt).await?;
1884 conn.disconnect().await?;
1885 assert_eq!(reduced, 10);
1886 Ok(())
1887 }
1888
1889 #[tokio::test]
1890 async fn should_prep_exec_statement() -> super::Result<()> {
1891 let mut conn = Conn::new(get_opts()).await?;
1892 let result = conn
1893 .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1894 .await?;
1895 let output = result
1896 .map_and_drop(|row| {
1897 let (a, b, c): (u8, u8, u8) = from_row(row);
1898 a * b * c
1899 })
1900 .await?;
1901 conn.disconnect().await?;
1902 assert_eq!(output[0], 12u8);
1903 Ok(())
1904 }
1905
1906 #[tokio::test]
1907 async fn should_first_exec_statement() -> super::Result<()> {
1908 let mut conn = Conn::new(get_opts()).await?;
1909 let output = conn
1910 .exec_first(
1911 r"SELECT :a UNION ALL SELECT :b",
1912 params! { "a" => 2, "b" => 3 },
1913 )
1914 .await?;
1915 conn.disconnect().await?;
1916 assert_eq!(output, Some(2u8));
1917 Ok(())
1918 }
1919
1920 #[tokio::test]
1921 async fn issue_107() -> super::Result<()> {
1922 let mut conn = Conn::new(get_opts()).await?;
1923 conn.query_drop(
1924 r"CREATE TEMPORARY TABLE mysql.issue (
1925 a BIGINT(20) UNSIGNED,
1926 b VARBINARY(16),
1927 c BINARY(32),
1928 d BIGINT(20) UNSIGNED,
1929 e BINARY(32)
1930 )",
1931 )
1932 .await?;
1933 conn.query_drop(
1934 r"INSERT INTO mysql.issue VALUES (
1935 0,
1936 0xC066F966B0860000,
1937 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1938 0,
1939 ''
1940 ), (
1941 1,
1942 '',
1943 0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1944 0,
1945 ''
1946 )",
1947 )
1948 .await?;
1949
1950 let q = "SELECT b, c, d, e FROM mysql.issue";
1951 let result = conn.query_iter(q).await?;
1952
1953 let loaded_structs = result
1954 .map_and_drop(crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>)
1955 .await?;
1956
1957 conn.disconnect().await?;
1958
1959 assert_eq!(loaded_structs.len(), 2);
1960
1961 Ok(())
1962 }
1963
1964 #[tokio::test]
1965 async fn should_run_transactions() -> super::Result<()> {
1966 let mut conn = Conn::new(get_opts()).await?;
1967 conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
1968 .await?;
1969 let mut transaction = conn.start_transaction(Default::default()).await?;
1970 transaction
1971 .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
1972 .await?;
1973 assert_eq!(transaction.last_insert_id(), None);
1974 assert_eq!(transaction.affected_rows(), 2);
1975 assert_eq!(transaction.get_warnings(), 0);
1976 assert_eq!(transaction.info(), "Records: 2 Duplicates: 0 Warnings: 0");
1977 transaction.commit().await?;
1978 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1979 assert_eq!(output_opt, Some((2u8,)));
1980 let mut transaction = conn.start_transaction(Default::default()).await?;
1981 transaction
1982 .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
1983 .await?;
1984 let output_opt = transaction
1985 .exec_first("SELECT COUNT(*) FROM tmp", ())
1986 .await?;
1987 assert_eq!(output_opt, Some((4u8,)));
1988 transaction.rollback().await?;
1989 let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1990 assert_eq!(output_opt, Some((2u8,)));
1991
1992 let mut transaction = conn.start_transaction(Default::default()).await?;
1993 transaction
1994 .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
1995 .await?;
1996 drop(transaction); let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1998 assert_eq!(output_opt, Some((2u8,)));
1999
2000 conn.disconnect().await?;
2001 Ok(())
2002 }
2003
2004 #[tokio::test]
2005 async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
2006 const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
2007 const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
2008 let mut conn = Conn::new(get_opts()).await.unwrap();
2009
2010 let result = QUERY_FIRST.run(&mut conn).await;
2012 assert!(matches!(result, Err(Error::Server(_))));
2013
2014 let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
2015
2016 let result_set: Vec<u8> = result.collect().await.unwrap();
2018 assert_eq!(result_set, vec![1]);
2019
2020 let result_set: super::Result<Vec<u8>> = result.collect().await;
2022 assert!(matches!(result_set, Err(Error::Server(_))));
2023
2024 assert!(result.is_empty());
2026
2027 conn.ping().await?;
2028 conn.disconnect().await?;
2029
2030 Ok(())
2031 }
2032
2033 #[tokio::test]
2034 async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
2035 const PROC_DEF_FIRST: &str =
2036 r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
2037 const PROC_DEF_MIDDLE: &str =
2038 r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
2039
2040 let mut conn = Conn::new(get_opts()).await.unwrap();
2041
2042 conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
2043 .await?;
2044 conn.query_iter(PROC_DEF_FIRST).await?;
2045
2046 conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
2047 .await?;
2048 conn.query_iter(PROC_DEF_MIDDLE).await?;
2049
2050 let result = conn.query_iter("CALL err_first()").await;
2052 assert!(matches!(result, Err(Error::Server(_))));
2053
2054 let mut result = conn.query_iter("CALL err_middle()").await?;
2055
2056 let result_set: Vec<u8> = result.collect().await.unwrap();
2058 assert_eq!(result_set, vec![1]);
2059
2060 let result_set: super::Result<Vec<u8>> = result.collect().await;
2062 assert!(matches!(result_set, Err(Error::Server(_))));
2063
2064 assert!(result.is_empty());
2066
2067 conn.ping().await?;
2068 conn.disconnect().await?;
2069
2070 Ok(())
2071 }
2072
2073 #[tokio::test]
2074 async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
2075 use std::fs::write;
2076
2077 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2078 let file_path = file_path.path();
2079 let file_name = file_path.file_name().unwrap();
2080
2081 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2082
2083 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2084
2085 let mut conn = Conn::new(opts).await.unwrap();
2087 "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
2088
2089 let query = format!(
2090 r#"SELECT * FROM tmp;
2091 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2092 LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2093 SELECT * FROM tmp"#,
2094 file_name.to_str().unwrap(),
2095 file_name.to_str().unwrap(),
2096 );
2097
2098 let mut result = query.run(&mut conn).await?;
2099
2100 let result_set = result.collect::<String>().await?;
2101 assert_eq!(result_set.len(), 0);
2102
2103 let mut no_local_infile = false;
2104
2105 for _ in 0..2 {
2106 match result.collect::<String>().await {
2107 Ok(result_set) => {
2108 assert_eq!(result.affected_rows(), 3);
2109 assert!(result_set.is_empty())
2110 }
2111 Err(Error::Server(ref err)) if err.code == 1148 => {
2112 no_local_infile = true;
2114 break;
2115 }
2116 Err(Error::Server(ref err)) if err.code == 3948 => {
2117 no_local_infile = true;
2120 break;
2121 }
2122 Err(err) => return Err(err),
2123 }
2124 }
2125
2126 if no_local_infile {
2127 assert!(result.is_empty());
2128 assert_eq!(result_set.len(), 0);
2129 } else {
2130 let result_set = result.collect::<String>().await?;
2131 assert_eq!(result_set.len(), 6);
2132 assert_eq!(result_set[0], "AAAAAA");
2133 assert_eq!(result_set[1], "BBBBBB");
2134 assert_eq!(result_set[2], "CCCCCC");
2135 assert_eq!(result_set[3], "AAAAAA");
2136 assert_eq!(result_set[4], "BBBBBB");
2137 assert_eq!(result_set[5], "CCCCCC");
2138 }
2139
2140 conn.ping().await?;
2141 conn.disconnect().await?;
2142
2143 Ok(())
2144 }
2145
2146 #[tokio::test]
2147 async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2148 let mut c = Conn::new(get_opts()).await?;
2149 c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2150 .await?;
2151
2152 let mut result = c
2153 .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2154 .await?;
2155 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2156
2157 result.for_each(drop).await?;
2158 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2159
2160 result.for_each(drop).await?;
2161 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2162
2163 result.for_each(drop).await?;
2164 assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2165
2166 c.disconnect().await?;
2167 Ok(())
2168 }
2169
2170 #[tokio::test]
2171 async fn should_expose_query_result_metadata() -> super::Result<()> {
2172 let pool = Pool::new(get_opts());
2173 let mut c = pool.get_conn().await?;
2174
2175 c.query_drop(
2176 r"
2177 CREATE TEMPORARY TABLE `foo`
2178 ( `id` SERIAL
2179 , `bar_id` varchar(36) NOT NULL
2180 , `baz_id` varchar(36) NOT NULL
2181 , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2182 , PRIMARY KEY (`id`)
2183 , KEY `bar_idx` (`bar_id`)
2184 , KEY `baz_idx` (`baz_id`)
2185 );",
2186 )
2187 .await?;
2188
2189 const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2190 let params = ("qwerty", "data.employee_id");
2191
2192 let query_result = c.exec_iter(QUERY, params).await?;
2193 assert_eq!(query_result.last_insert_id(), Some(1));
2194 query_result.drop_result().await?;
2195
2196 c.exec_drop(QUERY, params).await?;
2197 assert_eq!(c.last_insert_id(), Some(2));
2198
2199 let mut tx = c.start_transaction(Default::default()).await?;
2200
2201 tx.exec_drop(QUERY, params).await?;
2202 assert_eq!(tx.last_insert_id(), Some(3));
2203
2204 Ok(())
2205 }
2206
2207 #[tokio::test]
2208 async fn should_handle_local_infile_locally() -> super::Result<()> {
2209 let mut conn = Conn::new(get_opts()).await.unwrap();
2210 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2211 .await
2212 .unwrap();
2213
2214 conn.set_infile_handler(async move {
2215 Ok(
2216 stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2217 .map(Ok)
2218 .boxed(),
2219 )
2220 });
2221
2222 match conn
2223 .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2224 .await
2225 {
2226 Ok(_) => (),
2227 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2228 return Ok(());
2230 }
2231 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2232 return Ok(());
2235 }
2236 e @ Err(_) => e.unwrap(),
2237 };
2238
2239 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2240 assert_eq!(result.len(), 3);
2241 assert_eq!(result[0], "AAAAAA");
2242 assert_eq!(result[1], "BBBBBB");
2243 assert_eq!(result[2], "CCCCCC");
2244
2245 Ok(())
2246 }
2247
2248 #[tokio::test]
2249 async fn should_handle_local_infile_globally() -> super::Result<()> {
2250 use std::fs::write;
2251
2252 let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2253 let file_path = file_path.path();
2254 let file_name = file_path.file_name().unwrap();
2255
2256 write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2257
2258 let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2259
2260 let mut conn = Conn::new(opts).await.unwrap();
2261 conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2262 .await
2263 .unwrap();
2264
2265 match conn
2266 .query_drop(format!(
2267 r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2268 file_name.to_str().unwrap(),
2269 ))
2270 .await
2271 {
2272 Ok(_) => (),
2273 Err(super::Error::Server(ref err)) if err.code == 1148 => {
2274 return Ok(());
2276 }
2277 Err(super::Error::Server(ref err)) if err.code == 3948 => {
2278 return Ok(());
2281 }
2282 e @ Err(_) => e.unwrap(),
2283 };
2284
2285 let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2286 assert_eq!(result.len(), 3);
2287 assert_eq!(result[0], "AAAAAA");
2288 assert_eq!(result[1], "BBBBBB");
2289 assert_eq!(result[2], "CCCCCC");
2290
2291 Ok(())
2292 }
2293
2294 #[tokio::test]
2295 async fn should_handle_initial_error_packet() {
2296 let header = [
2297 0x68, 0x00, 0x00, 0x00, 0xff, 0x69, 0x04, ];
2302 let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'";
2303
2304 let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap();
2306
2307 let listen_addr = listener.local_addr().unwrap();
2308
2309 tokio::task::spawn(async move {
2310 let (mut stream, _) = listener.accept().await.unwrap();
2311 stream.write_all(&header).await.unwrap();
2312 stream.write_all(error_message.as_bytes()).await.unwrap();
2313 stream.shutdown().await.unwrap();
2314 });
2315
2316 let opts = OptsBuilder::default()
2317 .ip_or_hostname(listen_addr.ip().to_string())
2318 .tcp_port(listen_addr.port());
2319 let server_err = match Conn::new(opts).await {
2320 Err(Error::Server(server_err)) => server_err,
2321 other => panic!("expected server error but got: {:?}", other),
2322 };
2323 assert_eq!(
2324 server_err,
2325 ServerError {
2326 code: 1129,
2327 state: "HY000".to_owned(),
2328 message: error_message.to_owned(),
2329 }
2330 );
2331 }
2332
2333 #[cfg(feature = "nightly")]
2334 mod bench {
2335 use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2336
2337 #[bench]
2338 fn simple_exec(bencher: &mut test::Bencher) {
2339 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2340 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2341
2342 bencher.iter(|| {
2343 runtime.block_on(conn.query_drop("DO 1")).unwrap();
2344 });
2345
2346 runtime.block_on(conn.disconnect()).unwrap();
2347 }
2348
2349 #[bench]
2350 fn select_large_string(bencher: &mut test::Bencher) {
2351 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2352 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2353
2354 bencher.iter(|| {
2355 runtime
2356 .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2357 .unwrap();
2358 });
2359
2360 runtime.block_on(conn.disconnect()).unwrap();
2361 }
2362
2363 #[bench]
2364 fn prepared_exec(bencher: &mut test::Bencher) {
2365 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2366 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2367 let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2368
2369 bencher.iter(|| {
2370 runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2371 });
2372
2373 runtime.block_on(conn.close(stmt)).unwrap();
2374 runtime.block_on(conn.disconnect()).unwrap();
2375 }
2376
2377 #[bench]
2378 fn prepare_and_exec(bencher: &mut test::Bencher) {
2379 let mut runtime = tokio::runtime::Runtime::new().unwrap();
2380 let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2381
2382 bencher.iter(|| {
2383 runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2384 });
2385
2386 runtime.block_on(conn.disconnect()).unwrap();
2387 }
2388 }
2389}