Skip to main content

mz_interchange/avro/
decode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15    AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16    GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use serde::{Deserialize, Serialize};
26use tracing::trace;
27use uuid::Uuid;
28
29use crate::avro::AvroSchemaResolver;
30
31/// Manages decoding of Avro-encoded bytes.
32#[derive(Debug)]
33pub struct Decoder {
34    csr_avro: AvroSchemaResolver,
35    debug_name: String,
36    buf1: Vec<u8>,
37    row_buf: Row,
38}
39
40#[cfg(test)]
41mod tests {
42    use mz_ore::assert_err;
43    use mz_repr::{Datum, Row};
44
45    use crate::avro::{Decoder, WriterSchemaProvider};
46
47    #[mz_ore::test(tokio::test)]
48    async fn test_error_followed_by_success() {
49        let schema = r#"{
50"type": "record",
51"name": "test",
52"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
53}"#;
54        let mut decoder =
55            Decoder::new(schema, &[], WriterSchemaProvider::None, "Test".to_string()).unwrap();
56        // This is not a valid Avro blob for the given schema
57        let mut bad_bytes: &[u8] = &[0];
58        assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
59        // This is the blob that will make both ints in the value zero.
60        let mut good_bytes: &[u8] = &[0, 0];
61        // The decode should succeed with the correct value.
62        assert_eq!(
63            decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
64            Row::pack([Datum::Int32(0), Datum::Int32(0)])
65        );
66    }
67}
68
69impl Decoder {
70    /// Creates a new `Decoder`
71    ///
72    /// The provided schema is called the "reader schema", which is the schema
73    /// that we are expecting to use to decode records. The records may indicate
74    /// that they are encoded with a different schema; as long as those.
75    ///
76    /// The `reader_reference_schemas` parameter provides schemas for types that
77    /// are referenced by the reader schema but defined in separate schemas.
78    /// These should be provided in dependency order (dependencies first).
79    pub fn new(
80        reader_schema: &str,
81        reader_reference_schemas: &[String],
82        writer_schemas: crate::avro::WriterSchemaProvider,
83        debug_name: String,
84    ) -> anyhow::Result<Decoder> {
85        let csr_avro =
86            AvroSchemaResolver::new(reader_schema, reader_reference_schemas, writer_schemas)?;
87
88        Ok(Decoder {
89            csr_avro,
90            debug_name,
91            buf1: vec![],
92            row_buf: Row::default(),
93        })
94    }
95
96    /// Decodes Avro-encoded `bytes` into a `Row`.
97    pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
98        // Clear out any bytes that might be left over from
99        // an earlier run. This can happen if the
100        // `dsr.deserialize` call returns an error,
101        // causing us to return early.
102        let mut packer = self.row_buf.packer();
103        // The outer Result describes transient errors so use ? here to propagate
104        let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
105            Ok(ok) => ok,
106            Err(err) => return Ok(Err(err)),
107        };
108        *bytes = bytes2;
109        let dec = AvroFlatDecoder {
110            packer: &mut packer,
111            buf: &mut self.buf1,
112            is_top: true,
113        };
114        let dsr = GeneralDeserializer {
115            schema: resolved_schema.top_node(),
116        };
117        let result = dsr
118            .deserialize(bytes, dec)
119            .with_context(|| {
120                format!(
121                    "unable to decode row {}",
122                    match &csr_schema_id {
123                        Some(id) => format!("(Avro schema id = {:?})", id),
124                        None => "".to_string(),
125                    }
126                )
127            })
128            .map(|_| self.row_buf.clone());
129        if result.is_ok() {
130            trace!(
131                "[customer-data] Decoded row {:?} in {}",
132                self.row_buf, self.debug_name
133            );
134        }
135        Ok(result)
136    }
137}
138
139#[derive(Debug)]
140pub struct AvroFlatDecoder<'a, 'row> {
141    pub packer: &'a mut RowPacker<'row>,
142    pub buf: &'a mut Vec<u8>,
143    pub is_top: bool,
144}
145
146impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
147    type Out = ();
148    #[inline]
149    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
150        self,
151        a: &mut A,
152    ) -> Result<Self::Out, AvroError> {
153        let mut str_buf = std::mem::take(self.buf);
154        let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
155            let mut expected = 0;
156            let mut stash = vec![];
157            // The idea here is that if the deserializer gives us fields in the order we're expecting,
158            // we can decode them directly into the row.
159            // If not, we need to decode them into a Value (the old, slow decoding path) and stash them,
160            // so that we can put everything in the right order at the end.
161            //
162            // TODO(btv) - this is pretty bad, as a misordering at the top of the schema graph will
163            // cause the _entire_ chunk under it to be decoded in the slow way!
164            // Maybe instead, we should decode to separate sub-Rows and then add an API
165            // to Row that just copies in the bytes from another one.
166            while let Some((_name, idx, f)) = a.next_field()? {
167                if idx == expected {
168                    expected += 1;
169                    f.decode_field(AvroFlatDecoder {
170                        packer: rp,
171                        buf: &mut str_buf,
172                        is_top: false,
173                    })?;
174                } else {
175                    let val = f.decode_field(ValueDecoder)?;
176                    stash.push((idx, val));
177                }
178            }
179            stash.sort_by_key(|(idx, _val)| *idx);
180            for (idx, val) in stash {
181                assert!(idx == expected);
182                expected += 1;
183                let dec = AvroFlatDecoder {
184                    packer: rp,
185                    buf: &mut str_buf,
186                    is_top: false,
187                };
188                give_value(dec, &val)?;
189            }
190            Ok(())
191        };
192        if self.is_top {
193            pack_record(self.packer)?;
194        } else {
195            self.packer.push_list_with(pack_record)?;
196        }
197        *self.buf = str_buf;
198        Ok(())
199    }
200    #[inline]
201    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
202        self,
203        idx: usize,
204        n_variants: usize,
205        null_variant: Option<usize>,
206        deserializer: D,
207        reader: &'b mut R,
208    ) -> Result<Self::Out, AvroError> {
209        if null_variant == Some(idx) {
210            for _ in 0..n_variants - 1 {
211                self.packer.push(Datum::Null)
212            }
213        } else {
214            let mut deserializer = Some(deserializer);
215            for i in 0..n_variants {
216                let dec = AvroFlatDecoder {
217                    packer: self.packer,
218                    buf: self.buf,
219                    is_top: false,
220                };
221                if null_variant != Some(i) {
222                    if i == idx {
223                        deserializer.take().unwrap().deserialize(reader, dec)?;
224                    } else {
225                        self.packer.push(Datum::Null)
226                    }
227                }
228            }
229        }
230        Ok(())
231    }
232
233    #[inline]
234    fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
235        self.packer.push(Datum::String(symbol));
236        Ok(())
237    }
238    #[inline]
239    fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
240        match scalar {
241            mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
242            mz_avro::types::Scalar::Boolean(val) => {
243                if val {
244                    self.packer.push(Datum::True)
245                } else {
246                    self.packer.push(Datum::False)
247                }
248            }
249            mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
250            mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
251            mz_avro::types::Scalar::Float(val) => {
252                self.packer.push(Datum::Float32(OrderedFloat(val)))
253            }
254            mz_avro::types::Scalar::Double(val) => {
255                self.packer.push(Datum::Float64(OrderedFloat(val)))
256            }
257            mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
258                Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
259            )),
260            mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
261                CheckedTimestamp::from_timestamplike(val)
262                    .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
263            )),
264        }
265        Ok(())
266    }
267
268    #[inline]
269    fn decimal<'b, R: AvroRead>(
270        self,
271        _precision: usize,
272        scale: usize,
273        r: ValueOrReader<'b, &'b [u8], R>,
274    ) -> Result<Self::Out, AvroError> {
275        let mut buf = match r {
276            ValueOrReader::Value(val) => val.to_vec(),
277            ValueOrReader::Reader { len, r } => {
278                self.buf.resize_with(len, Default::default);
279                r.read_exact(self.buf)?;
280                let v = self.buf.clone();
281                v
282            }
283        };
284
285        let scale = u8::try_from(scale).map_err(|_| {
286            DecodeError::Custom(format!(
287                "Error decoding decimal: scale must fit within u8, but got scale {}",
288                scale,
289            ))
290        })?;
291
292        let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
293            .map_err(|e| e.to_string_with_causes())
294            .map_err(DecodeError::Custom)?;
295
296        if n.is_special()
297            || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
298        {
299            return Err(AvroError::Decode(DecodeError::Custom(format!(
300                "Error decoding numeric: exceeds maximum precision {}",
301                numeric::NUMERIC_DATUM_MAX_PRECISION
302            ))));
303        }
304
305        self.packer.push(Datum::from(n));
306
307        Ok(())
308    }
309
310    #[inline]
311    fn bytes<'b, R: AvroRead>(
312        self,
313        r: ValueOrReader<'b, &'b [u8], R>,
314    ) -> Result<Self::Out, AvroError> {
315        let buf = match r {
316            ValueOrReader::Value(val) => val,
317            ValueOrReader::Reader { len, r } => {
318                self.buf.resize_with(len, Default::default);
319                r.read_exact(self.buf)?;
320                self.buf
321            }
322        };
323        self.packer.push(Datum::Bytes(buf));
324        Ok(())
325    }
326    #[inline]
327    fn string<'b, R: AvroRead>(
328        self,
329        r: ValueOrReader<'b, &'b str, R>,
330    ) -> Result<Self::Out, AvroError> {
331        let s = match r {
332            ValueOrReader::Value(val) => val,
333            ValueOrReader::Reader { len, r } => {
334                // TODO - this copy is unnecessary,
335                // we should special case to just look at the bytes
336                // directly when r is &[u8].
337                // It probably doesn't make a huge difference though.
338                self.buf.resize_with(len, Default::default);
339                r.read_exact(self.buf)?;
340                std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
341            }
342        };
343        self.packer.push(Datum::String(s));
344        Ok(())
345    }
346    #[inline]
347    fn json<'b, R: AvroRead>(
348        self,
349        r: ValueOrReader<'b, &'b serde_json::Value, R>,
350    ) -> Result<Self::Out, AvroError> {
351        match r {
352            ValueOrReader::Value(val) => {
353                JsonbPacker::new(self.packer)
354                    .pack_serde_json(val.clone())
355                    .map_err(|e| {
356                        // Technically, these are not the original bytes;
357                        // they've gone through a deserialize-serialize
358                        // round trip. Hopefully they will be close enough to still
359                        // be useful for debugging.
360                        let bytes = val.to_string().into_bytes();
361
362                        DecodeError::BadJson {
363                            category: e.classify(),
364                            bytes,
365                        }
366                    })?;
367            }
368            ValueOrReader::Reader { len, r } => {
369                self.buf.resize_with(len, Default::default);
370                r.read_exact(self.buf)?;
371                JsonbPacker::new(self.packer)
372                    .pack_slice(self.buf)
373                    .map_err(|e| DecodeError::BadJson {
374                        category: e.classify(),
375                        bytes: self.buf.to_owned(),
376                    })?;
377            }
378        }
379        Ok(())
380    }
381    #[inline]
382    fn uuid<'b, R: AvroRead>(
383        self,
384        r: ValueOrReader<'b, &'b [u8], R>,
385    ) -> Result<Self::Out, AvroError> {
386        let buf = match r {
387            ValueOrReader::Value(val) => val,
388            ValueOrReader::Reader { len, r } => {
389                self.buf.resize_with(len, Default::default);
390                r.read_exact(self.buf)?;
391                self.buf
392            }
393        };
394        let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
395        self.packer.push(Datum::Uuid(
396            Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
397        ));
398        Ok(())
399    }
400    #[inline]
401    fn fixed<'b, R: AvroRead>(
402        self,
403        r: ValueOrReader<'b, &'b [u8], R>,
404    ) -> Result<Self::Out, AvroError> {
405        self.bytes(r)
406    }
407    #[inline]
408    fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
409        self.is_top = false;
410        let mut str_buf = std::mem::take(self.buf);
411        self.packer.push_list_with(|rp| -> Result<(), AvroError> {
412            loop {
413                let next = AvroFlatDecoder {
414                    packer: rp,
415                    buf: &mut str_buf,
416                    is_top: false,
417                };
418                if a.decode_next(next)?.is_none() {
419                    break;
420                }
421            }
422            Ok(())
423        })?;
424        *self.buf = str_buf;
425        Ok(())
426    }
427    #[inline]
428    fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
429        // Map (key, value) pairs need to be unique and ordered.
430        let mut map = BTreeMap::new();
431        while let Some((name, f)) = a.next_entry()? {
432            map.insert(name, f.decode_field(ValueDecoder)?);
433        }
434        self.packer
435            .push_dict_with(|packer| -> Result<(), AvroError> {
436                for (key, val) in map {
437                    packer.push(Datum::String(key.as_str()));
438                    give_value(
439                        AvroFlatDecoder {
440                            packer,
441                            buf: &mut vec![],
442                            is_top: false,
443                        },
444                        &val,
445                    )?;
446                }
447                Ok(())
448            })?;
449
450        Ok(())
451    }
452}
453
454#[derive(Clone, Debug, Serialize, Deserialize)]
455pub struct DiffPair<T> {
456    pub before: Option<T>,
457    pub after: Option<T>,
458}