tiberius/tds/codec/
decode.rs

1use super::{Packet, PacketCodec, PacketHeader, HEADER_BYTES};
2use crate::Error;
3use asynchronous_codec::Decoder;
4use bytes::{Buf, BytesMut};
5use tracing::{event, Level};
6
7pub trait Decode<B: Buf> {
8    fn decode(src: &mut B) -> crate::Result<Self>
9    where
10        Self: Sized;
11}
12
13impl Decoder for PacketCodec {
14    type Item = Packet;
15    type Error = Error;
16
17    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
18        if src.len() < HEADER_BYTES {
19            src.reserve(HEADER_BYTES);
20            return Ok(None);
21        }
22
23        let header = PacketHeader::decode(&mut BytesMut::from(&src[0..HEADER_BYTES]))?;
24        let length = header.length() as usize;
25
26        if src.len() < length {
27            src.reserve(length);
28            return Ok(None);
29        }
30
31        event!(
32            Level::TRACE,
33            "Reading a {:?} ({} bytes)",
34            header.r#type(),
35            length,
36        );
37
38        let header = PacketHeader::decode(src)?;
39
40        if length < HEADER_BYTES {
41            return Err(Error::Protocol("Invalid packet length".into()));
42        }
43
44        let payload = src.split_to(length - HEADER_BYTES);
45
46        Ok(Some(Packet::new(header, payload)))
47    }
48
49    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
50        match self.decode(buf)? {
51            Some(frame) => Ok(Some(frame)),
52            None => {
53                if buf.is_empty() {
54                    Ok(None)
55                } else {
56                    Err(
57                        std::io::Error::new(std::io::ErrorKind::Other, "bytes remaining on stream")
58                            .into(),
59                    )
60                }
61            }
62        }
63    }
64}