duckdb/
pragma.rs

1//! Pragma helpers
2
3use std::ops::Deref;
4
5use crate::{
6    error::Error,
7    ffi,
8    types::{ToSql, ToSqlOutput, ValueRef},
9    Connection, DatabaseName, Result, Row,
10};
11
12pub struct Sql {
13    buf: String,
14}
15
16impl Sql {
17    pub fn new() -> Self {
18        Self { buf: String::new() }
19    }
20
21    pub fn push_pragma(&mut self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str) -> Result<()> {
22        self.push_keyword("PRAGMA")?;
23        self.push_space();
24        if let Some(schema_name) = schema_name {
25            self.push_schema_name(schema_name);
26            self.push_dot();
27        }
28        self.push_keyword(pragma_name)
29    }
30
31    pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
32        if !keyword.is_empty() && is_identifier(keyword) {
33            self.buf.push_str(keyword);
34            Ok(())
35        } else {
36            Err(Error::DuckDBFailure(
37                ffi::Error::new(ffi::DuckDBError),
38                Some(format!("Invalid keyword \"{keyword}\"")),
39            ))
40        }
41    }
42
43    pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) {
44        match schema_name {
45            DatabaseName::Main => self.buf.push_str("main"),
46            DatabaseName::Temp => self.buf.push_str("temp"),
47            DatabaseName::Attached(s) => self.push_identifier(s),
48        };
49    }
50
51    pub fn push_identifier(&mut self, s: &str) {
52        if is_identifier(s) {
53            self.buf.push_str(s);
54        } else {
55            self.wrap_and_escape(s, '"');
56        }
57    }
58
59    pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
60        let value = value.to_sql()?;
61        let value = match value {
62            ToSqlOutput::Borrowed(v) => v,
63            ToSqlOutput::Owned(ref v) => ValueRef::from(v),
64        };
65        match value {
66            ValueRef::BigInt(i) => {
67                self.push_int(i);
68            }
69            ValueRef::Double(r) => {
70                self.push_real(r);
71            }
72            ValueRef::Text(s) => {
73                let s = std::str::from_utf8(s)?;
74                self.push_string_literal(s);
75            }
76            _ => {
77                return Err(Error::DuckDBFailure(
78                    ffi::Error::new(ffi::DuckDBError),
79                    Some(format!("Unsupported value \"{value:?}\"")),
80                ));
81            }
82        };
83        Ok(())
84    }
85
86    pub fn push_string_literal(&mut self, s: &str) {
87        self.wrap_and_escape(s, '\'');
88    }
89
90    pub fn push_int(&mut self, i: i64) {
91        self.buf.push_str(&i.to_string());
92    }
93
94    pub fn push_real(&mut self, f: f64) {
95        self.buf.push_str(&f.to_string());
96    }
97
98    pub fn push_space(&mut self) {
99        self.buf.push(' ');
100    }
101
102    pub fn push_dot(&mut self) {
103        self.buf.push('.');
104    }
105
106    pub fn push_equal_sign(&mut self) {
107        self.buf.push('=');
108    }
109
110    pub fn open_brace(&mut self) {
111        self.buf.push('(');
112    }
113
114    pub fn close_brace(&mut self) {
115        self.buf.push(')');
116    }
117
118    pub fn as_str(&self) -> &str {
119        &self.buf
120    }
121
122    fn wrap_and_escape(&mut self, s: &str, quote: char) {
123        self.buf.push(quote);
124        let chars = s.chars();
125        for ch in chars {
126            // escape `quote` by doubling it
127            if ch == quote {
128                self.buf.push(ch);
129            }
130            self.buf.push(ch)
131        }
132        self.buf.push(quote);
133    }
134}
135
136impl Deref for Sql {
137    type Target = str;
138
139    fn deref(&self) -> &str {
140        self.as_str()
141    }
142}
143
144impl Connection {
145    /// Query the current value of `pragma_name`.
146    ///
147    /// Some pragmas will return multiple rows/values which cannot be retrieved
148    /// with this method.
149    pub fn pragma_query_value<T, F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, f: F) -> Result<T>
150    where
151        F: FnOnce(&Row<'_>) -> Result<T>,
152    {
153        let mut query = Sql::new();
154        query.push_pragma(schema_name, pragma_name)?;
155        self.query_row(&query, [], f)
156    }
157
158    /// Query the current rows/values of `pragma_name`.
159    pub fn pragma_query<F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, mut f: F) -> Result<()>
160    where
161        F: FnMut(&Row<'_>) -> Result<()>,
162    {
163        let mut query = Sql::new();
164        query.push_pragma(schema_name, pragma_name)?;
165        let mut stmt = self.prepare(&query)?;
166        let mut rows = stmt.query([])?;
167        while let Some(result_row) = rows.next()? {
168            f(result_row)?;
169        }
170        Ok(())
171    }
172
173    /// Query the current value(s) of `pragma_name` associated to `pragma_value`.
174    ///
175    /// This method can be used with query-only pragmas which need an argument
176    /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s)
177    /// (e.g. `integrity_check`).
178    pub fn pragma<F>(
179        &self,
180        schema_name: Option<DatabaseName<'_>>,
181        pragma_name: &str,
182        pragma_value: &dyn ToSql,
183        mut f: F,
184    ) -> Result<()>
185    where
186        F: FnMut(&Row<'_>) -> Result<()>,
187    {
188        let mut sql = Sql::new();
189        sql.push_pragma(schema_name, pragma_name)?;
190        // The argument may be either in parentheses
191        // or it may be separated from the pragma name by an equal sign.
192        // The two syntaxes yield identical results.
193        sql.open_brace();
194        sql.push_value(pragma_value)?;
195        sql.close_brace();
196        let mut stmt = self.prepare(&sql)?;
197        let mut rows = stmt.query([])?;
198        while let Some(result_row) = rows.next()? {
199            let row = result_row;
200            f(row)?;
201        }
202        Ok(())
203    }
204
205    /// Set a new value to `pragma_name`.
206    ///
207    /// Some pragmas will return the updated value which cannot be retrieved
208    /// with this method.
209    pub fn pragma_update(
210        &self,
211        schema_name: Option<DatabaseName<'_>>,
212        pragma_name: &str,
213        pragma_value: &dyn ToSql,
214    ) -> Result<()> {
215        let mut sql = Sql::new();
216        sql.push_pragma(schema_name, pragma_name)?;
217        // The argument may be either in parentheses
218        // or it may be separated from the pragma name by an equal sign.
219        // The two syntaxes yield identical results.
220        sql.push_equal_sign();
221        sql.push_value(pragma_value)?;
222        self.execute_batch(&sql)
223    }
224
225    /// Set a new value to `pragma_name` and return the updated value.
226    ///
227    /// Only few pragmas automatically return the updated value.
228    pub fn pragma_update_and_check<F, T>(
229        &self,
230        schema_name: Option<DatabaseName<'_>>,
231        pragma_name: &str,
232        pragma_value: &dyn ToSql,
233        f: F,
234    ) -> Result<T>
235    where
236        F: FnOnce(&Row<'_>) -> Result<T>,
237    {
238        let mut sql = Sql::new();
239        sql.push_pragma(schema_name, pragma_name)?;
240        // The argument may be either in parentheses
241        // or it may be separated from the pragma name by an equal sign.
242        // The two syntaxes yield identical results.
243        sql.push_equal_sign();
244        sql.push_value(pragma_value)?;
245        self.query_row(&sql, [], f)
246    }
247}
248
249fn is_identifier(s: &str) -> bool {
250    let chars = s.char_indices();
251    for (i, ch) in chars {
252        if i == 0 {
253            if !is_identifier_start(ch) {
254                return false;
255            }
256        } else if !is_identifier_continue(ch) {
257            return false;
258        }
259    }
260    true
261}
262
263fn is_identifier_start(c: char) -> bool {
264    c.is_ascii_alphabetic() || c == '_' || c > '\x7F'
265}
266
267fn is_identifier_continue(c: char) -> bool {
268    c == '$' || c.is_ascii_alphanumeric() || c == '_' || c > '\x7F'
269}
270
271#[cfg(test)]
272mod test {
273    use super::Sql;
274    use crate::{pragma, Connection, DatabaseName, Result};
275
276    #[test]
277    fn pragma_query_value() -> Result<()> {
278        let db = Connection::open_in_memory()?;
279        let version: String = db.pragma_query_value(None, "version", |row| row.get(0))?;
280        assert!(!version.is_empty());
281        Ok(())
282    }
283
284    #[test]
285    fn pragma_query_with_schema() -> Result<()> {
286        let db = Connection::open_in_memory()?;
287        let res = db.pragma_query(Some(DatabaseName::Main), "version", |_| Ok(()));
288        assert_eq!(
289            res.unwrap_err().to_string().lines().next().unwrap(),
290            "Parser Error: syntax error at or near \".\""
291        );
292        Ok(())
293    }
294
295    #[test]
296    fn pragma() -> Result<()> {
297        let db = Connection::open_in_memory()?;
298        let mut columns = Vec::new();
299        db.pragma(None, "table_info", &"sqlite_master", |row| {
300            let column: String = row.get(1)?;
301            columns.push(column);
302            Ok(())
303        })?;
304        assert_eq!(5, columns.len());
305        Ok(())
306    }
307
308    #[test]
309    fn pragma_update() -> Result<()> {
310        let db = Connection::open_in_memory()?;
311        db.pragma_update(None, "explain_output", &"PHYSICAL_ONLY")
312    }
313
314    #[test]
315    fn test_pragma_update_and_check() -> Result<()> {
316        let db = Connection::open_in_memory()?;
317        let res = db.pragma_update_and_check(None, "explain_output", &"OPTIMIZED_ONLY", |_| Ok(()));
318        assert_eq!(res.unwrap_err(), crate::Error::QueryReturnedNoRows);
319        Ok(())
320    }
321
322    #[test]
323    fn is_identifier() {
324        assert!(pragma::is_identifier("full"));
325        assert!(pragma::is_identifier("r2d2"));
326        assert!(!pragma::is_identifier("sp ce"));
327        assert!(!pragma::is_identifier("semi;colon"));
328    }
329
330    #[test]
331    fn double_quote() {
332        let mut sql = Sql::new();
333        sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#));
334        assert_eq!(r#""schema"";--""#, sql.as_str());
335    }
336
337    #[test]
338    fn wrap_and_escape() {
339        let mut sql = Sql::new();
340        sql.push_string_literal("value'; --");
341        assert_eq!("'value''; --'", sql.as_str());
342    }
343}