Skip to main content

mz_testdrive/action/
s3.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// The metadata for Arrow `Field` type requires `std::collections::HashMap`, which is disallowed.
11#[allow(clippy::disallowed_types)]
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::str;
15use std::sync::Arc;
16use std::thread;
17use std::time::Duration;
18
19use anyhow::Context;
20use anyhow::bail;
21use arrow::array::{
22    ArrayRef, BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, FixedSizeBinaryBuilder,
23    Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, Int64Builder,
24    ListBuilder, StringArray, StructArray, Time32SecondArray, TimestampMillisecondArray,
25    UInt8Array, UInt16Array, UInt32Array, UInt64Array,
26};
27use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
28use arrow::record_batch::RecordBatch;
29use arrow::util::display::ArrayFormatter;
30use arrow::util::display::FormatOptions;
31use async_compression::tokio::bufread::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
32use chrono::{NaiveDate, NaiveDateTime, NaiveTime, Timelike};
33use parquet::arrow::ArrowWriter;
34use parquet::basic::{BrotliLevel, Compression as ParquetCompression, GzipLevel, ZstdLevel};
35use parquet::file::properties::WriterProperties;
36use regex::Regex;
37use tokio::io::{AsyncRead, AsyncReadExt};
38
39use crate::action::file::Compression;
40use crate::action::file::build_compression;
41use crate::action::file::build_contents;
42use crate::action::{ControlFlow, State};
43use crate::parser::BuiltinCommand;
44
45pub async fn run_verify_data(
46    mut cmd: BuiltinCommand,
47    state: &State,
48) -> Result<ControlFlow, anyhow::Error> {
49    let mut expected_body = cmd
50        .input
51        .into_iter()
52        // Strip suffix to allow lines with trailing whitespace
53        .map(|line| {
54            line.trim_end_matches("// allow-trailing-whitespace")
55                .to_string()
56        })
57        .collect::<Vec<String>>();
58    let bucket: String = cmd.args.parse("bucket")?;
59    let key: String = cmd.args.parse("key")?;
60    let sort_rows = cmd.args.opt_bool("sort-rows")?.unwrap_or(false);
61    cmd.args.done()?;
62
63    println!("Verifying contents of S3 bucket {bucket} key {key}...");
64
65    let client = mz_aws_util::s3::new_client(&state.aws_config);
66
67    // List the path until the INCOMPLETE sentinel file disappears so we know the
68    // data is complete.
69    let mut attempts = 0;
70    let all_files;
71    loop {
72        attempts += 1;
73        if attempts > 10 {
74            bail!("found incomplete sentinel file in path {key} after 10 attempts")
75        }
76
77        let files = client
78            .list_objects_v2()
79            .bucket(&bucket)
80            .prefix(&format!("{}/", key))
81            .send()
82            .await?;
83        match files.contents {
84            Some(files)
85                if files
86                    .iter()
87                    .any(|obj| obj.key().map_or(false, |key| key.contains("INCOMPLETE"))) =>
88            {
89                thread::sleep(Duration::from_secs(1))
90            }
91            None => bail!("no files found in bucket {bucket} key {key}"),
92            Some(files) => {
93                all_files = files;
94                break;
95            }
96        }
97    }
98
99    let mut rows = vec![];
100    for obj in all_files.iter() {
101        let file = client
102            .get_object()
103            .bucket(&bucket)
104            .key(obj.key().unwrap())
105            .send()
106            .await?;
107        let bytes = file.body.collect().await?.into_bytes();
108
109        let new_rows = match obj.key().unwrap() {
110            key if key.ends_with(".csv") => {
111                let actual_body = str::from_utf8(bytes.as_ref())?;
112                actual_body.lines().map(|l| l.to_string()).collect()
113            }
114            key if key.ends_with(".parquet") => rows_from_parquet(bytes),
115            key => bail!("unexpected file type: {key}"),
116        };
117        rows.extend(new_rows);
118    }
119    if sort_rows {
120        expected_body.sort();
121        rows.sort();
122    }
123    if rows != expected_body {
124        bail!(
125            "content did not match\nexpected:\n{:?}\n\nactual:\n{:?}",
126            expected_body,
127            rows
128        );
129    }
130
131    Ok(ControlFlow::Continue)
132}
133
134pub async fn run_verify_keys(
135    mut cmd: BuiltinCommand,
136    state: &State,
137) -> Result<ControlFlow, anyhow::Error> {
138    let bucket: String = cmd.args.parse("bucket")?;
139    let prefix_path: String = cmd.args.parse("prefix-path")?;
140    let key_pattern: Regex = cmd.args.parse("key-pattern")?;
141    let num_attempts = cmd.args.opt_parse("num-attempts")?.unwrap_or(30);
142    cmd.args.done()?;
143
144    println!("Verifying {key_pattern} in S3 bucket {bucket} path {prefix_path}...");
145
146    let client = mz_aws_util::s3::new_client(&state.aws_config);
147
148    let mut attempts = 0;
149    while attempts <= num_attempts {
150        attempts += 1;
151        let files = client
152            .list_objects_v2()
153            .bucket(&bucket)
154            .prefix(&format!("{}/", prefix_path))
155            .send()
156            .await?;
157        match files.contents {
158            Some(files) => {
159                let files: Vec<_> = files
160                    .iter()
161                    .filter(|obj| key_pattern.is_match(obj.key().unwrap()))
162                    .map(|obj| obj.key().unwrap())
163                    .collect();
164                if !files.is_empty() {
165                    println!("Found matching files: {files:?}");
166                    return Ok(ControlFlow::Continue);
167                }
168            }
169            _ => thread::sleep(Duration::from_secs(1)),
170        }
171    }
172
173    bail!("Did not find matching files in bucket {bucket} prefix {prefix_path}");
174}
175
176fn rows_from_parquet(bytes: bytes::Bytes) -> Vec<String> {
177    let reader =
178        parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(bytes, 1_000_000).unwrap();
179
180    let mut ret = vec![];
181    let format_options = FormatOptions::default();
182    for batch in reader {
183        let batch = batch.unwrap();
184        let converters = batch
185            .columns()
186            .iter()
187            .map(|a| ArrayFormatter::try_new(a.as_ref(), &format_options).unwrap())
188            .collect::<Vec<_>>();
189
190        for row_idx in 0..batch.num_rows() {
191            let mut buf = String::new();
192            for (col_idx, converter) in converters.iter().enumerate() {
193                if col_idx > 0 {
194                    buf.push_str(" ");
195                }
196                converter.value(row_idx).write(&mut buf).unwrap();
197            }
198            ret.push(buf);
199        }
200    }
201    ret
202}
203
204pub async fn run_upload(
205    mut cmd: BuiltinCommand,
206    state: &State,
207) -> Result<ControlFlow, anyhow::Error> {
208    let bucket = cmd.args.string("bucket")?;
209    let count: Option<usize> = cmd.args.opt_parse("count")?;
210
211    let keys: Vec<String> = if let Some(count) = count {
212        // Bulk mode uses `key-prefix` + `i` + optional `key-suffix`,
213        let prefix = cmd.args.string("key-prefix")?;
214        let suffix = cmd.args.opt_string("key-suffix").unwrap_or_default();
215        (0..count).map(|i| format!("{prefix}{i}{suffix}")).collect()
216    } else {
217        // Single-file mode uses `key`.
218        vec![cmd.args.string("key")?]
219    };
220
221    let compression = build_compression(&mut cmd)?;
222    let content = build_contents(&mut cmd)?;
223
224    let aws_client = mz_aws_util::s3::new_client(&state.aws_config);
225
226    // TODO(parkmycar): Stream data to S3. The ByteStream type from the AWS config is a bit
227    // cumbersome to work with, so for now just stick with this.
228    let mut body = vec![];
229    for line in content {
230        body.extend(&line);
231        body.push(b'\n');
232    }
233
234    let mut reader: Pin<Box<dyn AsyncRead + Send + Sync>> = match compression {
235        Compression::None => Box::pin(&body[..]),
236        Compression::Gzip => Box::pin(GzipEncoder::new(&body[..])),
237        Compression::Bzip2 => Box::pin(BzEncoder::new(&body[..])),
238        Compression::Xz => Box::pin(XzEncoder::new(&body[..])),
239        Compression::Zstd => Box::pin(ZstdEncoder::new(&body[..])),
240    };
241    let mut content = vec![];
242    reader
243        .read_to_end(&mut content)
244        .await
245        .context("compressing")?;
246
247    // Upload the file(s) to S3.
248    println!(
249        "Uploading {} files to S3 bucket, starting with '{bucket}/{}'",
250        keys.len(),
251        keys.first().map(String::as_str).unwrap_or("<none>")
252    );
253    for key in &keys {
254        aws_client
255            .put_object()
256            .bucket(&bucket)
257            .key(key)
258            .body(content.clone().into())
259            .send()
260            .await
261            .context("s3 put")?;
262    }
263
264    Ok(ControlFlow::Continue)
265}
266
267pub async fn run_set_presigned_url(
268    mut cmd: BuiltinCommand,
269    state: &mut State,
270) -> Result<ControlFlow, anyhow::Error> {
271    let key = cmd.args.string("key")?;
272    let bucket = cmd.args.string("bucket")?;
273    let var_name = cmd.args.string("var-name")?;
274
275    let aws_client = mz_aws_util::s3::new_client(&state.aws_config);
276    let presign_config = mz_aws_util::s3::new_presigned_config();
277    let request = aws_client
278        .get_object()
279        .bucket(&bucket)
280        .key(&key)
281        .presigned(presign_config)
282        .await
283        .context("s3 presign")?;
284
285    println!("Setting '{var_name}' to presigned URL for {bucket}/{key}");
286    state.cmd_vars.insert(var_name, request.uri().to_string());
287
288    Ok(ControlFlow::Continue)
289}
290
291/// Generates parquet files covering a wide range of Arrow types and uploads them to S3 with
292/// multiple compression variants. This is the Rust equivalent of the Python
293/// `generate_parquet_files()` function.
294///
295/// Uploads:
296/// - `{key-prefix}` (uncompressed)
297/// - `{key-prefix}.snappy`
298/// - `{key-prefix}.gzip`
299/// - `{key-prefix}.brotli`
300/// - `{key-prefix}.zstd`
301/// - `{key-prefix}.lz4`
302pub async fn run_upload_parquet_types(
303    mut cmd: BuiltinCommand,
304    state: &State,
305) -> Result<ControlFlow, anyhow::Error> {
306    let bucket = cmd.args.string("bucket")?;
307    let key_prefix = cmd.args.string("key-prefix")?;
308    cmd.args.done()?;
309
310    let batch = build_parquet_types_batch().context("building parquet types batch")?;
311
312    let compressions = vec![
313        ("".to_string(), ParquetCompression::UNCOMPRESSED),
314        (".snappy".to_string(), ParquetCompression::SNAPPY),
315        (
316            ".gzip".to_string(),
317            ParquetCompression::GZIP(GzipLevel::default()),
318        ),
319        (
320            ".brotli".to_string(),
321            ParquetCompression::BROTLI(BrotliLevel::default()),
322        ),
323        (
324            ".zstd".to_string(),
325            ParquetCompression::ZSTD(ZstdLevel::default()),
326        ),
327        (".lz4".to_string(), ParquetCompression::LZ4_RAW),
328    ];
329
330    let client = mz_aws_util::s3::new_client(&state.aws_config);
331
332    for (suffix, compression) in compressions {
333        let key = format!("{key_prefix}{suffix}");
334        println!("Uploading parquet types file to S3 bucket {bucket}/{key}");
335
336        let props = WriterProperties::builder()
337            .set_compression(compression)
338            .build();
339        let mut buf = Vec::new();
340        {
341            let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props))
342                .context("creating parquet writer")?;
343            writer.write(&batch).context("writing parquet batch")?;
344            writer.close().context("closing parquet writer")?;
345        }
346
347        client
348            .put_object()
349            .bucket(&bucket)
350            .key(&key)
351            .body(buf.into())
352            .send()
353            .await
354            .context("s3 put")?;
355    }
356
357    Ok(ControlFlow::Continue)
358}
359
360// Using `as ArrayRef` is necessary when creating the struct array because the inner arrays have different types.
361// The metadata for Arrow `Field` type requires `std::collections::HashMap`, which is disallowed.
362#[allow(clippy::as_conversions, clippy::disallowed_types)]
363fn build_parquet_types_batch() -> Result<RecordBatch, anyhow::Error> {
364    let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
365
366    // date32: days since epoch
367    let date_values: Vec<i32> = [
368        NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
369        NaiveDate::from_ymd_opt(2025, 11, 2).unwrap(),
370        NaiveDate::from_ymd_opt(2025, 11, 3).unwrap(),
371    ]
372    .into_iter()
373    .map(|d| d.signed_duration_since(epoch).num_days() as i32)
374    .collect();
375
376    // timestamp(ms): ms since epoch (no timezone)
377    let datetime_values: Vec<i64> = [
378        NaiveDateTime::new(
379            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
380            NaiveTime::from_hms_opt(10, 0, 0).unwrap(),
381        ),
382        NaiveDateTime::new(
383            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
384            NaiveTime::from_hms_opt(11, 30, 0).unwrap(),
385        ),
386        NaiveDateTime::new(
387            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
388            NaiveTime::from_hms_opt(12, 0, 0).unwrap(),
389        ),
390    ]
391    .into_iter()
392    .map(|dt| dt.and_utc().timestamp_millis())
393    .collect();
394
395    // time32(s): seconds since midnight
396    let time_values: Vec<i32> = [
397        NaiveTime::from_hms_opt(9, 0, 0).unwrap(),
398        NaiveTime::from_hms_opt(10, 30, 15).unwrap(),
399        NaiveTime::from_hms_opt(11, 45, 30).unwrap(),
400    ]
401    .into_iter()
402    .map(|t| t.num_seconds_from_midnight() as i32)
403    .collect();
404
405    // list<int64>: [-1, 2], [3, 4, 5], []
406    let mut list_builder = ListBuilder::new(Int64Builder::new());
407    for &val in &[-1i64, 2] {
408        list_builder.values().append_value(val);
409    }
410    list_builder.append(true);
411    for &val in &[3i64, 4, 5] {
412        list_builder.values().append_value(val);
413    }
414    list_builder.append(true);
415    list_builder.append(true); // empty list
416    let list_array = Arc::new(list_builder.finish());
417
418    // decimal128(precision=10, scale=5): -54.321, 123.45, null
419    let decimal_array = Arc::new(
420        Decimal128Array::from(vec![Some(-5_432_100i128), Some(12_345_000i128), None])
421            .with_precision_and_scale(10, 5)
422            .context("setting decimal precision/scale")?,
423    );
424
425    // struct/record: (name text, age int32, avg float64)
426    let struct_array = Arc::new(StructArray::from(vec![
427        (
428            Arc::new(Field::new("name", DataType::Utf8, true)),
429            Arc::new(StringArray::from(vec!["Taco", "Burger", "SlimJim"])) as ArrayRef,
430        ),
431        (
432            Arc::new(Field::new("age", DataType::Int32, true)),
433            Arc::new(Int32Array::from(vec![3, 2, 1])) as ArrayRef,
434        ),
435        (
436            Arc::new(Field::new("avg", DataType::Float64, true)),
437            Arc::new(Float64Array::from(vec![2.2, 4.5, 1.14])) as ArrayRef,
438        ),
439    ]));
440
441    // uuid: FixedSizeBinary(16) with arrow.uuid extension metadata
442    let mut uuid_builder = FixedSizeBinaryBuilder::with_capacity(3, 16);
443    for uuid_str in &[
444        "badc0deb-adc0-deba-dc0d-ebadc0debadc",
445        "deadbeef-dead-4eef-8eef-deaddeadbeef",
446        "00000000-0000-0000-0000-000000000000",
447    ] {
448        let uuid_val = uuid::Uuid::parse_str(uuid_str).context("parsing uuid")?;
449        uuid_builder
450            .append_value(uuid_val.as_bytes())
451            .context("appending uuid bytes")?;
452    }
453    let uuid_array = Arc::new(uuid_builder.finish());
454
455    // variable-length binary
456    let mut binary_builder = BinaryBuilder::new();
457    binary_builder.append_value(b"raw1");
458    binary_builder.append_value(b"raw2");
459    binary_builder.append_value(b"raw3");
460    let binary_array = Arc::new(binary_builder.finish());
461
462    let mut uuid_metadata = HashMap::new();
463    uuid_metadata.insert("ARROW:extension:name".to_string(), "arrow.uuid".to_string());
464
465    let schema = Arc::new(Schema::new(vec![
466        Field::new("int8_col", DataType::Int8, true),
467        Field::new("uint8_col", DataType::UInt8, true),
468        Field::new("int16_col", DataType::Int16, true),
469        Field::new("uint16_col", DataType::UInt16, true),
470        Field::new("int32_col", DataType::Int32, true),
471        Field::new("uint32_col", DataType::UInt32, true),
472        Field::new("int64_col", DataType::Int64, true),
473        Field::new("uint64_col", DataType::UInt64, true),
474        Field::new("float32_col", DataType::Float32, true),
475        Field::new("float64_col", DataType::Float64, true),
476        Field::new("bool_col", DataType::Boolean, true),
477        Field::new("string_col", DataType::Utf8, true),
478        Field::new("binary_col", DataType::Binary, true),
479        Field::new("date32_col", DataType::Date32, true),
480        Field::new(
481            "timestamp_ms_col",
482            DataType::Timestamp(TimeUnit::Millisecond, None),
483            true,
484        ),
485        Field::new("time32_col", DataType::Time32(TimeUnit::Second), true),
486        Field::new(
487            "list_col",
488            DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
489            true,
490        ),
491        Field::new("decimal_col", DataType::Decimal128(10, 5), true),
492        Field::new("json_col", DataType::Utf8, true),
493        Field::new(
494            "record_col",
495            DataType::Struct(
496                vec![
497                    Field::new("name", DataType::Utf8, true),
498                    Field::new("age", DataType::Int32, true),
499                    Field::new("avg", DataType::Float64, true),
500                ]
501                .into(),
502            ),
503            true,
504        ),
505        Field::new("uuid_col", DataType::FixedSizeBinary(16), false).with_metadata(uuid_metadata),
506    ]));
507
508    let batch = RecordBatch::try_new(
509        schema,
510        vec![
511            Arc::new(Int8Array::from(vec![-1i8, 2, 3])),
512            Arc::new(UInt8Array::from(vec![10u8, 20, 30])),
513            Arc::new(Int16Array::from(vec![-1000i16, 2000, 3000])),
514            Arc::new(UInt16Array::from(vec![10000u16, 20000, 30000])),
515            Arc::new(Int32Array::from(vec![-100000i32, 200000, 300000])),
516            Arc::new(UInt32Array::from(vec![1000000u32, 2000000, 3000000])),
517            Arc::new(Int64Array::from(vec![
518                -1_000_000_000i64,
519                2_000_000_000,
520                3_000_000_000,
521            ])),
522            Arc::new(UInt64Array::from(vec![
523                1_000_000_000_000_000_000u64,
524                2_000_000_000_000_000_000,
525                3_000_000_000_000_000_000,
526            ])),
527            Arc::new(Float32Array::from(vec![-1.0f32, 2.5, 3.7])),
528            Arc::new(Float64Array::from(vec![-1.0f64, 2.5, 3.7])),
529            Arc::new(BooleanArray::from(vec![true, false, true])),
530            Arc::new(StringArray::from(vec!["apple", "banana", "cherry"])),
531            binary_array,
532            Arc::new(Date32Array::from(date_values)),
533            Arc::new(TimestampMillisecondArray::from(datetime_values)),
534            Arc::new(Time32SecondArray::from(time_values)),
535            list_array,
536            decimal_array,
537            Arc::new(StringArray::from(vec![
538                r#"{"a": 5, "b": { "c": 1.1 } }"#,
539                r#"{ "d": "str", "e" : [1,2,3] }"#,
540                "{}",
541            ])),
542            struct_array,
543            uuid_array,
544        ],
545    )
546    .context("building record batch")?;
547
548    Ok(batch)
549}