1use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use rdkafka::consumer::{Consumer, StreamConsumer};
16use rdkafka::error::KafkaError;
17use rdkafka::message::{Headers, Message};
18use rdkafka::types::RDKafkaErrorCode;
19use regex::Regex;
20use tokio::pin;
21use tokio_stream::StreamExt;
22
23use crate::action::{ControlFlow, State};
24use crate::format::avro::{self, DebugValue};
25use crate::parser::BuiltinCommand;
26
27#[derive(Debug, Clone, Copy)]
28enum Format {
29 Avro,
30 Json,
31 Bytes,
32 Text,
33}
34
35impl TryFrom<&str> for Format {
36 type Error = anyhow::Error;
37
38 fn try_from(value: &str) -> Result<Self, Self::Error> {
39 match value {
40 "avro" => Ok(Format::Avro),
41 "json" => Ok(Format::Json),
42 "bytes" => Ok(Format::Bytes),
43 "text" => Ok(Format::Text),
44 f => bail!("unknown format: {}", f),
45 }
46 }
47}
48
49#[derive(Debug)]
50struct RecordFormat {
51 key: Format,
52 value: Format,
53 requires_key: bool,
54}
55
56#[allow(dead_code)]
57#[derive(Debug, Clone)]
58enum DecodedValue {
59 Avro(DebugValue),
60 Json(serde_json::Value),
61 Bytes(Vec<u8>),
62 Text(String),
63}
64
65enum Topic {
66 FromSink(String),
67 Named(String),
68}
69
70#[derive(Debug, Clone)]
71struct Record<A> {
72 headers: Vec<String>,
73 key: Option<A>,
74 value: Option<A>,
75 partition: Option<i32>,
76}
77
78async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
79 let query = format!(
80 "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
81 ON mz_sinks.id = mz_kafka_sinks.id \
82 JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
83 LEFT JOIN mz_databases d ON d.id = s.database_id \
84 WHERE d.name = $1 \
85 AND s.name = $2 \
86 AND mz_sinks.name = $3",
87 topic_field
88 );
89 let sink_fields: Vec<&str> = sink.split('.').collect();
90 let result = state
91 .materialize
92 .pgclient
93 .query_one(
94 query.as_str(),
95 &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
96 )
97 .await
98 .context("retrieving topic name")?
99 .get(topic_field);
100 Ok(result)
101}
102
103pub async fn run_verify_data(
104 mut cmd: BuiltinCommand,
105 state: &State,
106) -> Result<ControlFlow, anyhow::Error> {
107 let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
108 let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
112 let format_type = format_str.as_str().try_into()?;
113 RecordFormat {
114 key: format_type,
115 value: format_type,
116 requires_key,
117 }
118 } else {
119 let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
120 let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
121 RecordFormat {
122 key: key_format,
123 value: value_format,
124 requires_key: true,
125 }
126 };
127
128 let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
129 (Some(sink), None) => Topic::FromSink(sink),
130 (None, Some(topic)) => Topic::Named(topic),
131 (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
132 (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
133 };
134
135 let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
136
137 let header_keys: Vec<_> = cmd
138 .args
139 .opt_string("headers")
140 .map(|s| s.split(',').map(str::to_owned).collect())
141 .unwrap_or_default();
142
143 let expected_messages = cmd.input;
144 if expected_messages.len() == 0 {
145 bail!("kafka-verify-data requires a non-empty list of expected messages");
148 }
149 let partial_search = cmd.args.opt_parse("partial-search")?;
150 let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
151 cmd.args.done()?;
152
153 let topic: String = match &source {
154 Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
155 Topic::Named(name) => name.clone(),
156 };
157
158 println!("Verifying results in Kafka topic {}", topic);
159
160 let mut config = state.kafka_config.clone();
161 config.set("enable.auto.offset.store", "false");
162
163 let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
164 consumer
165 .subscribe(&[&topic])
166 .context("subscribing to kafka topic")?;
167
168 let (mut stream_messages_remaining, stream_timeout) = match partial_search {
169 Some(size) => (size, state.default_timeout),
170 None => (expected_messages.len(), Duration::from_secs(15)),
171 };
172
173 let timeout = cmp::max(state.default_timeout, stream_timeout);
174
175 let message_stream = consumer.stream().timeout(timeout);
176 pin!(message_stream);
177
178 let mut actual_bytes = vec![];
184
185 let start = std::time::Instant::now();
186 let mut topic_created = false;
187
188 while stream_messages_remaining > 0 {
189 match message_stream.next().await {
190 Some(Ok(message)) => {
191 let message = match message {
192 Err(KafkaError::MessageConsumption(
195 RDKafkaErrorCode::UnknownTopicOrPartition,
196 )) if start.elapsed() < timeout && !topic_created => {
197 println!("waiting for Kafka topic creation...");
198 continue;
199 }
200 e => e?,
201 };
202
203 stream_messages_remaining -= 1;
204 topic_created = true;
205
206 consumer
207 .store_offset_from_message(&message)
208 .context("storing message offset")?;
209
210 let mut headers = vec![];
211 for header_key in &header_keys {
212 let hs = message.headers().context("expected headers for message")?;
214 let mut hs = hs.iter().filter(|i| i.key == header_key);
215 let h = hs.next();
216 if hs.next().is_some() {
217 bail!("expected at most one header with key {header_key}");
218 }
219 match h {
220 None => headers.push("<missing>".into()),
221 Some(h) => {
222 let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
223 headers.push(value.into());
224 }
225 }
226 }
227
228 actual_bytes.push(Record {
229 headers,
230 key: message.key().map(|b| b.to_owned()),
231 value: message.payload().map(|b| b.to_owned()),
232 partition: Some(message.partition()),
233 });
234 }
235 Some(Err(e)) => {
236 println!("Received error from Kafka stream consumer: {}", e);
237 break;
238 }
239 None => {
240 break;
241 }
242 }
243 }
244
245 let key_schema = if let Format::Avro = format.key {
246 let schema = state
247 .ccsr_client
248 .get_schema_by_subject(&format!("{}-key", topic))
249 .await
250 .ok()
251 .map(|key_schema| avro::parse_schema(&key_schema.raw).context("parsing avro schema"))
252 .transpose()?;
253 if schema.is_some() {
256 format.requires_key = true;
257 }
258 schema
259 } else {
260 None
261 };
262 let value_schema = if let Format::Avro = format.value {
263 let val_schema = state
264 .ccsr_client
265 .get_schema_by_subject(&format!("{}-value", topic))
266 .await
267 .context("fetching schema")?
268 .raw;
269 Some(avro::parse_schema(&val_schema).context("parsing avro schema")?)
270 } else {
271 None
272 };
273
274 let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
275
276 if sort_messages {
277 actual_messages.sort_by_key(|r| format!("{:?}", r));
278 }
279
280 if debug_print_only {
281 bail!(
282 "records in sink:\n{}",
283 actual_messages
284 .into_iter()
285 .map(|a| format!("{:#?}", a))
286 .collect::<Vec<_>>()
287 .join("\n")
288 );
289 }
290
291 let expected = parse_expected_messages(
292 expected_messages,
293 key_schema,
294 value_schema,
295 &format,
296 &header_keys,
297 )?;
298
299 verify_with_partial_search(
300 &expected,
301 &actual_messages,
302 &state.regex,
303 &state.regex_replacement,
304 partial_search.is_some(),
305 )?;
306
307 Ok(ControlFlow::Continue)
308}
309
310fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
312 let whitespace = Regex::new("\\s+").expect("building known-valid regex");
313 let mut parts = whitespace.splitn(input, n_headers + 1);
314 let mut headers = Vec::with_capacity(n_headers);
315 for _ in 0..n_headers {
316 headers.push(
317 parts
318 .next()
319 .context("expected another header in the input")?
320 .to_string(),
321 )
322 }
323 let rest = parts
324 .next()
325 .context("expected some contents after any message headers")?;
326
327 ensure!(
328 parts.next().is_none(),
329 "more than n+1 elements from a call to splitn(_, n+1)"
330 );
331
332 Ok((headers, rest))
333}
334
335fn decode_messages(
336 actual_bytes: Vec<Record<Vec<u8>>>,
337 key_schema: &Option<mz_avro::Schema>,
338 value_schema: &Option<mz_avro::Schema>,
339 format: &RecordFormat,
340) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
341 let mut actual_messages = vec![];
342
343 for record in actual_bytes {
344 let Record { key, value, .. } = record;
345 let key = if format.requires_key {
346 match (key, format.key) {
347 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
348 avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
349 ))),
350 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
351 serde_json::from_slice(&bytes).context("decoding json")?,
352 )),
353 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
354 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
355 (None, _) if format.requires_key => bail!("empty message key"),
356 (None, _) => None,
357 }
358 } else {
359 None
360 };
361
362 let value = match (value, format.value) {
363 (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
364 avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
365 ))),
366 (Some(bytes), Format::Json) => Some(DecodedValue::Json(
367 serde_json::from_slice(&bytes).context("decoding json")?,
368 )),
369 (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
370 (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
371 (None, _) => None,
372 };
373
374 actual_messages.push(Record {
375 headers: record.headers.clone(),
376 key,
377 value,
378 partition: record.partition,
379 });
380 }
381
382 Ok(actual_messages)
383}
384
385fn parse_expected_messages(
386 expected_messages: Vec<String>,
387 key_schema: Option<mz_avro::Schema>,
388 value_schema: Option<mz_avro::Schema>,
389 format: &RecordFormat,
390 header_keys: &[String],
391) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
392 let mut expected = vec![];
393
394 for msg in expected_messages {
395 let (headers, content) = split_headers(&msg, header_keys.len())?;
396 let mut content = content.as_bytes();
397 let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
398
399 let key = if format.requires_key {
400 let key: serde_json::Value = deserializer
401 .next()
402 .context("key missing in input line")?
403 .context("parsing json")?;
404
405 Some(match format.key {
406 Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
407 &key,
408 key_schema.as_ref().unwrap().top_node(),
409 )?)),
410 Format::Json => DecodedValue::Json(key),
411 Format::Bytes => {
412 unimplemented!("bytes format not yet supported in tests")
413 }
414 Format::Text => DecodedValue::Text(
415 key.as_str()
416 .map(|s| s.to_string())
417 .unwrap_or_else(|| key.to_string()),
418 ),
419 })
420 } else {
421 None
422 };
423
424 let value = match deserializer.next().transpose().context("parsing json")? {
425 None => None,
426 Some(value) if value.as_str() == Some("<null>") => None,
427 Some(value) => match format.value {
428 Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
429 &value,
430 value_schema.as_ref().unwrap().top_node(),
431 )?))),
432 Format::Json => Some(DecodedValue::Json(value)),
433 Format::Bytes => {
434 unimplemented!("bytes format not yet supported in tests")
435 }
436 Format::Text => Some(DecodedValue::Text(value.to_string())),
437 },
438 };
439
440 let content =
441 str::from_utf8(content).context("internal error: contents were previously a string")?;
442 let partition = match content.trim().split_once("=") {
443 None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
444 None => None,
445 Some((label, partition)) => {
446 if label != "partition" {
447 bail!("partition expectation has unexpected label: {label}")
448 }
449 Some(partition.parse().context("parsing expected partition")?)
450 }
451 };
452
453 expected.push(Record {
454 headers,
455 key,
456 value,
457 partition,
458 });
459 }
460
461 Ok(expected)
462}
463
464fn verify_with_partial_search<A>(
465 expected: &[Record<A>],
466 actual: &[Record<A>],
467 regex: &Option<Regex>,
468 regex_replacement: &String,
469 partial_search: bool,
470) -> Result<(), anyhow::Error>
471where
472 A: Debug + Clone,
473{
474 let mut expected = expected.iter();
475 let mut actual = actual.iter();
476 let mut index = 0..;
477
478 let mut found_beginning = !partial_search;
479 let mut expected_item = expected.next();
480 let mut actual_item = actual.next();
481 loop {
482 let i = index.next().expect("known to exist");
483 match (expected_item, actual_item) {
484 (Some(e), Some(a)) => {
485 let mut a = a.clone();
486 if e.partition.is_none() {
487 a.partition = None;
488 }
489 let e_str = format!("{:#?}", e);
490 let a_str = match ®ex {
491 Some(regex) => regex
492 .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
493 .to_string(),
494 _ => format!("{:#?}", a),
495 };
496
497 if e_str != a_str {
498 if found_beginning {
499 bail!(
500 "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
501 i,
502 e_str,
503 a_str,
504 );
505 }
506 actual_item = actual.next();
507 } else {
508 found_beginning = true;
509 expected_item = expected.next();
510 actual_item = actual.next();
511 }
512 }
513 (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
514 (None, Some(a)) => {
515 if !partial_search {
516 bail!("extra record {}: {:#?}", i, a);
517 }
518 break;
519 }
520 (None, None) => break,
521 }
522 }
523 let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
524 let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
525
526 if !expected.is_empty() {
527 bail!("missing records:\n{}", expected.join("\n"))
528 } else if !actual.is_empty() && !partial_search {
529 bail!("extra records:\n{}", actual.join("\n"))
530 } else {
531 Ok(())
532 }
533}