Skip to main content

mz_ore/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Composable, escape-aware SQL fragment building.
17//!
18//! Use [`crate::sql!`]/[`Sql::new`] for trusted SQL text and
19//! [`Sql::ident`]/[`Sql::literal`] to escape values from untrusted sources.
20//! The escape rules are PostgreSQL's.
21
22use std::borrow::Cow;
23use std::fmt::{self, Write};
24
25/// A composable SQL query string.
26///
27/// Use [`crate::sql!`] for static SQL fragments and [`Sql::ident`]/[`Sql::literal`] for
28/// dynamic values. This mirrors psycopg's split between trusted SQL text and escaped
29/// values.
30#[derive(Clone, Debug, Eq, PartialEq)]
31pub struct Sql(Cow<'static, str>);
32
33#[doc(hidden)]
34#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35pub enum SqlTemplateError {
36    InvalidOpenBrace,
37    InvalidCloseBrace,
38}
39
40#[doc(hidden)]
41pub const fn sql_template_placeholder_count(template: &str) -> Result<usize, SqlTemplateError> {
42    let bytes = template.as_bytes();
43    let mut i = 0;
44    let mut count = 0;
45
46    while i < bytes.len() {
47        match bytes[i] {
48            b'{' => {
49                if i + 1 >= bytes.len() {
50                    return Err(SqlTemplateError::InvalidOpenBrace);
51                }
52                match bytes[i + 1] {
53                    b'{' => i += 2,
54                    b'}' => {
55                        count += 1;
56                        i += 2;
57                    }
58                    _ => return Err(SqlTemplateError::InvalidOpenBrace),
59                }
60            }
61            b'}' => {
62                if i + 1 >= bytes.len() {
63                    return Err(SqlTemplateError::InvalidCloseBrace);
64                }
65                match bytes[i + 1] {
66                    b'}' => i += 2,
67                    _ => return Err(SqlTemplateError::InvalidCloseBrace),
68                }
69            }
70            _ => i += 1,
71        }
72    }
73
74    Ok(count)
75}
76
77/// Errors produced by [`Sql::format`].
78#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
79pub enum SqlFormatError {
80    /// The template contained a `{` that was not part of a `{}` placeholder
81    /// or `{{` escape.
82    #[error("SQL format string contains an invalid '{{' sequence")]
83    InvalidOpenBrace,
84    /// The template contained a `}` that was not part of a `{}` placeholder
85    /// or `}}` escape.
86    #[error("SQL format string contains an invalid '}}' sequence")]
87    InvalidCloseBrace,
88    /// The template had more `{}` placeholders than supplied arguments.
89    #[error("SQL format string expected more arguments")]
90    MissingArgument,
91    /// The template had fewer `{}` placeholders than supplied arguments.
92    #[error("SQL format string received too many arguments")]
93    ExtraArgument,
94}
95
96impl Sql {
97    /// Creates a SQL fragment from a static SQL string.
98    pub fn new(sql: &'static str) -> Self {
99        Self(Cow::Borrowed(sql))
100    }
101
102    /// Creates a SQL fragment from an arbitrary owned string, trusting the caller
103    /// that it is safe SQL. Prefer [`Sql::ident`] or [`Sql::literal`] when handling
104    /// untrusted input.
105    pub fn raw_unchecked(sql: String) -> Self {
106        Self(Cow::Owned(sql))
107    }
108
109    /// Wraps the body of an entire SQL request received over an external API
110    /// endpoint that explicitly accepts arbitrary caller-supplied SQL — the
111    /// HTTP/WebSocket SQL API (`/api/sql`, via the [`serde::Deserialize`] impl
112    /// on [`Sql`]) and the MCP query tools.
113    ///
114    /// **Do not use this constructor anywhere else.** It marks the trust
115    /// boundary where a caller hands us a complete SQL request as opaque text.
116    /// Internal code that assembles SQL must build it via [`Sql::new`],
117    /// [`Sql::ident`], [`Sql::literal`], the [`crate::sql!`] macro, or — when the
118    /// SQL is already trusted by other means — [`Sql::raw_unchecked`].
119    pub fn trusted_external_request(sql: String) -> Self {
120        Self(Cow::Owned(sql))
121    }
122
123    /// Creates a SQL fragment by escaping a SQL identifier.
124    pub fn ident(ident: &str) -> Self {
125        // PostgreSQL identifiers are escaped by surrounding with double quotes
126        // and doubling any embedded double quotes.
127        let mut out = String::with_capacity(ident.len() + 2);
128        out.push('"');
129        for ch in ident.chars() {
130            if ch == '"' {
131                out.push('"');
132            }
133            out.push(ch);
134        }
135        out.push('"');
136        Self(Cow::Owned(out))
137    }
138
139    /// Creates a SQL fragment by escaping a SQL literal.
140    pub fn literal(literal: &str) -> Self {
141        // PostgreSQL string literals are escaped by surrounding with single
142        // quotes and doubling any embedded single quotes.
143        let mut out = String::with_capacity(literal.len() + 2);
144        out.push('\'');
145        for ch in literal.chars() {
146            if ch == '\'' {
147                out.push('\'');
148            }
149            out.push(ch);
150        }
151        out.push('\'');
152        Self(Cow::Owned(out))
153    }
154
155    /// Creates a SQL fragment for a PostgreSQL positional parameter (e.g. `$1`).
156    pub fn param(index: usize) -> Self {
157        // PostgreSQL parameters are one-based and limited to 65535.
158        assert!(
159            (1..=65535).contains(&index),
160            "PostgreSQL parameter index out of range: {index}"
161        );
162        let mut out = String::new();
163        out.push('$');
164        let _ = write!(out, "{index}");
165        Self(Cow::Owned(out))
166    }
167
168    /// Joins SQL fragments with a static separator.
169    pub fn join(parts: impl IntoIterator<Item = Sql>, separator: &'static str) -> Self {
170        let mut iter = parts.into_iter();
171        let Some(first) = iter.next() else {
172            return Self(Cow::Borrowed(""));
173        };
174
175        let mut out = first.0;
176        for part in iter {
177            out.to_mut().push_str(separator);
178            out.to_mut().push_str(part.as_str());
179        }
180        Self(out)
181    }
182
183    /// Formats this SQL fragment by replacing each `{}` with the next SQL argument.
184    ///
185    /// Use `{{` and `}}` to escape literal braces.
186    pub fn format(self, args: impl IntoIterator<Item = Sql>) -> Result<Self, SqlFormatError> {
187        let mut args = args.into_iter();
188        let mut out = String::with_capacity(self.0.len());
189        let mut chars = self.0.chars().peekable();
190
191        while let Some(ch) = chars.next() {
192            match ch {
193                '{' => match chars.peek() {
194                    Some('{') => {
195                        chars.next();
196                        out.push('{');
197                    }
198                    Some('}') => {
199                        chars.next();
200                        let arg = args.next().ok_or(SqlFormatError::MissingArgument)?;
201                        out.push_str(arg.as_str());
202                    }
203                    _ => return Err(SqlFormatError::InvalidOpenBrace),
204                },
205                '}' => match chars.peek() {
206                    Some('}') => {
207                        chars.next();
208                        out.push('}');
209                    }
210                    _ => return Err(SqlFormatError::InvalidCloseBrace),
211                },
212                _ => out.push(ch),
213            }
214        }
215
216        if args.next().is_some() {
217            return Err(SqlFormatError::ExtraArgument);
218        }
219        Ok(Sql(Cow::Owned(out)))
220    }
221
222    /// Like [`Sql::format`], but panics on invalid input instead of returning
223    /// an error.
224    ///
225    /// The [`crate::sql!`] macro is the only intended caller; it upholds at
226    /// compile time the invariants this function relies on:
227    ///
228    /// * `{` and `}` appear in the template only as `{{` and `}}` escapes or
229    ///   `{}` placeholders.
230    /// * The number of arguments matches the number of `{}` placeholders.
231    ///
232    /// Do not call this directly; use [`Sql::format`] instead, which reports
233    /// violations as errors.
234    #[doc(hidden)]
235    pub fn format_unchecked(self, args: impl IntoIterator<Item = Sql>) -> Self {
236        let mut args = args.into_iter();
237        let mut out = String::with_capacity(self.0.len());
238        let mut chars = self.0.chars().peekable();
239
240        while let Some(ch) = chars.next() {
241            match ch {
242                '{' => match chars.next().expect("validated in sql! macro") {
243                    '{' => out.push('{'),
244                    '}' => {
245                        let arg = args.next().expect("validated in sql! macro");
246                        out.push_str(arg.as_str());
247                    }
248                    _ => unreachable!("validated in sql! macro"),
249                },
250                '}' => match chars.next().expect("validated in sql! macro") {
251                    '}' => out.push('}'),
252                    _ => unreachable!("validated in sql! macro"),
253                },
254                _ => out.push(ch),
255            }
256        }
257        if cfg!(debug_assertions) {
258            assert!(args.next().is_none(), "validated in sql! macro");
259        }
260        Sql(Cow::Owned(out))
261    }
262
263    /// Returns the underlying SQL string.
264    pub fn as_str(&self) -> &str {
265        self.0.as_ref()
266    }
267
268    /// Consumes this value and returns the SQL string.
269    pub fn into_string(self) -> String {
270        self.0.into_owned()
271    }
272}
273
274impl fmt::Display for Sql {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        f.write_str(self.as_str())
277    }
278}
279
280impl serde::Serialize for Sql {
281    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
282        self.as_str().serialize(serializer)
283    }
284}
285
286impl<'de> serde::Deserialize<'de> for Sql {
287    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
288        // Deserialization is the trust boundary for SQL arriving over the
289        // external `/api/sql` endpoint. Any other path that produces a `Sql`
290        // value should use the safe constructors (`Sql::new`, `Sql::ident`,
291        // `Sql::literal`, the `sql!` macro) or `Sql::raw_unchecked`.
292        Ok(Sql::trusted_external_request(String::deserialize(
293            deserializer,
294        )?))
295    }
296}
297
298macro_rules! impl_from_integer_for_sql {
299    ($($t:ty),+ $(,)?) => {
300        $(
301            impl From<$t> for Sql {
302                fn from(value: $t) -> Self {
303                    Sql(Cow::Owned(value.to_string()))
304                }
305            }
306        )+
307    };
308}
309
310impl_from_integer_for_sql!(i16, i32, i64, isize, u16, u32, u64, usize);
311
312/// Builds a [`Sql`] fragment from a static template and SQL arguments.
313///
314/// The template uses `{}` placeholders for SQL arguments and `{{`/`}}` to
315/// escape literal braces. Argument count is verified at compile time.
316#[macro_export]
317macro_rules! sql {
318    ($template:literal $(,)?) => {
319        $crate::sql::Sql::new($template)
320    };
321    ($template:literal, $($arg:expr),+ $(,)?) => {{
322        const __SQL_FORMAT_ARG_COUNT: usize = <[()]>::len(&[$($crate::sql!(@unit $arg)),*]);
323        const __SQL_FORMAT_PLACEHOLDER_COUNT: usize =
324            match $crate::sql::sql_template_placeholder_count($template) {
325                Ok(n) => n,
326                Err($crate::sql::SqlTemplateError::InvalidOpenBrace) => {
327                    panic!("sql!: invalid '{{' in SQL template")
328                }
329                Err($crate::sql::SqlTemplateError::InvalidCloseBrace) => {
330                    panic!("sql!: invalid '}}' in SQL template")
331                }
332            };
333        const _: () = {
334            if __SQL_FORMAT_ARG_COUNT != __SQL_FORMAT_PLACEHOLDER_COUNT {
335                panic!("sql!: placeholder count does not match arguments");
336            }
337        };
338
339        $crate::sql::Sql::new($template).format_unchecked([$($crate::sql::Sql::from($arg)),*])
340    }};
341    (@unit $_arg:expr) => { () };
342}
343
344#[cfg(test)]
345mod tests {
346    use super::{Sql, SqlFormatError};
347
348    #[mz_ore::test]
349    fn sql_identifier_escaping() {
350        assert_eq!(Sql::ident("a").as_str(), "\"a\"");
351        assert_eq!(Sql::ident("a\"b").as_str(), "\"a\"\"b\"");
352    }
353
354    #[mz_ore::test]
355    fn sql_literal_escaping() {
356        assert_eq!(Sql::literal("a").as_str(), "'a'");
357        assert_eq!(Sql::literal("a'b").as_str(), "'a''b'");
358    }
359
360    #[mz_ore::test]
361    fn sql_format_composes_fragments() {
362        let query = Sql::new("SELECT * FROM {} WHERE col = {}")
363            .format([Sql::ident("my_table"), Sql::literal("v")])
364            .expect("valid template");
365        assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'");
366    }
367
368    #[mz_ore::test]
369    fn sql_format_errors_on_invalid_placeholders() {
370        let err = Sql::new("SELECT {x}")
371            .format([Sql::ident("t")])
372            .expect_err("invalid format");
373        assert_eq!(err, SqlFormatError::InvalidOpenBrace);
374    }
375
376    #[mz_ore::test]
377    fn sql_macro_composes_fragments() {
378        let query = crate::sql!(
379            "SELECT * FROM {} WHERE col = {}",
380            Sql::ident("my_table"),
381            Sql::literal("v")
382        );
383        assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'");
384    }
385
386    #[mz_ore::test]
387    fn sql_macro_escaped_braces() {
388        let query = crate::sql!("SELECT '{{}}' AS braces, {} AS t", Sql::ident("col"));
389        assert_eq!(query.as_str(), "SELECT '{}' AS braces, \"col\" AS t");
390    }
391
392    #[mz_ore::test]
393    fn sql_macro_static_literal() {
394        let query = crate::sql!("SELECT '{not_a_placeholder}'");
395        assert_eq!(query.as_str(), "SELECT '{not_a_placeholder}'");
396    }
397}