1use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use rdkafka::consumer::{Consumer, StreamConsumer};
16use rdkafka::error::KafkaError;
17use rdkafka::message::{Headers, Message};
18use rdkafka::types::RDKafkaErrorCode;
19use regex::Regex;
20use tokio::pin;
21use tokio_stream::StreamExt;
22
23use crate::action::{ControlFlow, State};
24use crate::format::avro::{self, DebugValue};
25use crate::parser::BuiltinCommand;
26
27#[derive(Debug, Clone, Copy)]
28enum Format {
29 Avro,
30 Json,
31 Bytes,
32 Text,
33}
34
35impl TryFrom<&str> for Format {
36 type Error = anyhow::Error;
37
38 fn try_from(value: &str) -> Result<Self, Self::Error> {
39 match value {
40 "avro" => Ok(Format::Avro),
41 "json" => Ok(Format::Json),
42 "bytes" => Ok(Format::Bytes),
43 "text" => Ok(Format::Text),
44 f => bail!("unknown format: {}", f),
45 }
46 }
47}
48
49#[derive(Debug)]
50struct RecordFormat {
51 key: Format,
52 value: Format,
53 requires_key: bool,
54}
55
56#[allow(dead_code)]
57#[derive(Debug, Clone)]
58enum DecodedValue {
59 Avro(DebugValue),
60 Json(serde_json::Value),
61 Bytes(Vec<u8>),
62 Text(String),
63}
64
65enum Topic {
66 FromSink(String),
67 Named(String),
68}
69
70#[derive(Debug, Clone)]
71struct Record<A> {
72 headers: Vec<String>,
73 key: Option<A>,
74 value: Option<A>,
75 partition: Option<i32>,
76}
77
78async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
79 let query = format!(
80 "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
81 ON mz_sinks.id = mz_kafka_sinks.id \
82 JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
83 LEFT JOIN mz_databases d ON d.id = s.database_id \
84 WHERE d.name = $1 \
85 AND s.name = $2 \
86 AND mz_sinks.name = $3",
87 topic_field
88 );
89 let sink_fields: Vec<&str> = sink.split('.').collect();
90 let result = state
91 .materialize
92 .pgclient
93 .query_one(
94 query.as_str(),
95 &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
96 )
97 .await
98 .context("retrieving topic name")?
99 .get(topic_field);
100 Ok(result)
101}
102
103pub async fn run_verify_data(
104 mut cmd: BuiltinCommand,
105 state: &State,
106) -> Result<ControlFlow, anyhow::Error> {
107 let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
108 let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
112 let format_type = format_str.as_str().try_into()?;
113 RecordFormat {
114 key: format_type,
115 value: format_type,
116 requires_key,
117 }
118 } else {
119 let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
120 let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
121 RecordFormat {
122 key: key_format,
123 value: value_format,
124 requires_key: true,
125 }
126 };
127
128 let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
129 (Some(sink), None) => Topic::FromSink(sink),
130 (None, Some(topic)) => Topic::Named(topic),
131 (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
132 (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
133 };
134
135 let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
136
137 let header_keys: Vec<_> = cmd
138 .args
139 .opt_string("headers")
140 .map(|s| s.split(',').map(str::to_owned).collect())
141 .unwrap_or_default();
142
143 let expected_messages = cmd.input;
144 if expected_messages.len() == 0 {
145 bail!("kafka-verify-data requires a non-empty list of expected messages");
148 }
149 let partial_search = cmd.args.opt_parse("partial-search")?;
150 let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
151 cmd.args.done()?;
152
153 let topic: String = match &source {
154 Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
155 Topic::Named(name) => name.clone(),
156 };
157
158 println!("Verifying results in Kafka topic {}", topic);
159
160 let mut config = state.kafka_config.clone();
161 config.set("enable.auto.offset.store", "false");
162
163 let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
164 consumer
165 .subscribe(&[&topic])
166 .context("subscribing to kafka topic")?;
167
168 let (mut stream_messages_remaining, stream_timeout) = match partial_search {
169 Some(size) => (size, state.timeout),
170 None => (expected_messages.len(), Duration::from_secs(15)),
171 };
172
173 let timeout = cmp::max(state.timeout, stream_timeout);
174
175 let message_stream = consumer.stream().timeout(timeout);
176 pin!(message_stream);
177
178 let mut actual_bytes = vec![];
184
185 let start = std::time::Instant::now();
186 let mut topic_created = false;
187
188 while stream_messages_remaining > 0 {
189 match message_stream.next().await {
190 Some(Ok(message)) => {
191 let message = match message {
192 Err(KafkaError::MessageConsumption(
195 RDKafkaErrorCode::UnknownTopicOrPartition,
196 )) if start.elapsed() < timeout && !topic_created => {
197 println!("waiting for Kafka topic creation...");
198 continue;
199 }
200 e => e?,
201 };
202
203 stream_messages_remaining -= 1;
204 topic_created = true;
205
206 consumer
207 .store_offset_from_message(&message)
208 .context("storing message offset")?;
209
210 let mut headers = vec![];
211 for header_key in &header_keys {
212 let hs = message.headers().context("expected headers for message")?;
214 let mut hs = hs.iter().filter(|i| i.key == header_key);
215 let h = hs.next();
216 if hs.next().is_some() {
217 bail!("expected at most one header with key {header_key}");
218 }
219 match h {
220 None => headers.push("<missing>".into()),
221 Some(h) => {
222 let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
223 headers.push(value.into());
224 }
225 }
226 }
227
228 actual_bytes.push(Record {
229 headers,
230 key: message.key().map(|b| b.to_owned()),
231 value: message.payload().map(|b| b.to_owned()),
232 partition: Some(message.partition()),
233 });
234 }
235 Some(Err(e)) => {
236 println!("Received error from Kafka stream consumer: {}", e);
237 break;
238 }
239 None => {
240 break;
241 }
242 }
243 }
244
245 let key_schema = if let Format::Avro = format.key {
246 let schema = state
247 .ccsr_client
248 .get_schema_by_subject(&format!("{}-key", topic))
249 .await
250 .ok()
251 .map(|key_schema| {
252 avro::parse_schema(&key_schema.raw, &[]).context("parsing avro schema")
253 })
254 .transpose()?;
255 if schema.is_some() {
258 format.requires_key = true;
259 }
260 schema
261 } else {
262 None
263 };
264 let value_schema = if let Format::Avro = format.value {
265 let val_schema = state
266 .ccsr_client
267 .get_schema_by_subject(&format!("{}-value", topic))
268 .await
269 .context("fetching schema")?
270 .raw;
271 Some(avro::parse_schema(&val_schema, &[]).context("parsing avro schema")?)
272 } else {
273 None
274 };
275
276 let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
277
278 if sort_messages {
279 actual_messages.sort_by_key(|r| format!("{:?}", r));
280 }
281
282 if debug_print_only {
283 bail!(
284 "records in sink:\n{}",
285 actual_messages
286 .into_iter()
287 .map(|a| format!("{:#?}", a))
288 .collect::<Vec<_>>()
289 .join("\n")
290 );
291 }
292
293 let expected = parse_expected_messages(
294 expected_messages,
295 key_schema,
296 value_schema,
297 &format,
298 &header_keys,
299 )?;
300
301 verify_with_partial_search(
302 &expected,
303 &actual_messages,
304 &state.regex,
305 &state.regex_replacement,
306 partial_search.is_some(),
307 )?;
308
309 Ok(ControlFlow::Continue)
310}
311
312fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
314 let whitespace = Regex::new("\\s+").expect("building known-valid regex");
315 let mut parts = whitespace.splitn(input, n_headers + 1);
316 let mut headers = Vec::with_capacity(n_headers);
317 for _ in 0..n_headers {
318 headers.push(
319 parts
320 .next()
321 .context("expected another header in the input")?
322 .to_string(),
323 )
324 }
325 let rest = parts
326 .next()
327 .context("expected some contents after any message headers")?;
328
329 ensure!(
330 parts.next().is_none(),
331 "more than n+1 elements from a call to splitn(_, n+1)"
332 );
333
334 Ok((headers, rest))
335}
336
337fn decode_messages(
338 actual_bytes: Vec<Record<Vec<u8>>>,
339 key_schema: &Option<mz_avro::Schema>,
340 value_schema: &Option<mz_avro::Schema>,
341 format: &RecordFormat,
342) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
343 let mut actual_messages = vec![];
344
345 for record in actual_bytes {
346 let Record { key, value, .. } = record;
347 let key = if format.requires_key {
348 match (key, format.key) {
349 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
350 avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
351 ))),
352 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
353 serde_json::from_slice(&bytes).context("decoding json")?,
354 )),
355 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
356 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
357 (None, _) if format.requires_key => bail!("empty message key"),
358 (None, _) => None,
359 }
360 } else {
361 None
362 };
363
364 let value = match (value, format.value) {
365 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
366 avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
367 ))),
368 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
369 serde_json::from_slice(&bytes).context("decoding json")?,
370 )),
371 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
372 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
373 (None, _) => None,
374 };
375
376 actual_messages.push(Record {
377 headers: record.headers.clone(),
378 key,
379 value,
380 partition: record.partition,
381 });
382 }
383
384 Ok(actual_messages)
385}
386
387fn parse_expected_messages(
388 expected_messages: Vec<String>,
389 key_schema: Option<mz_avro::Schema>,
390 value_schema: Option<mz_avro::Schema>,
391 format: &RecordFormat,
392 header_keys: &[String],
393) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
394 let mut expected = vec![];
395
396 for msg in expected_messages {
397 let (headers, content) = split_headers(&msg, header_keys.len())?;
398 let mut content = content.as_bytes();
399 let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
400
401 let key = if format.requires_key {
402 let key: serde_json::Value = deserializer
403 .next()
404 .context("key missing in input line")?
405 .context("parsing json")?;
406
407 Some(match format.key {
408 Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
409 &key,
410 key_schema.as_ref().unwrap().top_node(),
411 )?)),
412 Format::Json => DecodedValue::Json(key),
413 Format::Bytes => {
414 unimplemented!("bytes format not yet supported in tests")
415 }
416 Format::Text => DecodedValue::Text(
417 key.as_str()
418 .map(|s| s.to_string())
419 .unwrap_or_else(|| key.to_string()),
420 ),
421 })
422 } else {
423 None
424 };
425
426 let value = match deserializer.next().transpose().context("parsing json")? {
427 None => None,
428 Some(value) if value.as_str() == Some("<null>") => None,
429 Some(value) => match format.value {
430 Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
431 &value,
432 value_schema.as_ref().unwrap().top_node(),
433 )?))),
434 Format::Json => Some(DecodedValue::Json(value)),
435 Format::Bytes => {
436 unimplemented!("bytes format not yet supported in tests")
437 }
438 Format::Text => Some(DecodedValue::Text(value.to_string())),
439 },
440 };
441
442 let content =
443 str::from_utf8(content).context("internal error: contents were previously a string")?;
444 let partition = match content.trim().split_once("=") {
445 None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
446 None => None,
447 Some((label, partition)) => {
448 if label != "partition" {
449 bail!("partition expectation has unexpected label: {label}")
450 }
451 Some(partition.parse().context("parsing expected partition")?)
452 }
453 };
454
455 expected.push(Record {
456 headers,
457 key,
458 value,
459 partition,
460 });
461 }
462
463 Ok(expected)
464}
465
466fn verify_with_partial_search<A>(
467 expected: &[Record<A>],
468 actual: &[Record<A>],
469 regex: &Option<Regex>,
470 regex_replacement: &String,
471 partial_search: bool,
472) -> Result<(), anyhow::Error>
473where
474 A: Debug + Clone,
475{
476 let mut expected = expected.iter();
477 let mut actual = actual.iter();
478 let mut index = 0..;
479
480 let mut found_beginning = !partial_search;
481 let mut expected_item = expected.next();
482 let mut actual_item = actual.next();
483 loop {
484 let i = index.next().expect("known to exist");
485 match (expected_item, actual_item) {
486 (Some(e), Some(a)) => {
487 let mut a = a.clone();
488 if e.partition.is_none() {
489 a.partition = None;
490 }
491 let e_str = format!("{:#?}", e);
492 let a_str = match ®ex {
493 Some(regex) => regex
494 .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
495 .to_string(),
496 _ => format!("{:#?}", a),
497 };
498
499 if e_str != a_str {
500 if found_beginning {
501 bail!(
502 "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
503 i,
504 e_str,
505 a_str,
506 );
507 }
508 actual_item = actual.next();
509 } else {
510 found_beginning = true;
511 expected_item = expected.next();
512 actual_item = actual.next();
513 }
514 }
515 (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
516 (None, Some(a)) => {
517 if !partial_search {
518 bail!("extra record {}: {:#?}", i, a);
519 }
520 break;
521 }
522 (None, None) => break,
523 }
524 }
525 let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
526 let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
527
528 if !expected.is_empty() {
529 bail!("missing records:\n{}", expected.join("\n"))
530 } else if !actual.is_empty() && !partial_search {
531 bail!("extra records:\n{}", actual.join("\n"))
532 } else {
533 Ok(())
534 }
535}