tiberius/client.rs
1mod auth;
2mod config;
3mod connection;
4
5mod tls;
6#[cfg(any(
7 feature = "rustls",
8 feature = "native-tls",
9 feature = "vendored-openssl"
10))]
11mod tls_stream;
12
13pub use auth::*;
14pub use config::*;
15pub(crate) use connection::*;
16
17use crate::tds::stream::ReceivedToken;
18use crate::{
19 result::ExecuteResult,
20 tds::{
21 codec::{self, IteratorJoin},
22 stream::{QueryStream, TokenStream},
23 },
24 BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
25};
26use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
27use enumflags2::BitFlags;
28use futures_util::io::{AsyncRead, AsyncWrite};
29use futures_util::stream::TryStreamExt;
30use std::{borrow::Cow, fmt::Debug};
31
32/// `Client` is the main entry point to the SQL Server, providing query
33/// execution capabilities.
34///
35/// A `Client` is created using the [`Config`], defining the needed
36/// connection options and capabilities.
37///
38/// # Example
39///
40/// ```no_run
41/// # use tiberius::{Config, AuthMethod};
42/// use tokio_util::compat::TokioAsyncWriteCompatExt;
43///
44/// # #[tokio::main]
45/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
46/// let mut config = Config::new();
47///
48/// config.host("0.0.0.0");
49/// config.port(1433);
50/// config.authentication(AuthMethod::sql_server("SA", "<Mys3cureP4ssW0rD>"));
51///
52/// let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
53/// tcp.set_nodelay(true)?;
54/// // Client is ready to use.
55/// let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
56/// # Ok(())
57/// # }
58/// ```
59///
60/// [`Config`]: struct.Config.html
61#[derive(Debug)]
62pub struct Client<S: AsyncRead + AsyncWrite + Unpin + Send> {
63 pub(crate) connection: Connection<S>,
64}
65
66impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
67 /// Uses an instance of [`Config`] to specify the connection
68 /// options required to connect to the database using an established
69 /// tcp connection
70 ///
71 /// [`Config`]: struct.Config.html
72 pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
73 Ok(Client {
74 connection: Connection::connect(config, tcp_stream).await?,
75 })
76 }
77
78 /// Executes SQL statements in the SQL Server, returning the number rows
79 /// affected. Useful for `INSERT`, `UPDATE` and `DELETE` statements. The
80 /// `query` can define the parameter placement by annotating them with
81 /// `@PN`, where N is the index of the parameter, starting from `1`. If
82 /// executing multiple queries at a time, delimit them with `;` and refer to
83 /// [`ExecuteResult`] how to get results for the separate queries.
84 ///
85 /// For mapping of Rust types when writing, see the documentation for
86 /// [`ToSql`]. For reading data from the database, see the documentation for
87 /// [`FromSql`].
88 ///
89 /// This API is not quite suitable for dynamic query parameters. In these
90 /// cases using a [`Query`] object might be easier.
91 ///
92 /// # Example
93 ///
94 /// ```no_run
95 /// # use tiberius::Config;
96 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
97 /// # use std::env;
98 /// # #[tokio::main]
99 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
100 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
101 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
102 /// # );
103 /// # let config = Config::from_ado_string(&c_str)?;
104 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
105 /// # tcp.set_nodelay(true)?;
106 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
107 /// let results = client
108 /// .execute(
109 /// "INSERT INTO ##Test (id) VALUES (@P1), (@P2), (@P3)",
110 /// &[&1i32, &2i32, &3i32],
111 /// )
112 /// .await?;
113 /// # Ok(())
114 /// # }
115 /// ```
116 ///
117 /// [`ExecuteResult`]: struct.ExecuteResult.html
118 /// [`ToSql`]: trait.ToSql.html
119 /// [`FromSql`]: trait.FromSql.html
120 /// [`Query`]: struct.Query.html
121 pub async fn execute<'a>(
122 &mut self,
123 query: impl Into<Cow<'a, str>>,
124 params: &[&dyn ToSql],
125 ) -> crate::Result<ExecuteResult> {
126 self.connection.flush_stream().await?;
127 let rpc_params = Self::rpc_params(query);
128
129 let params = params.iter().map(|s| s.to_sql());
130 self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
131 .await?;
132
133 ExecuteResult::new(&mut self.connection).await
134 }
135
136 /// Executes SQL statements in the SQL Server, returning resulting rows.
137 /// Useful for `SELECT` statements. The `query` can define the parameter
138 /// placement by annotating them with `@PN`, where N is the index of the
139 /// parameter, starting from `1`. If executing multiple queries at a time,
140 /// delimit them with `;` and refer to [`QueryStream`] on proper stream
141 /// handling.
142 ///
143 /// For mapping of Rust types when writing, see the documentation for
144 /// [`ToSql`]. For reading data from the database, see the documentation for
145 /// [`FromSql`].
146 ///
147 /// This API can be cumbersome for dynamic query parameters. In these cases,
148 /// if fighting too much with the compiler, using a [`Query`] object might be
149 /// easier.
150 ///
151 /// # Example
152 ///
153 /// ```
154 /// # use tiberius::Config;
155 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
156 /// # use std::env;
157 /// # #[tokio::main]
158 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
160 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
161 /// # );
162 /// # let config = Config::from_ado_string(&c_str)?;
163 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
164 /// # tcp.set_nodelay(true)?;
165 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
166 /// let stream = client
167 /// .query(
168 /// "SELECT @P1, @P2, @P3",
169 /// &[&1i32, &2i32, &3i32],
170 /// )
171 /// .await?;
172 /// # Ok(())
173 /// # }
174 /// ```
175 ///
176 /// [`QueryStream`]: struct.QueryStream.html
177 /// [`Query`]: struct.Query.html
178 /// [`ToSql`]: trait.ToSql.html
179 /// [`FromSql`]: trait.FromSql.html
180 pub async fn query<'a, 'b>(
181 &'a mut self,
182 query: impl Into<Cow<'b, str>>,
183 params: &'b [&'b dyn ToSql],
184 ) -> crate::Result<QueryStream<'a>>
185 where
186 'a: 'b,
187 {
188 self.connection.flush_stream().await?;
189 let rpc_params = Self::rpc_params(query);
190
191 let params = params.iter().map(|p| p.to_sql());
192 self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
193 .await?;
194
195 let ts = TokenStream::new(&mut self.connection);
196 let mut result = QueryStream::new(ts.try_unfold());
197 result.forward_to_metadata().await?;
198
199 Ok(result)
200 }
201
202 /// Execute multiple queries, delimited with `;` and return multiple result
203 /// sets; one for each query.
204 ///
205 /// # Example
206 ///
207 /// ```
208 /// # use tiberius::Config;
209 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
210 /// # use std::env;
211 /// # #[tokio::main]
212 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
213 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
214 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
215 /// # );
216 /// # let config = Config::from_ado_string(&c_str)?;
217 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
218 /// # tcp.set_nodelay(true)?;
219 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
220 /// let row = client.simple_query("SELECT 1 AS col").await?.into_row().await?.unwrap();
221 /// assert_eq!(Some(1i32), row.get("col"));
222 /// # Ok(())
223 /// # }
224 /// ```
225 ///
226 /// # Warning
227 ///
228 /// Do not use this with any user specified input. Please resort to prepared
229 /// statements using the [`query`] method.
230 ///
231 /// [`query`]: #method.query
232 pub async fn simple_query<'a, 'b>(
233 &'a mut self,
234 query: impl Into<Cow<'b, str>>,
235 ) -> crate::Result<QueryStream<'a>>
236 where
237 'a: 'b,
238 {
239 self.connection.flush_stream().await?;
240
241 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
242
243 let id = self.connection.context_mut().next_packet_id();
244 self.connection.send(PacketHeader::batch(id), req).await?;
245
246 let ts = TokenStream::new(&mut self.connection);
247
248 let mut result = QueryStream::new(ts.try_unfold());
249 result.forward_to_metadata().await?;
250
251 Ok(result)
252 }
253
254 /// Execute a `BULK INSERT` statement, efficiantly storing a large number of
255 /// rows to a specified table. Note: make sure the input row follows the same
256 /// schema as the table, otherwise calling `send()` will return an error.
257 ///
258 /// # Example
259 ///
260 /// ```
261 /// # use tiberius::{Config, IntoRow};
262 /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
263 /// # use std::env;
264 /// # #[tokio::main]
265 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
266 /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
267 /// # "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
268 /// # );
269 /// # let config = Config::from_ado_string(&c_str)?;
270 /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
271 /// # tcp.set_nodelay(true)?;
272 /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
273 /// let create_table = r#"
274 /// CREATE TABLE ##bulk_test (
275 /// id INT IDENTITY PRIMARY KEY,
276 /// val INT NOT NULL
277 /// )
278 /// "#;
279 ///
280 /// client.simple_query(create_table).await?;
281 ///
282 /// // Start the bulk insert with the client.
283 /// let mut req = client.bulk_insert("##bulk_test").await?;
284 ///
285 /// for i in [0i32, 1i32, 2i32] {
286 /// let row = (i).into_row();
287 ///
288 /// // The request will handle flushing to the wire in an optimal way,
289 /// // balancing between memory usage and IO performance.
290 /// req.send(row).await?;
291 /// }
292 ///
293 /// // The request must be finalized.
294 /// let res = req.finalize().await?;
295 /// assert_eq!(3, res.total());
296 /// # Ok(())
297 /// # }
298 /// ```
299 pub async fn bulk_insert<'a>(
300 &'a mut self,
301 table: &'a str,
302 ) -> crate::Result<BulkLoadRequest<'a, S>> {
303 // Start the bulk request
304 self.connection.flush_stream().await?;
305
306 // retrieve column metadata from server
307 let query = format!("SELECT TOP 0 * FROM {}", table);
308
309 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
310
311 let id = self.connection.context_mut().next_packet_id();
312 self.connection.send(PacketHeader::batch(id), req).await?;
313
314 let token_stream = TokenStream::new(&mut self.connection).try_unfold();
315
316 let columns = token_stream
317 .try_fold(None, |mut columns, token| async move {
318 if let ReceivedToken::NewResultset(metadata) = token {
319 columns = Some(metadata.columns.clone());
320 };
321
322 Ok(columns)
323 })
324 .await?;
325
326 // now start bulk upload
327 let columns: Vec<_> = columns
328 .ok_or_else(|| {
329 crate::Error::Protocol("expecting column metadata from query but not found".into())
330 })?
331 .into_iter()
332 .filter(|column| column.base.flags.contains(ColumnFlag::Updateable))
333 .collect();
334
335 self.connection.flush_stream().await?;
336 let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
337 let query = format!("INSERT BULK {} ({})", table, col_data);
338
339 let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
340 let id = self.connection.context_mut().next_packet_id();
341
342 self.connection.send(PacketHeader::batch(id), req).await?;
343
344 let ts = TokenStream::new(&mut self.connection);
345 ts.flush_done().await?;
346
347 BulkLoadRequest::new(&mut self.connection, columns)
348 }
349
350 /// Closes this database connection explicitly.
351 pub async fn close(self) -> crate::Result<()> {
352 self.connection.close().await
353 }
354
355 pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
356 vec![
357 RpcParam {
358 name: Cow::Borrowed("stmt"),
359 flags: BitFlags::empty(),
360 value: ColumnData::String(Some(query.into())),
361 },
362 RpcParam {
363 name: Cow::Borrowed("params"),
364 flags: BitFlags::empty(),
365 value: ColumnData::I32(Some(0)),
366 },
367 ]
368 }
369
370 pub(crate) async fn rpc_perform_query<'a, 'b>(
371 &'a mut self,
372 proc_id: RpcProcId,
373 mut rpc_params: Vec<RpcParam<'b>>,
374 params: impl Iterator<Item = ColumnData<'b>>,
375 ) -> crate::Result<()>
376 where
377 'a: 'b,
378 {
379 let mut param_str = String::new();
380
381 for (i, param) in params.enumerate() {
382 if i > 0 {
383 param_str.push(',')
384 }
385 param_str.push_str(&format!("@P{} ", i + 1));
386 param_str.push_str(¶m.type_name());
387
388 rpc_params.push(RpcParam {
389 name: Cow::Owned(format!("@P{}", i + 1)),
390 flags: BitFlags::empty(),
391 value: param,
392 });
393 }
394
395 if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") {
396 params.value = ColumnData::String(Some(param_str.into()));
397 }
398
399 let req = TokenRpcRequest::new(
400 proc_id,
401 rpc_params,
402 self.connection.context().transaction_descriptor(),
403 );
404
405 let id = self.connection.context_mut().next_packet_id();
406 self.connection.send(PacketHeader::rpc(id), req).await?;
407
408 Ok(())
409 }
410}