1use crate::tds::Context;
2use bytes::Buf;
3use futures_util::io::AsyncRead;
4use pin_project_lite::pin_project;
5use std::io::ErrorKind::UnexpectedEof;
6use std::{future::Future, io, mem::size_of, pin::Pin, task};
7use task::Poll;
8
9macro_rules! varchar_reader {
10 ($name:ident, $length_reader:ident) => {
11 pin_project! {
12 #[doc(hidden)]
13 pub struct $name<R> {
14 #[pin]
15 src: R,
16 length: Option<usize>,
17 buf: Option<Vec<u16>>,
18 read: usize
19 }
20 }
21
22 #[allow(dead_code)]
23 impl<R> $name<R> {
24 pub(crate) fn new(src: R) -> Self {
25 Self {
26 src,
27 length: None,
28 buf: None,
29 read: 0,
30 }
31 }
32 }
33
34 impl<R> Future for $name<R>
35 where
36 R: AsyncRead,
37 {
38 type Output = io::Result<String>;
39
40 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
41 let mut me = self.project();
42
43 while me.length.is_none() {
45 let mut read_len = $length_reader::new(&mut me.src);
46
47 match Pin::new(&mut read_len).poll(cx) {
48 Poll::Pending => return Poll::Pending,
49 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
50 Poll::Ready(Ok(length)) => {
51 *me.length = Some(length as usize);
52 *me.buf = Some(Vec::with_capacity(length as usize));
53 }
54 }
55 }
56
57 let len = me.length.unwrap();
59 let buf = me.buf.as_mut().unwrap();
60
61 if *me.read == len {
63 let s = String::from_utf16(&buf).map_err(|_| {
64 io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-16 data.")
65 })?;
66
67 return Poll::Ready(Ok(s));
68 }
69
70 while *me.read < len {
72 let mut read_u16 = ReadU16Le::new(&mut me.src);
73
74 match Pin::new(&mut read_u16).poll(cx) {
75 Poll::Pending => return Poll::Pending,
76 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
77 Poll::Ready(Ok(n)) => {
78 buf.push(n);
79 *me.read += 1;
80 }
81 }
82 }
83
84 let s = String::from_utf16(&buf).map_err(|_| {
86 io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-16 data.")
87 })?;
88
89 return Poll::Ready(Ok(s));
90 }
91 }
92 };
93}
94
95macro_rules! bytes_reader {
96 ($name:ident, $ty:ty, $reader:ident) => {
97 bytes_reader!($name, $ty, $reader, size_of::<$ty>());
98 };
99 ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => {
100 pin_project! {
101 #[doc(hidden)]
102 pub struct $name<R> {
103 #[pin]
104 src: R,
105 buf: [u8; $bytes],
106 read: u8,
107 }
108 }
109
110 #[allow(dead_code)]
111 impl<R> $name<R> {
112 pub(crate) fn new(src: R) -> Self {
113 $name {
114 src,
115 buf: [0; $bytes],
116 read: 0,
117 }
118 }
119 }
120
121 impl<R> Future for $name<R>
122 where
123 R: AsyncRead,
124 {
125 type Output = io::Result<$ty>;
126
127 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
128 let mut me = self.project();
129
130 if *me.read == $bytes as u8 {
131 return Poll::Ready(Ok(Buf::$reader(&mut &me.buf[..])));
132 }
133
134 while *me.read < $bytes as u8 {
135 *me.read += match me
136 .src
137 .as_mut()
138 .poll_read(cx, &mut me.buf[*me.read as usize..])
139 {
140 Poll::Pending => return Poll::Pending,
141 Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
142 Poll::Ready(Ok(0)) => {
143 return Poll::Ready(Err(UnexpectedEof.into()));
144 }
145 Poll::Ready(Ok(n)) => n as u8,
146 };
147 }
148
149 let num = Buf::$reader(&mut &me.buf[..]);
150
151 Poll::Ready(Ok(num))
152 }
153 }
154 };
155}
156
157pub(crate) trait SqlReadBytes: AsyncRead + Unpin {
160 #[allow(dead_code)]
162 fn debug_buffer(&self);
163
164 fn context(&self) -> &Context;
166
167 fn context_mut(&mut self) -> &mut Context;
169
170 #[allow(dead_code)]
172 fn read_i8(&mut self) -> ReadI8<&mut Self>
173 where
174 Self: Unpin,
175 {
176 ReadI8::new(self)
177 }
178
179 fn read_u8(&mut self) -> ReadU8<&mut Self>
181 where
182 Self: Unpin,
183 {
184 ReadU8::new(self)
185 }
186
187 fn read_u32(&mut self) -> ReadU32Be<&mut Self>
189 where
190 Self: Unpin,
191 {
192 ReadU32Be::new(self)
193 }
194
195 #[allow(dead_code)]
197 fn read_f32(&mut self) -> ReadF32<&mut Self>
198 where
199 Self: Unpin,
200 {
201 ReadF32::new(self)
202 }
203
204 #[allow(dead_code)]
206 fn read_f64(&mut self) -> ReadF64<&mut Self>
207 where
208 Self: Unpin,
209 {
210 ReadF64::new(self)
211 }
212
213 fn read_f32_le(&mut self) -> ReadF32Le<&mut Self>
215 where
216 Self: Unpin,
217 {
218 ReadF32Le::new(self)
219 }
220
221 fn read_f64_le(&mut self) -> ReadF64Le<&mut Self>
223 where
224 Self: Unpin,
225 {
226 ReadF64Le::new(self)
227 }
228
229 fn read_u16_le(&mut self) -> ReadU16Le<&mut Self>
231 where
232 Self: Unpin,
233 {
234 ReadU16Le::new(self)
235 }
236
237 fn read_u32_le(&mut self) -> ReadU32Le<&mut Self>
239 where
240 Self: Unpin,
241 {
242 ReadU32Le::new(self)
243 }
244
245 fn read_u64_le(&mut self) -> ReadU64Le<&mut Self>
247 where
248 Self: Unpin,
249 {
250 ReadU64Le::new(self)
251 }
252
253 #[allow(dead_code)]
255 fn read_u128_le(&mut self) -> ReadU128Le<&mut Self>
256 where
257 Self: Unpin,
258 {
259 ReadU128Le::new(self)
260 }
261
262 fn read_i16_le(&mut self) -> ReadI16Le<&mut Self>
264 where
265 Self: Unpin,
266 {
267 ReadI16Le::new(self)
268 }
269
270 fn read_i32_le(&mut self) -> ReadI32Le<&mut Self>
272 where
273 Self: Unpin,
274 {
275 ReadI32Le::new(self)
276 }
277
278 fn read_i64_le(&mut self) -> ReadI64Le<&mut Self>
280 where
281 Self: Unpin,
282 {
283 ReadI64Le::new(self)
284 }
285
286 #[allow(dead_code)]
288 fn read_i128_le(&mut self) -> ReadI128Le<&mut Self>
289 where
290 Self: Unpin,
291 {
292 ReadI128Le::new(self)
293 }
294
295 fn read_b_varchar(&mut self) -> ReadBVarchar<&mut Self>
297 where
298 Self: Unpin,
299 {
300 ReadBVarchar::new(self)
301 }
302
303 fn read_us_varchar(&mut self) -> ReadUSVarchar<&mut Self>
305 where
306 Self: Unpin,
307 {
308 ReadUSVarchar::new(self)
309 }
310}
311
312varchar_reader!(ReadBVarchar, ReadU8);
313varchar_reader!(ReadUSVarchar, ReadU16Le);
314
315bytes_reader!(ReadI8, i8, get_i8);
316bytes_reader!(ReadU8, u8, get_u8);
317bytes_reader!(ReadU32Be, u32, get_u32);
318
319bytes_reader!(ReadU16Le, u16, get_u16_le);
320bytes_reader!(ReadU32Le, u32, get_u32_le);
321bytes_reader!(ReadU64Le, u64, get_u64_le);
322bytes_reader!(ReadU128Le, u128, get_u128_le);
323
324bytes_reader!(ReadI16Le, i16, get_i16_le);
325bytes_reader!(ReadI32Le, i32, get_i32_le);
326bytes_reader!(ReadI64Le, i64, get_i64_le);
327bytes_reader!(ReadI128Le, i128, get_i128_le);
328
329bytes_reader!(ReadF32, f32, get_f32);
330bytes_reader!(ReadF64, f64, get_f64);
331
332bytes_reader!(ReadF32Le, f32, get_f32_le);
333bytes_reader!(ReadF64Le, f64, get_f64_le);
334
335#[cfg(test)]
336pub(crate) mod test_utils {
337 use crate::tds::Context;
338 use crate::SqlReadBytes;
339 use bytes::BytesMut;
340 use futures_util::io::AsyncRead;
341 use std::io;
342 use std::pin::Pin;
343 use std::task::Poll;
344
345 pub(crate) trait IntoSqlReadBytes {
347 type T: SqlReadBytes;
348 fn into_sql_read_bytes(self) -> Self::T;
349 }
350
351 impl IntoSqlReadBytes for BytesMut {
352 type T = BytesMutReader;
353
354 fn into_sql_read_bytes(self) -> Self::T {
355 BytesMutReader { buf: self }
356 }
357 }
358
359 pub(crate) struct BytesMutReader {
360 buf: BytesMut,
361 }
362
363 impl AsyncRead for BytesMutReader {
364 fn poll_read(
365 self: Pin<&mut Self>,
366 _cx: &mut std::task::Context<'_>,
367 buf: &mut [u8],
368 ) -> Poll<std::io::Result<usize>> {
369 let this = self.get_mut();
370 let size = buf.len();
371
372 if this.buf.len() < size {
374 return Poll::Ready(Err(io::Error::new(
375 io::ErrorKind::UnexpectedEof,
376 "No more packets in the wire",
377 )));
378 }
379
380 buf.copy_from_slice(this.buf.split_to(size).as_ref());
381 Poll::Ready(Ok(size))
382 }
383 }
384
385 impl SqlReadBytes for BytesMutReader {
386 fn debug_buffer(&self) {
387 todo!()
388 }
389
390 fn context(&self) -> &Context {
391 todo!()
392 }
393
394 fn context_mut(&mut self) -> &mut Context {
395 todo!()
396 }
397 }
398}