1#[cfg(any(
2 feature = "rustls",
3 feature = "native-tls",
4 feature = "vendored-openssl"
5))]
6use super::tls_stream::TlsStream;
7use crate::tds::{
8 codec::{Decode, Encode, PacketHeader, PacketStatus, PacketType},
9 HEADER_BYTES,
10};
11use bytes::BytesMut;
12use futures_util::io::{AsyncRead, AsyncWrite};
13use futures_util::ready;
14use std::{
15 cmp, io,
16 pin::Pin,
17 task::{self, Poll},
18};
19use tracing::{event, Level};
20
21pub(crate) enum MaybeTlsStream<S: AsyncRead + AsyncWrite + Unpin + Send> {
23 Raw(S),
24 #[cfg(any(
25 feature = "rustls",
26 feature = "native-tls",
27 feature = "vendored-openssl"
28 ))]
29 Tls(TlsStream<TlsPreloginWrapper<S>>),
30}
31
32#[cfg(any(
33 feature = "rustls",
34 feature = "native-tls",
35 feature = "vendored-openssl"
36))]
37impl<S: AsyncRead + AsyncWrite + Unpin + Send> MaybeTlsStream<S> {
38 pub fn into_inner(self) -> S {
39 match self {
40 Self::Raw(s) => s,
41 #[cfg(any(
42 feature = "rustls",
43 feature = "native-tls",
44 feature = "vendored-openssl"
45 ))]
46 Self::Tls(mut tls) => tls.get_mut().stream.take().unwrap(),
47 }
48 }
49}
50
51impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for MaybeTlsStream<S> {
52 fn poll_read(
53 self: Pin<&mut Self>,
54 cx: &mut task::Context<'_>,
55 buf: &mut [u8],
56 ) -> Poll<io::Result<usize>> {
57 match self.get_mut() {
58 MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
59 #[cfg(any(
60 feature = "rustls",
61 feature = "native-tls",
62 feature = "vendored-openssl"
63 ))]
64 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
65 }
66 }
67}
68
69impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for MaybeTlsStream<S> {
70 fn poll_write(
71 self: Pin<&mut Self>,
72 cx: &mut task::Context<'_>,
73 buf: &[u8],
74 ) -> Poll<io::Result<usize>> {
75 match self.get_mut() {
76 MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
77 #[cfg(any(
78 feature = "rustls",
79 feature = "native-tls",
80 feature = "vendored-openssl"
81 ))]
82 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
83 }
84 }
85
86 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
87 match self.get_mut() {
88 MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
89 #[cfg(any(
90 feature = "rustls",
91 feature = "native-tls",
92 feature = "vendored-openssl"
93 ))]
94 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
95 }
96 }
97
98 fn poll_close(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
99 match self.get_mut() {
100 MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
101 #[cfg(any(
102 feature = "rustls",
103 feature = "native-tls",
104 feature = "vendored-openssl"
105 ))]
106 MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
107 }
108 }
109}
110
111pub(crate) struct TlsPreloginWrapper<S> {
118 stream: Option<S>,
119 pending_handshake: bool,
120
121 header_buf: [u8; HEADER_BYTES],
122 header_pos: usize,
123 read_remaining: usize,
124
125 wr_buf: Vec<u8>,
126 header_written: bool,
127}
128
129#[cfg(any(
130 feature = "rustls",
131 feature = "native-tls",
132 feature = "vendored-openssl"
133))]
134impl<S> TlsPreloginWrapper<S> {
135 pub fn new(stream: S) -> Self {
136 TlsPreloginWrapper {
137 stream: Some(stream),
138 pending_handshake: true,
139
140 header_buf: [0u8; HEADER_BYTES],
141 header_pos: 0,
142 read_remaining: 0,
143 wr_buf: vec![0u8; HEADER_BYTES],
144 header_written: false,
145 }
146 }
147
148 pub fn handshake_complete(&mut self) {
149 self.pending_handshake = false;
150 }
151}
152
153impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncRead for TlsPreloginWrapper<S> {
154 fn poll_read(
155 mut self: Pin<&mut Self>,
156 cx: &mut task::Context<'_>,
157 buf: &mut [u8],
158 ) -> Poll<io::Result<usize>> {
159 if !self.pending_handshake {
162 return Pin::new(&mut self.stream.as_mut().unwrap()).poll_read(cx, buf);
163 }
164
165 let inner = self.get_mut();
166
167 if !inner.header_buf[inner.header_pos..].is_empty() {
170 while !inner.header_buf[inner.header_pos..].is_empty() {
171 let read = ready!(Pin::new(inner.stream.as_mut().unwrap())
172 .poll_read(cx, &mut inner.header_buf[inner.header_pos..]))?;
173
174 if read == 0 {
175 return Poll::Ready(Ok(0));
176 }
177
178 inner.header_pos += read;
179 }
180
181 let header = PacketHeader::decode(&mut BytesMut::from(&inner.header_buf[..]))
182 .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
183
184 assert_eq!(header.r#type(), PacketType::PreLogin);
186
187 inner.read_remaining = header.length() as usize - HEADER_BYTES;
189
190 event!(
191 Level::TRACE,
192 "Reading packet of {} bytes",
193 inner.read_remaining,
194 );
195 }
196
197 let max_read = cmp::min(inner.read_remaining, buf.len());
198
199 let read = ready!(
201 Pin::new(&mut inner.stream.as_mut().unwrap()).poll_read(cx, &mut buf[..max_read])
202 )?;
203
204 inner.read_remaining -= read;
205
206 if inner.read_remaining == 0 {
208 inner.header_pos = 0;
209 }
210
211 Poll::Ready(Ok(read))
212 }
213}
214
215impl<S: AsyncRead + AsyncWrite + Unpin + Send> AsyncWrite for TlsPreloginWrapper<S> {
216 fn poll_write(
217 mut self: Pin<&mut Self>,
218 cx: &mut task::Context<'_>,
219 buf: &[u8],
220 ) -> Poll<io::Result<usize>> {
221 if !self.pending_handshake {
224 return Pin::new(&mut self.stream.as_mut().unwrap()).poll_write(cx, buf);
225 }
226
227 self.wr_buf.extend_from_slice(buf);
229
230 Poll::Ready(Ok(buf.len()))
231 }
232
233 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
234 let inner = self.get_mut();
235
236 if inner.pending_handshake && inner.wr_buf.len() > HEADER_BYTES {
238 if !inner.header_written {
239 let mut header = PacketHeader::new(inner.wr_buf.len(), 0);
240
241 header.set_type(PacketType::PreLogin);
242 header.set_status(PacketStatus::EndOfMessage);
243
244 header
245 .encode(&mut &mut inner.wr_buf[0..HEADER_BYTES])
246 .map_err(|_| {
247 io::Error::new(io::ErrorKind::InvalidInput, "Could not encode header.")
248 })?;
249
250 inner.header_written = true;
251 }
252
253 while !inner.wr_buf.is_empty() {
254 event!(
255 Level::TRACE,
256 "Writing a packet of {} bytes",
257 inner.wr_buf.len(),
258 );
259
260 let written = ready!(
261 Pin::new(&mut inner.stream.as_mut().unwrap()).poll_write(cx, &inner.wr_buf)
262 )?;
263
264 inner.wr_buf.drain(..written);
265 }
266
267 inner.wr_buf.resize(HEADER_BYTES, 0);
268 inner.header_written = false;
269 }
270
271 Pin::new(&mut inner.stream.as_mut().unwrap()).poll_flush(cx)
272 }
273
274 fn poll_close(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
275 Pin::new(&mut self.stream.as_mut().unwrap()).poll_close(cx)
276 }
277}