tiberius/client/
connection.rs

1#[cfg(any(
2    feature = "rustls",
3    feature = "native-tls",
4    feature = "vendored-openssl"
5))]
6use crate::client::{tls::TlsPreloginWrapper, tls_stream::create_tls_stream};
7use crate::{
8    client::{tls::MaybeTlsStream, AuthMethod, Config},
9    tds::{
10        codec::{
11            self, Encode, LoginMessage, Packet, PacketCodec, PacketHeader, PacketStatus,
12            PreloginMessage, TokenDone,
13        },
14        stream::TokenStream,
15        Context, HEADER_BYTES,
16    },
17    EncryptionLevel, SqlReadBytes,
18};
19use asynchronous_codec::Framed;
20use bytes::BytesMut;
21#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
22use codec::TokenSspi;
23use futures_util::io::{AsyncRead, AsyncWrite};
24use futures_util::ready;
25use futures_util::sink::SinkExt;
26use futures_util::stream::{Stream, TryStream, TryStreamExt};
27#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
28use libgssapi::{
29    context::{ClientCtx, CtxFlags},
30    credential::{Cred, CredUsage},
31    name::Name,
32    oid::{OidSet, GSS_MECH_KRB5, GSS_NT_KRB5_PRINCIPAL},
33};
34use pretty_hex::*;
35#[cfg(all(unix, feature = "integrated-auth-gssapi"))]
36use std::ops::Deref;
37use std::{cmp, fmt::Debug, io, pin::Pin, task};
38use task::Poll;
39use tracing::{event, Level};
40#[cfg(all(windows, feature = "winauth"))]
41use winauth::{windows::NtlmSspiBuilder, NextBytes};
42
43/// A `Connection` is an abstraction between the [`Client`] and the server. It
44/// can be used as a `Stream` to fetch [`Packet`]s from and to `send` packets
45/// splitting them to the negotiated limit automatically.
46///
47/// `Connection` is not meant to use directly, but as an abstraction layer for
48/// the numerous `Stream`s for easy packet handling.
49///
50/// [`Client`]: struct.Encode.html
51/// [`Packet`]: ../protocol/codec/struct.Packet.html
52pub(crate) struct Connection<S>
53where
54    S: AsyncRead + AsyncWrite + Unpin + Send,
55{
56    transport: Framed<MaybeTlsStream<S>, PacketCodec>,
57    flushed: bool,
58    context: Context,
59    buf: BytesMut,
60}
61
62impl<S: AsyncRead + AsyncWrite + Unpin + Send> Debug for Connection<S> {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("Connection")
65            .field("transport", &"Framed<..>")
66            .field("flushed", &self.flushed)
67            .field("context", &self.context)
68            .field("buf", &self.buf.as_ref().hex_dump())
69            .finish()
70    }
71}
72
73impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
74    /// Creates a new connection
75    pub(crate) async fn connect(config: Config, tcp_stream: S) -> crate::Result<Connection<S>> {
76        let context = {
77            let mut context = Context::new();
78            context.set_spn(config.get_host(), config.get_port());
79            context
80        };
81
82        let transport = Framed::new(MaybeTlsStream::Raw(tcp_stream), PacketCodec);
83
84        let mut connection = Self {
85            transport,
86            context,
87            flushed: false,
88            buf: BytesMut::new(),
89        };
90
91        let fed_auth_required = matches!(config.auth, AuthMethod::AADToken(_));
92
93        let prelogin = connection
94            .prelogin(config.encryption, fed_auth_required)
95            .await?;
96
97        let encryption = prelogin.negotiated_encryption(config.encryption);
98
99        let connection = connection.tls_handshake(&config, encryption).await?;
100
101        let mut connection = connection
102            .login(
103                config.auth,
104                encryption,
105                config.database,
106                config.host,
107                config.application_name,
108                config.readonly,
109                prelogin,
110            )
111            .await?;
112
113        connection.flush_done().await?;
114
115        Ok(connection)
116    }
117
118    /// Flush the incoming token stream until receiving `DONE` token.
119    async fn flush_done(&mut self) -> crate::Result<TokenDone> {
120        TokenStream::new(self).flush_done().await
121    }
122
123    #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
124    /// Flush the incoming token stream until receiving `SSPI` token.
125    async fn flush_sspi(&mut self) -> crate::Result<TokenSspi> {
126        TokenStream::new(self).flush_sspi().await
127    }
128
129    #[cfg(any(
130        feature = "rustls",
131        feature = "native-tls",
132        feature = "vendored-openssl"
133    ))]
134    fn post_login_encryption(mut self, encryption: EncryptionLevel) -> Self {
135        if let EncryptionLevel::Off = encryption {
136            event!(
137                Level::WARN,
138                "Turning TLS off after a login. All traffic from here on is not encrypted.",
139            );
140
141            let Self { transport, .. } = self;
142            let tcp = transport.into_inner().into_inner();
143            self.transport = Framed::new(MaybeTlsStream::Raw(tcp), PacketCodec);
144        }
145
146        self
147    }
148
149    #[cfg(not(any(
150        feature = "rustls",
151        feature = "native-tls",
152        feature = "vendored-openssl"
153    )))]
154    fn post_login_encryption(self, _: EncryptionLevel) -> Self {
155        self
156    }
157
158    /// Send an item to the wire. Header should define the item type and item should implement
159    /// [`Encode`], defining the byte structure for the wire.
160    ///
161    /// The `send` will split the packet into multiple packets if bigger than
162    /// the negotiated packet size, and handle flushing to the wire in an optimal way.
163    ///
164    /// [`Encode`]: ../protocol/codec/trait.Encode.html
165    pub async fn send<E>(&mut self, mut header: PacketHeader, item: E) -> crate::Result<()>
166    where
167        E: Sized + Encode<BytesMut>,
168    {
169        self.flushed = false;
170        let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES;
171
172        let mut payload = BytesMut::new();
173        item.encode(&mut payload)?;
174
175        while !payload.is_empty() {
176            let writable = cmp::min(payload.len(), packet_size);
177            let split_payload = payload.split_to(writable);
178
179            if payload.is_empty() {
180                header.set_status(PacketStatus::EndOfMessage);
181            } else {
182                header.set_status(PacketStatus::NormalMessage);
183            }
184
185            event!(
186                Level::TRACE,
187                "Sending a packet ({} bytes)",
188                split_payload.len() + HEADER_BYTES,
189            );
190
191            self.write_to_wire(header, split_payload).await?;
192        }
193
194        self.flush_sink().await?;
195
196        Ok(())
197    }
198
199    /// Sends a packet of data to the database.
200    ///
201    /// # Warning
202    ///
203    /// Please be sure the packet size doesn't exceed the largest allowed size
204    /// dictaded by the server.
205    pub(crate) async fn write_to_wire(
206        &mut self,
207        header: PacketHeader,
208        data: BytesMut,
209    ) -> crate::Result<()> {
210        self.flushed = false;
211
212        let packet = Packet::new(header, data);
213        self.transport.send(packet).await?;
214
215        Ok(())
216    }
217
218    /// Sends all pending packages to the wire.
219    pub(crate) async fn flush_sink(&mut self) -> crate::Result<()> {
220        self.transport.flush().await
221    }
222
223    /// Cleans the packet stream from previous use. It is important to use the
224    /// whole stream before using the connection again. Flushing the stream
225    /// makes sure we don't have any old data causing undefined behaviour after
226    /// previous queries.
227    ///
228    /// Calling this will slow down the queries if stream is still dirty if all
229    /// results are not handled.
230    pub async fn flush_stream(&mut self) -> crate::Result<()> {
231        self.buf.truncate(0);
232
233        if self.flushed {
234            return Ok(());
235        }
236
237        while let Some(packet) = self.try_next().await? {
238            event!(
239                Level::WARN,
240                "Flushing unhandled packet from the wire. Please consume your streams!",
241            );
242
243            let is_last = packet.is_last();
244
245            if is_last {
246                break;
247            }
248        }
249
250        Ok(())
251    }
252
253    /// True if the underlying stream has no more data and is consumed
254    /// completely.
255    pub fn is_eof(&self) -> bool {
256        self.flushed && self.buf.is_empty()
257    }
258
259    /// A message sent by the client to set up context for login. The server
260    /// responds to a client PRELOGIN message with a message of packet header
261    /// type 0x04 and with the packet data containing a PRELOGIN structure.
262    ///
263    /// This message stream is also used to wrap the TLS handshake payload if
264    /// encryption is needed. In this scenario, where PRELOGIN message is
265    /// transporting the TLS handshake payload, the packet data is simply the
266    /// raw bytes of the TLS handshake payload.
267    async fn prelogin(
268        &mut self,
269        encryption: EncryptionLevel,
270        fed_auth_required: bool,
271    ) -> crate::Result<PreloginMessage> {
272        let mut msg = PreloginMessage::new();
273        msg.encryption = encryption;
274        msg.fed_auth_required = fed_auth_required;
275
276        let id = self.context.next_packet_id();
277        self.send(PacketHeader::pre_login(id), msg).await?;
278
279        let response: PreloginMessage = codec::collect_from(self).await?;
280        // threadid (should be empty when sent from server to client)
281        debug_assert_eq!(response.thread_id, 0);
282        Ok(response)
283    }
284
285    /// Defines the login record rules with SQL Server. Authentication with
286    /// connection options.
287    #[allow(clippy::too_many_arguments)]
288    async fn login<'a>(
289        mut self,
290        auth: AuthMethod,
291        encryption: EncryptionLevel,
292        db: Option<String>,
293        server_name: Option<String>,
294        application_name: Option<String>,
295        readonly: bool,
296        prelogin: PreloginMessage,
297    ) -> crate::Result<Self> {
298        let mut login_message = LoginMessage::new();
299
300        if let Some(db) = db {
301            login_message.db_name(db);
302        }
303
304        if let Some(server_name) = server_name {
305            login_message.server_name(server_name);
306        }
307
308        if let Some(app_name) = application_name {
309            login_message.app_name(app_name);
310        }
311
312        login_message.readonly(readonly);
313
314        match auth {
315            #[cfg(all(windows, feature = "winauth"))]
316            AuthMethod::Integrated => {
317                let mut client = NtlmSspiBuilder::new()
318                    .target_spn(self.context.spn())
319                    .build()?;
320
321                login_message.integrated_security(client.next_bytes(None)?);
322
323                let id = self.context.next_packet_id();
324                self.send(PacketHeader::login(id), login_message).await?;
325
326                self = self.post_login_encryption(encryption);
327
328                let sspi_bytes = self.flush_sspi().await?;
329
330                match client.next_bytes(Some(sspi_bytes.as_ref()))? {
331                    Some(sspi_response) => {
332                        event!(Level::TRACE, sspi_response_len = sspi_response.len());
333
334                        let id = self.context.next_packet_id();
335                        let header = PacketHeader::login(id);
336
337                        let token = TokenSspi::new(sspi_response);
338                        self.send(header, token).await?;
339                    }
340                    None => unreachable!(),
341                }
342            }
343            #[cfg(all(unix, feature = "integrated-auth-gssapi"))]
344            AuthMethod::Integrated => {
345                let mut s = OidSet::new()?;
346                s.add(&GSS_MECH_KRB5)?;
347
348                let client_cred = Cred::acquire(None, None, CredUsage::Initiate, Some(&s))?;
349
350                let ctx = ClientCtx::new(
351                    client_cred,
352                    Name::new(self.context.spn().as_bytes(), Some(&GSS_NT_KRB5_PRINCIPAL))?,
353                    CtxFlags::GSS_C_MUTUAL_FLAG | CtxFlags::GSS_C_SEQUENCE_FLAG,
354                    None,
355                );
356
357                let init_token = ctx.step(None)?;
358
359                login_message.integrated_security(Some(Vec::from(init_token.unwrap().deref())));
360
361                let id = self.context.next_packet_id();
362                self.send(PacketHeader::login(id), login_message).await?;
363
364                self = self.post_login_encryption(encryption);
365
366                let auth_bytes = self.flush_sspi().await?;
367
368                let next_token = match ctx.step(Some(auth_bytes.as_ref()))? {
369                    Some(response) => {
370                        event!(Level::TRACE, response_len = response.len());
371                        TokenSspi::new(Vec::from(response.deref()))
372                    }
373                    None => {
374                        event!(Level::TRACE, response_len = 0);
375                        TokenSspi::new(Vec::new())
376                    }
377                };
378
379                let id = self.context.next_packet_id();
380                let header = PacketHeader::login(id);
381
382                self.send(header, next_token).await?;
383            }
384            #[cfg(all(windows, feature = "winauth"))]
385            AuthMethod::Windows(auth) => {
386                let spn = self.context.spn().to_string();
387                let builder = winauth::NtlmV2ClientBuilder::new().target_spn(spn);
388                let mut client = builder.build(auth.domain, auth.user, auth.password);
389
390                login_message.integrated_security(client.next_bytes(None)?);
391
392                let id = self.context.next_packet_id();
393                self.send(PacketHeader::login(id), login_message).await?;
394
395                self = self.post_login_encryption(encryption);
396
397                let sspi_bytes = self.flush_sspi().await?;
398
399                match client.next_bytes(Some(sspi_bytes.as_ref()))? {
400                    Some(sspi_response) => {
401                        event!(Level::TRACE, sspi_response_len = sspi_response.len());
402
403                        let id = self.context.next_packet_id();
404                        let header = PacketHeader::login(id);
405
406                        let token = TokenSspi::new(sspi_response);
407                        self.send(header, token).await?;
408                    }
409                    None => unreachable!(),
410                }
411            }
412            AuthMethod::None => {
413                let id = self.context.next_packet_id();
414                self.send(PacketHeader::login(id), login_message).await?;
415                self = self.post_login_encryption(encryption);
416            }
417            AuthMethod::SqlServer(auth) => {
418                login_message.user_name(auth.user());
419                login_message.password(auth.password());
420
421                let id = self.context.next_packet_id();
422                self.send(PacketHeader::login(id), login_message).await?;
423                self = self.post_login_encryption(encryption);
424            }
425            AuthMethod::AADToken(token) => {
426                login_message.aad_token(token, prelogin.fed_auth_required, prelogin.nonce);
427                let id = self.context.next_packet_id();
428                self.send(PacketHeader::login(id), login_message).await?;
429                self = self.post_login_encryption(encryption);
430            }
431        }
432
433        Ok(self)
434    }
435
436    /// Implements the TLS handshake with the SQL Server.
437    #[cfg(any(
438        feature = "rustls",
439        feature = "native-tls",
440        feature = "vendored-openssl"
441    ))]
442    async fn tls_handshake(
443        self,
444        config: &Config,
445        encryption: EncryptionLevel,
446    ) -> crate::Result<Self> {
447        if encryption != EncryptionLevel::NotSupported {
448            event!(Level::INFO, "Performing a TLS handshake");
449
450            let Self {
451                transport, context, ..
452            } = self;
453            let mut stream = match transport.into_inner() {
454                MaybeTlsStream::Raw(tcp) => {
455                    create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await?
456                }
457                _ => unreachable!(),
458            };
459
460            stream.get_mut().handshake_complete();
461            event!(Level::INFO, "TLS handshake successful");
462
463            let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec);
464
465            Ok(Self {
466                transport,
467                context,
468                flushed: false,
469                buf: BytesMut::new(),
470            })
471        } else {
472            event!(
473                Level::WARN,
474                "TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
475            );
476
477            Ok(self)
478        }
479    }
480
481    /// Implements the TLS handshake with the SQL Server.
482    #[cfg(not(any(
483        feature = "rustls",
484        feature = "native-tls",
485        feature = "vendored-openssl"
486    )))]
487    async fn tls_handshake(self, _: &Config, _: EncryptionLevel) -> crate::Result<Self> {
488        event!(
489            Level::WARN,
490            "TLS encryption is not enabled. All traffic including the login credentials are not encrypted."
491        );
492
493        Ok(self)
494    }
495
496    pub(crate) async fn close(mut self) -> crate::Result<()> {
497        self.transport.close().await
498    }
499}
500
501impl<S: AsyncRead + AsyncWrite + Unpin + Send> Stream for Connection<S> {
502    type Item = crate::Result<Packet>;
503
504    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
505        let this = self.get_mut();
506
507        match ready!(this.transport.try_poll_next_unpin(cx)) {
508            Some(Ok(packet)) => {
509                this.flushed = packet.is_last();
510                Poll::Ready(Some(Ok(packet)))
511            }
512            Some(Err(e)) => Poll::Ready(Some(Err(e))),
513            None => Poll::Ready(None),
514        }
515    }
516}
517
518impl<S: AsyncRead + AsyncWrite + Unpin + Send> futures_util::io::AsyncRead for Connection<S> {
519    fn poll_read(
520        self: Pin<&mut Self>,
521        cx: &mut task::Context<'_>,
522        buf: &mut [u8],
523    ) -> Poll<io::Result<usize>> {
524        let mut this = self.get_mut();
525        let size = buf.len();
526
527        if this.buf.len() < size {
528            while let Some(item) = ready!(Pin::new(&mut this).try_poll_next(cx)) {
529                match item {
530                    Ok(packet) => {
531                        let (_, payload) = packet.into_parts();
532                        this.buf.extend(payload);
533
534                        if this.buf.len() >= size {
535                            break;
536                        }
537                    }
538                    Err(e) => {
539                        return Poll::Ready(Err(io::Error::new(
540                            io::ErrorKind::BrokenPipe,
541                            e.to_string(),
542                        )))
543                    }
544                }
545            }
546
547            // Got EOF before having all the data.
548            if this.buf.len() < size {
549                return Poll::Ready(Err(io::Error::new(
550                    io::ErrorKind::UnexpectedEof,
551                    "No more packets in the wire",
552                )));
553            }
554        }
555
556        buf.copy_from_slice(this.buf.split_to(size).as_ref());
557        Poll::Ready(Ok(size))
558    }
559}
560
561impl<S: AsyncRead + AsyncWrite + Unpin + Send> SqlReadBytes for Connection<S> {
562    /// Hex dump of the current buffer.
563    fn debug_buffer(&self) {
564        dbg!(self.buf.as_ref().hex_dump());
565    }
566
567    /// The current execution context.
568    fn context(&self) -> &Context {
569        &self.context
570    }
571
572    /// A mutable reference to the current execution context.
573    fn context_mut(&mut self) -> &mut Context {
574        &mut self.context
575    }
576}