1use std::cmp;
11use std::io::{BufRead, Read};
12use std::time::Duration;
13
14use anyhow::{Context, anyhow, bail};
15use byteorder::{NetworkEndian, WriteBytesExt};
16use futures::stream::{FuturesUnordered, StreamExt};
17use maplit::btreemap;
18use prost::Message;
19use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
20use rdkafka::message::{Header, OwnedHeaders};
21use rdkafka::producer::FutureRecord;
22use serde::de::DeserializeOwned;
23use tokio::fs;
24
25use crate::action::{self, ControlFlow, State};
26use crate::format::avro::{self, Schema};
27use crate::format::bytes;
28use crate::parser::BuiltinCommand;
29
30const INGEST_BATCH_SIZE: isize = 10000;
31
32#[derive(Clone)]
33enum Format {
34 Avro {
35 schema: String,
36 confluent_wire_format: bool,
37 },
38 Protobuf {
39 descriptor_file: String,
40 message: String,
41 confluent_wire_format: bool,
42 schema_id_subject: Option<String>,
43 schema_message_id: u8,
44 },
45 Bytes {
46 terminator: Option<u8>,
47 },
48}
49
50enum Transcoder {
51 PlainAvro {
52 schema: Schema,
53 },
54 ConfluentAvro {
55 schema: Schema,
56 schema_id: i32,
57 },
58 Protobuf {
59 message: MessageDescriptor,
60 confluent_wire_format: bool,
61 schema_id: i32,
62 schema_message_id: u8,
63 },
64 Bytes {
65 terminator: Option<u8>,
66 },
67}
68
69impl Transcoder {
70 fn decode_json<R, T>(row: R) -> Result<Option<T>, anyhow::Error>
71 where
72 R: Read,
73 T: DeserializeOwned,
74 {
75 let deserializer = serde_json::Deserializer::from_reader(row);
76 deserializer
77 .into_iter()
78 .next()
79 .transpose()
80 .context("parsing json")
81 }
82
83 fn transcode<R>(&self, mut row: R) -> Result<Option<Vec<u8>>, anyhow::Error>
84 where
85 R: BufRead,
86 {
87 match self {
88 Transcoder::ConfluentAvro { schema, schema_id } => {
89 if let Some(val) = Self::decode_json(row)? {
90 let val = avro::from_json(&val, schema.top_node())?;
91 let mut out = vec![];
92 out.write_u8(0).unwrap();
98 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
99 out.extend(avro::to_avro_datum(schema, val)?);
100 Ok(Some(out))
101 } else {
102 Ok(None)
103 }
104 }
105 Transcoder::PlainAvro { schema } => {
106 if let Some(val) = Self::decode_json(row)? {
107 let val = avro::from_json(&val, schema.top_node())?;
108 let mut out = vec![];
109 out.extend(avro::to_avro_datum(schema, val)?);
110 Ok(Some(out))
111 } else {
112 Ok(None)
113 }
114 }
115 Transcoder::Protobuf {
116 message,
117 confluent_wire_format,
118 schema_id,
119 schema_message_id,
120 } => {
121 if let Some(val) = Self::decode_json::<_, serde_json::Value>(row)? {
122 let message = DynamicMessage::deserialize(message.clone(), val)
123 .context("parsing protobuf JSON")?;
124 let mut out = vec![];
125 if *confluent_wire_format {
126 out.write_u8(0).unwrap();
133 out.write_i32::<NetworkEndian>(*schema_id).unwrap();
134 out.write_u8(*schema_message_id).unwrap();
135 }
136 message.encode(&mut out)?;
137 Ok(Some(out))
138 } else {
139 Ok(None)
140 }
141 }
142 Transcoder::Bytes { terminator } => {
143 let mut out = vec![];
144 match terminator {
145 Some(t) => {
146 row.read_until(*t, &mut out)?;
147 if out.last() == Some(t) {
148 out.pop();
149 }
150 }
151 None => {
152 row.read_to_end(&mut out)?;
153 }
154 }
155 if out.is_empty() {
156 Ok(None)
157 } else {
158 Ok(Some(bytes::unescape(&out)?))
159 }
160 }
161 }
162 }
163}
164
165pub async fn run_ingest(
166 mut cmd: BuiltinCommand,
167 state: &mut State,
168) -> Result<ControlFlow, anyhow::Error> {
169 let topic_prefix = format!("testdrive-{}", cmd.args.string("topic")?);
170 let partition = cmd.args.opt_parse::<i32>("partition")?;
171 let start_iteration = cmd.args.opt_parse::<isize>("start-iteration")?.unwrap_or(0);
172 let repeat = cmd.args.opt_parse::<isize>("repeat")?.unwrap_or(1);
173 let omit_key = cmd.args.opt_bool("omit-key")?.unwrap_or(false);
174 let omit_value = cmd.args.opt_bool("omit-value")?.unwrap_or(false);
175 let schema_id_var = cmd.args.opt_parse("set-schema-id-var")?;
176 let key_schema_id_var = cmd.args.opt_parse("set-key-schema-id-var")?;
177 let format = match cmd.args.string("format")?.as_str() {
178 "avro" => Format::Avro {
179 schema: cmd.args.string("schema")?,
180 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
181 },
182 "protobuf" => {
183 let descriptor_file = cmd.args.string("descriptor-file")?;
184 let message = cmd.args.string("message")?;
185 Format::Protobuf {
186 descriptor_file,
187 message,
188 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
191 schema_id_subject: cmd.args.opt_string("schema-id-subject"),
192 schema_message_id: cmd.args.opt_parse::<u8>("schema-message-id")?.unwrap_or(0),
193 }
194 }
195 "bytes" => Format::Bytes { terminator: None },
196 f => bail!("unknown format: {}", f),
197 };
198 let mut key_schema = cmd.args.opt_string("key-schema");
199 let key_format = match cmd.args.opt_string("key-format").as_deref() {
200 Some("avro") => Some(Format::Avro {
201 schema: key_schema.take().ok_or_else(|| {
202 anyhow!("key-schema parameter required when key-format is present")
203 })?,
204 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(true),
205 }),
206 Some("protobuf") => {
207 let descriptor_file = cmd.args.string("key-descriptor-file")?;
208 let message = cmd.args.string("key-message")?;
209 Some(Format::Protobuf {
210 descriptor_file,
211 message,
212 confluent_wire_format: cmd.args.opt_bool("confluent-wire-format")?.unwrap_or(false),
213 schema_id_subject: cmd.args.opt_string("key-schema-id-subject"),
214 schema_message_id: cmd
215 .args
216 .opt_parse::<u8>("key-schema-message-id")?
217 .unwrap_or(0),
218 })
219 }
220 Some("bytes") => Some(Format::Bytes {
221 terminator: match cmd.args.opt_parse::<char>("key-terminator")? {
222 Some(c) => match u8::try_from(c) {
223 Ok(c) => Some(c),
224 Err(_) => bail!("key terminator must be single ASCII character"),
225 },
226 None => Some(b':'),
227 },
228 }),
229 Some(f) => bail!("unknown key format: {}", f),
230 None => None,
231 };
232 if key_schema.is_some() {
233 anyhow::bail!("key-schema specified without a matching key-format");
234 }
235
236 let timestamp = cmd.args.opt_parse("timestamp")?;
237
238 use serde_json::Value;
239 let headers = if let Some(headers_val) = cmd.args.opt_parse::<serde_json::Value>("headers")? {
240 let mut headers = Vec::new();
241 let headers_maps = match headers_val {
242 Value::Array(values) => {
243 let mut headers_map = Vec::new();
244 for value in values {
245 if let Value::Object(m) = value {
246 headers_map.push(m)
247 } else {
248 bail!("`headers` array values must be maps")
249 }
250 }
251 headers_map
252 }
253 Value::Object(v) => vec![v],
254 _ => bail!("`headers` must be a map or an array"),
255 };
256
257 for headers_map in headers_maps {
258 for (k, v) in headers_map.iter() {
259 headers.push((k.clone(), match v {
260 Value::String(val) => Some(val.as_bytes().to_vec()),
261 Value::Array(val) => {
262 let mut values = Vec::new();
263 for value in val {
264 if let Value::Number(int) = value {
265 values.push(u8::try_from(int.as_i64().unwrap()).unwrap())
266 } else {
267 bail!("`headers` value arrays must only contain numbers (to represent bytes)")
268 }
269 }
270 Some(values.clone())
271 },
272 Value::Null => None,
273 _ => bail!("`headers` must have string, int array or null values")
274 }));
275 }
276 }
277 Some(headers)
278 } else {
279 None
280 };
281
282 cmd.args.done()?;
283
284 if let Some(kf) = &key_format {
285 fn is_confluent_format(fmt: &Format) -> Option<bool> {
286 match fmt {
287 Format::Avro {
288 confluent_wire_format,
289 ..
290 } => Some(*confluent_wire_format),
291 Format::Protobuf {
292 confluent_wire_format,
293 ..
294 } => Some(*confluent_wire_format),
295 Format::Bytes { .. } => None,
296 }
297 }
298 match (is_confluent_format(kf), is_confluent_format(&format)) {
299 (Some(false), Some(true)) | (Some(true), Some(false)) => {
300 bail!(
301 "It does not make sense to have the key be in confluent format and not the value, or vice versa."
302 );
303 }
304 _ => {}
305 }
306 }
307
308 let topic_name = &format!("{}-{}", topic_prefix, state.seed);
309 println!(
310 "Ingesting data into Kafka topic {} with start_iteration = {}, repeat = {}",
311 topic_name, start_iteration, repeat
312 );
313
314 let set_schema_id_var = |state: &mut State, schema_id_var, transcoder| match transcoder {
315 &Transcoder::ConfluentAvro { schema_id, .. } | &Transcoder::Protobuf { schema_id, .. } => {
316 state.cmd_vars.insert(schema_id_var, schema_id.to_string());
317 }
318 _ => (),
319 };
320
321 let value_transcoder =
322 make_transcoder(state, format.clone(), format!("{}-value", topic_name)).await?;
323 if let Some(var) = schema_id_var {
324 set_schema_id_var(state, var, &value_transcoder);
325 }
326
327 let key_transcoder = match key_format.clone() {
328 None => None,
329 Some(f) => {
330 let transcoder = make_transcoder(state, f, format!("{}-key", topic_name)).await?;
331 if let Some(var) = key_schema_id_var {
332 set_schema_id_var(state, var, &transcoder);
333 }
334 Some(transcoder)
335 }
336 };
337
338 let mut futs = FuturesUnordered::new();
339
340 for iteration in start_iteration..(start_iteration + repeat) {
341 let iter = &mut cmd.input.iter().peekable();
342
343 for row in iter {
344 let row = action::substitute_vars(
345 row,
346 &btreemap! { "kafka-ingest.iteration".into() => iteration.to_string() },
347 &None,
348 false,
349 )?;
350 let mut row = row.as_bytes();
351 let key = match (omit_key, &key_transcoder) {
352 (true, _) => None,
353 (false, None) => None,
354 (false, Some(kt)) => kt.transcode(&mut row)?,
355 };
356 let value = if omit_value {
357 None
358 } else {
359 value_transcoder
360 .transcode(&mut row)
361 .with_context(|| format!("parsing row: {}", String::from_utf8_lossy(row)))?
362 };
363 let producer = &state.kafka_producer;
364 let timeout = cmp::max(state.default_timeout, Duration::from_secs(1));
365 let headers = headers.clone();
366 futs.push(async move {
367 let mut record: FutureRecord<_, _> = FutureRecord::to(topic_name);
368
369 if let Some(partition) = partition {
370 record = record.partition(partition);
371 }
372 if let Some(key) = &key {
373 record = record.key(key);
374 }
375 if let Some(value) = &value {
376 record = record.payload(value);
377 }
378 if let Some(timestamp) = timestamp {
379 record = record.timestamp(timestamp);
380 }
381 if let Some(headers) = headers {
382 let mut rd_meta = OwnedHeaders::new();
383 for (k, v) in &headers {
384 rd_meta = rd_meta.insert(Header {
385 key: k,
386 value: v.as_deref(),
387 });
388 }
389 record = record.headers(rd_meta);
390 }
391 producer.send(record, timeout).await
392 });
393 }
394
395 if iteration % INGEST_BATCH_SIZE == 0 || iteration == (start_iteration + repeat - 1) {
397 while let Some(res) = futs.next().await {
398 res.map_err(|(e, _message)| e)?;
399 }
400 }
401 }
402 Ok(ControlFlow::Continue)
403}
404
405async fn make_transcoder(
406 state: &State,
407 format: Format,
408 ccsr_subject: String,
409) -> Result<Transcoder, anyhow::Error> {
410 match format {
411 Format::Avro {
412 schema,
413 confluent_wire_format,
414 } => {
415 if confluent_wire_format {
416 let schema_id = state
417 .ccsr_client
418 .publish_schema(&ccsr_subject, &schema, mz_ccsr::SchemaType::Avro, &[])
419 .await
420 .context("publishing to schema registry")?;
421 let schema = avro::parse_schema(&schema)
422 .with_context(|| format!("parsing avro schema: {}", schema))?;
423 Ok::<_, anyhow::Error>(Transcoder::ConfluentAvro { schema, schema_id })
424 } else {
425 let schema = avro::parse_schema(&schema)
426 .with_context(|| format!("parsing avro schema: {}", schema))?;
427 Ok(Transcoder::PlainAvro { schema })
428 }
429 }
430 Format::Protobuf {
431 descriptor_file,
432 message,
433 confluent_wire_format,
434 schema_id_subject,
435 schema_message_id,
436 } => {
437 let schema_id = if confluent_wire_format {
438 state
439 .ccsr_client
440 .get_schema_by_subject(schema_id_subject.as_deref().unwrap_or(&ccsr_subject))
441 .await
442 .context("fetching schema from registry")?
443 .id
444 } else {
445 0
446 };
447
448 let bytes = fs::read(state.temp_path.join(descriptor_file))
449 .await
450 .context("reading protobuf descriptor file")?;
451 let fd = DescriptorPool::decode(&*bytes).context("parsing protobuf descriptor file")?;
452 let message = fd
453 .get_message_by_name(&message)
454 .ok_or_else(|| anyhow!("unknown message name {}", message))?;
455 Ok(Transcoder::Protobuf {
456 message,
457 confluent_wire_format,
458 schema_id,
459 schema_message_id,
460 })
461 }
462 Format::Bytes { terminator } => Ok(Transcoder::Bytes { terminator }),
463 }
464}