1use std::sync::Arc;
14
15use mz_ore::cast::CastFrom;
16use mz_persist_client::Diagnostics;
17use mz_persist_client::PersistClient;
18use mz_persist_types::ShardId;
19use mz_persist_types::codec_impls::UnitSchema;
20use mz_repr::{Datum, RelationDesc, Row, SqlColumnType, SqlScalarType, Timestamp};
21use mz_storage_types::StorageDiff;
22use mz_storage_types::sources::SourceData;
23use timely::progress::Antichain;
24
25#[derive(Clone, Debug, PartialEq, Eq)]
32pub enum Cell {
33 Null,
35 Int16(i16),
37 Int32(i32),
39 Int64(i64),
41 Bool(bool),
43 Str(String),
45 Bytes(Vec<u8>),
47}
48
49impl Cell {
50 pub fn datum(&self) -> Datum<'_> {
52 match self {
53 Cell::Null => Datum::Null,
54 Cell::Int16(v) => Datum::from(*v),
55 Cell::Int32(v) => Datum::from(*v),
56 Cell::Int64(v) => Datum::from(*v),
57 Cell::Bool(v) => Datum::from(*v),
58 Cell::Str(v) => Datum::String(v),
59 Cell::Bytes(v) => Datum::Bytes(v),
60 }
61 }
62}
63
64pub fn pack_cells(cells: &[Cell]) -> Row {
66 let mut row = Row::default();
67 let mut packer = row.packer();
68 for cell in cells {
69 packer.push(cell.datum());
70 }
71 row
72}
73
74pub fn sample_desc() -> RelationDesc {
77 RelationDesc::builder()
78 .with_column(
79 "id",
80 SqlColumnType {
81 scalar_type: SqlScalarType::Int64,
82 nullable: false,
83 },
84 )
85 .with_column(
86 "payload",
87 SqlColumnType {
88 scalar_type: SqlScalarType::String,
89 nullable: false,
90 },
91 )
92 .finish()
93}
94
95pub fn synth_cell(scalar_type: &SqlScalarType, i: u64, pad: usize) -> Cell {
102 match scalar_type {
103 SqlScalarType::Int16 => Cell::Int16(i16::try_from(i % 0x8000).expect("fits i16")),
107 SqlScalarType::Int32 => Cell::Int32(i32::try_from(i % 0x8000_0000).expect("fits i32")),
108 SqlScalarType::Int64 => Cell::Int64(i64::try_from(i).expect("row index fits i64")),
109 SqlScalarType::Bool => Cell::Bool(i % 2 == 0),
110 SqlScalarType::String => Cell::Str(format!("{:0>width$}", i, width = pad)),
111 SqlScalarType::Bytes => Cell::Bytes(format!("{:0>width$}", i, width = pad).into_bytes()),
112 other => panic!("synth_cell: unsupported scalar type {other:?}"),
113 }
114}
115
116pub fn synth_rows(desc: &RelationDesc, start: u64, n: u64, pad: usize) -> Vec<Row> {
123 let types: Vec<SqlScalarType> = desc.iter_types().map(|c| c.scalar_type.clone()).collect();
124 (start..start + n)
125 .map(|i| {
126 let cells: Vec<Cell> = types.iter().map(|t| synth_cell(t, i, pad)).collect();
127 pack_cells(&cells)
128 })
129 .collect()
130}
131
132pub fn sample_rows(n: u64, pad: usize) -> Vec<Row> {
135 sample_rows_from(0, n, pad)
136}
137
138pub fn sample_rows_from(start: u64, n: u64, pad: usize) -> Vec<Row> {
142 synth_rows(&sample_desc(), start, n, pad)
143}
144
145pub async fn write_rows_single_ts(
148 client: &PersistClient,
149 shard: ShardId,
150 desc: &RelationDesc,
151 rows: &[Row],
152 ts: Timestamp,
153) -> anyhow::Result<()> {
154 let mut writer = client
155 .open_writer::<SourceData, (), Timestamp, StorageDiff>(
156 shard,
157 Arc::new(desc.clone()),
158 Arc::new(UnitSchema),
159 Diagnostics {
160 shard_name: "driver-data".to_string(),
161 handle_purpose: "headless driver write".to_string(),
162 },
163 )
164 .await?;
165
166 let updates: Vec<_> = rows
167 .iter()
168 .map(|r| ((SourceData(Ok(r.clone())), ()), ts, 1i64))
169 .collect();
170 let lower = Antichain::from_elem(ts);
171 let upper = Antichain::from_elem(ts.step_forward());
172 writer
173 .compare_and_append(&updates, lower, upper)
174 .await?
175 .map_err(|e| anyhow::anyhow!("{e}"))?;
176 Ok(())
177}
178
179pub async fn write_rows_spread(
186 client: &PersistClient,
187 shard: ShardId,
188 desc: &RelationDesc,
189 rows: &[Row],
190 n_ts: u64,
191) -> anyhow::Result<()> {
192 assert!(n_ts > 0, "n_ts must be positive");
193 let mut writer = client
194 .open_writer::<SourceData, (), Timestamp, StorageDiff>(
195 shard,
196 Arc::new(desc.clone()),
197 Arc::new(UnitSchema),
198 Diagnostics {
199 shard_name: "driver-data".to_string(),
200 handle_purpose: "headless driver spread write".to_string(),
201 },
202 )
203 .await?;
204 let updates: Vec<_> = rows
205 .iter()
206 .enumerate()
207 .map(|(i, r)| {
208 let t = u64::cast_from(i) % n_ts;
209 ((SourceData(Ok(r.clone())), ()), Timestamp::from(t), 1i64)
210 })
211 .collect();
212 let lower = Antichain::from_elem(Timestamp::from(0));
213 let upper = Antichain::from_elem(Timestamp::from(n_ts));
214 writer
215 .compare_and_append(&updates, lower, upper)
216 .await?
217 .map_err(|e| anyhow::anyhow!("{e}"))?;
218 Ok(())
219}
220
221pub fn rows_for_bytes(target_bytes: u64, pad: usize) -> u64 {
224 let per_row = u64::cast_from(pad) + 24;
225 (target_bytes / per_row).max(1)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::persist_host::PersistHost;
232 use mz_persist_types::PersistLocation;
233
234 #[mz_ore::test(tokio::test)]
235 #[cfg_attr(miri, ignore)]
236 async fn write_then_snapshot_counts() {
237 let host = PersistHost::start(PersistLocation::new_in_mem())
238 .await
239 .expect("host");
240 let client = host.client().await.expect("client");
241 let shard = ShardId::new();
242 let desc = sample_desc();
243 let rows = sample_rows(1000, 16);
244 write_rows_single_ts(&client, shard, &desc, &rows, Timestamp::from(0))
245 .await
246 .expect("write");
247
248 let mut reader = client
249 .open_leased_reader::<SourceData, (), Timestamp, StorageDiff>(
250 shard,
251 Arc::new(desc.clone()),
252 Arc::new(UnitSchema),
253 Diagnostics::from_purpose("snapshot"),
254 true,
255 )
256 .await
257 .expect("reader");
258 let as_of = Antichain::from_elem(Timestamp::from(0));
259 let contents = reader.snapshot_and_fetch(as_of).await.expect("snapshot");
260 let count: i64 = contents.iter().map(|(_, _, d)| *d).sum();
261 assert_eq!(count, 1000);
262 }
263
264 #[mz_ore::test(tokio::test)]
265 #[cfg_attr(miri, ignore)]
266 async fn spread_write_snapshot_counts() {
267 use crate::persist_host::PersistHost;
268 use mz_persist_types::PersistLocation;
269
270 let host = PersistHost::start(PersistLocation::new_in_mem())
271 .await
272 .expect("host");
273 let client = host.client().await.expect("client");
274 let shard = ShardId::new();
275 let desc = sample_desc();
276 let rows = sample_rows(1000, 16);
277 write_rows_spread(&client, shard, &desc, &rows, 8)
278 .await
279 .expect("spread write");
280
281 let mut reader = client
282 .open_leased_reader::<SourceData, (), Timestamp, StorageDiff>(
283 shard,
284 Arc::new(desc.clone()),
285 Arc::new(UnitSchema),
286 Diagnostics::from_purpose("snapshot"),
287 true,
288 )
289 .await
290 .expect("reader");
291 let as_of = Antichain::from_elem(Timestamp::from(7));
293 let contents = reader.snapshot_and_fetch(as_of).await.expect("snapshot");
294 let count: i64 = contents.iter().map(|(_, _, d)| *d).sum();
295 assert_eq!(count, 1000);
296 }
297
298 #[mz_ore::test]
299 fn rows_for_bytes_basic() {
300 assert_eq!(rows_for_bytes(0, 16), 1); assert!(rows_for_bytes(1_000_000, 64) > 0);
302 }
303}