mz_interchange/avro/
decode.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11
12use anyhow::{Context, Error};
13use mz_avro::error::{DecodeError, Error as AvroError};
14use mz_avro::{
15    AvroArrayAccess, AvroDecode, AvroDeserializer, AvroMapAccess, AvroRead, AvroRecordAccess,
16    GeneralDeserializer, ValueDecoder, ValueOrReader, give_value,
17};
18use mz_ore::error::ErrorExt;
19use mz_repr::adt::date::Date;
20use mz_repr::adt::jsonb::JsonbPacker;
21use mz_repr::adt::numeric;
22use mz_repr::adt::timestamp::CheckedTimestamp;
23use mz_repr::{Datum, Row, RowPacker};
24use ordered_float::OrderedFloat;
25use tracing::trace;
26use uuid::Uuid;
27
28use crate::avro::ConfluentAvroResolver;
29
30/// Manages decoding of Avro-encoded bytes.
31#[derive(Debug)]
32pub struct Decoder {
33    csr_avro: ConfluentAvroResolver,
34    debug_name: String,
35    buf1: Vec<u8>,
36    row_buf: Row,
37}
38
39#[cfg(test)]
40mod tests {
41    use mz_ore::assert_err;
42    use mz_repr::{Datum, Row};
43
44    use crate::avro::Decoder;
45
46    #[mz_ore::test(tokio::test)]
47    async fn test_error_followed_by_success() {
48        let schema = r#"{
49"type": "record",
50"name": "test",
51"fields": [{"name": "f1", "type": "int"}, {"name": "f2", "type": "int"}]
52}"#;
53        let mut decoder = Decoder::new(schema, None, "Test".to_string(), false).unwrap();
54        // This is not a valid Avro blob for the given schema
55        let mut bad_bytes: &[u8] = &[0];
56        assert_err!(decoder.decode(&mut bad_bytes).await.unwrap());
57        // This is the blob that will make both ints in the value zero.
58        let mut good_bytes: &[u8] = &[0, 0];
59        // The decode should succeed with the correct value.
60        assert_eq!(
61            decoder.decode(&mut good_bytes).await.unwrap().unwrap(),
62            Row::pack([Datum::Int32(0), Datum::Int32(0)])
63        );
64    }
65}
66
67impl Decoder {
68    /// Creates a new `Decoder`
69    ///
70    /// The provided schema is called the "reader schema", which is the schema
71    /// that we are expecting to use to decode records. The records may indicate
72    /// that they are encoded with a different schema; as long as those.
73    pub fn new(
74        reader_schema: &str,
75        ccsr_client: Option<mz_ccsr::Client>,
76        debug_name: String,
77        confluent_wire_format: bool,
78    ) -> anyhow::Result<Decoder> {
79        let csr_avro =
80            ConfluentAvroResolver::new(reader_schema, ccsr_client, confluent_wire_format)?;
81
82        Ok(Decoder {
83            csr_avro,
84            debug_name,
85            buf1: vec![],
86            row_buf: Row::default(),
87        })
88    }
89
90    /// Decodes Avro-encoded `bytes` into a `Row`.
91    pub async fn decode(&mut self, bytes: &mut &[u8]) -> Result<Result<Row, Error>, Error> {
92        // Clear out any bytes that might be left over from
93        // an earlier run. This can happen if the
94        // `dsr.deserialize` call returns an error,
95        // causing us to return early.
96        let mut packer = self.row_buf.packer();
97        // The outer Result describes transient errors so use ? here to propagate
98        let (bytes2, resolved_schema, csr_schema_id) = match self.csr_avro.resolve(bytes).await? {
99            Ok(ok) => ok,
100            Err(err) => return Ok(Err(err)),
101        };
102        *bytes = bytes2;
103        let dec = AvroFlatDecoder {
104            packer: &mut packer,
105            buf: &mut self.buf1,
106            is_top: true,
107        };
108        let dsr = GeneralDeserializer {
109            schema: resolved_schema.top_node(),
110        };
111        let result = dsr
112            .deserialize(bytes, dec)
113            .with_context(|| {
114                format!(
115                    "unable to decode row {}",
116                    match csr_schema_id {
117                        Some(id) => format!("(Avro schema id = {:?})", id),
118                        None => "".to_string(),
119                    }
120                )
121            })
122            .map(|_| self.row_buf.clone());
123        if result.is_ok() {
124            trace!(
125                "[customer-data] Decoded row {:?} in {}",
126                self.row_buf, self.debug_name
127            );
128        }
129        Ok(result)
130    }
131}
132
133#[derive(Debug)]
134pub struct AvroFlatDecoder<'a, 'row> {
135    pub packer: &'a mut RowPacker<'row>,
136    pub buf: &'a mut Vec<u8>,
137    pub is_top: bool,
138}
139
140impl<'a, 'row> AvroDecode for AvroFlatDecoder<'a, 'row> {
141    type Out = ();
142    #[inline]
143    fn record<R: AvroRead, A: AvroRecordAccess<R>>(
144        self,
145        a: &mut A,
146    ) -> Result<Self::Out, AvroError> {
147        let mut str_buf = std::mem::take(self.buf);
148        let mut pack_record = |rp: &mut RowPacker| -> Result<(), AvroError> {
149            let mut expected = 0;
150            let mut stash = vec![];
151            // The idea here is that if the deserializer gives us fields in the order we're expecting,
152            // we can decode them directly into the row.
153            // If not, we need to decode them into a Value (the old, slow decoding path) and stash them,
154            // so that we can put everything in the right order at the end.
155            //
156            // TODO(btv) - this is pretty bad, as a misordering at the top of the schema graph will
157            // cause the _entire_ chunk under it to be decoded in the slow way!
158            // Maybe instead, we should decode to separate sub-Rows and then add an API
159            // to Row that just copies in the bytes from another one.
160            while let Some((_name, idx, f)) = a.next_field()? {
161                if idx == expected {
162                    expected += 1;
163                    f.decode_field(AvroFlatDecoder {
164                        packer: rp,
165                        buf: &mut str_buf,
166                        is_top: false,
167                    })?;
168                } else {
169                    let val = f.decode_field(ValueDecoder)?;
170                    stash.push((idx, val));
171                }
172            }
173            stash.sort_by_key(|(idx, _val)| *idx);
174            for (idx, val) in stash {
175                assert!(idx == expected);
176                expected += 1;
177                let dec = AvroFlatDecoder {
178                    packer: rp,
179                    buf: &mut str_buf,
180                    is_top: false,
181                };
182                give_value(dec, &val)?;
183            }
184            Ok(())
185        };
186        if self.is_top {
187            pack_record(self.packer)?;
188        } else {
189            self.packer.push_list_with(pack_record)?;
190        }
191        *self.buf = str_buf;
192        Ok(())
193    }
194    #[inline]
195    fn union_branch<'b, R: AvroRead, D: AvroDeserializer>(
196        self,
197        idx: usize,
198        n_variants: usize,
199        null_variant: Option<usize>,
200        deserializer: D,
201        reader: &'b mut R,
202    ) -> Result<Self::Out, AvroError> {
203        if null_variant == Some(idx) {
204            for _ in 0..n_variants - 1 {
205                self.packer.push(Datum::Null)
206            }
207        } else {
208            let mut deserializer = Some(deserializer);
209            for i in 0..n_variants {
210                let dec = AvroFlatDecoder {
211                    packer: self.packer,
212                    buf: self.buf,
213                    is_top: false,
214                };
215                if null_variant != Some(i) {
216                    if i == idx {
217                        deserializer.take().unwrap().deserialize(reader, dec)?;
218                    } else {
219                        self.packer.push(Datum::Null)
220                    }
221                }
222            }
223        }
224        Ok(())
225    }
226
227    #[inline]
228    fn enum_variant(self, symbol: &str, _idx: usize) -> Result<Self::Out, AvroError> {
229        self.packer.push(Datum::String(symbol));
230        Ok(())
231    }
232    #[inline]
233    fn scalar(self, scalar: mz_avro::types::Scalar) -> Result<Self::Out, AvroError> {
234        match scalar {
235            mz_avro::types::Scalar::Null => self.packer.push(Datum::Null),
236            mz_avro::types::Scalar::Boolean(val) => {
237                if val {
238                    self.packer.push(Datum::True)
239                } else {
240                    self.packer.push(Datum::False)
241                }
242            }
243            mz_avro::types::Scalar::Int(val) => self.packer.push(Datum::Int32(val)),
244            mz_avro::types::Scalar::Long(val) => self.packer.push(Datum::Int64(val)),
245            mz_avro::types::Scalar::Float(val) => {
246                self.packer.push(Datum::Float32(OrderedFloat(val)))
247            }
248            mz_avro::types::Scalar::Double(val) => {
249                self.packer.push(Datum::Float64(OrderedFloat(val)))
250            }
251            mz_avro::types::Scalar::Date(val) => self.packer.push(Datum::Date(
252                Date::from_unix_epoch(val).map_err(|_| DecodeError::DateOutOfRange(val))?,
253            )),
254            mz_avro::types::Scalar::Timestamp(val) => self.packer.push(Datum::Timestamp(
255                CheckedTimestamp::from_timestamplike(val)
256                    .map_err(|_| DecodeError::TimestampOutOfRange(val))?,
257            )),
258        }
259        Ok(())
260    }
261
262    #[inline]
263    fn decimal<'b, R: AvroRead>(
264        self,
265        _precision: usize,
266        scale: usize,
267        r: ValueOrReader<'b, &'b [u8], R>,
268    ) -> Result<Self::Out, AvroError> {
269        let mut buf = match r {
270            ValueOrReader::Value(val) => val.to_vec(),
271            ValueOrReader::Reader { len, r } => {
272                self.buf.resize_with(len, Default::default);
273                r.read_exact(self.buf)?;
274                let v = self.buf.clone();
275                v
276            }
277        };
278
279        let scale = u8::try_from(scale).map_err(|_| {
280            DecodeError::Custom(format!(
281                "Error decoding decimal: scale must fit within u8, but got scale {}",
282                scale,
283            ))
284        })?;
285
286        let n = numeric::twos_complement_be_to_numeric(&mut buf, scale)
287            .map_err(|e| e.to_string_with_causes())
288            .map_err(DecodeError::Custom)?;
289
290        if n.is_special()
291            || numeric::get_precision(&n) > u32::from(numeric::NUMERIC_DATUM_MAX_PRECISION)
292        {
293            return Err(AvroError::Decode(DecodeError::Custom(format!(
294                "Error decoding numeric: exceeds maximum precision {}",
295                numeric::NUMERIC_DATUM_MAX_PRECISION
296            ))));
297        }
298
299        self.packer.push(Datum::from(n));
300
301        Ok(())
302    }
303
304    #[inline]
305    fn bytes<'b, R: AvroRead>(
306        self,
307        r: ValueOrReader<'b, &'b [u8], R>,
308    ) -> Result<Self::Out, AvroError> {
309        let buf = match r {
310            ValueOrReader::Value(val) => val,
311            ValueOrReader::Reader { len, r } => {
312                self.buf.resize_with(len, Default::default);
313                r.read_exact(self.buf)?;
314                self.buf
315            }
316        };
317        self.packer.push(Datum::Bytes(buf));
318        Ok(())
319    }
320    #[inline]
321    fn string<'b, R: AvroRead>(
322        self,
323        r: ValueOrReader<'b, &'b str, R>,
324    ) -> Result<Self::Out, AvroError> {
325        let s = match r {
326            ValueOrReader::Value(val) => val,
327            ValueOrReader::Reader { len, r } => {
328                // TODO - this copy is unnecessary,
329                // we should special case to just look at the bytes
330                // directly when r is &[u8].
331                // It probably doesn't make a huge difference though.
332                self.buf.resize_with(len, Default::default);
333                r.read_exact(self.buf)?;
334                std::str::from_utf8(self.buf).map_err(|_| DecodeError::StringUtf8Error)?
335            }
336        };
337        self.packer.push(Datum::String(s));
338        Ok(())
339    }
340    #[inline]
341    fn json<'b, R: AvroRead>(
342        self,
343        r: ValueOrReader<'b, &'b serde_json::Value, R>,
344    ) -> Result<Self::Out, AvroError> {
345        match r {
346            ValueOrReader::Value(val) => {
347                JsonbPacker::new(self.packer)
348                    .pack_serde_json(val.clone())
349                    .map_err(|e| {
350                        // Technically, these are not the original bytes;
351                        // they've gone through a deserialize-serialize
352                        // round trip. Hopefully they will be close enough to still
353                        // be useful for debugging.
354                        let bytes = val.to_string().into_bytes();
355
356                        DecodeError::BadJson {
357                            category: e.classify(),
358                            bytes,
359                        }
360                    })?;
361            }
362            ValueOrReader::Reader { len, r } => {
363                self.buf.resize_with(len, Default::default);
364                r.read_exact(self.buf)?;
365                JsonbPacker::new(self.packer)
366                    .pack_slice(self.buf)
367                    .map_err(|e| DecodeError::BadJson {
368                        category: e.classify(),
369                        bytes: self.buf.to_owned(),
370                    })?;
371            }
372        }
373        Ok(())
374    }
375    #[inline]
376    fn uuid<'b, R: AvroRead>(
377        self,
378        r: ValueOrReader<'b, &'b [u8], R>,
379    ) -> Result<Self::Out, AvroError> {
380        let buf = match r {
381            ValueOrReader::Value(val) => val,
382            ValueOrReader::Reader { len, r } => {
383                self.buf.resize_with(len, Default::default);
384                r.read_exact(self.buf)?;
385                self.buf
386            }
387        };
388        let s = std::str::from_utf8(buf).map_err(|_e| DecodeError::UuidUtf8Error)?;
389        self.packer.push(Datum::Uuid(
390            Uuid::parse_str(s).map_err(DecodeError::BadUuid)?,
391        ));
392        Ok(())
393    }
394    #[inline]
395    fn fixed<'b, R: AvroRead>(
396        self,
397        r: ValueOrReader<'b, &'b [u8], R>,
398    ) -> Result<Self::Out, AvroError> {
399        self.bytes(r)
400    }
401    #[inline]
402    fn array<A: AvroArrayAccess>(mut self, a: &mut A) -> Result<Self::Out, AvroError> {
403        self.is_top = false;
404        let mut str_buf = std::mem::take(self.buf);
405        self.packer.push_list_with(|rp| -> Result<(), AvroError> {
406            loop {
407                let next = AvroFlatDecoder {
408                    packer: rp,
409                    buf: &mut str_buf,
410                    is_top: false,
411                };
412                if a.decode_next(next)?.is_none() {
413                    break;
414                }
415            }
416            Ok(())
417        })?;
418        *self.buf = str_buf;
419        Ok(())
420    }
421    #[inline]
422    fn map<A: AvroMapAccess>(self, a: &mut A) -> Result<Self::Out, AvroError> {
423        // Map (key, value) pairs need to be unique and ordered.
424        let mut map = BTreeMap::new();
425        while let Some((name, f)) = a.next_entry()? {
426            map.insert(name, f.decode_field(ValueDecoder)?);
427        }
428        self.packer
429            .push_dict_with(|packer| -> Result<(), AvroError> {
430                for (key, val) in map {
431                    packer.push(Datum::String(key.as_str()));
432                    give_value(
433                        AvroFlatDecoder {
434                            packer,
435                            buf: &mut vec![],
436                            is_top: false,
437                        },
438                        &val,
439                    )?;
440                }
441                Ok(())
442            })?;
443
444        Ok(())
445    }
446}
447
448#[derive(Clone, Debug)]
449pub struct DiffPair<T> {
450    pub before: Option<T>,
451    pub after: Option<T>,
452}