Skip to main content

mz_testdrive/action/kafka/
verify_data.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::fmt::Debug;
11use std::time::Duration;
12use std::{cmp, str};
13
14use anyhow::{Context, bail, ensure};
15use rdkafka::consumer::{Consumer, StreamConsumer};
16use rdkafka::error::KafkaError;
17use rdkafka::message::{Headers, Message};
18use rdkafka::types::RDKafkaErrorCode;
19use regex::Regex;
20use tokio::pin;
21use tokio_stream::StreamExt;
22
23use crate::action::{ControlFlow, State};
24use crate::format::avro::{self, DebugValue};
25use crate::parser::BuiltinCommand;
26
27#[derive(Debug, Clone, Copy)]
28enum Format {
29    Avro,
30    Json,
31    Bytes,
32    Text,
33}
34
35impl TryFrom<&str> for Format {
36    type Error = anyhow::Error;
37
38    fn try_from(value: &str) -> Result<Self, Self::Error> {
39        match value {
40            "avro" => Ok(Format::Avro),
41            "json" => Ok(Format::Json),
42            "bytes" => Ok(Format::Bytes),
43            "text" => Ok(Format::Text),
44            f => bail!("unknown format: {}", f),
45        }
46    }
47}
48
49#[derive(Debug)]
50struct RecordFormat {
51    key: Format,
52    value: Format,
53    requires_key: bool,
54}
55
56#[allow(dead_code)]
57#[derive(Debug, Clone)]
58enum DecodedValue {
59    Avro(DebugValue),
60    Json(serde_json::Value),
61    Bytes(Vec<u8>),
62    Text(String),
63}
64
65enum Topic {
66    FromSink(String),
67    Named(String),
68}
69
70#[derive(Debug, Clone)]
71struct Record<A> {
72    headers: Vec<String>,
73    key: Option<A>,
74    value: Option<A>,
75    partition: Option<i32>,
76}
77
78async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result<String, anyhow::Error> {
79    let query = format!(
80        "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \
81        ON mz_sinks.id = mz_kafka_sinks.id \
82        JOIN mz_schemas s ON s.id = mz_sinks.schema_id \
83        LEFT JOIN mz_databases d ON d.id = s.database_id \
84        WHERE d.name = $1 \
85        AND s.name = $2 \
86        AND mz_sinks.name = $3",
87        topic_field
88    );
89    let sink_fields: Vec<&str> = sink.split('.').collect();
90    let result = state
91        .materialize
92        .pgclient
93        .query_one(
94            query.as_str(),
95            &[&sink_fields[0], &sink_fields[1], &sink_fields[2]],
96        )
97        .await
98        .context("retrieving topic name")?
99        .get(topic_field);
100    Ok(result)
101}
102
103pub async fn run_verify_data(
104    mut cmd: BuiltinCommand,
105    state: &State,
106) -> Result<ControlFlow, anyhow::Error> {
107    let mut format = if let Some(format_str) = cmd.args.opt_string("format") {
108        // If just a single format is provided, the user should specify `key=true` if they expect a
109        // in each message.  However for format=avro, we will conveniently set this based
110        // on the presence of the key schema in the registry, so this argument is not required.
111        let requires_key: bool = cmd.args.opt_bool("key")?.unwrap_or(false);
112        let format_type = format_str.as_str().try_into()?;
113        RecordFormat {
114            key: format_type,
115            value: format_type,
116            requires_key,
117        }
118    } else {
119        let key_format = cmd.args.string("key-format")?.as_str().try_into()?;
120        let value_format = cmd.args.string("value-format")?.as_str().try_into()?;
121        RecordFormat {
122            key: key_format,
123            value: value_format,
124            requires_key: true,
125        }
126    };
127
128    let source = match (cmd.args.opt_string("sink"), cmd.args.opt_string("topic")) {
129        (Some(sink), None) => Topic::FromSink(sink),
130        (None, Some(topic)) => Topic::Named(topic),
131        (Some(_), Some(_)) => bail!("Can't provide both `source` and `topic` to kafka-verify-data"),
132        (None, None) => bail!("kafka-verify-data expects either `source` or `topic`"),
133    };
134
135    let sort_messages = cmd.args.opt_bool("sort-messages")?.unwrap_or(false);
136
137    let header_keys: Vec<_> = cmd
138        .args
139        .opt_string("headers")
140        .map(|s| s.split(',').map(str::to_owned).collect())
141        .unwrap_or_default();
142
143    let expected_messages = cmd.input;
144    if expected_messages.len() == 0 {
145        // verify with 0 messages doesn't check that no messages have been written -
146        // it 'verifies' 0 messages and trivially returns true
147        bail!("kafka-verify-data requires a non-empty list of expected messages");
148    }
149    let partial_search = cmd.args.opt_parse("partial-search")?;
150    let debug_print_only = cmd.args.opt_bool("debug-print-only")?.unwrap_or(false);
151    cmd.args.done()?;
152
153    let topic: String = match &source {
154        Topic::FromSink(sink) => get_topic(sink, "topic", state).await?,
155        Topic::Named(name) => name.clone(),
156    };
157
158    println!("Verifying results in Kafka topic {}", topic);
159
160    let mut config = state.kafka_config.clone();
161    config.set("enable.auto.offset.store", "false");
162
163    let consumer: StreamConsumer = config.create().context("creating kafka consumer")?;
164    consumer
165        .subscribe(&[&topic])
166        .context("subscribing to kafka topic")?;
167
168    let (mut stream_messages_remaining, stream_timeout) = match partial_search {
169        Some(size) => (size, state.timeout),
170        None => (expected_messages.len(), Duration::from_secs(15)),
171    };
172
173    let timeout = cmp::max(state.timeout, stream_timeout);
174
175    let message_stream = consumer.stream().timeout(timeout);
176    pin!(message_stream);
177
178    // Collect all messages that arrive without timing out. If we trip
179    // the timeout, suppress the error and return what we have. This
180    // is nicer than returning "timeout expired", as the user will
181    // instead get an error message about the expected messages that
182    // were missing.
183    let mut actual_bytes = vec![];
184
185    let start = std::time::Instant::now();
186    let mut topic_created = false;
187
188    while stream_messages_remaining > 0 {
189        match message_stream.next().await {
190            Some(Ok(message)) => {
191                let message = match message {
192                    // We create topics after creating sinks, so we permit
193                    // retries here while waiting for the topic to get created.
194                    Err(KafkaError::MessageConsumption(
195                        RDKafkaErrorCode::UnknownTopicOrPartition,
196                    )) if start.elapsed() < timeout && !topic_created => {
197                        println!("waiting for Kafka topic creation...");
198                        continue;
199                    }
200                    e => e?,
201                };
202
203                stream_messages_remaining -= 1;
204                topic_created = true;
205
206                consumer
207                    .store_offset_from_message(&message)
208                    .context("storing message offset")?;
209
210                let mut headers = vec![];
211                for header_key in &header_keys {
212                    // Expect a unique header with the given key and a UTF8-formatted body.
213                    let hs = message.headers().context("expected headers for message")?;
214                    let mut hs = hs.iter().filter(|i| i.key == header_key);
215                    let h = hs.next();
216                    if hs.next().is_some() {
217                        bail!("expected at most one header with key {header_key}");
218                    }
219                    match h {
220                        None => headers.push("<missing>".into()),
221                        Some(h) => {
222                            let value = str::from_utf8(h.value.unwrap_or(b"<null>"))?;
223                            headers.push(value.into());
224                        }
225                    }
226                }
227
228                actual_bytes.push(Record {
229                    headers,
230                    key: message.key().map(|b| b.to_owned()),
231                    value: message.payload().map(|b| b.to_owned()),
232                    partition: Some(message.partition()),
233                });
234            }
235            Some(Err(e)) => {
236                println!("Received error from Kafka stream consumer: {}", e);
237                break;
238            }
239            None => {
240                break;
241            }
242        }
243    }
244
245    let key_schema = if let Format::Avro = format.key {
246        let schema = state
247            .ccsr_client
248            .get_schema_by_subject(&format!("{}-key", topic))
249            .await
250            .ok()
251            .map(|key_schema| {
252                avro::parse_schema(&key_schema.raw, &[]).context("parsing avro schema")
253            })
254            .transpose()?;
255        // for avro, we can determine if a key is required based on the presence of the key schema
256        // rather than requiring the user to specify the key=true flag
257        if schema.is_some() {
258            format.requires_key = true;
259        }
260        schema
261    } else {
262        None
263    };
264    let value_schema = if let Format::Avro = format.value {
265        let val_schema = state
266            .ccsr_client
267            .get_schema_by_subject(&format!("{}-value", topic))
268            .await
269            .context("fetching schema")?
270            .raw;
271        Some(avro::parse_schema(&val_schema, &[]).context("parsing avro schema")?)
272    } else {
273        None
274    };
275
276    let mut actual_messages = decode_messages(actual_bytes, &key_schema, &value_schema, &format)?;
277
278    if sort_messages {
279        actual_messages.sort_by_key(|r| format!("{:?}", r));
280    }
281
282    if debug_print_only {
283        bail!(
284            "records in sink:\n{}",
285            actual_messages
286                .into_iter()
287                .map(|a| format!("{:#?}", a))
288                .collect::<Vec<_>>()
289                .join("\n")
290        );
291    }
292
293    let expected = parse_expected_messages(
294        expected_messages,
295        key_schema,
296        value_schema,
297        &format,
298        &header_keys,
299    )?;
300
301    verify_with_partial_search(
302        &expected,
303        &actual_messages,
304        &state.regex,
305        &state.regex_replacement,
306        partial_search.is_some(),
307    )?;
308
309    Ok(ControlFlow::Continue)
310}
311
312/// Expect and split out `n` whitespace-delimited headers before the main contents of the 'expect' row.
313fn split_headers(input: &str, n_headers: usize) -> anyhow::Result<(Vec<String>, &str)> {
314    let whitespace = Regex::new("\\s+").expect("building known-valid regex");
315    let mut parts = whitespace.splitn(input, n_headers + 1);
316    let mut headers = Vec::with_capacity(n_headers);
317    for _ in 0..n_headers {
318        headers.push(
319            parts
320                .next()
321                .context("expected another header in the input")?
322                .to_string(),
323        )
324    }
325    let rest = parts
326        .next()
327        .context("expected some contents after any message headers")?;
328
329    ensure!(
330        parts.next().is_none(),
331        "more than n+1 elements from a call to splitn(_, n+1)"
332    );
333
334    Ok((headers, rest))
335}
336
337fn decode_messages(
338    actual_bytes: Vec<Record<Vec<u8>>>,
339    key_schema: &Option<mz_avro::Schema>,
340    value_schema: &Option<mz_avro::Schema>,
341    format: &RecordFormat,
342) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
343    let mut actual_messages = vec![];
344
345    for record in actual_bytes {
346        let Record { key, value, .. } = record;
347        let key = if format.requires_key {
348            match (key, format.key) {
349                (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
350                    avro::from_confluent_bytes(key_schema.as_ref().unwrap(), &bytes)?,
351                ))),
352                (Some(bytes), Format::Json) => Some(DecodedValue::Json(
353                    serde_json::from_slice(&bytes).context("decoding json")?,
354                )),
355                (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
356                (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
357                (None, _) if format.requires_key => bail!("empty message key"),
358                (None, _) => None,
359            }
360        } else {
361            None
362        };
363
364        let value = match (value, format.value) {
365            (Some(bytes), Format::Avro) => Some(DecodedValue::Avro(DebugValue(
366                avro::from_confluent_bytes(value_schema.as_ref().unwrap(), &bytes)?,
367            ))),
368            (Some(bytes), Format::Json) => Some(DecodedValue::Json(
369                serde_json::from_slice(&bytes).context("decoding json")?,
370            )),
371            (Some(bytes), Format::Bytes) => Some(DecodedValue::Bytes(bytes)),
372            (Some(bytes), Format::Text) => Some(DecodedValue::Text(String::from_utf8(bytes)?)),
373            (None, _) => None,
374        };
375
376        actual_messages.push(Record {
377            headers: record.headers.clone(),
378            key,
379            value,
380            partition: record.partition,
381        });
382    }
383
384    Ok(actual_messages)
385}
386
387fn parse_expected_messages(
388    expected_messages: Vec<String>,
389    key_schema: Option<mz_avro::Schema>,
390    value_schema: Option<mz_avro::Schema>,
391    format: &RecordFormat,
392    header_keys: &[String],
393) -> Result<Vec<Record<DecodedValue>>, anyhow::Error> {
394    let mut expected = vec![];
395
396    for msg in expected_messages {
397        let (headers, content) = split_headers(&msg, header_keys.len())?;
398        let mut content = content.as_bytes();
399        let mut deserializer = serde_json::Deserializer::from_reader(&mut content).into_iter();
400
401        let key = if format.requires_key {
402            let key: serde_json::Value = deserializer
403                .next()
404                .context("key missing in input line")?
405                .context("parsing json")?;
406
407            Some(match format.key {
408                Format::Avro => DecodedValue::Avro(DebugValue(avro::from_json(
409                    &key,
410                    key_schema.as_ref().unwrap().top_node(),
411                )?)),
412                Format::Json => DecodedValue::Json(key),
413                Format::Bytes => {
414                    unimplemented!("bytes format not yet supported in tests")
415                }
416                Format::Text => DecodedValue::Text(
417                    key.as_str()
418                        .map(|s| s.to_string())
419                        .unwrap_or_else(|| key.to_string()),
420                ),
421            })
422        } else {
423            None
424        };
425
426        let value = match deserializer.next().transpose().context("parsing json")? {
427            None => None,
428            Some(value) if value.as_str() == Some("<null>") => None,
429            Some(value) => match format.value {
430                Format::Avro => Some(DecodedValue::Avro(DebugValue(avro::from_json(
431                    &value,
432                    value_schema.as_ref().unwrap().top_node(),
433                )?))),
434                Format::Json => Some(DecodedValue::Json(value)),
435                Format::Bytes => {
436                    unimplemented!("bytes format not yet supported in tests")
437                }
438                Format::Text => Some(DecodedValue::Text(value.to_string())),
439            },
440        };
441
442        let content =
443            str::from_utf8(content).context("internal error: contents were previously a string")?;
444        let partition = match content.trim().split_once("=") {
445            None if content.trim() != "" => bail!("unexpected cruft at end of line: {content}"),
446            None => None,
447            Some((label, partition)) => {
448                if label != "partition" {
449                    bail!("partition expectation has unexpected label: {label}")
450                }
451                Some(partition.parse().context("parsing expected partition")?)
452            }
453        };
454
455        expected.push(Record {
456            headers,
457            key,
458            value,
459            partition,
460        });
461    }
462
463    Ok(expected)
464}
465
466fn verify_with_partial_search<A>(
467    expected: &[Record<A>],
468    actual: &[Record<A>],
469    regex: &Option<Regex>,
470    regex_replacement: &String,
471    partial_search: bool,
472) -> Result<(), anyhow::Error>
473where
474    A: Debug + Clone,
475{
476    let mut expected = expected.iter();
477    let mut actual = actual.iter();
478    let mut index = 0..;
479
480    let mut found_beginning = !partial_search;
481    let mut expected_item = expected.next();
482    let mut actual_item = actual.next();
483    loop {
484        let i = index.next().expect("known to exist");
485        match (expected_item, actual_item) {
486            (Some(e), Some(a)) => {
487                let mut a = a.clone();
488                if e.partition.is_none() {
489                    a.partition = None;
490                }
491                let e_str = format!("{:#?}", e);
492                let a_str = match &regex {
493                    Some(regex) => regex
494                        .replace_all(&format!("{:#?}", a).to_string(), regex_replacement.as_str())
495                        .to_string(),
496                    _ => format!("{:#?}", a),
497                };
498
499                if e_str != a_str {
500                    if found_beginning {
501                        bail!(
502                            "record {} did not match\nexpected:\n{}\n\nactual:\n{}",
503                            i,
504                            e_str,
505                            a_str,
506                        );
507                    }
508                    actual_item = actual.next();
509                } else {
510                    found_beginning = true;
511                    expected_item = expected.next();
512                    actual_item = actual.next();
513                }
514            }
515            (Some(e), None) => bail!("missing record {}: {:#?}", i, e),
516            (None, Some(a)) => {
517                if !partial_search {
518                    bail!("extra record {}: {:#?}", i, a);
519                }
520                break;
521            }
522            (None, None) => break,
523        }
524    }
525    let expected: Vec<_> = expected.map(|e| format!("{:#?}", e)).collect();
526    let actual: Vec<_> = actual.map(|a| format!("{:#?}", a)).collect();
527
528    if !expected.is_empty() {
529        bail!("missing records:\n{}", expected.join("\n"))
530    } else if !actual.is_empty() && !partial_search {
531        bail!("extra records:\n{}", actual.join("\n"))
532    } else {
533        Ok(())
534    }
535}