#[cfg(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
))]
use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream};
use crate::{
client::{tls::MaybeTlsStream, AuthMethod, Config},
tds::{
codec::{
self, Encode, LoginMessage, Packet, PacketCodec, PacketHeader, PacketStatus,
PreloginMessage, TokenDone,
},
stream::TokenStream,
Context, HEADER_BYTES,
},
EncryptionLevel, SqlReadBytes,
};
use asynchronous_codec::Framed;
use bytes::BytesMut;
#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
use codec::TokenSspi;
use futures::{ready, AsyncRead, AsyncWrite, SinkExt, Stream, TryStream, TryStreamExt};
#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
use libgssapi::{
context::{ClientCtx, CtxFlags},
credential::{Cred, CredUsage},
name::Name,
oid::{OidSet, GSS_MECH_KRB5, GSS_NT_KRB5_PRINCIPAL},
};
use pretty_hex::*;
#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
use std::ops::Deref;
use std::{cmp, fmt::Debug, io, pin::Pin, task};
use task::Poll;
use tracing::{event, Level};
#[cfg(all(windows, feature = "winauth"))]
use winauth::{windows::NtlmSspiBuilder, NextBytes};
pub(crate) struct Connection<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
transport: Framed<MaybeTlsStream<S>, PacketCodec>,
flushed: bool,
context: Context,
buf: BytesMut,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Debug for Connection<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("transport", &"Framed<..>")
.field("flushed", &self.flushed)
.field("context", &self.context)
.field("buf", &self.buf.as_ref().hex_dump())
.finish()
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result<Connection<S>> {
let context = {
let mut context = Context::new();
context.set_spn(config.get_host(), config.get_port());
context
};
let transport = Framed::new(MaybeTlsStream::Raw(tcp_stream), PacketCodec);
let mut connection = Self {
transport,
context,
flushed: false,
buf: BytesMut::new(),
};
let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_));
let prelogin = connection
.prelogin(config.encryption, fed_auth_required)
.await?;
let encryption = prelogin.negotiated_encryption(config.encryption);
let connection = connection.tls_handshake(&config, encryption).await?;
let mut connection = connection
.login(
config.auth,
encryption,
config.database,
config.host,
config.application_name,
prelogin,
)
.await?;
connection.flush_done().await?;
Ok(connection)
}
async fn flush_done(&mut self) -> crate::Result<TokenDone> {
TokenStream::new(self).flush_done().await
}
#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
async fn flush_sspi(&mut self) -> crate::Result<TokenSspi> {
TokenStream::new(self).flush_sspi().await
}
#[cfg(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
))]
fn post_login_encryption(mut self, encryption: EncryptionLevel) -> Self {
if let EncryptionLevel::Off = encryption {
event!(
Level::WARN,
"Turning TLS off after a login. All traffic from here on is not encrypted.",
);
let Self { transport, .. } = self;
let tcp = transport.into_inner().into_inner();
self.transport = Framed::new(MaybeTlsStream::Raw(tcp), PacketCodec);
}
self
}
#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
fn post_login_encryption(self, _: EncryptionLevel) -> Self {
self
}
pub async fn send<E>(&mut self, mut header: PacketHeader, item: E) -> crate::Result<()>
where
E: Sized + Encode<BytesMut>,
{
self.flushed = false;
let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES;
let mut payload = BytesMut::new();
item.encode(&mut payload)?;
while !payload.is_empty() {
let writable = cmp::min(payload.len(), packet_size);
let split_payload = payload.split_to(writable);
if payload.is_empty() {
header.set_status(PacketStatus::EndOfMessage);
} else {
header.set_status(PacketStatus::NormalMessage);
}
event!(
Level::TRACE,
"Sending a packet ({} bytes)",
split_payload.len() + HEADER_BYTES,
);
self.write_to_wire(header, split_payload).await?;
}
self.flush_sink().await?;
Ok(())
}
pub(crate) async fn write_to_wire(
&mut self,
header: PacketHeader,
data: BytesMut,
) -> crate::Result<()> {
self.flushed = false;
let packet = Packet::new(header, data);
self.transport.send(packet).await?;
Ok(())
}
pub(crate) async fn flush_sink(&mut self) -> crate::Result<()> {
self.transport.flush().await
}
pub async fn flush_stream(&mut self) -> crate::Result<()> {
self.buf.truncate(0);
if self.flushed {
return Ok(());
}
while let Some(packet) = self.try_next().await? {
event!(
Level::WARN,
"Flushing unhandled packet from the wire. Please consume your streams!",
);
let is_last = packet.is_last();
if is_last {
break;
}
}
Ok(())
}
pub fn is_eof(&self) -> bool {
self.flushed && self.buf.is_empty()
}
async fn prelogin(
&mut self,
encryption: EncryptionLevel,
fed_auth_required: bool,
) -> crate::Result<PreloginMessage> {
let mut msg = PreloginMessage::new();
msg.encryption = encryption;
msg.fed_auth_required = fed_auth_required;
let id = self.context.next_packet_id();
self.send(PacketHeader::pre_login(id), msg).await?;
let response: PreloginMessage = codec::collect_from(self).await?;
debug_assert_eq!(response.thread_id, 0);
Ok(response)
}
async fn login<'a>(
mut self,
auth: AuthMethod,
encryption: EncryptionLevel,
db: Option<String>,
server_name: Option<String>,
application_name: Option<String>,
prelogin: PreloginMessage,
) -> crate::Result<Self> {
let mut login_message = LoginMessage::new();
if let Some(db) = db {
login_message.db_name(db);
}
if let Some(server_name) = server_name {
login_message.server_name(server_name);
}
if let Some(app_name) = application_name {
login_message.app_name(app_name);
}
match auth {
#[cfg(all(windows, feature = "winauth"))]
AuthMethod::Integrated => {
let mut client = NtlmSspiBuilder::new()
.target_spn(self.context.spn())
.build()?;
login_message.integrated_security(client.next_bytes(None)?);
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
let sspi_bytes = self.flush_sspi().await?;
match client.next_bytes(Some(sspi_bytes.as_ref()))? {
Some(sspi_response) => {
event!(Level::TRACE, sspi_response_len = sspi_response.len());
let id = self.context.next_packet_id();
let header = PacketHeader::login(id);
let token = TokenSspi::new(sspi_response);
self.send(header, token).await?;
}
None => unreachable!(),
}
}
#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
AuthMethod::Integrated => {
let mut s = OidSet::new()?;
s.add(&GSS_MECH_KRB5)?;
let client_cred = Cred::acquire(None, None, CredUsage::Initiate, Some(&s))?;
let ctx = ClientCtx::new(
client_cred,
Name::new(self.context.spn().as_bytes(), Some(&GSS_NT_KRB5_PRINCIPAL))?,
CtxFlags::GSS_C_MUTUAL_FLAG | CtxFlags::GSS_C_SEQUENCE_FLAG,
None,
);
let init_token = ctx.step(None)?;
login_message.integrated_security(Some(Vec::from(init_token.unwrap().deref())));
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
let auth_bytes = self.flush_sspi().await?;
let next_token = match ctx.step(Some(auth_bytes.as_ref()))? {
Some(response) => {
event!(Level::TRACE, response_len = response.len());
TokenSspi::new(Vec::from(response.deref()))
}
None => {
event!(Level::TRACE, response_len = 0);
TokenSspi::new(Vec::new())
}
};
let id = self.context.next_packet_id();
let header = PacketHeader::login(id);
self.send(header, next_token).await?;
}
#[cfg(all(windows, feature = "winauth"))]
AuthMethod::Windows(auth) => {
let spn = self.context.spn().to_string();
let builder = winauth::NtlmV2ClientBuilder::new().target_spn(spn);
let mut client = builder.build(auth.domain, auth.user, auth.password);
login_message.integrated_security(client.next_bytes(None)?);
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
let sspi_bytes = self.flush_sspi().await?;
match client.next_bytes(Some(sspi_bytes.as_ref()))? {
Some(sspi_response) => {
event!(Level::TRACE, sspi_response_len = sspi_response.len());
let id = self.context.next_packet_id();
let header = PacketHeader::login(id);
let token = TokenSspi::new(sspi_response);
self.send(header, token).await?;
}
None => unreachable!(),
}
}
AuthMethod::None => {
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
}
AuthMethod::SqlServer(auth) => {
login_message.user_name(auth.user());
login_message.password(auth.password());
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
}
AuthMethod::AADToken(token) => {
login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce);
let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self = self.post_login_encryption(encryption);
}
}
Ok(self)
}
#[cfg(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
))]
async fn tls_handshake(
self,
config: &Config,
encryption: EncryptionLevel,
) -> crate::Result<Self> {
if encryption != EncryptionLevel::NotSupported {
event!(Level::INFO, "Performing a TLS handshake");
let Self {
transport, context, ..
} = self;
let mut stream = match transport.into_inner() {
MaybeTlsStream::Raw(tcp) => {
create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await?
}
_ => unreachable!(),
};
stream.get_mut().handshake_complete();
event!(Level::INFO, "TLS handshake successful");
let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec);
Ok(Self {
transport,
context,
flushed: false,
buf: BytesMut::new(),
})
} else {
event!(
Level::WARN,
"TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
);
Ok(self)
}
}
#[cfg(not(any(
feature = "rustls",
feature = "native-tls",
feature = "vendored-openssl"
)))]
async fn tls_handshake(self, _: &Config, _: EncryptionLevel) -> crate::Result<Self> {
event!(
Level::WARN,
"TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
);
Ok(self)
}
pub(crate) async fn close(mut self) -> crate::Result<()> {
self.transport.close().await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Stream for Connection<S> {
type Item = crate::Result<Packet>;
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match ready!(this.transport.try_poll_next_unpin(cx)) {
Some(Ok(packet)) => {
this.flushed = packet.is_last();
Poll::Ready(Some(Ok(packet)))
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> futures::AsyncRead for Connection<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut this = self.get_mut();
let size = buf.len();
if this.buf.len() < size {
while let Some(item) = ready!(Pin::new(&mut this).try_poll_next(cx)) {
match item {
Ok(packet) => {
let (_, payload) = packet.into_parts();
this.buf.extend(payload);
if this.buf.len() >= size {
break;
}
}
Err(e) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
e.to_string(),
)))
}
}
}
if this.buf.len() < size {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"No more packets in the wire",
)));
}
}
buf.copy_from_slice(this.buf.split_to(size).as_ref());
Poll::Ready(Ok(size))
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> SqlReadBytes for Connection<S> {
fn debug_buffer(&self) {
dbg!(self.buf.as_ref().hex_dump());
}
fn context(&self) -> &Context {
&self.context
}
fn context_mut(&mut self) -> &mut Context {
&mut self.context
}
}