Skip to main content

mz_testdrive/action/kafka/
ingest.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cmp;
11use std::io::{BufRead, Read};
12use std::time::Duration;
13
14use anyhow::{Context, anyhow, bail};
15use byteorder::{NetworkEndian, WriteBytesExt};
16use futures::stream::{FuturesUnordered, StreamExt};
17use maplit::btreemap;
18use prost::Message;
19use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
20use rdkafka::message::{Header, OwnedHeaders};
21use rdkafka::producer::FutureRecord;
22use serde::de::DeserializeOwned;
23use tokio::fs;
24
25use crate::action::{self, ControlFlow, State};
26use crate::format::avro::{self, Schema};
27use crate::format::bytes;
28use crate::parser::BuiltinCommand;
29
30const INGEST_BATCH_SIZE: isize = 10000;
31
32/// Extracts ALL type names defined in an Avro schema (including nested types).
33/// Returns a set of fully qualified type names.
34#[allow(clippy::disallowed_types)]
35fn extract_all_defined_types(
36    schema_json: &str,
37) -> anyhow::Result<std::collections::HashSet<String>> {
38    let value: serde_json::Value = serde_json::from_str(schema_json)
39        .context("parsing schema JSON to extract defined types")?;
40
41    let mut types = std::collections::HashSet::new();
42    collect_defined_types(&value, None, &mut types);
43    Ok(types)
44}
45
46/// Recursively collects all named type definitions from an Avro schema.
47#[allow(clippy::disallowed_types)]
48fn collect_defined_types(
49    value: &serde_json::Value,
50    parent_namespace: Option<&str>,
51    types: &mut std::collections::HashSet<String>,
52) {
53    match value {
54        serde_json::Value::Object(map) => {
55            // Get this schema's namespace (falls back to parent's namespace)
56            let namespace = map
57                .get("namespace")
58                .and_then(|v| v.as_str())
59                .or(parent_namespace);
60
61            // Check if this is a named type definition (record, enum, or fixed)
62            if let Some(type_val) = map.get("type")
63                && type_val
64                    .as_str()
65                    .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
66            {
67                if let Some(name) = map.get("name").and_then(|v| v.as_str()) {
68                    // Construct fully qualified name
69                    let fullname = if name.contains('.') {
70                        name.to_string()
71                    } else if let Some(ns) = namespace {
72                        format!("{}.{}", ns, name)
73                    } else {
74                        name.to_string()
75                    };
76                    types.insert(fullname);
77                }
78            }
79
80            // The following types may have references:
81            // type field, items (array types), values (map types), and fields (e.g. unions)
82            for entity_type in &["type", "items", "values", "fields"] {
83                if let Some(val) = map.get(*entity_type) {
84                    collect_defined_types(val, namespace, types);
85                }
86            }
87        }
88        serde_json::Value::Array(arr) => {
89            for item in arr {
90                collect_defined_types(item, parent_namespace, types);
91            }
92        }
93        _ => {}
94    }
95}
96
97/// Extracts all type references from an Avro schema JSON string.
98/// This finds all fully qualified type names that are referenced but not defined in the schema.
99#[allow(clippy::disallowed_types)]
100fn extract_type_references(schema_json: &str) -> anyhow::Result<std::collections::HashSet<String>> {
101    let value: serde_json::Value = serde_json::from_str(schema_json)
102        .context("parsing schema JSON to extract type references")?;
103
104    let mut references = std::collections::HashSet::new();
105    collect_type_references(&value, &mut references);
106    Ok(references)
107}
108
109/// Recursively collects type references from an Avro schema JSON value.
110#[allow(clippy::disallowed_types)]
111fn collect_type_references(
112    value: &serde_json::Value,
113    references: &mut std::collections::HashSet<String>,
114) {
115    match value {
116        serde_json::Value::String(s) => {
117            // A string type that contains a dot is likely a fully qualified type reference
118            if s.contains('.')
119                && ![
120                    "null", "boolean", "int", "long", "float", "double", "bytes", "string",
121                ]
122                .contains(&s.as_str())
123            {
124                references.insert(s.clone());
125            }
126        }
127        serde_json::Value::Object(map) => {
128            // For named types, we want to recurse into the fields, but the named type doesn't
129            // get added to references.
130            if let Some(type_val) = map.get("type")
131                && type_val
132                    .as_str()
133                    .is_some_and(|typ| ["record", "enum", "fixed"].contains(&typ))
134            {
135                if let Some(fields) = map.get("fields") {
136                    collect_type_references(fields, references);
137                }
138                return;
139            }
140
141            // The following types may have references:
142            // type field, items (array types), values (map types), and fields (e.g. unions)
143            for entity_type in &["type", "items", "values", "fields"] {
144                if let Some(val) = map.get(*entity_type) {
145                    collect_type_references(val, references);
146                }
147            }
148        }
149        serde_json::Value::Array(arr) => {
150            for item in arr {
151                collect_type_references(item, references);
152            }
153        }
154        _ => {}
155    }
156}
157
158#[derive(Clone)]
159enum Format {
160    Avro {
161        schema: String,
162        confluent_wire_format: bool,
163        /// Schema references (subject names) for Confluent Schema Registry
164        references: Vec<String>,
165    },
166    Protobuf {
167        descriptor_file: String,
168        message: String,
169        confluent_wire_format: bool,
170        schema_id_subject: Option<String>,
171        schema_message_id: u8,
172    },
173    Bytes {
174        terminator: Option<u8>,
175    },
176}
177
178enum Transcoder {
179    PlainAvro {
180        schema: Schema,
181    },
182    ConfluentAvro {
183        schema: Schema,
184        schema_id: i32,
185    },
186    Protobuf {
187        message: MessageDescriptor,
188        confluent_wire_format: bool,
189        schema_id: i32,
190        schema_message_id: u8,
191    },
192    Bytes {
193        terminator: Option<u8>,
194    },
195}
196
197impl Transcoder {
198    fn decode_json<R, T>(row: R) -> Result<Option<T>, anyhow::Error>
199    where
200        R: Read,
201        T: DeserializeOwned,
202    {
203        let deserializer = serde_json::Deserializer::from_reader(row);
204        deserializer
205            .into_iter()
206            .next()
207            .transpose()
208            .context("parsing json")
209    }
210
211    fn transcode<R>(&self, mut row: R) -> Result<Option<Vec<u8>>, anyhow::Error>
212    where
213        R: BufRead,
214    {
215        match self {
216            Transcoder::ConfluentAvro { schema, schema_id } => {
217                if let Some(val) = Self::decode_json(row)? {
218                    let val = avro::from_json(&val, schema.top_node())?;
219                    let mut out = vec![];
220                    // The first byte is a magic byte (0) that indicates the Confluent
221                    // serialization format version, and the next four bytes are a
222                    // 32-bit schema ID.
223                    //
224                    // https://docs.confluent.io/3.3.0/schema-registry/docs/serializer-formatter.html#wire-format
225                    out.write_u8(0).unwrap();
226                    out.write_i32::<NetworkEndian>(*schema_id).unwrap();
227                    out.extend(avro::to_avro_datum(schema, val)?);
228                    Ok(Some(out))
229                } else {
230                    Ok(None)
231                }
232            }
233            Transcoder::PlainAvro { schema } => {
234                if let Some(val) = Self::decode_json(row)? {
235                    let val = avro::from_json(&val, schema.top_node())?;
236                    let mut out = vec![];
237                    out.extend(avro::to_avro_datum(schema, val)?);
238                    Ok(Some(out))
239                } else {
240                    Ok(None)
241                }
242            }
243            Transcoder::Protobuf {
244                message,
245                confluent_wire_format,
246                schema_id,
247                schema_message_id,
248            } => {
249                if let Some(val) = Self::decode_json::<_, serde_json::Value>(row)? {
250                    let message = DynamicMessage::deserialize(message.clone(), val)
251                        .context("parsing protobuf JSON")?;
252                    let mut out = vec![];
253                    if *confluent_wire_format {
254                        // See: https://github.com/MaterializeInc/database-issues/issues/2837
255                        // The first byte is a magic byte (0) that indicates the Confluent
256                        // serialization format version, and the next four bytes are a
257                        // 32-bit schema ID, which we default to something fun.
258                        // And, as we only support single-message proto files for now,
259                        // we also set the following message id to 0.
260                        out.write_u8(0).unwrap();
261                        out.write_i32::<NetworkEndian>(*schema_id).unwrap();
262                        out.write_u8(*schema_message_id).unwrap();
263                    }
264                    message.encode(&mut out)?;
265                    Ok(Some(out))
266                } else {
267                    Ok(None)
268                }
269            }
270            Transcoder::Bytes { terminator } => {
271                let mut out = vec![];
272                match terminator {
273                    Some(t) => {
274                        row.read_until(*t, &mut out)?;
275                        if out.last() == Some(t) {
276                            out.pop();
277                        }
278                    }
279                    None => {
280                        row.read_to_end(&mut out)?;
281                    }
282                }
283                if out.is_empty() {
284                    Ok(None)
285                } else {
286                    Ok(Some(bytes::unescape(&out)?))
287                }
288            }
289        }
290    }
291}
292
293pub async fn run_ingest(
294    mut cmd: BuiltinCommand,
295    state: &mut State,
296) -> Result<ControlFlow, anyhow::Error> {
297    let topic_prefix = format!("testdrive-{}", cmd.args.string("topic")?);
298    let partition = cmd.args.opt_parse::<i32>("partition")?;
299    let start_iteration = cmd.args.opt_parse::<isize>("start-iteration")?.unwrap_or(0);
300    let repeat = cmd.args.opt_parse::<isize>("repeat")?.unwrap_or(1);
301    let omit_key = cmd.args.opt_bool("omit-key")?.unwrap_or(false);
302    let omit_value = cmd.args.opt_bool("omit-value")?.unwrap_or(false);
303    let schema_id_var = cmd.args.opt_parse("set-schema-id-var")?;
304    let key_schema_id_var = cmd.args.opt_parse("set-key-schema-id-var")?;
305    let format = match cmd.args.string("format")?.as_str() {
306        "avro" => Format::Avro {
307            schema: cmd.args.string("schema")?,
308            confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
309            // TODO (maz): update README!
310            references: cmd
311                .args
312                .opt_string("references")
313                .map(|s| s.split(',').map(|s| s.to_string()).collect())
314                .unwrap_or_default(),
315        },
316        "protobuf" => {
317            let descriptor_file = cmd.args.string("descriptor-file")?;
318            let message = cmd.args.string("message")?;
319            Format::Protobuf {
320                descriptor_file,
321                message,
322                // This was introduced after the avro format's confluent-wire-format, so it defaults to
323                // false
324                confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
325                schema_id_subject: cmd.args.opt_string("schema-id-subject"),
326                schema_message_id: cmd.args.opt_parse::<u8>("schema-message-id")?.unwrap_or(0),
327            }
328        }
329        "bytes" => Format::Bytes { terminator: None },
330        f => bail!("unknown format: {}", f),
331    };
332    let mut key_schema = cmd.args.opt_string("key-schema");
333    let key_format = match cmd.args.opt_string("key-format").as_deref() {
334        Some("avro") => Some(Format::Avro {
335            schema: key_schema.take().ok_or_else(|| {
336                anyhow!("key-schema parameter required when key-format is present")
337            })?,
338            confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
339            references: cmd
340                .args
341                .opt_string("key-references")
342                .map(|s| s.split(',').map(|s| s.to_string()).collect())
343                .unwrap_or_default(),
344        }),
345        Some("protobuf") => {
346            let descriptor_file = cmd.args.string("key-descriptor-file")?;
347            let message = cmd.args.string("key-message")?;
348            Some(Format::Protobuf {
349                descriptor_file,
350                message,
351                confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
352                schema_id_subject: cmd.args.opt_string("key-schema-id-subject"),
353                schema_message_id: cmd
354                    .args
355                    .opt_parse::<u8>("key-schema-message-id")?
356                    .unwrap_or(0),
357            })
358        }
359        Some("bytes") => Some(Format::Bytes {
360            terminator: match cmd.args.opt_parse::<char>("key-terminator")? {
361                Some(c) => match u8::try_from(c) {
362                    Ok(c) => Some(c),
363                    Err(_) => bail!("key terminator must be single ASCII character"),
364                },
365                None => Some(b':'),
366            },
367        }),
368        Some(f) => bail!("unknown key format: {}", f),
369        None => None,
370    };
371    if key_schema.is_some() {
372        anyhow::bail!("key-schema specified without a matching key-format");
373    }
374
375    let timestamp = cmd.args.opt_parse("timestamp")?;
376
377    use serde_json::Value;
378    let headers = if let Some(headers_val) = cmd.args.opt_parse::<serde_json::Value>("headers")? {
379        let mut headers = Vec::new();
380        let headers_maps = match headers_val {
381            Value::Array(values) => {
382                let mut headers_map = Vec::new();
383                for value in values {
384                    if let Value::Object(m) = value {
385                        headers_map.push(m)
386                    } else {
387                        bail!("`headers` array values must be maps")
388                    }
389                }
390                headers_map
391            }
392            Value::Object(v) => vec![v],
393            _ => bail!("`headers` must be a map or an array"),
394        };
395
396        for headers_map in headers_maps {
397            for (k, v) in headers_map.iter() {
398                headers.push((k.clone(), match v {
399                    Value::String(val) => Some(val.as_bytes().to_vec()),
400                    Value::Array(val) => {
401                        let mut values = Vec::new();
402                        for value in val {
403                            if let Value::Number(int) = value {
404                                values.push(u8::try_from(int.as_i64().unwrap()).unwrap())
405                            } else {
406                                bail!("`headers` value arrays must only contain numbers (to represent bytes)")
407                            }
408                        }
409                        Some(values.clone())
410                    },
411                    Value::Null => None,
412                    _ => bail!("`headers` must have string, int array or null values")
413                }));
414            }
415        }
416        Some(headers)
417    } else {
418        None
419    };
420
421    cmd.args.done()?;
422
423    if let Some(kf) = &key_format {
424        fn is_confluent_format(fmt: &Format) -> Option<bool> {
425            match fmt {
426                Format::Avro {
427                    confluent_wire_format,
428                    ..
429                } => Some(*confluent_wire_format),
430                Format::Protobuf {
431                    confluent_wire_format,
432                    ..
433                } => Some(*confluent_wire_format),
434                Format::Bytes { .. } => None,
435            }
436        }
437        match (is_confluent_format(kf), is_confluent_format(&format)) {
438            (Some(false), Some(true)) | (Some(true), Some(false)) => {
439                bail!(
440                    "It does not make sense to have the key be in confluent format and not the value, or vice versa."
441                );
442            }
443            _ => {}
444        }
445    }
446
447    let topic_name = &format!("{}-{}", topic_prefix, state.seed);
448    println!(
449        "Ingesting data into Kafka topic {} with start_iteration = {}, repeat = {}",
450        topic_name, start_iteration, repeat
451    );
452
453    let set_schema_id_var = |state: &mut State, schema_id_var, transcoder| match transcoder {
454        &Transcoder::ConfluentAvro { schema_id, .. } | &Transcoder::Protobuf { schema_id, .. } => {
455            state.cmd_vars.insert(schema_id_var, schema_id.to_string());
456        }
457        _ => (),
458    };
459
460    let value_transcoder =
461        make_transcoder(state, format.clone(), format!("{}-value", topic_name)).await?;
462    if let Some(var) = schema_id_var {
463        set_schema_id_var(state, var, &value_transcoder);
464    }
465
466    let key_transcoder = match key_format.clone() {
467        None => None,
468        Some(f) => {
469            let transcoder = make_transcoder(state, f, format!("{}-key", topic_name)).await?;
470            if let Some(var) = key_schema_id_var {
471                set_schema_id_var(state, var, &transcoder);
472            }
473            Some(transcoder)
474        }
475    };
476
477    let mut futs = FuturesUnordered::new();
478
479    for iteration in start_iteration..(start_iteration + repeat) {
480        let iter = &mut cmd.input.iter().peekable();
481
482        for row in iter {
483            let row = action::substitute_vars(
484                row,
485                &btreemap! { "kafka-ingest.iteration".into() => iteration.to_string() },
486                &None,
487                false,
488            )?;
489            let mut row = row.as_bytes();
490            let key = match (omit_key, &key_transcoder) {
491                (true, _) => None,
492                (false, None) => None,
493                (false, Some(kt)) => kt.transcode(&mut row)?,
494            };
495            let value = if omit_value {
496                None
497            } else {
498                value_transcoder
499                    .transcode(&mut row)
500                    .with_context(|| format!("parsing row: {}", String::from_utf8_lossy(row)))?
501            };
502            let producer = &state.kafka_producer;
503            let timeout = cmp::max(state.default_timeout, Duration::from_secs(1));
504            let headers = headers.clone();
505            futs.push(async move {
506                let mut record: FutureRecord<_, _> = FutureRecord::to(topic_name);
507
508                if let Some(partition) = partition {
509                    record = record.partition(partition);
510                }
511                if let Some(key) = &key {
512                    record = record.key(key);
513                }
514                if let Some(value) = &value {
515                    record = record.payload(value);
516                }
517                if let Some(timestamp) = timestamp {
518                    record = record.timestamp(timestamp);
519                }
520                if let Some(headers) = headers {
521                    let mut rd_meta = OwnedHeaders::new();
522                    for (k, v) in &headers {
523                        rd_meta = rd_meta.insert(Header {
524                            key: k,
525                            value: v.as_deref(),
526                        });
527                    }
528                    record = record.headers(rd_meta);
529                }
530                producer.send(record, timeout).await
531            });
532        }
533
534        // Reap the futures thus produced periodically or after the last iteration
535        if iteration % INGEST_BATCH_SIZE == 0 || iteration == (start_iteration + repeat - 1) {
536            while let Some(res) = futs.next().await {
537                res.map_err(|(e, _message)| e)?;
538            }
539        }
540    }
541    Ok(ControlFlow::Continue)
542}
543
544async fn make_transcoder(
545    state: &State,
546    format: Format,
547    ccsr_subject: String,
548) -> Result<Transcoder, anyhow::Error> {
549    match format {
550        Format::Avro {
551            schema,
552            confluent_wire_format,
553            references,
554        } => {
555            if confluent_wire_format {
556                // Build references list by fetching each subject from the registry.
557                // Start with immediate references and automatically resolve transitive ones.
558                // We need ALL references for local parsing, but only DIRECT references for the registry.
559                #[allow(clippy::disallowed_types)]
560                let mut reference_subjects = vec![];
561                #[allow(clippy::disallowed_types)]
562                let mut seen_subjects: std::collections::HashSet<String> =
563                    std::collections::HashSet::new();
564                let mut queue: Vec<String> = references.clone();
565
566                // Process queue (as a stack), adding transitive dependencies as we discover them
567                while let Some(ref_name) = queue.pop() {
568                    if seen_subjects.contains(&ref_name) {
569                        continue;
570                    }
571                    seen_subjects.insert(ref_name.clone());
572
573                    let (subject, ref_deps) = state
574                        .ccsr_client
575                        .get_subject_with_references(&ref_name)
576                        .await
577                        .with_context(|| format!("fetching reference {}", ref_name))?;
578
579                    // Add newly discovered dependencies to the queue
580                    for dep in ref_deps {
581                        if !seen_subjects.contains(&dep.subject) {
582                            queue.push(dep.subject);
583                        }
584                    }
585
586                    // Extract ALL type names defined in this schema (including nested types)
587                    let defined_types = extract_all_defined_types(&subject.schema.raw)
588                        .with_context(|| {
589                            format!("extracting type names from reference schema {}", ref_name)
590                        })?;
591                    reference_subjects.push((
592                        ref_name,
593                        subject.version,
594                        subject.schema.raw,
595                        defined_types,
596                    ));
597                }
598
599                // Reverse to get dependency order: since we use a stack, dependencies are
600                // discovered and added after the schemas that depend on them, so reversing
601                // puts dependencies first (required for incremental schema parsing)
602                reference_subjects.reverse();
603
604                // Extract types directly referenced by the primary schema
605                let direct_refs = extract_type_references(&schema)
606                    .context("extracting type references from schema")?;
607
608                // For the registry, create a reference for each type in direct_refs
609                // that is defined in one of the reference subjects
610                let mut schema_references = vec![];
611                for type_name in &direct_refs {
612                    for (subject_name, version, _, defined_types) in &reference_subjects {
613                        if defined_types.contains(type_name) {
614                            schema_references.push(mz_ccsr::SchemaReference {
615                                name: type_name.clone(),
616                                subject: subject_name.clone(),
617                                version: *version,
618                            });
619                            break;
620                        }
621                    }
622                }
623
624                // For local parsing, we need all reference schemas
625                let reference_raw_schemas: Vec<_> = reference_subjects
626                    .into_iter()
627                    .map(|(_, _, raw, _)| raw)
628                    .collect();
629
630                let schema_id = state
631                    .ccsr_client
632                    .publish_schema(
633                        &ccsr_subject,
634                        &schema,
635                        mz_ccsr::SchemaType::Avro,
636                        &schema_references,
637                    )
638                    .await
639                    .context("publishing to schema registry")?;
640
641                // Parse schema, handling references if any
642                let schema = if reference_raw_schemas.is_empty() {
643                    avro::parse_schema(&schema, &[])
644                        .with_context(|| format!("parsing avro schema: {}", schema))?
645                } else {
646                    // Parse reference schemas incrementally (each may depend on previous ones).
647                    // References must be specified in dependency order (dependencies first).
648                    let mut parsed_refs: Vec<Schema> = vec![];
649                    for raw in &reference_raw_schemas {
650                        let schema_value: serde_json::Value = serde_json::from_str(raw)
651                            .with_context(|| format!("parsing reference schema JSON: {}", raw))?;
652                        let parsed = Schema::parse_with_references(&schema_value, &parsed_refs)
653                            .with_context(|| format!("parsing reference avro schema: {}", raw))?;
654                        parsed_refs.push(parsed);
655                    }
656
657                    // Parse primary schema with all reference types available
658                    let schema_value: serde_json::Value = serde_json::from_str(&schema)
659                        .with_context(|| format!("parsing schema JSON: {}", schema))?;
660                    Schema::parse_with_references(&schema_value, &parsed_refs).with_context(
661                        || format!("parsing avro schema with references: {}", schema),
662                    )?
663                };
664
665                Ok::<_, anyhow::Error>(Transcoder::ConfluentAvro { schema, schema_id })
666            } else {
667                let schema = avro::parse_schema(&schema, &[])
668                    .with_context(|| format!("parsing avro schema: {}", schema))?;
669                Ok(Transcoder::PlainAvro { schema })
670            }
671        }
672        Format::Protobuf {
673            descriptor_file,
674            message,
675            confluent_wire_format,
676            schema_id_subject,
677            schema_message_id,
678        } => {
679            let schema_id = if confluent_wire_format {
680                state
681                    .ccsr_client
682                    .get_schema_by_subject(schema_id_subject.as_deref().unwrap_or(&ccsr_subject))
683                    .await
684                    .context("fetching schema from registry")?
685                    .id
686            } else {
687                0
688            };
689
690            let bytes = fs::read(state.temp_path.join(descriptor_file))
691                .await
692                .context("reading protobuf descriptor file")?;
693            let fd = DescriptorPool::decode(&*bytes).context("parsing protobuf descriptor file")?;
694            let message = fd
695                .get_message_by_name(&message)
696                .ok_or_else(|| anyhow!("unknown message name {}", message))?;
697            Ok(Transcoder::Protobuf {
698                message,
699                confluent_wire_format,
700                schema_id,
701                schema_message_id,
702            })
703        }
704        Format::Bytes { terminator } => Ok(Transcoder::Bytes { terminator }),
705    }
706}