mysql_async/conn/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use futures_util::FutureExt;
10
11use mysql_common::{
12    constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI},
13    crypto,
14    io::ParseBuf,
15    packets::{
16        AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket,
17        HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket,
18        ResultSetTerminator, SslRequest,
19    },
20    proto::MySerialize,
21    row::Row,
22};
23
24use std::{
25    borrow::Cow,
26    fmt,
27    future::Future,
28    mem::{self, replace},
29    pin::Pin,
30    str::FromStr,
31    sync::Arc,
32    time::{Duration, Instant},
33};
34
35use crate::{
36    buffer_pool::PooledBuf,
37    conn::{pool::Pool, stmt_cache::StmtCache},
38    consts::{CapabilityFlags, Command, StatusFlags},
39    error::*,
40    io::Stream,
41    opts::Opts,
42    queryable::{
43        query_result::{QueryResult, ResultSetMeta},
44        transaction::TxStatus,
45        BinaryProtocol, Queryable, TextProtocol,
46    },
47    ChangeUserOpts, InfileData, OptsBuilder,
48};
49
50use self::routines::Routine;
51
52#[cfg(feature = "binlog")]
53pub mod binlog_stream;
54pub mod pool;
55pub mod routines;
56pub mod stmt_cache;
57
58const DEFAULT_WAIT_TIMEOUT: usize = 28800;
59
60/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
61fn disconnect(mut conn: Conn) {
62    let disconnected = conn.inner.disconnected;
63
64    // Mark conn as disconnected.
65    conn.inner.disconnected = true;
66
67    if !disconnected {
68        // We shouldn't call tokio::spawn if unwinding
69        if std::thread::panicking() {
70            return;
71        }
72
73        // Server will report broken connection if spawn fails.
74        // this might fail if, say, the runtime is shutting down, but we've done what we could
75        if let Ok(handle) = tokio::runtime::Handle::try_current() {
76            handle.spawn(async move {
77                if let Ok(conn) = conn.cleanup_for_pool().await {
78                    let _ = conn.disconnect().await;
79                }
80            });
81        }
82    }
83}
84
85/// Pending result set.
86#[derive(Debug, Clone)]
87pub(crate) enum PendingResult {
88    /// There is a pending result set.
89    Pending(ResultSetMeta),
90    /// Result set metadata was taken but not yet consumed.
91    Taken(Arc<ResultSetMeta>),
92}
93
94/// Mysql connection
95struct ConnInner {
96    stream: Option<Stream>,
97    id: u32,
98    is_mariadb: bool,
99    version: (u16, u16, u16),
100    socket: Option<String>,
101    capabilities: CapabilityFlags,
102    status: StatusFlags,
103    last_ok_packet: Option<OkPacket<'static>>,
104    last_err_packet: Option<mysql_common::packets::ServerError<'static>>,
105    handshake_complete: bool,
106    pool: Option<Pool>,
107    pending_result: std::result::Result<Option<PendingResult>, ServerError>,
108    tx_status: TxStatus,
109    reset_upon_returning_to_a_pool: bool,
110    opts: Opts,
111    ttl_deadline: Option<Instant>,
112    last_io: Instant,
113    wait_timeout: Duration,
114    stmt_cache: StmtCache,
115    nonce: Vec<u8>,
116    auth_plugin: AuthPlugin<'static>,
117    auth_switched: bool,
118    server_key: Option<Vec<u8>>,
119    active_since: Instant,
120    /// Connection is already disconnected.
121    pub(crate) disconnected: bool,
122    /// One-time connection-level infile handler.
123    infile_handler:
124        Option<Pin<Box<dyn Future<Output = crate::Result<InfileData>> + Send + Sync + 'static>>>,
125}
126
127impl fmt::Debug for ConnInner {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        f.debug_struct("Conn")
130            .field("connection id", &self.id)
131            .field("server version", &self.version)
132            .field("pool", &self.pool)
133            .field("pending_result", &self.pending_result)
134            .field("tx_status", &self.tx_status)
135            .field("stream", &self.stream)
136            .field("options", &self.opts)
137            .field("server_key", &self.server_key)
138            .field("auth_plugin", &self.auth_plugin)
139            .finish()
140    }
141}
142
143impl ConnInner {
144    /// Constructs an empty connection.
145    fn empty(opts: Opts) -> ConnInner {
146        let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline();
147        ConnInner {
148            capabilities: opts.get_capabilities(),
149            status: StatusFlags::empty(),
150            last_ok_packet: None,
151            last_err_packet: None,
152            handshake_complete: false,
153            stream: None,
154            is_mariadb: false,
155            version: (0, 0, 0),
156            id: 0,
157            pending_result: Ok(None),
158            pool: None,
159            tx_status: TxStatus::None,
160            last_io: Instant::now(),
161            wait_timeout: Duration::from_secs(0),
162            stmt_cache: StmtCache::new(opts.stmt_cache_size()),
163            socket: opts.socket().map(Into::into),
164            opts,
165            ttl_deadline,
166            nonce: Vec::default(),
167            auth_plugin: AuthPlugin::MysqlNativePassword,
168            auth_switched: false,
169            disconnected: false,
170            server_key: None,
171            infile_handler: None,
172            reset_upon_returning_to_a_pool: false,
173            active_since: Instant::now(),
174        }
175    }
176
177    /// Returns mutable reference to a connection stream.
178    ///
179    /// Returns `DriverError::ConnectionClosed` if there is no stream.
180    fn stream_mut(&mut self) -> Result<&mut Stream> {
181        self.stream
182            .as_mut()
183            .ok_or_else(|| DriverError::ConnectionClosed.into())
184    }
185}
186
187/// MySql server connection.
188#[derive(Debug)]
189pub struct Conn {
190    inner: Box<ConnInner>,
191}
192
193impl Conn {
194    /// Returns connection identifier.
195    pub fn id(&self) -> u32 {
196        self.inner.id
197    }
198
199    /// Returns the ID generated by a query (usually `INSERT`) on a table with a column having the
200    /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection
201    /// or if the query did not update an AUTO_INCREMENT value.
202    pub fn last_insert_id(&self) -> Option<u64> {
203        self.inner
204            .last_ok_packet
205            .as_ref()
206            .and_then(|ok| ok.last_insert_id())
207    }
208
209    /// Returns the number of rows affected by the last `INSERT`, `UPDATE`, `REPLACE` or `DELETE`
210    /// query.
211    pub fn affected_rows(&self) -> u64 {
212        self.inner
213            .last_ok_packet
214            .as_ref()
215            .map(|ok| ok.affected_rows())
216            .unwrap_or_default()
217    }
218
219    /// Text information, as reported by the server in the last OK packet, or an empty string.
220    pub fn info(&self) -> Cow<'_, str> {
221        self.inner
222            .last_ok_packet
223            .as_ref()
224            .and_then(|ok| ok.info_str())
225            .unwrap_or_else(|| "".into())
226    }
227
228    /// Number of warnings, as reported by the server in the last OK packet, or `0`.
229    pub fn get_warnings(&self) -> u16 {
230        self.inner
231            .last_ok_packet
232            .as_ref()
233            .map(|ok| ok.warnings())
234            .unwrap_or_default()
235    }
236
237    /// Returns a reference to the last OK packet.
238    pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> {
239        self.inner.last_ok_packet.as_ref()
240    }
241
242    /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]).
243    ///
244    /// Only makes sense for pooled connections.
245    pub fn reset_connection(&mut self, reset_connection: bool) {
246        self.inner.reset_upon_returning_to_a_pool = reset_connection;
247    }
248
249    pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> {
250        self.inner.stream_mut()
251    }
252
253    pub(crate) fn capabilities(&self) -> CapabilityFlags {
254        self.inner.capabilities
255    }
256
257    /// Will update last IO time for this connection.
258    pub(crate) fn touch(&mut self) {
259        self.inner.last_io = Instant::now();
260    }
261
262    /// Will set packet sequence id to `0`.
263    pub(crate) fn reset_seq_id(&mut self) {
264        if let Some(stream) = self.inner.stream.as_mut() {
265            stream.reset_seq_id();
266        }
267    }
268
269    /// Will syncronize sequence ids between compressed and uncompressed codecs.
270    pub(crate) fn sync_seq_id(&mut self) {
271        if let Some(stream) = self.inner.stream.as_mut() {
272            stream.sync_seq_id();
273        }
274    }
275
276    /// Handles OK packet.
277    pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) {
278        self.inner.status = ok_packet.status_flags();
279        self.inner.last_err_packet = None;
280        self.inner.last_ok_packet = Some(ok_packet);
281    }
282
283    /// Handles ERR packet.
284    pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> {
285        match err_packet {
286            ErrPacket::Error(err) => {
287                self.inner.status = StatusFlags::empty();
288                self.inner.last_ok_packet = None;
289                self.inner.last_err_packet = Some(err.clone().into_owned());
290                Err(Error::from(err))
291            }
292            ErrPacket::Progress(_) => Ok(()),
293        }
294    }
295
296    /// Returns the current transaction status.
297    pub(crate) fn get_tx_status(&self) -> TxStatus {
298        self.inner.tx_status
299    }
300
301    /// Sets the given transaction status for this connection.
302    pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) {
303        self.inner.tx_status = tx_status;
304    }
305
306    /// Returns pending result metadata, if any.
307    ///
308    /// If `Some(_)`, then result is not yet consumed.
309    pub(crate) fn use_pending_result(
310        &mut self,
311    ) -> std::result::Result<Option<&PendingResult>, ServerError> {
312        if let Err(ref e) = self.inner.pending_result {
313            let e = e.clone();
314            self.inner.pending_result = Ok(None);
315            Err(e)
316        } else {
317            Ok(self.inner.pending_result.as_ref().unwrap().as_ref())
318        }
319    }
320
321    pub(crate) fn get_pending_result(
322        &self,
323    ) -> std::result::Result<Option<&PendingResult>, &ServerError> {
324        self.inner.pending_result.as_ref().map(|x| x.as_ref())
325    }
326
327    pub(crate) fn has_pending_result(&self) -> bool {
328        self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_)))
329    }
330
331    /// Sets the given pening result metadata for this connection. Returns the previous value.
332    pub(crate) fn set_pending_result(
333        &mut self,
334        meta: Option<ResultSetMeta>,
335    ) -> std::result::Result<Option<PendingResult>, ServerError> {
336        replace(
337            &mut self.inner.pending_result,
338            Ok(meta.map(PendingResult::Pending)),
339        )
340    }
341
342    pub(crate) fn set_pending_result_error(
343        &mut self,
344        error: ServerError,
345    ) -> std::result::Result<Option<PendingResult>, ServerError> {
346        replace(&mut self.inner.pending_result, Err(error))
347    }
348
349    /// Gives the currently pending result to a caller for consumption.
350    pub(crate) fn take_pending_result(
351        &mut self,
352    ) -> std::result::Result<Option<Arc<ResultSetMeta>>, ServerError> {
353        let mut output = None;
354
355        self.inner.pending_result = match replace(&mut self.inner.pending_result, Ok(None))? {
356            Some(PendingResult::Pending(x)) => {
357                let meta = Arc::new(x);
358                output = Some(meta.clone());
359                Ok(Some(PendingResult::Taken(meta)))
360            }
361            x => Ok(x),
362        };
363
364        Ok(output)
365    }
366
367    /// Returns current status flags.
368    pub(crate) fn status(&self) -> StatusFlags {
369        self.inner.status
370    }
371
372    pub(crate) async fn routine<'a, F, T>(&mut self, mut f: F) -> crate::Result<T>
373    where
374        F: Routine<T> + 'a,
375    {
376        self.inner.disconnected = true;
377        let result = f.call(&mut *self).await;
378        match result {
379            result @ Ok(_) | result @ Err(crate::Error::Server(_)) => {
380                // either OK or non-fatal error
381                self.inner.disconnected = false;
382                result
383            }
384            Err(err) => {
385                if self.inner.stream.is_some() {
386                    self.take_stream().close().await?;
387                }
388                Err(err)
389            }
390        }
391    }
392
393    /// Returns server version.
394    pub fn server_version(&self) -> (u16, u16, u16) {
395        self.inner.version
396    }
397
398    /// Returns connection options.
399    pub fn opts(&self) -> &Opts {
400        &self.inner.opts
401    }
402
403    /// Setup _local_ `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section
404    /// of the crate-level docs).
405    ///
406    /// It'll overwrite existing _local_ handler, if any.
407    ///
408    /// [2]: ../mysql_async/#local-infile-handlers
409    pub fn set_infile_handler<T>(&mut self, handler: T)
410    where
411        T: Future<Output = crate::Result<InfileData>>,
412        T: Send + Sync + 'static,
413    {
414        self.inner.infile_handler = Some(Box::pin(handler));
415    }
416
417    fn take_stream(&mut self) -> Stream {
418        self.inner.stream.take().unwrap()
419    }
420
421    /// Disconnects this connection from server.
422    pub async fn disconnect(mut self) -> Result<()> {
423        if !self.inner.disconnected {
424            self.inner.disconnected = true;
425            self.write_command_data(Command::COM_QUIT, &[]).await?;
426            let stream = self.take_stream();
427            stream.close().await?;
428        }
429        Ok(())
430    }
431
432    /// Closes the connection.
433    async fn close_conn(mut self) -> Result<()> {
434        self = self.cleanup_for_pool().await?;
435        self.disconnect().await
436    }
437
438    /// Returns true if io stream is encrypted.
439    fn is_secure(&self) -> bool {
440        #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
441        {
442            self.inner
443                .stream
444                .as_ref()
445                .map(|x| x.is_secure())
446                .unwrap_or_default()
447        }
448
449        #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))]
450        false
451    }
452
453    /// Returns true if io stream is socket.
454    fn is_socket(&self) -> bool {
455        #[cfg(unix)]
456        {
457            self.inner
458                .stream
459                .as_ref()
460                .map(|x| x.is_socket())
461                .unwrap_or_default()
462        }
463
464        #[cfg(not(unix))]
465        false
466    }
467
468    /// Hacky way to move connection through &mut. `self` becomes unusable.
469    fn take(&mut self) -> Conn {
470        mem::replace(self, Conn::empty(Default::default()))
471    }
472
473    fn empty(opts: Opts) -> Self {
474        Self {
475            inner: Box::new(ConnInner::empty(opts)),
476        }
477    }
478
479    /// Set `io::Stream` options as defined in the `Opts` of the connection.
480    ///
481    /// Requires that self.inner.stream is Some
482    fn setup_stream(&mut self) -> Result<()> {
483        debug_assert!(self.inner.stream.is_some());
484        if let Some(stream) = self.inner.stream.as_mut() {
485            stream.set_tcp_nodelay(self.inner.opts.tcp_nodelay())?;
486        }
487        Ok(())
488    }
489
490    async fn handle_handshake(&mut self) -> Result<()> {
491        let packet = self.read_packet().await?;
492        let handshake = ParseBuf(&packet).parse::<HandshakePacket>(())?;
493
494        // Handshake scramble is always 21 bytes length (20 + zero terminator)
495        self.inner.nonce = {
496            let mut nonce = Vec::from(handshake.scramble_1_ref());
497            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
498            // Trim zero terminator. Fill with zeroes if nonce
499            // is somehow smaller than 20 bytes (this matches the server behavior).
500            nonce.resize(20, 0);
501            nonce
502        };
503
504        self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities();
505        self.inner.version = handshake
506            .maria_db_server_version_parsed()
507            .inspect(|_| self.inner.is_mariadb = true)
508            .or_else(|| handshake.server_version_parsed())
509            .unwrap_or((0, 0, 0));
510        self.inner.id = handshake.connection_id();
511        self.inner.status = handshake.status_flags();
512
513        // Allow only CachingSha2Password and MysqlNativePassword here
514        // because sha256_password is deprecated and other plugins won't
515        // appear here.
516        self.inner.auth_plugin = match handshake.auth_plugin() {
517            Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password,
518            _ => AuthPlugin::MysqlNativePassword,
519        };
520
521        Ok(())
522    }
523
524    async fn switch_to_ssl_if_needed(&mut self) -> Result<()> {
525        if self
526            .inner
527            .opts
528            .get_capabilities()
529            .contains(CapabilityFlags::CLIENT_SSL)
530        {
531            if !self
532                .inner
533                .capabilities
534                .contains(CapabilityFlags::CLIENT_SSL)
535            {
536                return Err(DriverError::NoClientSslFlagFromServer.into());
537            }
538
539            let collation = if self.inner.version >= (5, 5, 3) {
540                UTF8MB4_GENERAL_CI
541            } else {
542                UTF8_GENERAL_CI
543            };
544
545            let ssl_request = SslRequest::new(
546                self.inner.capabilities,
547                DEFAULT_MAX_ALLOWED_PACKET as u32,
548                collation as u8,
549            );
550            self.write_struct(&ssl_request).await?;
551            let conn = self;
552            let ssl_opts = conn.opts().ssl_opts_and_connector().expect("unreachable");
553            let domain = ssl_opts
554                .ssl_opts()
555                .tls_hostname_override()
556                .unwrap_or_else(|| conn.opts().ip_or_hostname())
557                .into();
558            let tls_connector = ssl_opts.build_tls_connector().await?;
559            conn.stream_mut()?
560                .make_secure(domain, &tls_connector)
561                .await?;
562            Ok(())
563        } else {
564            Ok(())
565        }
566    }
567
568    async fn do_handshake_response(&mut self) -> Result<()> {
569        let auth_data = self
570            .inner
571            .auth_plugin
572            .gen_data(self.inner.opts.pass(), &self.inner.nonce);
573
574        let handshake_response = HandshakeResponse::new(
575            auth_data.as_deref(),
576            self.inner.version,
577            self.inner.opts.user().map(|x| x.as_bytes()),
578            self.inner.opts.db_name().map(|x| x.as_bytes()),
579            Some(self.inner.auth_plugin.borrow()),
580            self.capabilities(),
581            Default::default(), // TODO: Add support
582            self.inner
583                .opts
584                .max_allowed_packet()
585                .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32,
586        );
587
588        // Serialize here to satisfy borrow checker.
589        let mut buf = crate::buffer_pool().get();
590        handshake_response.serialize(buf.as_mut());
591
592        self.write_packet(buf).await?;
593        self.inner.handshake_complete = true;
594        Ok(())
595    }
596
597    async fn perform_auth_switch(
598        &mut self,
599        auth_switch_request: AuthSwitchRequest<'_>,
600    ) -> Result<()> {
601        if !self.inner.auth_switched {
602            self.inner.auth_switched = true;
603            self.inner.nonce = auth_switch_request.plugin_data().to_vec();
604
605            if matches!(
606                auth_switch_request.auth_plugin(),
607                AuthPlugin::MysqlOldPassword
608            ) && self.inner.opts.secure_auth()
609            {
610                return Err(DriverError::MysqlOldPasswordDisabled.into());
611            }
612
613            self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned();
614
615            let plugin_data = match &self.inner.auth_plugin {
616                x @ AuthPlugin::CachingSha2Password => {
617                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
618                }
619                x @ AuthPlugin::MysqlNativePassword => {
620                    x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
621                }
622                x @ AuthPlugin::MysqlOldPassword => {
623                    if self.inner.opts.secure_auth() {
624                        return Err(DriverError::MysqlOldPasswordDisabled.into());
625                    } else {
626                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
627                    }
628                }
629                x @ AuthPlugin::MysqlClearPassword => {
630                    if self.inner.opts.enable_cleartext_plugin() {
631                        x.gen_data(self.inner.opts.pass(), &self.inner.nonce)
632                    } else {
633                        return Err(DriverError::CleartextPluginDisabled.into());
634                    }
635                }
636                x @ AuthPlugin::Ed25519 => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
637                x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce),
638            };
639
640            if let Some(plugin_data) = plugin_data {
641                self.write_struct(&plugin_data.into_owned()).await?;
642            } else {
643                self.write_packet(crate::buffer_pool().get()).await?;
644            }
645
646            self.continue_auth().await?;
647
648            Ok(())
649        } else {
650            unreachable!("auth_switched flag should be checked by caller")
651        }
652    }
653
654    fn continue_auth(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
655        // NOTE: we need to box this since it may recurse
656        // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782
657        Box::pin(async move {
658            match self.inner.auth_plugin {
659                AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
660                    self.continue_mysql_native_password_auth().await?;
661                    Ok(())
662                }
663                AuthPlugin::CachingSha2Password => {
664                    self.continue_caching_sha2_password_auth().await?;
665                    Ok(())
666                }
667                AuthPlugin::MysqlClearPassword => {
668                    if self.inner.opts.enable_cleartext_plugin() {
669                        self.continue_mysql_native_password_auth().await?;
670                        Ok(())
671                    } else {
672                        Err(DriverError::CleartextPluginDisabled.into())
673                    }
674                }
675                AuthPlugin::Ed25519 => {
676                    self.continue_ed25519_auth().await?;
677                    Ok(())
678                }
679                AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin {
680                    name: String::from_utf8_lossy(name.as_ref()).to_string(),
681                }
682                .into()),
683            }
684        })
685    }
686
687    fn switch_to_compression(&mut self) -> Result<()> {
688        if self
689            .capabilities()
690            .contains(CapabilityFlags::CLIENT_COMPRESS)
691        {
692            if let Some(compression) = self.inner.opts.compression() {
693                if let Some(stream) = self.inner.stream.as_mut() {
694                    stream.compress(compression);
695                }
696            }
697        }
698        Ok(())
699    }
700
701    async fn continue_ed25519_auth(&mut self) -> Result<()> {
702        let packet = self.read_packet().await?;
703        match packet.first() {
704            Some(0x00) => {
705                // ok packet for empty password
706                Ok(())
707            }
708            Some(0xfe) if !self.inner.auth_switched => {
709                let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
710                self.perform_auth_switch(auth_switch_request).await
711            }
712            _ => Err(DriverError::UnexpectedPacket {
713                payload: packet.to_vec(),
714            }
715            .into()),
716        }
717    }
718
719    async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> {
720        let packet = self.read_packet().await?;
721        match packet.first() {
722            Some(0x00) => {
723                // ok packet for empty password
724                Ok(())
725            }
726            Some(0x01) => match packet.get(1) {
727                Some(0x03) => {
728                    // auth ok
729                    self.drop_packet().await
730                }
731                Some(0x04) => {
732                    let pass = self.inner.opts.pass().unwrap_or_default();
733                    let mut pass = crate::buffer_pool().get_with(pass.as_bytes());
734                    pass.as_mut().push(0);
735
736                    if self.is_secure() || self.is_socket() {
737                        self.write_packet(pass).await?;
738                    } else {
739                        if self.inner.server_key.is_none() {
740                            self.write_bytes(&[0x02][..]).await?;
741                            let packet = self.read_packet().await?;
742                            self.inner.server_key = Some(packet[1..].to_vec());
743                        }
744                        for (i, byte) in pass.as_mut().iter_mut().enumerate() {
745                            *byte ^= self.inner.nonce[i % self.inner.nonce.len()];
746                        }
747                        let encrypted_pass = crypto::encrypt(
748                            &pass,
749                            self.inner.server_key.as_deref().expect("unreachable"),
750                        );
751                        self.write_bytes(&encrypted_pass).await?;
752                    };
753                    self.drop_packet().await?;
754                    Ok(())
755                }
756                _ => Err(DriverError::UnexpectedPacket {
757                    payload: packet.to_vec(),
758                }
759                .into()),
760            },
761            Some(0xfe) if !self.inner.auth_switched => {
762                let auth_switch_request = ParseBuf(&packet).parse::<AuthSwitchRequest>(())?;
763                self.perform_auth_switch(auth_switch_request).await?;
764                Ok(())
765            }
766            _ => Err(DriverError::UnexpectedPacket {
767                payload: packet.to_vec(),
768            }
769            .into()),
770        }
771    }
772
773    async fn continue_mysql_native_password_auth(&mut self) -> Result<()> {
774        let packet = self.read_packet().await?;
775        match packet.first() {
776            Some(0x00) => Ok(()),
777            Some(0xfe) if !self.inner.auth_switched => {
778                let auth_switch = if packet.len() > 1 {
779                    ParseBuf(&packet).parse(())?
780                } else {
781                    let _ = ParseBuf(&packet).parse::<OldAuthSwitchRequest>(())?;
782                    // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin
783                    AuthSwitchRequest::new(
784                        "mysql_old_password".as_bytes(),
785                        self.inner.nonce.clone(),
786                    )
787                };
788                self.perform_auth_switch(auth_switch).await
789            }
790            _ => Err(DriverError::UnexpectedPacket {
791                payload: packet.to_vec(),
792            }
793            .into()),
794        }
795    }
796
797    /// Returns `true` for ProgressReport packet.
798    fn handle_packet(&mut self, packet: &PooledBuf) -> Result<bool> {
799        let ok_packet = if self.has_pending_result() {
800            if self
801                .capabilities()
802                .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF)
803            {
804                ParseBuf(packet)
805                    .parse::<OkPacketDeserializer<ResultSetTerminator>>(self.capabilities())
806                    .map(|x| x.into_inner())
807            } else {
808                ParseBuf(packet)
809                    .parse::<OkPacketDeserializer<OldEofPacket>>(self.capabilities())
810                    .map(|x| x.into_inner())
811            }
812        } else {
813            ParseBuf(packet)
814                .parse::<OkPacketDeserializer<CommonOkPacket>>(self.capabilities())
815                .map(|x| x.into_inner())
816        };
817
818        if let Ok(ok_packet) = ok_packet {
819            self.handle_ok(ok_packet.into_owned());
820        } else {
821            // If we haven't completed the handshake the server will not be aware of our
822            // capabilities and so it will behave as if we have none. In particular, the error
823            // packet will not contain a SQL State field even if our capabilities do contain the
824            // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet
825            // with no capability assumptions if we have not completed the handshake.
826            //
827            // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html
828            let capabilities = if self.inner.handshake_complete {
829                self.capabilities()
830            } else {
831                CapabilityFlags::empty()
832            };
833            let err_packet = ParseBuf(packet).parse::<ErrPacket>(capabilities);
834            if let Ok(err_packet) = err_packet {
835                self.handle_err(err_packet)?;
836                return Ok(true);
837            }
838        }
839
840        Ok(false)
841    }
842
843    pub(crate) async fn read_packet(&mut self) -> Result<PooledBuf> {
844        loop {
845            let packet = crate::io::ReadPacket::new(&mut *self)
846                .await
847                .map_err(|io_err| {
848                    self.inner.stream.take();
849                    self.inner.disconnected = true;
850                    Error::from(io_err)
851                })?;
852            if self.handle_packet(&packet)? {
853                // ignore progress report
854                continue;
855            } else {
856                return Ok(packet);
857            }
858        }
859    }
860
861    /// Returns future that reads packets from a server.
862    pub(crate) async fn read_packets(&mut self, n: usize) -> Result<Vec<PooledBuf>> {
863        let mut packets = Vec::with_capacity(n);
864        for _ in 0..n {
865            packets.push(self.read_packet().await?);
866        }
867        Ok(packets)
868    }
869
870    pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> {
871        crate::io::WritePacket::new(&mut *self, data)
872            .await
873            .map_err(|io_err| {
874                self.inner.stream.take();
875                self.inner.disconnected = true;
876                From::from(io_err)
877            })
878    }
879
880    /// Writes bytes to a server.
881    pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
882        let buf = crate::buffer_pool().get_with(bytes);
883        self.write_packet(buf).await
884    }
885
886    /// Sends a serializable structure to a server.
887    pub(crate) async fn write_struct<T: MySerialize>(&mut self, x: &T) -> Result<()> {
888        let mut buf = crate::buffer_pool().get();
889        x.serialize(buf.as_mut());
890        self.write_packet(buf).await
891    }
892
893    /// Sends a command to a server.
894    pub(crate) async fn write_command<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
895        self.clean_dirty().await?;
896        self.reset_seq_id();
897        self.write_struct(cmd).await
898    }
899
900    /// Returns future that sends full command body to a server.
901    pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> {
902        debug_assert!(!body.is_empty());
903        self.clean_dirty().await?;
904        self.reset_seq_id();
905        self.write_packet(body).await
906    }
907
908    /// Returns future that writes command to a server.
909    pub(crate) async fn write_command_data<T>(&mut self, cmd: Command, cmd_data: T) -> Result<()>
910    where
911        T: AsRef<[u8]>,
912    {
913        let cmd_data = cmd_data.as_ref();
914        let mut buf = crate::buffer_pool().get();
915        let body = buf.as_mut();
916        body.push(cmd as u8);
917        body.extend_from_slice(cmd_data);
918        self.write_command_raw(buf).await
919    }
920
921    async fn drop_packet(&mut self) -> Result<()> {
922        self.read_packet().await?;
923        Ok(())
924    }
925
926    async fn run_init_commands(&mut self) -> Result<()> {
927        let mut init = self.inner.opts.init().to_vec();
928
929        while let Some(query) = init.pop() {
930            self.query_drop(query).await?;
931        }
932
933        Ok(())
934    }
935
936    async fn run_setup_commands(&mut self) -> Result<()> {
937        let mut setup = self.inner.opts.setup().to_vec();
938
939        while let Some(query) = setup.pop() {
940            self.query_drop(query).await?;
941        }
942
943        Ok(())
944    }
945
946    /// Returns a future that resolves to [`Conn`].
947    pub fn new<T: Into<Opts>>(opts: T) -> crate::BoxFuture<'static, Conn> {
948        let opts = opts.into();
949        async move {
950            let mut conn = Conn::empty(opts.clone());
951
952            let stream = if let Some(_path) = opts.socket() {
953                #[cfg(unix)]
954                {
955                    Stream::connect_socket(_path.to_owned()).await?
956                }
957                #[cfg(not(unix))]
958                return Err(crate::DriverError::NamedPipesDisabled.into());
959            } else {
960                let keepalive = opts
961                    .tcp_keepalive()
962                    .map(|x| std::time::Duration::from_millis(x.into()));
963                Stream::connect_tcp(opts.hostport_or_url(), keepalive).await?
964            };
965
966            conn.inner.stream = Some(stream);
967            conn.setup_stream()?;
968            conn.handle_handshake().await?;
969            conn.switch_to_ssl_if_needed().await?;
970            conn.do_handshake_response().await?;
971            conn.continue_auth().await?;
972            conn.switch_to_compression()?;
973            conn.read_settings().await?;
974            conn.reconnect_via_socket_if_needed().await?;
975            conn.run_init_commands().await?;
976            conn.run_setup_commands().await?;
977
978            Ok(conn)
979        }
980        .boxed()
981    }
982
983    /// Returns a future that resolves to [`Conn`].
984    pub async fn from_url<T: AsRef<str>>(url: T) -> Result<Conn> {
985        Conn::new(Opts::from_str(url.as_ref())?).await
986    }
987
988    /// Will try to reconnect via socket using socket address in `self.inner.socket`.
989    ///
990    /// Won't try to reconnect if socket connection is already enforced in [`Opts`].
991    async fn reconnect_via_socket_if_needed(&mut self) -> Result<()> {
992        if let Some(socket) = self.inner.socket.as_ref() {
993            let opts = self.inner.opts.clone();
994            if opts.socket().is_none() {
995                let opts = OptsBuilder::from_opts(opts).socket(Some(&**socket));
996                if let Ok(conn) = Conn::new(opts).await {
997                    let old_conn = std::mem::replace(self, conn);
998                    // tidy up the old connection
999                    old_conn.close_conn().await?;
1000                }
1001            }
1002        }
1003        Ok(())
1004    }
1005
1006    /// Configures the connection based on server settings. In particular:
1007    ///
1008    /// * It reads and stores socket address inside the connection unless if socket address is
1009    ///   already in [`Opts`] or if `prefer_socket` is `false`.
1010    ///
1011    /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`]
1012    ///
1013    /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
1014    ///
1015    async fn read_settings(&mut self) -> Result<()> {
1016        enum Action {
1017            Load(Cfg),
1018            Apply(CfgData),
1019        }
1020
1021        enum CfgData {
1022            MaxAllowedPacket(usize),
1023            WaitTimeout(usize),
1024        }
1025
1026        impl CfgData {
1027            fn apply(&self, conn: &mut Conn) {
1028                match self {
1029                    Self::MaxAllowedPacket(value) => {
1030                        if let Some(stream) = conn.inner.stream.as_mut() {
1031                            stream.set_max_allowed_packet(*value);
1032                        }
1033                    }
1034                    Self::WaitTimeout(value) => {
1035                        conn.inner.wait_timeout = Duration::from_secs(*value as u64);
1036                    }
1037                }
1038            }
1039        }
1040
1041        enum Cfg {
1042            Socket,
1043            MaxAllowedPacket,
1044            WaitTimeout,
1045        }
1046
1047        impl Cfg {
1048            const fn name(&self) -> &'static str {
1049                match self {
1050                    Self::Socket => "@@socket",
1051                    Self::MaxAllowedPacket => "@@max_allowed_packet",
1052                    Self::WaitTimeout => "@@wait_timeout",
1053                }
1054            }
1055
1056            fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
1057                match self {
1058                    Cfg::Socket => {
1059                        conn.inner.socket = value.and_then(crate::from_value);
1060                    }
1061                    Cfg::MaxAllowedPacket => {
1062                        if let Some(stream) = conn.inner.stream.as_mut() {
1063                            stream.set_max_allowed_packet(
1064                                value
1065                                    .and_then(crate::from_value)
1066                                    .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
1067                            );
1068                        }
1069                    }
1070                    Cfg::WaitTimeout => {
1071                        conn.inner.wait_timeout = Duration::from_secs(
1072                            value
1073                                .and_then(crate::from_value)
1074                                .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
1075                        );
1076                    }
1077                }
1078            }
1079        }
1080
1081        let mut actions = vec![
1082            if let Some(x) = self.opts().max_allowed_packet() {
1083                Action::Apply(CfgData::MaxAllowedPacket(x))
1084            } else {
1085                Action::Load(Cfg::MaxAllowedPacket)
1086            },
1087            if let Some(x) = self.opts().wait_timeout() {
1088                Action::Apply(CfgData::WaitTimeout(x))
1089            } else {
1090                Action::Load(Cfg::WaitTimeout)
1091            },
1092        ];
1093
1094        if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1095            actions.push(Action::Load(Cfg::Socket))
1096        }
1097
1098        let loads = actions
1099            .iter()
1100            .filter_map(|x| match x {
1101                Action::Load(x) => Some(x),
1102                Action::Apply(_) => None,
1103            })
1104            .collect::<Vec<_>>();
1105
1106        let loaded = if !loads.is_empty() {
1107            let query = loads
1108                .iter()
1109                .zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1110                .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1111                    acc.push(prefix);
1112                    acc.push_str(cfg.name());
1113                    acc
1114                });
1115
1116            self.query_internal::<Row, String>(query)
1117                .await?
1118                .map(|row| row.unwrap())
1119                .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
1120        } else {
1121            vec![]
1122        };
1123        let mut loaded = loaded.into_iter();
1124
1125        for action in actions {
1126            match action {
1127                Action::Load(cfg) => cfg.apply(self, loaded.next()),
1128                Action::Apply(cfg) => cfg.apply(self),
1129            }
1130        }
1131
1132        Ok(())
1133    }
1134
1135    /// Returns true if time since last IO exceeds `wait_timeout`
1136    /// (or `conn_ttl` if specified in opts).
1137    fn expired(&self) -> bool {
1138        if let Some(deadline) = self.inner.ttl_deadline {
1139            if Instant::now() > deadline {
1140                return true;
1141            }
1142        }
1143        let ttl = self
1144            .inner
1145            .opts
1146            .conn_ttl()
1147            .unwrap_or(self.inner.wait_timeout);
1148        !ttl.is_zero() && self.idling() > ttl
1149    }
1150
1151    /// Returns duration since last IO.
1152    fn idling(&self) -> Duration {
1153        self.inner.last_io.elapsed()
1154    }
1155
1156    /// Executes [`COM_RESET_CONNECTION`][1].
1157    ///
1158    /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3).
1159    /// For older versions consider using [`Conn::change_user`].
1160    ///
1161    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html
1162    pub async fn reset(&mut self) -> Result<bool> {
1163        let supports_com_reset_connection = if self.inner.is_mariadb {
1164            self.inner.version >= (10, 2, 4)
1165        } else {
1166            // assuming mysql
1167            self.inner.version > (5, 7, 2)
1168        };
1169
1170        if supports_com_reset_connection {
1171            self.routine(routines::ResetRoutine).await?;
1172            self.inner.stmt_cache.clear();
1173            self.inner.infile_handler = None;
1174            self.run_setup_commands().await?;
1175        }
1176
1177        Ok(supports_com_reset_connection)
1178    }
1179
1180    /// Executes [`COM_CHANGE_USER`][1].
1181    ///
1182    /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
1183    /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
1184    ///
1185    /// ## Note
1186    ///
1187    /// * Using non-default `opts` for a pooled connection is discouraging.
1188    /// * Connection options will be permanently updated.
1189    ///
1190    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
1191    pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
1192        // We'll kick this connection from a pool if opts are changed.
1193        if opts != ChangeUserOpts::default() {
1194            let mut opts_changed = false;
1195            if let Some(user) = opts.user() {
1196                opts_changed |= user != self.opts().user()
1197            };
1198            if let Some(pass) = opts.pass() {
1199                opts_changed |= pass != self.opts().pass()
1200            };
1201            if let Some(db_name) = opts.db_name() {
1202                opts_changed |= db_name != self.opts().db_name()
1203            };
1204            if opts_changed {
1205                if let Some(pool) = self.inner.pool.take() {
1206                    pool.cancel_connection();
1207                }
1208            }
1209        }
1210
1211        let conn_opts = &mut self.inner.opts;
1212        opts.update_opts(conn_opts);
1213        self.routine(routines::ChangeUser).await?;
1214        self.inner.stmt_cache.clear();
1215        self.inner.infile_handler = None;
1216        self.run_setup_commands().await?;
1217        Ok(())
1218    }
1219
1220    /// Resets the connection upon returning it to a pool.
1221    ///
1222    /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported.
1223    async fn reset_for_pool(mut self) -> Result<Self> {
1224        if !self.reset().await? {
1225            self.change_user(Default::default()).await?;
1226        }
1227        Ok(self)
1228    }
1229
1230    /// Requires that `self.inner.tx_status != TxStatus::None`
1231    pub(crate) async fn rollback_transaction(&mut self) -> Result<()> {
1232        debug_assert_ne!(self.inner.tx_status, TxStatus::None);
1233        self.inner.tx_status = TxStatus::None;
1234        self.query_drop("ROLLBACK").await
1235    }
1236
1237    /// Returns `true` if `SERVER_MORE_RESULTS_EXISTS` flag is contained
1238    /// in status flags of the connection.
1239    pub(crate) fn more_results_exists(&self) -> bool {
1240        self.status()
1241            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
1242    }
1243
1244    /// The purpose of this function is to cleanup a pending result set
1245    /// for prematurely dropeed connection or query result.
1246    ///
1247    /// Requires that there are no other references to the pending result.
1248    pub(crate) async fn drop_result(&mut self) -> Result<()> {
1249        // Map everything into `PendingResult::Pending`
1250        let meta = match self.set_pending_result(None)? {
1251            Some(PendingResult::Pending(meta)) => Some(meta),
1252            Some(PendingResult::Taken(meta)) => {
1253                // This also asserts that there is only one reference left to the taken ResultSetMeta,
1254                // therefore this result set must be dropped here since it won't be dropped anywhere else.
1255                Some(Arc::try_unwrap(meta).expect("Conn::drop_result call on a pending result that may still be droped by someone else"))
1256            }
1257            None => None,
1258        };
1259
1260        let _ = self.set_pending_result(meta);
1261
1262        match self.use_pending_result() {
1263            Ok(Some(PendingResult::Pending(ResultSetMeta::Text(_)))) => {
1264                QueryResult::<'_, '_, TextProtocol>::new(self)
1265                    .drop_result()
1266                    .await
1267            }
1268            Ok(Some(PendingResult::Pending(ResultSetMeta::Binary(_)))) => {
1269                QueryResult::<'_, '_, BinaryProtocol>::new(self)
1270                    .drop_result()
1271                    .await
1272            }
1273            Ok(None) => Ok((/* this case does not require an action */)),
1274            Ok(Some(PendingResult::Taken(_))) | Err(_) => {
1275                unreachable!("this case must be handled earlier in this function")
1276            }
1277        }
1278    }
1279
1280    /// This function will drop pending result and rollback a transaction, if needed.
1281    ///
1282    /// The purpose of this function, is to cleanup the connection while returning it to a [`Pool`].
1283    async fn cleanup_for_pool(mut self) -> Result<Self> {
1284        loop {
1285            let result = if self.has_pending_result() {
1286                self.drop_result().await
1287            } else if self.inner.tx_status != TxStatus::None {
1288                self.rollback_transaction().await
1289            } else {
1290                break;
1291            };
1292
1293            // The connection was dropped and we assume that it was dropped intentionally,
1294            // so we'll ignore non-fatal errors during cleanup (also there is no direct caller
1295            // to return this error to).
1296            if let Err(err) = result {
1297                if err.is_fatal() {
1298                    // This means that connection is completely broken
1299                    // and shouldn't return to a pool.
1300                    return Err(err);
1301                }
1302            }
1303        }
1304        Ok(self)
1305    }
1306}
1307
1308#[cfg(test)]
1309mod test {
1310    use bytes::Bytes;
1311    use futures_util::stream::{self, StreamExt};
1312    use mysql_common::constants::MAX_PAYLOAD_LEN;
1313    use rand::Rng;
1314    use tokio::{io::AsyncWriteExt, net::TcpListener};
1315
1316    use crate::{
1317        from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error,
1318        OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler,
1319    };
1320
1321    #[tokio::test]
1322    async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
1323        let opts = get_opts().client_found_rows(true);
1324        let mut conn = Conn::new(opts).await.unwrap();
1325
1326        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1327            .ignore(&mut conn)
1328            .await?;
1329
1330        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1331            .ignore(&mut conn)
1332            .await?;
1333
1334        // Inserted one row, affected should be one.
1335        assert_eq!(conn.affected_rows(), 1);
1336
1337        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1338            .ignore(&mut conn)
1339            .await?;
1340
1341        // The query doesn't affect any rows, but due to us wanting FOUND rows,
1342        // this has to return one.
1343        assert_eq!(conn.affected_rows(), 1);
1344
1345        Ok(())
1346    }
1347
1348    #[tokio::test]
1349    async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
1350        let mut conn = Conn::new(get_opts()).await.unwrap();
1351
1352        "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
1353            .ignore(&mut conn)
1354            .await?;
1355
1356        "INSERT INTO mysql.found_rows (val) VALUES (1)"
1357            .ignore(&mut conn)
1358            .await?;
1359
1360        // Inserted one row, affected should be one.
1361        assert_eq!(conn.affected_rows(), 1);
1362
1363        "UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
1364            .ignore(&mut conn)
1365            .await?;
1366
1367        // The query doesn't affect any rows.
1368        assert_eq!(conn.affected_rows(), 0);
1369
1370        Ok(())
1371    }
1372
1373    #[test]
1374    fn opts_should_satisfy_send_and_sync() {
1375        struct A<T: Sync + Send>(T);
1376        #[allow(clippy::unnecessary_operation)]
1377        A(get_opts());
1378    }
1379
1380    #[tokio::test]
1381    async fn should_connect_without_database() -> super::Result<()> {
1382        // no database name
1383        let mut conn: Conn = Conn::new(get_opts().db_name(None::<String>)).await?;
1384        conn.ping().await?;
1385        conn.disconnect().await?;
1386
1387        // empty database name
1388        let mut conn: Conn = Conn::new(get_opts().db_name(Some(""))).await?;
1389        conn.ping().await?;
1390        conn.disconnect().await?;
1391
1392        Ok(())
1393    }
1394
1395    #[tokio::test]
1396    async fn should_clean_state_if_wrapper_is_dropeed() -> super::Result<()> {
1397        let mut conn: Conn = Conn::new(get_opts()).await?;
1398
1399        conn.query_drop("CREATE TEMPORARY TABLE mysql.foo (id SERIAL)")
1400            .await?;
1401
1402        // dropped query:
1403        conn.query_iter("SELECT 1").await?;
1404        conn.ping().await?;
1405
1406        // dropped query in dropped transaction:
1407        let mut tx = conn.start_transaction(Default::default()).await?;
1408        tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)")
1409            .await?;
1410        tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?;
1411        drop(tx);
1412        conn.ping().await?;
1413
1414        let count: u8 = conn
1415            .query_first("SELECT COUNT(*) FROM mysql.foo")
1416            .await?
1417            .unwrap_or_default();
1418
1419        assert_eq!(count, 0);
1420
1421        Ok(())
1422    }
1423
1424    #[tokio::test]
1425    async fn should_connect() -> super::Result<()> {
1426        let mut conn: Conn = Conn::new(get_opts()).await?;
1427        conn.ping().await?;
1428        let plugins: Vec<String> = conn
1429            .query_map("SHOW PLUGINS", |mut row: crate::Row| {
1430                row.take("Name").unwrap()
1431            })
1432            .await?;
1433
1434        // Should connect with any combination of supported plugin and empty-nonempty password.
1435        let variants = vec![
1436            ("caching_sha2_password", 2_u8, "non-empty"),
1437            ("caching_sha2_password", 2_u8, ""),
1438            ("mysql_native_password", 0_u8, "non-empty"),
1439            ("mysql_native_password", 0_u8, ""),
1440        ]
1441        .into_iter()
1442        .filter(|variant| plugins.iter().any(|p| p == variant.0));
1443
1444        for (plug, val, pass) in variants {
1445            dbg!((plug, val, pass, conn.inner.version));
1446
1447            if plug == "mysql_native_password" && conn.inner.version >= (8, 4, 0) {
1448                continue;
1449            }
1450
1451            let _ = conn.query_drop("DROP USER 'test_user'@'%'").await;
1452
1453            let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug);
1454            conn.query_drop(query).await.unwrap();
1455
1456            if conn.inner.version < (8, 0, 11) {
1457                conn.query_drop(format!("SET old_passwords = {}", val))
1458                    .await
1459                    .unwrap();
1460                conn.query_drop(format!(
1461                    "SET PASSWORD FOR 'test_user'@'%' = PASSWORD('{}')",
1462                    pass
1463                ))
1464                .await
1465                .unwrap();
1466            } else {
1467                conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass))
1468                    .await
1469                    .unwrap();
1470            };
1471
1472            let opts = get_opts()
1473                .user(Some("test_user"))
1474                .pass(Some(pass))
1475                .db_name(None::<String>);
1476            let result = Conn::new(opts).await;
1477
1478            conn.query_drop("DROP USER 'test_user'@'%'").await.unwrap();
1479
1480            result?.disconnect().await?;
1481        }
1482
1483        if crate::test_misc::test_compression() {
1484            assert!(format!("{:?}", conn).contains("Compression"));
1485        }
1486
1487        if crate::test_misc::test_ssl() {
1488            assert!(format!("{:?}", conn).contains("Tls"));
1489        }
1490
1491        conn.disconnect().await?;
1492        Ok(())
1493    }
1494
1495    #[test]
1496    fn should_not_panic_if_dropped_without_tokio_runtime() {
1497        let fut = Conn::new(get_opts());
1498        let runtime = tokio::runtime::Runtime::new().unwrap();
1499        runtime.block_on(async {
1500            fut.await.unwrap();
1501        });
1502        // connection will drop here
1503    }
1504
1505    #[tokio::test]
1506    async fn should_execute_init_queries_on_new_connection() -> super::Result<()> {
1507        let opts = OptsBuilder::from_opts(get_opts()).init(vec!["SET @a = 42", "SET @b = 'foo'"]);
1508        let mut conn = Conn::new(opts).await?;
1509        let result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1510        conn.disconnect().await?;
1511        assert_eq!(result, vec![(42, "foo".into())]);
1512        Ok(())
1513    }
1514
1515    #[tokio::test]
1516    async fn should_execute_setup_queries_on_reset() -> super::Result<()> {
1517        let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]);
1518        let mut conn = Conn::new(opts).await?;
1519
1520        // initial run
1521        let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?;
1522        assert_eq!(result, vec![(42, "foo".into())]);
1523
1524        // after reset
1525        if conn.reset().await? {
1526            result = conn.query("SELECT @a, @b").await?;
1527            assert_eq!(result, vec![(42, "foo".into())]);
1528        }
1529
1530        // after change user
1531        conn.change_user(Default::default()).await?;
1532        result = conn.query("SELECT @a, @b").await?;
1533        assert_eq!(result, vec![(42, "foo".into())]);
1534
1535        conn.disconnect().await?;
1536        Ok(())
1537    }
1538
1539    #[tokio::test]
1540    async fn should_reset_the_connection() -> super::Result<()> {
1541        let mut conn = Conn::new(get_opts()).await?;
1542
1543        assert_eq!(
1544            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1545            Value::NULL
1546        );
1547
1548        conn.query_drop("SET @foo = 'foo'").await?;
1549
1550        assert_eq!(
1551            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1552            "foo",
1553        );
1554
1555        if conn.reset().await? {
1556            assert_eq!(
1557                conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1558                Value::NULL
1559            );
1560        } else {
1561            assert_eq!(
1562                conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1563                "foo",
1564            );
1565        }
1566
1567        conn.disconnect().await?;
1568        Ok(())
1569    }
1570
1571    #[tokio::test]
1572    async fn should_change_user() -> super::Result<()> {
1573        /// Whether particular authentication plugin should be tested on the current database.
1574        type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool;
1575        /// Generates `CREATE USER` and `SET PASSWORD` statements
1576        type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec<String>;
1577
1578        #[allow(clippy::type_complexity)]
1579        const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [
1580            (
1581                "mysql_old_password",
1582                |is_mariadb, version| is_mariadb || version < (5, 7, 0),
1583                |is_mariadb, version, pass| {
1584                    if is_mariadb {
1585                        vec![
1586                            "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(),
1587                            "SET old_passwords=1".into(),
1588                            format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"),
1589                            "SET old_passwords=0".into(),
1590                        ]
1591                    } else if matches!(version, (5, 6, _)) {
1592                        vec![
1593                            "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(),
1594                            format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"),
1595                        ]
1596                    } else {
1597                        vec![
1598                            "CREATE USER '__mats'@'%'".into(),
1599                            format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1600                        ]
1601                    }
1602                },
1603            ),
1604            (
1605                "mysql_native_password",
1606                |is_mariadb, version| is_mariadb || version < (8, 4, 0),
1607                |is_mariadb, version, pass| {
1608                    if is_mariadb {
1609                        vec![
1610                            format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')")
1611                        ]
1612                    } else if version < (8, 0, 0) {
1613                        vec![
1614                            format!(
1615                                "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password"
1616                            ),
1617                            format!("SET old_passwords = 0"),
1618                            format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1619                        ]
1620                    } else {
1621                        vec![
1622                            format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'")
1623                        ]
1624                    }
1625                },
1626            ),
1627            (
1628                "caching_sha2_password",
1629                |is_mariadb, version| !is_mariadb && version >= (5, 8, 0),
1630                |_is_mariadb, _version, pass| {
1631                    vec![
1632                        format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'")
1633                    ]
1634                },
1635            ),
1636            (
1637                "client_ed25519",
1638                |is_mariadb, version| is_mariadb && version >= (11, 6, 2),
1639                |_is_mariadb, _version, pass| {
1640                    vec![format!(
1641                        "CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')"
1642                    )]
1643                },
1644            ),
1645        ];
1646
1647        fn random_pass() -> String {
1648            let mut rng = rand::rng();
1649            let pass: [u8; 10] = rng.random();
1650
1651            IntoIterator::into_iter(pass)
1652                .map(|x| ((x % (123 - 97)) + 97) as char)
1653                .collect()
1654        }
1655
1656        let mut conn = Conn::new(get_opts()).await?;
1657
1658        assert_eq!(
1659            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1660            Value::NULL
1661        );
1662
1663        conn.query_drop("SET @foo = 'foo'").await?;
1664
1665        assert_eq!(
1666            conn.query_first::<String, _>("SELECT @foo").await?.unwrap(),
1667            "foo",
1668        );
1669
1670        conn.change_user(Default::default()).await?;
1671        assert_eq!(
1672            conn.query_first::<Value, _>("SELECT @foo").await?.unwrap(),
1673            Value::NULL
1674        );
1675
1676        for (i, (plugin, should_run, create_statements)) in TEST_MATRIX.iter().enumerate() {
1677            dbg!(plugin);
1678            let is_mariadb = conn.inner.is_mariadb;
1679            let version = conn.server_version();
1680
1681            if should_run(is_mariadb, version) {
1682                let pass = random_pass();
1683
1684                let result = conn
1685                    .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats")
1686                    .await;
1687
1688                if matches!(version, (5, 6, _)) && i == 0 {
1689                    // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration
1690                    drop(result);
1691                } else {
1692                    result.unwrap();
1693                }
1694
1695                for statement in create_statements(is_mariadb, version, &pass) {
1696                    conn.query_drop(dbg!(statement)).await.unwrap();
1697                }
1698
1699                let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap();
1700                conn2
1701                    .change_user(
1702                        ChangeUserOpts::default()
1703                            .with_db_name(None)
1704                            .with_user(Some("__mats".into()))
1705                            .with_pass(Some(pass)),
1706                    )
1707                    .await
1708                    .unwrap();
1709
1710                let (db, user) = conn2
1711                    .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
1712                    .await
1713                    .unwrap()
1714                    .unwrap();
1715                assert_eq!(db, None);
1716                assert!(user.starts_with("__mats"));
1717
1718                conn2.disconnect().await.unwrap();
1719            }
1720        }
1721
1722        Ok(())
1723    }
1724
1725    #[tokio::test]
1726    async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> {
1727        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
1728
1729        let mut conn = Conn::new(opts).await?;
1730        conn.exec_drop("DO ?", (1_u8,)).await?;
1731
1732        let stmt = conn.prep("DO 2").await?;
1733        conn.exec_drop(&stmt, ()).await?;
1734        conn.exec_drop(&stmt, ()).await?;
1735        conn.close(stmt).await?;
1736
1737        conn.exec_drop("DO 3", ()).await?;
1738        conn.exec_batch("DO 4", vec![(), ()]).await?;
1739        conn.exec_first::<u8, _, _>("DO 5", ()).await?;
1740        let row: Option<(crate::Value, usize)> = conn
1741            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1742            .await?;
1743
1744        assert_eq!(row.unwrap().1, 1);
1745        assert_eq!(conn.inner.stmt_cache.len(), 0);
1746
1747        conn.disconnect().await?;
1748
1749        Ok(())
1750    }
1751
1752    #[tokio::test]
1753    async fn should_hold_stmt_cache_size_bound() -> super::Result<()> {
1754        let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
1755        let mut conn = Conn::new(opts).await?;
1756        conn.exec_drop("DO 1", ()).await?;
1757        conn.exec_drop("DO 2", ()).await?;
1758        conn.exec_drop("DO 3", ()).await?;
1759        conn.exec_drop("DO 1", ()).await?;
1760        conn.exec_drop("DO 4", ()).await?;
1761        conn.exec_drop("DO 3", ()).await?;
1762        conn.exec_drop("DO 5", ()).await?;
1763        conn.exec_drop("DO 6", ()).await?;
1764        let row_opt = conn
1765            .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
1766            .await?;
1767        let (_, count): (String, usize) = row_opt.unwrap();
1768        assert_eq!(count, 3);
1769        let order = conn
1770            .stmt_cache_ref()
1771            .iter()
1772            .map(|item| item.1.query.0.as_ref())
1773            .collect::<Vec<&[u8]>>();
1774        assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
1775        conn.disconnect().await?;
1776        Ok(())
1777    }
1778
1779    #[tokio::test]
1780    async fn should_perform_queries() -> super::Result<()> {
1781        let mut conn = Conn::new(get_opts()).await?;
1782        for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) {
1783            let long_string = "A".repeat(x);
1784            let result: Vec<(String, u8)> = conn
1785                .query(format!(r"SELECT '{}', 231", long_string))
1786                .await?;
1787            assert_eq!((long_string, 231_u8), result[0]);
1788        }
1789        conn.disconnect().await?;
1790        Ok(())
1791    }
1792
1793    #[tokio::test]
1794    async fn should_query_drop() -> super::Result<()> {
1795        let mut conn = Conn::new(get_opts()).await?;
1796        conn.query_drop("CREATE TEMPORARY TABLE tmp (id int DEFAULT 10, name text)")
1797            .await?;
1798        conn.query_drop("INSERT INTO tmp VALUES (1, 'foo')").await?;
1799        let result: Option<u8> = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1800        conn.disconnect().await?;
1801        assert_eq!(result, Some(1_u8));
1802        Ok(())
1803    }
1804
1805    #[tokio::test]
1806    async fn should_prepare_statement() -> super::Result<()> {
1807        let mut conn = Conn::new(get_opts()).await?;
1808        let stmt = conn.prep(r"SELECT ?").await?;
1809        conn.close(stmt).await?;
1810        conn.disconnect().await?;
1811
1812        let mut conn = Conn::new(get_opts()).await?;
1813        let stmt = conn.prep(r"SELECT :foo").await?;
1814
1815        {
1816            let query = String::from("SELECT ?, ?");
1817            let stmt = conn.prep(&*query).await?;
1818            conn.close(stmt).await?;
1819            {
1820                let mut conn = Conn::new(get_opts()).await?;
1821                let stmt = conn.prep(&*query).await?;
1822                conn.close(stmt).await?;
1823                conn.disconnect().await?;
1824            }
1825        }
1826
1827        conn.close(stmt).await?;
1828        conn.disconnect().await?;
1829
1830        Ok(())
1831    }
1832
1833    #[tokio::test]
1834    async fn should_execute_statement() -> super::Result<()> {
1835        let long_string = "A".repeat(18 * 1024 * 1024);
1836        let mut conn = Conn::new(get_opts()).await?;
1837        let stmt = conn.prep(r"SELECT ?").await?;
1838        let result = conn.exec_iter(&stmt, (&long_string,)).await?;
1839        let mut mapped = result.map_and_drop(from_row::<(String,)>).await?;
1840        assert_eq!(mapped.len(), 1);
1841        assert_eq!(mapped.pop(), Some((long_string,)));
1842        let result = conn.exec_iter(&stmt, (42_u8,)).await?;
1843        let collected = result.collect_and_drop::<(u8,)>().await?;
1844        assert_eq!(collected, vec![(42u8,)]);
1845        let result = conn.exec_iter(&stmt, (8_u8,)).await?;
1846        let reduced = result
1847            .reduce_and_drop(2, |mut acc, row| {
1848                acc += from_row::<i32>(row);
1849                acc
1850            })
1851            .await?;
1852        conn.close(stmt).await?;
1853        conn.disconnect().await?;
1854        assert_eq!(reduced, 10);
1855
1856        let mut conn = Conn::new(get_opts()).await?;
1857        let stmt = conn.prep(r"SELECT :foo, :bar, :foo, 3").await?;
1858        let result = conn
1859            .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" })
1860            .await?;
1861        let mut mapped = result
1862            .map_and_drop(from_row::<(String, String, String, u8)>)
1863            .await?;
1864        assert_eq!(mapped.len(), 1);
1865        assert_eq!(
1866            mapped.pop(),
1867            Some(("quux".into(), "baz".into(), "quux".into(), 3))
1868        );
1869        let result = conn
1870            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1871            .await?;
1872        let collected = result.collect_and_drop::<(u8, u8, u8, u8)>().await?;
1873        assert_eq!(collected, vec![(2, 3, 2, 3)]);
1874        let result = conn
1875            .exec_iter(&stmt, params! { "foo" => 2, "bar" => 3 })
1876            .await?;
1877        let reduced = result
1878            .reduce_and_drop(0, |acc, row| {
1879                let (a, b, c, d): (u8, u8, u8, u8) = from_row(row);
1880                acc + a + b + c + d
1881            })
1882            .await?;
1883        conn.close(stmt).await?;
1884        conn.disconnect().await?;
1885        assert_eq!(reduced, 10);
1886        Ok(())
1887    }
1888
1889    #[tokio::test]
1890    async fn should_prep_exec_statement() -> super::Result<()> {
1891        let mut conn = Conn::new(get_opts()).await?;
1892        let result = conn
1893            .exec_iter(r"SELECT :a, :b, :a", params! { "a" => 2, "b" => 3 })
1894            .await?;
1895        let output = result
1896            .map_and_drop(|row| {
1897                let (a, b, c): (u8, u8, u8) = from_row(row);
1898                a * b * c
1899            })
1900            .await?;
1901        conn.disconnect().await?;
1902        assert_eq!(output[0], 12u8);
1903        Ok(())
1904    }
1905
1906    #[tokio::test]
1907    async fn should_first_exec_statement() -> super::Result<()> {
1908        let mut conn = Conn::new(get_opts()).await?;
1909        let output = conn
1910            .exec_first(
1911                r"SELECT :a UNION ALL SELECT :b",
1912                params! { "a" => 2, "b" => 3 },
1913            )
1914            .await?;
1915        conn.disconnect().await?;
1916        assert_eq!(output, Some(2u8));
1917        Ok(())
1918    }
1919
1920    #[tokio::test]
1921    async fn issue_107() -> super::Result<()> {
1922        let mut conn = Conn::new(get_opts()).await?;
1923        conn.query_drop(
1924            r"CREATE TEMPORARY TABLE mysql.issue (
1925                    a BIGINT(20) UNSIGNED,
1926                    b VARBINARY(16),
1927                    c BINARY(32),
1928                    d BIGINT(20) UNSIGNED,
1929                    e BINARY(32)
1930                )",
1931        )
1932        .await?;
1933        conn.query_drop(
1934            r"INSERT INTO mysql.issue VALUES (
1935                    0,
1936                    0xC066F966B0860000,
1937                    0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1938                    0,
1939                    ''
1940                ), (
1941                    1,
1942                    '',
1943                    0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1944                    0,
1945                    ''
1946                )",
1947        )
1948        .await?;
1949
1950        let q = "SELECT b, c, d, e FROM mysql.issue";
1951        let result = conn.query_iter(q).await?;
1952
1953        let loaded_structs = result
1954            .map_and_drop(crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>)
1955            .await?;
1956
1957        conn.disconnect().await?;
1958
1959        assert_eq!(loaded_structs.len(), 2);
1960
1961        Ok(())
1962    }
1963
1964    #[tokio::test]
1965    async fn should_run_transactions() -> super::Result<()> {
1966        let mut conn = Conn::new(get_opts()).await?;
1967        conn.query_drop("CREATE TEMPORARY TABLE tmp (id INT, name TEXT)")
1968            .await?;
1969        let mut transaction = conn.start_transaction(Default::default()).await?;
1970        transaction
1971            .query_drop("INSERT INTO tmp VALUES (1, 'foo'), (2, 'bar')")
1972            .await?;
1973        assert_eq!(transaction.last_insert_id(), None);
1974        assert_eq!(transaction.affected_rows(), 2);
1975        assert_eq!(transaction.get_warnings(), 0);
1976        assert_eq!(transaction.info(), "Records: 2  Duplicates: 0  Warnings: 0");
1977        transaction.commit().await?;
1978        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1979        assert_eq!(output_opt, Some((2u8,)));
1980        let mut transaction = conn.start_transaction(Default::default()).await?;
1981        transaction
1982            .query_drop("INSERT INTO tmp VALUES (3, 'baz'), (4, 'quux')")
1983            .await?;
1984        let output_opt = transaction
1985            .exec_first("SELECT COUNT(*) FROM tmp", ())
1986            .await?;
1987        assert_eq!(output_opt, Some((4u8,)));
1988        transaction.rollback().await?;
1989        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1990        assert_eq!(output_opt, Some((2u8,)));
1991
1992        let mut transaction = conn.start_transaction(Default::default()).await?;
1993        transaction
1994            .query_drop("INSERT INTO tmp VALUES (3, 'baz')")
1995            .await?;
1996        drop(transaction); // implicit rollback
1997        let output_opt = conn.query_first("SELECT COUNT(*) FROM tmp").await?;
1998        assert_eq!(output_opt, Some((2u8,)));
1999
2000        conn.disconnect().await?;
2001        Ok(())
2002    }
2003
2004    #[tokio::test]
2005    async fn should_handle_multiresult_set_with_error() -> super::Result<()> {
2006        const QUERY_FIRST: &str = "SELECT * FROM tmp; SELECT 1; SELECT 2;";
2007        const QUERY_MIDDLE: &str = "SELECT 1; SELECT * FROM tmp; SELECT 2";
2008        let mut conn = Conn::new(get_opts()).await.unwrap();
2009
2010        // if error is in the first result set, then query should return it immediately.
2011        let result = QUERY_FIRST.run(&mut conn).await;
2012        assert!(matches!(result, Err(Error::Server(_))));
2013
2014        let mut result = QUERY_MIDDLE.run(&mut conn).await.unwrap();
2015
2016        // first result set will contain one row
2017        let result_set: Vec<u8> = result.collect().await.unwrap();
2018        assert_eq!(result_set, vec![1]);
2019
2020        // second result set will contain an error.
2021        let result_set: super::Result<Vec<u8>> = result.collect().await;
2022        assert!(matches!(result_set, Err(Error::Server(_))));
2023
2024        // there will be no third result set
2025        assert!(result.is_empty());
2026
2027        conn.ping().await?;
2028        conn.disconnect().await?;
2029
2030        Ok(())
2031    }
2032
2033    #[tokio::test]
2034    async fn should_handle_binary_multiresult_set_with_error() -> super::Result<()> {
2035        const PROC_DEF_FIRST: &str =
2036            r#"CREATE PROCEDURE err_first() BEGIN SELECT * FROM tmp; SELECT 1; END"#;
2037        const PROC_DEF_MIDDLE: &str =
2038            r#"CREATE PROCEDURE err_middle() BEGIN SELECT 1; SELECT * FROM tmp; SELECT 2; END"#;
2039
2040        let mut conn = Conn::new(get_opts()).await.unwrap();
2041
2042        conn.query_drop("DROP PROCEDURE IF EXISTS err_first")
2043            .await?;
2044        conn.query_iter(PROC_DEF_FIRST).await?;
2045
2046        conn.query_drop("DROP PROCEDURE IF EXISTS err_middle")
2047            .await?;
2048        conn.query_iter(PROC_DEF_MIDDLE).await?;
2049
2050        // if error is in the first result set, then query should return it immediately.
2051        let result = conn.query_iter("CALL err_first()").await;
2052        assert!(matches!(result, Err(Error::Server(_))));
2053
2054        let mut result = conn.query_iter("CALL err_middle()").await?;
2055
2056        // first result set will contain one row
2057        let result_set: Vec<u8> = result.collect().await.unwrap();
2058        assert_eq!(result_set, vec![1]);
2059
2060        // second result set will contain an error.
2061        let result_set: super::Result<Vec<u8>> = result.collect().await;
2062        assert!(matches!(result_set, Err(Error::Server(_))));
2063
2064        // there will be no third result set
2065        assert!(result.is_empty());
2066
2067        conn.ping().await?;
2068        conn.disconnect().await?;
2069
2070        Ok(())
2071    }
2072
2073    #[tokio::test]
2074    async fn should_handle_multiresult_set_with_local_infile() -> super::Result<()> {
2075        use std::fs::write;
2076
2077        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2078        let file_path = file_path.path();
2079        let file_name = file_path.file_name().unwrap();
2080
2081        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2082
2083        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2084
2085        // LOCAL INFILE in the middle of a multi-result set should not break anything.
2086        let mut conn = Conn::new(opts).await.unwrap();
2087        "CREATE TEMPORARY TABLE tmp (a TEXT)".run(&mut conn).await?;
2088
2089        let query = format!(
2090            r#"SELECT * FROM tmp;
2091            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2092            LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;
2093            SELECT * FROM tmp"#,
2094            file_name.to_str().unwrap(),
2095            file_name.to_str().unwrap(),
2096        );
2097
2098        let mut result = query.run(&mut conn).await?;
2099
2100        let result_set = result.collect::<String>().await?;
2101        assert_eq!(result_set.len(), 0);
2102
2103        let mut no_local_infile = false;
2104
2105        for _ in 0..2 {
2106            match result.collect::<String>().await {
2107                Ok(result_set) => {
2108                    assert_eq!(result.affected_rows(), 3);
2109                    assert!(result_set.is_empty())
2110                }
2111                Err(Error::Server(ref err)) if err.code == 1148 => {
2112                    // The used command is not allowed with this MySQL version
2113                    no_local_infile = true;
2114                    break;
2115                }
2116                Err(Error::Server(ref err)) if err.code == 3948 => {
2117                    // Loading local data is disabled;
2118                    // this must be enabled on both the client and server sides
2119                    no_local_infile = true;
2120                    break;
2121                }
2122                Err(err) => return Err(err),
2123            }
2124        }
2125
2126        if no_local_infile {
2127            assert!(result.is_empty());
2128            assert_eq!(result_set.len(), 0);
2129        } else {
2130            let result_set = result.collect::<String>().await?;
2131            assert_eq!(result_set.len(), 6);
2132            assert_eq!(result_set[0], "AAAAAA");
2133            assert_eq!(result_set[1], "BBBBBB");
2134            assert_eq!(result_set[2], "CCCCCC");
2135            assert_eq!(result_set[3], "AAAAAA");
2136            assert_eq!(result_set[4], "BBBBBB");
2137            assert_eq!(result_set[5], "CCCCCC");
2138        }
2139
2140        conn.ping().await?;
2141        conn.disconnect().await?;
2142
2143        Ok(())
2144    }
2145
2146    #[tokio::test]
2147    async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
2148        let mut c = Conn::new(get_opts()).await?;
2149        c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
2150            .await?;
2151
2152        let mut result = c
2153            .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
2154            .await?;
2155        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2156
2157        result.for_each(drop).await?;
2158        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);
2159
2160        result.for_each(drop).await?;
2161        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);
2162
2163        result.for_each(drop).await?;
2164        assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);
2165
2166        c.disconnect().await?;
2167        Ok(())
2168    }
2169
2170    #[tokio::test]
2171    async fn should_expose_query_result_metadata() -> super::Result<()> {
2172        let pool = Pool::new(get_opts());
2173        let mut c = pool.get_conn().await?;
2174
2175        c.query_drop(
2176            r"
2177            CREATE TEMPORARY TABLE `foo`
2178                ( `id` SERIAL
2179                , `bar_id` varchar(36) NOT NULL
2180                , `baz_id` varchar(36) NOT NULL
2181                , `ctime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP()
2182                , PRIMARY KEY (`id`)
2183                , KEY `bar_idx` (`bar_id`)
2184                , KEY `baz_idx` (`baz_id`)
2185            );",
2186        )
2187        .await?;
2188
2189        const QUERY: &str = "INSERT INTO foo (bar_id, baz_id) VALUES (?, ?)";
2190        let params = ("qwerty", "data.employee_id");
2191
2192        let query_result = c.exec_iter(QUERY, params).await?;
2193        assert_eq!(query_result.last_insert_id(), Some(1));
2194        query_result.drop_result().await?;
2195
2196        c.exec_drop(QUERY, params).await?;
2197        assert_eq!(c.last_insert_id(), Some(2));
2198
2199        let mut tx = c.start_transaction(Default::default()).await?;
2200
2201        tx.exec_drop(QUERY, params).await?;
2202        assert_eq!(tx.last_insert_id(), Some(3));
2203
2204        Ok(())
2205    }
2206
2207    #[tokio::test]
2208    async fn should_handle_local_infile_locally() -> super::Result<()> {
2209        let mut conn = Conn::new(get_opts()).await.unwrap();
2210        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2211            .await
2212            .unwrap();
2213
2214        conn.set_infile_handler(async move {
2215            Ok(
2216                stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")])
2217                    .map(Ok)
2218                    .boxed(),
2219            )
2220        });
2221
2222        match conn
2223            .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#)
2224            .await
2225        {
2226            Ok(_) => (),
2227            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2228                // The used command is not allowed with this MySQL version
2229                return Ok(());
2230            }
2231            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2232                // Loading local data is disabled;
2233                // this must be enabled on both the client and server sides
2234                return Ok(());
2235            }
2236            e @ Err(_) => e.unwrap(),
2237        };
2238
2239        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2240        assert_eq!(result.len(), 3);
2241        assert_eq!(result[0], "AAAAAA");
2242        assert_eq!(result[1], "BBBBBB");
2243        assert_eq!(result[2], "CCCCCC");
2244
2245        Ok(())
2246    }
2247
2248    #[tokio::test]
2249    async fn should_handle_local_infile_globally() -> super::Result<()> {
2250        use std::fs::write;
2251
2252        let file_path = tempfile::Builder::new().tempfile_in("").unwrap();
2253        let file_path = file_path.path();
2254        let file_name = file_path.file_name().unwrap();
2255
2256        write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?;
2257
2258        let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..])));
2259
2260        let mut conn = Conn::new(opts).await.unwrap();
2261        conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);")
2262            .await
2263            .unwrap();
2264
2265        match conn
2266            .query_drop(format!(
2267                r#"LOAD DATA LOCAL INFILE "{}" INTO TABLE tmp;"#,
2268                file_name.to_str().unwrap(),
2269            ))
2270            .await
2271        {
2272            Ok(_) => (),
2273            Err(super::Error::Server(ref err)) if err.code == 1148 => {
2274                // The used command is not allowed with this MySQL version
2275                return Ok(());
2276            }
2277            Err(super::Error::Server(ref err)) if err.code == 3948 => {
2278                // Loading local data is disabled;
2279                // this must be enabled on both the client and server sides
2280                return Ok(());
2281            }
2282            e @ Err(_) => e.unwrap(),
2283        };
2284
2285        let result: Vec<String> = conn.query("SELECT * FROM tmp").await?;
2286        assert_eq!(result.len(), 3);
2287        assert_eq!(result[0], "AAAAAA");
2288        assert_eq!(result[1], "BBBBBB");
2289        assert_eq!(result[2], "CCCCCC");
2290
2291        Ok(())
2292    }
2293
2294    #[tokio::test]
2295    async fn should_handle_initial_error_packet() {
2296        let header = [
2297            0x68, 0x00, 0x00, // packet_length
2298            0x00, // sequence
2299            0xff, // error_header
2300            0x69, 0x04, // error_code
2301        ];
2302        let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'";
2303
2304        // Create a fake MySQL server that immediately replies with an error packet.
2305        let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap();
2306
2307        let listen_addr = listener.local_addr().unwrap();
2308
2309        tokio::task::spawn(async move {
2310            let (mut stream, _) = listener.accept().await.unwrap();
2311            stream.write_all(&header).await.unwrap();
2312            stream.write_all(error_message.as_bytes()).await.unwrap();
2313            stream.shutdown().await.unwrap();
2314        });
2315
2316        let opts = OptsBuilder::default()
2317            .ip_or_hostname(listen_addr.ip().to_string())
2318            .tcp_port(listen_addr.port());
2319        let server_err = match Conn::new(opts).await {
2320            Err(Error::Server(server_err)) => server_err,
2321            other => panic!("expected server error but got: {:?}", other),
2322        };
2323        assert_eq!(
2324            server_err,
2325            ServerError {
2326                code: 1129,
2327                state: "HY000".to_owned(),
2328                message: error_message.to_owned(),
2329            }
2330        );
2331    }
2332
2333    #[cfg(feature = "nightly")]
2334    mod bench {
2335        use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts};
2336
2337        #[bench]
2338        fn simple_exec(bencher: &mut test::Bencher) {
2339            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2340            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2341
2342            bencher.iter(|| {
2343                runtime.block_on(conn.query_drop("DO 1")).unwrap();
2344            });
2345
2346            runtime.block_on(conn.disconnect()).unwrap();
2347        }
2348
2349        #[bench]
2350        fn select_large_string(bencher: &mut test::Bencher) {
2351            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2352            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2353
2354            bencher.iter(|| {
2355                runtime
2356                    .block_on(conn.query_drop("SELECT REPEAT('A', 10000)"))
2357                    .unwrap();
2358            });
2359
2360            runtime.block_on(conn.disconnect()).unwrap();
2361        }
2362
2363        #[bench]
2364        fn prepared_exec(bencher: &mut test::Bencher) {
2365            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2366            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2367            let stmt = runtime.block_on(conn.prep("DO 1")).unwrap();
2368
2369            bencher.iter(|| {
2370                runtime.block_on(conn.exec_drop(&stmt, ())).unwrap();
2371            });
2372
2373            runtime.block_on(conn.close(stmt)).unwrap();
2374            runtime.block_on(conn.disconnect()).unwrap();
2375        }
2376
2377        #[bench]
2378        fn prepare_and_exec(bencher: &mut test::Bencher) {
2379            let mut runtime = tokio::runtime::Runtime::new().unwrap();
2380            let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap();
2381
2382            bencher.iter(|| {
2383                runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap();
2384            });
2385
2386            runtime.block_on(conn.disconnect()).unwrap();
2387        }
2388    }
2389}