1use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use mz_postgres_util::{Sql, query_one, sql};
16use rdkafka::consumer::{Consumer, StreamConsumer};
17use rdkafka::error::KafkaError;
18use rdkafka::message::{Headers, Message};
19use rdkafka::types::RDKafkaErrorCode;
20use regex::Regex;
21use tokio::pin;
22use tokio_stream::StreamExt;
23
24use crate::action::{ControlFlow, State};
25use crate::format::avro::{self, DebugValue};
26use crate::parser::BuiltinCommand;
27
28#[derive(Debug, Clone, Copy)]
29enum Format {
30 Avro,
31 Json,
32 Bytes,
33 Text,
34}
35
36impl TryFrom<&str> for Format {
37 type Error = anyhow::Error;
38
39 fn try_from(value: &str) -> Result<Self, Self::Error> {
40 match value {
41 "avro" => Ok(Format::Avro),
42 "json" => Ok(Format::Json),
43 "bytes" => Ok(Format::Bytes),
44 "text" => Ok(Format::Text),
45 f => bail!("unknown format: {}", f),
46 }
47 }
48}
49
50#[derive(Debug)]
51struct RecordFormat {
52 key: Format,
53 value: Format,
54 requires_key: bool,
55}
56
57#[allow(dead_code)]
58#[derive(Debug, Clone)]
59enum DecodedValue {
60 Avro(DebugValue),
61 Json(serde_json::Value),
62 Bytes(Vec<u8>),
63 Text(String),
64}
65
66enum Topic {
67 FromSink(String),
68 Named(String),
69}
70
71#[derive(Debug, Clone)]
72struct Record<A> {
73 headers: Vec<String>,
74 key: Option<A>,
75 value: Option<A>,
76 partition: Option<i32>,
77}
78
79async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
80 let query = sql!(
81 "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
82 ON mz_sinks.id = mz_kafka_sinks.id \
83 JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
84 LEFT JOIN mz_databases d ON d.id = s.database_id \
85 WHERE d.name = $1 \
86 AND s.name = $2 \
87 AND mz_sinks.name = $3",
88 Sql::ident(topic_field)
89 );
90 let sink_fields: Vec<&str> = sink.split('.').collect();
91 let result = query_one(
92 &state.materialize.pgclient,
93 query,
94 &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
95 )
96 .await
97 .context("retrieving topic name")?
98 .get(topic_field);
99 Ok(result)
100}
101
102pub async fn run_verify_data(
103 mut cmd: BuiltinCommand,
104 state: &State,
105) -> Result<ControlFlow, anyhow::Error> {
106 let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
107 let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
111 let format_type = format_str.as_str().try_into()?;
112 RecordFormat {
113 key: format_type,
114 value: format_type,
115 requires_key,
116 }
117 } else {
118 let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
119 let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
120 RecordFormat {
121 key: key_format,
122 value: value_format,
123 requires_key: true,
124 }
125 };
126
127 let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
128 (Some(sink), None) => Topic::FromSink(sink),
129 (None, Some(topic)) => Topic::Named(topic),
130 (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
131 (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
132 };
133
134 let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
135
136 let header_keys: Vec<_> = cmd
137 .args
138 .opt_string("headers")
139 .map(|s| s.split(',').map(str::to_owned).collect())
140 .unwrap_or_default();
141
142 let expected_messages = cmd.input;
143 if expected_messages.len() == 0 {
144 bail!("kafka-verify-data requires a non-empty list of expected messages");
147 }
148 let partial_search = cmd.args.opt_parse("partial-search")?;
149 let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
150 cmd.args.done()?;
151
152 let topic: String = match &source {
153 Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
154 Topic::Named(name) => name.clone(),
155 };
156
157 println!("Verifying results in Kafka topic {}", topic);
158
159 let mut config = state.kafka_config.clone();
160 config.set("enable.auto.offset.store", "false");
161
162 let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
163 consumer
164 .subscribe(&[&topic])
165 .context("subscribing to kafka topic")?;
166
167 let (mut stream_messages_remaining, stream_timeout) = match partial_search {
168 Some(size) => (size, state.timeout),
169 None => (expected_messages.len(), Duration::from_secs(15)),
170 };
171
172 let timeout = cmp::max(state.timeout, stream_timeout);
173
174 let message_stream = consumer.stream().timeout(timeout);
175 pin!(message_stream);
176
177 let mut actual_bytes = vec![];
183
184 let start = std::time::Instant::now();
185 let mut topic_created = false;
186
187 while stream_messages_remaining > 0 {
188 match message_stream.next().await {
189 Some(Ok(message)) => {
190 let message = match message {
191 Err(KafkaError::MessageConsumption(
194 RDKafkaErrorCode::UnknownTopicOrPartition,
195 )) if start.elapsed() < timeout && !topic_created => {
196 println!("waiting for Kafka topic creation...");
197 continue;
198 }
199 e => e?,
200 };
201
202 stream_messages_remaining -= 1;
203 topic_created = true;
204
205 consumer
206 .store_offset_from_message(&message)
207 .context("storing message offset")?;
208
209 let mut headers = vec![];
210 for header_key in &header_keys {
211 let hs = message.headers().context("expected headers for message")?;
213 let mut hs = hs.iter().filter(|i| i.key == header_key);
214 let h = hs.next();
215 if hs.next().is_some() {
216 bail!("expected at most one header with key {header_key}");
217 }
218 match h {
219 None => headers.push("<missing>".into()),
220 Some(h) => {
221 let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
222 headers.push(value.into());
223 }
224 }
225 }
226
227 actual_bytes.push(Record {
228 headers,
229 key: message.key().map(|b| b.to_owned()),
230 value: message.payload().map(|b| b.to_owned()),
231 partition: Some(message.partition()),
232 });
233 }
234 Some(Err(e)) => {
235 println!("Received error from Kafka stream consumer: {}", e);
236 break;
237 }
238 None => {
239 break;
240 }
241 }
242 }
243
244 let key_schema = if let Format::Avro = format.key {
245 let schema = state
246 .ccsr_client
247 .get_schema_by_subject(&format!("{}-key", topic))
248 .await
249 .ok()
250 .map(|key_schema| {
251 avro::parse_schema(&key_schema.raw, &[]).context("parsing avro schema")
252 })
253 .transpose()?;
254 if schema.is_some() {
257 format.requires_key = true;
258 }
259 schema
260 } else {
261 None
262 };
263 let value_schema = if let Format::Avro = format.value {
264 let val_schema = state
265 .ccsr_client
266 .get_schema_by_subject(&format!("{}-value", topic))
267 .await
268 .context("fetching schema")?
269 .raw;
270 Some(avro::parse_schema(&val_schema, &[]).context("parsing avro schema")?)
271 } else {
272 None
273 };
274
275 let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
276
277 if sort_messages {
278 actual_messages.sort_by_key(|r| format!("{:?}", r));
279 }
280
281 if debug_print_only {
282 bail!(
283 "records in sink:\n{}",
284 actual_messages
285 .into_iter()
286 .map(|a| format!("{:#?}", a))
287 .collect::<Vec<_>>()
288 .join("\n")
289 );
290 }
291
292 let expected = parse_expected_messages(
293 expected_messages,
294 key_schema,
295 value_schema,
296 &format,
297 &header_keys,
298 )?;
299
300 verify_with_partial_search(
301 &expected,
302 &actual_messages,
303 &state.regex,
304 &state.regex_replacement,
305 partial_search.is_some(),
306 )?;
307
308 Ok(ControlFlow::Continue)
309}
310
311fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
313 let whitespace = Regex::new("\\s+").expect("building known-valid regex");
314 let mut parts = whitespace.splitn(input, n_headers + 1);
315 let mut headers = Vec::with_capacity(n_headers);
316 for _ in 0..n_headers {
317 headers.push(
318 parts
319 .next()
320 .context("expected another header in the input")?
321 .to_string(),
322 )
323 }
324 let rest = parts
325 .next()
326 .context("expected some contents after any message headers")?;
327
328 ensure!(
329 parts.next().is_none(),
330 "more than n+1 elements from a call to splitn(_, n+1)"
331 );
332
333 Ok((headers, rest))
334}
335
336fn decode_messages(
337 actual_bytes: Vec<Record<Vec<u8>>>,
338 key_schema: &Option<mz_avro::Schema>,
339 value_schema: &Option<mz_avro::Schema>,
340 format: &RecordFormat,
341) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
342 let mut actual_messages = vec![];
343
344 for record in actual_bytes {
345 let Record { key, value, .. } = record;
346 let key = if format.requires_key {
347 match (key, format.key) {
348 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
349 avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
350 ))),
351 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
352 serde_json::from_slice(&bytes).context("decoding json")?,
353 )),
354 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
355 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
356 (None, _) if format.requires_key => bail!("empty message key"),
357 (None, _) => None,
358 }
359 } else {
360 None
361 };
362
363 let value = match (value, format.value) {
364 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
365 avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
366 ))),
367 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
368 serde_json::from_slice(&bytes).context("decoding json")?,
369 )),
370 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
371 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
372 (None, _) => None,
373 };
374
375 actual_messages.push(Record {
376 headers: record.headers.clone(),
377 key,
378 value,
379 partition: record.partition,
380 });
381 }
382
383 Ok(actual_messages)
384}
385
386fn parse_expected_messages(
387 expected_messages: Vec<String>,
388 key_schema: Option<mz_avro::Schema>,
389 value_schema: Option<mz_avro::Schema>,
390 format: &RecordFormat,
391 header_keys: &[String],
392) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
393 let mut expected = vec![];
394
395 for msg in expected_messages {
396 let (headers, content) = split_headers(&msg, header_keys.len())?;
397 let mut content = content.as_bytes();
398 let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
399
400 let key = if format.requires_key {
401 let key: serde_json::Value = deserializer
402 .next()
403 .context("key missing in input line")?
404 .context("parsing json")?;
405
406 Some(match format.key {
407 Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
408 &key,
409 key_schema.as_ref().unwrap().top_node(),
410 )?)),
411 Format::Json => DecodedValue::Json(key),
412 Format::Bytes => {
413 unimplemented!("bytes format not yet supported in tests")
414 }
415 Format::Text => DecodedValue::Text(
416 key.as_str()
417 .map(|s| s.to_string())
418 .unwrap_or_else(|| key.to_string()),
419 ),
420 })
421 } else {
422 None
423 };
424
425 let value = match deserializer.next().transpose().context("parsing json")? {
426 None => None,
427 Some(value) if value.as_str() == Some("<null>") => None,
428 Some(value) => match format.value {
429 Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
430 &value,
431 value_schema.as_ref().unwrap().top_node(),
432 )?))),
433 Format::Json => Some(DecodedValue::Json(value)),
434 Format::Bytes => {
435 unimplemented!("bytes format not yet supported in tests")
436 }
437 Format::Text => Some(DecodedValue::Text(value.to_string())),
438 },
439 };
440
441 let content =
442 str::from_utf8(content).context("internal error: contents were previously a string")?;
443 let partition = match content.trim().split_once("=") {
444 None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
445 None => None,
446 Some((label, partition)) => {
447 if label != "partition" {
448 bail!("partition expectation has unexpected label: {label}")
449 }
450 Some(partition.parse().context("parsing expected partition")?)
451 }
452 };
453
454 expected.push(Record {
455 headers,
456 key,
457 value,
458 partition,
459 });
460 }
461
462 Ok(expected)
463}
464
465fn verify_with_partial_search<A>(
466 expected: &[Record<A>],
467 actual: &[Record<A>],
468 regex: &Option<Regex>,
469 regex_replacement: &String,
470 partial_search: bool,
471) -> Result<(), anyhow::Error>
472where
473 A: Debug + Clone,
474{
475 let mut expected = expected.iter();
476 let mut actual = actual.iter();
477 let mut index = 0..;
478
479 let mut found_beginning = !partial_search;
480 let mut expected_item = expected.next();
481 let mut actual_item = actual.next();
482 loop {
483 let i = index.next().expect("known to exist");
484 match (expected_item, actual_item) {
485 (Some(e), Some(a)) => {
486 let mut a = a.clone();
487 if e.partition.is_none() {
488 a.partition = None;
489 }
490 let e_str = format!("{:#?}", e);
491 let a_str = match ®ex {
492 Some(regex) => regex
493 .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
494 .to_string(),
495 _ => format!("{:#?}", a),
496 };
497
498 if e_str != a_str {
499 if found_beginning {
500 bail!(
501 "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
502 i,
503 e_str,
504 a_str,
505 );
506 }
507 actual_item = actual.next();
508 } else {
509 found_beginning = true;
510 expected_item = expected.next();
511 actual_item = actual.next();
512 }
513 }
514 (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
515 (None, Some(a)) => {
516 if !partial_search {
517 bail!("extra record {}: {:#?}", i, a);
518 }
519 break;
520 }
521 (None, None) => break,
522 }
523 }
524 let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
525 let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
526
527 if !expected.is_empty() {
528 bail!("missing records:\n{}", expected.join("\n"))
529 } else if !actual.is_empty() && !partial_search {
530 bail!("extra records:\n{}", actual.join("\n"))
531 } else {
532 Ok(())
533 }
534}