Skip to main content

mz_testdrive/action/postgres/
execute.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use anyhow::{Context, anyhow, bail};
11use mz_ore::task;
12use tokio_postgres::Client;
13
14use crate::action::{ControlFlow, State};
15use crate::parser::BuiltinCommand;
16use crate::util::postgres::postgres_client;
17
18async fn execute_input(cmd: BuiltinCommand, client: &Client) -> Result<(), anyhow::Error> {
19    for query in cmd.input {
20        println!(">> {}", query);
21        // `query` is raw SQL from testdrive input and may contain multiple
22        // statements; this command intentionally forwards it verbatim.
23        #[allow(clippy::disallowed_methods)]
24        client
25            .batch_execute(&query)
26            .await
27            .context("executing postgres query")?;
28    }
29    Ok(())
30}
31
32pub async fn run_execute(
33    mut cmd: BuiltinCommand,
34    state: &State,
35) -> Result<ControlFlow, anyhow::Error> {
36    let connection = cmd.args.string("connection")?;
37    let background = cmd.args.opt_bool("background")?.unwrap_or(false);
38    cmd.args.done()?;
39
40    match (connection.starts_with("postgres://"), background) {
41        (true, true) => {
42            let (client_inner, _) = postgres_client(&connection, state.default_timeout).await?;
43            task::spawn(|| "postgres-execute", async move {
44                match execute_input(cmd, &client_inner).await {
45                    Ok(_) => {}
46                    Err(e) => println!("Error in backgrounded postgres-execute query: {e}"),
47                }
48            });
49        }
50        (false, true) => bail!("cannot use 'background' arg with referenced connection"),
51        (true, false) => {
52            let (client_inner, _) = postgres_client(&connection, state.default_timeout).await?;
53            execute_input(cmd, &client_inner).await?;
54        }
55        (false, false) => {
56            let client = state
57                .postgres_clients
58                .get(&connection)
59                .ok_or_else(|| anyhow!("connection '{}' not found", &connection))?;
60            execute_input(cmd, client).await?;
61        }
62    }
63
64    Ok(ControlFlow::Continue)
65}