1use std::cmp;
11use std::io::{BufRead, Read};
12use std::time::Duration;
13
14use anyhow::{Context, anyhow, bail};
15use byteorder::{NetworkEndian, WriteBytesExt};
16use futures::stream::{FuturesUnordered, StreamExt};
17use maplit::btreemap;
18use prost::Message;
19use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
20use rdkafka::message::{Header, OwnedHeaders};
21use rdkafka::producer::FutureRecord;
22use serde::de::DeserializeOwned;
23use tokio::fs;
24use uuid::Uuid;
25
26use crate::action::{self, ControlFlow, State};
27use crate::format::avro::{self, Schema};
28use crate::format::bytes;
29use crate::parser::BuiltinCommand;
30
31const INGEST_BATCH_SIZE: isize = 10000;
32
33#[allow(clippy::disallowed_types)]
36fn extract_all_defined_types(
37 schema_json: &str,
38) -> anyhow::Result<std::collections::HashSet<String>> {
39 let value: serde_json::Value = serde_json::from_str(schema_json)
40 .context("parsing schema JSON to extract defined types")?;
41
42 let mut types = std::collections::HashSet::new();
43 collect_defined_types(&value, None, &mut types);
44 Ok(types)
45}
46
47#[allow(clippy::disallowed_types)]
49fn collect_defined_types(
50 value: &serde_json::Value,
51 parent_namespace: Option<&str>,
52 types: &mut std::collections::HashSet<String>,
53) {
54 match value {
55 serde_json::Value::Object(map) => {
56 let namespace = map
58 .get("namespace")
59 .and_then(|v| v.as_str())
60 .or(parent_namespace);
61
62 if let Some(type_val) = map.get("type")
64 && type_val
65 .as_str()
66 .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
67 {
68 if let Some(name) = map.get("name").and_then(|v| v.as_str()) {
69 let fullname = if name.contains('.') {
71 name.to_string()
72 } else if let Some(ns) = namespace {
73 format!("{}.{}", ns, name)
74 } else {
75 name.to_string()
76 };
77 types.insert(fullname);
78 }
79 }
80
81 for entity_type in &["type", "items", "values", "fields"] {
84 if let Some(val) = map.get(*entity_type) {
85 collect_defined_types(val, namespace, types);
86 }
87 }
88 }
89 serde_json::Value::Array(arr) => {
90 for item in arr {
91 collect_defined_types(item, parent_namespace, types);
92 }
93 }
94 _ => {}
95 }
96}
97
98#[allow(clippy::disallowed_types)]
101fn extract_type_references(schema_json: &str) -> anyhow::Result<std::collections::HashSet<String>> {
102 let value: serde_json::Value = serde_json::from_str(schema_json)
103 .context("parsing schema JSON to extract type references")?;
104
105 let mut references = std::collections::HashSet::new();
106 collect_type_references(&value, &mut references);
107 Ok(references)
108}
109
110#[allow(clippy::disallowed_types)]
112fn collect_type_references(
113 value: &serde_json::Value,
114 references: &mut std::collections::HashSet<String>,
115) {
116 match value {
117 serde_json::Value::String(s) => {
118 if s.contains('.')
120 && ![
121 "null", "boolean", "int", "long", "float", "double", "bytes", "string",
122 ]
123 .contains(&s.as_str())
124 {
125 references.insert(s.clone());
126 }
127 }
128 serde_json::Value::Object(map) => {
129 if let Some(type_val) = map.get("type")
132 && type_val
133 .as_str()
134 .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
135 {
136 if let Some(fields) = map.get("fields") {
137 collect_type_references(fields, references);
138 }
139 return;
140 }
141
142 for entity_type in &["type", "items", "values", "fields"] {
145 if let Some(val) = map.get(*entity_type) {
146 collect_type_references(val, references);
147 }
148 }
149 }
150 serde_json::Value::Array(arr) => {
151 for item in arr {
152 collect_type_references(item, references);
153 }
154 }
155 _ => {}
156 }
157}
158
159#[derive(Clone)]
160enum Format {
161 Avro {
162 schema: String,
163 confluent_wire_format: bool,
164 glue_schema_version_id: Option<Uuid>,
168 references: Vec<String>,
170 },
171 Protobuf {
172 descriptor_file: String,
173 message: String,
174 confluent_wire_format: bool,
175 schema_id_subject: Option<String>,
176 schema_message_id: u8,
177 },
178 Bytes {
179 terminator: Option<u8>,
180 },
181}
182
183enum Transcoder {
184 PlainAvro {
185 schema: Schema,
186 },
187 ConfluentAvro {
188 schema: Schema,
189 schema_id: i32,
190 },
191 GlueAvro {
192 schema: Schema,
193 schema_version_id: Uuid,
194 },
195 Protobuf {
196 message: MessageDescriptor,
197 confluent_wire_format: bool,
198 schema_id: i32,
199 schema_message_id: u8,
200 },
201 Bytes {
202 terminator: Option<u8>,
203 },
204}
205
206impl Transcoder {
207 fn decode_json<R, T>(row: R) -> Result<Option<T>, anyhow::Error>
208 where
209 R: Read,
210 T: DeserializeOwned,
211 {
212 let deserializer = serde_json::Deserializer::from_reader(row);
213 deserializer
214 .into_iter()
215 .next()
216 .transpose()
217 .context("parsing json")
218 }
219
220 fn transcode<R>(&self, mut row: R) -> Result<Option<Vec<u8>>, anyhow::Error>
221 where
222 R: BufRead,
223 {
224 match self {
225 Transcoder::ConfluentAvro { schema, schema_id } => {
226 if let Some(val) = Self::decode_json(row)? {
227 let val = avro::from_json(&val, schema.top_node())?;
228 let mut out = vec![];
229 out.write_u8(0).unwrap();
235 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
236 out.extend(avro::to_avro_datum(schema, val)?);
237 Ok(Some(out))
238 } else {
239 Ok(None)
240 }
241 }
242 Transcoder::GlueAvro {
243 schema,
244 schema_version_id,
245 } => {
246 if let Some(val) = Self::decode_json(row)? {
247 let val = avro::from_json(&val, schema.top_node())?;
248 let mut out = vec![];
249 out.write_u8(0x03).unwrap();
254 out.write_u8(0x00).unwrap();
255 out.extend_from_slice(schema_version_id.as_bytes());
256 out.extend(avro::to_avro_datum(schema, val)?);
257 Ok(Some(out))
258 } else {
259 Ok(None)
260 }
261 }
262 Transcoder::PlainAvro { schema } => {
263 if let Some(val) = Self::decode_json(row)? {
264 let val = avro::from_json(&val, schema.top_node())?;
265 let mut out = vec![];
266 out.extend(avro::to_avro_datum(schema, val)?);
267 Ok(Some(out))
268 } else {
269 Ok(None)
270 }
271 }
272 Transcoder::Protobuf {
273 message,
274 confluent_wire_format,
275 schema_id,
276 schema_message_id,
277 } => {
278 if let Some(val) = Self::decode_json::<_, serde_json::Value>(row)? {
279 let message = DynamicMessage::deserialize(message.clone(), val)
280 .context("parsing protobuf JSON")?;
281 let mut out = vec![];
282 if *confluent_wire_format {
283 out.write_u8(0).unwrap();
290 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
291 out.write_u8(*schema_message_id).unwrap();
292 }
293 message.encode(&mut out)?;
294 Ok(Some(out))
295 } else {
296 Ok(None)
297 }
298 }
299 Transcoder::Bytes { terminator } => {
300 let mut out = vec![];
301 match terminator {
302 Some(t) => {
303 row.read_until(*t, &mut out)?;
304 if out.last() == Some(t) {
305 out.pop();
306 }
307 }
308 None => {
309 row.read_to_end(&mut out)?;
310 }
311 }
312 if out.is_empty() {
313 Ok(None)
314 } else {
315 Ok(Some(bytes::unescape(&out)?))
316 }
317 }
318 }
319 }
320}
321
322pub async fn run_ingest(
323 mut cmd: BuiltinCommand,
324 state: &mut State,
325) -> Result<ControlFlow, anyhow::Error> {
326 let topic_prefix = format!("testdrive-{}", cmd.args.string("topic")?);
327 let partition = cmd.args.opt_parse::<i32>("partition")?;
328 let start_iteration = cmd.args.opt_parse::<isize>("start-iteration")?.unwrap_or(0);
329 let repeat = cmd.args.opt_parse::<isize>("repeat")?.unwrap_or(1);
330 let omit_key = cmd.args.opt_bool("omit-key")?.unwrap_or(false);
331 let omit_value = cmd.args.opt_bool("omit-value")?.unwrap_or(false);
332 let schema_id_var = cmd.args.opt_parse("set-schema-id-var")?;
333 let key_schema_id_var = cmd.args.opt_parse("set-key-schema-id-var")?;
334 let format = match cmd.args.string("format")?.as_str() {
335 "avro" => {
336 let glue_schema_version_id = cmd
337 .args
338 .opt_string("glue-schema-version-id")
339 .map(|s| Uuid::parse_str(&s).context("parsing glue-schema-version-id as UUID"))
340 .transpose()?;
341 let confluent_wire_format_explicit = cmd.args.opt_bool("confluent-wire-format")?;
342 if glue_schema_version_id.is_some() && confluent_wire_format_explicit == Some(true) {
343 bail!("confluent-wire-format=true is incompatible with glue-schema-version-id");
344 }
345 let confluent_wire_format =
347 confluent_wire_format_explicit.unwrap_or_else(|| glue_schema_version_id.is_none());
348 Format::Avro {
349 schema: cmd.args.string("schema")?,
350 confluent_wire_format,
351 glue_schema_version_id,
352 references: cmd
353 .args
354 .opt_string("references")
355 .map(|s| s.split(',').map(|s| s.to_string()).collect())
356 .unwrap_or_default(),
357 }
358 }
359 "protobuf" => {
360 let descriptor_file = cmd.args.string("descriptor-file")?;
361 let message = cmd.args.string("message")?;
362 Format::Protobuf {
363 descriptor_file,
364 message,
365 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
368 schema_id_subject: cmd.args.opt_string("schema-id-subject"),
369 schema_message_id: cmd.args.opt_parse::<u8>("schema-message-id")?.unwrap_or(0),
370 }
371 }
372 "bytes" => Format::Bytes { terminator: None },
373 f => bail!("unknown format: {}", f),
374 };
375 let mut key_schema = cmd.args.opt_string("key-schema");
376 let key_format = match cmd.args.opt_string("key-format").as_deref() {
377 Some("avro") => {
378 let key_glue_schema_version_id = cmd
379 .args
380 .opt_string("key-glue-schema-version-id")
381 .map(|s| Uuid::parse_str(&s).context("parsing key-glue-schema-version-id as UUID"))
382 .transpose()?;
383 let confluent_wire_format_explicit = cmd.args.opt_bool("confluent-wire-format")?;
384 if key_glue_schema_version_id.is_some() && confluent_wire_format_explicit == Some(true)
385 {
386 bail!("confluent-wire-format=true is incompatible with key-glue-schema-version-id");
387 }
388 let confluent_wire_format = confluent_wire_format_explicit
390 .unwrap_or_else(|| key_glue_schema_version_id.is_none());
391 Some(Format::Avro {
392 schema: key_schema.take().ok_or_else(|| {
393 anyhow!("key-schema parameter required when key-format is present")
394 })?,
395 confluent_wire_format,
396 glue_schema_version_id: key_glue_schema_version_id,
397 references: cmd
398 .args
399 .opt_string("key-references")
400 .map(|s| s.split(',').map(|s| s.to_string()).collect())
401 .unwrap_or_default(),
402 })
403 }
404 Some("protobuf") => {
405 let descriptor_file = cmd.args.string("key-descriptor-file")?;
406 let message = cmd.args.string("key-message")?;
407 Some(Format::Protobuf {
408 descriptor_file,
409 message,
410 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
411 schema_id_subject: cmd.args.opt_string("key-schema-id-subject"),
412 schema_message_id: cmd
413 .args
414 .opt_parse::<u8>("key-schema-message-id")?
415 .unwrap_or(0),
416 })
417 }
418 Some("bytes") => Some(Format::Bytes {
419 terminator: match cmd.args.opt_parse::<char>("key-terminator")? {
420 Some(c) => match u8::try_from(c) {
421 Ok(c) => Some(c),
422 Err(_) => bail!("key terminator must be single ASCII character"),
423 },
424 None => Some(b':'),
425 },
426 }),
427 Some(f) => bail!("unknown key format: {}", f),
428 None => None,
429 };
430 if key_schema.is_some() {
431 anyhow::bail!("key-schema specified without a matching key-format");
432 }
433
434 let timestamp = cmd.args.opt_parse("timestamp")?;
435
436 use serde_json::Value;
437 let headers = if let Some(headers_val) = cmd.args.opt_parse::<serde_json::Value>("headers")? {
438 let mut headers = Vec::new();
439 let headers_maps = match headers_val {
440 Value::Array(values) => {
441 let mut headers_map = Vec::new();
442 for value in values {
443 if let Value::Object(m) = value {
444 headers_map.push(m)
445 } else {
446 bail!("`headers` array values must be maps")
447 }
448 }
449 headers_map
450 }
451 Value::Object(v) => vec![v],
452 _ => bail!("`headers` must be a map or an array"),
453 };
454
455 for headers_map in headers_maps {
456 for (k, v) in headers_map.iter() {
457 headers.push((k.clone(), match v {
458 Value::String(val) => Some(val.as_bytes().to_vec()),
459 Value::Array(val) => {
460 let mut values = Vec::new();
461 for value in val {
462 if let Value::Number(int) = value {
463 values.push(u8::try_from(int.as_i64().unwrap()).unwrap())
464 } else {
465 bail!("`headers` value arrays must only contain numbers (to represent bytes)")
466 }
467 }
468 Some(values.clone())
469 },
470 Value::Null => None,
471 _ => bail!("`headers` must have string, int array or null values")
472 }));
473 }
474 }
475 Some(headers)
476 } else {
477 None
478 };
479
480 cmd.args.done()?;
481
482 if let Some(kf) = &key_format {
483 fn is_confluent_format(fmt: &Format) -> Option<bool> {
484 match fmt {
485 Format::Avro {
486 confluent_wire_format,
487 glue_schema_version_id,
488 ..
489 } => {
490 if glue_schema_version_id.is_some() {
492 None
493 } else {
494 Some(*confluent_wire_format)
495 }
496 }
497 Format::Protobuf {
498 confluent_wire_format,
499 ..
500 } => Some(*confluent_wire_format),
501 Format::Bytes { .. } => None,
502 }
503 }
504 match (is_confluent_format(kf), is_confluent_format(&format)) {
505 (Some(false), Some(true)) | (Some(true), Some(false)) => {
506 bail!(
507 "It does not make sense to have the key be in confluent format and not the value, or vice versa."
508 );
509 }
510 _ => {}
511 }
512 }
513
514 let topic_name = &format!("{}-{}", topic_prefix, state.seed);
515 println!(
516 "Ingesting data into Kafka topic {} with start_iteration = {}, repeat = {}",
517 topic_name, start_iteration, repeat
518 );
519
520 let set_schema_id_var = |state: &mut State, schema_id_var, transcoder| match transcoder {
521 &Transcoder::ConfluentAvro { schema_id, .. } | &Transcoder::Protobuf { schema_id, .. } => {
522 state.cmd_vars.insert(schema_id_var, schema_id.to_string());
523 }
524 _ => (),
525 };
526
527 let value_transcoder =
528 make_transcoder(state, format.clone(), format!("{}-value", topic_name)).await?;
529 if let Some(var) = schema_id_var {
530 set_schema_id_var(state, var, &value_transcoder);
531 }
532
533 let key_transcoder = match key_format.clone() {
534 None => None,
535 Some(f) => {
536 let transcoder = make_transcoder(state, f, format!("{}-key", topic_name)).await?;
537 if let Some(var) = key_schema_id_var {
538 set_schema_id_var(state, var, &transcoder);
539 }
540 Some(transcoder)
541 }
542 };
543
544 let mut futs = FuturesUnordered::new();
545
546 for iteration in start_iteration..(start_iteration + repeat) {
547 let iter = &mut cmd.input.iter().peekable();
548
549 for row in iter {
550 let row = action::substitute_vars(
551 row,
552 &btreemap! { "kafka-ingest.iteration".into() => iteration.to_string() },
553 &None,
554 false,
555 )?;
556 let mut row = row.as_bytes();
557 let key = match (omit_key, &key_transcoder) {
558 (true, _) => None,
559 (false, None) => None,
560 (false, Some(kt)) => kt.transcode(&mut row)?,
561 };
562 let value = if omit_value {
563 None
564 } else {
565 value_transcoder
566 .transcode(&mut row)
567 .with_context(|| format!("parsing row: {}", String::from_utf8_lossy(row)))?
568 };
569 let producer = &state.kafka_producer;
570 let timeout = cmp::max(state.default_timeout, Duration::from_secs(1));
571 let headers = headers.clone();
572 futs.push(async move {
573 let mut record: FutureRecord<_, _> = FutureRecord::to(topic_name);
574
575 if let Some(partition) = partition {
576 record = record.partition(partition);
577 }
578 if let Some(key) = &key {
579 record = record.key(key);
580 }
581 if let Some(value) = &value {
582 record = record.payload(value);
583 }
584 if let Some(timestamp) = timestamp {
585 record = record.timestamp(timestamp);
586 }
587 if let Some(headers) = headers {
588 let mut rd_meta = OwnedHeaders::new();
589 for (k, v) in &headers {
590 rd_meta = rd_meta.insert(Header {
591 key: k,
592 value: v.as_deref(),
593 });
594 }
595 record = record.headers(rd_meta);
596 }
597 producer.send(record, timeout).await
598 });
599 }
600
601 if iteration % INGEST_BATCH_SIZE == 0 || iteration == (start_iteration + repeat - 1) {
603 while let Some(res) = futs.next().await {
604 res.map_err(|(e, _message)| e)?;
605 }
606 }
607 }
608 Ok(ControlFlow::Continue)
609}
610
611async fn make_transcoder(
612 state: &State,
613 format: Format,
614 ccsr_subject: String,
615) -> Result<Transcoder, anyhow::Error> {
616 match format {
617 Format::Avro {
618 schema,
619 confluent_wire_format,
620 glue_schema_version_id,
621 references,
622 } => {
623 if let Some(schema_version_id) = glue_schema_version_id {
624 if !references.is_empty() {
625 bail!("schema references are not supported with glue-schema-version-id");
626 }
627 let schema = avro::parse_schema(&schema, &[])
628 .with_context(|| format!("parsing avro schema: {}", schema))?;
629 return Ok(Transcoder::GlueAvro {
630 schema,
631 schema_version_id,
632 });
633 }
634 if confluent_wire_format {
635 #[allow(clippy::disallowed_types)]
639 let mut reference_subjects = vec![];
640 #[allow(clippy::disallowed_types)]
641 let mut seen_subjects: std::collections::HashSet<String> =
642 std::collections::HashSet::new();
643 let mut queue: Vec<String> = references.clone();
644
645 while let Some(ref_name) = queue.pop() {
647 if seen_subjects.contains(&ref_name) {
648 continue;
649 }
650 seen_subjects.insert(ref_name.clone());
651
652 let (subject, ref_deps) = state
653 .ccsr_client
654 .get_subject_with_references(&ref_name)
655 .await
656 .with_context(|| format!("fetching reference {}", ref_name))?;
657
658 for dep in ref_deps {
660 if !seen_subjects.contains(&dep.subject) {
661 queue.push(dep.subject);
662 }
663 }
664
665 let defined_types = extract_all_defined_types(&subject.schema.raw)
667 .with_context(|| {
668 format!("extracting type names from reference schema {}", ref_name)
669 })?;
670 reference_subjects.push((
671 ref_name,
672 subject.version,
673 subject.schema.raw,
674 defined_types,
675 ));
676 }
677
678 reference_subjects.reverse();
682
683 let direct_refs = extract_type_references(&schema)
685 .context("extracting type references from schema")?;
686
687 let mut schema_references = vec![];
690 for type_name in &direct_refs {
691 for (subject_name, version, _, defined_types) in &reference_subjects {
692 if defined_types.contains(type_name) {
693 schema_references.push(mz_ccsr::SchemaReference {
694 name: type_name.clone(),
695 subject: subject_name.clone(),
696 version: *version,
697 });
698 break;
699 }
700 }
701 }
702
703 let reference_raw_schemas: Vec<_> = reference_subjects
705 .into_iter()
706 .map(|(_, _, raw, _)| raw)
707 .collect();
708
709 let schema_id = state
710 .ccsr_client
711 .publish_schema(
712 &ccsr_subject,
713 &schema,
714 mz_ccsr::SchemaType::Avro,
715 &schema_references,
716 )
717 .await
718 .context("publishing to schema registry")?;
719
720 let schema = if reference_raw_schemas.is_empty() {
722 avro::parse_schema(&schema, &[])
723 .with_context(|| format!("parsing avro schema: {}", schema))?
724 } else {
725 let mut parsed_refs: Vec<Schema> = vec![];
728 for raw in &reference_raw_schemas {
729 let schema_value: serde_json::Value = serde_json::from_str(raw)
730 .with_context(|| format!("parsing reference schema JSON: {}", raw))?;
731 let parsed = Schema::parse_with_references(&schema_value, &parsed_refs)
732 .with_context(|| format!("parsing reference avro schema: {}", raw))?;
733 parsed_refs.push(parsed);
734 }
735
736 let schema_value: serde_json::Value = serde_json::from_str(&schema)
738 .with_context(|| format!("parsing schema JSON: {}", schema))?;
739 Schema::parse_with_references(&schema_value, &parsed_refs).with_context(
740 || format!("parsing avro schema with references: {}", schema),
741 )?
742 };
743
744 Ok::<_, anyhow::Error>(Transcoder::ConfluentAvro { schema, schema_id })
745 } else {
746 let schema = avro::parse_schema(&schema, &[])
747 .with_context(|| format!("parsing avro schema: {}", schema))?;
748 Ok(Transcoder::PlainAvro { schema })
749 }
750 }
751 Format::Protobuf {
752 descriptor_file,
753 message,
754 confluent_wire_format,
755 schema_id_subject,
756 schema_message_id,
757 } => {
758 let schema_id = if confluent_wire_format {
759 state
760 .ccsr_client
761 .get_schema_by_subject(schema_id_subject.as_deref().unwrap_or(&ccsr_subject))
762 .await
763 .context("fetching schema from registry")?
764 .id
765 } else {
766 0
767 };
768
769 let bytes = fs::read(state.temp_path.join(descriptor_file))
770 .await
771 .context("reading protobuf descriptor file")?;
772 let fd = DescriptorPool::decode(&*bytes).context("parsing protobuf descriptor file")?;
773 let message = fd
774 .get_message_by_name(&message)
775 .ok_or_else(|| anyhow!("unknown message name {}", message))?;
776 Ok(Transcoder::Protobuf {
777 message,
778 confluent_wire_format,
779 schema_id,
780 schema_message_id,
781 })
782 }
783 Format::Bytes { terminator } => Ok(Transcoder::Bytes { terminator }),
784 }
785}