tiberius/
sql_read_bytes.rs

1use crate::tds::Context;
2use bytes::Buf;
3use futures_util::io::AsyncRead;
4use pin_project_lite::pin_project;
5use std::io::ErrorKind::UnexpectedEof;
6use std::{future::Future, io, mem::size_of, pin::Pin, task};
7use task::Poll;
8
9macro_rules! varchar_reader {
10    ($name:ident, $length_reader:ident) => {
11        pin_project! {
12            #[doc(hidden)]
13            pub struct $name<R> {
14                #[pin]
15                src: R,
16                length: Option<usize>,
17                buf: Option<Vec<u16>>,
18                read: usize
19            }
20        }
21
22        #[allow(dead_code)]
23        impl<R> $name<R> {
24            pub(crate) fn new(src: R) -> Self {
25                Self {
26                    src,
27                    length: None,
28                    buf: None,
29                    read: 0,
30                }
31            }
32        }
33
34        impl<R> Future for $name<R>
35        where
36            R: AsyncRead,
37        {
38            type Output = io::Result<String>;
39
40            fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
41                let mut me = self.project();
42
43                // We must know the length of the string first.
44                while me.length.is_none() {
45                    let mut read_len = $length_reader::new(&mut me.src);
46
47                    match Pin::new(&mut read_len).poll(cx) {
48                        Poll::Pending => return Poll::Pending,
49                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
50                        Poll::Ready(Ok(length)) => {
51                            *me.length = Some(length as usize);
52                            *me.buf = Some(Vec::with_capacity(length as usize));
53                        }
54                    }
55                }
56
57                // We've set the length and initialized the buffer
58                let len = me.length.unwrap();
59                let buf = me.buf.as_mut().unwrap();
60
61                // Everything's read, we can return the string.
62                if *me.read == len {
63                    let s = String::from_utf16(&buf).map_err(|_| {
64                        io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-16 data.")
65                    })?;
66
67                    return Poll::Ready(Ok(s));
68                }
69
70                // Read the utf-16 data
71                while *me.read < len {
72                    let mut read_u16 = ReadU16Le::new(&mut me.src);
73
74                    match Pin::new(&mut read_u16).poll(cx) {
75                        Poll::Pending => return Poll::Pending,
76                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
77                        Poll::Ready(Ok(n)) => {
78                            buf.push(n);
79                            *me.read += 1;
80                        }
81                    }
82                }
83
84                // Everything's read, we can return the string.
85                let s = String::from_utf16(&buf).map_err(|_| {
86                    io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-16 data.")
87                })?;
88
89                return Poll::Ready(Ok(s));
90            }
91        }
92    };
93}
94
95macro_rules! bytes_reader {
96    ($name:ident, $ty:ty, $reader:ident) => {
97        bytes_reader!($name, $ty, $reader, size_of::<$ty>());
98    };
99    ($name:ident, $ty:ty, $reader:ident, $bytes:expr) => {
100        pin_project! {
101            #[doc(hidden)]
102            pub struct $name<R> {
103                #[pin]
104                src: R,
105                buf: [u8; $bytes],
106                read: u8,
107            }
108        }
109
110        #[allow(dead_code)]
111        impl<R> $name<R> {
112            pub(crate) fn new(src: R) -> Self {
113                $name {
114                    src,
115                    buf: [0; $bytes],
116                    read: 0,
117                }
118            }
119        }
120
121        impl<R> Future for $name<R>
122        where
123            R: AsyncRead,
124        {
125            type Output = io::Result<$ty>;
126
127            fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
128                let mut me = self.project();
129
130                if *me.read == $bytes as u8 {
131                    return Poll::Ready(Ok(Buf::$reader(&mut &me.buf[..])));
132                }
133
134                while *me.read < $bytes as u8 {
135                    *me.read += match me
136                        .src
137                        .as_mut()
138                        .poll_read(cx, &mut me.buf[*me.read as usize..])
139                    {
140                        Poll::Pending => return Poll::Pending,
141                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e.into())),
142                        Poll::Ready(Ok(0)) => {
143                            return Poll::Ready(Err(UnexpectedEof.into()));
144                        }
145                        Poll::Ready(Ok(n)) => n as u8,
146                    };
147                }
148
149                let num = Buf::$reader(&mut &me.buf[..]);
150
151                Poll::Ready(Ok(num))
152            }
153        }
154    };
155}
156
157/// The `SqlReadBytes` trait is used to read bytes from the wire.
158// Many of the methods have an `allow(dead_code)` attribute because they are not currently used but they could be anytime in the future.
159pub(crate) trait SqlReadBytes: AsyncRead + Unpin {
160    // Pretty-print current wire content.
161    #[allow(dead_code)]
162    fn debug_buffer(&self);
163
164    // The client state.
165    fn context(&self) -> &Context;
166
167    // A mutable reference to the SQL client state.
168    fn context_mut(&mut self) -> &mut Context;
169
170    // Read a single i8 value.
171    #[allow(dead_code)]
172    fn read_i8(&mut self) -> ReadI8<&mut Self>
173    where
174        Self: Unpin,
175    {
176        ReadI8::new(self)
177    }
178
179    // Read a single byte value.
180    fn read_u8(&mut self) -> ReadU8<&mut Self>
181    where
182        Self: Unpin,
183    {
184        ReadU8::new(self)
185    }
186
187    // Read a single big-endian u32 value.
188    fn read_u32(&mut self) -> ReadU32Be<&mut Self>
189    where
190        Self: Unpin,
191    {
192        ReadU32Be::new(self)
193    }
194
195    // Read a single big-endian f32 value.
196    #[allow(dead_code)]
197    fn read_f32(&mut self) -> ReadF32<&mut Self>
198    where
199        Self: Unpin,
200    {
201        ReadF32::new(self)
202    }
203
204    // Read a single big-endian f64 value.
205    #[allow(dead_code)]
206    fn read_f64(&mut self) -> ReadF64<&mut Self>
207    where
208        Self: Unpin,
209    {
210        ReadF64::new(self)
211    }
212
213    // Read a single f32 value.
214    fn read_f32_le(&mut self) -> ReadF32Le<&mut Self>
215    where
216        Self: Unpin,
217    {
218        ReadF32Le::new(self)
219    }
220
221    // Read a single f64 value.
222    fn read_f64_le(&mut self) -> ReadF64Le<&mut Self>
223    where
224        Self: Unpin,
225    {
226        ReadF64Le::new(self)
227    }
228
229    // Read a single u16 value.
230    fn read_u16_le(&mut self) -> ReadU16Le<&mut Self>
231    where
232        Self: Unpin,
233    {
234        ReadU16Le::new(self)
235    }
236
237    // Read a single u32 value.
238    fn read_u32_le(&mut self) -> ReadU32Le<&mut Self>
239    where
240        Self: Unpin,
241    {
242        ReadU32Le::new(self)
243    }
244
245    // Read a single u64 value.
246    fn read_u64_le(&mut self) -> ReadU64Le<&mut Self>
247    where
248        Self: Unpin,
249    {
250        ReadU64Le::new(self)
251    }
252
253    // Read a single u128 value.
254    #[allow(dead_code)]
255    fn read_u128_le(&mut self) -> ReadU128Le<&mut Self>
256    where
257        Self: Unpin,
258    {
259        ReadU128Le::new(self)
260    }
261
262    // Read a single i16 value.
263    fn read_i16_le(&mut self) -> ReadI16Le<&mut Self>
264    where
265        Self: Unpin,
266    {
267        ReadI16Le::new(self)
268    }
269
270    // Read a single i32 value.
271    fn read_i32_le(&mut self) -> ReadI32Le<&mut Self>
272    where
273        Self: Unpin,
274    {
275        ReadI32Le::new(self)
276    }
277
278    // Read a single i64 value.
279    fn read_i64_le(&mut self) -> ReadI64Le<&mut Self>
280    where
281        Self: Unpin,
282    {
283        ReadI64Le::new(self)
284    }
285
286    // Read a single i128 value.
287    #[allow(dead_code)]
288    fn read_i128_le(&mut self) -> ReadI128Le<&mut Self>
289    where
290        Self: Unpin,
291    {
292        ReadI128Le::new(self)
293    }
294
295    // A variable-length character stream defined by a length-field of an u8.
296    fn read_b_varchar(&mut self) -> ReadBVarchar<&mut Self>
297    where
298        Self: Unpin,
299    {
300        ReadBVarchar::new(self)
301    }
302
303    // A variable-length character stream defined by a length-field of an u16.
304    fn read_us_varchar(&mut self) -> ReadUSVarchar<&mut Self>
305    where
306        Self: Unpin,
307    {
308        ReadUSVarchar::new(self)
309    }
310}
311
312varchar_reader!(ReadBVarchar, ReadU8);
313varchar_reader!(ReadUSVarchar, ReadU16Le);
314
315bytes_reader!(ReadI8, i8, get_i8);
316bytes_reader!(ReadU8, u8, get_u8);
317bytes_reader!(ReadU32Be, u32, get_u32);
318
319bytes_reader!(ReadU16Le, u16, get_u16_le);
320bytes_reader!(ReadU32Le, u32, get_u32_le);
321bytes_reader!(ReadU64Le, u64, get_u64_le);
322bytes_reader!(ReadU128Le, u128, get_u128_le);
323
324bytes_reader!(ReadI16Le, i16, get_i16_le);
325bytes_reader!(ReadI32Le, i32, get_i32_le);
326bytes_reader!(ReadI64Le, i64, get_i64_le);
327bytes_reader!(ReadI128Le, i128, get_i128_le);
328
329bytes_reader!(ReadF32, f32, get_f32);
330bytes_reader!(ReadF64, f64, get_f64);
331
332bytes_reader!(ReadF32Le, f32, get_f32_le);
333bytes_reader!(ReadF64Le, f64, get_f64_le);
334
335#[cfg(test)]
336pub(crate) mod test_utils {
337    use crate::tds::Context;
338    use crate::SqlReadBytes;
339    use bytes::BytesMut;
340    use futures_util::io::AsyncRead;
341    use std::io;
342    use std::pin::Pin;
343    use std::task::Poll;
344
345    // a test util to run decode logic on BytesMut, for testing loop back
346    pub(crate) trait IntoSqlReadBytes {
347        type T: SqlReadBytes;
348        fn into_sql_read_bytes(self) -> Self::T;
349    }
350
351    impl IntoSqlReadBytes for BytesMut {
352        type T = BytesMutReader;
353
354        fn into_sql_read_bytes(self) -> Self::T {
355            BytesMutReader { buf: self }
356        }
357    }
358
359    pub(crate) struct BytesMutReader {
360        buf: BytesMut,
361    }
362
363    impl AsyncRead for BytesMutReader {
364        fn poll_read(
365            self: Pin<&mut Self>,
366            _cx: &mut std::task::Context<'_>,
367            buf: &mut [u8],
368        ) -> Poll<std::io::Result<usize>> {
369            let this = self.get_mut();
370            let size = buf.len();
371
372            // Got EOF before having all the data.
373            if this.buf.len() < size {
374                return Poll::Ready(Err(io::Error::new(
375                    io::ErrorKind::UnexpectedEof,
376                    "No more packets in the wire",
377                )));
378            }
379
380            buf.copy_from_slice(this.buf.split_to(size).as_ref());
381            Poll::Ready(Ok(size))
382        }
383    }
384
385    impl SqlReadBytes for BytesMutReader {
386        fn debug_buffer(&self) {
387            todo!()
388        }
389
390        fn context(&self) -> &Context {
391            todo!()
392        }
393
394        fn context_mut(&mut self) -> &mut Context {
395            todo!()
396        }
397    }
398}