Skip to main content

mz_testdrive/action/
s3.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10// The metadata for Arrow `Field` type requires `std::collections::HashMap`, which is disallowed.
11#[allow(clippy::disallowed_types)]
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::str;
15use std::sync::Arc;
16use std::thread;
17use std::time::Duration;
18
19use anyhow::Context;
20use anyhow::bail;
21use arrow::array::{
22    ArrayRef, BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, FixedSizeBinaryBuilder,
23    Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, Int64Builder,
24    IntervalDayTimeArray, IntervalYearMonthArray, ListBuilder, MapBuilder, StringArray,
25    StringBuilder, StructArray, Time32SecondArray, TimestampMillisecondArray, UInt8Array,
26    UInt16Array, UInt32Array, UInt64Array,
27};
28use arrow::datatypes::{DataType, Field, IntervalDayTime, IntervalUnit, Schema, TimeUnit};
29use arrow::record_batch::RecordBatch;
30use arrow::util::display::ArrayFormatter;
31use arrow::util::display::FormatOptions;
32use async_compression::tokio::bufread::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
33use chrono::{NaiveDate, NaiveDateTime, NaiveTime, Timelike};
34use parquet::arrow::ArrowWriter;
35use parquet::basic::{BrotliLevel, Compression as ParquetCompression, GzipLevel, ZstdLevel};
36use parquet::file::properties::WriterProperties;
37use regex::Regex;
38use tokio::io::{AsyncRead, AsyncReadExt};
39
40use crate::action::file::Compression;
41use crate::action::file::build_compression;
42use crate::action::file::build_contents;
43use crate::action::{ControlFlow, State};
44use crate::parser::BuiltinCommand;
45
46pub async fn run_verify_data(
47    mut cmd: BuiltinCommand,
48    state: &State,
49) -> Result<ControlFlow, anyhow::Error> {
50    let mut expected_body = cmd
51        .input
52        .into_iter()
53        // Strip suffix to allow lines with trailing whitespace
54        .map(|line| {
55            line.trim_end_matches("// allow-trailing-whitespace")
56                .to_string()
57        })
58        .collect::<Vec<String>>();
59    let bucket: String = cmd.args.parse("bucket")?;
60    let key: String = cmd.args.parse("key")?;
61    let sort_rows = cmd.args.opt_bool("sort-rows")?.unwrap_or(false);
62    cmd.args.done()?;
63
64    println!("Verifying contents of S3 bucket {bucket} key {key}...");
65
66    let client = mz_aws_util::s3::new_client(&state.aws_config);
67
68    // List the path until the INCOMPLETE sentinel file disappears so we know the
69    // data is complete.
70    let mut attempts = 0;
71    let all_files;
72    loop {
73        attempts += 1;
74        if attempts > 10 {
75            bail!("found incomplete sentinel file in path {key} after 10 attempts")
76        }
77
78        let files = client
79            .list_objects_v2()
80            .bucket(&bucket)
81            .prefix(&format!("{}/", key))
82            .send()
83            .await?;
84        match files.contents {
85            Some(files)
86                if files
87                    .iter()
88                    .any(|obj| obj.key().map_or(false, |key| key.contains("INCOMPLETE"))) =>
89            {
90                thread::sleep(Duration::from_secs(1))
91            }
92            None => bail!("no files found in bucket {bucket} key {key}"),
93            Some(files) => {
94                all_files = files;
95                break;
96            }
97        }
98    }
99
100    let mut rows = vec![];
101    for obj in all_files.iter() {
102        let file = client
103            .get_object()
104            .bucket(&bucket)
105            .key(obj.key().unwrap())
106            .send()
107            .await?;
108        let bytes = file.body.collect().await?.into_bytes();
109
110        let new_rows = match obj.key().unwrap() {
111            key if key.ends_with(".csv") => {
112                let actual_body = str::from_utf8(bytes.as_ref())?;
113                actual_body.lines().map(|l| l.to_string()).collect()
114            }
115            key if key.ends_with(".parquet") => rows_from_parquet(bytes),
116            key => bail!("unexpected file type: {key}"),
117        };
118        rows.extend(new_rows);
119    }
120    if sort_rows {
121        expected_body.sort();
122        rows.sort();
123    }
124    if rows != expected_body {
125        bail!(
126            "content did not match\nexpected:\n{:?}\n\nactual:\n{:?}",
127            expected_body,
128            rows
129        );
130    }
131
132    Ok(ControlFlow::Continue)
133}
134
135pub async fn run_verify_keys(
136    mut cmd: BuiltinCommand,
137    state: &State,
138) -> Result<ControlFlow, anyhow::Error> {
139    let bucket: String = cmd.args.parse("bucket")?;
140    let prefix_path: String = cmd.args.parse("prefix-path")?;
141    let key_pattern: Regex = cmd.args.parse("key-pattern")?;
142    let num_attempts = cmd.args.opt_parse("num-attempts")?.unwrap_or(30);
143    cmd.args.done()?;
144
145    println!("Verifying {key_pattern} in S3 bucket {bucket} path {prefix_path}...");
146
147    let client = mz_aws_util::s3::new_client(&state.aws_config);
148
149    let mut attempts = 0;
150    while attempts <= num_attempts {
151        attempts += 1;
152        let files = client
153            .list_objects_v2()
154            .bucket(&bucket)
155            .prefix(&format!("{}/", prefix_path))
156            .send()
157            .await?;
158        match files.contents {
159            Some(files) => {
160                let files: Vec<_> = files
161                    .iter()
162                    .filter(|obj| key_pattern.is_match(obj.key().unwrap()))
163                    .map(|obj| obj.key().unwrap())
164                    .collect();
165                if !files.is_empty() {
166                    println!("Found matching files: {files:?}");
167                    return Ok(ControlFlow::Continue);
168                }
169            }
170            _ => thread::sleep(Duration::from_secs(1)),
171        }
172    }
173
174    bail!("Did not find matching files in bucket {bucket} prefix {prefix_path}");
175}
176
177fn rows_from_parquet(bytes: bytes::Bytes) -> Vec<String> {
178    let reader =
179        parquet::arrow::arrow_reader::ParquetRecordBatchReader::try_new(bytes, 1_000_000).unwrap();
180
181    let mut ret = vec![];
182    let format_options = FormatOptions::default();
183    for batch in reader {
184        let batch = batch.unwrap();
185        let converters = batch
186            .columns()
187            .iter()
188            .map(|a| ArrayFormatter::try_new(a.as_ref(), &format_options).unwrap())
189            .collect::<Vec<_>>();
190
191        for row_idx in 0..batch.num_rows() {
192            let mut buf = String::new();
193            for (col_idx, converter) in converters.iter().enumerate() {
194                if col_idx > 0 {
195                    buf.push_str(" ");
196                }
197                converter.value(row_idx).write(&mut buf).unwrap();
198            }
199            ret.push(buf);
200        }
201    }
202    ret
203}
204
205pub async fn run_upload(
206    mut cmd: BuiltinCommand,
207    state: &State,
208) -> Result<ControlFlow, anyhow::Error> {
209    let bucket = cmd.args.string("bucket")?;
210    let count: Option<usize> = cmd.args.opt_parse("count")?;
211
212    let keys: Vec<String> = if let Some(count) = count {
213        // Bulk mode uses `key-prefix` + `i` + optional `key-suffix`,
214        let prefix = cmd.args.string("key-prefix")?;
215        let suffix = cmd.args.opt_string("key-suffix").unwrap_or_default();
216        (0..count).map(|i| format!("{prefix}{i}{suffix}")).collect()
217    } else {
218        // Single-file mode uses `key`.
219        vec![cmd.args.string("key")?]
220    };
221
222    let compression = build_compression(&mut cmd)?;
223    let content = build_contents(&mut cmd)?;
224
225    let aws_client = mz_aws_util::s3::new_client(&state.aws_config);
226
227    // TODO(parkmycar): Stream data to S3. The ByteStream type from the AWS config is a bit
228    // cumbersome to work with, so for now just stick with this.
229    let mut body = vec![];
230    for line in content {
231        body.extend(&line);
232        body.push(b'\n');
233    }
234
235    let mut reader: Pin<Box<dyn AsyncRead + Send + Sync>> = match compression {
236        Compression::None => Box::pin(&body[..]),
237        Compression::Gzip => Box::pin(GzipEncoder::new(&body[..])),
238        Compression::Bzip2 => Box::pin(BzEncoder::new(&body[..])),
239        Compression::Xz => Box::pin(XzEncoder::new(&body[..])),
240        Compression::Zstd => Box::pin(ZstdEncoder::new(&body[..])),
241    };
242    let mut content = vec![];
243    reader
244        .read_to_end(&mut content)
245        .await
246        .context("compressing")?;
247
248    // Upload the file(s) to S3.
249    println!(
250        "Uploading {} files to S3 bucket, starting with '{bucket}/{}'",
251        keys.len(),
252        keys.first().map(String::as_str).unwrap_or("<none>")
253    );
254    for key in &keys {
255        aws_client
256            .put_object()
257            .bucket(&bucket)
258            .key(key)
259            .body(content.clone().into())
260            .send()
261            .await
262            .context("s3 put")?;
263    }
264
265    Ok(ControlFlow::Continue)
266}
267
268pub async fn run_set_presigned_url(
269    mut cmd: BuiltinCommand,
270    state: &mut State,
271) -> Result<ControlFlow, anyhow::Error> {
272    let key = cmd.args.string("key")?;
273    let bucket = cmd.args.string("bucket")?;
274    let var_name = cmd.args.string("var-name")?;
275
276    let aws_client = mz_aws_util::s3::new_client(&state.aws_config);
277    let presign_config = mz_aws_util::s3::new_presigned_config();
278    let request = aws_client
279        .get_object()
280        .bucket(&bucket)
281        .key(&key)
282        .presigned(presign_config)
283        .await
284        .context("s3 presign")?;
285
286    println!("Setting '{var_name}' to presigned URL for {bucket}/{key}");
287    state.cmd_vars.insert(var_name, request.uri().to_string());
288
289    Ok(ControlFlow::Continue)
290}
291
292/// Generates parquet files covering a wide range of Arrow types and uploads them to S3 with
293/// multiple compression variants. This is the Rust equivalent of the Python
294/// `generate_parquet_files()` function.
295///
296/// Uploads:
297/// - `{key-prefix}` (uncompressed)
298/// - `{key-prefix}.snappy`
299/// - `{key-prefix}.gzip`
300/// - `{key-prefix}.brotli`
301/// - `{key-prefix}.zstd`
302/// - `{key-prefix}.lz4`
303pub async fn run_upload_parquet_types(
304    mut cmd: BuiltinCommand,
305    state: &State,
306) -> Result<ControlFlow, anyhow::Error> {
307    let bucket = cmd.args.string("bucket")?;
308    let key_prefix = cmd.args.string("key-prefix")?;
309    cmd.args.done()?;
310
311    let batch = build_parquet_types_batch().context("building parquet types batch")?;
312
313    let compressions = vec![
314        ("".to_string(), ParquetCompression::UNCOMPRESSED),
315        (".snappy".to_string(), ParquetCompression::SNAPPY),
316        (
317            ".gzip".to_string(),
318            ParquetCompression::GZIP(GzipLevel::default()),
319        ),
320        (
321            ".brotli".to_string(),
322            ParquetCompression::BROTLI(BrotliLevel::default()),
323        ),
324        (
325            ".zstd".to_string(),
326            ParquetCompression::ZSTD(ZstdLevel::default()),
327        ),
328        (".lz4".to_string(), ParquetCompression::LZ4_RAW),
329    ];
330
331    let client = mz_aws_util::s3::new_client(&state.aws_config);
332
333    for (suffix, compression) in compressions {
334        let key = format!("{key_prefix}{suffix}");
335        println!("Uploading parquet types file to S3 bucket {bucket}/{key}");
336
337        let props = WriterProperties::builder()
338            .set_compression(compression)
339            .build();
340        let mut buf = Vec::new();
341        {
342            let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props))
343                .context("creating parquet writer")?;
344            writer.write(&batch).context("writing parquet batch")?;
345            writer.close().context("closing parquet writer")?;
346        }
347
348        client
349            .put_object()
350            .bucket(&bucket)
351            .key(&key)
352            .body(buf.into())
353            .send()
354            .await
355            .context("s3 put")?;
356    }
357
358    Ok(ControlFlow::Continue)
359}
360
361// Using `as ArrayRef` is necessary when creating the struct array because the inner arrays have different types.
362// The metadata for Arrow `Field` type requires `std::collections::HashMap`, which is disallowed.
363#[allow(clippy::as_conversions, clippy::disallowed_types)]
364fn build_parquet_types_batch() -> Result<RecordBatch, anyhow::Error> {
365    let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
366
367    // date32: days since epoch
368    let date_values: Vec<i32> = [
369        NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
370        NaiveDate::from_ymd_opt(2025, 11, 2).unwrap(),
371        NaiveDate::from_ymd_opt(2025, 11, 3).unwrap(),
372    ]
373    .into_iter()
374    .map(|d| d.signed_duration_since(epoch).num_days() as i32)
375    .collect();
376
377    // timestamp(ms): ms since epoch (no timezone)
378    let datetime_values: Vec<i64> = [
379        NaiveDateTime::new(
380            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
381            NaiveTime::from_hms_opt(10, 0, 0).unwrap(),
382        ),
383        NaiveDateTime::new(
384            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
385            NaiveTime::from_hms_opt(11, 30, 0).unwrap(),
386        ),
387        NaiveDateTime::new(
388            NaiveDate::from_ymd_opt(2025, 11, 1).unwrap(),
389            NaiveTime::from_hms_opt(12, 0, 0).unwrap(),
390        ),
391    ]
392    .into_iter()
393    .map(|dt| dt.and_utc().timestamp_millis())
394    .collect();
395
396    // time32(s): seconds since midnight
397    let time_values: Vec<i32> = [
398        NaiveTime::from_hms_opt(9, 0, 0).unwrap(),
399        NaiveTime::from_hms_opt(10, 30, 15).unwrap(),
400        NaiveTime::from_hms_opt(11, 45, 30).unwrap(),
401    ]
402    .into_iter()
403    .map(|t| t.num_seconds_from_midnight() as i32)
404    .collect();
405
406    // list<int64>: [-1, 2], [3, 4, 5], []
407    let mut list_builder = ListBuilder::new(Int64Builder::new());
408    for &val in &[-1i64, 2] {
409        list_builder.values().append_value(val);
410    }
411    list_builder.append(true);
412    for &val in &[3i64, 4, 5] {
413        list_builder.values().append_value(val);
414    }
415    list_builder.append(true);
416    list_builder.append(true); // empty list
417    let list_array = Arc::new(list_builder.finish());
418
419    // decimal128(precision=10, scale=5): -54.321, 123.45, null
420    let decimal_array = Arc::new(
421        Decimal128Array::from(vec![Some(-5_432_100i128), Some(12_345_000i128), None])
422            .with_precision_and_scale(10, 5)
423            .context("setting decimal precision/scale")?,
424    );
425
426    // struct/record: (name text, age int32, avg float64)
427    let struct_array = Arc::new(StructArray::from(vec![
428        (
429            Arc::new(Field::new("name", DataType::Utf8, true)),
430            Arc::new(StringArray::from(vec!["Taco", "Burger", "SlimJim"])) as ArrayRef,
431        ),
432        (
433            Arc::new(Field::new("age", DataType::Int32, true)),
434            Arc::new(Int32Array::from(vec![3, 2, 1])) as ArrayRef,
435        ),
436        (
437            Arc::new(Field::new("avg", DataType::Float64, true)),
438            Arc::new(Float64Array::from(vec![2.2, 4.5, 1.14])) as ArrayRef,
439        ),
440    ]));
441
442    // uuid: FixedSizeBinary(16) with arrow.uuid extension metadata
443    let mut uuid_builder = FixedSizeBinaryBuilder::with_capacity(3, 16);
444    for uuid_str in &[
445        "badc0deb-adc0-deba-dc0d-ebadc0debadc",
446        "deadbeef-dead-4eef-8eef-deaddeadbeef",
447        "00000000-0000-0000-0000-000000000000",
448    ] {
449        let uuid_val = uuid::Uuid::parse_str(uuid_str).context("parsing uuid")?;
450        uuid_builder
451            .append_value(uuid_val.as_bytes())
452            .context("appending uuid bytes")?;
453    }
454    let uuid_array = Arc::new(uuid_builder.finish());
455
456    // variable-length binary
457    let mut binary_builder = BinaryBuilder::new();
458    binary_builder.append_value(b"raw1");
459    binary_builder.append_value(b"raw2");
460    binary_builder.append_value(b"raw3");
461    let binary_array = Arc::new(binary_builder.finish());
462
463    // interval(year-month): stored as i32 months
464    let interval_ym_array = Arc::new(IntervalYearMonthArray::from(vec![1i32, 13, -2]));
465
466    // interval(day-time): stored as (days: i32, milliseconds: i32)
467    let interval_dt_array = Arc::new(IntervalDayTimeArray::from(vec![
468        IntervalDayTime {
469            days: 1,
470            milliseconds: 500,
471        },
472        IntervalDayTime {
473            days: 30,
474            milliseconds: 0,
475        },
476        IntervalDayTime {
477            days: -1,
478            milliseconds: -1000,
479        },
480    ]));
481
482    // map<utf8, utf8>: {"k1": "v1", "k2": "v2"}, {"k3": "v3"}, {}
483    let mut map_builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
484    map_builder.keys().append_value("k1");
485    map_builder.values().append_value("v1");
486    map_builder.keys().append_value("k2");
487    map_builder.values().append_value("v2");
488    map_builder.append(true).context("appending map row 0")?;
489    map_builder.keys().append_value("k3");
490    map_builder.values().append_value("v3");
491    map_builder.append(true).context("appending map row 1")?;
492    map_builder.append(true).context("appending map row 2")?; // empty map
493    let map_array = Arc::new(map_builder.finish());
494
495    let mut uuid_metadata = HashMap::new();
496    uuid_metadata.insert("ARROW:extension:name".to_string(), "arrow.uuid".to_string());
497
498    let schema = Arc::new(Schema::new(vec![
499        Field::new("int8_col", DataType::Int8, true),
500        Field::new("uint8_col", DataType::UInt8, true),
501        Field::new("int16_col", DataType::Int16, true),
502        Field::new("uint16_col", DataType::UInt16, true),
503        Field::new("int32_col", DataType::Int32, true),
504        Field::new("uint32_col", DataType::UInt32, true),
505        Field::new("int64_col", DataType::Int64, true),
506        Field::new("uint64_col", DataType::UInt64, true),
507        Field::new("float32_col", DataType::Float32, true),
508        Field::new("float64_col", DataType::Float64, true),
509        Field::new("bool_col", DataType::Boolean, true),
510        Field::new("string_col", DataType::Utf8, true),
511        Field::new("binary_col", DataType::Binary, true),
512        Field::new("date32_col", DataType::Date32, true),
513        Field::new(
514            "timestamp_ms_col",
515            DataType::Timestamp(TimeUnit::Millisecond, None),
516            true,
517        ),
518        Field::new("time32_col", DataType::Time32(TimeUnit::Second), true),
519        Field::new(
520            "list_col",
521            DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
522            true,
523        ),
524        Field::new("decimal_col", DataType::Decimal128(10, 5), true),
525        Field::new("json_col", DataType::Utf8, true),
526        Field::new(
527            "record_col",
528            DataType::Struct(
529                vec![
530                    Field::new("name", DataType::Utf8, true),
531                    Field::new("age", DataType::Int32, true),
532                    Field::new("avg", DataType::Float64, true),
533                ]
534                .into(),
535            ),
536            true,
537        ),
538        Field::new("uuid_col", DataType::FixedSizeBinary(16), false).with_metadata(uuid_metadata),
539        Field::new(
540            "interval_ym_col",
541            DataType::Interval(IntervalUnit::YearMonth),
542            true,
543        ),
544        Field::new(
545            "interval_dt_col",
546            DataType::Interval(IntervalUnit::DayTime),
547            true,
548        ),
549        Field::new(
550            "map_col",
551            DataType::Map(
552                Arc::new(Field::new(
553                    "entries",
554                    DataType::Struct(
555                        vec![
556                            Field::new("keys", DataType::Utf8, false),
557                            Field::new("values", DataType::Utf8, true),
558                        ]
559                        .into(),
560                    ),
561                    false,
562                )),
563                false,
564            ),
565            true,
566        ),
567    ]));
568
569    let batch = RecordBatch::try_new(
570        schema,
571        vec![
572            Arc::new(Int8Array::from(vec![-1i8, 2, 3])),
573            Arc::new(UInt8Array::from(vec![10u8, 20, 30])),
574            Arc::new(Int16Array::from(vec![-1000i16, 2000, 3000])),
575            Arc::new(UInt16Array::from(vec![10000u16, 20000, 30000])),
576            Arc::new(Int32Array::from(vec![-100000i32, 200000, 300000])),
577            Arc::new(UInt32Array::from(vec![1000000u32, 2000000, 3000000])),
578            Arc::new(Int64Array::from(vec![
579                -1_000_000_000i64,
580                2_000_000_000,
581                3_000_000_000,
582            ])),
583            Arc::new(UInt64Array::from(vec![
584                1_000_000_000_000_000_000u64,
585                2_000_000_000_000_000_000,
586                3_000_000_000_000_000_000,
587            ])),
588            Arc::new(Float32Array::from(vec![-1.0f32, 2.5, 3.7])),
589            Arc::new(Float64Array::from(vec![-1.0f64, 2.5, 3.7])),
590            Arc::new(BooleanArray::from(vec![true, false, true])),
591            Arc::new(StringArray::from(vec!["apple", "banana", "cherry"])),
592            binary_array,
593            Arc::new(Date32Array::from(date_values)),
594            Arc::new(TimestampMillisecondArray::from(datetime_values)),
595            Arc::new(Time32SecondArray::from(time_values)),
596            list_array,
597            decimal_array,
598            Arc::new(StringArray::from(vec![
599                r#"{"a": 5, "b": { "c": 1.1 } }"#,
600                r#"{ "d": "str", "e" : [1,2,3] }"#,
601                "{}",
602            ])),
603            struct_array,
604            uuid_array,
605            interval_ym_array,
606            interval_dt_array,
607            map_array,
608        ],
609    )
610    .context("building record batch")?;
611
612    Ok(batch)
613}