mysql_async/io/
mod.rs

1// Copyright (c) 2016 Anatoly Ikorsky
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9pub use self::{read_packet::ReadPacket, write_packet::WritePacket};
10
11use bytes::BytesMut;
12use futures_core::{ready, stream};
13use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
14use pin_project::pin_project;
15#[cfg(any(unix, windows))]
16use socket2::{Socket as Socket2Socket, TcpKeepalive};
17#[cfg(unix)]
18use tokio::io::AsyncWriteExt;
19use tokio::{
20    io::{AsyncRead, AsyncWrite, ErrorKind::Interrupted, ReadBuf},
21    net::TcpStream,
22};
23use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
24
25#[cfg(unix)]
26use std::path::Path;
27use std::{
28    fmt,
29    future::Future,
30    io::{
31        self,
32        ErrorKind::{BrokenPipe, NotConnected, Other},
33    },
34    mem::replace,
35    net::SocketAddr,
36    ops::{Deref, DerefMut},
37    pin::Pin,
38    task::{Context, Poll},
39    time::Duration,
40};
41
42use crate::{
43    buffer_pool::PooledBuf,
44    error::IoError,
45    opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT},
46};
47
48#[cfg(unix)]
49use crate::io::socket::Socket;
50
51mod tls;
52
53macro_rules! with_interrupted {
54    ($e:expr) => {
55        loop {
56            match $e {
57                Poll::Ready(Err(err)) if err.kind() == Interrupted => continue,
58                x => break x,
59            }
60        }
61    };
62}
63
64mod read_packet;
65mod socket;
66mod write_packet;
67
68#[derive(Debug)]
69pub struct PacketCodec {
70    inner: PacketCodecInner,
71    decode_buf: PooledBuf,
72}
73
74impl Default for PacketCodec {
75    fn default() -> Self {
76        Self {
77            inner: Default::default(),
78            decode_buf: crate::buffer_pool().get(),
79        }
80    }
81}
82
83impl Deref for PacketCodec {
84    type Target = PacketCodecInner;
85
86    fn deref(&self) -> &Self::Target {
87        &self.inner
88    }
89}
90
91impl DerefMut for PacketCodec {
92    fn deref_mut(&mut self) -> &mut Self::Target {
93        &mut self.inner
94    }
95}
96
97impl Decoder for PacketCodec {
98    type Item = PooledBuf;
99    type Error = IoError;
100
101    fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, IoError> {
102        if self.inner.decode(src, self.decode_buf.as_mut())? {
103            let new_buf = crate::buffer_pool().get();
104            Ok(Some(replace(&mut self.decode_buf, new_buf)))
105        } else {
106            Ok(None)
107        }
108    }
109}
110
111impl Encoder<PooledBuf> for PacketCodec {
112    type Error = IoError;
113
114    fn encode(&mut self, item: PooledBuf, dst: &mut BytesMut) -> std::result::Result<(), IoError> {
115        Ok(self.inner.encode(&mut item.as_ref(), dst)?)
116    }
117}
118
119#[pin_project(project = EndpointProj)]
120#[derive(Debug)]
121pub(crate) enum Endpoint {
122    Plain(Option<TcpStream>),
123    #[cfg(feature = "native-tls-tls")]
124    Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
125    #[cfg(feature = "rustls-tls")]
126    Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
127    #[cfg(unix)]
128    Socket(#[pin] Socket),
129}
130
131/// This future will check that TcpStream is live.
132///
133/// This check is similar to a one, implemented by GitHub team for the go-sql-driver/mysql.
134#[derive(Debug)]
135struct CheckTcpStream<'a>(&'a mut TcpStream);
136
137impl Future for CheckTcpStream<'_> {
138    type Output = io::Result<()>;
139    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
140        match self.0.poll_read_ready(cx) {
141            Poll::Ready(Ok(())) => {
142                // stream is readable
143                let mut buf = [0_u8; 1];
144                match self.0.try_read(&mut buf) {
145                    Ok(0) => Poll::Ready(Err(io::Error::new(BrokenPipe, "broken pipe"))),
146                    Ok(_) => Poll::Ready(Err(io::Error::new(Other, "stream should be empty"))),
147                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Ready(Ok(())),
148                    Err(err) => Poll::Ready(Err(err)),
149                }
150            }
151            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
152            Poll::Pending => Poll::Ready(Ok(())),
153        }
154    }
155}
156
157impl Endpoint {
158    #[cfg(unix)]
159    fn is_socket(&self) -> bool {
160        matches!(self, Self::Socket(_))
161    }
162
163    /// Checks, that connection is alive.
164    async fn check(&mut self) -> std::result::Result<(), IoError> {
165        //return Ok(());
166        match self {
167            Endpoint::Plain(Some(stream)) => {
168                CheckTcpStream(stream).await?;
169                Ok(())
170            }
171            #[cfg(feature = "native-tls-tls")]
172            Endpoint::Secure(tls_stream) => {
173                CheckTcpStream(tls_stream.get_mut().get_mut().get_mut()).await?;
174                Ok(())
175            }
176            #[cfg(feature = "rustls-tls")]
177            Endpoint::Secure(tls_stream) => {
178                let stream = tls_stream.get_mut().0;
179                CheckTcpStream(stream).await?;
180                Ok(())
181            }
182            #[cfg(unix)]
183            Endpoint::Socket(socket) => {
184                let _ = socket.write(&[]).await?;
185                Ok(())
186            }
187            Endpoint::Plain(None) => unreachable!(),
188        }
189    }
190
191    #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
192    pub fn is_secure(&self) -> bool {
193        matches!(self, Endpoint::Secure(_))
194    }
195
196    #[cfg(all(not(feature = "native-tls-tls"), not(feature = "rustls")))]
197    pub async fn make_secure(
198        &mut self,
199        _domain: String,
200        _ssl_opts: crate::SslOpts,
201    ) -> crate::error::Result<()> {
202        panic!(
203            "Client had asked for TLS connection but TLS support is disabled. \
204            Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
205        )
206    }
207
208    pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
209        match *self {
210            Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?,
211            Endpoint::Plain(None) => unreachable!(),
212            #[cfg(feature = "native-tls-tls")]
213            Endpoint::Secure(ref stream) => {
214                stream.get_ref().get_ref().get_ref().set_nodelay(val)?
215            }
216            #[cfg(feature = "rustls-tls")]
217            Endpoint::Secure(ref stream) => {
218                let stream = stream.get_ref().0;
219                stream.set_nodelay(val)?;
220            }
221            #[cfg(unix)]
222            Endpoint::Socket(_) => (/* inapplicable */),
223        }
224        Ok(())
225    }
226}
227
228impl From<TcpStream> for Endpoint {
229    fn from(stream: TcpStream) -> Self {
230        Endpoint::Plain(Some(stream))
231    }
232}
233
234#[cfg(unix)]
235impl From<Socket> for Endpoint {
236    fn from(socket: Socket) -> Self {
237        Endpoint::Socket(socket)
238    }
239}
240
241#[cfg(feature = "native-tls-tls")]
242impl From<tokio_native_tls::TlsStream<TcpStream>> for Endpoint {
243    fn from(stream: tokio_native_tls::TlsStream<TcpStream>) -> Self {
244        Endpoint::Secure(stream)
245    }
246}
247
248/* TODO
249#[cfg(feature = "rustls-tls")]
250*/
251
252impl AsyncRead for Endpoint {
253    fn poll_read(
254        self: Pin<&mut Self>,
255        cx: &mut Context<'_>,
256        buf: &mut ReadBuf<'_>,
257    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
258        let mut this = self.project();
259        with_interrupted!(match this {
260            EndpointProj::Plain(ref mut stream) => {
261                Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf)
262            }
263            #[cfg(feature = "native-tls-tls")]
264            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
265            #[cfg(feature = "rustls-tls")]
266            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
267            #[cfg(unix)]
268            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
269        })
270    }
271}
272
273impl AsyncWrite for Endpoint {
274    fn poll_write(
275        self: Pin<&mut Self>,
276        cx: &mut Context,
277        buf: &[u8],
278    ) -> Poll<std::result::Result<usize, tokio::io::Error>> {
279        let mut this = self.project();
280        with_interrupted!(match this {
281            EndpointProj::Plain(ref mut stream) => {
282                Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf)
283            }
284            #[cfg(feature = "native-tls-tls")]
285            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
286            #[cfg(feature = "rustls-tls")]
287            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
288            #[cfg(unix)]
289            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
290        })
291    }
292
293    fn poll_flush(
294        self: Pin<&mut Self>,
295        cx: &mut Context,
296    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
297        let mut this = self.project();
298        with_interrupted!(match this {
299            EndpointProj::Plain(ref mut stream) => {
300                Pin::new(stream.as_mut().unwrap()).poll_flush(cx)
301            }
302            #[cfg(feature = "native-tls-tls")]
303            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
304            #[cfg(feature = "rustls-tls")]
305            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
306            #[cfg(unix)]
307            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
308        })
309    }
310
311    fn poll_shutdown(
312        self: Pin<&mut Self>,
313        cx: &mut Context,
314    ) -> Poll<std::result::Result<(), tokio::io::Error>> {
315        let mut this = self.project();
316        with_interrupted!(match this {
317            EndpointProj::Plain(ref mut stream) => {
318                Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx)
319            }
320            #[cfg(feature = "native-tls-tls")]
321            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
322            #[cfg(feature = "rustls-tls")]
323            EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
324            #[cfg(unix)]
325            EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
326        })
327    }
328}
329
330/// A Stream, connected to MySql server.
331pub struct Stream {
332    closed: bool,
333    pub(crate) codec: Option<Box<Framed<Endpoint, PacketCodec>>>,
334}
335
336impl fmt::Debug for Stream {
337    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338        write!(
339            f,
340            "Stream (endpoint={:?})",
341            self.codec.as_ref().unwrap().get_ref()
342        )
343    }
344}
345
346impl Stream {
347    #[cfg(unix)]
348    fn new<T: Into<Endpoint>>(endpoint: T) -> Self {
349        let endpoint = endpoint.into();
350
351        Self {
352            closed: false,
353            codec: Box::new(Framed::new(endpoint, PacketCodec::default())).into(),
354        }
355    }
356
357    pub(crate) async fn connect_tcp(
358        addr: &HostPortOrUrl,
359        keepalive: Option<Duration>,
360    ) -> io::Result<Stream> {
361        let tcp_stream = match addr {
362            HostPortOrUrl::HostPort {
363                host,
364                port,
365                resolved_ips,
366            } => match resolved_ips {
367                Some(ips) => {
368                    let addrs = ips
369                        .iter()
370                        .map(|ip| SocketAddr::new(*ip, *port))
371                        .collect::<Vec<_>>();
372                    TcpStream::connect(&*addrs).await?
373                }
374                None => TcpStream::connect((host.as_str(), *port)).await?,
375            },
376            HostPortOrUrl::Url(url) => {
377                let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
378                TcpStream::connect(&*addrs).await?
379            }
380        };
381
382        #[cfg(any(unix, windows))]
383        if let Some(duration) = keepalive {
384            #[cfg(unix)]
385            let socket = {
386                use std::os::unix::prelude::*;
387                let fd = tcp_stream.as_raw_fd();
388                unsafe { Socket2Socket::from_raw_fd(fd) }
389            };
390            #[cfg(windows)]
391            let socket = {
392                use std::os::windows::prelude::*;
393                let sock = tcp_stream.as_raw_socket();
394                unsafe { Socket2Socket::from_raw_socket(sock) }
395            };
396            socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?;
397            std::mem::forget(socket);
398        }
399
400        Ok(Stream {
401            closed: false,
402            codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(),
403        })
404    }
405
406    #[cfg(unix)]
407    pub(crate) async fn connect_socket<P: AsRef<Path>>(path: P) -> io::Result<Stream> {
408        Ok(Stream::new(Socket::new(path).await?))
409    }
410
411    pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
412        self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
413    }
414
415    pub(crate) async fn make_secure(
416        &mut self,
417        domain: String,
418        ssl_opts: SslOpts,
419    ) -> crate::error::Result<()> {
420        let codec = self.codec.take().unwrap();
421        let FramedParts { mut io, codec, .. } = codec.into_parts();
422        io.make_secure(domain, ssl_opts).await?;
423        let codec = Framed::new(io, codec);
424        self.codec = Some(Box::new(codec));
425        Ok(())
426    }
427
428    #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
429    pub(crate) fn is_secure(&self) -> bool {
430        self.codec.as_ref().unwrap().get_ref().is_secure()
431    }
432
433    #[cfg(unix)]
434    pub(crate) fn is_socket(&self) -> bool {
435        self.codec.as_ref().unwrap().get_ref().is_socket()
436    }
437
438    pub(crate) fn reset_seq_id(&mut self) {
439        if let Some(codec) = self.codec.as_mut() {
440            codec.codec_mut().reset_seq_id();
441        }
442    }
443
444    pub(crate) fn sync_seq_id(&mut self) {
445        if let Some(codec) = self.codec.as_mut() {
446            codec.codec_mut().sync_seq_id();
447        }
448    }
449
450    pub(crate) fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) {
451        if let Some(codec) = self.codec.as_mut() {
452            codec.codec_mut().max_allowed_packet = max_allowed_packet;
453        }
454    }
455
456    pub(crate) fn compress(&mut self, level: crate::Compression) {
457        if let Some(codec) = self.codec.as_mut() {
458            codec.codec_mut().compress(level);
459        }
460    }
461
462    /// Checks, that connection is alive.
463    pub(crate) async fn check(&mut self) -> std::result::Result<(), IoError> {
464        if let Some(codec) = self.codec.as_mut() {
465            codec.get_mut().check().await?;
466        }
467        Ok(())
468    }
469
470    pub(crate) async fn close(mut self) -> std::result::Result<(), IoError> {
471        self.closed = true;
472        if let Some(mut codec) = self.codec {
473            use futures_sink::Sink;
474            futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) {
475                Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => {
476                    Poll::Ready(Ok(()))
477                }
478                x => x,
479            })
480            .await?;
481        }
482        Ok(())
483    }
484}
485
486impl stream::Stream for Stream {
487    type Item = std::result::Result<PooledBuf, IoError>;
488
489    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
490        if !self.closed {
491            let item = ready!(Pin::new(self.codec.as_mut().unwrap()).poll_next(cx)).transpose()?;
492            Poll::Ready(Ok(item).transpose())
493        } else {
494            Poll::Ready(None)
495        }
496    }
497}
498
499#[cfg(test)]
500mod test {
501    #[cfg(unix)] // no sane way to retrieve current keepalive value on windows
502    #[tokio::test]
503    async fn should_connect_with_keepalive() {
504        use crate::{test_misc::get_opts, Conn};
505
506        let opts = get_opts()
507            .tcp_keepalive(Some(42_000_u32))
508            .prefer_socket(false);
509        let mut conn: Conn = Conn::new(opts).await.unwrap();
510        let stream = conn.stream_mut().unwrap();
511        let endpoint = stream.codec.as_mut().unwrap().get_ref();
512        let stream = match endpoint {
513            super::Endpoint::Plain(Some(stream)) => stream,
514            #[cfg(feature = "rustls-tls")]
515            super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
516            #[cfg(feature = "native-tls-tls")]
517            super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
518            _ => unreachable!(),
519        };
520        let sock = unsafe {
521            use std::os::unix::prelude::*;
522            let raw = stream.as_raw_fd();
523            socket2::Socket::from_raw_fd(raw)
524        };
525
526        assert_eq!(
527            sock.keepalive_time().unwrap(),
528            std::time::Duration::from_millis(42_000),
529        );
530
531        std::mem::forget(sock);
532
533        conn.disconnect().await.unwrap();
534    }
535}