Skip to main content

mz_clusterd_test_driver/
data.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Synthetic data generation and direct persist writes. This supports the
11//! direct-write *use case*; the mechanism does not depend on it.
12
13use std::sync::Arc;
14
15use mz_ore::cast::CastFrom;
16use mz_persist_client::Diagnostics;
17use mz_persist_client::PersistClient;
18use mz_persist_types::ShardId;
19use mz_persist_types::codec_impls::UnitSchema;
20use mz_repr::{Datum, RelationDesc, Row, SqlColumnType, SqlScalarType, Timestamp};
21use mz_storage_types::StorageDiff;
22use mz_storage_types::sources::SourceData;
23use timely::progress::Antichain;
24
25/// An owned scalar value, packable into a [`Row`].
26///
27/// Bridges synthetic generation and explicit script-provided values: both need
28/// owned storage because [`Datum`] borrows its `String` and `Bytes` payloads. The
29/// supported set is intentionally small; floats, numerics, and temporal types can
30/// be added alongside their synthetic/parse rules when a scenario needs them.
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub enum Cell {
33    /// SQL `NULL` (only valid in a nullable column).
34    Null,
35    /// `smallint`.
36    Int16(i16),
37    /// `integer`.
38    Int32(i32),
39    /// `bigint`.
40    Int64(i64),
41    /// `boolean`.
42    Bool(bool),
43    /// `text`.
44    Str(String),
45    /// `bytea`.
46    Bytes(Vec<u8>),
47}
48
49impl Cell {
50    /// Borrow this cell as a [`Datum`] for packing.
51    pub fn datum(&self) -> Datum<'_> {
52        match self {
53            Cell::Null => Datum::Null,
54            Cell::Int16(v) => Datum::from(*v),
55            Cell::Int32(v) => Datum::from(*v),
56            Cell::Int64(v) => Datum::from(*v),
57            Cell::Bool(v) => Datum::from(*v),
58            Cell::Str(v) => Datum::String(v),
59            Cell::Bytes(v) => Datum::Bytes(v),
60        }
61    }
62}
63
64/// Pack one row from owned cells, in column order.
65pub fn pack_cells(cells: &[Cell]) -> Row {
66    let mut row = Row::default();
67    let mut packer = row.packer();
68    for cell in cells {
69        packer.push(cell.datum());
70    }
71    row
72}
73
74/// A two-column `(bigint, text)` schema used as the default for the generators
75/// and the script reader.
76pub fn sample_desc() -> RelationDesc {
77    RelationDesc::builder()
78        .with_column(
79            "id",
80            SqlColumnType {
81                scalar_type: SqlScalarType::Int64,
82                nullable: false,
83            },
84        )
85        .with_column(
86            "payload",
87            SqlColumnType {
88                scalar_type: SqlScalarType::String,
89                nullable: false,
90            },
91        )
92        .finish()
93}
94
95/// A deterministic synthetic value for a column of `scalar_type` at row index `i`.
96///
97/// `pad` widens `text` columns so callers can target a byte budget. Each value is
98/// a function of `i`, so a row is distinct per `i` as long as the schema has at
99/// least one wide-enough column (int, text, or bytes); an all-`bool` schema would
100/// collide. Unsupported types panic — schema construction rejects them first.
101pub fn synth_cell(scalar_type: &SqlScalarType, i: u64, pad: usize) -> Cell {
102    match scalar_type {
103        // Narrow ints wrap within their non-negative range; use a wide column as
104        // the id for large row counts. The modulus keeps the value in range, so
105        // `try_from` never fails.
106        SqlScalarType::Int16 => Cell::Int16(i16::try_from(i % 0x8000).expect("fits i16")),
107        SqlScalarType::Int32 => Cell::Int32(i32::try_from(i % 0x8000_0000).expect("fits i32")),
108        SqlScalarType::Int64 => Cell::Int64(i64::try_from(i).expect("row index fits i64")),
109        SqlScalarType::Bool => Cell::Bool(i % 2 == 0),
110        SqlScalarType::String => Cell::Str(format!("{:0>width$}", i, width = pad)),
111        SqlScalarType::Bytes => Cell::Bytes(format!("{:0>width$}", i, width = pad).into_bytes()),
112        other => panic!("synth_cell: unsupported scalar type {other:?}"),
113    }
114}
115
116/// Generate `n` synthetic rows for `desc`, with row indices running
117/// `start..start + n`.
118///
119/// Successive batches over disjoint index ranges produce distinct rows that never
120/// consolidate, so a downstream count equals the total rows written (provided the
121/// schema carries a wide-enough column; see [`synth_cell`]).
122pub fn synth_rows(desc: &RelationDesc, start: u64, n: u64, pad: usize) -> Vec<Row> {
123    let types: Vec<SqlScalarType> = desc.iter_types().map(|c| c.scalar_type.clone()).collect();
124    (start..start + n)
125        .map(|i| {
126            let cells: Vec<Cell> = types.iter().map(|t| synth_cell(t, i, pad)).collect();
127            pack_cells(&cells)
128        })
129        .collect()
130}
131
132/// Builds `n` rows of the [`sample_desc`] schema; `payload` is `pad` bytes wide so
133/// callers can target a byte budget (≈ `n * (pad + overhead)`).
134pub fn sample_rows(n: u64, pad: usize) -> Vec<Row> {
135    sample_rows_from(0, n, pad)
136}
137
138/// Like [`sample_rows`], but ids run `start..start + n`. Successive batches with
139/// disjoint id ranges produce distinct rows that never consolidate with each
140/// other, so a downstream count equals the total rows written.
141pub fn sample_rows_from(start: u64, n: u64, pad: usize) -> Vec<Row> {
142    synth_rows(&sample_desc(), start, n, pad)
143}
144
145/// Writes `rows` to `shard` at `ts`, advancing `upper` to `ts+1`. All rows are
146/// inserted with diff `+1`. Returns once the append succeeds.
147pub async fn write_rows_single_ts(
148    client: &PersistClient,
149    shard: ShardId,
150    desc: &RelationDesc,
151    rows: &[Row],
152    ts: Timestamp,
153) -> anyhow::Result<()> {
154    let mut writer = client
155        .open_writer::<SourceData, (), Timestamp, StorageDiff>(
156            shard,
157            Arc::new(desc.clone()),
158            Arc::new(UnitSchema),
159            Diagnostics {
160                shard_name: "driver-data".to_string(),
161                handle_purpose: "headless driver write".to_string(),
162            },
163        )
164        .await?;
165
166    let updates: Vec<_> = rows
167        .iter()
168        .map(|r| ((SourceData(Ok(r.clone())), ()), ts, 1i64))
169        .collect();
170    let lower = Antichain::from_elem(ts);
171    let upper = Antichain::from_elem(ts.step_forward());
172    writer
173        .compare_and_append(&updates, lower, upper)
174        .await?
175        .map_err(|e| anyhow::anyhow!("{e}"))?;
176    Ok(())
177}
178
179/// Writes `rows` spread across timestamps `0..n_ts` (row `i` at time `i % n_ts`)
180/// in a single append that seals `[0, n_ts)`. All rows are inserted with diff
181/// `+1`. This is one `compare_and_append` regardless of `n_ts` — persist accepts
182/// updates at any timestamp within the sealed range — so it stays fast even for
183/// very large `n_ts` (a per-timestamp append would be `n_ts` consensus
184/// round-trips).
185pub async fn write_rows_spread(
186    client: &PersistClient,
187    shard: ShardId,
188    desc: &RelationDesc,
189    rows: &[Row],
190    n_ts: u64,
191) -> anyhow::Result<()> {
192    assert!(n_ts > 0, "n_ts must be positive");
193    let mut writer = client
194        .open_writer::<SourceData, (), Timestamp, StorageDiff>(
195            shard,
196            Arc::new(desc.clone()),
197            Arc::new(UnitSchema),
198            Diagnostics {
199                shard_name: "driver-data".to_string(),
200                handle_purpose: "headless driver spread write".to_string(),
201            },
202        )
203        .await?;
204    let updates: Vec<_> = rows
205        .iter()
206        .enumerate()
207        .map(|(i, r)| {
208            let t = u64::cast_from(i) % n_ts;
209            ((SourceData(Ok(r.clone())), ()), Timestamp::from(t), 1i64)
210        })
211        .collect();
212    let lower = Antichain::from_elem(Timestamp::from(0));
213    let upper = Antichain::from_elem(Timestamp::from(n_ts));
214    writer
215        .compare_and_append(&updates, lower, upper)
216        .await?
217        .map_err(|e| anyhow::anyhow!("{e}"))?;
218    Ok(())
219}
220
221/// Number of rows needed to roughly hit `target_bytes` given `pad`-wide
222/// payloads. Overhead per row is approximate; coarse sizing, not exact.
223pub fn rows_for_bytes(target_bytes: u64, pad: usize) -> u64 {
224    let per_row = u64::cast_from(pad) + 24;
225    (target_bytes / per_row).max(1)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::persist_host::PersistHost;
232    use mz_persist_types::PersistLocation;
233
234    #[mz_ore::test(tokio::test)]
235    #[cfg_attr(miri, ignore)]
236    async fn write_then_snapshot_counts() {
237        let host = PersistHost::start(PersistLocation::new_in_mem())
238            .await
239            .expect("host");
240        let client = host.client().await.expect("client");
241        let shard = ShardId::new();
242        let desc = sample_desc();
243        let rows = sample_rows(1000, 16);
244        write_rows_single_ts(&client, shard, &desc, &rows, Timestamp::from(0))
245            .await
246            .expect("write");
247
248        let mut reader = client
249            .open_leased_reader::<SourceData, (), Timestamp, StorageDiff>(
250                shard,
251                Arc::new(desc.clone()),
252                Arc::new(UnitSchema),
253                Diagnostics::from_purpose("snapshot"),
254                true,
255            )
256            .await
257            .expect("reader");
258        let as_of = Antichain::from_elem(Timestamp::from(0));
259        let contents = reader.snapshot_and_fetch(as_of).await.expect("snapshot");
260        let count: i64 = contents.iter().map(|(_, _, d)| *d).sum();
261        assert_eq!(count, 1000);
262    }
263
264    #[mz_ore::test(tokio::test)]
265    #[cfg_attr(miri, ignore)]
266    async fn spread_write_snapshot_counts() {
267        use crate::persist_host::PersistHost;
268        use mz_persist_types::PersistLocation;
269
270        let host = PersistHost::start(PersistLocation::new_in_mem())
271            .await
272            .expect("host");
273        let client = host.client().await.expect("client");
274        let shard = ShardId::new();
275        let desc = sample_desc();
276        let rows = sample_rows(1000, 16);
277        write_rows_spread(&client, shard, &desc, &rows, 8)
278            .await
279            .expect("spread write");
280
281        let mut reader = client
282            .open_leased_reader::<SourceData, (), Timestamp, StorageDiff>(
283                shard,
284                Arc::new(desc.clone()),
285                Arc::new(UnitSchema),
286                Diagnostics::from_purpose("snapshot"),
287                true,
288            )
289            .await
290            .expect("reader");
291        // Snapshot at the last written timestamp (7); all 1000 rows must be present.
292        let as_of = Antichain::from_elem(Timestamp::from(7));
293        let contents = reader.snapshot_and_fetch(as_of).await.expect("snapshot");
294        let count: i64 = contents.iter().map(|(_, _, d)| *d).sum();
295        assert_eq!(count, 1000);
296    }
297
298    #[mz_ore::test]
299    fn rows_for_bytes_basic() {
300        assert_eq!(rows_for_bytes(0, 16), 1); // always at least 1
301        assert!(rows_for_bytes(1_000_000, 64) > 0);
302    }
303}