mz_testdrive/action/
set.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cmp;
11
12use anyhow::{Context, bail};
13use regex::Regex;
14use tokio::fs;
15
16use crate::action::{ControlFlow, State};
17use crate::parser::BuiltinCommand;
18
19pub const DEFAULT_REGEX_REPLACEMENT: &str = "<regex_match>";
20
21pub fn run_regex_set(
22    mut cmd: BuiltinCommand,
23    state: &mut State,
24) -> Result<ControlFlow, anyhow::Error> {
25    let regex: Regex = cmd.args.parse("match")?;
26    let replacement = cmd
27        .args
28        .opt_string("replacement")
29        .unwrap_or_else(|| DEFAULT_REGEX_REPLACEMENT.into());
30    cmd.args.done()?;
31
32    state.regex = Some(regex);
33    state.regex_replacement = replacement;
34    Ok(ControlFlow::Continue)
35}
36
37pub fn run_regex_unset(
38    cmd: BuiltinCommand,
39    state: &mut State,
40) -> Result<ControlFlow, anyhow::Error> {
41    cmd.args.done()?;
42    state.regex = None;
43    state.regex_replacement = DEFAULT_REGEX_REPLACEMENT.to_string();
44    Ok(ControlFlow::Continue)
45}
46
47pub fn run_sql_timeout(
48    mut cmd: BuiltinCommand,
49    state: &mut State,
50) -> Result<ControlFlow, anyhow::Error> {
51    let duration = cmd.args.string("duration")?;
52    let duration = if duration.to_lowercase() == "default" {
53        None
54    } else {
55        Some(humantime::parse_duration(&duration).context("parsing duration")?)
56    };
57    let force = cmd.args.opt_bool("force")?.unwrap_or(false);
58    cmd.args.done()?;
59    state.timeout = duration.unwrap_or(state.default_timeout);
60    if !force {
61        // Bump the timeout to be at least the default timeout unless the
62        // timeout has been forced.
63        state.timeout = cmp::max(state.timeout, state.default_timeout);
64    }
65    Ok(ControlFlow::Continue)
66}
67
68pub fn run_max_tries(
69    mut cmd: BuiltinCommand,
70    state: &mut State,
71) -> Result<ControlFlow, anyhow::Error> {
72    let max_tries = cmd.args.string("max-tries")?;
73    cmd.args.done()?;
74    state.max_tries = max_tries.parse::<usize>()?;
75    Ok(ControlFlow::Continue)
76}
77
78pub fn run_set_arg_default(
79    cmd: BuiltinCommand,
80    state: &mut State,
81) -> Result<ControlFlow, anyhow::Error> {
82    for (key, val) in cmd.args {
83        let arg_key = format!("arg.{key}");
84        state.cmd_vars.entry(arg_key).or_insert(val);
85    }
86
87    Ok(ControlFlow::Continue)
88}
89
90pub fn set_vars(cmd: BuiltinCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
91    for (key, val) in cmd.args {
92        if val.is_empty() {
93            state.cmd_vars.insert(key, cmd.input.join("\n"));
94        } else {
95            state.cmd_vars.insert(key, val);
96        }
97    }
98
99    Ok(ControlFlow::Continue)
100}
101
102pub async fn run_set_from_sql(
103    mut cmd: BuiltinCommand,
104    state: &mut State,
105) -> Result<ControlFlow, anyhow::Error> {
106    let var = cmd.args.string("var")?;
107    cmd.args.done()?;
108
109    let row = state
110        .materialize
111        .pgclient
112        .query_one(&cmd.input.join("\n"), &[])
113        .await
114        .context("running query")?;
115    if row.columns().len() != 1 {
116        bail!(
117            "set-from-sql query must return exactly one column, but it returned {}",
118            row.columns().len()
119        );
120    }
121    let value: String = row.try_get(0).context("deserializing value as string")?;
122
123    state.cmd_vars.insert(var, value);
124
125    Ok(ControlFlow::Continue)
126}
127
128pub async fn run_set_from_file(
129    cmd: BuiltinCommand,
130    state: &mut State,
131) -> Result<ControlFlow, anyhow::Error> {
132    for (key, path) in cmd.args {
133        println!("Setting {} to contents of {}...", key, path);
134        let contents = fs::read_to_string(&path)
135            .await
136            .with_context(|| format!("reading {path}"))?;
137        state.cmd_vars.insert(key, contents);
138    }
139    Ok(ControlFlow::Continue)
140}