Skip to main content

mz_testdrive/action/kafka/
verify_data.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use mz_postgres_util::{Sql, query_one, sql};
16use rdkafka::consumer::{Consumer, StreamConsumer};
17use rdkafka::error::KafkaError;
18use rdkafka::message::{Headers, Message};
19use rdkafka::types::RDKafkaErrorCode;
20use regex::Regex;
21use tokio::pin;
22use tokio_stream::StreamExt;
23
24use crate::action::{ControlFlow, State};
25use crate::format::avro::{self, DebugValue};
26use crate::parser::BuiltinCommand;
27
28#[derive(Debug, Clone, Copy)]
29enum Format {
30    Avro,
31    Json,
32    Bytes,
33    Text,
34}
35
36impl TryFrom<&str> for Format {
37    type Error = anyhow::Error;
38
39    fn try_from(value: &str) -> Result<Self, Self::Error> {
40        match value {
41            "avro" => Ok(Format::Avro),
42            "json" => Ok(Format::Json),
43            "bytes" => Ok(Format::Bytes),
44            "text" => Ok(Format::Text),
45            f => bail!("unknown format: {}", f),
46        }
47    }
48}
49
50#[derive(Debug)]
51struct RecordFormat {
52    key: Format,
53    value: Format,
54    requires_key: bool,
55}
56
57#[allow(dead_code)]
58#[derive(Debug, Clone)]
59enum DecodedValue {
60    Avro(DebugValue),
61    Json(serde_json::Value),
62    Bytes(Vec<u8>),
63    Text(String),
64}
65
66enum Topic {
67    FromSink(String),
68    Named(String),
69}
70
71#[derive(Debug, Clone)]
72struct Record<A> {
73    headers: Vec<String>,
74    key: Option<A>,
75    value: Option<A>,
76    partition: Option<i32>,
77}
78
79async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
80    let query = sql!(
81        "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
82         ON mz_sinks.id = mz_kafka_sinks.id \
83         JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
84         LEFT JOIN mz_databases d ON d.id = s.database_id \
85         WHERE d.name = $1 \
86         AND s.name = $2 \
87         AND mz_sinks.name = $3",
88        Sql::ident(topic_field)
89    );
90    let sink_fields: Vec<&str> = sink.split('.').collect();
91    let result = query_one(
92        &state.materialize.pgclient,
93        query,
94        &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
95    )
96    .await
97    .context("retrieving topic name")?
98    .get(topic_field);
99    Ok(result)
100}
101
102pub async fn run_verify_data(
103    mut cmd: BuiltinCommand,
104    state: &State,
105) -> Result<ControlFlow, anyhow::Error> {
106    let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
107        // If just a single format is provided, the user should specify `key=true` if they expect a
108        // in each message.  However for format=avro, we will conveniently set this based
109        // on the presence of the key schema in the registry, so this argument is not required.
110        let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
111        let format_type = format_str.as_str().try_into()?;
112        RecordFormat {
113            key: format_type,
114            value: format_type,
115            requires_key,
116        }
117    } else {
118        let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
119        let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
120        RecordFormat {
121            key: key_format,
122            value: value_format,
123            requires_key: true,
124        }
125    };
126
127    let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
128        (Some(sink), None) => Topic::FromSink(sink),
129        (None, Some(topic)) => Topic::Named(topic),
130        (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
131        (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
132    };
133
134    let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
135
136    let header_keys: Vec<_> = cmd
137        .args
138        .opt_string("headers")
139        .map(|s| s.split(',').map(str::to_owned).collect())
140        .unwrap_or_default();
141
142    let expected_messages = cmd.input;
143    if expected_messages.len() == 0 {
144        // verify with 0 messages doesn't check that no messages have been written -
145        // it 'verifies' 0 messages and trivially returns true
146        bail!("kafka-verify-data requires a non-empty list of expected messages");
147    }
148    let partial_search = cmd.args.opt_parse("partial-search")?;
149    let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
150    cmd.args.done()?;
151
152    let topic: String = match &source {
153        Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
154        Topic::Named(name) => name.clone(),
155    };
156
157    println!("Verifying results in Kafka topic {}", topic);
158
159    let mut config = state.kafka_config.clone();
160    config.set("enable.auto.offset.store", "false");
161
162    let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
163    consumer
164        .subscribe(&[&topic])
165        .context("subscribing to kafka topic")?;
166
167    let (mut stream_messages_remaining, stream_timeout) = match partial_search {
168        Some(size) => (size, state.timeout),
169        None => (expected_messages.len(), Duration::from_secs(15)),
170    };
171
172    let timeout = cmp::max(state.timeout, stream_timeout);
173
174    let message_stream = consumer.stream().timeout(timeout);
175    pin!(message_stream);
176
177    // Collect all messages that arrive without timing out. If we trip
178    // the timeout, suppress the error and return what we have. This
179    // is nicer than returning "timeout expired", as the user will
180    // instead get an error message about the expected messages that
181    // were missing.
182    let mut actual_bytes = vec![];
183
184    let start = std::time::Instant::now();
185    let mut topic_created = false;
186
187    while stream_messages_remaining > 0 {
188        match message_stream.next().await {
189            Some(Ok(message)) => {
190                let message = match message {
191                    // We create topics after creating sinks, so we permit
192                    // retries here while waiting for the topic to get created.
193                    Err(KafkaError::MessageConsumption(
194                        RDKafkaErrorCode::UnknownTopicOrPartition,
195                    )) if start.elapsed() < timeout && !topic_created => {
196                        println!("waiting for Kafka topic creation...");
197                        continue;
198                    }
199                    e => e?,
200                };
201
202                stream_messages_remaining -= 1;
203                topic_created = true;
204
205                consumer
206                    .store_offset_from_message(&message)
207                    .context("storing message offset")?;
208
209                let mut headers = vec![];
210                for header_key in &header_keys {
211                    // Expect a unique header with the given key and a UTF8-formatted body.
212                    let hs = message.headers().context("expected headers for message")?;
213                    let mut hs = hs.iter().filter(|i| i.key == header_key);
214                    let h = hs.next();
215                    if hs.next().is_some() {
216                        bail!("expected at most one header with key {header_key}");
217                    }
218                    match h {
219                        None => headers.push("<missing>".into()),
220                        Some(h) => {
221                            let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
222                            headers.push(value.into());
223                        }
224                    }
225                }
226
227                actual_bytes.push(Record {
228                    headers,
229                    key: message.key().map(|b| b.to_owned()),
230                    value: message.payload().map(|b| b.to_owned()),
231                    partition: Some(message.partition()),
232                });
233            }
234            Some(Err(e)) => {
235                println!("Received error from Kafka stream consumer: {}", e);
236                break;
237            }
238            None => {
239                break;
240            }
241        }
242    }
243
244    let key_schema = if let Format::Avro = format.key {
245        let schema = state
246            .ccsr_client
247            .get_schema_by_subject(&format!("{}-key", topic))
248            .await
249            .ok()
250            .map(|key_schema| {
251                avro::parse_schema(&key_schema.raw, &[]).context("parsing avro schema")
252            })
253            .transpose()?;
254        // for avro, we can determine if a key is required based on the presence of the key schema
255        // rather than requiring the user to specify the key=true flag
256        if schema.is_some() {
257            format.requires_key = true;
258        }
259        schema
260    } else {
261        None
262    };
263    let value_schema = if let Format::Avro = format.value {
264        let val_schema = state
265            .ccsr_client
266            .get_schema_by_subject(&format!("{}-value", topic))
267            .await
268            .context("fetching schema")?
269            .raw;
270        Some(avro::parse_schema(&val_schema, &[]).context("parsing avro schema")?)
271    } else {
272        None
273    };
274
275    let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
276
277    if sort_messages {
278        actual_messages.sort_by_key(|r| format!("{:?}", r));
279    }
280
281    if debug_print_only {
282        bail!(
283            "records in sink:\n{}",
284            actual_messages
285                .into_iter()
286                .map(|a| format!("{:#?}", a))
287                .collect::<Vec<_>>()
288                .join("\n")
289        );
290    }
291
292    let expected = parse_expected_messages(
293        expected_messages,
294        key_schema,
295        value_schema,
296        &format,
297        &header_keys,
298    )?;
299
300    verify_with_partial_search(
301        &expected,
302        &actual_messages,
303        &state.regex,
304        &state.regex_replacement,
305        partial_search.is_some(),
306    )?;
307
308    Ok(ControlFlow::Continue)
309}
310
311/// Expect and split out `n` whitespace-delimited headers before the main contents of the 'expect' row.
312fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
313    let whitespace = Regex::new("\\s+").expect("building known-valid regex");
314    let mut parts = whitespace.splitn(input, n_headers + 1);
315    let mut headers = Vec::with_capacity(n_headers);
316    for _ in 0..n_headers {
317        headers.push(
318            parts
319                .next()
320                .context("expected another header in the input")?
321                .to_string(),
322        )
323    }
324    let rest = parts
325        .next()
326        .context("expected some contents after any message headers")?;
327
328    ensure!(
329        parts.next().is_none(),
330        "more than n+1 elements from a call to splitn(_, n+1)"
331    );
332
333    Ok((headers, rest))
334}
335
336fn decode_messages(
337    actual_bytes: Vec<Record<Vec<u8>>>,
338    key_schema: &Option<mz_avro::Schema>,
339    value_schema: &Option<mz_avro::Schema>,
340    format: &RecordFormat,
341) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
342    let mut actual_messages = vec![];
343
344    for record in actual_bytes {
345        let Record { key, value, .. } = record;
346        let key = if format.requires_key {
347            match (key, format.key) {
348                (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
349                    avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
350                ))),
351                (Some(bytes), Format::Json) => Some(DecodedValue::Json(
352                    serde_json::from_slice(&bytes).context("decoding json")?,
353                )),
354                (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
355                (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
356                (None, _) if format.requires_key => bail!("empty message key"),
357                (None, _) => None,
358            }
359        } else {
360            None
361        };
362
363        let value = match (value, format.value) {
364            (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
365                avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
366            ))),
367            (Some(bytes), Format::Json) => Some(DecodedValue::Json(
368                serde_json::from_slice(&bytes).context("decoding json")?,
369            )),
370            (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
371            (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
372            (None, _) => None,
373        };
374
375        actual_messages.push(Record {
376            headers: record.headers.clone(),
377            key,
378            value,
379            partition: record.partition,
380        });
381    }
382
383    Ok(actual_messages)
384}
385
386fn parse_expected_messages(
387    expected_messages: Vec<String>,
388    key_schema: Option<mz_avro::Schema>,
389    value_schema: Option<mz_avro::Schema>,
390    format: &RecordFormat,
391    header_keys: &[String],
392) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
393    let mut expected = vec![];
394
395    for msg in expected_messages {
396        let (headers, content) = split_headers(&msg, header_keys.len())?;
397        let mut content = content.as_bytes();
398        let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
399
400        let key = if format.requires_key {
401            let key: serde_json::Value = deserializer
402                .next()
403                .context("key missing in input line")?
404                .context("parsing json")?;
405
406            Some(match format.key {
407                Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
408                    &key,
409                    key_schema.as_ref().unwrap().top_node(),
410                )?)),
411                Format::Json => DecodedValue::Json(key),
412                Format::Bytes => {
413                    unimplemented!("bytes format not yet supported in tests")
414                }
415                Format::Text => DecodedValue::Text(
416                    key.as_str()
417                        .map(|s| s.to_string())
418                        .unwrap_or_else(|| key.to_string()),
419                ),
420            })
421        } else {
422            None
423        };
424
425        let value = match deserializer.next().transpose().context("parsing json")? {
426            None => None,
427            Some(value) if value.as_str() == Some("<null>") => None,
428            Some(value) => match format.value {
429                Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
430                    &value,
431                    value_schema.as_ref().unwrap().top_node(),
432                )?))),
433                Format::Json => Some(DecodedValue::Json(value)),
434                Format::Bytes => {
435                    unimplemented!("bytes format not yet supported in tests")
436                }
437                Format::Text => Some(DecodedValue::Text(value.to_string())),
438            },
439        };
440
441        let content =
442            str::from_utf8(content).context("internal error: contents were previously a string")?;
443        let partition = match content.trim().split_once("=") {
444            None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
445            None => None,
446            Some((label, partition)) => {
447                if label != "partition" {
448                    bail!("partition expectation has unexpected label: {label}")
449                }
450                Some(partition.parse().context("parsing expected partition")?)
451            }
452        };
453
454        expected.push(Record {
455            headers,
456            key,
457            value,
458            partition,
459        });
460    }
461
462    Ok(expected)
463}
464
465fn verify_with_partial_search<A>(
466    expected: &[Record<A>],
467    actual: &[Record<A>],
468    regex: &Option<Regex>,
469    regex_replacement: &String,
470    partial_search: bool,
471) -> Result<(), anyhow::Error>
472where
473    A: Debug + Clone,
474{
475    let mut expected = expected.iter();
476    let mut actual = actual.iter();
477    let mut index = 0..;
478
479    let mut found_beginning = !partial_search;
480    let mut expected_item = expected.next();
481    let mut actual_item = actual.next();
482    loop {
483        let i = index.next().expect("known to exist");
484        match (expected_item, actual_item) {
485            (Some(e), Some(a)) => {
486                let mut a = a.clone();
487                if e.partition.is_none() {
488                    a.partition = None;
489                }
490                let e_str = format!("{:#?}", e);
491                let a_str = match &regex {
492                    Some(regex) => regex
493                        .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
494                        .to_string(),
495                    _ => format!("{:#?}", a),
496                };
497
498                if e_str != a_str {
499                    if found_beginning {
500                        bail!(
501                            "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
502                            i,
503                            e_str,
504                            a_str,
505                        );
506                    }
507                    actual_item = actual.next();
508                } else {
509                    found_beginning = true;
510                    expected_item = expected.next();
511                    actual_item = actual.next();
512                }
513            }
514            (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
515            (None, Some(a)) => {
516                if !partial_search {
517                    bail!("extra record {}: {:#?}", i, a);
518                }
519                break;
520            }
521            (None, None) => break,
522        }
523    }
524    let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
525    let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
526
527    if !expected.is_empty() {
528        bail!("missing records:\n{}", expected.join("\n"))
529    } else if !actual.is_empty() && !partial_search {
530        bail!("extra records:\n{}", actual.join("\n"))
531    } else {
532        Ok(())
533    }
534}