tiberius/tds/codec/
header.rs

1use super::{Decode, Encode};
2use crate::Error;
3use bytes::{Buf, BufMut, BytesMut};
4use std::convert::TryFrom;
5
6uint_enum! {
7    /// the type of the packet [2.2.3.1.1]#[repr(u32)]
8    #[repr(u8)]
9    pub enum PacketType {
10        SQLBatch = 1,
11        /// unused
12        PreTDSv7Login = 2,
13        Rpc = 3,
14        TabularResult = 4,
15        AttentionSignal = 6,
16        BulkLoad = 7,
17        /// Federated Authentication Token
18        Fat = 8,
19        TransactionManagerReq = 14,
20        TDSv7Login = 16,
21        Sspi = 17,
22        PreLogin = 18,
23    }
24}
25
26uint_enum! {
27    /// the message state [2.2.3.1.2]
28    #[repr(u8)]
29    pub enum PacketStatus {
30        NormalMessage = 0,
31        EndOfMessage = 1,
32        /// [client to server ONLY] (EndOfMessage also required)
33        IgnoreEvent = 3,
34        /// [client to server ONLY] [>= TDSv7.1]
35        ResetConnection = 0x08,
36        /// [client to server ONLY] [>= TDSv7.3]
37        ResetConnectionSkipTran = 0x10,
38    }
39}
40
41/// packet header consisting of 8 bytes [2.2.3.1]
42#[derive(Debug, Clone, Copy)]
43pub(crate) struct PacketHeader {
44    ty: PacketType,
45    status: PacketStatus,
46    /// [BE] the length of the packet (including the 8 header bytes)
47    /// must match the negotiated size sending from client to server [since TDSv7.3] after login
48    /// (only if not EndOfMessage)
49    length: u16,
50    /// [BE] the process ID on the server, for debugging purposes only
51    spid: u16,
52    /// packet id
53    id: u8,
54    /// currently unused
55    window: u8,
56}
57
58impl PacketHeader {
59    pub fn new(length: usize, id: u8) -> PacketHeader {
60        assert!(length <= u16::max_value() as usize);
61        PacketHeader {
62            ty: PacketType::TDSv7Login,
63            status: PacketStatus::ResetConnection,
64            length: length as u16,
65            spid: 0,
66            id,
67            window: 0,
68        }
69    }
70
71    pub fn rpc(id: u8) -> Self {
72        Self {
73            ty: PacketType::Rpc,
74            status: PacketStatus::NormalMessage,
75            ..Self::new(0, id)
76        }
77    }
78
79    pub fn pre_login(id: u8) -> Self {
80        Self {
81            ty: PacketType::PreLogin,
82            status: PacketStatus::EndOfMessage,
83            ..Self::new(0, id)
84        }
85    }
86
87    pub fn login(id: u8) -> Self {
88        Self {
89            ty: PacketType::TDSv7Login,
90            status: PacketStatus::EndOfMessage,
91            ..Self::new(0, id)
92        }
93    }
94
95    pub fn batch(id: u8) -> Self {
96        Self {
97            ty: PacketType::SQLBatch,
98            status: PacketStatus::NormalMessage,
99            ..Self::new(0, id)
100        }
101    }
102
103    pub fn bulk_load(id: u8) -> Self {
104        Self {
105            ty: PacketType::BulkLoad,
106            status: PacketStatus::NormalMessage,
107            ..Self::new(0, id)
108        }
109    }
110
111    pub fn set_status(&mut self, status: PacketStatus) {
112        self.status = status;
113    }
114
115    pub fn set_type(&mut self, ty: PacketType) {
116        self.ty = ty;
117    }
118
119    pub fn status(&self) -> PacketStatus {
120        self.status
121    }
122
123    pub fn r#type(&self) -> PacketType {
124        self.ty
125    }
126
127    pub fn length(&self) -> u16 {
128        self.length
129    }
130}
131
132impl<B> Encode<B> for PacketHeader
133where
134    B: BufMut,
135{
136    fn encode(self, dst: &mut B) -> crate::Result<()> {
137        dst.put_u8(self.ty as u8);
138        dst.put_u8(self.status as u8);
139        dst.put_u16(self.length);
140        dst.put_u16(self.spid);
141        dst.put_u8(self.id);
142        dst.put_u8(self.window);
143
144        Ok(())
145    }
146}
147
148impl Decode<BytesMut> for PacketHeader {
149    fn decode(src: &mut BytesMut) -> crate::Result<Self>
150    where
151        Self: Sized,
152    {
153        let raw_ty = src.get_u8();
154
155        let ty = PacketType::try_from(raw_ty).map_err(|_| {
156            Error::Protocol(format!("header: invalid packet type: {}", raw_ty).into())
157        })?;
158
159        let status = PacketStatus::try_from(src.get_u8())
160            .map_err(|_| Error::Protocol("header: invalid packet status".into()))?;
161
162        let header = PacketHeader {
163            ty,
164            status,
165            length: src.get_u16(),
166            spid: src.get_u16(),
167            id: src.get_u8(),
168            window: src.get_u8(),
169        };
170
171        Ok(header)
172    }
173}