mz_interchange/avro/
encode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::fmt;
12use std::sync::LazyLock;
13
14use anyhow::Ok;
15use byteorder::{NetworkEndian, WriteBytesExt};
16use chrono::Timelike;
17use itertools::Itertools;
18use mz_avro::Schema;
19use mz_avro::types::{DecimalValue, ToAvro, Value};
20use mz_ore::cast::CastFrom;
21use mz_repr::adt::jsonb::JsonbRef;
22use mz_repr::adt::numeric::{self, NUMERIC_AGG_MAX_PRECISION, NUMERIC_DATUM_MAX_PRECISION};
23use mz_repr::{CatalogItemId, ColumnName, Datum, RelationDesc, Row, SqlColumnType, SqlScalarType};
24use serde_json::json;
25
26use crate::encode::{Encode, TypedDatum, column_names_and_types};
27use crate::envelopes::{self, DBZ_ROW_TYPE_ID, ENVELOPE_CUSTOM_NAMES};
28use crate::json::{SchemaOptions, build_row_schema_json};
29
30// TODO(rkhaitan): this schema intentionally omits the data_collections field
31// that is typically present in Debezium transaction metadata topics. See
32// https://debezium.io/documentation/reference/connectors/postgresql.html#postgresql-transaction-metadata
33// for more information. We chose to omit this field because it is redundant
34// for sinks where each consistency topic corresponds to exactly one sink.
35// We will need to add it in order to be able to reingest sinked topics.
36static DEBEZIUM_TRANSACTION_SCHEMA: LazyLock<Schema> = LazyLock::new(|| {
37    Schema::parse(&json!({
38        "type": "record",
39        "name": "envelope",
40        "fields": [
41            {
42                "name": "id",
43                "type": "string"
44            },
45            {
46                "name": "status",
47                "type": "string"
48            },
49            {
50                "name": "event_count",
51                "type": [
52                  "null",
53                  "long"
54                ]
55            },
56            {
57                "name": "data_collections",
58                "type": [
59                    "null",
60                    {
61                        "type": "array",
62                        "items": {
63                            "type": "record",
64                            "name": "data_collection",
65                            "fields": [
66                                {
67                                    "name": "data_collection",
68                                    "type": "string"
69                                },
70                                {
71                                    "name": "event_count",
72                                    "type": "long"
73                                },
74                            ]
75                        }
76                    }
77                ],
78                "default": null,
79            },
80        ]
81    }))
82    .expect("valid schema constructed")
83});
84
85fn encode_avro_header(buf: &mut Vec<u8>, schema_id: i32) {
86    // The first byte is a magic byte (0) that indicates the Confluent
87    // serialization format version, and the next four bytes are a
88    // 32-bit schema ID.
89    //
90    // https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format
91    buf.write_u8(0).expect("writing to vec cannot fail");
92    buf.write_i32::<NetworkEndian>(schema_id)
93        .expect("writing to vec cannot fail");
94}
95
96fn encode_message_unchecked(
97    schema_id: i32,
98    row: Row,
99    schema: &Schema,
100    columns: &[(ColumnName, SqlColumnType)],
101) -> Vec<u8> {
102    let mut buf = vec![];
103    encode_avro_header(&mut buf, schema_id);
104    let value = encode_datums_as_avro(row.iter(), columns);
105    mz_avro::encode_unchecked(&value, schema, &mut buf);
106    buf
107}
108
109#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
110pub enum DocTarget {
111    Type(CatalogItemId),
112    Field {
113        object_id: CatalogItemId,
114        column_name: ColumnName,
115    },
116}
117
118impl DocTarget {
119    fn id(&self) -> CatalogItemId {
120        match self {
121            DocTarget::Type(object_id) => *object_id,
122            DocTarget::Field { object_id, .. } => *object_id,
123        }
124    }
125}
126
127/// Generates an Avro schema
128pub struct AvroSchemaGenerator {
129    columns: Vec<(ColumnName, SqlColumnType)>,
130    schema: Schema,
131}
132
133impl fmt::Debug for AvroSchemaGenerator {
134    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135        f.debug_struct("SchemaGenerator")
136            .field("writer_schema", &self.schema())
137            .finish()
138    }
139}
140
141impl AvroSchemaGenerator {
142    pub fn new(
143        desc: RelationDesc,
144        debezium: bool,
145        mut doc_options: BTreeMap<DocTarget, String>,
146        avro_fullname: &str,
147        set_null_defaults: bool,
148        sink_from: Option<CatalogItemId>,
149        use_custom_envelope_names: bool,
150    ) -> Result<Self, anyhow::Error> {
151        let mut columns = column_names_and_types(desc);
152        if debezium {
153            columns = envelopes::dbz_envelope(columns);
154            // With DEBEZIUM envelope the message is wrapped into "before" and "after"
155            // with `DBZ_ROW_TYPE_ID` instead of `sink_from`.
156            // Replacing comments for the columns and type in `sink_from` to `DBZ_ROW_TYPE_ID`.
157            if let Some(sink_from_id) = sink_from {
158                let mut new_column_docs = BTreeMap::new();
159                doc_options.iter().for_each(|(k, v)| {
160                    if k.id() == sink_from_id {
161                        match k {
162                            DocTarget::Field { column_name, .. } => {
163                                new_column_docs.insert(
164                                    DocTarget::Field {
165                                        object_id: DBZ_ROW_TYPE_ID,
166                                        column_name: column_name.clone(),
167                                    },
168                                    v.clone(),
169                                );
170                            }
171                            DocTarget::Type(_) => {
172                                new_column_docs.insert(DocTarget::Type(DBZ_ROW_TYPE_ID), v.clone());
173                            }
174                        }
175                    }
176                });
177                doc_options.append(&mut new_column_docs);
178                doc_options.retain(|k, _v| k.id() != sink_from_id);
179            }
180        }
181        let custom_names = if use_custom_envelope_names {
182            &ENVELOPE_CUSTOM_NAMES
183        } else {
184            &BTreeMap::new()
185        };
186        let row_schema = build_row_schema_json(
187            &columns,
188            avro_fullname,
189            custom_names,
190            sink_from,
191            &SchemaOptions {
192                set_null_defaults,
193                doc_comments: doc_options,
194            },
195        )?;
196        let schema = Schema::parse(&row_schema).expect("valid schema constructed");
197        Ok(AvroSchemaGenerator { columns, schema })
198    }
199
200    pub fn schema(&self) -> &Schema {
201        &self.schema
202    }
203
204    pub fn columns(&self) -> &[(ColumnName, SqlColumnType)] {
205        &self.columns
206    }
207}
208
209/// Manages encoding of Avro-encoded bytes.
210pub struct AvroEncoder {
211    columns: Vec<(ColumnName, SqlColumnType)>,
212    schema: Schema,
213    schema_id: i32,
214}
215
216impl fmt::Debug for AvroEncoder {
217    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
218        f.debug_struct("AvroEncoder")
219            .field("writer_schema", &self.schema)
220            .finish()
221    }
222}
223
224impl AvroEncoder {
225    pub fn new(desc: RelationDesc, debezium: bool, schema: &str, schema_id: i32) -> Self {
226        let mut columns = column_names_and_types(desc);
227        if debezium {
228            columns = envelopes::dbz_envelope(columns);
229        };
230        AvroEncoder {
231            columns,
232            schema: Schema::parse(&serde_json::from_str(schema).expect("valid schema json"))
233                .expect("valid schema"),
234            schema_id,
235        }
236    }
237}
238
239impl Encode for AvroEncoder {
240    fn encode_unchecked(&self, row: Row) -> Vec<u8> {
241        encode_message_unchecked(self.schema_id, row, &self.schema, &self.columns)
242    }
243
244    fn hash(&self, buf: &[u8]) -> u64 {
245        // Compute a stable hash by ignoring the avro header which might contain a
246        // non-deterministic schema id.
247        let (_schema_id, payload) = crate::confluent::extract_avro_header(buf).unwrap();
248        seahash::hash(payload)
249    }
250}
251
252/// Encodes a sequence of `Datum` as Avro (key and value), using supplied column names and types.
253pub fn encode_datums_as_avro<'a, I>(datums: I, names_types: &[(ColumnName, SqlColumnType)]) -> Value
254where
255    I: IntoIterator<Item = Datum<'a>>,
256{
257    let value_fields: Vec<(String, Value)> = names_types
258        .iter()
259        .zip_eq(datums)
260        .map(|((name, typ), datum)| {
261            let name = name.as_str().to_owned();
262            (name, TypedDatum::new(datum, typ).avro())
263        })
264        .collect();
265    let v = Value::Record(value_fields);
266    v
267}
268
269impl<'a> mz_avro::types::ToAvro for TypedDatum<'a> {
270    fn avro(self) -> Value {
271        let TypedDatum { datum, typ } = self;
272        if typ.nullable && datum.is_null() {
273            Value::Union {
274                index: 0,
275                inner: Box::new(Value::Null),
276                n_variants: 2,
277                null_variant: Some(0),
278            }
279        } else {
280            let mut val = match &typ.scalar_type {
281                SqlScalarType::AclItem => Value::String(datum.unwrap_acl_item().to_string()),
282                SqlScalarType::Bool => Value::Boolean(datum.unwrap_bool()),
283                SqlScalarType::PgLegacyChar => {
284                    Value::Fixed(1, datum.unwrap_uint8().to_le_bytes().into())
285                }
286                SqlScalarType::Int16 => Value::Int(i32::from(datum.unwrap_int16())),
287                SqlScalarType::Int32 => Value::Int(datum.unwrap_int32()),
288                SqlScalarType::Int64 => Value::Long(datum.unwrap_int64()),
289                SqlScalarType::UInt16 => {
290                    Value::Fixed(2, datum.unwrap_uint16().to_be_bytes().into())
291                }
292                SqlScalarType::UInt32 => {
293                    Value::Fixed(4, datum.unwrap_uint32().to_be_bytes().into())
294                }
295                SqlScalarType::UInt64 => {
296                    Value::Fixed(8, datum.unwrap_uint64().to_be_bytes().into())
297                }
298                SqlScalarType::Oid
299                | SqlScalarType::RegClass
300                | SqlScalarType::RegProc
301                | SqlScalarType::RegType => {
302                    Value::Fixed(4, datum.unwrap_uint32().to_be_bytes().into())
303                }
304                SqlScalarType::Float32 => Value::Float(datum.unwrap_float32()),
305                SqlScalarType::Float64 => Value::Double(datum.unwrap_float64()),
306                SqlScalarType::Numeric { max_scale } => {
307                    let mut d = datum.unwrap_numeric().0;
308                    let (unscaled, precision, scale) = match max_scale {
309                        Some(max_scale) => {
310                            // Values must be rescaled to resaturate trailing zeroes
311                            numeric::rescale(&mut d, max_scale.into_u8()).unwrap();
312                            (
313                                numeric::numeric_to_twos_complement_be(d).to_vec(),
314                                NUMERIC_DATUM_MAX_PRECISION,
315                                max_scale.into_u8(),
316                            )
317                        }
318                        // Decimals without specified scale must nonetheless be
319                        // expressed as a fixed scale, so we write everything as
320                        // a 78-digit number with a scale of 39, which
321                        // definitively expresses all valid numeric values.
322                        None => (
323                            numeric::numeric_to_twos_complement_wide(d).to_vec(),
324                            NUMERIC_AGG_MAX_PRECISION,
325                            NUMERIC_DATUM_MAX_PRECISION,
326                        ),
327                    };
328                    Value::Decimal(DecimalValue {
329                        unscaled,
330                        precision: usize::cast_from(precision),
331                        scale: usize::cast_from(scale),
332                    })
333                }
334                SqlScalarType::Date => Value::Date(datum.unwrap_date().unix_epoch_days()),
335                SqlScalarType::Time => Value::Long({
336                    let time = datum.unwrap_time();
337                    i64::from(time.num_seconds_from_midnight()) * 1_000_000
338                        + i64::from(time.nanosecond()) / 1_000
339                }),
340                SqlScalarType::Timestamp { .. } => {
341                    Value::Timestamp(datum.unwrap_timestamp().to_naive())
342                }
343                SqlScalarType::TimestampTz { .. } => {
344                    Value::Timestamp(datum.unwrap_timestamptz().to_naive())
345                }
346                // SQL intervals and Avro durations differ quite a lot (signed
347                // vs unsigned, different int sizes), so SQL intervals are their
348                // own bespoke type.
349                SqlScalarType::Interval => Value::Fixed(16, {
350                    let iv = datum.unwrap_interval();
351                    let mut buf = Vec::with_capacity(16);
352                    buf.extend(iv.months.to_le_bytes());
353                    buf.extend(iv.days.to_le_bytes());
354                    buf.extend(iv.micros.to_le_bytes());
355                    debug_assert_eq!(buf.len(), 16);
356                    buf
357                }),
358                SqlScalarType::Bytes => Value::Bytes(Vec::from(datum.unwrap_bytes())),
359                SqlScalarType::String
360                | SqlScalarType::VarChar { .. }
361                | SqlScalarType::PgLegacyName => Value::String(datum.unwrap_str().to_owned()),
362                SqlScalarType::Char { length } => {
363                    let s = mz_repr::adt::char::format_str_pad(datum.unwrap_str(), *length);
364                    Value::String(s)
365                }
366                SqlScalarType::Jsonb => Value::Json(JsonbRef::from_datum(datum).to_serde_json()),
367                SqlScalarType::Uuid => Value::Uuid(datum.unwrap_uuid()),
368                ty @ (SqlScalarType::Array(..)
369                | SqlScalarType::Int2Vector
370                | SqlScalarType::List { .. }) => {
371                    let list = match ty {
372                        SqlScalarType::Array(_) | SqlScalarType::Int2Vector => {
373                            datum.unwrap_array().elements()
374                        }
375                        SqlScalarType::List { .. } => datum.unwrap_list(),
376                        _ => unreachable!(),
377                    };
378
379                    let values = list
380                        .into_iter()
381                        .map(|datum| {
382                            TypedDatum::new(
383                                datum,
384                                &SqlColumnType {
385                                    nullable: true,
386                                    scalar_type: ty.unwrap_collection_element_type().clone(),
387                                },
388                            )
389                            .avro()
390                        })
391                        .collect();
392                    Value::Array(values)
393                }
394                SqlScalarType::Map { value_type, .. } => {
395                    let map = datum.unwrap_map();
396                    let elements = map
397                        .into_iter()
398                        .map(|(key, datum)| {
399                            let value = TypedDatum::new(
400                                datum,
401                                &SqlColumnType {
402                                    nullable: true,
403                                    scalar_type: (**value_type).clone(),
404                                },
405                            )
406                            .avro();
407                            (key.to_string(), value)
408                        })
409                        .collect();
410                    Value::Map(elements)
411                }
412                SqlScalarType::Record { fields, .. } => {
413                    let list = datum.unwrap_list();
414                    let fields = fields
415                        .iter()
416                        .zip(&list)
417                        .map(|((name, typ), datum)| {
418                            let name = name.to_string();
419                            let datum = TypedDatum::new(datum, typ);
420                            let value = datum.avro();
421                            (name, value)
422                        })
423                        .collect();
424                    Value::Record(fields)
425                }
426                SqlScalarType::MzTimestamp => {
427                    Value::String(datum.unwrap_mz_timestamp().to_string())
428                }
429                SqlScalarType::Range { .. } => Value::String(datum.unwrap_range().to_string()),
430                SqlScalarType::MzAclItem => Value::String(datum.unwrap_mz_acl_item().to_string()),
431            };
432            if typ.nullable {
433                val = Value::Union {
434                    index: 1,
435                    inner: Box::new(val),
436                    n_variants: 2,
437                    null_variant: Some(0),
438                };
439            }
440            val
441        }
442    }
443}
444
445pub fn get_debezium_transaction_schema() -> &'static Schema {
446    &DEBEZIUM_TRANSACTION_SCHEMA
447}
448
449pub fn encode_debezium_transaction_unchecked(
450    schema_id: i32,
451    collection: &str,
452    id: &str,
453    status: &str,
454    message_count: Option<i64>,
455) -> Vec<u8> {
456    let mut buf = Vec::new();
457    encode_avro_header(&mut buf, schema_id);
458
459    let transaction_id = Value::String(id.to_owned());
460    let status = Value::String(status.to_owned());
461    let event_count = match message_count {
462        None => Value::Union {
463            index: 0,
464            inner: Box::new(Value::Null),
465            n_variants: 2,
466            null_variant: Some(0),
467        },
468        Some(count) => Value::Union {
469            index: 1,
470            inner: Box::new(Value::Long(count)),
471            n_variants: 2,
472            null_variant: Some(0),
473        },
474    };
475
476    let data_collections = if let Some(message_count) = message_count {
477        let collection = Value::Record(vec![
478            ("data_collection".into(), Value::String(collection.into())),
479            ("event_count".into(), Value::Long(message_count)),
480        ]);
481        Value::Union {
482            index: 1,
483            inner: Box::new(Value::Array(vec![collection])),
484            n_variants: 2,
485            null_variant: Some(0),
486        }
487    } else {
488        Value::Union {
489            index: 0,
490            inner: Box::new(Value::Null),
491            n_variants: 2,
492            null_variant: Some(0),
493        }
494    };
495
496    let record_contents = vec![
497        ("id".into(), transaction_id),
498        ("status".into(), status),
499        ("event_count".into(), event_count),
500        ("data_collections".into(), data_collections),
501    ];
502    let avro = Value::Record(record_contents);
503    debug_assert!(avro.validate(DEBEZIUM_TRANSACTION_SCHEMA.top_node()));
504    mz_avro::encode_unchecked(&avro, &DEBEZIUM_TRANSACTION_SCHEMA, &mut buf);
505    buf
506}