mz_testdrive/action/
duckdb.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10mod execute;
11mod query;
12
13pub use execute::run_execute;
14pub use query::run_query;
15
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex};
18
19use anyhow::Context;
20use duckdb::Connection;
21
22use crate::action::State;
23
24/// Gets or creates a DuckDB connection with the given name.
25pub(crate) async fn get_or_create_connection(
26    state: &mut State,
27    name: String,
28) -> Result<Arc<Mutex<Connection>>, anyhow::Error> {
29    if let Some(conn) = state.duckdb_clients.get(&name) {
30        return Ok(Arc::clone(conn));
31    }
32
33    let temp_path = state.temp_path.clone();
34    let conn = create_connection(temp_path).await?;
35    let conn = Arc::new(Mutex::new(conn));
36    state.duckdb_clients.insert(name, Arc::clone(&conn));
37    Ok(conn)
38}
39
40async fn create_connection(temp_path: PathBuf) -> Result<Connection, anyhow::Error> {
41    mz_ore::task::spawn_blocking(
42        || "duckdb_connect".to_string(),
43        move || {
44            let conn =
45                Connection::open_in_memory().context("opening in-memory DuckDB connection")?;
46
47            let ext_dir = temp_path.join("duckdb_extensions");
48            conn.execute(
49                &format!("SET extension_directory = '{}';", ext_dir.display()),
50                [],
51            )
52            .context("setting extension_directory")?;
53
54            conn.execute_batch("INSTALL iceberg; LOAD iceberg;")
55                .context("installing iceberg extension")?;
56            conn.execute_batch("INSTALL httpfs; LOAD httpfs;")
57                .context("installing httpfs extension")?;
58
59            Ok::<_, anyhow::Error>(conn)
60        },
61    )
62    .await
63}