1use std::borrow::ToOwned;
11use std::collections::{BTreeMap, btree_map};
12use std::error::Error;
13use std::fmt::Write;
14use std::str::FromStr;
15use std::sync::LazyLock;
16
17use anyhow::{Context, anyhow, bail};
18use regex::Regex;
19
20use crate::error::PosError;
21
22#[derive(Debug, Clone)]
23pub struct PosCommand {
24 pub pos: usize,
25 pub command: Command,
26}
27
28#[derive(Debug, Clone)]
30pub struct VersionConstraint {
31 pub min: i32,
32 pub max: i32,
33}
34
35#[derive(Debug, Clone)]
36pub enum Command {
37 Builtin(BuiltinCommand, Option<VersionConstraint>),
38 Sql(SqlCommand, Option<VersionConstraint>),
39 FailSql(FailSqlCommand, Option<VersionConstraint>),
40}
41
42#[derive(Debug, Clone)]
43pub struct BuiltinCommand {
44 pub name: String,
45 pub args: ArgMap,
46 pub input: Vec<String>,
47}
48
49impl BuiltinCommand {
50 pub fn assert_no_input(&self) -> Result<(), anyhow::Error> {
51 if !self.input.is_empty() {
52 bail!("{} action does not take input", self.name);
53 }
54 Ok(())
55 }
56}
57
58#[derive(Debug, Clone)]
59pub enum SqlOutput {
60 Full {
61 column_names: Option<Vec<String>>,
62 expected_rows: Vec<Vec<String>>,
63 },
64 Hashed {
65 num_values: usize,
66 md5: String,
67 },
68}
69#[derive(Debug, Clone)]
70pub struct SqlCommand {
71 pub query: String,
72 pub expected_output: SqlOutput,
73 pub expected_start: usize,
74 pub expected_end: usize,
75}
76
77#[derive(Debug, Clone)]
78pub struct FailSqlCommand {
79 pub query: String,
80 pub expected_error: SqlExpectedError,
81 pub expected_detail: Option<String>,
82 pub expected_hint: Option<String>,
83}
84
85#[derive(Debug, Clone)]
86pub enum SqlExpectedError {
87 Contains(String),
88 Exact(String),
89 Regex(String),
90 Timeout,
91}
92
93pub(crate) fn parse(line_reader: &mut LineReader) -> Result<Vec<PosCommand>, PosError> {
94 let mut out = Vec::new();
95 while let Some((pos, line)) = line_reader.peek() {
96 let pos = *pos;
97 let command = match line.chars().next() {
98 Some('$') => {
99 let version = parse_version_constraint(line_reader)?;
100 Command::Builtin(parse_builtin(line_reader)?, version)
101 }
102 Some('>') => {
103 let version = parse_version_constraint(line_reader)?;
104 Command::Sql(parse_sql(line_reader)?, version)
105 }
106 Some('?') => {
107 let version = parse_version_constraint(line_reader)?;
108 Command::Sql(parse_explain_sql(line_reader)?, version)
109 }
110 Some('!') => {
111 let version = parse_version_constraint(line_reader)?;
112 Command::FailSql(parse_fail_sql(line_reader)?, version)
113 }
114 Some('#') => {
115 line_reader.next();
117 continue;
118 }
119 Some(x) => {
120 return Err(PosError {
121 source: anyhow!(format!("unexpected input line at beginning of file: {}", x)),
122 pos: Some(pos),
123 });
124 }
125 None => {
126 return Err(PosError {
127 source: anyhow!("unexpected input line at beginning of file"),
128 pos: Some(pos),
129 });
130 }
131 };
132 out.push(PosCommand { command, pos });
133 }
134 Ok(out)
135}
136
137fn parse_builtin(line_reader: &mut LineReader) -> Result<BuiltinCommand, PosError> {
138 let (pos, line) = line_reader.next().unwrap();
139 let mut builtin_reader = BuiltinReader::new(&line, pos);
140 let name = match builtin_reader.next() {
141 Some(Ok((_, s))) => s,
142 Some(Err(e)) => return Err(e),
143 None => {
144 return Err(PosError {
145 source: anyhow!("command line is missing command name"),
146 pos: Some(pos),
147 });
148 }
149 };
150 let mut args = BTreeMap::new();
151 for el in builtin_reader {
152 let (pos, token) = el?;
153 let pieces: Vec<_> = token.splitn(2, '=').collect();
154 let pieces = match pieces.as_slice() {
155 [key, value] => vec![*key, *value],
156 [key] => vec![*key, ""],
157 _ => {
158 return Err(PosError {
159 source: anyhow!("command argument is not in required key=value format"),
160 pos: Some(pos),
161 });
162 }
163 };
164 validate_ident(pieces[0]).map_err(|e| PosError::new(e, pos))?;
165
166 if let Some(original) = args.insert(pieces[0].to_owned(), pieces[1].to_owned()) {
167 return Err(PosError {
168 source: anyhow!(
169 "argument '{}' specified twice: {} & {}",
170 pieces[0],
171 original,
172 pieces[1]
173 ),
174 pos: Some(pos),
175 });
176 };
177 }
178 Ok(BuiltinCommand {
179 name,
180 args: ArgMap(args),
181 input: slurp_all(line_reader),
182 })
183}
184
185pub fn validate_ident(name: &str) -> Result<(), anyhow::Error> {
187 static VALID_KEY_REGEX: LazyLock<Regex> =
188 LazyLock::new(|| Regex::new("^[a-z0-9\\-]*$").unwrap());
189 if !VALID_KEY_REGEX.is_match(name) {
190 bail!(
191 "invalid builtin argument name '{}': \
192 only lowercase letters, numbers, and hyphens allowed",
193 name
194 );
195 }
196 Ok(())
197}
198
199fn parse_version_constraint(
200 line_reader: &mut LineReader,
201) -> Result<Option<VersionConstraint>, PosError> {
202 let (pos, line) = line_reader.next().unwrap();
203 if line[1..2].to_string() != "[" {
204 line_reader.push(&line);
205 return Ok(None);
206 }
207 let closed_brace_pos = match line.find(']') {
208 Some(x) => x,
209 None => {
210 return Err(PosError {
211 source: anyhow!("version-constraint: found no closing brace"),
212 pos: Some(pos),
213 });
214 }
215 };
216 if line[2..9].to_string() != "version" {
217 return Err(PosError {
218 source: anyhow!(
219 "version-constraint: invalid property {}",
220 line[2..closed_brace_pos].to_string()
221 ),
222 pos: Some(pos),
223 });
224 }
225 let remainder = line[closed_brace_pos + 1..].to_string();
226 line_reader.push(&remainder);
227 const MIN_VERSION: i32 = 0;
228 const MAX_VERSION: i32 = 9999999;
229 let version_pos = if line.as_bytes()[10].is_ascii_digit() {
230 10
231 } else {
232 11
233 };
234 let version = match line[version_pos..closed_brace_pos].parse::<i32>() {
235 Ok(x) => x,
236 Err(_) => {
237 return Err(PosError {
238 source: anyhow!(
239 "version-constraint: invalid version number {}",
240 line[version_pos..closed_brace_pos].to_string()
241 ),
242 pos: Some(pos),
243 });
244 }
245 };
246
247 match &line[9..version_pos] {
248 "=" => Ok(Some(VersionConstraint {
249 min: version,
250 max: version,
251 })),
252 "<=" => Ok(Some(VersionConstraint {
253 min: MIN_VERSION,
254 max: version,
255 })),
256 "<" => Ok(Some(VersionConstraint {
257 min: MIN_VERSION,
258 max: version - 1,
259 })),
260 ">=" => Ok(Some(VersionConstraint {
261 min: version,
262 max: MAX_VERSION,
263 })),
264 ">" => Ok(Some(VersionConstraint {
265 min: version + 1,
266 max: MAX_VERSION,
267 })),
268 _ => Err(PosError {
269 source: anyhow!(
270 "version-constraint: unknown comparison operator {}",
271 line[9..version_pos].to_string()
272 ),
273 pos: Some(pos),
274 }),
275 }
276}
277
278fn parse_sql(line_reader: &mut LineReader) -> Result<SqlCommand, PosError> {
279 let (_, line1) = line_reader.next().unwrap();
280 let query = line1[1..].trim().to_owned();
281 let expected_start = line_reader.raw_pos;
282 let line2 = slurp_one(line_reader);
283 let line3 = slurp_one(line_reader);
284 let mut column_names = None;
285 let mut expected_rows = Vec::new();
286 static HASH_REGEX: LazyLock<Regex> =
287 LazyLock::new(|| Regex::new(r"^(\S+) values hashing to (\S+)$").unwrap());
288 match (line2, line3) {
289 (Some((pos2, line2)), Some((pos3, line3))) => {
290 if line3.len() >= 3 && line3.chars().all(|c| c == '-') {
291 column_names = Some(split_line(pos2, &line2)?);
292 } else {
293 expected_rows.push(split_line(pos2, &line2)?);
294 expected_rows.push(split_line(pos3, &line3)?);
295 }
296 }
297 (Some((pos2, line2)), None) => match HASH_REGEX.captures(&line2) {
298 Some(captures) => match captures[1].parse::<usize>() {
299 Ok(num_values) => {
300 return Ok(SqlCommand {
301 query,
302 expected_output: SqlOutput::Hashed {
303 num_values,
304 md5: captures[2].to_owned(),
305 },
306 expected_start: 0,
307 expected_end: 0,
308 });
309 }
310 Err(err) => {
311 return Err(PosError {
312 source: anyhow!("Error parsing number of expected rows: {}", err),
313 pos: Some(pos2),
314 });
315 }
316 },
317 None => expected_rows.push(split_line(pos2, &line2)?),
318 },
319 _ => (),
320 }
321 while let Some((pos, line)) = slurp_one(line_reader) {
322 expected_rows.push(split_line(pos, &line)?)
323 }
324 let expected_end = line_reader.raw_pos;
325 Ok(SqlCommand {
326 query,
327 expected_output: SqlOutput::Full {
328 column_names,
329 expected_rows,
330 },
331 expected_start,
332 expected_end,
333 })
334}
335
336fn parse_explain_sql(line_reader: &mut LineReader) -> Result<SqlCommand, PosError> {
337 let (_, line1) = line_reader.next().unwrap();
338 let expected_start = line_reader.raw_pos;
339 let mut expected_output: String = line_reader
343 .inner
344 .lines()
345 .filter(|l| !matches!(l.chars().next(), Some('#')))
346 .take_while(|l| !is_sigil(l.chars().next()))
347 .fold(String::new(), |mut output, l| {
348 let _ = write!(output, "{}\n", l);
349 output
350 });
351 while expected_output.ends_with("\n\n") {
352 expected_output.pop();
353 }
354 slurp_all(line_reader);
357 let expected_end = line_reader.raw_pos;
358
359 Ok(SqlCommand {
360 query: line1[1..].trim().to_owned(),
361 expected_output: SqlOutput::Full {
362 column_names: None,
363 expected_rows: vec![vec![expected_output]],
364 },
365 expected_start,
366 expected_end,
367 })
368}
369
370fn parse_fail_sql(line_reader: &mut LineReader) -> Result<FailSqlCommand, PosError> {
371 let (pos, line1) = line_reader.next().unwrap();
372 let line2 = slurp_one(line_reader);
373 let (err_pos, expected_error) = match line2 {
374 Some((err_pos, line2)) => (err_pos, line2),
375 None => {
376 return Err(PosError {
377 pos: Some(pos),
378 source: anyhow!("failing SQL command is missing expected error message"),
379 });
380 }
381 };
382 let query = line1[1..].trim().to_string();
383
384 let expected_error = if let Some(e) = expected_error.strip_prefix("regex:") {
385 SqlExpectedError::Regex(e.trim().into())
386 } else if let Some(e) = expected_error.strip_prefix("contains:") {
387 SqlExpectedError::Contains(e.trim().into())
388 } else if let Some(e) = expected_error.strip_prefix("exact:") {
389 SqlExpectedError::Exact(e.trim().into())
390 } else if expected_error == "timeout" {
391 SqlExpectedError::Timeout
392 } else {
393 return Err(PosError {
394 pos: Some(err_pos),
395 source: anyhow!(
396 "Query error must start with match specifier (`regex:`|`contains:`|`exact:`|`timeout`)"
397 ),
398 });
399 };
400
401 let extra_error = |line_reader: &mut LineReader, prefix| {
402 if let Some((_pos, line)) = line_reader.peek() {
403 if let Some(_) = line.strip_prefix(prefix) {
404 let line = line_reader
405 .next()
406 .map(|(_, line)| line)
407 .unwrap()
408 .strip_prefix(prefix)
409 .map(|line| line.to_string())
410 .unwrap();
411 Some(line.trim().to_string())
412 } else {
413 None
414 }
415 } else {
416 None
417 }
418 };
419 let expected_detail = extra_error(line_reader, "detail:");
421 let expected_hint = extra_error(line_reader, "hint:");
422
423 Ok(FailSqlCommand {
424 query: query.trim().to_string(),
425 expected_error,
426 expected_detail,
427 expected_hint,
428 })
429}
430
431fn split_line(pos: usize, line: &str) -> Result<Vec<String>, PosError> {
432 let mut out = Vec::new();
433 let mut field = String::new();
434 let mut in_quotes = None;
435 let mut escaping = false;
436 for (i, c) in line.char_indices() {
437 if in_quotes.is_none() && c.is_whitespace() {
438 if !field.is_empty() {
439 out.push(field);
440 field = String::new();
441 }
442 } else if c == '"' && !escaping {
443 if in_quotes.is_none() {
444 in_quotes = Some(i)
445 } else {
446 in_quotes = None;
447 out.push(field);
448 field = String::new();
449 }
450 } else if c == '\\' && !escaping && in_quotes.is_some() {
451 escaping = true;
452 } else if escaping {
453 field.push(match c {
454 'n' => '\n',
455 't' => '\t',
456 'r' => '\r',
457 '0' => '\0',
458 c => c,
459 });
460 escaping = false;
461 } else {
462 field.push(c);
463 }
464 }
465 if let Some(i) = in_quotes {
466 return Err(PosError {
467 source: anyhow!("unterminated quote"),
468 pos: Some(pos + i),
469 });
470 }
471 if !field.is_empty() {
472 out.push(field);
473 }
474 Ok(out)
475}
476
477fn slurp_all(line_reader: &mut LineReader) -> Vec<String> {
478 let mut out = Vec::new();
479 while let Some((_, line)) = slurp_one(line_reader) {
480 out.push(line);
481 }
482 out
483}
484
485fn slurp_one(line_reader: &mut LineReader) -> Option<(usize, String)> {
486 while let Some((_, line)) = line_reader.peek() {
487 match line.chars().next() {
488 Some('#') => {
489 let _ = line_reader.next();
491 }
492 Some('$') | Some('>') | Some('!') | Some('?') => return None,
493 Some('\\') => {
494 return line_reader.next().map(|(pos, mut line)| {
495 line.remove(0);
496 (pos, line)
497 });
498 }
499 _ => return line_reader.next(),
500 }
501 }
502 None
503}
504
505pub struct LineReader<'a> {
506 inner: &'a str,
507 #[allow(clippy::option_option)]
508 next: Option<Option<(usize, String)>>,
509
510 src_line: usize,
511 pos: usize,
512 pos_map: BTreeMap<usize, (usize, usize)>,
513 raw_pos: usize,
514}
515
516impl<'a> LineReader<'a> {
517 pub fn new(inner: &'a str) -> LineReader<'a> {
518 let mut pos_map = BTreeMap::new();
519 pos_map.insert(0, (1, 1));
520 LineReader {
521 inner,
522 src_line: 1,
523 next: None,
524 pos: 0,
525 pos_map,
526 raw_pos: 0,
527 }
528 }
529
530 fn peek(&mut self) -> Option<&(usize, String)> {
531 if self.next.is_none() {
532 self.next = Some(self.next())
533 }
534 self.next.as_ref().unwrap().as_ref()
535 }
536
537 pub fn line_col(&self, pos: usize) -> (usize, usize) {
538 let (base_pos, (line, col)) = self.pos_map.range(..=pos).next_back().unwrap();
539 (*line, col + (pos - base_pos))
540 }
541
542 fn push(&mut self, text: &String) {
543 self.next = Some(Some((0usize, text.to_string())));
544 }
545}
546
547impl<'a> Iterator for LineReader<'a> {
548 type Item = (usize, String);
549
550 fn next(&mut self) -> Option<Self::Item> {
551 if let Some(next) = self.next.take() {
552 return next;
553 }
554 if self.inner.is_empty() {
555 return None;
556 }
557 let mut fold_newlines = is_non_sql_sigil(self.inner.chars().next());
558 let mut handle_newlines = is_sql_sigil(self.inner.chars().next());
559 let mut line = String::new();
560 let mut chars = self.inner.char_indices().fuse().peekable();
561 while let Some((i, c)) = chars.next() {
562 if c == '\n' {
563 self.src_line += 1;
564 if fold_newlines && self.inner.get(i + 1..i + 3) == Some(" ") {
565 chars.next();
570 self.pos_map.insert(self.pos + i, (self.src_line, 2));
571 continue;
572 } else if handle_newlines && self.inner.get(i + 1..i + 3) == Some(" ") {
573 line.push(c);
577 chars.next();
578 chars.next();
579 self.pos_map.insert(self.pos + i + 1, (self.src_line, 2));
580 continue;
581 } else if line.chars().all(char::is_whitespace) {
582 line.clear();
583 fold_newlines = is_non_sql_sigil(chars.peek().map(|c| c.1));
584 handle_newlines = is_sql_sigil(chars.peek().map(|c| c.1));
585 self.pos_map.insert(self.pos, (self.src_line, 1));
586 continue;
587 }
588 let pos = self.pos;
589 self.pos += i;
590 self.raw_pos += i + 1; self.pos_map.insert(self.pos, (self.src_line, 1));
592 self.inner = &self.inner[i + 1..];
593 return Some((pos, line));
594 }
595 line.push(c)
596 }
597 self.inner = "";
598 if !line.chars().all(char::is_whitespace) {
599 Some((self.pos, line))
600 } else {
601 None
602 }
603 }
604}
605
606fn is_sigil(c: Option<char>) -> bool {
607 is_sql_sigil(c) || is_non_sql_sigil(c)
608}
609
610fn is_sql_sigil(c: Option<char>) -> bool {
611 matches!(c, Some('>') | Some('!') | Some('?'))
612}
613
614fn is_non_sql_sigil(c: Option<char>) -> bool {
615 matches!(c, Some('$'))
616}
617
618struct BuiltinReader<'a> {
619 inner: &'a str,
620 pos: usize,
621}
622
623impl<'a> BuiltinReader<'a> {
624 fn new(line: &str, pos: usize) -> BuiltinReader {
625 BuiltinReader {
626 inner: &line[1..],
627 pos,
628 }
629 }
630}
631
632impl<'a> Iterator for BuiltinReader<'a> {
633 type Item = Result<(usize, String), PosError>;
634
635 fn next(&mut self) -> Option<Self::Item> {
636 if self.inner.is_empty() {
637 return None;
638 }
639
640 let mut iter = self.inner.char_indices().peekable();
641
642 while let Some((i, c)) = iter.peek() {
643 if c == &' ' {
644 iter.next();
645 } else {
646 self.pos += i;
647 break;
648 }
649 }
650
651 let mut token = String::new();
652 let mut nesting = Vec::new();
653 let mut done = false;
654 let mut quoted = false;
655 for (i, c) in iter {
656 if c == ' ' && nesting.is_empty() && !quoted {
657 done = true;
658 continue;
659 } else if done {
660 if let Some(nested) = nesting.last() {
661 return Some(Err(PosError {
662 pos: Some(self.pos + i),
663 source: anyhow!(
664 "command argument has unterminated open {}",
665 if nested == &'{' { "brace" } else { "bracket" }
666 ),
667 }));
668 }
669 let pos = self.pos;
670 self.pos += i;
671 self.inner = &self.inner[i..];
672 return Some(Ok((pos, token)));
673 } else if (c == '{' || c == '[') && !quoted {
674 nesting.push(c);
675 } else if (c == '}' || c == ']') && !quoted {
676 if let Some(nested) = nesting.last() {
677 if (nested == &'{' && c == '}') || (nested == &'[' && c == ']') {
678 nesting.pop();
679 } else {
680 return Some(Err(PosError {
681 pos: Some(self.pos + i),
682 source: anyhow!(
683 "command argument has unterminated open {}",
684 if nested == &'{' { "brace" } else { "bracket" }
685 ),
686 }));
687 }
688 } else {
689 return Some(Err(PosError {
690 pos: Some(self.pos + i),
691 source: anyhow!(
692 "command argument has unbalanced close {}",
693 if c == '}' { "brace" } else { "bracket" }
694 ),
695 }));
696 }
697 } else if c == '"' && nesting.is_empty() {
698 quoted = !quoted;
701 continue;
702 }
703 token.push(c);
704 }
705
706 if let Some(nested) = nesting.last() {
707 return Some(Err(PosError {
708 pos: Some(self.pos + self.inner.len() - 1),
709 source: anyhow!(
710 "command argument has unterminated open {}",
711 if nested == &'{' { "brace" } else { "bracket" }
712 ),
713 }));
714 }
715
716 if quoted {
717 return Some(Err(PosError {
718 pos: Some(self.pos),
719 source: anyhow!("command argument has unterminated open double quote",),
720 }));
721 }
722
723 self.inner = "";
724 if token.is_empty() {
725 None
726 } else {
727 Some(Ok((self.pos, token)))
728 }
729 }
730}
731
732#[derive(Debug, Clone)]
733pub struct ArgMap(BTreeMap<String, String>);
734
735impl ArgMap {
736 pub fn values_mut(&mut self) -> btree_map::ValuesMut<String, String> {
737 self.0.values_mut()
738 }
739
740 pub fn opt_string(&mut self, name: &str) -> Option<String> {
741 self.0.remove(name)
742 }
743
744 pub fn string(&mut self, name: &str) -> Result<String, anyhow::Error> {
745 self.opt_string(name)
746 .ok_or_else(|| anyhow!("missing {} parameter", name))
747 }
748
749 pub fn opt_parse<T>(&mut self, name: &str) -> Result<Option<T>, anyhow::Error>
750 where
751 T: FromStr,
752 T::Err: Error + Send + Sync + 'static,
753 {
754 match self.opt_string(name) {
755 Some(val) => {
756 let t = val
757 .parse()
758 .with_context(|| format!("parsing {} parameter", name))?;
759 Ok(Some(t))
760 }
761 None => Ok(None),
762 }
763 }
764
765 pub fn parse<T>(&mut self, name: &str) -> Result<T, anyhow::Error>
766 where
767 T: FromStr,
768 T::Err: Error + Send + Sync + 'static,
769 {
770 match self.opt_parse(name) {
771 Ok(None) => bail!("missing {} parameter", name),
772 Ok(Some(t)) => Ok(t),
773 Err(err) => Err(err),
774 }
775 }
776
777 pub fn opt_bool(&mut self, name: &str) -> Result<Option<bool>, anyhow::Error> {
778 self.opt_string(name)
779 .map(|val| {
780 if val == "true" {
781 Ok(true)
782 } else if val == "false" {
783 Ok(false)
784 } else {
785 bail!("bad value for boolean parameter {}: {}", name, val);
786 }
787 })
788 .transpose()
789 }
790
791 pub fn done(&self) -> Result<(), anyhow::Error> {
792 if let Some(name) = self.0.keys().next() {
793 bail!("unknown parameter {}", name);
794 }
795 Ok(())
796 }
797}
798
799impl IntoIterator for ArgMap {
800 type Item = (String, String);
801 type IntoIter = btree_map::IntoIter<String, String>;
802
803 fn into_iter(self) -> Self::IntoIter {
804 self.0.into_iter()
805 }
806}