1#[cfg(any(
2 feature = "rustls",
3 feature = "native-tls",
4 feature = "vendored-openssl"
5))]
6use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream};
7use crate::{
8 client::{tls::MaybeTlsStream, AuthMethod, Config},
9 tds::{
10 codec::{
11 self, Encode, LoginMessage, Packet, PacketCodec, PacketHeader, PacketStatus,
12 PreloginMessage, TokenDone,
13 },
14 stream::TokenStream,
15 Context, HEADER_BYTES,
16 },
17 EncryptionLevel, SqlReadBytes,
18};
19use asynchronous_codec::Framed;
20use bytes::BytesMut;
21#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
22use codec::TokenSspi;
23use futures_util::io::{AsyncRead, AsyncWrite};
24use futures_util::ready;
25use futures_util::sink::SinkExt;
26use futures_util::stream::{Stream, TryStream, TryStreamExt};
27#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
28use libgssapi::{
29 context::{ClientCtx, CtxFlags},
30 credential::{Cred, CredUsage},
31 name::Name,
32 oid::{OidSet, GSS_MECH_KRB5, GSS_NT_KRB5_PRINCIPAL},
33};
34use pretty_hex::*;
35#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
36use std::ops::Deref;
37use std::{cmp, fmt::Debug, io, pin::Pin, task};
38use task::Poll;
39use tracing::{event, Level};
40#[cfg(all(windows, feature = "winauth"))]
41use winauth::{windows::NtlmSspiBuilder, NextBytes};
42
43pub(crate) struct Connection<S>
53where
54 S: AsyncRead + AsyncWrite + Unpin + Send,
55{
56 transport: Framed<MaybeTlsStream<S>, PacketCodec>,
57 flushed: bool,
58 context: Context,
59 buf: BytesMut,
60}
61
62impl<S: AsyncRead + AsyncWrite + Unpin + Send> Debug for Connection<S> {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("Connection")
65 .field("transport", &"Framed<..>")
66 .field("flushed", &self.flushed)
67 .field("context", &self.context)
68 .field("buf", &self.buf.as_ref().hex_dump())
69 .finish()
70 }
71}
72
73impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
74 pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result<Connection<S>> {
76 let context = {
77 let mut context = Context::new();
78 context.set_spn(config.get_host(), config.get_port());
79 context
80 };
81
82 let transport = Framed::new(MaybeTlsStream::Raw(tcp_stream), PacketCodec);
83
84 let mut connection = Self {
85 transport,
86 context,
87 flushed: false,
88 buf: BytesMut::new(),
89 };
90
91 let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_));
92
93 let prelogin = connection
94 .prelogin(config.encryption, fed_auth_required)
95 .await?;
96
97 let encryption = prelogin.negotiated_encryption(config.encryption);
98
99 let connection = connection.tls_handshake(&config, encryption).await?;
100
101 let mut connection = connection
102 .login(
103 config.auth,
104 encryption,
105 config.database,
106 config.host,
107 config.application_name,
108 config.readonly,
109 prelogin,
110 )
111 .await?;
112
113 connection.flush_done().await?;
114
115 Ok(connection)
116 }
117
118 async fn flush_done(&mut self) -> crate::Result<TokenDone> {
120 TokenStream::new(self).flush_done().await
121 }
122
123 #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
124 async fn flush_sspi(&mut self) -> crate::Result<TokenSspi> {
126 TokenStream::new(self).flush_sspi().await
127 }
128
129 #[cfg(any(
130 feature = "rustls",
131 feature = "native-tls",
132 feature = "vendored-openssl"
133 ))]
134 fn post_login_encryption(mut self, encryption: EncryptionLevel) -> Self {
135 if let EncryptionLevel::Off = encryption {
136 event!(
137 Level::WARN,
138 "Turning TLS off after a login. All traffic from here on is not encrypted.",
139 );
140
141 let Self { transport, .. } = self;
142 let tcp = transport.into_inner().into_inner();
143 self.transport = Framed::new(MaybeTlsStream::Raw(tcp), PacketCodec);
144 }
145
146 self
147 }
148
149 #[cfg(not(any(
150 feature = "rustls",
151 feature = "native-tls",
152 feature = "vendored-openssl"
153 )))]
154 fn post_login_encryption(self, _: EncryptionLevel) -> Self {
155 self
156 }
157
158 pub async fn send<E>(&mut self, mut header: PacketHeader, item: E) -> crate::Result<()>
166 where
167 E: Sized + Encode<BytesMut>,
168 {
169 self.flushed = false;
170 let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES;
171
172 let mut payload = BytesMut::new();
173 item.encode(&mut payload)?;
174
175 while !payload.is_empty() {
176 let writable = cmp::min(payload.len(), packet_size);
177 let split_payload = payload.split_to(writable);
178
179 if payload.is_empty() {
180 header.set_status(PacketStatus::EndOfMessage);
181 } else {
182 header.set_status(PacketStatus::NormalMessage);
183 }
184
185 event!(
186 Level::TRACE,
187 "Sending a packet ({} bytes)",
188 split_payload.len() + HEADER_BYTES,
189 );
190
191 self.write_to_wire(header, split_payload).await?;
192 }
193
194 self.flush_sink().await?;
195
196 Ok(())
197 }
198
199 pub(crate) async fn write_to_wire(
206 &mut self,
207 header: PacketHeader,
208 data: BytesMut,
209 ) -> crate::Result<()> {
210 self.flushed = false;
211
212 let packet = Packet::new(header, data);
213 self.transport.send(packet).await?;
214
215 Ok(())
216 }
217
218 pub(crate) async fn flush_sink(&mut self) -> crate::Result<()> {
220 self.transport.flush().await
221 }
222
223 pub async fn flush_stream(&mut self) -> crate::Result<()> {
231 self.buf.truncate(0);
232
233 if self.flushed {
234 return Ok(());
235 }
236
237 while let Some(packet) = self.try_next().await? {
238 event!(
239 Level::WARN,
240 "Flushing unhandled packet from the wire. Please consume your streams!",
241 );
242
243 let is_last = packet.is_last();
244
245 if is_last {
246 break;
247 }
248 }
249
250 Ok(())
251 }
252
253 pub fn is_eof(&self) -> bool {
256 self.flushed && self.buf.is_empty()
257 }
258
259 async fn prelogin(
268 &mut self,
269 encryption: EncryptionLevel,
270 fed_auth_required: bool,
271 ) -> crate::Result<PreloginMessage> {
272 let mut msg = PreloginMessage::new();
273 msg.encryption = encryption;
274 msg.fed_auth_required = fed_auth_required;
275
276 let id = self.context.next_packet_id();
277 self.send(PacketHeader::pre_login(id), msg).await?;
278
279 let response: PreloginMessage = codec::collect_from(self).await?;
280 debug_assert_eq!(response.thread_id, 0);
282 Ok(response)
283 }
284
285 #[allow(clippy::too_many_arguments)]
288 async fn login<'a>(
289 mut self,
290 auth: AuthMethod,
291 encryption: EncryptionLevel,
292 db: Option<String>,
293 server_name: Option<String>,
294 application_name: Option<String>,
295 readonly: bool,
296 prelogin: PreloginMessage,
297 ) -> crate::Result<Self> {
298 let mut login_message = LoginMessage::new();
299
300 if let Some(db) = db {
301 login_message.db_name(db);
302 }
303
304 if let Some(server_name) = server_name {
305 login_message.server_name(server_name);
306 }
307
308 if let Some(app_name) = application_name {
309 login_message.app_name(app_name);
310 }
311
312 login_message.readonly(readonly);
313
314 match auth {
315 #[cfg(all(windows, feature = "winauth"))]
316 AuthMethod::Integrated => {
317 let mut client = NtlmSspiBuilder::new()
318 .target_spn(self.context.spn())
319 .build()?;
320
321 login_message.integrated_security(client.next_bytes(None)?);
322
323 let id = self.context.next_packet_id();
324 self.send(PacketHeader::login(id), login_message).await?;
325
326 self = self.post_login_encryption(encryption);
327
328 let sspi_bytes = self.flush_sspi().await?;
329
330 match client.next_bytes(Some(sspi_bytes.as_ref()))? {
331 Some(sspi_response) => {
332 event!(Level::TRACE, sspi_response_len = sspi_response.len());
333
334 let id = self.context.next_packet_id();
335 let header = PacketHeader::login(id);
336
337 let token = TokenSspi::new(sspi_response);
338 self.send(header, token).await?;
339 }
340 None => unreachable!(),
341 }
342 }
343 #[cfg(all(unix, feature = "integrated-auth-gssapi"))]
344 AuthMethod::Integrated => {
345 let mut s = OidSet::new()?;
346 s.add(&GSS_MECH_KRB5)?;
347
348 let client_cred = Cred::acquire(None, None, CredUsage::Initiate, Some(&s))?;
349
350 let ctx = ClientCtx::new(
351 client_cred,
352 Name::new(self.context.spn().as_bytes(), Some(&GSS_NT_KRB5_PRINCIPAL))?,
353 CtxFlags::GSS_C_MUTUAL_FLAG | CtxFlags::GSS_C_SEQUENCE_FLAG,
354 None,
355 );
356
357 let init_token = ctx.step(None)?;
358
359 login_message.integrated_security(Some(Vec::from(init_token.unwrap().deref())));
360
361 let id = self.context.next_packet_id();
362 self.send(PacketHeader::login(id), login_message).await?;
363
364 self = self.post_login_encryption(encryption);
365
366 let auth_bytes = self.flush_sspi().await?;
367
368 let next_token = match ctx.step(Some(auth_bytes.as_ref()))? {
369 Some(response) => {
370 event!(Level::TRACE, response_len = response.len());
371 TokenSspi::new(Vec::from(response.deref()))
372 }
373 None => {
374 event!(Level::TRACE, response_len = 0);
375 TokenSspi::new(Vec::new())
376 }
377 };
378
379 let id = self.context.next_packet_id();
380 let header = PacketHeader::login(id);
381
382 self.send(header, next_token).await?;
383 }
384 #[cfg(all(windows, feature = "winauth"))]
385 AuthMethod::Windows(auth) => {
386 let spn = self.context.spn().to_string();
387 let builder = winauth::NtlmV2ClientBuilder::new().target_spn(spn);
388 let mut client = builder.build(auth.domain, auth.user, auth.password);
389
390 login_message.integrated_security(client.next_bytes(None)?);
391
392 let id = self.context.next_packet_id();
393 self.send(PacketHeader::login(id), login_message).await?;
394
395 self = self.post_login_encryption(encryption);
396
397 let sspi_bytes = self.flush_sspi().await?;
398
399 match client.next_bytes(Some(sspi_bytes.as_ref()))? {
400 Some(sspi_response) => {
401 event!(Level::TRACE, sspi_response_len = sspi_response.len());
402
403 let id = self.context.next_packet_id();
404 let header = PacketHeader::login(id);
405
406 let token = TokenSspi::new(sspi_response);
407 self.send(header, token).await?;
408 }
409 None => unreachable!(),
410 }
411 }
412 AuthMethod::None => {
413 let id = self.context.next_packet_id();
414 self.send(PacketHeader::login(id), login_message).await?;
415 self = self.post_login_encryption(encryption);
416 }
417 AuthMethod::SqlServer(auth) => {
418 login_message.user_name(auth.user());
419 login_message.password(auth.password());
420
421 let id = self.context.next_packet_id();
422 self.send(PacketHeader::login(id), login_message).await?;
423 self = self.post_login_encryption(encryption);
424 }
425 AuthMethod::AADToken(token) => {
426 login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce);
427 let id = self.context.next_packet_id();
428 self.send(PacketHeader::login(id), login_message).await?;
429 self = self.post_login_encryption(encryption);
430 }
431 }
432
433 Ok(self)
434 }
435
436 #[cfg(any(
438 feature = "rustls",
439 feature = "native-tls",
440 feature = "vendored-openssl"
441 ))]
442 async fn tls_handshake(
443 self,
444 config: &Config,
445 encryption: EncryptionLevel,
446 ) -> crate::Result<Self> {
447 if encryption != EncryptionLevel::NotSupported {
448 event!(Level::INFO, "Performing a TLS handshake");
449
450 let Self {
451 transport, context, ..
452 } = self;
453 let mut stream = match transport.into_inner() {
454 MaybeTlsStream::Raw(tcp) => {
455 create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await?
456 }
457 _ => unreachable!(),
458 };
459
460 stream.get_mut().handshake_complete();
461 event!(Level::INFO, "TLS handshake successful");
462
463 let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec);
464
465 Ok(Self {
466 transport,
467 context,
468 flushed: false,
469 buf: BytesMut::new(),
470 })
471 } else {
472 event!(
473 Level::WARN,
474 "TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
475 );
476
477 Ok(self)
478 }
479 }
480
481 #[cfg(not(any(
483 feature = "rustls",
484 feature = "native-tls",
485 feature = "vendored-openssl"
486 )))]
487 async fn tls_handshake(self, _: &Config, _: EncryptionLevel) -> crate::Result<Self> {
488 event!(
489 Level::WARN,
490 "TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
491 );
492
493 Ok(self)
494 }
495
496 pub(crate) async fn close(mut self) -> crate::Result<()> {
497 self.transport.close().await
498 }
499}
500
501impl<S: AsyncRead + AsyncWrite + Unpin + Send> Stream for Connection<S> {
502 type Item = crate::Result<Packet>;
503
504 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
505 let this = self.get_mut();
506
507 match ready!(this.transport.try_poll_next_unpin(cx)) {
508 Some(Ok(packet)) => {
509 this.flushed = packet.is_last();
510 Poll::Ready(Some(Ok(packet)))
511 }
512 Some(Err(e)) => Poll::Ready(Some(Err(e))),
513 None => Poll::Ready(None),
514 }
515 }
516}
517
518impl<S: AsyncRead + AsyncWrite + Unpin + Send> futures_util::io::AsyncRead for Connection<S> {
519 fn poll_read(
520 self: Pin<&mut Self>,
521 cx: &mut task::Context<'_>,
522 buf: &mut [u8],
523 ) -> Poll<io::Result<usize>> {
524 let mut this = self.get_mut();
525 let size = buf.len();
526
527 if this.buf.len() < size {
528 while let Some(item) = ready!(Pin::new(&mut this).try_poll_next(cx)) {
529 match item {
530 Ok(packet) => {
531 let (_, payload) = packet.into_parts();
532 this.buf.extend(payload);
533
534 if this.buf.len() >= size {
535 break;
536 }
537 }
538 Err(e) => {
539 return Poll::Ready(Err(io::Error::new(
540 io::ErrorKind::BrokenPipe,
541 e.to_string(),
542 )))
543 }
544 }
545 }
546
547 if this.buf.len() < size {
549 return Poll::Ready(Err(io::Error::new(
550 io::ErrorKind::UnexpectedEof,
551 "No more packets in the wire",
552 )));
553 }
554 }
555
556 buf.copy_from_slice(this.buf.split_to(size).as_ref());
557 Poll::Ready(Ok(size))
558 }
559}
560
561impl<S: AsyncRead + AsyncWrite + Unpin + Send> SqlReadBytes for Connection<S> {
562 fn debug_buffer(&self) {
564 dbg!(self.buf.as_ref().hex_dump());
565 }
566
567 fn context(&self) -> &Context {
569 &self.context
570 }
571
572 fn context_mut(&mut self) -> &mut Context {
574 &mut self.context
575 }
576}