mz_testdrive/action/kafka/
verify_data.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use rdkafka::consumer::{Consumer, StreamConsumer};
16use rdkafka::error::KafkaError;
17use rdkafka::message::{Headers, Message};
18use rdkafka::types::RDKafkaErrorCode;
19use regex::Regex;
20use tokio::pin;
21use tokio_stream::StreamExt;
22
23use crate::action::{ControlFlow, State};
24use crate::format::avro::{self, DebugValue};
25use crate::parser::BuiltinCommand;
26
27#[derive(Debug, Clone, Copy)]
28enum Format {
29    Avro,
30    Json,
31    Bytes,
32    Text,
33}
34
35impl TryFrom<&str> for Format {
36    type Error = anyhow::Error;
37
38    fn try_from(value: &str) -> Result<Self, Self::Error> {
39        match value {
40            "avro" => Ok(Format::Avro),
41            "json" => Ok(Format::Json),
42            "bytes" => Ok(Format::Bytes),
43            "text" => Ok(Format::Text),
44            f => bail!("unknown format: {}", f),
45        }
46    }
47}
48
49#[derive(Debug)]
50struct RecordFormat {
51    key: Format,
52    value: Format,
53    requires_key: bool,
54}
55
56#[allow(dead_code)]
57#[derive(Debug, Clone)]
58enum DecodedValue {
59    Avro(DebugValue),
60    Json(serde_json::Value),
61    Bytes(Vec<u8>),
62    Text(String),
63}
64
65enum Topic {
66    FromSink(String),
67    Named(String),
68}
69
70#[derive(Debug, Clone)]
71struct Record<A> {
72    headers: Vec<String>,
73    key: Option<A>,
74    value: Option<A>,
75    partition: Option<i32>,
76}
77
78async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
79    let query = format!(
80        "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
81        ON mz_sinks.id = mz_kafka_sinks.id \
82        JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
83        LEFT JOIN mz_databases d ON d.id = s.database_id \
84        WHERE d.name = $1 \
85        AND s.name = $2 \
86        AND mz_sinks.name = $3",
87        topic_field
88    );
89    let sink_fields: Vec<&str> = sink.split('.').collect();
90    let result = state
91        .materialize
92        .pgclient
93        .query_one(
94            query.as_str(),
95            &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
96        )
97        .await
98        .context("retrieving topic name")?
99        .get(topic_field);
100    Ok(result)
101}
102
103pub async fn run_verify_data(
104    mut cmd: BuiltinCommand,
105    state: &State,
106) -> Result<ControlFlow, anyhow::Error> {
107    let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
108        // If just a single format is provided, the user should specify `key=true` if they expect a
109        // in each message.  However for format=avro, we will conveniently set this based
110        // on the presence of the key schema in the registry, so this argument is not required.
111        let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
112        let format_type = format_str.as_str().try_into()?;
113        RecordFormat {
114            key: format_type,
115            value: format_type,
116            requires_key,
117        }
118    } else {
119        let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
120        let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
121        RecordFormat {
122            key: key_format,
123            value: value_format,
124            requires_key: true,
125        }
126    };
127
128    let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
129        (Some(sink), None) => Topic::FromSink(sink),
130        (None, Some(topic)) => Topic::Named(topic),
131        (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
132        (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
133    };
134
135    let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
136
137    let header_keys: Vec<_> = cmd
138        .args
139        .opt_string("headers")
140        .map(|s| s.split(',').map(str::to_owned).collect())
141        .unwrap_or_default();
142
143    let expected_messages = cmd.input;
144    if expected_messages.len() == 0 {
145        // verify with 0 messages doesn't check that no messages have been written -
146        // it 'verifies' 0 messages and trivially returns true
147        bail!("kafka-verify-data requires a non-empty list of expected messages");
148    }
149    let partial_search = cmd.args.opt_parse("partial-search")?;
150    let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
151    cmd.args.done()?;
152
153    let topic: String = match &source {
154        Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
155        Topic::Named(name) => name.clone(),
156    };
157
158    println!("Verifying results in Kafka topic {}", topic);
159
160    let mut config = state.kafka_config.clone();
161    config.set("enable.auto.offset.store", "false");
162
163    let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
164    consumer
165        .subscribe(&[&topic])
166        .context("subscribing to kafka topic")?;
167
168    let (mut stream_messages_remaining, stream_timeout) = match partial_search {
169        Some(size) => (size, state.default_timeout),
170        None => (expected_messages.len(), Duration::from_secs(15)),
171    };
172
173    let timeout = cmp::max(state.default_timeout, stream_timeout);
174
175    let message_stream = consumer.stream().timeout(timeout);
176    pin!(message_stream);
177
178    // Collect all messages that arrive without timing out. If we trip
179    // the timeout, suppress the error and return what we have. This
180    // is nicer than returning "timeout expired", as the user will
181    // instead get an error message about the expected messages that
182    // were missing.
183    let mut actual_bytes = vec![];
184
185    let start = std::time::Instant::now();
186    let mut topic_created = false;
187
188    while stream_messages_remaining > 0 {
189        match message_stream.next().await {
190            Some(Ok(message)) => {
191                let message = match message {
192                    // We create topics after creating sinks, so we permit
193                    // retries here while waiting for the topic to get created.
194                    Err(KafkaError::MessageConsumption(
195                        RDKafkaErrorCode::UnknownTopicOrPartition,
196                    )) if start.elapsed() < timeout && !topic_created => {
197                        println!("waiting for Kafka topic creation...");
198                        continue;
199                    }
200                    e => e?,
201                };
202
203                stream_messages_remaining -= 1;
204                topic_created = true;
205
206                consumer
207                    .store_offset_from_message(&message)
208                    .context("storing message offset")?;
209
210                let mut headers = vec![];
211                for header_key in &header_keys {
212                    // Expect a unique header with the given key and a UTF8-formatted body.
213                    let hs = message.headers().context("expected headers for message")?;
214                    let mut hs = hs.iter().filter(|i| i.key == header_key);
215                    let h = hs.next();
216                    if hs.next().is_some() {
217                        bail!("expected at most one header with key {header_key}");
218                    }
219                    match h {
220                        None => headers.push("<missing>".into()),
221                        Some(h) => {
222                            let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
223                            headers.push(value.into());
224                        }
225                    }
226                }
227
228                actual_bytes.push(Record {
229                    headers,
230                    key: message.key().map(|b| b.to_owned()),
231                    value: message.payload().map(|b| b.to_owned()),
232                    partition: Some(message.partition()),
233                });
234            }
235            Some(Err(e)) => {
236                println!("Received error from Kafka stream consumer: {}", e);
237                break;
238            }
239            None => {
240                break;
241            }
242        }
243    }
244
245    let key_schema = if let Format::Avro = format.key {
246        let schema = state
247            .ccsr_client
248            .get_schema_by_subject(&format!("{}-key", topic))
249            .await
250            .ok()
251            .map(|key_schema| avro::parse_schema(&key_schema.raw).context("parsing avro schema"))
252            .transpose()?;
253        // for avro, we can determine if a key is required based on the presence of the key schema
254        // rather than requiring the user to specify the key=true flag
255        if schema.is_some() {
256            format.requires_key = true;
257        }
258        schema
259    } else {
260        None
261    };
262    let value_schema = if let Format::Avro = format.value {
263        let val_schema = state
264            .ccsr_client
265            .get_schema_by_subject(&format!("{}-value", topic))
266            .await
267            .context("fetching schema")?
268            .raw;
269        Some(avro::parse_schema(&val_schema).context("parsing avro schema")?)
270    } else {
271        None
272    };
273
274    let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
275
276    if sort_messages {
277        actual_messages.sort_by_key(|r| format!("{:?}", r));
278    }
279
280    if debug_print_only {
281        bail!(
282            "records in sink:\n{}",
283            actual_messages
284                .into_iter()
285                .map(|a| format!("{:#?}", a))
286                .collect::<Vec<_>>()
287                .join("\n")
288        );
289    }
290
291    let expected = parse_expected_messages(
292        expected_messages,
293        key_schema,
294        value_schema,
295        &format,
296        &header_keys,
297    )?;
298
299    verify_with_partial_search(
300        &expected,
301        &actual_messages,
302        &state.regex,
303        &state.regex_replacement,
304        partial_search.is_some(),
305    )?;
306
307    Ok(ControlFlow::Continue)
308}
309
310/// Expect and split out `n` whitespace-delimited headers before the main contents of the 'expect' row.
311fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
312    let whitespace = Regex::new("\\s+").expect("building known-valid regex");
313    let mut parts = whitespace.splitn(input, n_headers + 1);
314    let mut headers = Vec::with_capacity(n_headers);
315    for _ in 0..n_headers {
316        headers.push(
317            parts
318                .next()
319                .context("expected another header in the input")?
320                .to_string(),
321        )
322    }
323    let rest = parts
324        .next()
325        .context("expected some contents after any message headers")?;
326
327    ensure!(
328        parts.next().is_none(),
329        "more than n+1 elements from a call to splitn(_, n+1)"
330    );
331
332    Ok((headers, rest))
333}
334
335fn decode_messages(
336    actual_bytes: Vec<Record<Vec<u8>>>,
337    key_schema: &Option<mz_avro::Schema>,
338    value_schema: &Option<mz_avro::Schema>,
339    format: &RecordFormat,
340) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
341    let mut actual_messages = vec![];
342
343    for record in actual_bytes {
344        let Record { key, value, .. } = record;
345        let key = if format.requires_key {
346            match (key, format.key) {
347                (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
348                    avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
349                ))),
350                (Some(bytes), Format::Json) => Some(DecodedValue::Json(
351                    serde_json::from_slice(&bytes).context("decoding json")?,
352                )),
353                (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
354                (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
355                (None, _) if format.requires_key => bail!("empty message key"),
356                (None, _) => None,
357            }
358        } else {
359            None
360        };
361
362        let value = match (value, format.value) {
363            (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
364                avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
365            ))),
366            (Some(bytes), Format::Json) => Some(DecodedValue::Json(
367                serde_json::from_slice(&bytes).context("decoding json")?,
368            )),
369            (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
370            (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
371            (None, _) => None,
372        };
373
374        actual_messages.push(Record {
375            headers: record.headers.clone(),
376            key,
377            value,
378            partition: record.partition,
379        });
380    }
381
382    Ok(actual_messages)
383}
384
385fn parse_expected_messages(
386    expected_messages: Vec<String>,
387    key_schema: Option<mz_avro::Schema>,
388    value_schema: Option<mz_avro::Schema>,
389    format: &RecordFormat,
390    header_keys: &[String],
391) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
392    let mut expected = vec![];
393
394    for msg in expected_messages {
395        let (headers, content) = split_headers(&msg, header_keys.len())?;
396        let mut content = content.as_bytes();
397        let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
398
399        let key = if format.requires_key {
400            let key: serde_json::Value = deserializer
401                .next()
402                .context("key missing in input line")?
403                .context("parsing json")?;
404
405            Some(match format.key {
406                Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
407                    &key,
408                    key_schema.as_ref().unwrap().top_node(),
409                )?)),
410                Format::Json => DecodedValue::Json(key),
411                Format::Bytes => {
412                    unimplemented!("bytes format not yet supported in tests")
413                }
414                Format::Text => DecodedValue::Text(
415                    key.as_str()
416                        .map(|s| s.to_string())
417                        .unwrap_or_else(|| key.to_string()),
418                ),
419            })
420        } else {
421            None
422        };
423
424        let value = match deserializer.next().transpose().context("parsing json")? {
425            None => None,
426            Some(value) if value.as_str() == Some("<null>") => None,
427            Some(value) => match format.value {
428                Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
429                    &value,
430                    value_schema.as_ref().unwrap().top_node(),
431                )?))),
432                Format::Json => Some(DecodedValue::Json(value)),
433                Format::Bytes => {
434                    unimplemented!("bytes format not yet supported in tests")
435                }
436                Format::Text => Some(DecodedValue::Text(value.to_string())),
437            },
438        };
439
440        let content =
441            str::from_utf8(content).context("internal error: contents were previously a string")?;
442        let partition = match content.trim().split_once("=") {
443            None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
444            None => None,
445            Some((label, partition)) => {
446                if label != "partition" {
447                    bail!("partition expectation has unexpected label: {label}")
448                }
449                Some(partition.parse().context("parsing expected partition")?)
450            }
451        };
452
453        expected.push(Record {
454            headers,
455            key,
456            value,
457            partition,
458        });
459    }
460
461    Ok(expected)
462}
463
464fn verify_with_partial_search<A>(
465    expected: &[Record<A>],
466    actual: &[Record<A>],
467    regex: &Option<Regex>,
468    regex_replacement: &String,
469    partial_search: bool,
470) -> Result<(), anyhow::Error>
471where
472    A: Debug + Clone,
473{
474    let mut expected = expected.iter();
475    let mut actual = actual.iter();
476    let mut index = 0..;
477
478    let mut found_beginning = !partial_search;
479    let mut expected_item = expected.next();
480    let mut actual_item = actual.next();
481    loop {
482        let i = index.next().expect("known to exist");
483        match (expected_item, actual_item) {
484            (Some(e), Some(a)) => {
485                let mut a = a.clone();
486                if e.partition.is_none() {
487                    a.partition = None;
488                }
489                let e_str = format!("{:#?}", e);
490                let a_str = match &regex {
491                    Some(regex) => regex
492                        .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
493                        .to_string(),
494                    _ => format!("{:#?}", a),
495                };
496
497                if e_str != a_str {
498                    if found_beginning {
499                        bail!(
500                            "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
501                            i,
502                            e_str,
503                            a_str,
504                        );
505                    }
506                    actual_item = actual.next();
507                } else {
508                    found_beginning = true;
509                    expected_item = expected.next();
510                    actual_item = actual.next();
511                }
512            }
513            (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
514            (None, Some(a)) => {
515                if !partial_search {
516                    bail!("extra record {}: {:#?}", i, a);
517                }
518                break;
519            }
520            (None, None) => break,
521        }
522    }
523    let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
524    let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
525
526    if !expected.is_empty() {
527        bail!("missing records:\n{}", expected.join("\n"))
528    } else if !actual.is_empty() && !partial_search {
529        bail!("extra records:\n{}", actual.join("\n"))
530    } else {
531        Ok(())
532    }
533}