1use std::ops::Deref;
4
5use crate::{
6 error::Error,
7 ffi,
8 types::{ToSql, ToSqlOutput, ValueRef},
9 Connection, DatabaseName, Result, Row,
10};
11
12pub struct Sql {
13 buf: String,
14}
15
16impl Sql {
17 pub fn new() -> Self {
18 Self { buf: String::new() }
19 }
20
21 pub fn push_pragma(&mut self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str) -> Result<()> {
22 self.push_keyword("PRAGMA")?;
23 self.push_space();
24 if let Some(schema_name) = schema_name {
25 self.push_schema_name(schema_name);
26 self.push_dot();
27 }
28 self.push_keyword(pragma_name)
29 }
30
31 pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
32 if !keyword.is_empty() && is_identifier(keyword) {
33 self.buf.push_str(keyword);
34 Ok(())
35 } else {
36 Err(Error::DuckDBFailure(
37 ffi::Error::new(ffi::DuckDBError),
38 Some(format!("Invalid keyword \"{keyword}\"")),
39 ))
40 }
41 }
42
43 pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) {
44 match schema_name {
45 DatabaseName::Main => self.buf.push_str("main"),
46 DatabaseName::Temp => self.buf.push_str("temp"),
47 DatabaseName::Attached(s) => self.push_identifier(s),
48 };
49 }
50
51 pub fn push_identifier(&mut self, s: &str) {
52 if is_identifier(s) {
53 self.buf.push_str(s);
54 } else {
55 self.wrap_and_escape(s, '"');
56 }
57 }
58
59 pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
60 let value = value.to_sql()?;
61 let value = match value {
62 ToSqlOutput::Borrowed(v) => v,
63 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
64 };
65 match value {
66 ValueRef::BigInt(i) => {
67 self.push_int(i);
68 }
69 ValueRef::Double(r) => {
70 self.push_real(r);
71 }
72 ValueRef::Text(s) => {
73 let s = std::str::from_utf8(s)?;
74 self.push_string_literal(s);
75 }
76 _ => {
77 return Err(Error::DuckDBFailure(
78 ffi::Error::new(ffi::DuckDBError),
79 Some(format!("Unsupported value \"{value:?}\"")),
80 ));
81 }
82 };
83 Ok(())
84 }
85
86 pub fn push_string_literal(&mut self, s: &str) {
87 self.wrap_and_escape(s, '\'');
88 }
89
90 pub fn push_int(&mut self, i: i64) {
91 self.buf.push_str(&i.to_string());
92 }
93
94 pub fn push_real(&mut self, f: f64) {
95 self.buf.push_str(&f.to_string());
96 }
97
98 pub fn push_space(&mut self) {
99 self.buf.push(' ');
100 }
101
102 pub fn push_dot(&mut self) {
103 self.buf.push('.');
104 }
105
106 pub fn push_equal_sign(&mut self) {
107 self.buf.push('=');
108 }
109
110 pub fn open_brace(&mut self) {
111 self.buf.push('(');
112 }
113
114 pub fn close_brace(&mut self) {
115 self.buf.push(')');
116 }
117
118 pub fn as_str(&self) -> &str {
119 &self.buf
120 }
121
122 fn wrap_and_escape(&mut self, s: &str, quote: char) {
123 self.buf.push(quote);
124 let chars = s.chars();
125 for ch in chars {
126 if ch == quote {
128 self.buf.push(ch);
129 }
130 self.buf.push(ch)
131 }
132 self.buf.push(quote);
133 }
134}
135
136impl Deref for Sql {
137 type Target = str;
138
139 fn deref(&self) -> &str {
140 self.as_str()
141 }
142}
143
144impl Connection {
145 pub fn pragma_query_value<T, F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, f: F) -> Result<T>
150 where
151 F: FnOnce(&Row<'_>) -> Result<T>,
152 {
153 let mut query = Sql::new();
154 query.push_pragma(schema_name, pragma_name)?;
155 self.query_row(&query, [], f)
156 }
157
158 pub fn pragma_query<F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, mut f: F) -> Result<()>
160 where
161 F: FnMut(&Row<'_>) -> Result<()>,
162 {
163 let mut query = Sql::new();
164 query.push_pragma(schema_name, pragma_name)?;
165 let mut stmt = self.prepare(&query)?;
166 let mut rows = stmt.query([])?;
167 while let Some(result_row) = rows.next()? {
168 f(result_row)?;
169 }
170 Ok(())
171 }
172
173 pub fn pragma<F>(
179 &self,
180 schema_name: Option<DatabaseName<'_>>,
181 pragma_name: &str,
182 pragma_value: &dyn ToSql,
183 mut f: F,
184 ) -> Result<()>
185 where
186 F: FnMut(&Row<'_>) -> Result<()>,
187 {
188 let mut sql = Sql::new();
189 sql.push_pragma(schema_name, pragma_name)?;
190 sql.open_brace();
194 sql.push_value(pragma_value)?;
195 sql.close_brace();
196 let mut stmt = self.prepare(&sql)?;
197 let mut rows = stmt.query([])?;
198 while let Some(result_row) = rows.next()? {
199 let row = result_row;
200 f(row)?;
201 }
202 Ok(())
203 }
204
205 pub fn pragma_update(
210 &self,
211 schema_name: Option<DatabaseName<'_>>,
212 pragma_name: &str,
213 pragma_value: &dyn ToSql,
214 ) -> Result<()> {
215 let mut sql = Sql::new();
216 sql.push_pragma(schema_name, pragma_name)?;
217 sql.push_equal_sign();
221 sql.push_value(pragma_value)?;
222 self.execute_batch(&sql)
223 }
224
225 pub fn pragma_update_and_check<F, T>(
229 &self,
230 schema_name: Option<DatabaseName<'_>>,
231 pragma_name: &str,
232 pragma_value: &dyn ToSql,
233 f: F,
234 ) -> Result<T>
235 where
236 F: FnOnce(&Row<'_>) -> Result<T>,
237 {
238 let mut sql = Sql::new();
239 sql.push_pragma(schema_name, pragma_name)?;
240 sql.push_equal_sign();
244 sql.push_value(pragma_value)?;
245 self.query_row(&sql, [], f)
246 }
247}
248
249fn is_identifier(s: &str) -> bool {
250 let chars = s.char_indices();
251 for (i, ch) in chars {
252 if i == 0 {
253 if !is_identifier_start(ch) {
254 return false;
255 }
256 } else if !is_identifier_continue(ch) {
257 return false;
258 }
259 }
260 true
261}
262
263fn is_identifier_start(c: char) -> bool {
264 c.is_ascii_alphabetic() || c == '_' || c > '\x7F'
265}
266
267fn is_identifier_continue(c: char) -> bool {
268 c == '$' || c.is_ascii_alphanumeric() || c == '_' || c > '\x7F'
269}
270
271#[cfg(test)]
272mod test {
273 use super::Sql;
274 use crate::{pragma, Connection, DatabaseName, Result};
275
276 #[test]
277 fn pragma_query_value() -> Result<()> {
278 let db = Connection::open_in_memory()?;
279 let version: String = db.pragma_query_value(None, "version", |row| row.get(0))?;
280 assert!(!version.is_empty());
281 Ok(())
282 }
283
284 #[test]
285 fn pragma_query_with_schema() -> Result<()> {
286 let db = Connection::open_in_memory()?;
287 let res = db.pragma_query(Some(DatabaseName::Main), "version", |_| Ok(()));
288 assert_eq!(
289 res.unwrap_err().to_string().lines().next().unwrap(),
290 "Parser Error: syntax error at or near \".\""
291 );
292 Ok(())
293 }
294
295 #[test]
296 fn pragma() -> Result<()> {
297 let db = Connection::open_in_memory()?;
298 let mut columns = Vec::new();
299 db.pragma(None, "table_info", &"sqlite_master", |row| {
300 let column: String = row.get(1)?;
301 columns.push(column);
302 Ok(())
303 })?;
304 assert_eq!(5, columns.len());
305 Ok(())
306 }
307
308 #[test]
309 fn pragma_update() -> Result<()> {
310 let db = Connection::open_in_memory()?;
311 db.pragma_update(None, "explain_output", &"PHYSICAL_ONLY")
312 }
313
314 #[test]
315 fn test_pragma_update_and_check() -> Result<()> {
316 let db = Connection::open_in_memory()?;
317 let res = db.pragma_update_and_check(None, "explain_output", &"OPTIMIZED_ONLY", |_| Ok(()));
318 assert_eq!(res.unwrap_err(), crate::Error::QueryReturnedNoRows);
319 Ok(())
320 }
321
322 #[test]
323 fn is_identifier() {
324 assert!(pragma::is_identifier("full"));
325 assert!(pragma::is_identifier("r2d2"));
326 assert!(!pragma::is_identifier("sp ce"));
327 assert!(!pragma::is_identifier("semi;colon"));
328 }
329
330 #[test]
331 fn double_quote() {
332 let mut sql = Sql::new();
333 sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#));
334 assert_eq!(r#""schema"";--""#, sql.as_str());
335 }
336
337 #[test]
338 fn wrap_and_escape() {
339 let mut sql = Sql::new();
340 sql.push_string_literal("value'; --");
341 assert_eq!("'value''; --'", sql.as_str());
342 }
343}