use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use mz_ore::netio::AsyncReady;
use mz_server_core::TlsMode;
use tokio::io::{self, AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
use tokio_openssl::SslStream;
use tokio_postgres::error::SqlState;
use crate::ErrorResponse;
pub const CONN_UUID_KEY: &str = "mz_connection_uuid";
pub const MZ_FORWARDED_FOR_KEY: &str = "mz_forwarded_for";
#[derive(Debug)]
pub enum Conn<A> {
Unencrypted(A),
Ssl(SslStream<A>),
}
impl<A> Conn<A> {
pub fn inner_mut(&mut self) -> &mut A {
match self {
Conn::Unencrypted(inner) => inner,
Conn::Ssl(inner) => inner.get_mut(),
}
}
pub fn ensure_tls_compatibility(
&self,
tls_mode: &Option<TlsMode>,
) -> Result<(), ErrorResponse> {
match (tls_mode, self) {
(None, Conn::Unencrypted(_)) => (),
(None, Conn::Ssl(_)) => unreachable!(),
(Some(TlsMode::Allow), Conn::Unencrypted(_)) => (),
(Some(TlsMode::Allow), Conn::Ssl(_)) => (),
(Some(TlsMode::Require), Conn::Ssl(_)) => (),
(Some(TlsMode::Require), Conn::Unencrypted(_)) => {
return Err(ErrorResponse::fatal(
SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
"TLS encryption is required",
));
}
}
Ok(())
}
}
impl<A> AsyncRead for Conn<A>
where
A: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Conn::Unencrypted(inner) => Pin::new(inner).poll_read(cx, buf),
Conn::Ssl(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
}
impl<A> AsyncWrite for Conn<A>
where
A: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
match self.get_mut() {
Conn::Unencrypted(inner) => Pin::new(inner).poll_write(cx, buf),
Conn::Ssl(inner) => Pin::new(inner).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
Conn::Unencrypted(inner) => Pin::new(inner).poll_flush(cx),
Conn::Ssl(inner) => Pin::new(inner).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
Conn::Unencrypted(inner) => Pin::new(inner).poll_shutdown(cx),
Conn::Ssl(inner) => Pin::new(inner).poll_shutdown(cx),
}
}
}
#[async_trait]
impl<A> AsyncReady for Conn<A>
where
A: AsyncRead + AsyncWrite + AsyncReady + Sync + Unpin,
{
async fn ready(&self, interest: Interest) -> io::Result<Ready> {
match self {
Conn::Unencrypted(inner) => inner.ready(interest).await,
Conn::Ssl(inner) => inner.ready(interest).await,
}
}
}