Skip to main content

mz_testdrive/
parser.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::ToOwned;
11use std::collections::{BTreeMap, btree_map};
12use std::error::Error;
13use std::fmt::Write;
14use std::str::FromStr;
15use std::sync::LazyLock;
16
17use anyhow::{Context, anyhow, bail};
18use regex::Regex;
19
20use crate::error::PosError;
21
22#[derive(Debug, Clone)]
23pub struct PosCommand {
24    pub pos: usize,
25    pub command: Command,
26}
27
28// min and max versions, both inclusive
29#[derive(Debug, Clone)]
30pub struct VersionConstraint {
31    pub min: i32,
32    pub max: i32,
33}
34
35#[derive(Debug, Clone)]
36pub enum Command {
37    Builtin(BuiltinCommand, Option<VersionConstraint>),
38    Sql(SqlCommand, Option<VersionConstraint>),
39    FailSql(FailSqlCommand, Option<VersionConstraint>),
40}
41
42#[derive(Debug, Clone)]
43pub struct BuiltinCommand {
44    pub name: String,
45    pub args: ArgMap,
46    pub input: Vec<String>,
47}
48
49impl BuiltinCommand {
50    pub fn assert_no_input(&self) -> Result<(), anyhow::Error> {
51        if !self.input.is_empty() {
52            bail!("{} action does not take input", self.name);
53        }
54        Ok(())
55    }
56}
57
58#[derive(Debug, Clone)]
59pub enum SqlOutput {
60    Full {
61        column_names: Option<Vec<String>>,
62        expected_rows: Vec<Vec<String>>,
63    },
64    Hashed {
65        num_values: usize,
66        md5: String,
67    },
68}
69#[derive(Debug, Clone)]
70pub struct SqlCommand {
71    pub query: String,
72    pub expected_output: SqlOutput,
73    pub expected_start: usize,
74    pub expected_end: usize,
75}
76
77#[derive(Debug, Clone)]
78pub struct FailSqlCommand {
79    pub query: String,
80    pub expected_error: SqlExpectedError,
81    pub expected_detail: Option<String>,
82    pub expected_hint: Option<String>,
83}
84
85#[derive(Debug, Clone)]
86pub enum SqlExpectedError {
87    Contains(String),
88    Exact(String),
89    Regex(String),
90    Timeout,
91}
92
93pub(crate) fn parse(line_reader: &mut LineReader) -> Result<Vec<PosCommand>, PosError> {
94    let mut out = Vec::new();
95    while let Some((pos, line)) = line_reader.peek() {
96        let pos = *pos;
97        let command = match line.chars().next() {
98            Some('$') => {
99                let version = parse_version_constraint(line_reader)?;
100                Command::Builtin(parse_builtin(line_reader)?, version)
101            }
102            Some('>') => {
103                let version = parse_version_constraint(line_reader)?;
104                Command::Sql(parse_sql(line_reader)?, version)
105            }
106            Some('?') => {
107                let version = parse_version_constraint(line_reader)?;
108                Command::Sql(parse_explain_sql(line_reader)?, version)
109            }
110            Some('!') => {
111                let version = parse_version_constraint(line_reader)?;
112                Command::FailSql(parse_fail_sql(line_reader)?, version)
113            }
114            Some('#') => {
115                // Comment line.
116                line_reader.next();
117                continue;
118            }
119            Some(x) => {
120                return Err(PosError {
121                    source: anyhow!(format!("unexpected input line at beginning of file: {}", x)),
122                    pos: Some(pos),
123                });
124            }
125            None => {
126                return Err(PosError {
127                    source: anyhow!("unexpected input line at beginning of file"),
128                    pos: Some(pos),
129                });
130            }
131        };
132        out.push(PosCommand { command, pos });
133    }
134    Ok(out)
135}
136
137fn parse_builtin(line_reader: &mut LineReader) -> Result<BuiltinCommand, PosError> {
138    let (pos, line) = line_reader.next().unwrap();
139    let mut builtin_reader = BuiltinReader::new(&line, pos);
140    let name = match builtin_reader.next() {
141        Some(Ok((_, s))) => s,
142        Some(Err(e)) => return Err(e),
143        None => {
144            return Err(PosError {
145                source: anyhow!("command line is missing command name"),
146                pos: Some(pos),
147            });
148        }
149    };
150    let mut args = BTreeMap::new();
151    for el in builtin_reader {
152        let (pos, token) = el?;
153        let pieces: Vec<_> = token.splitn(2, '=').collect();
154        let pieces = match pieces.as_slice() {
155            [key, value] => vec![*key, *value],
156            [key] => vec![*key, ""],
157            _ => {
158                return Err(PosError {
159                    source: anyhow!("command argument is not in required key=value format"),
160                    pos: Some(pos),
161                });
162            }
163        };
164        validate_ident(pieces[0]).map_err(|e| PosError::new(e, pos))?;
165
166        if let Some(original) = args.insert(pieces[0].to_owned(), pieces[1].to_owned()) {
167            return Err(PosError {
168                source: anyhow!(
169                    "argument '{}' specified twice: {} & {}",
170                    pieces[0],
171                    original,
172                    pieces[1]
173                ),
174                pos: Some(pos),
175            });
176        };
177    }
178    Ok(BuiltinCommand {
179        name,
180        args: ArgMap(args),
181        input: slurp_all(line_reader),
182    })
183}
184
185/// Validate that the string is an allowed variable name (lowercase letters, numbers and dashes)
186pub fn validate_ident(name: &str) -> Result<(), anyhow::Error> {
187    static VALID_KEY_REGEX: LazyLock<Regex> =
188        LazyLock::new(|| Regex::new("^[a-z0-9\\-]*$").unwrap());
189    if !VALID_KEY_REGEX.is_match(name) {
190        bail!(
191            "invalid builtin argument name '{}': \
192             only lowercase letters, numbers, and hyphens allowed",
193            name
194        );
195    }
196    Ok(())
197}
198
199fn parse_version_constraint(
200    line_reader: &mut LineReader,
201) -> Result<Option<VersionConstraint>, PosError> {
202    let (pos, line) = line_reader.next().unwrap();
203    if line[1..2].to_string() != "[" {
204        line_reader.push(&line);
205        return Ok(None);
206    }
207    let closed_brace_pos = match line.find(']') {
208        Some(x) => x,
209        None => {
210            return Err(PosError {
211                source: anyhow!("version-constraint: found no closing brace"),
212                pos: Some(pos),
213            });
214        }
215    };
216    let mut begin_version_kw = 2;
217    const MIN_VERSION: i32 = 0;
218    let mut min_version = MIN_VERSION;
219    if line.as_bytes()[2].is_ascii_digit() {
220        let Some(op_pos) = line.find('<') else {
221            return Err(PosError {
222                source: anyhow!("version-constraint: initial number but no '<' following"),
223                pos: Some(pos),
224            });
225        };
226        let min_version_str = line[2..op_pos].to_string();
227        match min_version_str.parse::<i32>() {
228            Ok(mv) => min_version = mv,
229            Err(_) => {
230                return Err(PosError {
231                    source: anyhow!(
232                        "version-constraint: invalid version number {}",
233                        min_version_str
234                    ),
235                    pos: Some(pos),
236                });
237            }
238        };
239
240        if line.as_bytes()[op_pos + 1] == b'=' {
241            begin_version_kw = op_pos + 2;
242        } else {
243            begin_version_kw = op_pos + 1;
244            min_version += 1;
245        }
246    };
247
248    let version_start = begin_version_kw + "version".len();
249    if line[begin_version_kw..version_start].to_string() != "version" {
250        return Err(PosError {
251            source: anyhow!(
252                "version-constraint: invalid property {} (found '{}', expected 'version' {begin_version_kw})",
253                &line[2..closed_brace_pos],
254                &line[begin_version_kw..version_start]
255            ),
256            pos: Some(pos),
257        });
258    }
259    let remainder = line[closed_brace_pos + 1..].to_string();
260    line_reader.push(&remainder);
261    const MAX_VERSION: i32 = 9999999;
262
263    if version_start >= closed_brace_pos && min_version != MIN_VERSION {
264        return Ok(Some(VersionConstraint {
265            min: min_version,
266            max: MAX_VERSION,
267        }));
268    }
269
270    let version_pos = if line.as_bytes()[version_start + 1].is_ascii_digit() {
271        version_start + 1
272    } else {
273        version_start + 2
274    };
275    let version = match line[version_pos..closed_brace_pos].parse::<i32>() {
276        Ok(x) => x,
277        Err(_) => {
278            return Err(PosError {
279                source: anyhow!(
280                    "version-constraint: invalid version number {}",
281                    &line[version_pos..closed_brace_pos]
282                ),
283                pos: Some(pos),
284            });
285        }
286    };
287
288    match &line[version_start..version_pos] {
289        "=" => Ok(Some(VersionConstraint {
290            min: version,
291            max: version,
292        })),
293        "<=" => Ok(Some(VersionConstraint {
294            min: min_version,
295            max: version,
296        })),
297        "<" => Ok(Some(VersionConstraint {
298            min: min_version,
299            max: version - 1,
300        })),
301        ">=" if min_version == MIN_VERSION => Ok(Some(VersionConstraint {
302            min: version,
303            max: MAX_VERSION,
304        })),
305        ">" if min_version == MIN_VERSION => Ok(Some(VersionConstraint {
306            min: version + 1,
307            max: MAX_VERSION,
308        })),
309        ">=" | ">" => Err(PosError {
310            source: anyhow!(
311                "version-constraint: found comparison operator {} with a set minimum version {min_version}",
312                &line[version_start..version_pos]
313            ),
314            pos: Some(pos),
315        }),
316        _ => Err(PosError {
317            source: anyhow!(
318                "version-constraint: unknown comparison operator {}",
319                &line[version_start..version_pos]
320            ),
321            pos: Some(pos),
322        }),
323    }
324}
325
326fn parse_sql(line_reader: &mut LineReader) -> Result<SqlCommand, PosError> {
327    let (_, line1) = line_reader.next().unwrap();
328    let query = line1[1..].trim().to_owned();
329    let expected_start = line_reader.consumed_raw_pos;
330    let line2 = slurp_one(line_reader);
331    let line3 = slurp_one(line_reader);
332    let mut column_names = None;
333    let mut expected_rows = Vec::new();
334    static HASH_REGEX: LazyLock<Regex> =
335        LazyLock::new(|| Regex::new(r"^(\S+) values hashing to (\S+)$").unwrap());
336    match (line2, line3) {
337        (Some((pos2, line2)), Some((pos3, line3))) => {
338            if line3.len() >= 3 && line3.chars().all(|c| c == '-') {
339                column_names = Some(split_line(pos2, &line2)?);
340            } else {
341                expected_rows.push(split_line(pos2, &line2)?);
342                expected_rows.push(split_line(pos3, &line3)?);
343            }
344        }
345        (Some((pos2, line2)), None) => match HASH_REGEX.captures(&line2) {
346            Some(captures) => match captures[1].parse::<usize>() {
347                Ok(num_values) => {
348                    return Ok(SqlCommand {
349                        query,
350                        expected_output: SqlOutput::Hashed {
351                            num_values,
352                            md5: captures[2].to_owned(),
353                        },
354                        expected_start: 0,
355                        expected_end: 0,
356                    });
357                }
358                Err(err) => {
359                    return Err(PosError {
360                        source: anyhow!("Error parsing number of expected rows: {}", err),
361                        pos: Some(pos2),
362                    });
363                }
364            },
365            None => expected_rows.push(split_line(pos2, &line2)?),
366        },
367        _ => (),
368    }
369    while let Some((pos, line)) = slurp_one(line_reader) {
370        expected_rows.push(split_line(pos, &line)?)
371    }
372    let expected_end = line_reader.consumed_raw_pos;
373    Ok(SqlCommand {
374        query,
375        expected_output: SqlOutput::Full {
376            column_names,
377            expected_rows,
378        },
379        expected_start,
380        expected_end,
381    })
382}
383
384fn parse_explain_sql(line_reader: &mut LineReader) -> Result<SqlCommand, PosError> {
385    let (_, line1) = line_reader.next().unwrap();
386    let expected_start = line_reader.consumed_raw_pos;
387    // This is a bit of a hack to extract the next chunk of the file with
388    // blank lines intact. Ideally the `LineReader` would expose the API we
389    // need directly, but that would require a large refactor.
390    let mut expected_output: String = line_reader
391        .inner
392        .lines()
393        .filter(|l| !matches!(l.chars().next(), Some('#')))
394        .take_while(|l| !is_sigil(l.chars().next()))
395        .fold(String::new(), |mut output, l| {
396            let _ = write!(output, "{}\n", l);
397            output
398        });
399    while expected_output.ends_with("\n\n") {
400        expected_output.pop();
401    }
402    // We parsed the multiline expected_output directly using line_reader.inner
403    // above.
404    slurp_all(line_reader);
405    let expected_end = line_reader.consumed_raw_pos;
406
407    Ok(SqlCommand {
408        query: line1[1..].trim().to_owned(),
409        expected_output: SqlOutput::Full {
410            column_names: None,
411            expected_rows: vec![vec![expected_output]],
412        },
413        expected_start,
414        expected_end,
415    })
416}
417
418fn parse_fail_sql(line_reader: &mut LineReader) -> Result<FailSqlCommand, PosError> {
419    let (pos, line1) = line_reader.next().unwrap();
420    let line2 = slurp_one(line_reader);
421    let (err_pos, expected_error) = match line2 {
422        Some((err_pos, line2)) => (err_pos, line2),
423        None => {
424            return Err(PosError {
425                pos: Some(pos),
426                source: anyhow!("failing SQL command is missing expected error message"),
427            });
428        }
429    };
430    let query = line1[1..].trim().to_string();
431
432    let expected_error = if let Some(e) = expected_error.strip_prefix("regex:") {
433        SqlExpectedError::Regex(e.trim().into())
434    } else if let Some(e) = expected_error.strip_prefix("contains:") {
435        SqlExpectedError::Contains(e.trim().into())
436    } else if let Some(e) = expected_error.strip_prefix("exact:") {
437        SqlExpectedError::Exact(e.trim().into())
438    } else if expected_error == "timeout" {
439        SqlExpectedError::Timeout
440    } else {
441        return Err(PosError {
442            pos: Some(err_pos),
443            source: anyhow!(
444                "Query error must start with match specifier (`regex:`|`contains:`|`exact:`|`timeout`)"
445            ),
446        });
447    };
448
449    let extra_error = |line_reader: &mut LineReader, prefix| {
450        if let Some((_pos, line)) = line_reader.peek() {
451            if let Some(_) = line.strip_prefix(prefix) {
452                let line = line_reader
453                    .next()
454                    .map(|(_, line)| line)
455                    .unwrap()
456                    .strip_prefix(prefix)
457                    .map(|line| line.to_string())
458                    .unwrap();
459                Some(line.trim().to_string())
460            } else {
461                None
462            }
463        } else {
464            None
465        }
466    };
467    // Expect `hint` to always follow `detail` if they are both present, for now.
468    let expected_detail = extra_error(line_reader, "detail:");
469    let expected_hint = extra_error(line_reader, "hint:");
470
471    Ok(FailSqlCommand {
472        query: query.trim().to_string(),
473        expected_error,
474        expected_detail,
475        expected_hint,
476    })
477}
478
479fn split_line(pos: usize, line: &str) -> Result<Vec<String>, PosError> {
480    let mut out = Vec::new();
481    let mut field = String::new();
482    let mut in_quotes = None;
483    let mut escaping = false;
484    for (i, c) in line.char_indices() {
485        if in_quotes.is_none() && c.is_whitespace() {
486            if !field.is_empty() {
487                out.push(field);
488                field = String::new();
489            }
490        } else if c == '"' && !escaping {
491            if in_quotes.is_none() {
492                in_quotes = Some(i)
493            } else {
494                in_quotes = None;
495                out.push(field);
496                field = String::new();
497            }
498        } else if c == '\\' && !escaping && in_quotes.is_some() {
499            escaping = true;
500        } else if escaping {
501            field.push(match c {
502                'n' => '\n',
503                't' => '\t',
504                'r' => '\r',
505                '0' => '\0',
506                c => c,
507            });
508            escaping = false;
509        } else {
510            field.push(c);
511        }
512    }
513    if let Some(i) = in_quotes {
514        return Err(PosError {
515            source: anyhow!("unterminated quote"),
516            pos: Some(pos + i),
517        });
518    }
519    if !field.is_empty() {
520        out.push(field);
521    }
522    Ok(out)
523}
524
525fn slurp_all(line_reader: &mut LineReader) -> Vec<String> {
526    let mut out = Vec::new();
527    while let Some((_, line)) = slurp_one(line_reader) {
528        out.push(line);
529    }
530    out
531}
532
533fn slurp_one(line_reader: &mut LineReader) -> Option<(usize, String)> {
534    while let Some((_, line)) = line_reader.peek() {
535        match line.chars().next() {
536            Some('#') => {
537                // Comment line. Skip.
538                let _ = line_reader.next();
539            }
540            Some('$') | Some('>') | Some('!') | Some('?') => return None,
541            Some('\\') => {
542                return line_reader.next().map(|(pos, mut line)| {
543                    line.remove(0);
544                    (pos, line)
545                });
546            }
547            _ => return line_reader.next(),
548        }
549    }
550    None
551}
552
553pub struct LineReader<'a> {
554    inner: &'a str,
555    #[allow(clippy::option_option)]
556    next: Option<Option<(usize, String)>>,
557
558    src_line: usize,
559    pos: usize,
560    pos_map: BTreeMap<usize, (usize, usize)>,
561    raw_pos: usize,
562    // Position one byte past the end of the most recently *consumed* line —
563    // i.e. a line returned by an external call to `next()`. `peek()` reads a
564    // line from `inner` and advances `raw_pos`, but the line is not consumed
565    // until a caller takes it via `next()`. `consumed_raw_pos` therefore lags
566    // `raw_pos` by one peeked-but-not-consumed line, which is what callers
567    // need when slicing the input string at command boundaries.
568    consumed_raw_pos: usize,
569}
570
571impl<'a> LineReader<'a> {
572    pub fn new(inner: &'a str) -> LineReader<'a> {
573        let mut pos_map = BTreeMap::new();
574        pos_map.insert(0, (1, 1));
575        LineReader {
576            inner,
577            src_line: 1,
578            next: None,
579            pos: 0,
580            pos_map,
581            raw_pos: 0,
582            consumed_raw_pos: 0,
583        }
584    }
585
586    fn peek(&mut self) -> Option<&(usize, String)> {
587        if self.next.is_none() {
588            self.next = Some(self.read_one())
589        }
590        self.next.as_ref().unwrap().as_ref()
591    }
592
593    pub fn line_col(&self, pos: usize) -> (usize, usize) {
594        let (base_pos, (line, col)) = self.pos_map.range(..=pos).next_back().unwrap();
595        (*line, col + (pos - base_pos))
596    }
597
598    fn push(&mut self, text: &String) {
599        self.next = Some(Some((0usize, text.to_string())));
600    }
601
602    fn read_one(&mut self) -> Option<(usize, String)> {
603        if self.inner.is_empty() {
604            return None;
605        }
606        let mut fold_newlines = is_non_sql_sigil(self.inner.chars().next());
607        let mut handle_newlines = is_sql_sigil(self.inner.chars().next());
608        let mut line = String::new();
609        let mut chars = self.inner.char_indices().fuse().peekable();
610        while let Some((i, c)) = chars.next() {
611            if c == '\n' {
612                self.src_line += 1;
613                if fold_newlines && self.inner.get(i + 1..i + 3) == Some("  ") {
614                    // Chomp the newline and one space. This ensures a SQL query
615                    // that is split over two lines does not become invalid. For $ commands the
616                    // newline should not be removed so that the argument parser can handle the
617                    // arguments correctly.
618                    chars.next();
619                    self.pos_map.insert(self.pos + i, (self.src_line, 2));
620                    continue;
621                } else if handle_newlines && self.inner.get(i + 1..i + 3) == Some("  ") {
622                    // Chomp the two spaces after newline. This ensures a SQL query
623                    // that is split over two lines does not become invalid, and keeping the
624                    // newline ensures that comments don't remove the following lines.
625                    line.push(c);
626                    chars.next();
627                    chars.next();
628                    self.pos_map.insert(self.pos + i + 1, (self.src_line, 2));
629                    continue;
630                } else if line.chars().all(char::is_whitespace) {
631                    line.clear();
632                    fold_newlines = is_non_sql_sigil(chars.peek().map(|c| c.1));
633                    handle_newlines = is_sql_sigil(chars.peek().map(|c| c.1));
634                    self.pos_map.insert(self.pos, (self.src_line, 1));
635                    continue;
636                }
637                let pos = self.pos;
638                self.pos += i;
639                self.raw_pos += i + 1; // Include \n character in count
640                self.pos_map.insert(self.pos, (self.src_line, 1));
641                self.inner = &self.inner[i + 1..];
642                return Some((pos, line));
643            }
644            line.push(c)
645        }
646        self.inner = "";
647        if !line.chars().all(char::is_whitespace) {
648            Some((self.pos, line))
649        } else {
650            None
651        }
652    }
653}
654
655impl<'a> Iterator for LineReader<'a> {
656    type Item = (usize, String);
657
658    fn next(&mut self) -> Option<Self::Item> {
659        let item = if let Some(next) = self.next.take() {
660            next
661        } else {
662            self.read_one()
663        };
664        if item.is_some() {
665            // Sync `consumed_raw_pos` with `raw_pos` only when the line is
666            // actually returned to an external caller. `peek()` advances
667            // `raw_pos` via `read_one()` without consuming, so updating
668            // `consumed_raw_pos` here keeps it pointing at the end of the most
669            // recently consumed line — even if a later line has been peeked.
670            self.consumed_raw_pos = self.raw_pos;
671        }
672        item
673    }
674}
675
676fn is_sigil(c: Option<char>) -> bool {
677    is_sql_sigil(c) || is_non_sql_sigil(c)
678}
679
680fn is_sql_sigil(c: Option<char>) -> bool {
681    matches!(c, Some('>') | Some('!') | Some('?'))
682}
683
684fn is_non_sql_sigil(c: Option<char>) -> bool {
685    matches!(c, Some('$'))
686}
687
688struct BuiltinReader<'a> {
689    inner: &'a str,
690    pos: usize,
691}
692
693impl<'a> BuiltinReader<'a> {
694    fn new(line: &str, pos: usize) -> BuiltinReader<'_> {
695        BuiltinReader {
696            inner: &line[1..],
697            pos,
698        }
699    }
700}
701
702impl<'a> Iterator for BuiltinReader<'a> {
703    type Item = Result<(usize, String), PosError>;
704
705    fn next(&mut self) -> Option<Self::Item> {
706        if self.inner.is_empty() {
707            return None;
708        }
709
710        let mut iter = self.inner.char_indices().peekable();
711
712        while let Some((i, c)) = iter.peek() {
713            if c == &' ' {
714                iter.next();
715            } else {
716                self.pos += i;
717                break;
718            }
719        }
720
721        let mut token = String::new();
722        let mut nesting = Vec::new();
723        let mut done = false;
724        let mut quoted = false;
725        for (i, c) in iter {
726            if c == ' ' && nesting.is_empty() && !quoted {
727                done = true;
728                continue;
729            } else if done {
730                if let Some(nested) = nesting.last() {
731                    return Some(Err(PosError {
732                        pos: Some(self.pos + i),
733                        source: anyhow!(
734                            "command argument has unterminated open {}",
735                            if nested == &'{' { "brace" } else { "bracket" }
736                        ),
737                    }));
738                }
739                let pos = self.pos;
740                self.pos += i;
741                self.inner = &self.inner[i..];
742                return Some(Ok((pos, token)));
743            } else if (c == '{' || c == '[') && !quoted {
744                nesting.push(c);
745            } else if (c == '}' || c == ']') && !quoted {
746                if let Some(nested) = nesting.last() {
747                    if (nested == &'{' && c == '}') || (nested == &'[' && c == ']') {
748                        nesting.pop();
749                    } else {
750                        return Some(Err(PosError {
751                            pos: Some(self.pos + i),
752                            source: anyhow!(
753                                "command argument has unterminated open {}",
754                                if nested == &'{' { "brace" } else { "bracket" }
755                            ),
756                        }));
757                    }
758                } else {
759                    return Some(Err(PosError {
760                        pos: Some(self.pos + i),
761                        source: anyhow!(
762                            "command argument has unbalanced close {}",
763                            if c == '}' { "brace" } else { "bracket" }
764                        ),
765                    }));
766                }
767            } else if c == '"' && nesting.is_empty() {
768                // remove the double quote for un-nested commands such as: command="\dt public"
769                // keep the quotes when inside of a nested object such as: schema={ "type" : "array" }
770                quoted = !quoted;
771                continue;
772            }
773            token.push(c);
774        }
775
776        if let Some(nested) = nesting.last() {
777            return Some(Err(PosError {
778                pos: Some(self.pos + self.inner.len() - 1),
779                source: anyhow!(
780                    "command argument has unterminated open {}",
781                    if nested == &'{' { "brace" } else { "bracket" }
782                ),
783            }));
784        }
785
786        if quoted {
787            return Some(Err(PosError {
788                pos: Some(self.pos),
789                source: anyhow!("command argument has unterminated open double quote",),
790            }));
791        }
792
793        self.inner = "";
794        if token.is_empty() {
795            None
796        } else {
797            Some(Ok((self.pos, token)))
798        }
799    }
800}
801
802#[derive(Debug, Clone)]
803pub struct ArgMap(BTreeMap<String, String>);
804
805impl ArgMap {
806    pub fn values_mut(&mut self) -> btree_map::ValuesMut<'_, String, String> {
807        self.0.values_mut()
808    }
809
810    pub fn opt_string(&mut self, name: &str) -> Option<String> {
811        self.0.remove(name)
812    }
813
814    pub fn string(&mut self, name: &str) -> Result<String, anyhow::Error> {
815        self.opt_string(name)
816            .ok_or_else(|| anyhow!("missing {} parameter", name))
817    }
818
819    pub fn opt_parse<T>(&mut self, name: &str) -> Result<Option<T>, anyhow::Error>
820    where
821        T: FromStr,
822        T::Err: Error + Send + Sync + 'static,
823    {
824        match self.opt_string(name) {
825            Some(val) => {
826                let t = val
827                    .parse()
828                    .with_context(|| format!("parsing {} parameter", name))?;
829                Ok(Some(t))
830            }
831            None => Ok(None),
832        }
833    }
834
835    pub fn parse<T>(&mut self, name: &str) -> Result<T, anyhow::Error>
836    where
837        T: FromStr,
838        T::Err: Error + Send + Sync + 'static,
839    {
840        match self.opt_parse(name) {
841            Ok(None) => bail!("missing {} parameter", name),
842            Ok(Some(t)) => Ok(t),
843            Err(err) => Err(err),
844        }
845    }
846
847    pub fn opt_bool(&mut self, name: &str) -> Result<Option<bool>, anyhow::Error> {
848        self.opt_string(name)
849            .map(|val| {
850                if val == "true" {
851                    Ok(true)
852                } else if val == "false" {
853                    Ok(false)
854                } else {
855                    bail!("bad value for boolean parameter {}: {}", name, val);
856                }
857            })
858            .transpose()
859    }
860
861    pub fn done(&self) -> Result<(), anyhow::Error> {
862        if let Some(name) = self.0.keys().next() {
863            bail!("unknown parameter {}", name);
864        }
865        Ok(())
866    }
867}
868
869impl IntoIterator for ArgMap {
870    type Item = (String, String);
871    type IntoIter = btree_map::IntoIter<String, String>;
872
873    fn into_iter(self) -> Self::IntoIter {
874        self.0.into_iter()
875    }
876}