1pub use self::{read_packet::ReadPacket, write_packet::WritePacket};
10
11use bytes::BytesMut;
12use futures_core::{ready, stream};
13use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
14use pin_project::pin_project;
15#[cfg(any(unix, windows))]
16use socket2::{Socket as Socket2Socket, TcpKeepalive};
17#[cfg(unix)]
18use tokio::io::AsyncWriteExt;
19use tokio::{
20 io::{AsyncRead, AsyncWrite, ErrorKind::Interrupted, ReadBuf},
21 net::TcpStream,
22};
23use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
24
25#[cfg(unix)]
26use std::path::Path;
27use std::{
28 fmt,
29 future::Future,
30 io::{
31 self,
32 ErrorKind::{BrokenPipe, NotConnected, Other},
33 },
34 mem::replace,
35 net::SocketAddr,
36 ops::{Deref, DerefMut},
37 pin::Pin,
38 task::{Context, Poll},
39 time::Duration,
40};
41
42use crate::{
43 buffer_pool::PooledBuf,
44 error::IoError,
45 opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT},
46};
47
48#[cfg(unix)]
49use crate::io::socket::Socket;
50
51mod tls;
52
53macro_rules! with_interrupted {
54 ($e:expr) => {
55 loop {
56 match $e {
57 Poll::Ready(Err(err)) if err.kind() == Interrupted => continue,
58 x => break x,
59 }
60 }
61 };
62}
63
64mod read_packet;
65mod socket;
66mod write_packet;
67
68#[derive(Debug)]
69pub struct PacketCodec {
70 inner: PacketCodecInner,
71 decode_buf: PooledBuf,
72}
73
74impl Default for PacketCodec {
75 fn default() -> Self {
76 Self {
77 inner: Default::default(),
78 decode_buf: crate::buffer_pool().get(),
79 }
80 }
81}
82
83impl Deref for PacketCodec {
84 type Target = PacketCodecInner;
85
86 fn deref(&self) -> &Self::Target {
87 &self.inner
88 }
89}
90
91impl DerefMut for PacketCodec {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 &mut self.inner
94 }
95}
96
97impl Decoder for PacketCodec {
98 type Item = PooledBuf;
99 type Error = IoError;
100
101 fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, IoError> {
102 if self.inner.decode(src, self.decode_buf.as_mut())? {
103 let new_buf = crate::buffer_pool().get();
104 Ok(Some(replace(&mut self.decode_buf, new_buf)))
105 } else {
106 Ok(None)
107 }
108 }
109}
110
111impl Encoder<PooledBuf> for PacketCodec {
112 type Error = IoError;
113
114 fn encode(&mut self, item: PooledBuf, dst: &mut BytesMut) -> std::result::Result<(), IoError> {
115 Ok(self.inner.encode(&mut item.as_ref(), dst)?)
116 }
117}
118
119#[pin_project(project = EndpointProj)]
120#[derive(Debug)]
121pub(crate) enum Endpoint {
122 Plain(Option<TcpStream>),
123 #[cfg(feature = "native-tls-tls")]
124 Secure(#[pin] tokio_native_tls::TlsStream<TcpStream>),
125 #[cfg(feature = "rustls-tls")]
126 Secure(#[pin] tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
127 #[cfg(unix)]
128 Socket(#[pin] Socket),
129}
130
131#[derive(Debug)]
135struct CheckTcpStream<'a>(&'a mut TcpStream);
136
137impl Future for CheckTcpStream<'_> {
138 type Output = io::Result<()>;
139 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
140 match self.0.poll_read_ready(cx) {
141 Poll::Ready(Ok(())) => {
142 let mut buf = [0_u8; 1];
144 match self.0.try_read(&mut buf) {
145 Ok(0) => Poll::Ready(Err(io::Error::new(BrokenPipe, "broken pipe"))),
146 Ok(_) => Poll::Ready(Err(io::Error::new(Other, "stream should be empty"))),
147 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Ready(Ok(())),
148 Err(err) => Poll::Ready(Err(err)),
149 }
150 }
151 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
152 Poll::Pending => Poll::Ready(Ok(())),
153 }
154 }
155}
156
157impl Endpoint {
158 #[cfg(unix)]
159 fn is_socket(&self) -> bool {
160 matches!(self, Self::Socket(_))
161 }
162
163 async fn check(&mut self) -> std::result::Result<(), IoError> {
165 match self {
167 Endpoint::Plain(Some(stream)) => {
168 CheckTcpStream(stream).await?;
169 Ok(())
170 }
171 #[cfg(feature = "native-tls-tls")]
172 Endpoint::Secure(tls_stream) => {
173 CheckTcpStream(tls_stream.get_mut().get_mut().get_mut()).await?;
174 Ok(())
175 }
176 #[cfg(feature = "rustls-tls")]
177 Endpoint::Secure(tls_stream) => {
178 let stream = tls_stream.get_mut().0;
179 CheckTcpStream(stream).await?;
180 Ok(())
181 }
182 #[cfg(unix)]
183 Endpoint::Socket(socket) => {
184 let _ = socket.write(&[]).await?;
185 Ok(())
186 }
187 Endpoint::Plain(None) => unreachable!(),
188 }
189 }
190
191 #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
192 pub fn is_secure(&self) -> bool {
193 matches!(self, Endpoint::Secure(_))
194 }
195
196 #[cfg(all(not(feature = "native-tls-tls"), not(feature = "rustls")))]
197 pub async fn make_secure(
198 &mut self,
199 _domain: String,
200 _ssl_opts: crate::SslOpts,
201 ) -> crate::error::Result<()> {
202 panic!(
203 "Client had asked for TLS connection but TLS support is disabled. \
204 Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]"
205 )
206 }
207
208 pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
209 match *self {
210 Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?,
211 Endpoint::Plain(None) => unreachable!(),
212 #[cfg(feature = "native-tls-tls")]
213 Endpoint::Secure(ref stream) => {
214 stream.get_ref().get_ref().get_ref().set_nodelay(val)?
215 }
216 #[cfg(feature = "rustls-tls")]
217 Endpoint::Secure(ref stream) => {
218 let stream = stream.get_ref().0;
219 stream.set_nodelay(val)?;
220 }
221 #[cfg(unix)]
222 Endpoint::Socket(_) => (),
223 }
224 Ok(())
225 }
226}
227
228impl From<TcpStream> for Endpoint {
229 fn from(stream: TcpStream) -> Self {
230 Endpoint::Plain(Some(stream))
231 }
232}
233
234#[cfg(unix)]
235impl From<Socket> for Endpoint {
236 fn from(socket: Socket) -> Self {
237 Endpoint::Socket(socket)
238 }
239}
240
241#[cfg(feature = "native-tls-tls")]
242impl From<tokio_native_tls::TlsStream<TcpStream>> for Endpoint {
243 fn from(stream: tokio_native_tls::TlsStream<TcpStream>) -> Self {
244 Endpoint::Secure(stream)
245 }
246}
247
248impl AsyncRead for Endpoint {
253 fn poll_read(
254 self: Pin<&mut Self>,
255 cx: &mut Context<'_>,
256 buf: &mut ReadBuf<'_>,
257 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
258 let mut this = self.project();
259 with_interrupted!(match this {
260 EndpointProj::Plain(ref mut stream) => {
261 Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf)
262 }
263 #[cfg(feature = "native-tls-tls")]
264 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
265 #[cfg(feature = "rustls-tls")]
266 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
267 #[cfg(unix)]
268 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
269 })
270 }
271}
272
273impl AsyncWrite for Endpoint {
274 fn poll_write(
275 self: Pin<&mut Self>,
276 cx: &mut Context,
277 buf: &[u8],
278 ) -> Poll<std::result::Result<usize, tokio::io::Error>> {
279 let mut this = self.project();
280 with_interrupted!(match this {
281 EndpointProj::Plain(ref mut stream) => {
282 Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf)
283 }
284 #[cfg(feature = "native-tls-tls")]
285 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
286 #[cfg(feature = "rustls-tls")]
287 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
288 #[cfg(unix)]
289 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
290 })
291 }
292
293 fn poll_flush(
294 self: Pin<&mut Self>,
295 cx: &mut Context,
296 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
297 let mut this = self.project();
298 with_interrupted!(match this {
299 EndpointProj::Plain(ref mut stream) => {
300 Pin::new(stream.as_mut().unwrap()).poll_flush(cx)
301 }
302 #[cfg(feature = "native-tls-tls")]
303 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
304 #[cfg(feature = "rustls-tls")]
305 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
306 #[cfg(unix)]
307 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
308 })
309 }
310
311 fn poll_shutdown(
312 self: Pin<&mut Self>,
313 cx: &mut Context,
314 ) -> Poll<std::result::Result<(), tokio::io::Error>> {
315 let mut this = self.project();
316 with_interrupted!(match this {
317 EndpointProj::Plain(ref mut stream) => {
318 Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx)
319 }
320 #[cfg(feature = "native-tls-tls")]
321 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
322 #[cfg(feature = "rustls-tls")]
323 EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
324 #[cfg(unix)]
325 EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
326 })
327 }
328}
329
330pub struct Stream {
332 closed: bool,
333 pub(crate) codec: Option<Box<Framed<Endpoint, PacketCodec>>>,
334}
335
336impl fmt::Debug for Stream {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 write!(
339 f,
340 "Stream (endpoint={:?})",
341 self.codec.as_ref().unwrap().get_ref()
342 )
343 }
344}
345
346impl Stream {
347 #[cfg(unix)]
348 fn new<T: Into<Endpoint>>(endpoint: T) -> Self {
349 let endpoint = endpoint.into();
350
351 Self {
352 closed: false,
353 codec: Box::new(Framed::new(endpoint, PacketCodec::default())).into(),
354 }
355 }
356
357 pub(crate) async fn connect_tcp(
358 addr: &HostPortOrUrl,
359 keepalive: Option<Duration>,
360 ) -> io::Result<Stream> {
361 let tcp_stream = match addr {
362 HostPortOrUrl::HostPort {
363 host,
364 port,
365 resolved_ips,
366 } => match resolved_ips {
367 Some(ips) => {
368 let addrs = ips
369 .iter()
370 .map(|ip| SocketAddr::new(*ip, *port))
371 .collect::<Vec<_>>();
372 TcpStream::connect(&*addrs).await?
373 }
374 None => TcpStream::connect((host.as_str(), *port)).await?,
375 },
376 HostPortOrUrl::Url(url) => {
377 let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
378 TcpStream::connect(&*addrs).await?
379 }
380 };
381
382 #[cfg(any(unix, windows))]
383 if let Some(duration) = keepalive {
384 #[cfg(unix)]
385 let socket = {
386 use std::os::unix::prelude::*;
387 let fd = tcp_stream.as_raw_fd();
388 unsafe { Socket2Socket::from_raw_fd(fd) }
389 };
390 #[cfg(windows)]
391 let socket = {
392 use std::os::windows::prelude::*;
393 let sock = tcp_stream.as_raw_socket();
394 unsafe { Socket2Socket::from_raw_socket(sock) }
395 };
396 socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?;
397 std::mem::forget(socket);
398 }
399
400 Ok(Stream {
401 closed: false,
402 codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(),
403 })
404 }
405
406 #[cfg(unix)]
407 pub(crate) async fn connect_socket<P: AsRef<Path>>(path: P) -> io::Result<Stream> {
408 Ok(Stream::new(Socket::new(path).await?))
409 }
410
411 pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> {
412 self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val)
413 }
414
415 pub(crate) async fn make_secure(
416 &mut self,
417 domain: String,
418 ssl_opts: SslOpts,
419 ) -> crate::error::Result<()> {
420 let codec = self.codec.take().unwrap();
421 let FramedParts { mut io, codec, .. } = codec.into_parts();
422 io.make_secure(domain, ssl_opts).await?;
423 let codec = Framed::new(io, codec);
424 self.codec = Some(Box::new(codec));
425 Ok(())
426 }
427
428 #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
429 pub(crate) fn is_secure(&self) -> bool {
430 self.codec.as_ref().unwrap().get_ref().is_secure()
431 }
432
433 #[cfg(unix)]
434 pub(crate) fn is_socket(&self) -> bool {
435 self.codec.as_ref().unwrap().get_ref().is_socket()
436 }
437
438 pub(crate) fn reset_seq_id(&mut self) {
439 if let Some(codec) = self.codec.as_mut() {
440 codec.codec_mut().reset_seq_id();
441 }
442 }
443
444 pub(crate) fn sync_seq_id(&mut self) {
445 if let Some(codec) = self.codec.as_mut() {
446 codec.codec_mut().sync_seq_id();
447 }
448 }
449
450 pub(crate) fn set_max_allowed_packet(&mut self, max_allowed_packet: usize) {
451 if let Some(codec) = self.codec.as_mut() {
452 codec.codec_mut().max_allowed_packet = max_allowed_packet;
453 }
454 }
455
456 pub(crate) fn compress(&mut self, level: crate::Compression) {
457 if let Some(codec) = self.codec.as_mut() {
458 codec.codec_mut().compress(level);
459 }
460 }
461
462 pub(crate) async fn check(&mut self) -> std::result::Result<(), IoError> {
464 if let Some(codec) = self.codec.as_mut() {
465 codec.get_mut().check().await?;
466 }
467 Ok(())
468 }
469
470 pub(crate) async fn close(mut self) -> std::result::Result<(), IoError> {
471 self.closed = true;
472 if let Some(mut codec) = self.codec {
473 use futures_sink::Sink;
474 futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) {
475 Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => {
476 Poll::Ready(Ok(()))
477 }
478 x => x,
479 })
480 .await?;
481 }
482 Ok(())
483 }
484}
485
486impl stream::Stream for Stream {
487 type Item = std::result::Result<PooledBuf, IoError>;
488
489 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
490 if !self.closed {
491 let item = ready!(Pin::new(self.codec.as_mut().unwrap()).poll_next(cx)).transpose()?;
492 Poll::Ready(Ok(item).transpose())
493 } else {
494 Poll::Ready(None)
495 }
496 }
497}
498
499#[cfg(test)]
500mod test {
501 #[cfg(unix)] #[tokio::test]
503 async fn should_connect_with_keepalive() {
504 use crate::{test_misc::get_opts, Conn};
505
506 let opts = get_opts()
507 .tcp_keepalive(Some(42_000_u32))
508 .prefer_socket(false);
509 let mut conn: Conn = Conn::new(opts).await.unwrap();
510 let stream = conn.stream_mut().unwrap();
511 let endpoint = stream.codec.as_mut().unwrap().get_ref();
512 let stream = match endpoint {
513 super::Endpoint::Plain(Some(stream)) => stream,
514 #[cfg(feature = "rustls-tls")]
515 super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0,
516 #[cfg(feature = "native-tls-tls")]
517 super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(),
518 _ => unreachable!(),
519 };
520 let sock = unsafe {
521 use std::os::unix::prelude::*;
522 let raw = stream.as_raw_fd();
523 socket2::Socket::from_raw_fd(raw)
524 };
525
526 assert_eq!(
527 sock.keepalive_time().unwrap(),
528 std::time::Duration::from_millis(42_000),
529 );
530
531 std::mem::forget(sock);
532
533 conn.disconnect().await.unwrap();
534 }
535}