tiberius/
client.rs

1mod auth;
2mod config;
3mod connection;
4
5mod tls;
6#[cfg(any(
7    feature = "rustls",
8    feature = "native-tls",
9    feature = "vendored-openssl"
10))]
11mod tls_stream;
12
13pub use auth::*;
14pub use config::*;
15pub(crate) use connection::*;
16
17use crate::tds::stream::ReceivedToken;
18use crate::{
19    result::ExecuteResult,
20    tds::{
21        codec::{self, IteratorJoin},
22        stream::{QueryStream, TokenStream},
23    },
24    BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
25};
26use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
27use enumflags2::BitFlags;
28use futures_util::io::{AsyncRead, AsyncWrite};
29use futures_util::stream::TryStreamExt;
30use std::{borrow::Cow, fmt::Debug};
31
32/// `Client` is the main entry point to the SQL Server, providing query
33/// execution capabilities.
34///
35/// A `Client` is created using the [`Config`], defining the needed
36/// connection options and capabilities.
37///
38/// # Example
39///
40/// ```no_run
41/// # use tiberius::{Config, AuthMethod};
42/// use tokio_util::compat::TokioAsyncWriteCompatExt;
43///
44/// # #[tokio::main]
45/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
46/// let mut config = Config::new();
47///
48/// config.host("0.0.0.0");
49/// config.port(1433);
50/// config.authentication(AuthMethod::sql_server("SA", "<Mys3cureP4ssW0rD>"));
51///
52/// let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
53/// tcp.set_nodelay(true)?;
54/// // Client is ready to use.
55/// let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
56/// # Ok(())
57/// # }
58/// ```
59///
60/// [`Config`]: struct.Config.html
61#[derive(Debug)]
62pub struct Client<S: AsyncRead + AsyncWrite + Unpin + Send> {
63    pub(crate) connection: Connection<S>,
64}
65
66impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
67    /// Uses an instance of [`Config`] to specify the connection
68    /// options required to connect to the database using an established
69    /// tcp connection
70    ///
71    /// [`Config`]: struct.Config.html
72    pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
73        Ok(Client {
74            connection: Connection::connect(config, tcp_stream).await?,
75        })
76    }
77
78    /// Executes SQL statements in the SQL Server, returning the number rows
79    /// affected. Useful for `INSERT`, `UPDATE` and `DELETE` statements. The
80    /// `query` can define the parameter placement by annotating them with
81    /// `@PN`, where N is the index of the parameter, starting from `1`. If
82    /// executing multiple queries at a time, delimit them with `;` and refer to
83    /// [`ExecuteResult`] how to get results for the separate queries.
84    ///
85    /// For mapping of Rust types when writing, see the documentation for
86    /// [`ToSql`]. For reading data from the database, see the documentation for
87    /// [`FromSql`].
88    ///
89    /// This API is not quite suitable for dynamic query parameters. In these
90    /// cases using a [`Query`] object might be easier.
91    ///
92    /// # Example
93    ///
94    /// ```no_run
95    /// # use tiberius::Config;
96    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
97    /// # use std::env;
98    /// # #[tokio::main]
99    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
100    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
101    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
102    /// # );
103    /// # let config = Config::from_ado_string(&c_str)?;
104    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
105    /// # tcp.set_nodelay(true)?;
106    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
107    /// let results = client
108    ///     .execute(
109    ///         "INSERT INTO ##Test (id) VALUES (@P1), (@P2), (@P3)",
110    ///         &[&1i32, &2i32, &3i32],
111    ///     )
112    ///     .await?;
113    /// # Ok(())
114    /// # }
115    /// ```
116    ///
117    /// [`ExecuteResult`]: struct.ExecuteResult.html
118    /// [`ToSql`]: trait.ToSql.html
119    /// [`FromSql`]: trait.FromSql.html
120    /// [`Query`]: struct.Query.html
121    pub async fn execute<'a>(
122        &mut self,
123        query: impl Into<Cow<'a, str>>,
124        params: &[&dyn ToSql],
125    ) -> crate::Result<ExecuteResult> {
126        self.connection.flush_stream().await?;
127        let rpc_params = Self::rpc_params(query);
128
129        let params = params.iter().map(|s| s.to_sql());
130        self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
131            .await?;
132
133        ExecuteResult::new(&mut self.connection).await
134    }
135
136    /// Executes SQL statements in the SQL Server, returning resulting rows.
137    /// Useful for `SELECT` statements. The `query` can define the parameter
138    /// placement by annotating them with `@PN`, where N is the index of the
139    /// parameter, starting from `1`. If executing multiple queries at a time,
140    /// delimit them with `;` and refer to [`QueryStream`] on proper stream
141    /// handling.
142    ///
143    /// For mapping of Rust types when writing, see the documentation for
144    /// [`ToSql`]. For reading data from the database, see the documentation for
145    /// [`FromSql`].
146    ///
147    /// This API can be cumbersome for dynamic query parameters. In these cases,
148    /// if fighting too much with the compiler, using a [`Query`] object might be
149    /// easier.
150    ///
151    /// # Example
152    ///
153    /// ```
154    /// # use tiberius::Config;
155    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
156    /// # use std::env;
157    /// # #[tokio::main]
158    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
160    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
161    /// # );
162    /// # let config = Config::from_ado_string(&c_str)?;
163    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
164    /// # tcp.set_nodelay(true)?;
165    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
166    /// let stream = client
167    ///     .query(
168    ///         "SELECT @P1, @P2, @P3",
169    ///         &[&1i32, &2i32, &3i32],
170    ///     )
171    ///     .await?;
172    /// # Ok(())
173    /// # }
174    /// ```
175    ///
176    /// [`QueryStream`]: struct.QueryStream.html
177    /// [`Query`]: struct.Query.html
178    /// [`ToSql`]: trait.ToSql.html
179    /// [`FromSql`]: trait.FromSql.html
180    pub async fn query<'a, 'b>(
181        &'a mut self,
182        query: impl Into<Cow<'b, str>>,
183        params: &'b [&'b dyn ToSql],
184    ) -> crate::Result<QueryStream<'a>>
185    where
186        'a: 'b,
187    {
188        self.connection.flush_stream().await?;
189        let rpc_params = Self::rpc_params(query);
190
191        let params = params.iter().map(|p| p.to_sql());
192        self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
193            .await?;
194
195        let ts = TokenStream::new(&mut self.connection);
196        let mut result = QueryStream::new(ts.try_unfold());
197        result.forward_to_metadata().await?;
198
199        Ok(result)
200    }
201
202    /// Execute multiple queries, delimited with `;` and return multiple result
203    /// sets; one for each query.
204    ///
205    /// # Example
206    ///
207    /// ```
208    /// # use tiberius::Config;
209    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
210    /// # use std::env;
211    /// # #[tokio::main]
212    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
213    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
214    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
215    /// # );
216    /// # let config = Config::from_ado_string(&c_str)?;
217    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
218    /// # tcp.set_nodelay(true)?;
219    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
220    /// let row = client.simple_query("SELECT 1 AS col").await?.into_row().await?.unwrap();
221    /// assert_eq!(Some(1i32), row.get("col"));
222    /// # Ok(())
223    /// # }
224    /// ```
225    ///
226    /// # Warning
227    ///
228    /// Do not use this with any user specified input. Please resort to prepared
229    /// statements using the [`query`] method.
230    ///
231    /// [`query`]: #method.query
232    pub async fn simple_query<'a, 'b>(
233        &'a mut self,
234        query: impl Into<Cow<'b, str>>,
235    ) -> crate::Result<QueryStream<'a>>
236    where
237        'a: 'b,
238    {
239        self.connection.flush_stream().await?;
240
241        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
242
243        let id = self.connection.context_mut().next_packet_id();
244        self.connection.send(PacketHeader::batch(id), req).await?;
245
246        let ts = TokenStream::new(&mut self.connection);
247
248        let mut result = QueryStream::new(ts.try_unfold());
249        result.forward_to_metadata().await?;
250
251        Ok(result)
252    }
253
254    /// Execute a `BULK INSERT` statement, efficiantly storing a large number of
255    /// rows to a specified table. Note: make sure the input row follows the same
256    /// schema as the table, otherwise calling `send()` will return an error.
257    ///
258    /// # Example
259    ///
260    /// ```
261    /// # use tiberius::{Config, IntoRow};
262    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
263    /// # use std::env;
264    /// # #[tokio::main]
265    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
266    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
267    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
268    /// # );
269    /// # let config = Config::from_ado_string(&c_str)?;
270    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
271    /// # tcp.set_nodelay(true)?;
272    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
273    /// let create_table = r#"
274    ///     CREATE TABLE ##bulk_test (
275    ///         id INT IDENTITY PRIMARY KEY,
276    ///         val INT NOT NULL
277    ///     )
278    /// "#;
279    ///
280    /// client.simple_query(create_table).await?;
281    ///
282    /// // Start the bulk insert with the client.
283    /// let mut req = client.bulk_insert("##bulk_test").await?;
284    ///
285    /// for i in [0i32, 1i32, 2i32] {
286    ///     let row = (i).into_row();
287    ///
288    ///     // The request will handle flushing to the wire in an optimal way,
289    ///     // balancing between memory usage and IO performance.
290    ///     req.send(row).await?;
291    /// }
292    ///
293    /// // The request must be finalized.
294    /// let res = req.finalize().await?;
295    /// assert_eq!(3, res.total());
296    /// # Ok(())
297    /// # }
298    /// ```
299    pub async fn bulk_insert<'a>(
300        &'a mut self,
301        table: &'a str,
302    ) -> crate::Result<BulkLoadRequest<'a, S>> {
303        // Start the bulk request
304        self.connection.flush_stream().await?;
305
306        // retrieve column metadata from server
307        let query = format!("SELECT TOP 0 * FROM {}", table);
308
309        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
310
311        let id = self.connection.context_mut().next_packet_id();
312        self.connection.send(PacketHeader::batch(id), req).await?;
313
314        let token_stream = TokenStream::new(&mut self.connection).try_unfold();
315
316        let columns = token_stream
317            .try_fold(None, |mut columns, token| async move {
318                if let ReceivedToken::NewResultset(metadata) = token {
319                    columns = Some(metadata.columns.clone());
320                };
321
322                Ok(columns)
323            })
324            .await?;
325
326        // now start bulk upload
327        let columns: Vec<_> = columns
328            .ok_or_else(|| {
329                crate::Error::Protocol("expecting column metadata from query but not found".into())
330            })?
331            .into_iter()
332            .filter(|column| column.base.flags.contains(ColumnFlag::Updateable))
333            .collect();
334
335        self.connection.flush_stream().await?;
336        let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
337        let query = format!("INSERT BULK {} ({})", table, col_data);
338
339        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
340        let id = self.connection.context_mut().next_packet_id();
341
342        self.connection.send(PacketHeader::batch(id), req).await?;
343
344        let ts = TokenStream::new(&mut self.connection);
345        ts.flush_done().await?;
346
347        BulkLoadRequest::new(&mut self.connection, columns)
348    }
349
350    /// Closes this database connection explicitly.
351    pub async fn close(self) -> crate::Result<()> {
352        self.connection.close().await
353    }
354
355    pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
356        vec![
357            RpcParam {
358                name: Cow::Borrowed("stmt"),
359                flags: BitFlags::empty(),
360                value: ColumnData::String(Some(query.into())),
361            },
362            RpcParam {
363                name: Cow::Borrowed("params"),
364                flags: BitFlags::empty(),
365                value: ColumnData::I32(Some(0)),
366            },
367        ]
368    }
369
370    pub(crate) async fn rpc_perform_query<'a, 'b>(
371        &'a mut self,
372        proc_id: RpcProcId,
373        mut rpc_params: Vec<RpcParam<'b>>,
374        params: impl Iterator<Item = ColumnData<'b>>,
375    ) -> crate::Result<()>
376    where
377        'a: 'b,
378    {
379        let mut param_str = String::new();
380
381        for (i, param) in params.enumerate() {
382            if i > 0 {
383                param_str.push(',')
384            }
385            param_str.push_str(&format!("@P{} ", i + 1));
386            param_str.push_str(&param.type_name());
387
388            rpc_params.push(RpcParam {
389                name: Cow::Owned(format!("@P{}", i + 1)),
390                flags: BitFlags::empty(),
391                value: param,
392            });
393        }
394
395        if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") {
396            params.value = ColumnData::String(Some(param_str.into()));
397        }
398
399        let req = TokenRpcRequest::new(
400            proc_id,
401            rpc_params,
402            self.connection.context().transaction_descriptor(),
403        );
404
405        let id = self.connection.context_mut().next_packet_id();
406        self.connection.send(PacketHeader::rpc(id), req).await?;
407
408        Ok(())
409    }
410}