1use std::cmp;
11use std::io::{BufRead, Read};
12use std::time::Duration;
13
14use anyhow::{Context, anyhow, bail};
15use byteorder::{NetworkEndian, WriteBytesExt};
16use futures::stream::{FuturesUnordered, StreamExt};
17use maplit::btreemap;
18use prost::Message;
19use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
20use rdkafka::message::{Header, OwnedHeaders};
21use rdkafka::producer::FutureRecord;
22use serde::de::DeserializeOwned;
23use tokio::fs;
24
25use crate::action::{self, ControlFlow, State};
26use crate::format::avro::{self, Schema};
27use crate::format::bytes;
28use crate::parser::BuiltinCommand;
29
30const INGEST_BATCH_SIZE: isize = 10000;
31
32#[allow(clippy::disallowed_types)]
35fn extract_all_defined_types(
36 schema_json: &str,
37) -> anyhow::Result<std::collections::HashSet<String>> {
38 let value: serde_json::Value = serde_json::from_str(schema_json)
39 .context("parsing schema JSON to extract defined types")?;
40
41 let mut types = std::collections::HashSet::new();
42 collect_defined_types(&value, None, &mut types);
43 Ok(types)
44}
45
46#[allow(clippy::disallowed_types)]
48fn collect_defined_types(
49 value: &serde_json::Value,
50 parent_namespace: Option<&str>,
51 types: &mut std::collections::HashSet<String>,
52) {
53 match value {
54 serde_json::Value::Object(map) => {
55 let namespace = map
57 .get("namespace")
58 .and_then(|v| v.as_str())
59 .or(parent_namespace);
60
61 if let Some(type_val) = map.get("type")
63 && type_val
64 .as_str()
65 .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
66 {
67 if let Some(name) = map.get("name").and_then(|v| v.as_str()) {
68 let fullname = if name.contains('.') {
70 name.to_string()
71 } else if let Some(ns) = namespace {
72 format!("{}.{}", ns, name)
73 } else {
74 name.to_string()
75 };
76 types.insert(fullname);
77 }
78 }
79
80 for entity_type in &["type", "items", "values", "fields"] {
83 if let Some(val) = map.get(*entity_type) {
84 collect_defined_types(val, namespace, types);
85 }
86 }
87 }
88 serde_json::Value::Array(arr) => {
89 for item in arr {
90 collect_defined_types(item, parent_namespace, types);
91 }
92 }
93 _ => {}
94 }
95}
96
97#[allow(clippy::disallowed_types)]
100fn extract_type_references(schema_json: &str) -> anyhow::Result<std::collections::HashSet<String>> {
101 let value: serde_json::Value = serde_json::from_str(schema_json)
102 .context("parsing schema JSON to extract type references")?;
103
104 let mut references = std::collections::HashSet::new();
105 collect_type_references(&value, &mut references);
106 Ok(references)
107}
108
109#[allow(clippy::disallowed_types)]
111fn collect_type_references(
112 value: &serde_json::Value,
113 references: &mut std::collections::HashSet<String>,
114) {
115 match value {
116 serde_json::Value::String(s) => {
117 if s.contains('.')
119 && ![
120 "null", "boolean", "int", "long", "float", "double", "bytes", "string",
121 ]
122 .contains(&s.as_str())
123 {
124 references.insert(s.clone());
125 }
126 }
127 serde_json::Value::Object(map) => {
128 if let Some(type_val) = map.get("type")
131 && type_val
132 .as_str()
133 .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
134 {
135 if let Some(fields) = map.get("fields") {
136 collect_type_references(fields, references);
137 }
138 return;
139 }
140
141 for entity_type in &["type", "items", "values", "fields"] {
144 if let Some(val) = map.get(*entity_type) {
145 collect_type_references(val, references);
146 }
147 }
148 }
149 serde_json::Value::Array(arr) => {
150 for item in arr {
151 collect_type_references(item, references);
152 }
153 }
154 _ => {}
155 }
156}
157
158#[derive(Clone)]
159enum Format {
160 Avro {
161 schema: String,
162 confluent_wire_format: bool,
163 references: Vec<String>,
165 },
166 Protobuf {
167 descriptor_file: String,
168 message: String,
169 confluent_wire_format: bool,
170 schema_id_subject: Option<String>,
171 schema_message_id: u8,
172 },
173 Bytes {
174 terminator: Option<u8>,
175 },
176}
177
178enum Transcoder {
179 PlainAvro {
180 schema: Schema,
181 },
182 ConfluentAvro {
183 schema: Schema,
184 schema_id: i32,
185 },
186 Protobuf {
187 message: MessageDescriptor,
188 confluent_wire_format: bool,
189 schema_id: i32,
190 schema_message_id: u8,
191 },
192 Bytes {
193 terminator: Option<u8>,
194 },
195}
196
197impl Transcoder {
198 fn decode_json<R, T>(row: R) -> Result<Option<T>, anyhow::Error>
199 where
200 R: Read,
201 T: DeserializeOwned,
202 {
203 let deserializer = serde_json::Deserializer::from_reader(row);
204 deserializer
205 .into_iter()
206 .next()
207 .transpose()
208 .context("parsing json")
209 }
210
211 fn transcode<R>(&self, mut row: R) -> Result<Option<Vec<u8>>, anyhow::Error>
212 where
213 R: BufRead,
214 {
215 match self {
216 Transcoder::ConfluentAvro { schema, schema_id } => {
217 if let Some(val) = Self::decode_json(row)? {
218 let val = avro::from_json(&val, schema.top_node())?;
219 let mut out = vec![];
220 out.write_u8(0).unwrap();
226 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
227 out.extend(avro::to_avro_datum(schema, val)?);
228 Ok(Some(out))
229 } else {
230 Ok(None)
231 }
232 }
233 Transcoder::PlainAvro { schema } => {
234 if let Some(val) = Self::decode_json(row)? {
235 let val = avro::from_json(&val, schema.top_node())?;
236 let mut out = vec![];
237 out.extend(avro::to_avro_datum(schema, val)?);
238 Ok(Some(out))
239 } else {
240 Ok(None)
241 }
242 }
243 Transcoder::Protobuf {
244 message,
245 confluent_wire_format,
246 schema_id,
247 schema_message_id,
248 } => {
249 if let Some(val) = Self::decode_json::<_, serde_json::Value>(row)? {
250 let message = DynamicMessage::deserialize(message.clone(), val)
251 .context("parsing protobuf JSON")?;
252 let mut out = vec![];
253 if *confluent_wire_format {
254 out.write_u8(0).unwrap();
261 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
262 out.write_u8(*schema_message_id).unwrap();
263 }
264 message.encode(&mut out)?;
265 Ok(Some(out))
266 } else {
267 Ok(None)
268 }
269 }
270 Transcoder::Bytes { terminator } => {
271 let mut out = vec![];
272 match terminator {
273 Some(t) => {
274 row.read_until(*t, &mut out)?;
275 if out.last() == Some(t) {
276 out.pop();
277 }
278 }
279 None => {
280 row.read_to_end(&mut out)?;
281 }
282 }
283 if out.is_empty() {
284 Ok(None)
285 } else {
286 Ok(Some(bytes::unescape(&out)?))
287 }
288 }
289 }
290 }
291}
292
293pub async fn run_ingest(
294 mut cmd: BuiltinCommand,
295 state: &mut State,
296) -> Result<ControlFlow, anyhow::Error> {
297 let topic_prefix = format!("testdrive-{}", cmd.args.string("topic")?);
298 let partition = cmd.args.opt_parse::<i32>("partition")?;
299 let start_iteration = cmd.args.opt_parse::<isize>("start-iteration")?.unwrap_or(0);
300 let repeat = cmd.args.opt_parse::<isize>("repeat")?.unwrap_or(1);
301 let omit_key = cmd.args.opt_bool("omit-key")?.unwrap_or(false);
302 let omit_value = cmd.args.opt_bool("omit-value")?.unwrap_or(false);
303 let schema_id_var = cmd.args.opt_parse("set-schema-id-var")?;
304 let key_schema_id_var = cmd.args.opt_parse("set-key-schema-id-var")?;
305 let format = match cmd.args.string("format")?.as_str() {
306 "avro" => Format::Avro {
307 schema: cmd.args.string("schema")?,
308 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
309 references: cmd
311 .args
312 .opt_string("references")
313 .map(|s| s.split(',').map(|s| s.to_string()).collect())
314 .unwrap_or_default(),
315 },
316 "protobuf" => {
317 let descriptor_file = cmd.args.string("descriptor-file")?;
318 let message = cmd.args.string("message")?;
319 Format::Protobuf {
320 descriptor_file,
321 message,
322 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
325 schema_id_subject: cmd.args.opt_string("schema-id-subject"),
326 schema_message_id: cmd.args.opt_parse::<u8>("schema-message-id")?.unwrap_or(0),
327 }
328 }
329 "bytes" => Format::Bytes { terminator: None },
330 f => bail!("unknown format: {}", f),
331 };
332 let mut key_schema = cmd.args.opt_string("key-schema");
333 let key_format = match cmd.args.opt_string("key-format").as_deref() {
334 Some("avro") => Some(Format::Avro {
335 schema: key_schema.take().ok_or_else(|| {
336 anyhow!("key-schema parameter required when key-format is present")
337 })?,
338 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
339 references: cmd
340 .args
341 .opt_string("key-references")
342 .map(|s| s.split(',').map(|s| s.to_string()).collect())
343 .unwrap_or_default(),
344 }),
345 Some("protobuf") => {
346 let descriptor_file = cmd.args.string("key-descriptor-file")?;
347 let message = cmd.args.string("key-message")?;
348 Some(Format::Protobuf {
349 descriptor_file,
350 message,
351 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
352 schema_id_subject: cmd.args.opt_string("key-schema-id-subject"),
353 schema_message_id: cmd
354 .args
355 .opt_parse::<u8>("key-schema-message-id")?
356 .unwrap_or(0),
357 })
358 }
359 Some("bytes") => Some(Format::Bytes {
360 terminator: match cmd.args.opt_parse::<char>("key-terminator")? {
361 Some(c) => match u8::try_from(c) {
362 Ok(c) => Some(c),
363 Err(_) => bail!("key terminator must be single ASCII character"),
364 },
365 None => Some(b':'),
366 },
367 }),
368 Some(f) => bail!("unknown key format: {}", f),
369 None => None,
370 };
371 if key_schema.is_some() {
372 anyhow::bail!("key-schema specified without a matching key-format");
373 }
374
375 let timestamp = cmd.args.opt_parse("timestamp")?;
376
377 use serde_json::Value;
378 let headers = if let Some(headers_val) = cmd.args.opt_parse::<serde_json::Value>("headers")? {
379 let mut headers = Vec::new();
380 let headers_maps = match headers_val {
381 Value::Array(values) => {
382 let mut headers_map = Vec::new();
383 for value in values {
384 if let Value::Object(m) = value {
385 headers_map.push(m)
386 } else {
387 bail!("`headers` array values must be maps")
388 }
389 }
390 headers_map
391 }
392 Value::Object(v) => vec![v],
393 _ => bail!("`headers` must be a map or an array"),
394 };
395
396 for headers_map in headers_maps {
397 for (k, v) in headers_map.iter() {
398 headers.push((k.clone(), match v {
399 Value::String(val) => Some(val.as_bytes().to_vec()),
400 Value::Array(val) => {
401 let mut values = Vec::new();
402 for value in val {
403 if let Value::Number(int) = value {
404 values.push(u8::try_from(int.as_i64().unwrap()).unwrap())
405 } else {
406 bail!("`headers` value arrays must only contain numbers (to represent bytes)")
407 }
408 }
409 Some(values.clone())
410 },
411 Value::Null => None,
412 _ => bail!("`headers` must have string, int array or null values")
413 }));
414 }
415 }
416 Some(headers)
417 } else {
418 None
419 };
420
421 cmd.args.done()?;
422
423 if let Some(kf) = &key_format {
424 fn is_confluent_format(fmt: &Format) -> Option<bool> {
425 match fmt {
426 Format::Avro {
427 confluent_wire_format,
428 ..
429 } => Some(*confluent_wire_format),
430 Format::Protobuf {
431 confluent_wire_format,
432 ..
433 } => Some(*confluent_wire_format),
434 Format::Bytes { .. } => None,
435 }
436 }
437 match (is_confluent_format(kf), is_confluent_format(&format)) {
438 (Some(false), Some(true)) | (Some(true), Some(false)) => {
439 bail!(
440 "It does not make sense to have the key be in confluent format and not the value, or vice versa."
441 );
442 }
443 _ => {}
444 }
445 }
446
447 let topic_name = &format!("{}-{}", topic_prefix, state.seed);
448 println!(
449 "Ingesting data into Kafka topic {} with start_iteration = {}, repeat = {}",
450 topic_name, start_iteration, repeat
451 );
452
453 let set_schema_id_var = |state: &mut State, schema_id_var, transcoder| match transcoder {
454 &Transcoder::ConfluentAvro { schema_id, .. } | &Transcoder::Protobuf { schema_id, .. } => {
455 state.cmd_vars.insert(schema_id_var, schema_id.to_string());
456 }
457 _ => (),
458 };
459
460 let value_transcoder =
461 make_transcoder(state, format.clone(), format!("{}-value", topic_name)).await?;
462 if let Some(var) = schema_id_var {
463 set_schema_id_var(state, var, &value_transcoder);
464 }
465
466 let key_transcoder = match key_format.clone() {
467 None => None,
468 Some(f) => {
469 let transcoder = make_transcoder(state, f, format!("{}-key", topic_name)).await?;
470 if let Some(var) = key_schema_id_var {
471 set_schema_id_var(state, var, &transcoder);
472 }
473 Some(transcoder)
474 }
475 };
476
477 let mut futs = FuturesUnordered::new();
478
479 for iteration in start_iteration..(start_iteration + repeat) {
480 let iter = &mut cmd.input.iter().peekable();
481
482 for row in iter {
483 let row = action::substitute_vars(
484 row,
485 &btreemap! { "kafka-ingest.iteration".into() => iteration.to_string() },
486 &None,
487 false,
488 )?;
489 let mut row = row.as_bytes();
490 let key = match (omit_key, &key_transcoder) {
491 (true, _) => None,
492 (false, None) => None,
493 (false, Some(kt)) => kt.transcode(&mut row)?,
494 };
495 let value = if omit_value {
496 None
497 } else {
498 value_transcoder
499 .transcode(&mut row)
500 .with_context(|| format!("parsing row: {}", String::from_utf8_lossy(row)))?
501 };
502 let producer = &state.kafka_producer;
503 let timeout = cmp::max(state.default_timeout, Duration::from_secs(1));
504 let headers = headers.clone();
505 futs.push(async move {
506 let mut record: FutureRecord<_, _> = FutureRecord::to(topic_name);
507
508 if let Some(partition) = partition {
509 record = record.partition(partition);
510 }
511 if let Some(key) = &key {
512 record = record.key(key);
513 }
514 if let Some(value) = &value {
515 record = record.payload(value);
516 }
517 if let Some(timestamp) = timestamp {
518 record = record.timestamp(timestamp);
519 }
520 if let Some(headers) = headers {
521 let mut rd_meta = OwnedHeaders::new();
522 for (k, v) in &headers {
523 rd_meta = rd_meta.insert(Header {
524 key: k,
525 value: v.as_deref(),
526 });
527 }
528 record = record.headers(rd_meta);
529 }
530 producer.send(record, timeout).await
531 });
532 }
533
534 if iteration % INGEST_BATCH_SIZE == 0 || iteration == (start_iteration + repeat - 1) {
536 while let Some(res) = futs.next().await {
537 res.map_err(|(e, _message)| e)?;
538 }
539 }
540 }
541 Ok(ControlFlow::Continue)
542}
543
544async fn make_transcoder(
545 state: &State,
546 format: Format,
547 ccsr_subject: String,
548) -> Result<Transcoder, anyhow::Error> {
549 match format {
550 Format::Avro {
551 schema,
552 confluent_wire_format,
553 references,
554 } => {
555 if confluent_wire_format {
556 #[allow(clippy::disallowed_types)]
560 let mut reference_subjects = vec![];
561 #[allow(clippy::disallowed_types)]
562 let mut seen_subjects: std::collections::HashSet<String> =
563 std::collections::HashSet::new();
564 let mut queue: Vec<String> = references.clone();
565
566 while let Some(ref_name) = queue.pop() {
568 if seen_subjects.contains(&ref_name) {
569 continue;
570 }
571 seen_subjects.insert(ref_name.clone());
572
573 let (subject, ref_deps) = state
574 .ccsr_client
575 .get_subject_with_references(&ref_name)
576 .await
577 .with_context(|| format!("fetching reference {}", ref_name))?;
578
579 for dep in ref_deps {
581 if !seen_subjects.contains(&dep.subject) {
582 queue.push(dep.subject);
583 }
584 }
585
586 let defined_types = extract_all_defined_types(&subject.schema.raw)
588 .with_context(|| {
589 format!("extracting type names from reference schema {}", ref_name)
590 })?;
591 reference_subjects.push((
592 ref_name,
593 subject.version,
594 subject.schema.raw,
595 defined_types,
596 ));
597 }
598
599 reference_subjects.reverse();
603
604 let direct_refs = extract_type_references(&schema)
606 .context("extracting type references from schema")?;
607
608 let mut schema_references = vec![];
611 for type_name in &direct_refs {
612 for (subject_name, version, _, defined_types) in &reference_subjects {
613 if defined_types.contains(type_name) {
614 schema_references.push(mz_ccsr::SchemaReference {
615 name: type_name.clone(),
616 subject: subject_name.clone(),
617 version: *version,
618 });
619 break;
620 }
621 }
622 }
623
624 let reference_raw_schemas: Vec<_> = reference_subjects
626 .into_iter()
627 .map(|(_, _, raw, _)| raw)
628 .collect();
629
630 let schema_id = state
631 .ccsr_client
632 .publish_schema(
633 &ccsr_subject,
634 &schema,
635 mz_ccsr::SchemaType::Avro,
636 &schema_references,
637 )
638 .await
639 .context("publishing to schema registry")?;
640
641 let schema = if reference_raw_schemas.is_empty() {
643 avro::parse_schema(&schema, &[])
644 .with_context(|| format!("parsing avro schema: {}", schema))?
645 } else {
646 let mut parsed_refs: Vec<Schema> = vec![];
649 for raw in &reference_raw_schemas {
650 let schema_value: serde_json::Value = serde_json::from_str(raw)
651 .with_context(|| format!("parsing reference schema JSON: {}", raw))?;
652 let parsed = Schema::parse_with_references(&schema_value, &parsed_refs)
653 .with_context(|| format!("parsing reference avro schema: {}", raw))?;
654 parsed_refs.push(parsed);
655 }
656
657 let schema_value: serde_json::Value = serde_json::from_str(&schema)
659 .with_context(|| format!("parsing schema JSON: {}", schema))?;
660 Schema::parse_with_references(&schema_value, &parsed_refs).with_context(
661 || format!("parsing avro schema with references: {}", schema),
662 )?
663 };
664
665 Ok::<_, anyhow::Error>(Transcoder::ConfluentAvro { schema, schema_id })
666 } else {
667 let schema = avro::parse_schema(&schema, &[])
668 .with_context(|| format!("parsing avro schema: {}", schema))?;
669 Ok(Transcoder::PlainAvro { schema })
670 }
671 }
672 Format::Protobuf {
673 descriptor_file,
674 message,
675 confluent_wire_format,
676 schema_id_subject,
677 schema_message_id,
678 } => {
679 let schema_id = if confluent_wire_format {
680 state
681 .ccsr_client
682 .get_schema_by_subject(schema_id_subject.as_deref().unwrap_or(&ccsr_subject))
683 .await
684 .context("fetching schema from registry")?
685 .id
686 } else {
687 0
688 };
689
690 let bytes = fs::read(state.temp_path.join(descriptor_file))
691 .await
692 .context("reading protobuf descriptor file")?;
693 let fd = DescriptorPool::decode(&*bytes).context("parsing protobuf descriptor file")?;
694 let message = fd
695 .get_message_by_name(&message)
696 .ok_or_else(|| anyhow!("unknown message name {}", message))?;
697 Ok(Transcoder::Protobuf {
698 message,
699 confluent_wire_format,
700 schema_id,
701 schema_message_id,
702 })
703 }
704 Format::Bytes { terminator } => Ok(Transcoder::Bytes { terminator }),
705 }
706}