1use std::borrow::Cow;
23use std::fmt::{self, Write};
24
25#[derive(Clone, Debug, Eq, PartialEq)]
31pub struct Sql(Cow<'static, str>);
32
33#[doc(hidden)]
34#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35pub enum SqlTemplateError {
36 InvalidOpenBrace,
37 InvalidCloseBrace,
38}
39
40#[doc(hidden)]
41pub const fn sql_template_placeholder_count(template: &str) -> Result<usize, SqlTemplateError> {
42 let bytes = template.as_bytes();
43 let mut i = 0;
44 let mut count = 0;
45
46 while i < bytes.len() {
47 match bytes[i] {
48 b'{' => {
49 if i + 1 >= bytes.len() {
50 return Err(SqlTemplateError::InvalidOpenBrace);
51 }
52 match bytes[i + 1] {
53 b'{' => i += 2,
54 b'}' => {
55 count += 1;
56 i += 2;
57 }
58 _ => return Err(SqlTemplateError::InvalidOpenBrace),
59 }
60 }
61 b'}' => {
62 if i + 1 >= bytes.len() {
63 return Err(SqlTemplateError::InvalidCloseBrace);
64 }
65 match bytes[i + 1] {
66 b'}' => i += 2,
67 _ => return Err(SqlTemplateError::InvalidCloseBrace),
68 }
69 }
70 _ => i += 1,
71 }
72 }
73
74 Ok(count)
75}
76
77#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
79pub enum SqlFormatError {
80 #[error("SQL format string contains an invalid '{{' sequence")]
83 InvalidOpenBrace,
84 #[error("SQL format string contains an invalid '}}' sequence")]
87 InvalidCloseBrace,
88 #[error("SQL format string expected more arguments")]
90 MissingArgument,
91 #[error("SQL format string received too many arguments")]
93 ExtraArgument,
94}
95
96impl Sql {
97 pub fn new(sql: &'static str) -> Self {
99 Self(Cow::Borrowed(sql))
100 }
101
102 pub fn raw_unchecked(sql: String) -> Self {
106 Self(Cow::Owned(sql))
107 }
108
109 pub fn trusted_external_request(sql: String) -> Self {
120 Self(Cow::Owned(sql))
121 }
122
123 pub fn ident(ident: &str) -> Self {
125 let mut out = String::with_capacity(ident.len() + 2);
128 out.push('"');
129 for ch in ident.chars() {
130 if ch == '"' {
131 out.push('"');
132 }
133 out.push(ch);
134 }
135 out.push('"');
136 Self(Cow::Owned(out))
137 }
138
139 pub fn literal(literal: &str) -> Self {
141 let mut out = String::with_capacity(literal.len() + 2);
144 out.push('\'');
145 for ch in literal.chars() {
146 if ch == '\'' {
147 out.push('\'');
148 }
149 out.push(ch);
150 }
151 out.push('\'');
152 Self(Cow::Owned(out))
153 }
154
155 pub fn param(index: usize) -> Self {
157 assert!(
159 (1..=65535).contains(&index),
160 "PostgreSQL parameter index out of range: {index}"
161 );
162 let mut out = String::new();
163 out.push('$');
164 let _ = write!(out, "{index}");
165 Self(Cow::Owned(out))
166 }
167
168 pub fn join(parts: impl IntoIterator<Item = Sql>, separator: &'static str) -> Self {
170 let mut iter = parts.into_iter();
171 let Some(first) = iter.next() else {
172 return Self(Cow::Borrowed(""));
173 };
174
175 let mut out = first.0;
176 for part in iter {
177 out.to_mut().push_str(separator);
178 out.to_mut().push_str(part.as_str());
179 }
180 Self(out)
181 }
182
183 pub fn format(self, args: impl IntoIterator<Item = Sql>) -> Result<Self, SqlFormatError> {
187 let mut args = args.into_iter();
188 let mut out = String::with_capacity(self.0.len());
189 let mut chars = self.0.chars().peekable();
190
191 while let Some(ch) = chars.next() {
192 match ch {
193 '{' => match chars.peek() {
194 Some('{') => {
195 chars.next();
196 out.push('{');
197 }
198 Some('}') => {
199 chars.next();
200 let arg = args.next().ok_or(SqlFormatError::MissingArgument)?;
201 out.push_str(arg.as_str());
202 }
203 _ => return Err(SqlFormatError::InvalidOpenBrace),
204 },
205 '}' => match chars.peek() {
206 Some('}') => {
207 chars.next();
208 out.push('}');
209 }
210 _ => return Err(SqlFormatError::InvalidCloseBrace),
211 },
212 _ => out.push(ch),
213 }
214 }
215
216 if args.next().is_some() {
217 return Err(SqlFormatError::ExtraArgument);
218 }
219 Ok(Sql(Cow::Owned(out)))
220 }
221
222 #[doc(hidden)]
235 pub fn format_unchecked(self, args: impl IntoIterator<Item = Sql>) -> Self {
236 let mut args = args.into_iter();
237 let mut out = String::with_capacity(self.0.len());
238 let mut chars = self.0.chars().peekable();
239
240 while let Some(ch) = chars.next() {
241 match ch {
242 '{' => match chars.next().expect("validated in sql! macro") {
243 '{' => out.push('{'),
244 '}' => {
245 let arg = args.next().expect("validated in sql! macro");
246 out.push_str(arg.as_str());
247 }
248 _ => unreachable!("validated in sql! macro"),
249 },
250 '}' => match chars.next().expect("validated in sql! macro") {
251 '}' => out.push('}'),
252 _ => unreachable!("validated in sql! macro"),
253 },
254 _ => out.push(ch),
255 }
256 }
257 if cfg!(debug_assertions) {
258 assert!(args.next().is_none(), "validated in sql! macro");
259 }
260 Sql(Cow::Owned(out))
261 }
262
263 pub fn as_str(&self) -> &str {
265 self.0.as_ref()
266 }
267
268 pub fn into_string(self) -> String {
270 self.0.into_owned()
271 }
272}
273
274impl fmt::Display for Sql {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 f.write_str(self.as_str())
277 }
278}
279
280impl serde::Serialize for Sql {
281 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
282 self.as_str().serialize(serializer)
283 }
284}
285
286impl<'de> serde::Deserialize<'de> for Sql {
287 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
288 Ok(Sql::trusted_external_request(String::deserialize(
293 deserializer,
294 )?))
295 }
296}
297
298macro_rules! impl_from_integer_for_sql {
299 ($($t:ty),+ $(,)?) => {
300 $(
301 impl From<$t> for Sql {
302 fn from(value: $t) -> Self {
303 Sql(Cow::Owned(value.to_string()))
304 }
305 }
306 )+
307 };
308}
309
310impl_from_integer_for_sql!(i16, i32, i64, isize, u16, u32, u64, usize);
311
312#[macro_export]
317macro_rules! sql {
318 ($template:literal $(,)?) => {
319 $crate::sql::Sql::new($template)
320 };
321 ($template:literal, $($arg:expr),+ $(,)?) => {{
322 const __SQL_FORMAT_ARG_COUNT: usize = <[()]>::len(&[$($crate::sql!(@unit $arg)),*]);
323 const __SQL_FORMAT_PLACEHOLDER_COUNT: usize =
324 match $crate::sql::sql_template_placeholder_count($template) {
325 Ok(n) => n,
326 Err($crate::sql::SqlTemplateError::InvalidOpenBrace) => {
327 panic!("sql!: invalid '{{' in SQL template")
328 }
329 Err($crate::sql::SqlTemplateError::InvalidCloseBrace) => {
330 panic!("sql!: invalid '}}' in SQL template")
331 }
332 };
333 const _: () = {
334 if __SQL_FORMAT_ARG_COUNT != __SQL_FORMAT_PLACEHOLDER_COUNT {
335 panic!("sql!: placeholder count does not match arguments");
336 }
337 };
338
339 $crate::sql::Sql::new($template).format_unchecked([$($crate::sql::Sql::from($arg)),*])
340 }};
341 (@unit $_arg:expr) => { () };
342}
343
344#[cfg(test)]
345mod tests {
346 use super::{Sql, SqlFormatError};
347
348 #[mz_ore::test]
349 fn sql_identifier_escaping() {
350 assert_eq!(Sql::ident("a").as_str(), "\"a\"");
351 assert_eq!(Sql::ident("a\"b").as_str(), "\"a\"\"b\"");
352 }
353
354 #[mz_ore::test]
355 fn sql_literal_escaping() {
356 assert_eq!(Sql::literal("a").as_str(), "'a'");
357 assert_eq!(Sql::literal("a'b").as_str(), "'a''b'");
358 }
359
360 #[mz_ore::test]
361 fn sql_format_composes_fragments() {
362 let query = Sql::new("SELECT * FROM {} WHERE col = {}")
363 .format([Sql::ident("my_table"), Sql::literal("v")])
364 .expect("valid template");
365 assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'");
366 }
367
368 #[mz_ore::test]
369 fn sql_format_errors_on_invalid_placeholders() {
370 let err = Sql::new("SELECT {x}")
371 .format([Sql::ident("t")])
372 .expect_err("invalid format");
373 assert_eq!(err, SqlFormatError::InvalidOpenBrace);
374 }
375
376 #[mz_ore::test]
377 fn sql_macro_composes_fragments() {
378 let query = crate::sql!(
379 "SELECT * FROM {} WHERE col = {}",
380 Sql::ident("my_table"),
381 Sql::literal("v")
382 );
383 assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'");
384 }
385
386 #[mz_ore::test]
387 fn sql_macro_escaped_braces() {
388 let query = crate::sql!("SELECT '{{}}' AS braces, {} AS t", Sql::ident("col"));
389 assert_eq!(query.as_str(), "SELECT '{}' AS braces, \"col\" AS t");
390 }
391
392 #[mz_ore::test]
393 fn sql_macro_static_literal() {
394 let query = crate::sql!("SELECT '{not_a_placeholder}'");
395 assert_eq!(query.as_str(), "SELECT '{not_a_placeholder}'");
396 }
397}