tiberius/tds/stream/
token.rs

1use crate::tds::codec::TokenSspi;
2use crate::{
3    client::Connection,
4    tds::codec::{
5        TokenColMetaData, TokenDone, TokenEnvChange, TokenError, TokenFeatureExtAck, TokenInfo,
6        TokenLoginAck, TokenOrder, TokenReturnValue, TokenRow,
7    },
8    Error, SqlReadBytes, TokenType,
9};
10use futures_util::{
11    io::{AsyncRead, AsyncWrite},
12    stream::{BoxStream, TryStreamExt},
13};
14use std::{convert::TryFrom, sync::Arc};
15use tracing::{event, Level};
16
17#[derive(Debug)]
18#[allow(dead_code)]
19pub enum ReceivedToken {
20    NewResultset(Arc<TokenColMetaData<'static>>),
21    Row(TokenRow<'static>),
22    Done(TokenDone),
23    DoneInProc(TokenDone),
24    DoneProc(TokenDone),
25    ReturnStatus(u32),
26    ReturnValue(TokenReturnValue),
27    Order(TokenOrder),
28    EnvChange(TokenEnvChange),
29    Info(TokenInfo),
30    LoginAck(TokenLoginAck),
31    Sspi(TokenSspi),
32    FeatureExtAck(TokenFeatureExtAck),
33    Error(TokenError),
34}
35
36pub(crate) struct TokenStream<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
37    conn: &'a mut Connection<S>,
38    last_error: Option<Error>,
39}
40
41impl<'a, S> TokenStream<'a, S>
42where
43    S: AsyncRead + AsyncWrite + Unpin + Send,
44{
45    pub(crate) fn new(conn: &'a mut Connection<S>) -> Self {
46        Self {
47            conn,
48            last_error: None,
49        }
50    }
51
52    pub(crate) async fn flush_done(self) -> crate::Result<TokenDone> {
53        let mut stream = self.try_unfold();
54        let mut last_error = None;
55        let mut routing = None;
56
57        loop {
58            match stream.try_next().await? {
59                Some(ReceivedToken::Error(error)) => {
60                    if last_error.is_none() {
61                        last_error = Some(error);
62                    }
63                }
64                Some(ReceivedToken::Done(token)) => match (last_error, routing) {
65                    (Some(error), _) => return Err(Error::Server(error)),
66                    (_, Some(routing)) => return Err(routing),
67                    (_, _) => return Ok(token),
68                },
69                Some(ReceivedToken::EnvChange(TokenEnvChange::Routing { host, port })) => {
70                    routing = Some(Error::Routing { host, port });
71                }
72                Some(_) => (),
73                None => return Err(crate::Error::Protocol("Never got DONE token.".into())),
74            }
75        }
76    }
77
78    #[cfg(any(windows, feature = "integrated-auth-gssapi"))]
79    pub(crate) async fn flush_sspi(self) -> crate::Result<TokenSspi> {
80        let mut stream = self.try_unfold();
81        let mut last_error = None;
82
83        loop {
84            match stream.try_next().await? {
85                Some(ReceivedToken::Error(error)) => {
86                    if last_error.is_none() {
87                        last_error = Some(error);
88                    }
89                }
90                Some(ReceivedToken::Sspi(token)) => return Ok(token),
91                Some(_) => (),
92                None => match last_error {
93                    Some(err) => return Err(crate::Error::Server(err)),
94                    None => return Err(crate::Error::Protocol("Never got SSPI token.".into())),
95                },
96            }
97        }
98    }
99
100    async fn get_col_metadata(&mut self) -> crate::Result<ReceivedToken> {
101        let meta = Arc::new(TokenColMetaData::decode(self.conn).await?);
102        self.conn.context_mut().set_last_meta(meta.clone());
103
104        event!(Level::TRACE, ?meta);
105
106        Ok(ReceivedToken::NewResultset(meta))
107    }
108
109    async fn get_row(&mut self) -> crate::Result<ReceivedToken> {
110        let return_value = TokenRow::decode(self.conn).await?;
111
112        event!(Level::TRACE, message = ?return_value);
113        Ok(ReceivedToken::Row(return_value))
114    }
115
116    async fn get_nbc_row(&mut self) -> crate::Result<ReceivedToken> {
117        let return_value = TokenRow::decode_nbc(self.conn).await?;
118
119        event!(Level::TRACE, message = ?return_value);
120        Ok(ReceivedToken::Row(return_value))
121    }
122
123    async fn get_return_value(&mut self) -> crate::Result<ReceivedToken> {
124        let return_value = TokenReturnValue::decode(self.conn).await?;
125        event!(Level::TRACE, message = ?return_value);
126        Ok(ReceivedToken::ReturnValue(return_value))
127    }
128
129    async fn get_return_status(&mut self) -> crate::Result<ReceivedToken> {
130        let status = self.conn.read_u32_le().await?;
131        Ok(ReceivedToken::ReturnStatus(status))
132    }
133
134    async fn get_error(&mut self) -> crate::Result<ReceivedToken> {
135        let err = TokenError::decode(self.conn).await?;
136
137        if self.last_error.is_none() {
138            self.last_error = Some(Error::Server(err.clone()));
139        }
140
141        event!(Level::ERROR, message = %err.message, code = err.code);
142        Ok(ReceivedToken::Error(err))
143    }
144
145    async fn get_order(&mut self) -> crate::Result<ReceivedToken> {
146        let order = TokenOrder::decode(self.conn).await?;
147        event!(Level::TRACE, message = ?order);
148        Ok(ReceivedToken::Order(order))
149    }
150
151    async fn get_done_value(&mut self) -> crate::Result<ReceivedToken> {
152        let done = TokenDone::decode(self.conn).await?;
153        event!(Level::TRACE, "{}", done);
154        Ok(ReceivedToken::Done(done))
155    }
156
157    async fn get_done_proc_value(&mut self) -> crate::Result<ReceivedToken> {
158        let done = TokenDone::decode(self.conn).await?;
159        event!(Level::TRACE, "{}", done);
160        Ok(ReceivedToken::DoneProc(done))
161    }
162
163    async fn get_done_in_proc_value(&mut self) -> crate::Result<ReceivedToken> {
164        let done = TokenDone::decode(self.conn).await?;
165        event!(Level::TRACE, "{}", done);
166        Ok(ReceivedToken::DoneInProc(done))
167    }
168
169    async fn get_env_change(&mut self) -> crate::Result<ReceivedToken> {
170        let change = TokenEnvChange::decode(self.conn).await?;
171
172        match change {
173            TokenEnvChange::PacketSize(new_size, _) => {
174                self.conn.context_mut().set_packet_size(new_size);
175            }
176            TokenEnvChange::BeginTransaction(desc) => {
177                self.conn.context_mut().set_transaction_descriptor(desc);
178            }
179            TokenEnvChange::CommitTransaction
180            | TokenEnvChange::RollbackTransaction
181            | TokenEnvChange::DefectTransaction => {
182                self.conn.context_mut().set_transaction_descriptor([0; 8]);
183            }
184            _ => (),
185        }
186
187        event!(Level::INFO, "{}", change);
188
189        Ok(ReceivedToken::EnvChange(change))
190    }
191
192    async fn get_info(&mut self) -> crate::Result<ReceivedToken> {
193        let info = TokenInfo::decode(self.conn).await?;
194        event!(Level::INFO, "{}", info.message);
195        Ok(ReceivedToken::Info(info))
196    }
197
198    async fn get_login_ack(&mut self) -> crate::Result<ReceivedToken> {
199        let ack = TokenLoginAck::decode(self.conn).await?;
200        event!(Level::INFO, "{} version {}", ack.prog_name, ack.version);
201        Ok(ReceivedToken::LoginAck(ack))
202    }
203
204    async fn get_feature_ext_ack(&mut self) -> crate::Result<ReceivedToken> {
205        let ack = TokenFeatureExtAck::decode(self.conn).await?;
206        event!(
207            Level::INFO,
208            "FeatureExtAck with {} features",
209            ack.features.len()
210        );
211        Ok(ReceivedToken::FeatureExtAck(ack))
212    }
213
214    async fn get_sspi(&mut self) -> crate::Result<ReceivedToken> {
215        let sspi = TokenSspi::decode_async(self.conn).await?;
216        event!(Level::TRACE, "SSPI response");
217        Ok(ReceivedToken::Sspi(sspi))
218    }
219
220    pub fn try_unfold(self) -> BoxStream<'a, crate::Result<ReceivedToken>> {
221        let stream = futures_util::stream::try_unfold(self, |mut this| async move {
222            if this.conn.is_eof() {
223                match this.last_error {
224                    None => return Ok(None),
225                    Some(error) => return Err(error),
226                }
227            }
228
229            let ty_byte = this.conn.read_u8().await?;
230
231            let ty = TokenType::try_from(ty_byte)
232                .map_err(|_| Error::Protocol(format!("invalid token type {:x}", ty_byte).into()))?;
233
234            let token = match ty {
235                TokenType::ReturnStatus => this.get_return_status().await?,
236                TokenType::ColMetaData => this.get_col_metadata().await?,
237                TokenType::Row => this.get_row().await?,
238                TokenType::NbcRow => this.get_nbc_row().await?,
239                TokenType::Done => this.get_done_value().await?,
240                TokenType::DoneProc => this.get_done_proc_value().await?,
241                TokenType::DoneInProc => this.get_done_in_proc_value().await?,
242                TokenType::ReturnValue => this.get_return_value().await?,
243                TokenType::Error => this.get_error().await?,
244                TokenType::Order => this.get_order().await?,
245                TokenType::EnvChange => this.get_env_change().await?,
246                TokenType::Info => this.get_info().await?,
247                TokenType::LoginAck => this.get_login_ack().await?,
248                TokenType::Sspi => this.get_sspi().await?,
249                TokenType::FeatureExtAck => this.get_feature_ext_ack().await?,
250                _ => panic!("Token {:?} unimplemented!", ty),
251            };
252
253            Ok(Some((token, this)))
254        });
255
256        Box::pin(stream)
257    }
258}