Skip to main content

mz_testdrive/action/
set.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cmp;
11
12use anyhow::{Context, bail};
13use mz_postgres_util::query_one_prepared;
14use regex::Regex;
15use tokio::fs;
16
17use crate::action::{ControlFlow, State};
18use crate::parser::BuiltinCommand;
19
20pub const DEFAULT_REGEX_REPLACEMENT: &str = "<regex_match>";
21
22pub fn run_regex_set(
23    mut cmd: BuiltinCommand,
24    state: &mut State,
25) -> Result<ControlFlow, anyhow::Error> {
26    let regex: Regex = cmd.args.parse("match")?;
27    let replacement = cmd
28        .args
29        .opt_string("replacement")
30        .unwrap_or_else(|| DEFAULT_REGEX_REPLACEMENT.into());
31    cmd.args.done()?;
32
33    state.regex = Some(regex);
34    state.regex_replacement = replacement;
35    Ok(ControlFlow::Continue)
36}
37
38pub fn run_regex_unset(
39    cmd: BuiltinCommand,
40    state: &mut State,
41) -> Result<ControlFlow, anyhow::Error> {
42    cmd.args.done()?;
43    state.regex = None;
44    state.regex_replacement = DEFAULT_REGEX_REPLACEMENT.to_string();
45    Ok(ControlFlow::Continue)
46}
47
48pub fn run_sql_timeout(
49    mut cmd: BuiltinCommand,
50    state: &mut State,
51) -> Result<ControlFlow, anyhow::Error> {
52    let duration = cmd.args.string("duration")?;
53    let duration = if duration.to_lowercase() == "default" {
54        None
55    } else {
56        Some(humantime::parse_duration(&duration).context("parsing duration")?)
57    };
58    let force = cmd.args.opt_bool("force")?.unwrap_or(false);
59    cmd.args.done()?;
60    state.timeout = duration.unwrap_or(state.default_timeout);
61    if !force {
62        // Bump the timeout to be at least the default timeout unless the
63        // timeout has been forced.
64        state.timeout = cmp::max(state.timeout, state.default_timeout);
65    }
66    Ok(ControlFlow::Continue)
67}
68
69pub fn run_max_tries(
70    mut cmd: BuiltinCommand,
71    state: &mut State,
72) -> Result<ControlFlow, anyhow::Error> {
73    let max_tries = cmd.args.string("max-tries")?;
74    cmd.args.done()?;
75    state.max_tries = max_tries.parse::<usize>()?;
76    Ok(ControlFlow::Continue)
77}
78
79pub fn run_set_arg_default(
80    cmd: BuiltinCommand,
81    state: &mut State,
82) -> Result<ControlFlow, anyhow::Error> {
83    for (key, val) in cmd.args {
84        let arg_key = format!("arg.{key}");
85        state.cmd_vars.entry(arg_key).or_insert(val);
86    }
87
88    Ok(ControlFlow::Continue)
89}
90
91pub fn set_vars(cmd: BuiltinCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
92    for (key, val) in cmd.args {
93        if val.is_empty() {
94            state.cmd_vars.insert(key, cmd.input.join("\n"));
95        } else {
96            state.cmd_vars.insert(key, val);
97        }
98    }
99
100    Ok(ControlFlow::Continue)
101}
102
103pub async fn run_set_from_sql(
104    mut cmd: BuiltinCommand,
105    state: &mut State,
106) -> Result<ControlFlow, anyhow::Error> {
107    let var = cmd.args.string("var")?;
108    cmd.args.done()?;
109
110    let query = cmd.input.join("\n");
111    let statement = state
112        .materialize
113        .pgclient
114        .prepare(&query)
115        .await
116        .context("preparing query")?;
117    let row = query_one_prepared(&state.materialize.pgclient, &statement, &[])
118        .await
119        .context("running query")?;
120    if row.columns().len() != 1 {
121        bail!(
122            "set-from-sql query must return exactly one column, but it returned {}",
123            row.columns().len()
124        );
125    }
126    let value: String = row.try_get(0).context("deserializing value as string")?;
127
128    state.cmd_vars.insert(var, value);
129
130    Ok(ControlFlow::Continue)
131}
132
133pub async fn run_set_from_file(
134    cmd: BuiltinCommand,
135    state: &mut State,
136) -> Result<ControlFlow, anyhow::Error> {
137    for (key, path) in cmd.args {
138        println!("Setting {} to contents of {}...", key, path);
139        let contents = fs::read_to_string(&path)
140            .await
141            .with_context(|| format!("reading {path}"))?;
142        state.cmd_vars.insert(key, contents);
143    }
144    Ok(ControlFlow::Continue)
145}