tiberius/client/
tls.rs

1#[cfg(any(
2    feature = "rustls",
3    feature = "native-tls",
4    feature = "vendored-openssl"
5))]
6use super::tls_stream::TlsStream;
7use crate::tds::{
8    codec::{Decode, Encode, PacketHeader, PacketStatus, PacketType},
9    HEADER_BYTES,
10};
11use bytes::BytesMut;
12use futures_util::io::{AsyncRead, AsyncWrite};
13use futures_util::ready;
14use std::{
15    cmp, io,
16    pin::Pin,
17    task::{self, Poll},
18};
19use tracing::{event, Level};
20
21/// A wrapper to handle either TLS or bare connections.
22pub(crate) enum MaybeTlsStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
23    Raw(S),
24    #[cfg(any(
25        feature = "rustls",
26        feature = "native-tls",
27        feature = "vendored-openssl"
28    ))]
29    Tls(TlsStream<TlsPreloginWrapper<S>>),
30}
31
32#[cfg(any(
33    feature = "rustls",
34    feature = "native-tls",
35    feature = "vendored-openssl"
36))]
37impl<S: AsyncRead + AsyncWrite + Unpin + Send> MaybeTlsStream<S> {
38    pub fn into_inner(self) -> S {
39        match self {
40            Self::Raw(s) => s,
41            #[cfg(any(
42                feature = "rustls",
43                feature = "native-tls",
44                feature = "vendored-openssl"
45            ))]
46            Self::Tls(mut tls) => tls.get_mut().stream.take().unwrap(),
47        }
48    }
49}
50
51impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for MaybeTlsStream<S> {
52    fn poll_read(
53        self: Pin<&mut Self>,
54        cx: &mut task::Context<'_>,
55        buf: &mut [u8],
56    ) -> Poll<io::Result<usize>> {
57        match self.get_mut() {
58            MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
59            #[cfg(any(
60                feature = "rustls",
61                feature = "native-tls",
62                feature = "vendored-openssl"
63            ))]
64            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
65        }
66    }
67}
68
69impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for MaybeTlsStream<S> {
70    fn poll_write(
71        self: Pin<&mut Self>,
72        cx: &mut task::Context<'_>,
73        buf: &[u8],
74    ) -> Poll<io::Result<usize>> {
75        match self.get_mut() {
76            MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
77            #[cfg(any(
78                feature = "rustls",
79                feature = "native-tls",
80                feature = "vendored-openssl"
81            ))]
82            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
83        }
84    }
85
86    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
87        match self.get_mut() {
88            MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
89            #[cfg(any(
90                feature = "rustls",
91                feature = "native-tls",
92                feature = "vendored-openssl"
93            ))]
94            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
95        }
96    }
97
98    fn poll_close(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
99        match self.get_mut() {
100            MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
101            #[cfg(any(
102                feature = "rustls",
103                feature = "native-tls",
104                feature = "vendored-openssl"
105            ))]
106            MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
107        }
108    }
109}
110
111/// On TLS handshake, the server expects to get and sends back normal TDS
112/// packets. To use a common TLS library, we must implement a wrapper for
113/// packet handling on this stage.
114///
115/// What it does is it interferes on handshake for TDS packet handling,
116/// and when complete, just passes the calls to the underlying connection.
117pub(crate) struct TlsPreloginWrapper<S> {
118    stream: Option<S>,
119    pending_handshake: bool,
120
121    header_buf: [u8; HEADER_BYTES],
122    header_pos: usize,
123    read_remaining: usize,
124
125    wr_buf: Vec<u8>,
126    header_written: bool,
127}
128
129#[cfg(any(
130    feature = "rustls",
131    feature = "native-tls",
132    feature = "vendored-openssl"
133))]
134impl<S> TlsPreloginWrapper<S> {
135    pub fn new(stream: S) -> Self {
136        TlsPreloginWrapper {
137            stream: Some(stream),
138            pending_handshake: true,
139
140            header_buf: [0u8; HEADER_BYTES],
141            header_pos: 0,
142            read_remaining: 0,
143            wr_buf: vec![0u8; HEADER_BYTES],
144            header_written: false,
145        }
146    }
147
148    pub fn handshake_complete(&mut self) {
149        self.pending_handshake = false;
150    }
151}
152
153impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<S> {
154    fn poll_read(
155        mut self: Pin<&mut Self>,
156        cx: &mut task::Context<'_>,
157        buf: &mut [u8],
158    ) -> Poll<io::Result<usize>> {
159        // Normal operation does not need any extra treatment, we handle packets
160        // in the codec.
161        if !self.pending_handshake {
162            return Pin::new(&mut self.stream.as_mut().unwrap()).poll_read(cx, buf);
163        }
164
165        let inner = self.get_mut();
166
167        // Read the headers separately and do not send them to the Tls
168        // connection handling.
169        if !inner.header_buf[inner.header_pos..].is_empty() {
170            while !inner.header_buf[inner.header_pos..].is_empty() {
171                let read = ready!(Pin::new(inner.stream.as_mut().unwrap())
172                    .poll_read(cx, &mut inner.header_buf[inner.header_pos..]))?;
173
174                if read == 0 {
175                    return Poll::Ready(Ok(0));
176                }
177
178                inner.header_pos += read;
179            }
180
181            let header = PacketHeader::decode(&mut BytesMut::from(&inner.header_buf[..]))
182                .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
183
184            // We only get pre-login packets in the handshake process.
185            assert_eq!(header.r#type(), PacketType::PreLogin);
186
187            // And we know from this point on how much data we should expect
188            inner.read_remaining = header.length() as usize - HEADER_BYTES;
189
190            event!(
191                Level::TRACE,
192                "Reading packet of {} bytes",
193                inner.read_remaining,
194            );
195        }
196
197        let max_read = cmp::min(inner.read_remaining, buf.len());
198
199        // TLS connector gets whatever we have after the header.
200        let read = ready!(
201            Pin::new(&mut inner.stream.as_mut().unwrap()).poll_read(cx, &mut buf[..max_read])
202        )?;
203
204        inner.read_remaining -= read;
205
206        // All data is read, after this we're expecting a new header.
207        if inner.read_remaining == 0 {
208            inner.header_pos = 0;
209        }
210
211        Poll::Ready(Ok(read))
212    }
213}
214
215impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper<S> {
216    fn poll_write(
217        mut self: Pin<&mut Self>,
218        cx: &mut task::Context<'_>,
219        buf: &[u8],
220    ) -> Poll<io::Result<usize>> {
221        // Normal operation does not need any extra treatment, we handle
222        // packets in the codec.
223        if !self.pending_handshake {
224            return Pin::new(&mut self.stream.as_mut().unwrap()).poll_write(cx, buf);
225        }
226
227        // Buffering data.
228        self.wr_buf.extend_from_slice(buf);
229
230        Poll::Ready(Ok(buf.len()))
231    }
232
233    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
234        let inner = self.get_mut();
235
236        // If on handshake mode, wraps the data to a TDS packet before sending.
237        if inner.pending_handshake && inner.wr_buf.len() > HEADER_BYTES {
238            if !inner.header_written {
239                let mut header = PacketHeader::new(inner.wr_buf.len(), 0);
240
241                header.set_type(PacketType::PreLogin);
242                header.set_status(PacketStatus::EndOfMessage);
243
244                header
245                    .encode(&mut &mut inner.wr_buf[0..HEADER_BYTES])
246                    .map_err(|_| {
247                        io::Error::new(io::ErrorKind::InvalidInput, "Could not encode header.")
248                    })?;
249
250                inner.header_written = true;
251            }
252
253            while !inner.wr_buf.is_empty() {
254                event!(
255                    Level::TRACE,
256                    "Writing a packet of {} bytes",
257                    inner.wr_buf.len(),
258                );
259
260                let written = ready!(
261                    Pin::new(&mut inner.stream.as_mut().unwrap()).poll_write(cx, &inner.wr_buf)
262                )?;
263
264                inner.wr_buf.drain(..written);
265            }
266
267            inner.wr_buf.resize(HEADER_BYTES, 0);
268            inner.header_written = false;
269        }
270
271        Pin::new(&mut inner.stream.as_mut().unwrap()).poll_flush(cx)
272    }
273
274    fn poll_close(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
275        Pin::new(&mut self.stream.as_mut().unwrap()).poll_close(cx)
276    }
277}