use std::collections::BTreeMap;
use std::convert::Infallible;
use std::sync::Arc;
use differential_dataflow::AsCollection;
use futures::stream::StreamExt;
use mz_ore::cast::CastFrom;
use mz_ore::iter::IteratorExt;
use mz_repr::{Datum, Diff, Row};
use mz_storage_types::errors::DataflowError;
use mz_storage_types::sources::load_generator::{KeyValueLoadGenerator, LoadGeneratorOutput};
use mz_storage_types::sources::{MzOffset, SourceTimestamp};
use mz_timely_util::builder_async::{OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton};
use mz_timely_util::containers::stack::AccountedStackBuilder;
use rand::rngs::StdRng;
use rand::{RngCore, SeedableRng};
use timely::container::CapacityContainerBuilder;
use timely::dataflow::operators::{Concat, ToStream};
use timely::dataflow::{Scope, Stream};
use timely::progress::Antichain;
use tracing::info;
use crate::healthcheck::{HealthStatusMessage, HealthStatusUpdate, StatusNamespace};
use crate::source::types::{ProgressStatisticsUpdate, SignaledFuture, StackedCollection};
use crate::source::{RawSourceCreationConfig, SourceMessage};
pub fn render<G: Scope<Timestamp = MzOffset>>(
key_value: KeyValueLoadGenerator,
scope: &mut G,
config: RawSourceCreationConfig,
committed_uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'static,
start_signal: impl std::future::Future<Output = ()> + 'static,
output_map: BTreeMap<LoadGeneratorOutput, Vec<usize>>,
) -> (
StackedCollection<G, (usize, Result<SourceMessage, DataflowError>)>,
Option<Stream<G, Infallible>>,
Stream<G, HealthStatusMessage>,
Stream<G, ProgressStatisticsUpdate>,
Vec<PressOnDropButton>,
) {
let (steady_state_stats_stream, stats_button) =
render_statistics_operator(scope, config.clone(), committed_uppers);
let mut builder = AsyncOperatorBuilder::new(config.name.clone(), scope.clone());
let (data_output, stream) = builder.new_output::<AccountedStackBuilder<_>>();
let (_progress_output, progress_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
let (stats_output, stats_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
let busy_signal = Arc::clone(&config.busy_signal);
let button = builder.build(move |caps| {
SignaledFuture::new(busy_signal, async move {
let [mut cap, mut progress_cap, stats_cap]: [_; 3] = caps.try_into().unwrap();
let resume_upper = Antichain::from_iter(
config
.source_resume_uppers
.values()
.flat_map(|f| f.iter().map(MzOffset::decode_row)),
);
let Some(resume_offset) = resume_upper.into_option() else {
return;
};
let output_indexes = output_map
.get(&LoadGeneratorOutput::Default)
.expect("default output");
info!(
?config.worker_id,
"starting key-value load generator at {}",
resume_offset.offset,
);
cap.downgrade(&resume_offset);
progress_cap.downgrade(&resume_offset);
start_signal.await;
info!(?config.worker_id, "received key-value load generator start signal");
let snapshotting = resume_offset.offset == 0;
let mut local_partitions: Vec<_> = (0..key_value.partitions)
.filter_map(|p| {
config
.responsible_for(p)
.then(|| TransactionalSnapshotProducer::new(p, key_value.clone()))
})
.collect();
let stats_worker = config.responsible_for(0);
if local_partitions.is_empty() {
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::Snapshot {
records_known: 0,
records_staged: 0,
},
);
return;
}
let local_snapshot_size = (u64::cast_from(local_partitions.len()))
* key_value.keys
* key_value.transactional_snapshot_rounds()
/ key_value.partitions;
let mut value_buffer: Vec<u8> = vec![0; usize::cast_from(key_value.value_size)];
let mut upper_offset = if snapshotting {
let snapshot_rounds = key_value.transactional_snapshot_rounds();
if stats_worker {
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::SteadyState {
offset_known: snapshot_rounds,
offset_committed: 0,
},
);
};
progress_cap.downgrade(&MzOffset::from(snapshot_rounds));
let mut emitted = 0;
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::Snapshot {
records_known: local_snapshot_size,
records_staged: emitted,
},
);
while local_partitions.iter().any(|si| !si.finished()) {
for sp in local_partitions.iter_mut() {
for u in sp.produce_batch(&mut value_buffer, output_indexes) {
data_output.give_fueled(&cap, u).await;
emitted += 1;
}
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::Snapshot {
records_known: local_snapshot_size,
records_staged: emitted,
},
);
}
}
cap.downgrade(&MzOffset::from(snapshot_rounds));
snapshot_rounds
} else {
cap.downgrade(&resume_offset);
progress_cap.downgrade(&resume_offset);
resume_offset.offset
};
let mut local_partitions: Vec<_> = (0..key_value.partitions)
.filter_map(|p| {
config
.responsible_for(p)
.then(|| UpdateProducer::new(p, upper_offset, key_value.clone()))
})
.collect();
if !local_partitions.is_empty()
&& (key_value.tick_interval.is_some() || !key_value.transactional_snapshot)
{
let mut interval = key_value.tick_interval.map(tokio::time::interval);
loop {
if local_partitions.iter().all(|si| si.finished_quick()) {
if let Some(interval) = &mut interval {
interval.tick().await;
} else {
break;
}
}
for up in local_partitions.iter_mut() {
let (new_upper, iter) = up.produce_batch(&mut value_buffer, output_indexes);
upper_offset = new_upper;
for u in iter {
data_output.give_fueled(&cap, u).await;
}
}
cap.downgrade(&MzOffset::from(upper_offset));
progress_cap.downgrade(&MzOffset::from(upper_offset));
}
}
std::future::pending::<()>().await;
})
});
let status = [HealthStatusMessage {
index: 0,
namespace: StatusNamespace::Generator,
update: HealthStatusUpdate::running(),
}]
.to_stream(scope);
let stats_stream = stats_stream.concat(&steady_state_stats_stream);
(
stream.as_collection(),
Some(progress_stream),
status,
stats_stream,
vec![button.press_on_drop(), stats_button],
)
}
struct PartitionKeyIterator {
partition: u64,
partitions: u64,
keys: u64,
next: u64,
}
impl PartitionKeyIterator {
fn new(partition: u64, partitions: u64, keys: u64, start_key: u64) -> Self {
assert_eq!(keys % partitions, 0);
PartitionKeyIterator {
partition,
partitions,
keys,
next: start_key,
}
}
}
impl Iterator for &mut PartitionKeyIterator {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
let ret = self.next;
self.next = (self.next + self.partitions) % self.keys;
Some(ret)
}
}
fn create_consistent_rng(source_seed: u64, offset: u64, partition: u64) -> StdRng {
let mut seed = [0; 32];
seed[0..8].copy_from_slice(&source_seed.to_le_bytes());
seed[8..16].copy_from_slice(&offset.to_le_bytes());
seed[16..24].copy_from_slice(&partition.to_le_bytes());
StdRng::from_seed(seed)
}
struct TransactionalSnapshotProducer {
pi: PartitionKeyIterator,
batch_size: u64,
produced_batches: u64,
expected_batches: u64,
round: u64,
snapshot_rounds: u64,
rng: Option<StdRng>,
seed: u64,
include_offset: bool,
}
impl TransactionalSnapshotProducer {
fn new(partition: u64, key_value: KeyValueLoadGenerator) -> Self {
let snapshot_rounds = key_value.transactional_snapshot_rounds();
let KeyValueLoadGenerator {
partitions,
keys,
batch_size,
seed,
include_offset,
..
} = key_value;
assert_eq!((keys / partitions) % batch_size, 0);
let pi = PartitionKeyIterator::new(
partition, partitions, keys, partition,
);
TransactionalSnapshotProducer {
pi,
batch_size,
produced_batches: 0,
expected_batches: keys / partitions / batch_size,
round: 0,
snapshot_rounds,
rng: None,
seed,
include_offset: include_offset.is_some(),
}
}
fn finished(&self) -> bool {
self.round >= self.snapshot_rounds
}
fn produce_batch<'a>(
&'a mut self,
value_buffer: &'a mut Vec<u8>,
output_indexes: &'a [usize],
) -> impl Iterator<
Item = (
(usize, Result<SourceMessage, DataflowError>),
MzOffset,
Diff,
),
> + 'a {
let finished = self.finished();
let rng = self
.rng
.get_or_insert_with(|| create_consistent_rng(self.seed, self.round, self.pi.partition));
let partition: u64 = self.pi.partition;
let iter_round: u64 = self.round;
let include_offset: bool = self.include_offset;
let iter = self
.pi
.take(if finished {
0
} else {
usize::cast_from(self.batch_size)
})
.flat_map(move |key| {
rng.fill_bytes(value_buffer.as_mut_slice());
let msg = Ok(SourceMessage {
key: Row::pack_slice(&[Datum::UInt64(key)]),
value: Row::pack_slice(&[Datum::UInt64(partition), Datum::Bytes(value_buffer)]),
metadata: if include_offset {
Row::pack(&[Datum::UInt64(iter_round)])
} else {
Row::default()
},
});
output_indexes
.iter()
.repeat_clone(msg)
.map(move |(idx, msg)| ((*idx, msg), MzOffset::from(iter_round), 1))
});
if !finished {
self.produced_batches += 1;
if self.produced_batches == self.expected_batches {
self.round += 1;
self.produced_batches = 0;
}
}
iter
}
}
struct UpdateProducer {
pi: PartitionKeyIterator,
batch_size: u64,
next_offset: u64,
seed: u64,
expected_quick_offsets: u64,
include_offset: bool,
}
impl UpdateProducer {
fn new(partition: u64, next_offset: u64, key_value: KeyValueLoadGenerator) -> Self {
let snapshot_rounds = key_value.transactional_snapshot_rounds();
let quick_rounds = key_value.non_transactional_snapshot_rounds();
let KeyValueLoadGenerator {
partitions,
keys,
batch_size,
seed,
include_offset,
..
} = key_value;
let start_key =
(((next_offset - snapshot_rounds) * batch_size * partitions) + partition) % keys;
let expected_quick_offsets =
((keys / partitions / batch_size) * quick_rounds) + snapshot_rounds;
let pi = PartitionKeyIterator::new(partition, partitions, keys, start_key);
UpdateProducer {
pi,
batch_size,
next_offset,
seed,
expected_quick_offsets,
include_offset: include_offset.is_some(),
}
}
fn finished_quick(&self) -> bool {
self.next_offset >= self.expected_quick_offsets
}
fn produce_batch<'a>(
&'a mut self,
value_buffer: &'a mut Vec<u8>,
output_indexes: &'a [usize],
) -> (
u64,
impl Iterator<
Item = (
(usize, Result<SourceMessage, DataflowError>),
MzOffset,
Diff,
),
> + 'a,
) {
let mut rng = create_consistent_rng(self.seed, self.next_offset, self.pi.partition);
let partition: u64 = self.pi.partition;
let iter_offset: u64 = self.next_offset;
let include_offset: bool = self.include_offset;
let iter = self
.pi
.take(usize::cast_from(self.batch_size))
.flat_map(move |key| {
rng.fill_bytes(value_buffer.as_mut_slice());
let msg = Ok(SourceMessage {
key: Row::pack_slice(&[Datum::UInt64(key)]),
value: Row::pack_slice(&[Datum::UInt64(partition), Datum::Bytes(value_buffer)]),
metadata: if include_offset {
Row::pack(&[Datum::UInt64(iter_offset)])
} else {
Row::default()
},
});
output_indexes
.iter()
.repeat_clone(msg)
.map(move |(idx, msg)| ((*idx, msg), MzOffset::from(iter_offset), 1))
});
self.next_offset += 1;
(self.next_offset, iter)
}
}
pub fn render_statistics_operator<G: Scope<Timestamp = MzOffset>>(
scope: &G,
config: RawSourceCreationConfig,
committed_uppers: impl futures::Stream<Item = Antichain<MzOffset>> + 'static,
) -> (Stream<G, ProgressStatisticsUpdate>, PressOnDropButton) {
let id = config.id;
let mut builder =
AsyncOperatorBuilder::new(format!("key_value_loadgen_statistics:{id}"), scope.clone());
let (stats_output, stats_stream) = builder.new_output();
let button = builder.build(move |caps| async move {
let [stats_cap]: [_; 1] = caps.try_into().unwrap();
let offset_worker = config.responsible_for(0);
if !offset_worker {
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::SteadyState {
offset_known: 0,
offset_committed: 0,
},
);
return;
}
tokio::pin!(committed_uppers);
loop {
match committed_uppers.next().await {
Some(frontier) => {
if let Some(offset) = frontier.as_option() {
stats_output.give(
&stats_cap,
ProgressStatisticsUpdate::SteadyState {
offset_known: offset.offset,
offset_committed: offset.offset,
},
);
}
}
None => return,
}
}
});
(stats_stream, button.press_on_drop())
}
#[cfg(test)]
mod test {
use super::*;
#[mz_ore::test]
fn test_key_value_loadgen_resume_upper() {
let up = UpdateProducer::new(
1, 5, KeyValueLoadGenerator {
keys: 126,
snapshot_rounds: 2,
transactional_snapshot: true,
value_size: 1234,
partitions: 3,
tick_interval: None,
batch_size: 2,
seed: 1234,
include_offset: None,
},
);
assert_eq!(up.pi.next, 19);
let up = UpdateProducer::new(
1, 5 + 126 / 2, KeyValueLoadGenerator {
keys: 126,
snapshot_rounds: 2,
transactional_snapshot: true,
value_size: 1234,
partitions: 3,
tick_interval: None,
batch_size: 2,
seed: 1234,
include_offset: None,
},
);
assert_eq!(up.pi.next, 19);
}
#[mz_ore::test]
fn test_key_value_loadgen_part_iter() {
let mut pi = PartitionKeyIterator::new(
1, 3, 126, 1, );
assert_eq!(1, Iterator::next(&mut &mut pi).unwrap());
assert_eq!(4, Iterator::next(&mut &mut pi).unwrap());
let _ = pi.take((126 / 3) - 2).count();
assert_eq!(pi.next, 1);
}
}