mz_testdrive/action/postgres/
execute.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use anyhow::{Context, anyhow, bail};
11use mz_ore::task;
12use tokio_postgres::Client;
13
14use crate::action::{ControlFlow, State};
15use crate::parser::BuiltinCommand;
16use crate::util::postgres::postgres_client;
17
18async fn execute_input(cmd: BuiltinCommand, client: &Client) -> Result<(), anyhow::Error> {
19    for query in cmd.input {
20        println!(">> {}", query);
21        client
22            .batch_execute(&query)
23            .await
24            .context("executing postgres query")?;
25    }
26    Ok(())
27}
28
29pub async fn run_execute(
30    mut cmd: BuiltinCommand,
31    state: &State,
32) -> Result<ControlFlow, anyhow::Error> {
33    let connection = cmd.args.string("connection")?;
34    let background = cmd.args.opt_bool("background")?.unwrap_or(false);
35    cmd.args.done()?;
36
37    match (connection.starts_with("postgres://"), background) {
38        (true, true) => {
39            let (client_inner, _) = postgres_client(&connection, state.default_timeout).await?;
40            task::spawn(|| "postgres-execute", async move {
41                match execute_input(cmd, &client_inner).await {
42                    Ok(_) => {}
43                    Err(e) => println!("Error in backgrounded postgres-execute query: {e}"),
44                }
45            });
46        }
47        (false, true) => bail!("cannot use 'background' arg with referenced connection"),
48        (true, false) => {
49            let (client_inner, _) = postgres_client(&connection, state.default_timeout).await?;
50            execute_input(cmd, &client_inner).await?;
51        }
52        (false, false) => {
53            let client = state
54                .postgres_clients
55                .get(&connection)
56                .ok_or_else(|| anyhow!("connection '{}' not found", &connection))?;
57            execute_input(cmd, client).await?;
58        }
59    }
60
61    Ok(ControlFlow::Continue)
62}