mz_testdrive/action/kafka/
ingest.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cmp;
11use std::io::{BufRead, Read};
12use std::time::Duration;
13
14use anyhow::{Context, anyhow, bail};
15use byteorder::{NetworkEndian, WriteBytesExt};
16use futures::stream::{FuturesUnordered, StreamExt};
17use maplit::btreemap;
18use prost::Message;
19use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
20use rdkafka::message::{Header, OwnedHeaders};
21use rdkafka::producer::FutureRecord;
22use serde::de::DeserializeOwned;
23use tokio::fs;
24
25use crate::action::{self, ControlFlow, State};
26use crate::format::avro::{self, Schema};
27use crate::format::bytes;
28use crate::parser::BuiltinCommand;
29
30const INGEST_BATCH_SIZE: isize = 10000;
31
32#[derive(Clone)]
33enum Format {
34    Avro {
35        schema: String,
36        confluent_wire_format: bool,
37    },
38    Protobuf {
39        descriptor_file: String,
40        message: String,
41        confluent_wire_format: bool,
42        schema_id_subject: Option<String>,
43        schema_message_id: u8,
44    },
45    Bytes {
46        terminator: Option<u8>,
47    },
48}
49
50enum Transcoder {
51    PlainAvro {
52        schema: Schema,
53    },
54    ConfluentAvro {
55        schema: Schema,
56        schema_id: i32,
57    },
58    Protobuf {
59        message: MessageDescriptor,
60        confluent_wire_format: bool,
61        schema_id: i32,
62        schema_message_id: u8,
63    },
64    Bytes {
65        terminator: Option<u8>,
66    },
67}
68
69impl Transcoder {
70    fn decode_json<R, T>(row: R) -> Result<Option<T>, anyhow::Error>
71    where
72        R: Read,
73        T: DeserializeOwned,
74    {
75        let deserializer = serde_json::Deserializer::from_reader(row);
76        deserializer
77            .into_iter()
78            .next()
79            .transpose()
80            .context("parsing json")
81    }
82
83    fn transcode<R>(&self, mut row: R) -> Result<Option<Vec<u8>>, anyhow::Error>
84    where
85        R: BufRead,
86    {
87        match self {
88            Transcoder::ConfluentAvro { schema, schema_id } => {
89                if let Some(val) = Self::decode_json(row)? {
90                    let val = avro::from_json(&val, schema.top_node())?;
91                    let mut out = vec![];
92                    // The first byte is a magic byte (0) that indicates the Confluent
93                    // serialization format version, and the next four bytes are a
94                    // 32-bit schema ID.
95                    //
96                    // https://docs.confluent.io/3.3.0/schema-registry/docs/serializer-formatter.html#wire-format
97                    out.write_u8(0).unwrap();
98                    out.write_i32::<NetworkEndian>(*schema_id).unwrap();
99                    out.extend(avro::to_avro_datum(schema, val)?);
100                    Ok(Some(out))
101                } else {
102                    Ok(None)
103                }
104            }
105            Transcoder::PlainAvro { schema } => {
106                if let Some(val) = Self::decode_json(row)? {
107                    let val = avro::from_json(&val, schema.top_node())?;
108                    let mut out = vec![];
109                    out.extend(avro::to_avro_datum(schema, val)?);
110                    Ok(Some(out))
111                } else {
112                    Ok(None)
113                }
114            }
115            Transcoder::Protobuf {
116                message,
117                confluent_wire_format,
118                schema_id,
119                schema_message_id,
120            } => {
121                if let Some(val) = Self::decode_json::<_, serde_json::Value>(row)? {
122                    let message = DynamicMessage::deserialize(message.clone(), val)
123                        .context("parsing protobuf JSON")?;
124                    let mut out = vec![];
125                    if *confluent_wire_format {
126                        // See: https://github.com/MaterializeInc/database-issues/issues/2837
127                        // The first byte is a magic byte (0) that indicates the Confluent
128                        // serialization format version, and the next four bytes are a
129                        // 32-bit schema ID, which we default to something fun.
130                        // And, as we only support single-message proto files for now,
131                        // we also set the following message id to 0.
132                        out.write_u8(0).unwrap();
133                        out.write_i32::<NetworkEndian>(*schema_id).unwrap();
134                        out.write_u8(*schema_message_id).unwrap();
135                    }
136                    message.encode(&mut out)?;
137                    Ok(Some(out))
138                } else {
139                    Ok(None)
140                }
141            }
142            Transcoder::Bytes { terminator } => {
143                let mut out = vec![];
144                match terminator {
145                    Some(t) => {
146                        row.read_until(*t, &mut out)?;
147                        if out.last() == Some(t) {
148                            out.pop();
149                        }
150                    }
151                    None => {
152                        row.read_to_end(&mut out)?;
153                    }
154                }
155                if out.is_empty() {
156                    Ok(None)
157                } else {
158                    Ok(Some(bytes::unescape(&out)?))
159                }
160            }
161        }
162    }
163}
164
165pub async fn run_ingest(
166    mut cmd: BuiltinCommand,
167    state: &mut State,
168) -> Result<ControlFlow, anyhow::Error> {
169    let topic_prefix = format!("testdrive-{}", cmd.args.string("topic")?);
170    let partition = cmd.args.opt_parse::<i32>("partition")?;
171    let start_iteration = cmd.args.opt_parse::<isize>("start-iteration")?.unwrap_or(0);
172    let repeat = cmd.args.opt_parse::<isize>("repeat")?.unwrap_or(1);
173    let omit_key = cmd.args.opt_bool("omit-key")?.unwrap_or(false);
174    let omit_value = cmd.args.opt_bool("omit-value")?.unwrap_or(false);
175    let schema_id_var = cmd.args.opt_parse("set-schema-id-var")?;
176    let key_schema_id_var = cmd.args.opt_parse("set-key-schema-id-var")?;
177    let format = match cmd.args.string("format")?.as_str() {
178        "avro" => Format::Avro {
179            schema: cmd.args.string("schema")?,
180            confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
181        },
182        "protobuf" => {
183            let descriptor_file = cmd.args.string("descriptor-file")?;
184            let message = cmd.args.string("message")?;
185            Format::Protobuf {
186                descriptor_file,
187                message,
188                // This was introduced after the avro format's confluent-wire-format, so it defaults to
189                // false
190                confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
191                schema_id_subject: cmd.args.opt_string("schema-id-subject"),
192                schema_message_id: cmd.args.opt_parse::<u8>("schema-message-id")?.unwrap_or(0),
193            }
194        }
195        "bytes" => Format::Bytes { terminator: None },
196        f => bail!("unknown format: {}", f),
197    };
198    let mut key_schema = cmd.args.opt_string("key-schema");
199    let key_format = match cmd.args.opt_string("key-format").as_deref() {
200        Some("avro") => Some(Format::Avro {
201            schema: key_schema.take().ok_or_else(|| {
202                anyhow!("key-schema parameter required when key-format is present")
203            })?,
204            confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
205        }),
206        Some("protobuf") => {
207            let descriptor_file = cmd.args.string("key-descriptor-file")?;
208            let message = cmd.args.string("key-message")?;
209            Some(Format::Protobuf {
210                descriptor_file,
211                message,
212                confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
213                schema_id_subject: cmd.args.opt_string("key-schema-id-subject"),
214                schema_message_id: cmd
215                    .args
216                    .opt_parse::<u8>("key-schema-message-id")?
217                    .unwrap_or(0),
218            })
219        }
220        Some("bytes") => Some(Format::Bytes {
221            terminator: match cmd.args.opt_parse::<char>("key-terminator")? {
222                Some(c) => match u8::try_from(c) {
223                    Ok(c) => Some(c),
224                    Err(_) => bail!("key terminator must be single ASCII character"),
225                },
226                None => Some(b':'),
227            },
228        }),
229        Some(f) => bail!("unknown key format: {}", f),
230        None => None,
231    };
232    if key_schema.is_some() {
233        anyhow::bail!("key-schema specified without a matching key-format");
234    }
235
236    let timestamp = cmd.args.opt_parse("timestamp")?;
237
238    use serde_json::Value;
239    let headers = if let Some(headers_val) = cmd.args.opt_parse::<serde_json::Value>("headers")? {
240        let mut headers = Vec::new();
241        let headers_maps = match headers_val {
242            Value::Array(values) => {
243                let mut headers_map = Vec::new();
244                for value in values {
245                    if let Value::Object(m) = value {
246                        headers_map.push(m)
247                    } else {
248                        bail!("`headers` array values must be maps")
249                    }
250                }
251                headers_map
252            }
253            Value::Object(v) => vec![v],
254            _ => bail!("`headers` must be a map or an array"),
255        };
256
257        for headers_map in headers_maps {
258            for (k, v) in headers_map.iter() {
259                headers.push((k.clone(), match v {
260                    Value::String(val) => Some(val.as_bytes().to_vec()),
261                    Value::Array(val) => {
262                        let mut values = Vec::new();
263                        for value in val {
264                            if let Value::Number(int) = value {
265                                values.push(u8::try_from(int.as_i64().unwrap()).unwrap())
266                            } else {
267                                bail!("`headers` value arrays must only contain numbers (to represent bytes)")
268                            }
269                        }
270                        Some(values.clone())
271                    },
272                    Value::Null => None,
273                    _ => bail!("`headers` must have string, int array or null values")
274                }));
275            }
276        }
277        Some(headers)
278    } else {
279        None
280    };
281
282    cmd.args.done()?;
283
284    if let Some(kf) = &key_format {
285        fn is_confluent_format(fmt: &Format) -> Option<bool> {
286            match fmt {
287                Format::Avro {
288                    confluent_wire_format,
289                    ..
290                } => Some(*confluent_wire_format),
291                Format::Protobuf {
292                    confluent_wire_format,
293                    ..
294                } => Some(*confluent_wire_format),
295                Format::Bytes { .. } => None,
296            }
297        }
298        match (is_confluent_format(kf), is_confluent_format(&format)) {
299            (Some(false), Some(true)) | (Some(true), Some(false)) => {
300                bail!(
301                    "It does not make sense to have the key be in confluent format and not the value, or vice versa."
302                );
303            }
304            _ => {}
305        }
306    }
307
308    let topic_name = &format!("{}-{}", topic_prefix, state.seed);
309    println!(
310        "Ingesting data into Kafka topic {} with start_iteration = {}, repeat = {}",
311        topic_name, start_iteration, repeat
312    );
313
314    let set_schema_id_var = |state: &mut State, schema_id_var, transcoder| match transcoder {
315        &Transcoder::ConfluentAvro { schema_id, .. } | &Transcoder::Protobuf { schema_id, .. } => {
316            state.cmd_vars.insert(schema_id_var, schema_id.to_string());
317        }
318        _ => (),
319    };
320
321    let value_transcoder =
322        make_transcoder(state, format.clone(), format!("{}-value", topic_name)).await?;
323    if let Some(var) = schema_id_var {
324        set_schema_id_var(state, var, &value_transcoder);
325    }
326
327    let key_transcoder = match key_format.clone() {
328        None => None,
329        Some(f) => {
330            let transcoder = make_transcoder(state, f, format!("{}-key", topic_name)).await?;
331            if let Some(var) = key_schema_id_var {
332                set_schema_id_var(state, var, &transcoder);
333            }
334            Some(transcoder)
335        }
336    };
337
338    let mut futs = FuturesUnordered::new();
339
340    for iteration in start_iteration..(start_iteration + repeat) {
341        let iter = &mut cmd.input.iter().peekable();
342
343        for row in iter {
344            let row = action::substitute_vars(
345                row,
346                &btreemap! { "kafka-ingest.iteration".into() => iteration.to_string() },
347                &None,
348                false,
349            )?;
350            let mut row = row.as_bytes();
351            let key = match (omit_key, &key_transcoder) {
352                (true, _) => None,
353                (false, None) => None,
354                (false, Some(kt)) => kt.transcode(&mut row)?,
355            };
356            let value = if omit_value {
357                None
358            } else {
359                value_transcoder
360                    .transcode(&mut row)
361                    .with_context(|| format!("parsing row: {}", String::from_utf8_lossy(row)))?
362            };
363            let producer = &state.kafka_producer;
364            let timeout = cmp::max(state.default_timeout, Duration::from_secs(1));
365            let headers = headers.clone();
366            futs.push(async move {
367                let mut record: FutureRecord<_, _> = FutureRecord::to(topic_name);
368
369                if let Some(partition) = partition {
370                    record = record.partition(partition);
371                }
372                if let Some(key) = &key {
373                    record = record.key(key);
374                }
375                if let Some(value) = &value {
376                    record = record.payload(value);
377                }
378                if let Some(timestamp) = timestamp {
379                    record = record.timestamp(timestamp);
380                }
381                if let Some(headers) = headers {
382                    let mut rd_meta = OwnedHeaders::new();
383                    for (k, v) in &headers {
384                        rd_meta = rd_meta.insert(Header {
385                            key: k,
386                            value: v.as_deref(),
387                        });
388                    }
389                    record = record.headers(rd_meta);
390                }
391                producer.send(record, timeout).await
392            });
393        }
394
395        // Reap the futures thus produced periodically or after the last iteration
396        if iteration % INGEST_BATCH_SIZE == 0 || iteration == (start_iteration + repeat - 1) {
397            while let Some(res) = futs.next().await {
398                res.map_err(|(e, _message)| e)?;
399            }
400        }
401    }
402    Ok(ControlFlow::Continue)
403}
404
405async fn make_transcoder(
406    state: &State,
407    format: Format,
408    ccsr_subject: String,
409) -> Result<Transcoder, anyhow::Error> {
410    match format {
411        Format::Avro {
412            schema,
413            confluent_wire_format,
414        } => {
415            if confluent_wire_format {
416                let schema_id = state
417                    .ccsr_client
418                    .publish_schema(&ccsr_subject, &schema, mz_ccsr::SchemaType::Avro, &[])
419                    .await
420                    .context("publishing to schema registry")?;
421                let schema = avro::parse_schema(&schema)
422                    .with_context(|| format!("parsing avro schema: {}", schema))?;
423                Ok::<_, anyhow::Error>(Transcoder::ConfluentAvro { schema, schema_id })
424            } else {
425                let schema = avro::parse_schema(&schema)
426                    .with_context(|| format!("parsing avro schema: {}", schema))?;
427                Ok(Transcoder::PlainAvro { schema })
428            }
429        }
430        Format::Protobuf {
431            descriptor_file,
432            message,
433            confluent_wire_format,
434            schema_id_subject,
435            schema_message_id,
436        } => {
437            let schema_id = if confluent_wire_format {
438                state
439                    .ccsr_client
440                    .get_schema_by_subject(schema_id_subject.as_deref().unwrap_or(&ccsr_subject))
441                    .await
442                    .context("fetching schema from registry")?
443                    .id
444            } else {
445                0
446            };
447
448            let bytes = fs::read(state.temp_path.join(descriptor_file))
449                .await
450                .context("reading protobuf descriptor file")?;
451            let fd = DescriptorPool::decode(&*bytes).context("parsing protobuf descriptor file")?;
452            let message = fd
453                .get_message_by_name(&message)
454                .ok_or_else(|| anyhow!("unknown message name {}", message))?;
455            Ok(Transcoder::Protobuf {
456                message,
457                confluent_wire_format,
458                schema_id,
459                schema_message_id,
460            })
461        }
462        Format::Bytes { terminator } => Ok(Transcoder::Bytes { terminator }),
463    }
464}