1use crate::tds::codec::TokenSspi;
2use crate::{
3 client::Connection,
4 tds::codec::{
5 TokenColMetaData, TokenDone, TokenEnvChange, TokenError, TokenFeatureExtAck, TokenInfo,
6 TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow,
7 },
8 Error, SqlReadBytes, TokenType,
9};
10use futures_util::{
11 io::{AsyncRead, AsyncWrite},
12 stream::{BoxStream, TryStreamExt},
13};
14use std::{convert::TryFrom, sync::Arc};
15use tracing::{event, Level};
16
17#[derive(Debug)]
18#[allow(dead_code)]
19pub enum ReceivedToken {
20 NewResultset(Arc<TokenColMetaData<'static>>),
21 Row(TokenRow<'static>),
22 Done(TokenDone),
23 DoneInProc(TokenDone),
24 DoneProc(TokenDone),
25 ReturnStatus(u32),
26 ReturnValue(TokenReturnValue),
27 Order(TokenOrder),
28 EnvChange(TokenEnvChange),
29 Info(TokenInfo),
30 LoginAck(TokenLoginAck),
31 Sspi(TokenSspi),
32 FeatureExtAck(TokenFeatureExtAck),
33 Error(TokenError),
34}
35
36pub(crate) struct TokenStream<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
37 conn: &'a mut Connection<S>,
38 last_error: Option<Error>,
39}
40
41impl<'a, S> TokenStream<'a, S>
42where
43 S: AsyncRead + AsyncWrite + Unpin + Send,
44{
45 pub(crate) fn new(conn: &'a mut Connection<S>) -> Self {
46 Self {
47 conn,
48 last_error: None,
49 }
50 }
51
52 pub(crate) async fn flush_done(self) -> crate::Result<TokenDone> {
53 let mut stream = self.try_unfold();
54 let mut last_error = None;
55 let mut routing = None;
56
57 loop {
58 match stream.try_next().await? {
59 Some(ReceivedToken::Error(error)) => {
60 if last_error.is_none() {
61 last_error = Some(error);
62 }
63 }
64 Some(ReceivedToken::Done(token)) => match (last_error, routing) {
65 (Some(error), _) => return Err(Error::Server(error)),
66 (_, Some(routing)) => return Err(routing),
67 (_, _) => return Ok(token),
68 },
69 Some(ReceivedToken::EnvChange(TokenEnvChange::Routing { host, port })) => {
70 routing = Some(Error::Routing { host, port });
71 }
72 Some(_) => (),
73 None => return Err(crate::Error::Protocol("Never got DONE token.".into())),
74 }
75 }
76 }
77
78 #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
79 pub(crate) async fn flush_sspi(self) -> crate::Result<TokenSspi> {
80 let mut stream = self.try_unfold();
81 let mut last_error = None;
82
83 loop {
84 match stream.try_next().await? {
85 Some(ReceivedToken::Error(error)) => {
86 if last_error.is_none() {
87 last_error = Some(error);
88 }
89 }
90 Some(ReceivedToken::Sspi(token)) => return Ok(token),
91 Some(_) => (),
92 None => match last_error {
93 Some(err) => return Err(crate::Error::Server(err)),
94 None => return Err(crate::Error::Protocol("Never got SSPI token.".into())),
95 },
96 }
97 }
98 }
99
100 async fn get_col_metadata(&mut self) -> crate::Result<ReceivedToken> {
101 let meta = Arc::new(TokenColMetaData::decode(self.conn).await?);
102 self.conn.context_mut().set_last_meta(meta.clone());
103
104 event!(Level::TRACE, ?meta);
105
106 Ok(ReceivedToken::NewResultset(meta))
107 }
108
109 async fn get_row(&mut self) -> crate::Result<ReceivedToken> {
110 let return_value = TokenRow::decode(self.conn).await?;
111
112 event!(Level::TRACE, message = ?return_value);
113 Ok(ReceivedToken::Row(return_value))
114 }
115
116 async fn get_nbc_row(&mut self) -> crate::Result<ReceivedToken> {
117 let return_value = TokenRow::decode_nbc(self.conn).await?;
118
119 event!(Level::TRACE, message = ?return_value);
120 Ok(ReceivedToken::Row(return_value))
121 }
122
123 async fn get_return_value(&mut self) -> crate::Result<ReceivedToken> {
124 let return_value = TokenReturnValue::decode(self.conn).await?;
125 event!(Level::TRACE, message = ?return_value);
126 Ok(ReceivedToken::ReturnValue(return_value))
127 }
128
129 async fn get_return_status(&mut self) -> crate::Result<ReceivedToken> {
130 let status = self.conn.read_u32_le().await?;
131 Ok(ReceivedToken::ReturnStatus(status))
132 }
133
134 async fn get_error(&mut self) -> crate::Result<ReceivedToken> {
135 let err = TokenError::decode(self.conn).await?;
136
137 if self.last_error.is_none() {
138 self.last_error = Some(Error::Server(err.clone()));
139 }
140
141 event!(Level::ERROR, message = %err.message, code = err.code);
142 Ok(ReceivedToken::Error(err))
143 }
144
145 async fn get_order(&mut self) -> crate::Result<ReceivedToken> {
146 let order = TokenOrder::decode(self.conn).await?;
147 event!(Level::TRACE, message = ?order);
148 Ok(ReceivedToken::Order(order))
149 }
150
151 async fn get_done_value(&mut self) -> crate::Result<ReceivedToken> {
152 let done = TokenDone::decode(self.conn).await?;
153 event!(Level::TRACE, "{}", done);
154 Ok(ReceivedToken::Done(done))
155 }
156
157 async fn get_done_proc_value(&mut self) -> crate::Result<ReceivedToken> {
158 let done = TokenDone::decode(self.conn).await?;
159 event!(Level::TRACE, "{}", done);
160 Ok(ReceivedToken::DoneProc(done))
161 }
162
163 async fn get_done_in_proc_value(&mut self) -> crate::Result<ReceivedToken> {
164 let done = TokenDone::decode(self.conn).await?;
165 event!(Level::TRACE, "{}", done);
166 Ok(ReceivedToken::DoneInProc(done))
167 }
168
169 async fn get_env_change(&mut self) -> crate::Result<ReceivedToken> {
170 let change = TokenEnvChange::decode(self.conn).await?;
171
172 match change {
173 TokenEnvChange::PacketSize(new_size, _) => {
174 self.conn.context_mut().set_packet_size(new_size);
175 }
176 TokenEnvChange::BeginTransaction(desc) => {
177 self.conn.context_mut().set_transaction_descriptor(desc);
178 }
179 TokenEnvChange::CommitTransaction
180 | TokenEnvChange::RollbackTransaction
181 | TokenEnvChange::DefectTransaction => {
182 self.conn.context_mut().set_transaction_descriptor([0; 8]);
183 }
184 _ => (),
185 }
186
187 event!(Level::INFO, "{}", change);
188
189 Ok(ReceivedToken::EnvChange(change))
190 }
191
192 async fn get_info(&mut self) -> crate::Result<ReceivedToken> {
193 let info = TokenInfo::decode(self.conn).await?;
194 event!(Level::INFO, "{}", info.message);
195 Ok(ReceivedToken::Info(info))
196 }
197
198 async fn get_login_ack(&mut self) -> crate::Result<ReceivedToken> {
199 let ack = TokenLoginAck::decode(self.conn).await?;
200 event!(Level::INFO, "{} version {}", ack.prog_name, ack.version);
201 Ok(ReceivedToken::LoginAck(ack))
202 }
203
204 async fn get_feature_ext_ack(&mut self) -> crate::Result<ReceivedToken> {
205 let ack = TokenFeatureExtAck::decode(self.conn).await?;
206 event!(
207 Level::INFO,
208 "FeatureExtAck with {} features",
209 ack.features.len()
210 );
211 Ok(ReceivedToken::FeatureExtAck(ack))
212 }
213
214 async fn get_sspi(&mut self) -> crate::Result<ReceivedToken> {
215 let sspi = TokenSspi::decode_async(self.conn).await?;
216 event!(Level::TRACE, "SSPI response");
217 Ok(ReceivedToken::Sspi(sspi))
218 }
219
220 pub fn try_unfold(self) -> BoxStream<'a, crate::Result<ReceivedToken>> {
221 let stream = futures_util::stream::try_unfold(self, |mut this| async move {
222 if this.conn.is_eof() {
223 match this.last_error {
224 None => return Ok(None),
225 Some(error) => return Err(error),
226 }
227 }
228
229 let ty_byte = this.conn.read_u8().await?;
230
231 let ty = TokenType::try_from(ty_byte)
232 .map_err(|_| Error::Protocol(format!("invalid token type {:x}", ty_byte).into()))?;
233
234 let token = match ty {
235 TokenType::ReturnStatus => this.get_return_status().await?,
236 TokenType::ColMetaData => this.get_col_metadata().await?,
237 TokenType::Row => this.get_row().await?,
238 TokenType::NbcRow => this.get_nbc_row().await?,
239 TokenType::Done => this.get_done_value().await?,
240 TokenType::DoneProc => this.get_done_proc_value().await?,
241 TokenType::DoneInProc => this.get_done_in_proc_value().await?,
242 TokenType::ReturnValue => this.get_return_value().await?,
243 TokenType::Error => this.get_error().await?,
244 TokenType::Order => this.get_order().await?,
245 TokenType::EnvChange => this.get_env_change().await?,
246 TokenType::Info => this.get_info().await?,
247 TokenType::LoginAck => this.get_login_ack().await?,
248 TokenType::Sspi => this.get_sspi().await?,
249 TokenType::FeatureExtAck => this.get_feature_ext_ack().await?,
250 _ => panic!("Token {:?} unimplemented!", ty),
251 };
252
253 Ok(Some((token, this)))
254 });
255
256 Box::pin(stream)
257 }
258}