mysql_async/conn/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use futures_util::FutureExt;
10
11use mysql_common::{
12    constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
13    crypto,
14    io::ParseBuf,
15    packets::{
16        AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket,
17        HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket,
18        ResultSetTerminator, SslRequest,
19    },
20    proto::MySerialize,
21    row::Row,
22};
23
24use std::{
25    borrow::Cow,
26    fmt,
27    future::Future,
28    mem::{self, replace},
29    pin::Pin,
30    str::FromStr,
31    sync::Arc,
32    time::{Duration, Instant},
33};
34
35use crate::{
36    buffer_pool::PooledBuf,
37    conn::{pool::Pool, stmt_cache::StmtCache},
38    consts::{CapabilityFlags, Command, StatusFlags},
39    error::*,
40    io::Stream,
41    opts::Opts,
42    queryable::{
43        query_result::{QueryResult, ResultSetMeta},
44        transaction::TxStatus,
45        BinaryProtocol, Queryable, TextProtocol,
46    },
47    ChangeUserOpts, InfileData, OptsBuilder,
48};
49
50use self::routines::Routine;
51
52#[cfg(feature = "binlog")]
53pub mod binlog_stream;
54pub mod pool;
55pub mod routines;
56pub mod stmt_cache;
57
58const DEFAULT_WAIT_TIMEOUT: usize = 28800;
59
60/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
61fn disconnect(mut conn: Conn) {
62    let disconnected = conn.inner.disconnected;
63
64    // Mark conn as disconnected.
65    conn.inner.disconnected = true;
66
67    if !disconnected {
68        // We shouldn't call tokio::spawn if unwinding
69        if std::thread::panicking() {
70            return;
71        }
72
73        // Server will report broken connection if spawn fails.
74        // this might fail if, say, the runtime is shutting down, but we've done what we could
75        if let Ok(handle) = tokio::runtime::Handle::try_current() {
76            handle.spawn(async move {
77                if let Ok(conn) = conn.cleanup_for_pool().await {
78                    let _ = conn.disconnect().await;
79                }
80            });
81        }
82    }
83}
84
85/// Pending result set.
86#[derive(Debug, Clone)]
87pub(crate) enum PendingResult {
88    /// There is a pending result set.
89    Pending(ResultSetMeta),
90    /// Result set metadata was taken but not yet consumed.
91    Taken(Arc<ResultSetMeta>),
92}
93
94/// Mysql connection
95struct ConnInner {
96    stream: Option<Stream>,
97    id: u32,
98    is_mariadb: bool,
99    version: (u16, u16, u16),
100    socket: Option<String>,
101    capabilities: CapabilityFlags,
102    status: StatusFlags,
103    last_ok_packet: Option<OkPacket<'static>>,
104    last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105    handshake_complete: bool,
106    pool: Option<Pool>,
107    pending_result: std::result::Result<Option<PendingResult>, ServerError>,
108    tx_status: TxStatus,
109    reset_upon_returning_to_a_pool: bool,
110    opts: Opts,
111    ttl_deadline: Option<Instant>,
112    last_io: Instant,
113    wait_timeout: Duration,
114    stmt_cache: StmtCache,
115    nonce: Vec<u8>,
116    auth_plugin: AuthPlugin<'static>,
117    auth_switched: bool,
118    server_key: Option<Vec<u8>>,
119    active_since: Instant,
120    /// Connection is already disconnected.
121    pub(crate) disconnected: bool,
122    /// One-time connection-level infile handler.
123    infile_handler:
124        Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
125}
126
127impl fmt::Debug for ConnInner {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        f.debug_struct("Conn")
130            .field("connection id", &self.id)
131            .field("server version", &self.version)
132            .field("pool", &self.pool)
133            .field("pending_result", &self.pending_result)
134            .field("tx_status", &self.tx_status)
135            .field("stream", &self.stream)
136            .field("options", &self.opts)
137            .field("server_key", &self.server_key)
138            .field("auth_plugin", &self.auth_plugin)
139            .finish()
140    }
141}
142
143impl ConnInner {
144    /// Constructs an empty connection.
145    fn empty(opts: Opts) -> ConnInner {
146        let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline();
147        ConnInner {
148            capabilities: opts.get_capabilities(),
149            status: StatusFlags::empty(),
150            last_ok_packet: None,
151            last_err_packet: None,
152            handshake_complete: false,
153            stream: None,
154            is_mariadb: false,
155            version: (0, 0, 0),
156            id: 0,
157            pending_result: Ok(None),
158            pool: None,
159            tx_status: TxStatus::None,
160            last_io: Instant::now(),
161            wait_timeout: Duration::from_secs(0),
162            stmt_cache: StmtCache::new(opts.stmt_cache_size()),
163            socket: opts.socket().map(Into::into),
164            opts,
165            ttl_deadline,
166            nonce: Vec::default(),
167            auth_plugin: AuthPlugin::MysqlNativePassword,
168            auth_switched: false,
169            disconnected: false,
170            server_key: None,
171            infile_handler: None,
172            reset_upon_returning_to_a_pool: false,
173            active_since: Instant::now(),
174        }
175    }
176
177    /// Returns mutable reference to a connection stream.
178    ///
179    /// Returns `DriverError::ConnectionClosed` if there is no stream.
180    fn stream_mut(&mut self) -> Result<&mut Stream> {
181        self.stream
182            .as_mut()
183            .ok_or_else(|| DriverError::ConnectionClosed.into())
184    }
185}
186
187/// MySql server connection.
188#[derive(Debug)]
189pub struct Conn {
190    inner: Box<ConnInner>,
191}
192
193impl Conn {
194    /// Returns connection identifier.
195    pub fn id(&self) -> u32 {
196        self.inner.id
197    }
198
199    /// Returns the ID generated by a query (usually `INSERT`) on a table with a column having the
200    /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection
201    /// or if the query did not update an AUTO_INCREMENT value.
202    pub fn last_insert_id(&self) -> Option<u64> {
203        self.inner
204            .last_ok_packet
205            .as_ref()
206            .and_then(|ok| ok.last_insert_id())
207    }
208
209    /// Returns the number of rows affected by the last `INSERT`, `UPDATE`, `REPLACE` or `DELETE`
210    /// query.
211    pub fn affected_rows(&self) -> u64 {
212        self.inner
213            .last_ok_packet
214            .as_ref()
215            .map(|ok| ok.affected_rows())
216            .unwrap_or_default()
217    }
218
219    /// Text information, as reported by the server in the last OK packet, or an empty string.
220    pub fn info(&self) -> Cow<'_, str> {
221        self.inner
222            .last_ok_packet
223            .as_ref()
224            .and_then(|ok| ok.info_str())
225            .unwrap_or_else(|| "".into())
226    }
227
228    /// Number of warnings, as reported by the server in the last OK packet, or `0`.
229    pub fn get_warnings(&self) -> u16 {
230        self.inner
231            .last_ok_packet
232            .as_ref()
233            .map(|ok| ok.warnings())
234            .unwrap_or_default()
235    }
236
237    /// Returns a reference to the last OK packet.
238    pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
239        self.inner.last_ok_packet.as_ref()
240    }
241
242    /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]).
243    ///
244    /// Only makes sense for pooled connections.
245    pub fn reset_connection(&mut self, reset_connection: bool) {
246        self.inner.reset_upon_returning_to_a_pool = reset_connection;
247    }
248
249    pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
250        self.inner.stream_mut()
251    }
252
253    pub(crate) fn capabilities(&self) -> CapabilityFlags {
254        self.inner.capabilities
255    }
256
257    /// Will update last IO time for this connection.
258    pub(crate) fn touch(&mut self) {
259        self.inner.last_io = Instant::now();
260    }
261
262    /// Will set packet sequence id to `0`.
263    pub(crate) fn reset_seq_id(&mut self) {
264        if let Some(stream) = self.inner.stream.as_mut() {
265            stream.reset_seq_id();
266        }
267    }
268
269    /// Will syncronize sequence ids between compressed and uncompressed codecs.
270    pub(crate) fn sync_seq_id(&mut self) {
271        if let Some(stream) = self.inner.stream.as_mut() {
272            stream.sync_seq_id();
273        }
274    }
275
276    /// Handles OK packet.
277    pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
278        self.inner.status = ok_packet.status_flags();
279        self.inner.last_err_packet = None;
280        self.inner.last_ok_packet = Some(ok_packet);
281    }
282
283    /// Handles ERR packet.
284    pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
285        match err_packet {
286            ErrPacket::Error(err) => {
287                self.inner.status = StatusFlags::empty();
288                self.inner.last_ok_packet = None;
289                self.inner.last_err_packet = Some(err.clone().into_owned());
290                Err(Error::from(err))
291            }
292            ErrPacket::Progress(_) => Ok(()),
293        }
294    }
295
296    /// Returns the current transaction status.
297    pub(crate) fn get_tx_status(&self) -> TxStatus {
298        self.inner.tx_status
299    }
300
301    /// Sets the given transaction status for this connection.
302    pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
303        self.inner.tx_status = tx_status;
304    }
305
306    /// Returns pending result metadata, if any.
307    ///
308    /// If `Some(_)`, then result is not yet consumed.
309    pub(crate) fn use_pending_result(
310        &mut self,
311    ) -> std::result::Result<Option<&PendingResult>, ServerError> {
312        if let Err(ref e) = self.inner.pending_result {
313            let e = e.clone();
314            self.inner.pending_result = Ok(None);
315            Err(e)
316        } else {
317            Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
318        }
319    }
320
321    pub(crate) fn get_pending_result(
322        &self,
323    ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
324        self.inner.pending_result.as_ref().map(|x| x.as_ref())
325    }
326
327    pub(crate) fn has_pending_result(&self) -> bool {
328        self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_)))
329    }
330
331    /// Sets the given pening result metadata for this connection. Returns the previous value.
332    pub(crate) fn set_pending_result(
333        &mut self,
334        meta: Option<ResultSetMeta>,
335    ) -> std::result::Result<Option<PendingResult>, ServerError> {
336        replace(
337            &mut self.inner.pending_result,
338            Ok(meta.map(PendingResult::Pending)),
339        )
340    }
341
342    pub(crate) fn set_pending_result_error(
343        &mut self,
344        error: ServerError,
345    ) -> std::result::Result<Option<PendingResult>, ServerError> {
346        replace(&mut self.inner.pending_result, Err(error))
347    }
348
349    /// Gives the currently pending result to a caller for consumption.
350    pub(crate) fn take_pending_result(
351        &mut self,
352    ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
353        let mut output = None;
354
355        self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
356            Some(PendingResult::Pending(x)) => {
357                let meta = Arc::new(x);
358                output = Some(meta.clone());
359                Ok(Some(PendingResult::Taken(meta)))
360            }
361            x => Ok(x),
362        };
363
364        Ok(output)
365    }
366
367    /// Returns current status flags.
368    pub(crate) fn status(&self) -> StatusFlags {
369        self.inner.status
370    }
371
372    pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
373    where
374        F: Routine<T> + 'a,
375    {
376        self.inner.disconnected = true;
377        let result = f.call(&mut *self).await;
378        match result {
379            result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
380                // either OK or non-fatal error
381                self.inner.disconnected = false;
382                result
383            }
384            Err(err) => {
385                if self.inner.stream.is_some() {
386                    self.take_stream().close().await?;
387                }
388                Err(err)
389            }
390        }
391    }
392
393    /// Returns server version.
394    pub fn server_version(&self) -> (u16, u16, u16) {
395        self.inner.version
396    }
397
398    /// Returns connection options.
399    pub fn opts(&self) -> &Opts {
400        &self.inner.opts
401    }
402
403    /// Setup _local_ `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section
404    /// of the crate-level docs).
405    ///
406    /// It'll overwrite existing _local_ handler, if any.
407    ///
408    /// [2]: ../mysql_async/#local-infile-handlers
409    pub fn set_infile_handler<T>(&mut self, handler: T)
410    where
411        T: Future<Output = crate::Result<InfileData>>,
412        T: Send + Sync + 'static,
413    {
414        self.inner.infile_handler = Some(Box::pin(handler));
415    }
416
417    fn take_stream(&mut self) -> Stream {
418        self.inner.stream.take().unwrap()
419    }
420
421    /// Disconnects this connection from server.
422    pub async fn disconnect(mut self) -> Result<()> {
423        if !self.inner.disconnected {
424            self.inner.disconnected = true;
425            self.write_command_data(Command::COM_QUIT, &[]).await?;
426            let stream = self.take_stream();
427            stream.close().await?;
428        }
429        Ok(())
430    }
431
432    /// Closes the connection.
433    async fn close_conn(mut self) -> Result<()> {
434        self = self.cleanup_for_pool().await?;
435        self.disconnect().await
436    }
437
438    /// Returns true if io stream is encrypted.
439    fn is_secure(&self) -> bool {
440        #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
441        {
442            self.inner
443                .stream
444                .as_ref()
445                .map(|x| x.is_secure())
446                .unwrap_or_default()
447        }
448
449        #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
450        false
451    }
452
453    /// Returns true if io stream is socket.
454    fn is_socket(&self) -> bool {
455        #[cfg(unix)]
456        {
457            self.inner
458                .stream
459                .as_ref()
460                .map(|x| x.is_socket())
461                .unwrap_or_default()
462        }
463
464        #[cfg(not(unix))]
465        false
466    }
467
468    /// Hacky way to move connection through &mut. `self` becomes unusable.
469    fn take(&mut self) -> Conn {
470        mem::replace(self, Conn::empty(Default::default()))
471    }
472
473    fn empty(opts: Opts) -> Self {
474        Self {
475            inner: Box::new(ConnInner::empty(opts)),
476        }
477    }
478
479    /// Set `io::Stream` options as defined in the `Opts` of the connection.
480    ///
481    /// Requires that self.inner.stream is Some
482    fn setup_stream(&mut self) -> Result<()> {
483        debug_assert!(self.inner.stream.is_some());
484        if let Some(stream) = self.inner.stream.as_mut() {
485            stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
486        }
487        Ok(())
488    }
489
490    async fn handle_handshake(&mut self) -> Result<()> {
491        let packet = self.read_packet().await?;
492        let handshake = ParseBuf(&packet).parse::<HandshakePacket>(())?;
493
494        // Handshake scramble is always 21 bytes length (20 + zero terminator)
495        self.inner.nonce = {
496            let mut nonce = Vec::from(handshake.scramble_1_ref());
497            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
498            // Trim zero terminator. Fill with zeroes if nonce
499            // is somehow smaller than 20 bytes (this matches the server behavior).
500            nonce.resize(20, 0);
501            nonce
502        };
503
504        self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
505        self.inner.version = handshake
506            .maria_db_server_version_parsed()
507            .inspect(|_| self.inner.is_mariadb = true)
508            .or_else(|| handshake.server_version_parsed())
509            .unwrap_or((0, 0, 0));
510        self.inner.id = handshake.connection_id();
511        self.inner.status = handshake.status_flags();
512
513        // Allow only CachingSha2Password and MysqlNativePassword here
514        // because sha256_password is deprecated and other plugins won't
515        // appear here.
516        self.inner.auth_plugin = match handshake.auth_plugin() {
517            Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
518            _ => AuthPlugin::MysqlNativePassword,
519        };
520
521        Ok(())
522    }
523
524    async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
525        if self
526            .inner
527            .opts
528            .get_capabilities()
529            .contains(CapabilityFlags::CLIENT_SSL)
530        {
531            if !self
532                .inner
533                .capabilities
534                .contains(CapabilityFlags::CLIENT_SSL)
535            {
536                return Err(DriverError::NoClientSslFlagFromServer.into());
537            }
538
539            let collation = if self.inner.version >= (5, 5, 3) {
540                UTF8MB4_GENERAL_CI
541            } else {
542                UTF8_GENERAL_CI
543            };
544
545            let ssl_request = SslRequest::new(
546                self.inner.capabilities,
547                DEFAULT_MAX_ALLOWED_PACKET as u32,
548                collation as u8,
549            );
550            self.write_struct(&ssl_request).await?;
551            let conn = self;
552            let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable");
553            let domain = ssl_opts
554                .tls_hostname_override()
555                .unwrap_or_else(|| conn.opts().ip_or_hostname())
556                .into();
557            conn.stream_mut()?.make_secure(domain, ssl_opts).await?;
558            Ok(())
559        } else {
560            Ok(())
561        }
562    }
563
564    async fn do_handshake_response(&mut self) -> Result<()> {
565        let auth_data = self
566            .inner
567            .auth_plugin
568            .gen_data(self.inner.opts.pass(), &self.inner.nonce);
569
570        let handshake_response = HandshakeResponse::new(
571            auth_data.as_deref(),
572            self.inner.version,
573            self.inner.opts.user().map(|x| x.as_bytes()),
574            self.inner.opts.db_name().map(|x| x.as_bytes()),
575            Some(self.inner.auth_plugin.borrow()),
576            self.capabilities(),
577            Default::default(), // TODO: Add support
578            self.inner
579                .opts
580                .max_allowed_packet()
581                .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32,
582        );
583
584        // Serialize here to satisfy borrow checker.
585        let mut buf = crate::buffer_pool().get();
586        handshake_response.serialize(buf.as_mut());
587
588        self.write_packet(buf).await?;
589        self.inner.handshake_complete = true;
590        Ok(())
591    }
592
593    async fn perform_auth_switch(
594        &mut self,
595        auth_switch_request: AuthSwitchRequest<'_>,
596    ) -> Result<()> {
597        if !self.inner.auth_switched {
598            self.inner.auth_switched = true;
599            self.inner.nonce = auth_switch_request.plugin_data().to_vec();
600
601            if matches!(
602                auth_switch_request.auth_plugin(),
603                AuthPlugin::MysqlOldPassword
604            ) && self.inner.opts.secure_auth()
605            {
606                return Err(DriverError::MysqlOldPasswordDisabled.into());
607            }
608
609            self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
610
611            let plugin_data = match &self.inner.auth_plugin {
612                x @ AuthPlugin::CachingSha2Password => {
613                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
614                }
615                x @ AuthPlugin::MysqlNativePassword => {
616                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
617                }
618                x @ AuthPlugin::MysqlOldPassword => {
619                    if self.inner.opts.secure_auth() {
620                        return Err(DriverError::MysqlOldPasswordDisabled.into());
621                    } else {
622                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
623                    }
624                }
625                x @ AuthPlugin::MysqlClearPassword => {
626                    if self.inner.opts.enable_cleartext_plugin() {
627                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
628                    } else {
629                        return Err(DriverError::CleartextPluginDisabled.into());
630                    }
631                }
632                x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
633            };
634
635            if let Some(plugin_data) = plugin_data {
636                self.write_struct(&plugin_data.into_owned()).await?;
637            } else {
638                self.write_packet(crate::buffer_pool().get()).await?;
639            }
640
641            self.continue_auth().await?;
642
643            Ok(())
644        } else {
645            unreachable!("auth_switched flag should be checked by caller")
646        }
647    }
648
649    fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
650        // NOTE: we need to box this since it may recurse
651        // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782
652        Box::pin(async move {
653            match self.inner.auth_plugin {
654                AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
655                    self.continue_mysql_native_password_auth().await?;
656                    Ok(())
657                }
658                AuthPlugin::CachingSha2Password => {
659                    self.continue_caching_sha2_password_auth().await?;
660                    Ok(())
661                }
662                AuthPlugin::MysqlClearPassword => {
663                    if self.inner.opts.enable_cleartext_plugin() {
664                        self.continue_mysql_native_password_auth().await?;
665                        Ok(())
666                    } else {
667                        Err(DriverError::CleartextPluginDisabled.into())
668                    }
669                }
670                AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
671                    name: String::from_utf8_lossy(name.as_ref()).to_string(),
672                }
673                .into()),
674            }
675        })
676    }
677
678    fn switch_to_compression(&mut self) -> Result<()> {
679        if self
680            .capabilities()
681            .contains(CapabilityFlags::CLIENT_COMPRESS)
682        {
683            if let Some(compression) = self.inner.opts.compression() {
684                if let Some(stream) = self.inner.stream.as_mut() {
685                    stream.compress(compression);
686                }
687            }
688        }
689        Ok(())
690    }
691
692    async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
693        let packet = self.read_packet().await?;
694        match packet.first() {
695            Some(0x00) => {
696                // ok packet for empty password
697                Ok(())
698            }
699            Some(0x01) => match packet.get(1) {
700                Some(0x03) => {
701                    // auth ok
702                    self.drop_packet().await
703                }
704                Some(0x04) => {
705                    let pass = self.inner.opts.pass().unwrap_or_default();
706                    let mut pass = crate::buffer_pool().get_with(pass.as_bytes());
707                    pass.as_mut().push(0);
708
709                    if self.is_secure() || self.is_socket() {
710                        self.write_packet(pass).await?;
711                    } else {
712                        if self.inner.server_key.is_none() {
713                            self.write_bytes(&[0x02][..]).await?;
714                            let packet = self.read_packet().await?;
715                            self.inner.server_key = Some(packet[1..].to_vec());
716                        }
717                        for (i, byte) in pass.as_mut().iter_mut().enumerate() {
718                            *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
719                        }
720                        let encrypted_pass = crypto::encrypt(
721                            &pass,
722                            self.inner.server_key.as_deref().expect("unreachable"),
723                        );
724                        self.write_bytes(&encrypted_pass).await?;
725                    };
726                    self.drop_packet().await?;
727                    Ok(())
728                }
729                _ => Err(DriverError::UnexpectedPacket {
730                    payload: packet.to_vec(),
731                }
732                .into()),
733            },
734            Some(0xfe) if !self.inner.auth_switched => {
735                let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
736                self.perform_auth_switch(auth_switch_request).await?;
737                Ok(())
738            }
739            _ => Err(DriverError::UnexpectedPacket {
740                payload: packet.to_vec(),
741            }
742            .into()),
743        }
744    }
745
746    async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
747        let packet = self.read_packet().await?;
748        match packet.first() {
749            Some(0x00) => Ok(()),
750            Some(0xfe) if !self.inner.auth_switched => {
751                let auth_switch = if packet.len() > 1 {
752                    ParseBuf(&packet).parse(())?
753                } else {
754                    let _ = ParseBuf(&packet).parse::<OldAuthSwitchRequest>(())?;
755                    // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin
756                    AuthSwitchRequest::new(
757                        "mysql_old_password".as_bytes(),
758                        self.inner.nonce.clone(),
759                    )
760                };
761                self.perform_auth_switch(auth_switch).await
762            }
763            _ => Err(DriverError::UnexpectedPacket {
764                payload: packet.to_vec(),
765            }
766            .into()),
767        }
768    }
769
770    /// Returns `true` for ProgressReport packet.
771    fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
772        let ok_packet = if self.has_pending_result() {
773            if self
774                .capabilities()
775                .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
776            {
777                ParseBuf(packet)
778                    .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
779                    .map(|x| x.into_inner())
780            } else {
781                ParseBuf(packet)
782                    .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
783                    .map(|x| x.into_inner())
784            }
785        } else {
786            ParseBuf(packet)
787                .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
788                .map(|x| x.into_inner())
789        };
790
791        if let Ok(ok_packet) = ok_packet {
792            self.handle_ok(ok_packet.into_owned());
793        } else {
794            // If we haven't completed the handshake the server will not be aware of our
795            // capabilities and so it will behave as if we have none. In particular, the error
796            // packet will not contain a SQL State field even if our capabilities do contain the
797            // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
798            // with no capability assumptions if we have not completed the handshake.
799            //
800            // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
801            let capabilities = if self.inner.handshake_complete {
802                self.capabilities()
803            } else {
804                CapabilityFlags::empty()
805            };
806            let err_packet = ParseBuf(packet).parse::<ErrPacket>(capabilities);
807            if let Ok(err_packet) = err_packet {
808                self.handle_err(err_packet)?;
809                return Ok(true);
810            }
811        }
812
813        Ok(false)
814    }
815
816    pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
817        loop {
818            let packet = crate::io::ReadPacket::new(&mut *self)
819                .await
820                .map_err(|io_err| {
821                    self.inner.stream.take();
822                    self.inner.disconnected = true;
823                    Error::from(io_err)
824                })?;
825            if self.handle_packet(&packet)? {
826                // ignore progress report
827                continue;
828            } else {
829                return Ok(packet);
830            }
831        }
832    }
833
834    /// Returns future that reads packets from a server.
835    pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
836        let mut packets = Vec::with_capacity(n);
837        for _ in 0..n {
838            packets.push(self.read_packet().await?);
839        }
840        Ok(packets)
841    }
842
843    pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
844        crate::io::WritePacket::new(&mut *self, data)
845            .await
846            .map_err(|io_err| {
847                self.inner.stream.take();
848                self.inner.disconnected = true;
849                From::from(io_err)
850            })
851    }
852
853    /// Writes bytes to a server.
854    pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
855        let buf = crate::buffer_pool().get_with(bytes);
856        self.write_packet(buf).await
857    }
858
859    /// Sends a serializable structure to a server.
860    pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
861        let mut buf = crate::buffer_pool().get();
862        x.serialize(buf.as_mut());
863        self.write_packet(buf).await
864    }
865
866    /// Sends a command to a server.
867    pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
868        self.clean_dirty().await?;
869        self.reset_seq_id();
870        self.write_struct(cmd).await
871    }
872
873    /// Returns future that sends full command body to a server.
874    pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
875        debug_assert!(!body.is_empty());
876        self.clean_dirty().await?;
877        self.reset_seq_id();
878        self.write_packet(body).await
879    }
880
881    /// Returns future that writes command to a server.
882    pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
883    where
884        T: AsRef<[u8]>,
885    {
886        let cmd_data = cmd_data.as_ref();
887        let mut buf = crate::buffer_pool().get();
888        let body = buf.as_mut();
889        body.push(cmd as u8);
890        body.extend_from_slice(cmd_data);
891        self.write_command_raw(buf).await
892    }
893
894    async fn drop_packet(&mut self) -> Result<()> {
895        self.read_packet().await?;
896        Ok(())
897    }
898
899    async fn run_init_commands(&mut self) -> Result<()> {
900        let mut init = self.inner.opts.init().to_vec();
901
902        while let Some(query) = init.pop() {
903            self.query_drop(query).await?;
904        }
905
906        Ok(())
907    }
908
909    async fn run_setup_commands(&mut self) -> Result<()> {
910        let mut setup = self.inner.opts.setup().to_vec();
911
912        while let Some(query) = setup.pop() {
913            self.query_drop(query).await?;
914        }
915
916        Ok(())
917    }
918
919    /// Returns a future that resolves to [`Conn`].
920    pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
921        let opts = opts.into();
922        async move {
923            let mut conn = Conn::empty(opts.clone());
924
925            let stream = if let Some(_path) = opts.socket() {
926                #[cfg(unix)]
927                {
928                    Stream::connect_socket(_path.to_owned()).await?
929                }
930                #[cfg(not(unix))]
931                return Err(crate::DriverError::NamedPipesDisabled.into());
932            } else {
933                let keepalive = opts
934                    .tcp_keepalive()
935                    .map(|x| std::time::Duration::from_millis(x.into()));
936                Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
937            };
938
939            conn.inner.stream = Some(stream);
940            conn.setup_stream()?;
941            conn.handle_handshake().await?;
942            conn.switch_to_ssl_if_needed().await?;
943            conn.do_handshake_response().await?;
944            conn.continue_auth().await?;
945            conn.switch_to_compression()?;
946            conn.read_settings().await?;
947            conn.reconnect_via_socket_if_needed().await?;
948            conn.run_init_commands().await?;
949            conn.run_setup_commands().await?;
950
951            Ok(conn)
952        }
953        .boxed()
954    }
955
956    /// Returns a future that resolves to [`Conn`].
957    pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
958        Conn::new(Opts::from_str(url.as_ref())?).await
959    }
960
961    /// Will try to reconnect via socket using socket address in `self.inner.socket`.
962    ///
963    /// Won't try to reconnect if socket connection is already enforced in [`Opts`].
964    async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
965        if let Some(socket) = self.inner.socket.as_ref() {
966            let opts = self.inner.opts.clone();
967            if opts.socket().is_none() {
968                let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
969                if let Ok(conn) = Conn::new(opts).await {
970                    let old_conn = std::mem::replace(self, conn);
971                    // tidy up the old connection
972                    old_conn.close_conn().await?;
973                }
974            }
975        }
976        Ok(())
977    }
978
979    /// Configures the connection based on server settings. In particular:
980    ///
981    /// * It reads and stores socket address inside the connection unless if socket address is
982    ///   already in [`Opts`] or if `prefer_socket` is `false`.
983    ///
984    /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`]
985    ///
986    /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
987    ///
988    async fn read_settings(&mut self) -> Result<()> {
989        enum Action {
990            Load(Cfg),
991            Apply(CfgData),
992        }
993
994        enum CfgData {
995            MaxAllowedPacket(usize),
996            WaitTimeout(usize),
997        }
998
999        impl CfgData {
1000            fn apply(&self, conn: &mut Conn) {
1001                match self {
1002                    Self::MaxAllowedPacket(value) => {
1003                        if let Some(stream) = conn.inner.stream.as_mut() {
1004                            stream.set_max_allowed_packet(*value);
1005                        }
1006                    }
1007                    Self::WaitTimeout(value) => {
1008                        conn.inner.wait_timeout = Duration::from_secs(*value as u64);
1009                    }
1010                }
1011            }
1012        }
1013
1014        enum Cfg {
1015            Socket,
1016            MaxAllowedPacket,
1017            WaitTimeout,
1018        }
1019
1020        impl Cfg {
1021            const fn name(&self) -> &'static str {
1022                match self {
1023                    Self::Socket => "@@socket",
1024                    Self::MaxAllowedPacket => "@@max_allowed_packet",
1025                    Self::WaitTimeout => "@@wait_timeout",
1026                }
1027            }
1028
1029            fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1030                match self {
1031                    Cfg::Socket => {
1032                        conn.inner.socket = value.and_then(crate::from_value);
1033                    }
1034                    Cfg::MaxAllowedPacket => {
1035                        if let Some(stream) = conn.inner.stream.as_mut() {
1036                            stream.set_max_allowed_packet(
1037                                value
1038                                    .and_then(crate::from_value)
1039                                    .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1040                            );
1041                        }
1042                    }
1043                    Cfg::WaitTimeout => {
1044                        conn.inner.wait_timeout = Duration::from_secs(
1045                            value
1046                                .and_then(crate::from_value)
1047                                .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1048                        );
1049                    }
1050                }
1051            }
1052        }
1053
1054        let mut actions = vec![
1055            if let Some(x) = self.opts().max_allowed_packet() {
1056                Action::Apply(CfgData::MaxAllowedPacket(x))
1057            } else {
1058                Action::Load(Cfg::MaxAllowedPacket)
1059            },
1060            if let Some(x) = self.opts().wait_timeout() {
1061                Action::Apply(CfgData::WaitTimeout(x))
1062            } else {
1063                Action::Load(Cfg::WaitTimeout)
1064            },
1065        ];
1066
1067        if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1068            actions.push(Action::Load(Cfg::Socket))
1069        }
1070
1071        let loads = actions
1072            .iter()
1073            .filter_map(|x| match x {
1074                Action::Load(x) => Some(x),
1075                Action::Apply(_) => None,
1076            })
1077            .collect::<Vec<_>>();
1078
1079        let loaded = if !loads.is_empty() {
1080            let query = loads
1081                .iter()
1082                .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1083                .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1084                    acc.push(prefix);
1085                    acc.push_str(cfg.name());
1086                    acc
1087                });
1088
1089            self.query_internal::<Row, String>(query)
1090                .await?
1091                .map(|row| row.unwrap())
1092                .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1093        } else {
1094            vec![]
1095        };
1096        let mut loaded = loaded.into_iter();
1097
1098        for action in actions {
1099            match action {
1100                Action::Load(cfg) => cfg.apply(self, loaded.next()),
1101                Action::Apply(cfg) => cfg.apply(self),
1102            }
1103        }
1104
1105        Ok(())
1106    }
1107
1108    /// Returns true if time since last IO exceeds `wait_timeout`
1109    /// (or `conn_ttl` if specified in opts).
1110    fn expired(&self) -> bool {
1111        if let Some(deadline) = self.inner.ttl_deadline {
1112            if Instant::now() > deadline {
1113                return true;
1114            }
1115        }
1116        let ttl = self
1117            .inner
1118            .opts
1119            .conn_ttl()
1120            .unwrap_or(self.inner.wait_timeout);
1121        !ttl.is_zero() && self.idling() > ttl
1122    }
1123
1124    /// Returns duration since last IO.
1125    fn idling(&self) -> Duration {
1126        self.inner.last_io.elapsed()
1127    }
1128
1129    /// Executes [`COM_RESET_CONNECTION`][1].
1130    ///
1131    /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
1132    /// For older versions consider using [`Conn::change_user`].
1133    ///
1134    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
1135    pub async fn reset(&mut self) -> Result<bool> {
1136        let supports_com_reset_connection = if self.inner.is_mariadb {
1137            self.inner.version >= (10, 2, 4)
1138        } else {
1139            // assuming mysql
1140            self.inner.version > (5, 7, 2)
1141        };
1142
1143        if supports_com_reset_connection {
1144            self.routine(routines::ResetRoutine).await?;
1145            self.inner.stmt_cache.clear();
1146            self.inner.infile_handler = None;
1147            self.run_setup_commands().await?;
1148        }
1149
1150        Ok(supports_com_reset_connection)
1151    }
1152
1153    /// Executes [`COM_CHANGE_USER`][1].
1154    ///
1155    /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1156    /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1157    ///
1158    /// ## Note
1159    ///
1160    /// * Using non-default `opts` for a pooled connection is discouraging.
1161    /// * Connection options will be permanently updated.
1162    ///
1163    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1164    pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1165        // We'll kick this connection from a pool if opts are changed.
1166        if opts != ChangeUserOpts::default() {
1167            let mut opts_changed = false;
1168            if let Some(user) = opts.user() {
1169                opts_changed |= user != self.opts().user()
1170            };
1171            if let Some(pass) = opts.pass() {
1172                opts_changed |= pass != self.opts().pass()
1173            };
1174            if let Some(db_name) = opts.db_name() {
1175                opts_changed |= db_name != self.opts().db_name()
1176            };
1177            if opts_changed {
1178                if let Some(pool) = self.inner.pool.take() {
1179                    pool.cancel_connection();
1180                }
1181            }
1182        }
1183
1184        let conn_opts = &mut self.inner.opts;
1185        opts.update_opts(conn_opts);
1186        self.routine(routines::ChangeUser).await?;
1187        self.inner.stmt_cache.clear();
1188        self.inner.infile_handler = None;
1189        self.run_setup_commands().await?;
1190        Ok(())
1191    }
1192
1193    /// Resets the connection upon returning it to a pool.
1194    ///
1195    /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1196    async fn reset_for_pool(mut self) -> Result<Self> {
1197        if !self.reset().await? {
1198            self.change_user(Default::default()).await?;
1199        }
1200        Ok(self)
1201    }
1202
1203    /// Requires that `self.inner.tx_status != TxStatus::None`
1204    async fn rollback_transaction(&mut self) -> Result<()> {
1205        debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1206        self.inner.tx_status = TxStatus::None;
1207        self.query_drop("ROLLBACK").await
1208    }
1209
1210    /// Returns `true` if `SERVER_MORE_RESULTS_EXISTS` flag is contained
1211    /// in status flags of the connection.
1212    pub(crate) fn more_results_exists(&self) -> bool {
1213        self.status()
1214            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1215    }
1216
1217    /// The purpose of this function is to cleanup a pending result set
1218    /// for prematurely dropeed connection or query result.
1219    ///
1220    /// Requires that there are no other references to the pending result.
1221    pub(crate) async fn drop_result(&mut self) -> Result<()> {
1222        // Map everything into `PendingResult::Pending`
1223        let meta = match self.set_pending_result(None)? {
1224            Some(PendingResult::Pending(meta)) => Some(meta),
1225            Some(PendingResult::Taken(meta)) => {
1226                // This also asserts that there is only one reference left to the taken ResultSetMeta,
1227                // therefore this result set must be dropped here since it won't be dropped anywhere else.
1228                Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1229            }
1230            None => None,
1231        };
1232
1233        let _ = self.set_pending_result(meta);
1234
1235        match self.use_pending_result() {
1236            Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1237                QueryResult::<'_, '_, TextProtocol>::new(self)
1238                    .drop_result()
1239                    .await
1240            }
1241            Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1242                QueryResult::<'_, '_, BinaryProtocol>::new(self)
1243                    .drop_result()
1244                    .await
1245            }
1246            Ok(None) => Ok((/* this case does not require an action */)),
1247            Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1248                unreachable!("this case must be handled earlier in this function")
1249            }
1250        }
1251    }
1252
1253    /// This function will drop pending result and rollback a transaction, if needed.
1254    ///
1255    /// The purpose of this function, is to cleanup the connection while returning it to a [`Pool`].
1256    async fn cleanup_for_pool(mut self) -> Result<Self> {
1257        loop {
1258            let result = if self.has_pending_result() {
1259                self.drop_result().await
1260            } else if self.inner.tx_status != TxStatus::None {
1261                self.rollback_transaction().await
1262            } else {
1263                break;
1264            };
1265
1266            // The connection was dropped and we assume that it was dropped intentionally,
1267            // so we'll ignore non-fatal errors during cleanup (also there is no direct caller
1268            // to return this error to).
1269            if let Err(err) = result {
1270                if err.is_fatal() {
1271                    // This means that connection is completely broken
1272                    // and shouldn't return to a pool.
1273                    return Err(err);
1274                }
1275            }
1276        }
1277        Ok(self)
1278    }
1279}
1280
1281#[cfg(test)]
1282mod test {
1283    use bytes::Bytes;
1284    use futures_util::stream::{self, StreamExt};
1285    use mysql_common::constants::MAX_PAYLOAD_LEN;
1286    use rand::Fill;
1287    use tokio::{io::AsyncWriteExt, net::TcpListener};
1288
1289    use crate::{
1290        from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1291        OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler,
1292    };
1293
1294    #[tokio::test]
1295    async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1296        let opts = get_opts().client_found_rows(true);
1297        let mut conn = Conn::new(opts).await.unwrap();
1298
1299        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1300            .ignore(&mut conn)
1301            .await?;
1302
1303        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1304            .ignore(&mut conn)
1305            .await?;
1306
1307        // Inserted one row, affected should be one.
1308        assert_eq!(conn.affected_rows(), 1);
1309
1310        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1311            .ignore(&mut conn)
1312            .await?;
1313
1314        // The query doesn't affect any rows, but due to us wanting FOUND rows,
1315        // this has to return one.
1316        assert_eq!(conn.affected_rows(), 1);
1317
1318        Ok(())
1319    }
1320
1321    #[tokio::test]
1322    async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1323        let mut conn = Conn::new(get_opts()).await.unwrap();
1324
1325        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1326            .ignore(&mut conn)
1327            .await?;
1328
1329        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1330            .ignore(&mut conn)
1331            .await?;
1332
1333        // Inserted one row, affected should be one.
1334        assert_eq!(conn.affected_rows(), 1);
1335
1336        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1337            .ignore(&mut conn)
1338            .await?;
1339
1340        // The query doesn't affect any rows.
1341        assert_eq!(conn.affected_rows(), 0);
1342
1343        Ok(())
1344    }
1345
1346    #[test]
1347    fn opts_should_satisfy_send_and_sync() {
1348        struct A<T: Sync + Send>(T);
1349        #[allow(clippy::unnecessary_operation)]
1350        A(get_opts());
1351    }
1352
1353    #[tokio::test]
1354    async fn should_connect_without_database() -> super::Result<()> {
1355        // no database name
1356        let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1357        conn.ping().await?;
1358        conn.disconnect().await?;
1359
1360        // empty database name
1361        let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1362        conn.ping().await?;
1363        conn.disconnect().await?;
1364
1365        Ok(())
1366    }
1367
1368    #[tokio::test]
1369    async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1370        let mut conn: Conn = Conn::new(get_opts()).await?;
1371
1372        conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1373            .await?;
1374
1375        // dropped query:
1376        conn.query_iter("SELECT 1").await?;
1377        conn.ping().await?;
1378
1379        // dropped query in dropped transaction:
1380        let mut tx = conn.start_transaction(Default::default()).await?;
1381        tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1382            .await?;
1383        tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1384        drop(tx);
1385        conn.ping().await?;
1386
1387        let count: u8 = conn
1388            .query_first("SELECT COUNT(*) FROM mysql.foo")
1389            .await?
1390            .unwrap_or_default();
1391
1392        assert_eq!(count, 0);
1393
1394        Ok(())
1395    }
1396
1397    #[tokio::test]
1398    async fn should_connect() -> super::Result<()> {
1399        let mut conn: Conn = Conn::new(get_opts()).await?;
1400        conn.ping().await?;
1401        let plugins: Vec<String> = conn
1402            .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1403                row.take("Name").unwrap()
1404            })
1405            .await?;
1406
1407        // Should connect with any combination of supported plugin and empty-nonempty password.
1408        let variants = vec![
1409            ("caching_sha2_password", 2_u8, "non-empty"),
1410            ("caching_sha2_password", 2_u8, ""),
1411            ("mysql_native_password", 0_u8, "non-empty"),
1412            ("mysql_native_password", 0_u8, ""),
1413        ]
1414        .into_iter()
1415        .filter(|variant| plugins.iter().any(|p| p == variant.0));
1416
1417        for (plug, val, pass) in variants {
1418            dbg!((plug, val, pass, conn.inner.version));
1419
1420            if plug == "mysql_native_password" && conn.inner.version >= (8, 4, 0) {
1421                continue;
1422            }
1423
1424            let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1425
1426            let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1427            conn.query_drop(query).await.unwrap();
1428
1429            if conn.inner.version < (8, 0, 11) {
1430                conn.query_drop(format!("SET old_passwords = {}", val))
1431                    .await
1432                    .unwrap();
1433                conn.query_drop(format!(
1434                    "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1435                    pass
1436                ))
1437                .await
1438                .unwrap();
1439            } else {
1440                conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1441                    .await
1442                    .unwrap();
1443            };
1444
1445            let opts = get_opts()
1446                .user(Some("test_user"))
1447                .pass(Some(pass))
1448                .db_name(None::<String>);
1449            let result = Conn::new(opts).await;
1450
1451            conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1452
1453            result?.disconnect().await?;
1454        }
1455
1456        if crate::test_misc::test_compression() {
1457            assert!(format!("{:?}", conn).contains("Compression"));
1458        }
1459
1460        if crate::test_misc::test_ssl() {
1461            assert!(format!("{:?}", conn).contains("Tls"));
1462        }
1463
1464        conn.disconnect().await?;
1465        Ok(())
1466    }
1467
1468    #[test]
1469    fn should_not_panic_if_dropped_without_tokio_runtime() {
1470        let fut = Conn::new(get_opts());
1471        let runtime = tokio::runtime::Runtime::new().unwrap();
1472        runtime.block_on(async {
1473            fut.await.unwrap();
1474        });
1475        // connection will drop here
1476    }
1477
1478    #[tokio::test]
1479    async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1480        let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1481        let mut conn = Conn::new(opts).await?;
1482        let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1483        conn.disconnect().await?;
1484        assert_eq!(result, vec![(42, "foo".into())]);
1485        Ok(())
1486    }
1487
1488    #[tokio::test]
1489    async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1490        let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1491        let mut conn = Conn::new(opts).await?;
1492
1493        // initial run
1494        let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1495        assert_eq!(result, vec![(42, "foo".into())]);
1496
1497        // after reset
1498        if conn.reset().await? {
1499            result = conn.query("SELECT @a, @b").await?;
1500            assert_eq!(result, vec![(42, "foo".into())]);
1501        }
1502
1503        // after change user
1504        conn.change_user(Default::default()).await?;
1505        result = conn.query("SELECT @a, @b").await?;
1506        assert_eq!(result, vec![(42, "foo".into())]);
1507
1508        conn.disconnect().await?;
1509        Ok(())
1510    }
1511
1512    #[tokio::test]
1513    async fn should_reset_the_connection() -> super::Result<()> {
1514        let mut conn = Conn::new(get_opts()).await?;
1515
1516        assert_eq!(
1517            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1518            Value::NULL
1519        );
1520
1521        conn.query_drop("SET @foo = 'foo'").await?;
1522
1523        assert_eq!(
1524            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1525            "foo",
1526        );
1527
1528        if conn.reset().await? {
1529            assert_eq!(
1530                conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1531                Value::NULL
1532            );
1533        } else {
1534            assert_eq!(
1535                conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1536                "foo",
1537            );
1538        }
1539
1540        conn.disconnect().await?;
1541        Ok(())
1542    }
1543
1544    #[tokio::test]
1545    async fn should_change_user() -> super::Result<()> {
1546        let mut conn = Conn::new(get_opts()).await?;
1547        assert_eq!(
1548            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1549            Value::NULL
1550        );
1551
1552        conn.query_drop("SET @foo = 'foo'").await?;
1553
1554        assert_eq!(
1555            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1556            "foo",
1557        );
1558
1559        conn.change_user(Default::default()).await?;
1560        assert_eq!(
1561            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1562            Value::NULL
1563        );
1564
1565        let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) {
1566            &["mysql_native_password", "caching_sha2_password"]
1567        } else {
1568            &["mysql_native_password"]
1569        };
1570
1571        for (i, plugin) in plugins.iter().enumerate() {
1572            if *plugin == "mysql_native_password" && conn.server_version() >= (8, 4, 0) {
1573                continue;
1574            }
1575
1576            let mut rng = rand::thread_rng();
1577            let mut pass = [0u8; 10];
1578            pass.try_fill(&mut rng).unwrap();
1579            let pass: String = IntoIterator::into_iter(pass)
1580                .map(|x| ((x % (123 - 97)) + 97) as char)
1581                .collect();
1582
1583            let result = conn
1584                .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats")
1585                .await;
1586            if matches!(conn.server_version(), (5, 6, _)) && i == 0 {
1587                // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration
1588                drop(result);
1589            } else {
1590                result.unwrap();
1591            }
1592
1593            if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) {
1594                if matches!(conn.server_version(), (5, 6, _)) {
1595                    conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password")
1596                        .await
1597                        .unwrap();
1598                    conn.query_drop(format!(
1599                        "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})",
1600                        Value::from(pass.clone()).as_sql(false)
1601                    ))
1602                    .await
1603                    .unwrap();
1604                } else {
1605                    conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap();
1606                    conn.query_drop(format!(
1607                        "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})",
1608                        Value::from(pass.clone()).as_sql(false)
1609                    ))
1610                    .await
1611                    .unwrap();
1612                }
1613            } else {
1614                conn.query_drop(format!(
1615                    "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}",
1616                    plugin,
1617                    Value::from(pass.clone()).as_sql(false)
1618                ))
1619                .await
1620                .unwrap();
1621            };
1622
1623            let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1624            conn2
1625                .change_user(
1626                    ChangeUserOpts::default()
1627                        .with_db_name(None)
1628                        .with_user(Some("__mats".into()))
1629                        .with_pass(Some(pass)),
1630                )
1631                .await
1632                .unwrap();
1633            let (db, user) = conn2
1634                .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1635                .await
1636                .unwrap()
1637                .unwrap();
1638            assert_eq!(db, None);
1639            assert!(user.starts_with("__mats"));
1640
1641            conn2.disconnect().await.unwrap();
1642        }
1643
1644        conn.disconnect().await?;
1645        Ok(())
1646    }
1647
1648    #[tokio::test]
1649    async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1650        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1651
1652        let mut conn = Conn::new(opts).await?;
1653        conn.exec_drop("DO ?", (1_u8,)).await?;
1654
1655        let stmt = conn.prep("DO 2").await?;
1656        conn.exec_drop(&stmt, ()).await?;
1657        conn.exec_drop(&stmt, ()).await?;
1658        conn.close(stmt).await?;
1659
1660        conn.exec_drop("DO 3", ()).await?;
1661        conn.exec_batch("DO 4", vec![(), ()]).await?;
1662        conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1663        let row: Option<(crate::Value, usize)> = conn
1664            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1665            .await?;
1666
1667        assert_eq!(row.unwrap().1, 1);
1668        assert_eq!(conn.inner.stmt_cache.len(), 0);
1669
1670        conn.disconnect().await?;
1671
1672        Ok(())
1673    }
1674
1675    #[tokio::test]
1676    async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1677        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1678        let mut conn = Conn::new(opts).await?;
1679        conn.exec_drop("DO 1", ()).await?;
1680        conn.exec_drop("DO 2", ()).await?;
1681        conn.exec_drop("DO 3", ()).await?;
1682        conn.exec_drop("DO 1", ()).await?;
1683        conn.exec_drop("DO 4", ()).await?;
1684        conn.exec_drop("DO 3", ()).await?;
1685        conn.exec_drop("DO 5", ()).await?;
1686        conn.exec_drop("DO 6", ()).await?;
1687        let row_opt = conn
1688            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1689            .await?;
1690        let (_, count): (String, usize) = row_opt.unwrap();
1691        assert_eq!(count, 3);
1692        let order = conn
1693            .stmt_cache_ref()
1694            .iter()
1695            .map(|item| item.1.query.0.as_ref())
1696            .collect::<Vec<&[u8]>>();
1697        assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1698        conn.disconnect().await?;
1699        Ok(())
1700    }
1701
1702    #[tokio::test]
1703    async fn should_perform_queries() -> super::Result<()> {
1704        let mut conn = Conn::new(get_opts()).await?;
1705        for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1706            let long_string = "A".repeat(x);
1707            let result: Vec<(String, u8)> = conn
1708                .query(format!(r"SELECT '{}', 231", long_string))
1709                .await?;
1710            assert_eq!((long_string, 231_u8), result[0]);
1711        }
1712        conn.disconnect().await?;
1713        Ok(())
1714    }
1715
1716    #[tokio::test]
1717    async fn should_query_drop() -> super::Result<()> {
1718        let mut conn = Conn::new(get_opts()).await?;
1719        conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1720            .await?;
1721        conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1722        let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1723        conn.disconnect().await?;
1724        assert_eq!(result, Some(1_u8));
1725        Ok(())
1726    }
1727
1728    #[tokio::test]
1729    async fn should_prepare_statement() -> super::Result<()> {
1730        let mut conn = Conn::new(get_opts()).await?;
1731        let stmt = conn.prep(r"SELECT ?").await?;
1732        conn.close(stmt).await?;
1733        conn.disconnect().await?;
1734
1735        let mut conn = Conn::new(get_opts()).await?;
1736        let stmt = conn.prep(r"SELECT :foo").await?;
1737
1738        {
1739            let query = String::from("SELECT ?, ?");
1740            let stmt = conn.prep(&*query).await?;
1741            conn.close(stmt).await?;
1742            {
1743                let mut conn = Conn::new(get_opts()).await?;
1744                let stmt = conn.prep(&*query).await?;
1745                conn.close(stmt).await?;
1746                conn.disconnect().await?;
1747            }
1748        }
1749
1750        conn.close(stmt).await?;
1751        conn.disconnect().await?;
1752
1753        Ok(())
1754    }
1755
1756    #[tokio::test]
1757    async fn should_execute_statement() -> super::Result<()> {
1758        let long_string = "A".repeat(18 * 1024 * 1024);
1759        let mut conn = Conn::new(get_opts()).await?;
1760        let stmt = conn.prep(r"SELECT ?").await?;
1761        let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1762        let mut mapped = result.map_and_drop(from_row::<(String,)>).await?;
1763        assert_eq!(mapped.len(), 1);
1764        assert_eq!(mapped.pop(), Some((long_string,)));
1765        let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1766        let collected = result.collect_and_drop::<(u8,)>().await?;
1767        assert_eq!(collected, vec![(42u8,)]);
1768        let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1769        let reduced = result
1770            .reduce_and_drop(2, |mut acc, row| {
1771                acc += from_row::<i32>(row);
1772                acc
1773            })
1774            .await?;
1775        conn.close(stmt).await?;
1776        conn.disconnect().await?;
1777        assert_eq!(reduced, 10);
1778
1779        let mut conn = Conn::new(get_opts()).await?;
1780        let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1781        let result = conn
1782            .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1783            .await?;
1784        let mut mapped = result
1785            .map_and_drop(from_row::<(String, String, String, u8)>)
1786            .await?;
1787        assert_eq!(mapped.len(), 1);
1788        assert_eq!(
1789            mapped.pop(),
1790            Some(("quux".into(), "baz".into(), "quux".into(), 3))
1791        );
1792        let result = conn
1793            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1794            .await?;
1795        let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1796        assert_eq!(collected, vec![(2, 3, 2, 3)]);
1797        let result = conn
1798            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1799            .await?;
1800        let reduced = result
1801            .reduce_and_drop(0, |acc, row| {
1802                let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1803                acc + a + b + c + d
1804            })
1805            .await?;
1806        conn.close(stmt).await?;
1807        conn.disconnect().await?;
1808        assert_eq!(reduced, 10);
1809        Ok(())
1810    }
1811
1812    #[tokio::test]
1813    async fn should_prep_exec_statement() -> super::Result<()> {
1814        let mut conn = Conn::new(get_opts()).await?;
1815        let result = conn
1816            .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1817            .await?;
1818        let output = result
1819            .map_and_drop(|row| {
1820                let (a, b, c): (u8, u8, u8) = from_row(row);
1821                a * b * c
1822            })
1823            .await?;
1824        conn.disconnect().await?;
1825        assert_eq!(output[0], 12u8);
1826        Ok(())
1827    }
1828
1829    #[tokio::test]
1830    async fn should_first_exec_statement() -> super::Result<()> {
1831        let mut conn = Conn::new(get_opts()).await?;
1832        let output = conn
1833            .exec_first(
1834                r"SELECT :a UNION ALL SELECT :b",
1835                params! { "a" => 2, "b" => 3 },
1836            )
1837            .await?;
1838        conn.disconnect().await?;
1839        assert_eq!(output, Some(2u8));
1840        Ok(())
1841    }
1842
1843    #[tokio::test]
1844    async fn issue_107() -> super::Result<()> {
1845        let mut conn = Conn::new(get_opts()).await?;
1846        conn.query_drop(
1847            r"CREATE TEMPORARY TABLE mysql.issue (
1848                    a BIGINT(20) UNSIGNED,
1849                    b VARBINARY(16),
1850                    c BINARY(32),
1851                    d BIGINT(20) UNSIGNED,
1852                    e BINARY(32)
1853                )",
1854        )
1855        .await?;
1856        conn.query_drop(
1857            r"INSERT INTO mysql.issue VALUES (
1858                    0,
1859                    0xC066F966B0860000,
1860                    0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1861                    0,
1862                    ''
1863                ), (
1864                    1,
1865                    '',
1866                    0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1867                    0,
1868                    ''
1869                )",
1870        )
1871        .await?;
1872
1873        let q = "SELECT b, c, d, e FROM mysql.issue";
1874        let result = conn.query_iter(q).await?;
1875
1876        let loaded_structs = result
1877            .map_and_drop(crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>)
1878            .await?;
1879
1880        conn.disconnect().await?;
1881
1882        assert_eq!(loaded_structs.len(), 2);
1883
1884        Ok(())
1885    }
1886
1887    #[tokio::test]
1888    async fn should_run_transactions() -> super::Result<()> {
1889        let mut conn = Conn::new(get_opts()).await?;
1890        conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
1891            .await?;
1892        let mut transaction = conn.start_transaction(Default::default()).await?;
1893        transaction
1894            .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
1895            .await?;
1896        assert_eq!(transaction.last_insert_id(), None);
1897        assert_eq!(transaction.affected_rows(), 2);
1898        assert_eq!(transaction.get_warnings(), 0);
1899        assert_eq!(transaction.info(), "Records: 2  Duplicates: 0  Warnings: 0");
1900        transaction.commit().await?;
1901        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1902        assert_eq!(output_opt, Some((2u8,)));
1903        let mut transaction = conn.start_transaction(Default::default()).await?;
1904        transaction
1905            .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
1906            .await?;
1907        let output_opt = transaction
1908            .exec_first("SELECT COUNT(*) FROM tmp", ())
1909            .await?;
1910        assert_eq!(output_opt, Some((4u8,)));
1911        transaction.rollback().await?;
1912        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1913        assert_eq!(output_opt, Some((2u8,)));
1914
1915        let mut transaction = conn.start_transaction(Default::default()).await?;
1916        transaction
1917            .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
1918            .await?;
1919        drop(transaction); // implicit rollback
1920        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1921        assert_eq!(output_opt, Some((2u8,)));
1922
1923        conn.disconnect().await?;
1924        Ok(())
1925    }
1926
1927    #[tokio::test]
1928    async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
1929        const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
1930        const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
1931        let mut conn = Conn::new(get_opts()).await.unwrap();
1932
1933        // if error is in the first result set, then query should return it immediately.
1934        let result = QUERY_FIRST.run(&mut conn).await;
1935        assert!(matches!(result, Err(Error::Server(_))));
1936
1937        let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
1938
1939        // first result set will contain one row
1940        let result_set: Vec<u8> = result.collect().await.unwrap();
1941        assert_eq!(result_set, vec![1]);
1942
1943        // second result set will contain an error.
1944        let result_set: super::Result<Vec<u8>> = result.collect().await;
1945        assert!(matches!(result_set, Err(Error::Server(_))));
1946
1947        // there will be no third result set
1948        assert!(result.is_empty());
1949
1950        conn.ping().await?;
1951        conn.disconnect().await?;
1952
1953        Ok(())
1954    }
1955
1956    #[tokio::test]
1957    async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
1958        const PROC_DEF_FIRST: &str =
1959            r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
1960        const PROC_DEF_MIDDLE: &str =
1961            r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
1962
1963        let mut conn = Conn::new(get_opts()).await.unwrap();
1964
1965        conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
1966            .await?;
1967        conn.query_iter(PROC_DEF_FIRST).await?;
1968
1969        conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
1970            .await?;
1971        conn.query_iter(PROC_DEF_MIDDLE).await?;
1972
1973        // if error is in the first result set, then query should return it immediately.
1974        let result = conn.query_iter("CALL err_first()").await;
1975        assert!(matches!(result, Err(Error::Server(_))));
1976
1977        let mut result = conn.query_iter("CALL err_middle()").await?;
1978
1979        // first result set will contain one row
1980        let result_set: Vec<u8> = result.collect().await.unwrap();
1981        assert_eq!(result_set, vec![1]);
1982
1983        // second result set will contain an error.
1984        let result_set: super::Result<Vec<u8>> = result.collect().await;
1985        assert!(matches!(result_set, Err(Error::Server(_))));
1986
1987        // there will be no third result set
1988        assert!(result.is_empty());
1989
1990        conn.ping().await?;
1991        conn.disconnect().await?;
1992
1993        Ok(())
1994    }
1995
1996    #[tokio::test]
1997    async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
1998        use std::fs::write;
1999
2000        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2001        let file_path = file_path.path();
2002        let file_name = file_path.file_name().unwrap();
2003
2004        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2005
2006        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2007
2008        // LOCAL INFILE in the middle of a multi-result set should not break anything.
2009        let mut conn = Conn::new(opts).await.unwrap();
2010        "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
2011
2012        let query = format!(
2013            r#"SELECT * FROM tmp;
2014            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2015            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2016            SELECT * FROM tmp"#,
2017            file_name.to_str().unwrap(),
2018            file_name.to_str().unwrap(),
2019        );
2020
2021        let mut result = query.run(&mut conn).await?;
2022
2023        let result_set = result.collect::<String>().await?;
2024        assert_eq!(result_set.len(), 0);
2025
2026        let mut no_local_infile = false;
2027
2028        for _ in 0..2 {
2029            match result.collect::<String>().await {
2030                Ok(result_set) => {
2031                    assert_eq!(result.affected_rows(), 3);
2032                    assert!(result_set.is_empty())
2033                }
2034                Err(Error::Server(ref err)) if err.code == 1148 => {
2035                    // The used command is not allowed with this MySQL version
2036                    no_local_infile = true;
2037                    break;
2038                }
2039                Err(Error::Server(ref err)) if err.code == 3948 => {
2040                    // Loading local data is disabled;
2041                    // this must be enabled on both the client and server sides
2042                    no_local_infile = true;
2043                    break;
2044                }
2045                Err(err) => return Err(err),
2046            }
2047        }
2048
2049        if no_local_infile {
2050            assert!(result.is_empty());
2051            assert_eq!(result_set.len(), 0);
2052        } else {
2053            let result_set = result.collect::<String>().await?;
2054            assert_eq!(result_set.len(), 6);
2055            assert_eq!(result_set[0], "AAAAAA");
2056            assert_eq!(result_set[1], "BBBBBB");
2057            assert_eq!(result_set[2], "CCCCCC");
2058            assert_eq!(result_set[3], "AAAAAA");
2059            assert_eq!(result_set[4], "BBBBBB");
2060            assert_eq!(result_set[5], "CCCCCC");
2061        }
2062
2063        conn.ping().await?;
2064        conn.disconnect().await?;
2065
2066        Ok(())
2067    }
2068
2069    #[tokio::test]
2070    async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2071        let mut c = Conn::new(get_opts()).await?;
2072        c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2073            .await?;
2074
2075        let mut result = c
2076            .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2077            .await?;
2078        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2079
2080        result.for_each(drop).await?;
2081        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2082
2083        result.for_each(drop).await?;
2084        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2085
2086        result.for_each(drop).await?;
2087        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2088
2089        c.disconnect().await?;
2090        Ok(())
2091    }
2092
2093    #[tokio::test]
2094    async fn should_expose_query_result_metadata() -> super::Result<()> {
2095        let pool = Pool::new(get_opts());
2096        let mut c = pool.get_conn().await?;
2097
2098        c.query_drop(
2099            r"
2100            CREATE TEMPORARY TABLE `foo`
2101                ( `id` SERIAL
2102                , `bar_id` varchar(36) NOT NULL
2103                , `baz_id` varchar(36) NOT NULL
2104                , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2105                , PRIMARY KEY (`id`)
2106                , KEY `bar_idx` (`bar_id`)
2107                , KEY `baz_idx` (`baz_id`)
2108            );",
2109        )
2110        .await?;
2111
2112        const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2113        let params = ("qwerty", "data.employee_id");
2114
2115        let query_result = c.exec_iter(QUERY, params).await?;
2116        assert_eq!(query_result.last_insert_id(), Some(1));
2117        query_result.drop_result().await?;
2118
2119        c.exec_drop(QUERY, params).await?;
2120        assert_eq!(c.last_insert_id(), Some(2));
2121
2122        let mut tx = c.start_transaction(Default::default()).await?;
2123
2124        tx.exec_drop(QUERY, params).await?;
2125        assert_eq!(tx.last_insert_id(), Some(3));
2126
2127        Ok(())
2128    }
2129
2130    #[tokio::test]
2131    async fn should_handle_local_infile_locally() -> super::Result<()> {
2132        let mut conn = Conn::new(get_opts()).await.unwrap();
2133        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2134            .await
2135            .unwrap();
2136
2137        conn.set_infile_handler(async move {
2138            Ok(
2139                stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2140                    .map(Ok)
2141                    .boxed(),
2142            )
2143        });
2144
2145        match conn
2146            .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2147            .await
2148        {
2149            Ok(_) => (),
2150            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2151                // The used command is not allowed with this MySQL version
2152                return Ok(());
2153            }
2154            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2155                // Loading local data is disabled;
2156                // this must be enabled on both the client and server sides
2157                return Ok(());
2158            }
2159            e @ Err(_) => e.unwrap(),
2160        };
2161
2162        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2163        assert_eq!(result.len(), 3);
2164        assert_eq!(result[0], "AAAAAA");
2165        assert_eq!(result[1], "BBBBBB");
2166        assert_eq!(result[2], "CCCCCC");
2167
2168        Ok(())
2169    }
2170
2171    #[tokio::test]
2172    async fn should_handle_local_infile_globally() -> super::Result<()> {
2173        use std::fs::write;
2174
2175        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2176        let file_path = file_path.path();
2177        let file_name = file_path.file_name().unwrap();
2178
2179        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2180
2181        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2182
2183        let mut conn = Conn::new(opts).await.unwrap();
2184        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2185            .await
2186            .unwrap();
2187
2188        match conn
2189            .query_drop(format!(
2190                r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2191                file_name.to_str().unwrap(),
2192            ))
2193            .await
2194        {
2195            Ok(_) => (),
2196            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2197                // The used command is not allowed with this MySQL version
2198                return Ok(());
2199            }
2200            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2201                // Loading local data is disabled;
2202                // this must be enabled on both the client and server sides
2203                return Ok(());
2204            }
2205            e @ Err(_) => e.unwrap(),
2206        };
2207
2208        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2209        assert_eq!(result.len(), 3);
2210        assert_eq!(result[0], "AAAAAA");
2211        assert_eq!(result[1], "BBBBBB");
2212        assert_eq!(result[2], "CCCCCC");
2213
2214        Ok(())
2215    }
2216
2217    #[tokio::test]
2218    async fn should_handle_initial_error_packet() {
2219        let header = [
2220            0x68, 0x00, 0x00, // packet_length
2221            0x00, // sequence
2222            0xff, // error_header
2223            0x69, 0x04, // error_code
2224        ];
2225        let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'";
2226
2227        // Create a fake MySQL server that immediately replies with an error packet.
2228        let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap();
2229
2230        let listen_addr = listener.local_addr().unwrap();
2231
2232        tokio::task::spawn(async move {
2233            let (mut stream, _) = listener.accept().await.unwrap();
2234            stream.write_all(&header).await.unwrap();
2235            stream.write_all(error_message.as_bytes()).await.unwrap();
2236            stream.shutdown().await.unwrap();
2237        });
2238
2239        let opts = OptsBuilder::default()
2240            .ip_or_hostname(listen_addr.ip().to_string())
2241            .tcp_port(listen_addr.port());
2242        let server_err = match Conn::new(opts).await {
2243            Err(Error::Server(server_err)) => server_err,
2244            other => panic!("expected server error but got: {:?}", other),
2245        };
2246        assert_eq!(
2247            server_err,
2248            ServerError {
2249                code: 1129,
2250                state: "HY000".to_owned(),
2251                message: error_message.to_owned(),
2252            }
2253        );
2254    }
2255
2256    #[cfg(feature = "nightly")]
2257    mod bench {
2258        use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2259
2260        #[bench]
2261        fn simple_exec(bencher: &mut test::Bencher) {
2262            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2263            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2264
2265            bencher.iter(|| {
2266                runtime.block_on(conn.query_drop("DO 1")).unwrap();
2267            });
2268
2269            runtime.block_on(conn.disconnect()).unwrap();
2270        }
2271
2272        #[bench]
2273        fn select_large_string(bencher: &mut test::Bencher) {
2274            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2275            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2276
2277            bencher.iter(|| {
2278                runtime
2279                    .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2280                    .unwrap();
2281            });
2282
2283            runtime.block_on(conn.disconnect()).unwrap();
2284        }
2285
2286        #[bench]
2287        fn prepared_exec(bencher: &mut test::Bencher) {
2288            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2289            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2290            let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2291
2292            bencher.iter(|| {
2293                runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2294            });
2295
2296            runtime.block_on(conn.close(stmt)).unwrap();
2297            runtime.block_on(conn.disconnect()).unwrap();
2298        }
2299
2300        #[bench]
2301        fn prepare_and_exec(bencher: &mut test::Bencher) {
2302            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2303            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2304
2305            bencher.iter(|| {
2306                runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2307            });
2308
2309            runtime.block_on(conn.disconnect()).unwrap();
2310        }
2311    }
2312}