mz_testdrive/action/duckdb/
query.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use anyhow::{Context, anyhow, bail};
11use duckdb::types::ValueRef;
12
13use crate::action::duckdb::get_or_create_connection;
14use crate::action::{ControlFlow, State};
15use crate::parser::BuiltinCommand;
16
17pub async fn run_query(
18    mut cmd: BuiltinCommand,
19    state: &mut State,
20) -> Result<ControlFlow, anyhow::Error> {
21    let name = cmd.args.string("name")?;
22    let sort_rows = cmd.args.opt_bool("sort-rows")?.unwrap_or(false);
23    cmd.args.done()?;
24
25    // First line is the query, remaining lines are expected output
26    let mut lines = cmd.input.into_iter();
27    let query = lines
28        .next()
29        .ok_or_else(|| anyhow!("duckdb-query requires a query as the first input line"))?;
30    let mut expected_rows: Vec<String> = lines.collect();
31
32    let conn = get_or_create_connection(state, name).await?;
33
34    let mut actual_rows = mz_ore::task::spawn_blocking(
35        || "duckdb_query".to_string(),
36        move || {
37            let conn = conn.lock().map_err(|e| anyhow!("lock poisoned: {}", e))?;
38            println!(">> {}", query);
39            let mut stmt = conn.prepare(&query).context("preparing DuckDB query")?;
40            let mut rows = stmt.query([]).context("executing DuckDB query")?;
41
42            let mut result = Vec::new();
43            while let Some(row) = rows.next()? {
44                // Get column count from the row's statement
45                let column_count = row.as_ref().column_count();
46                let mut row_values = Vec::with_capacity(column_count);
47                for i in 0..column_count {
48                    let val = row.get_ref(i)?;
49                    let formatted = format_value(&val);
50                    row_values.push(formatted);
51                }
52                result.push(row_values.join(" "));
53            }
54            Ok::<_, anyhow::Error>(result)
55        },
56    )
57    .await?;
58
59    if sort_rows {
60        expected_rows.sort();
61        actual_rows.sort();
62    }
63
64    if actual_rows != expected_rows {
65        bail!(
66            "DuckDB query result mismatch\nexpected ({} rows):\n{}\n\nactual ({} rows):\n{}",
67            expected_rows.len(),
68            expected_rows.join("\n"),
69            actual_rows.len(),
70            actual_rows.join("\n")
71        );
72    }
73
74    Ok(ControlFlow::Continue)
75}
76
77fn format_value(val: &ValueRef) -> String {
78    match val {
79        ValueRef::Null => "<null>".to_string(),
80        ValueRef::Boolean(b) => b.to_string(),
81        ValueRef::TinyInt(i) => i.to_string(),
82        ValueRef::SmallInt(i) => i.to_string(),
83        ValueRef::Int(i) => i.to_string(),
84        ValueRef::BigInt(i) => i.to_string(),
85        ValueRef::HugeInt(i) => i.to_string(),
86        ValueRef::UTinyInt(i) => i.to_string(),
87        ValueRef::USmallInt(i) => i.to_string(),
88        ValueRef::UInt(i) => i.to_string(),
89        ValueRef::UBigInt(i) => i.to_string(),
90        ValueRef::Float(f) => f.to_string(),
91        ValueRef::Double(f) => f.to_string(),
92        ValueRef::Text(bytes) => String::from_utf8_lossy(bytes).to_string(),
93        ValueRef::Blob(bytes) => format!("{:?}", bytes),
94        _ => format!("{:?}", val),
95    }
96}