kgen/
kgen.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::convert::{TryFrom, TryInto};
12use std::iter;
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::time::Duration;
15
16use anyhow::bail;
17use chrono::DateTime;
18use crossbeam::thread;
19use mz_avro::Schema;
20use mz_avro::schema::{SchemaNode, SchemaPiece, SchemaPieceOrNamed};
21use mz_avro::types::{DecimalValue, Value};
22use mz_kafka_util::client::MzClientContext;
23use mz_ore::cast::CastFrom;
24use mz_ore::cli::{self, CliConfig};
25use mz_ore::retry::Retry;
26use rand::distributions::uniform::SampleUniform;
27use rand::distributions::{Alphanumeric, Bernoulli, Uniform, WeightedIndex};
28use rand::prelude::{Distribution, ThreadRng};
29use rand::thread_rng;
30use rdkafka::error::KafkaError;
31use rdkafka::producer::{BaseRecord, Producer, ThreadedProducer};
32use rdkafka::types::RDKafkaErrorCode;
33use rdkafka::util::Timeout;
34use serde_json::Map;
35use url::Url;
36
37trait Generator<R>: FnMut(&mut ThreadRng) -> R + Send + Sync {
38    fn clone_box(&self) -> Box<dyn Generator<R>>;
39}
40
41impl<F, R> Generator<R> for F
42where
43    F: FnMut(&mut ThreadRng) -> R + Clone + Send + Sync + 'static,
44{
45    fn clone_box(&self) -> Box<dyn Generator<R>> {
46        Box::new(self.clone())
47    }
48}
49
50impl<R> Clone for Box<dyn Generator<R>>
51where
52    R: 'static,
53{
54    fn clone(&self) -> Box<dyn Generator<R>> {
55        (**self).clone_box()
56    }
57}
58
59#[derive(Clone)]
60struct RandomAvroGenerator<'a> {
61    // Generator functions for each piece of the schema. These map keys are
62    // morally `*const SchemaPiece`s, but represented as `usize`s so that they
63    // implement `Send`.
64    ints: BTreeMap<usize, Box<dyn Generator<i32>>>,
65    longs: BTreeMap<usize, Box<dyn Generator<i64>>>,
66    strings: BTreeMap<usize, Box<dyn Generator<Vec<u8>>>>,
67    bytes: BTreeMap<usize, Box<dyn Generator<Vec<u8>>>>,
68    unions: BTreeMap<usize, Box<dyn Generator<usize>>>,
69    enums: BTreeMap<usize, Box<dyn Generator<usize>>>,
70    bools: BTreeMap<usize, Box<dyn Generator<bool>>>,
71    floats: BTreeMap<usize, Box<dyn Generator<f32>>>,
72    doubles: BTreeMap<usize, Box<dyn Generator<f64>>>,
73    decimals: BTreeMap<usize, Box<dyn Generator<Vec<u8>>>>,
74    array_lens: BTreeMap<usize, Box<dyn Generator<usize>>>,
75
76    schema: SchemaNode<'a>,
77}
78
79impl<'a> RandomAvroGenerator<'a> {
80    fn gen_inner(&mut self, node: SchemaNode, rng: &mut ThreadRng) -> Value {
81        // TODO(benesch): rewrite to avoid `as`.
82        #[allow(clippy::as_conversions)]
83        let p = &*node.inner as *const _ as usize;
84        match node.inner {
85            SchemaPiece::Null => Value::Null,
86            SchemaPiece::Boolean => {
87                let val = self.bools.get_mut(&p).unwrap()(rng);
88                Value::Boolean(val)
89            }
90            SchemaPiece::Int => {
91                let val = self.ints.get_mut(&p).unwrap()(rng);
92                Value::Int(val)
93            }
94            SchemaPiece::Long => {
95                let val = self.longs.get_mut(&p).unwrap()(rng);
96                Value::Long(val)
97            }
98            SchemaPiece::Float => {
99                let val = self.floats.get_mut(&p).unwrap()(rng);
100                Value::Float(val)
101            }
102            SchemaPiece::Double => {
103                let val = self.doubles.get_mut(&p).unwrap()(rng);
104                Value::Double(val)
105            }
106            SchemaPiece::Date => {
107                let days = self.ints.get_mut(&p).unwrap()(rng);
108                Value::Date(days)
109            }
110            SchemaPiece::TimestampMilli => {
111                let millis = self.longs.get_mut(&p).unwrap()(rng);
112
113                let seconds = millis / 1000;
114                // TODO(benesch): rewrite to avoid `as`.
115                #[allow(clippy::as_conversions)]
116                let fraction = (millis % 1000) as u32;
117                let val = DateTime::from_timestamp(seconds, fraction * 1_000_000).unwrap();
118                Value::Timestamp(val.naive_utc())
119            }
120            SchemaPiece::TimestampMicro => {
121                let micros = self.longs.get_mut(&p).unwrap()(rng);
122
123                let seconds = micros / 1_000_000;
124                // TODO(benesch): rewrite to avoid `as`.
125                #[allow(clippy::as_conversions)]
126                let fraction = (micros % 1_000_000) as u32;
127                let val = DateTime::from_timestamp(seconds, fraction * 1_000).unwrap();
128                Value::Timestamp(val.naive_utc())
129            }
130            SchemaPiece::Decimal {
131                precision,
132                scale,
133                fixed_size: _,
134            } => {
135                let unscaled = self.decimals.get_mut(&p).unwrap()(rng);
136                Value::Decimal(DecimalValue {
137                    unscaled,
138                    precision: *precision,
139                    scale: *scale,
140                })
141            }
142            SchemaPiece::Bytes => {
143                let val = self.bytes.get_mut(&p).unwrap()(rng);
144                Value::Bytes(val)
145            }
146            SchemaPiece::String => {
147                let buf = self.strings.get_mut(&p).unwrap()(rng);
148                let val = String::from_utf8(buf).unwrap();
149                Value::String(val)
150            }
151            SchemaPiece::Json => unreachable!(),
152            SchemaPiece::Uuid => unreachable!(),
153            SchemaPiece::Array(inner) => {
154                let len = self.array_lens.get_mut(&p).unwrap()(rng);
155                let next = node.step(&**inner);
156                let inner_vals = (0..len).map(move |_| self.gen_inner(next, rng)).collect();
157                Value::Array(inner_vals)
158            }
159            SchemaPiece::Map(_inner) => {
160                // let len = self.array_lens.get_mut(&p).unwrap()();
161                // let key_f = self.map_keys.get_mut(&p).unwrap();
162                // let next = node.step(&**inner);
163                // let inner_entries = (0..len)
164                //     .map(|_| {
165                //         let mut key_buf = vec![];
166                //         key_f(&mut key_buf);
167                //         let key = String::from_utf8(key_buf).unwrap();
168                //         let val = self.gen_inner(next);
169                //         (key, val)
170                //     })
171                //     .collect();
172                // Value::Map(inner_entries)
173                unreachable!()
174            }
175            SchemaPiece::Union(us) => {
176                let index = self.unions.get_mut(&p).unwrap()(rng);
177                let next = node.step(&us.variants()[index]);
178                let null_variant = us
179                    .variants()
180                    .iter()
181                    .position(|v| v == &SchemaPieceOrNamed::Piece(SchemaPiece::Null));
182                let inner = Box::new(self.gen_inner(next, rng));
183                Value::Union {
184                    index,
185                    inner,
186                    n_variants: us.variants().len(),
187                    null_variant,
188                }
189            }
190            SchemaPiece::ResolveIntTsMilli
191            | SchemaPiece::ResolveIntTsMicro
192            | SchemaPiece::ResolveDateTimestamp
193            | SchemaPiece::ResolveIntLong
194            | SchemaPiece::ResolveIntFloat
195            | SchemaPiece::ResolveIntDouble
196            | SchemaPiece::ResolveLongFloat
197            | SchemaPiece::ResolveLongDouble
198            | SchemaPiece::ResolveFloatDouble
199            | SchemaPiece::ResolveConcreteUnion { .. }
200            | SchemaPiece::ResolveUnionUnion { .. }
201            | SchemaPiece::ResolveUnionConcrete { .. }
202            | SchemaPiece::ResolveRecord { .. }
203            | SchemaPiece::ResolveEnum { .. } => {
204                unreachable!("We never resolve schemas, so seeing this is impossible")
205            }
206            SchemaPiece::Record { fields, .. } => {
207                let fields = fields
208                    .iter()
209                    .map(|f| {
210                        let k = f.name.clone();
211                        let next = node.step(&f.schema);
212                        let v = self.gen_inner(next, rng);
213                        (k, v)
214                    })
215                    .collect();
216                Value::Record(fields)
217            }
218            SchemaPiece::Enum { symbols, .. } => {
219                let i = self.enums.get_mut(&p).unwrap()(rng);
220                Value::Enum(i, symbols[i].clone())
221            }
222            SchemaPiece::Fixed { size: _ } => unreachable!(),
223        }
224    }
225    pub fn generate(&mut self, rng: &mut ThreadRng) -> Value {
226        self.gen_inner(self.schema, rng)
227    }
228    fn new_inner(
229        &mut self,
230        node: SchemaNode<'a>,
231        annotations: &Map<String, serde_json::Value>,
232        field_name: Option<&str>,
233    ) {
234        fn bool_dist(
235            json: &serde_json::Value,
236        ) -> impl FnMut(&mut ThreadRng) -> bool + Clone + use<> {
237            let x = json.as_f64().unwrap();
238            let dist = Bernoulli::new(x).unwrap();
239            move |rng| dist.sample(rng)
240        }
241        fn integral_dist<T>(
242            json: &serde_json::Value,
243        ) -> impl FnMut(&mut ThreadRng) -> T + Clone + use<T>
244        where
245            T: SampleUniform + TryFrom<i64> + Clone,
246            T::Sampler: Clone,
247            <T as TryFrom<i64>>::Error: std::fmt::Debug,
248        {
249            let x = json.as_array().unwrap();
250            let (min, max): (T, T) = (
251                x[0].as_i64().unwrap().try_into().unwrap(),
252                x[1].as_i64().unwrap().try_into().unwrap(),
253            );
254            let dist = Uniform::new_inclusive(min, max);
255            move |rng| dist.sample(rng)
256        }
257        fn float_dist(
258            json: &serde_json::Value,
259        ) -> impl FnMut(&mut ThreadRng) -> f32 + Clone + use<> {
260            let x = json.as_array().unwrap();
261            // TODO(benesch): rewrite to avoid `as`.
262            #[allow(clippy::as_conversions)]
263            let (min, max) = (x[0].as_f64().unwrap() as f32, x[1].as_f64().unwrap() as f32);
264            let dist = Uniform::new_inclusive(min, max);
265            move |rng| dist.sample(rng)
266        }
267        fn double_dist(
268            json: &serde_json::Value,
269        ) -> impl FnMut(&mut ThreadRng) -> f64 + Clone + use<> {
270            let x = json.as_array().unwrap();
271            let (min, max) = (x[0].as_f64().unwrap(), x[1].as_f64().unwrap());
272            let dist = Uniform::new_inclusive(min, max);
273            move |rng| dist.sample(rng)
274        }
275        fn string_dist(
276            json: &serde_json::Value,
277        ) -> impl FnMut(&mut ThreadRng) -> Vec<u8> + Clone + use<> {
278            let mut len = integral_dist::<usize>(json);
279            move |rng| {
280                let len = len(rng);
281                let cd = Alphanumeric;
282                iter::repeat_with(|| cd.sample(rng)).take(len).collect()
283            }
284        }
285        fn bytes_dist(
286            json: &serde_json::Value,
287        ) -> impl FnMut(&mut ThreadRng) -> Vec<u8> + Clone + use<> {
288            let mut len = integral_dist::<usize>(json);
289            move |rng| {
290                let len = len(rng);
291                let bd = Uniform::new_inclusive(0, 255);
292                iter::repeat_with(|| bd.sample(rng)).take(len).collect()
293            }
294        }
295        fn decimal_dist(
296            json: &serde_json::Value,
297            precision: usize,
298        ) -> impl FnMut(&mut ThreadRng) -> Vec<u8> + Clone + use<> {
299            let x = json.as_array().unwrap();
300            let (min, max): (i64, i64) = (x[0].as_i64().unwrap(), x[1].as_i64().unwrap());
301            // Ensure values fit within precision bounds.
302            let precision_limit = 10i64
303                .checked_pow(u32::try_from(precision).unwrap())
304                .unwrap();
305            assert!(
306                precision_limit >= max,
307                "max value of {} exceeds value expressable with precision {}",
308                max,
309                precision
310            );
311            assert!(
312                precision_limit >= min.abs(),
313                "min value of {} exceeds value expressable with precision {}",
314                min,
315                precision
316            );
317            let dist = Uniform::<i64>::new_inclusive(min, max);
318            move |rng| dist.sample(rng).to_be_bytes().to_vec()
319        }
320        // TODO(benesch): rewrite to avoid `as`.
321        #[allow(clippy::as_conversions)]
322        let p = &*node.inner as *const _ as usize;
323
324        let dist_json = field_name.and_then(|fn_| annotations.get(fn_));
325        let err = format!(
326            "Distribution annotation not found: {}",
327            field_name.unwrap_or("(None)")
328        );
329        match node.inner {
330            SchemaPiece::Null => {}
331            SchemaPiece::Boolean => {
332                let dist = bool_dist(dist_json.expect(&err));
333                self.bools.insert(p, Box::new(dist));
334            }
335            SchemaPiece::Int => {
336                let dist = integral_dist(dist_json.expect(&err));
337                self.ints.insert(p, Box::new(dist));
338            }
339            SchemaPiece::Long => {
340                let dist = integral_dist(dist_json.expect(&err));
341                self.longs.insert(p, Box::new(dist));
342            }
343            SchemaPiece::Float => {
344                let dist = float_dist(dist_json.expect(&err));
345                self.floats.insert(p, Box::new(dist));
346            }
347            SchemaPiece::Double => {
348                let dist = double_dist(dist_json.expect(&err));
349                self.doubles.insert(p, Box::new(dist));
350            }
351            SchemaPiece::Date => {}
352            SchemaPiece::TimestampMilli => {}
353            SchemaPiece::TimestampMicro => {}
354            SchemaPiece::Decimal {
355                precision,
356                scale: _,
357                fixed_size: _,
358            } => {
359                let dist = decimal_dist(dist_json.expect(&err), *precision);
360                self.decimals.insert(p, Box::new(dist));
361            }
362            SchemaPiece::Bytes => {
363                let len_dist_json = annotations
364                    .get(&format!("{}.len", field_name.unwrap()))
365                    .unwrap();
366                let dist = bytes_dist(len_dist_json);
367                self.bytes.insert(p, Box::new(dist));
368            }
369            SchemaPiece::String => {
370                let len_dist_json = annotations
371                    .get(&format!("{}.len", field_name.unwrap()))
372                    .unwrap();
373                let dist = string_dist(len_dist_json);
374                self.strings.insert(p, Box::new(dist));
375            }
376            SchemaPiece::Json => unimplemented!(),
377            SchemaPiece::Uuid => unimplemented!(),
378            SchemaPiece::Array(inner) => {
379                let fn_ = field_name.unwrap();
380                let len_dist_json = annotations.get(&format!("{}.len", fn_)).unwrap();
381                let len = integral_dist::<usize>(len_dist_json);
382                self.array_lens.insert(p, Box::new(len));
383                let item_fn = format!("{}[]", fn_);
384                self.new_inner(node.step(&**inner), annotations, Some(&item_fn))
385            }
386            SchemaPiece::Map(_) => unimplemented!(),
387            SchemaPiece::Union(us) => {
388                let variant_jsons = dist_json.expect(&err).as_array().unwrap();
389                assert!(variant_jsons.len() == us.variants().len());
390                let probabilities = variant_jsons.iter().map(|v| v.as_f64().unwrap());
391                let dist = WeightedIndex::new(probabilities).unwrap();
392                let f = move |rng: &mut ThreadRng| dist.sample(rng);
393                self.unions.insert(p, Box::new(f));
394                let fn_ = field_name.unwrap();
395                for (i, v) in us.variants().iter().enumerate() {
396                    let fn_ = format!("{}.{}", fn_, i);
397                    self.new_inner(node.step(v), annotations, Some(&fn_))
398                }
399            }
400            SchemaPiece::Record {
401                doc: _,
402                fields,
403                lookup: _,
404            } => {
405                let name = node.name.unwrap();
406                for f in fields {
407                    let fn_ = format!("{}.{}::{}", name.namespace(), name.base_name(), f.name);
408                    self.new_inner(node.step(&f.schema), annotations, Some(&fn_));
409                }
410            }
411            SchemaPiece::Enum {
412                doc: _,
413                symbols: _,
414                default_idx: _,
415            } => unimplemented!(),
416            SchemaPiece::Fixed { size: _ } => unimplemented!(),
417            SchemaPiece::ResolveIntTsMilli
418            | SchemaPiece::ResolveIntTsMicro
419            | SchemaPiece::ResolveDateTimestamp
420            | SchemaPiece::ResolveIntLong
421            | SchemaPiece::ResolveIntFloat
422            | SchemaPiece::ResolveIntDouble
423            | SchemaPiece::ResolveLongFloat
424            | SchemaPiece::ResolveLongDouble
425            | SchemaPiece::ResolveFloatDouble
426            | SchemaPiece::ResolveConcreteUnion { .. }
427            | SchemaPiece::ResolveUnionUnion { .. }
428            | SchemaPiece::ResolveUnionConcrete { .. }
429            | SchemaPiece::ResolveRecord { .. }
430            | SchemaPiece::ResolveEnum { .. } => unreachable!(),
431        };
432    }
433    pub fn new(schema: &'a Schema, annotations: &serde_json::Value) -> Self {
434        let mut self_ = Self {
435            ints: Default::default(),
436            longs: Default::default(),
437            strings: Default::default(),
438            bytes: Default::default(),
439            unions: Default::default(),
440            enums: Default::default(),
441            bools: Default::default(),
442            floats: Default::default(),
443            doubles: Default::default(),
444            decimals: Default::default(),
445            array_lens: Default::default(),
446            schema: schema.top_node(),
447        };
448        self_.new_inner(schema.top_node(), annotations.as_object().unwrap(), None);
449        self_
450    }
451}
452
453#[derive(Clone)]
454enum ValueGenerator<'a> {
455    UniformBytes {
456        len: Uniform<usize>,
457        bytes: Uniform<u8>,
458    },
459    RandomAvro {
460        inner: RandomAvroGenerator<'a>,
461        schema: &'a Schema,
462        schema_id: i32,
463    },
464}
465
466impl<'a> ValueGenerator<'a> {
467    pub fn next_value(&mut self, out: &mut Vec<u8>, rng: &mut ThreadRng) {
468        match self {
469            ValueGenerator::UniformBytes { len, bytes } => {
470                let len = len.sample(rng);
471                let sample = || bytes.sample(rng);
472                out.clear();
473                out.extend(iter::repeat_with(sample).take(len));
474            }
475            ValueGenerator::RandomAvro {
476                inner,
477                schema,
478                schema_id,
479            } => {
480                let value = inner.generate(rng);
481                out.clear();
482                out.push(0);
483                for b in schema_id.to_be_bytes().iter() {
484                    out.push(*b);
485                }
486                debug_assert!(value.validate(schema.top_node()));
487                mz_avro::encode_unchecked(&value, schema, out);
488            }
489        }
490    }
491}
492
493#[derive(clap::ValueEnum, PartialEq, Debug, Clone)]
494pub enum KeyFormat {
495    Avro,
496    Random,
497    Sequential,
498}
499
500#[derive(clap::ValueEnum, PartialEq, Debug, Clone)]
501pub enum ValueFormat {
502    Bytes,
503    Avro,
504}
505
506/// Write random data to Kafka.
507#[derive(clap::Parser)]
508struct Args {
509    // == Kafka configuration arguments. ==
510    /// Address of one or more Kafka nodes, comma separated, in the Kafka
511    /// cluster to connect to.
512    #[clap(short = 'b', long, default_value = "localhost:9092")]
513    bootstrap_server: String,
514    /// URL of the schema registry to connect to, if using Avro keys or values.
515    #[clap(short = 's', long, default_value = "http://localhost:8081")]
516    schema_registry_url: Url,
517    /// Topic into which to write records.
518    #[clap(short = 't', long = "topic")]
519    topic: String,
520    /// Number of records to write.
521    #[clap(short = 'n', long = "num-records")]
522    num_records: usize,
523    /// Number of partitions over which records should be distributed in a
524    /// round-robin fashion, regardless of the value of the keys of these
525    /// records.
526    ///
527    /// The default value, 0, indicates that Kafka's default strategy of
528    /// distributing writes based upon the hash of their keys should be used
529    /// instead.
530    #[clap(long, default_value = "0")]
531    partitions_round_robin: usize,
532    /// The number of threads to use.
533    ///
534    /// If zero, uses the number of physical CPUs on the machine.
535    #[structopt(long, default_value = "0")]
536    threads: usize,
537
538    // == Key arguments. ==
539    /// Format in which to generate keys.
540    #[clap(
541        short = 'k',
542        long = "keys",
543        ignore_case = true,
544        value_enum,
545        default_value = "sequential"
546    )]
547    key_format: KeyFormat,
548    /// Minimum key value to generate, if using random-formatted keys.
549    #[clap(long, required_if_eq("key_format", "random"))]
550    key_min: Option<u64>,
551    /// Maximum key value to generate, if using random-formatted keys.
552    #[clap(long, required_if_eq("key_format", "random"))]
553    key_max: Option<u64>,
554    /// Schema describing Avro key data to randomly generate, if using
555    /// Avro-formatted keys.
556    #[clap(long, required_if_eq("key_format", "avro"))]
557    avro_key_schema: Option<Schema>,
558    /// JSON object describing the distribution parameters for each field of
559    /// the Avro key object, if using Avro-formatted keys.
560    #[clap(long, required_if_eq("key_format", "avro"))]
561    avro_key_distribution: Option<serde_json::Value>,
562
563    // == Value arguments. ==
564    /// Format in which to generate values.
565    #[clap(
566        short = 'v',
567        long = "values",
568        ignore_case = true,
569        value_enum,
570        default_value = "bytes"
571    )]
572    value_format: ValueFormat,
573    /// Minimum value size to generate, if using bytes-formatted values.
574    #[clap(
575        short = 'm',
576        long = "min-message-size",
577        required_if_eq("value_format", "bytes")
578    )]
579    min_value_size: Option<usize>,
580    /// Maximum value size to generate, if using bytes-formatted values.
581    #[clap(
582        short = 'M',
583        long = "max-message-size",
584        required_if_eq("value_format", "bytes")
585    )]
586    max_value_size: Option<usize>,
587    /// Schema describing Avro value data to randomly generate, if using
588    /// Avro-formatted values.
589    #[clap(long = "avro-schema", required_if_eq("value_format", "avro"))]
590    avro_value_schema: Option<Schema>,
591    /// JSON object describing the distribution parameters for each field of
592    /// the Avro value object, if using Avro-formatted keys.
593    #[clap(long = "avro-distribution", required_if_eq("value_format", "avro"))]
594    avro_value_distribution: Option<serde_json::Value>,
595
596    // == Output control. ==
597    /// Suppress printing progress messages.
598    #[clap(short = 'q', long)]
599    quiet: bool,
600}
601
602#[tokio::main]
603async fn main() -> anyhow::Result<()> {
604    let args: Args = cli::parse_args(CliConfig::default());
605
606    let value_gen = match args.value_format {
607        ValueFormat::Bytes => {
608            // Clap may one day be able to do this validation automatically.
609            // See: https://github.com/clap-rs/clap/discussions/2039
610            if args.avro_value_schema.is_some() {
611                bail!("cannot specify --avro-schema without --values=avro");
612            }
613            if args.avro_value_distribution.is_some() {
614                bail!("cannot specify --avro-distribution without --values=avro");
615            }
616            let len =
617                Uniform::new_inclusive(args.min_value_size.unwrap(), args.max_value_size.unwrap());
618            let bytes = Uniform::new_inclusive(0, 255);
619
620            ValueGenerator::UniformBytes { len, bytes }
621        }
622        ValueFormat::Avro => {
623            // Clap may one day be able to do this validation automatically.
624            // See: https://github.com/clap-rs/clap/discussions/2039
625            if args.min_value_size.is_some() {
626                bail!("cannot specify --min-message-size without --values=bytes");
627            }
628            if args.max_value_size.is_some() {
629                bail!("cannot specify --max-message-size without --values=bytes");
630            }
631            let value_schema = args.avro_value_schema.as_ref().unwrap();
632            let ccsr = mz_ccsr::ClientConfig::new(args.schema_registry_url.clone()).build()?;
633            let schema_id = ccsr
634                .publish_schema(
635                    &format!("{}-value", args.topic),
636                    &value_schema.to_string(),
637                    mz_ccsr::SchemaType::Avro,
638                    &[],
639                )
640                .await?;
641            let generator =
642                RandomAvroGenerator::new(value_schema, &args.avro_value_distribution.unwrap());
643            ValueGenerator::RandomAvro {
644                inner: generator,
645                schema: value_schema,
646                schema_id,
647            }
648        }
649    };
650
651    let key_gen = match args.key_format {
652        KeyFormat::Avro => {
653            // Clap may one day be able to do this validation automatically.
654            // See: https://github.com/clap-rs/clap/discussions/2039
655            if args.key_min.is_some() {
656                bail!("cannot specify --key-min without --keys=bytes");
657            }
658            if args.key_max.is_some() {
659                bail!("cannot specify --key-max without --keys=bytes");
660            }
661            let key_schema = args.avro_key_schema.as_ref().unwrap();
662            let ccsr = mz_ccsr::ClientConfig::new(args.schema_registry_url).build()?;
663            let key_schema_id = ccsr
664                .publish_schema(
665                    &format!("{}-key", args.topic),
666                    &key_schema.to_string(),
667                    mz_ccsr::SchemaType::Avro,
668                    &[],
669                )
670                .await?;
671            let generator =
672                RandomAvroGenerator::new(key_schema, &args.avro_key_distribution.unwrap());
673            Some(ValueGenerator::RandomAvro {
674                inner: generator,
675                schema: key_schema,
676                schema_id: key_schema_id,
677            })
678        }
679        _ => {
680            // Clap may one day be able to do this validation automatically.
681            // See: https://github.com/clap-rs/clap/discussions/2039
682            if args.avro_key_schema.is_some() {
683                bail!("cannot specify --avro-key-schema without --keys=avro");
684            }
685            if args.avro_key_distribution.is_some() {
686                bail!("cannot specify --avro-key-distribution without --keys=avro");
687            }
688            None
689        }
690    };
691    let key_dist = if let KeyFormat::Random = args.key_format {
692        Some(Uniform::new_inclusive(
693            args.key_min.unwrap(),
694            args.key_max.unwrap(),
695        ))
696    } else {
697        None
698    };
699
700    let threads = if args.threads == 0 {
701        num_cpus::get_physical()
702    } else {
703        args.threads
704    };
705    println!("Using {} threads...", threads);
706
707    let counter = AtomicUsize::new(0);
708    thread::scope(|scope| {
709        for thread in 0..threads {
710            let counter = &counter;
711            let topic = &args.topic;
712            let mut key_gen = key_gen.clone();
713            let mut value_gen = value_gen.clone();
714            let producer: ThreadedProducer<mz_kafka_util::client::MzClientContext> =
715                mz_kafka_util::client::create_new_client_config_simple()
716                    .set("bootstrap.servers", args.bootstrap_server.to_string())
717                    .create_with_context(MzClientContext::default())
718                    .unwrap();
719            let mut key_buf = vec![];
720            let mut value_buf = vec![];
721            let mut n = args.num_records / threads;
722            if thread < args.num_records % threads {
723                n += 1;
724            }
725            scope.spawn(move |_| {
726                let mut rng = thread_rng();
727                for _ in 0..n {
728                    let i = counter.fetch_add(1, Ordering::Relaxed);
729                    if !args.quiet && i % 100_000 == 0 {
730                        eprintln!("Generating message {}", i);
731                    }
732                    value_gen.next_value(&mut value_buf, &mut rng);
733                    if let Some(key_gen) = key_gen.as_mut() {
734                        key_gen.next_value(&mut key_buf, &mut rng);
735                    } else if let Some(key_dist) = key_dist.as_ref() {
736                        key_buf.clear();
737                        key_buf.extend(key_dist.sample(&mut rng).to_be_bytes().iter())
738                    } else {
739                        key_buf.clear();
740                        key_buf.extend(u64::cast_from(i).to_be_bytes().iter())
741                    };
742
743                    let mut rec = BaseRecord::to(topic).key(&key_buf).payload(&value_buf);
744                    // TODO(benesch): rewrite to avoid `as`.
745                    #[allow(clippy::as_conversions)]
746                    if args.partitions_round_robin != 0 {
747                        rec = rec.partition((i % args.partitions_round_robin) as i32);
748                    }
749                    let mut rec = Some(rec);
750
751                    Retry::default()
752                        .clamp_backoff(Duration::from_secs(1))
753                        .retry(|_| match producer.send(rec.take().unwrap()) {
754                            Ok(()) => Ok(()),
755                            Err((
756                                e @ KafkaError::MessageProduction(RDKafkaErrorCode::QueueFull),
757                                r,
758                            )) => {
759                                rec = Some(r);
760                                Err(e)
761                            }
762                            Err((e, _)) => panic!("unexpected Kafka error: {}", e),
763                        })
764                        .expect("unable to produce to Kafka");
765                }
766                producer.flush(Timeout::Never).unwrap();
767            });
768        }
769    })
770    .unwrap();
771
772    Ok(())
773}